diff options
Diffstat (limited to 'lib/model.py')
-rw-r--r-- | lib/model.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/lib/model.py b/lib/model.py index 5770218..2452af7 100644 --- a/lib/model.py +++ b/lib/model.py @@ -300,7 +300,7 @@ class AnalyticModel: model_type = os.getenv("DFATOOL_MODEL", "rmt") - if model_type != "rmt": + if model_type != "rmt" and model_type != "uls": for name in self.names: for attr in self.by_name[name]["attributes"]: if model_type == "cart": @@ -319,7 +319,7 @@ class AnalyticModel: self.attr_by_name[name][attr].build_xgb() else: logger.error(f"build_fitted: unknown model type: {model_type}") - elif self.force_tree: + elif model_type == "rmt" and self.force_tree: for name in self.names: for attr in self.by_name[name]["attributes"]: if ( @@ -337,8 +337,9 @@ class AnalyticModel: threshold=threshold, ) else: + # model_type == "rmt" and not self.force_tree or model_type == "uls" paramfit = ParamFit() - tree_allowed = bool(int(os.getenv("DFATOOL_RMT_ENABLED", "1"))) + tree_allowed = model_type == "rmt" tree_required = dict() for name in self.names: |