summaryrefslogtreecommitdiff
path: root/lib/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/model.py')
-rw-r--r--lib/model.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/lib/model.py b/lib/model.py
index 5770218..2452af7 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -300,7 +300,7 @@ class AnalyticModel:
model_type = os.getenv("DFATOOL_MODEL", "rmt")
- if model_type != "rmt":
+ if model_type != "rmt" and model_type != "uls":
for name in self.names:
for attr in self.by_name[name]["attributes"]:
if model_type == "cart":
@@ -319,7 +319,7 @@ class AnalyticModel:
self.attr_by_name[name][attr].build_xgb()
else:
logger.error(f"build_fitted: unknown model type: {model_type}")
- elif self.force_tree:
+ elif model_type == "rmt" and self.force_tree:
for name in self.names:
for attr in self.by_name[name]["attributes"]:
if (
@@ -337,8 +337,9 @@ class AnalyticModel:
threshold=threshold,
)
else:
+ # model_type == "rmt" and not self.force_tree or model_type == "uls"
paramfit = ParamFit()
- tree_allowed = bool(int(os.getenv("DFATOOL_RMT_ENABLED", "1")))
+ tree_allowed = model_type == "rmt"
tree_required = dict()
for name in self.names: