From 9754b3a46dad43211539a3dbfbc7c5095bdf30f5 Mon Sep 17 00:00:00 2001 From: Birte Kristina Friesel Date: Thu, 7 Mar 2024 10:28:40 +0100 Subject: Replace RMT_ENABLED=0 with MODEL=uls --- lib/model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'lib') diff --git a/lib/model.py b/lib/model.py index 5770218..2452af7 100644 --- a/lib/model.py +++ b/lib/model.py @@ -300,7 +300,7 @@ class AnalyticModel: model_type = os.getenv("DFATOOL_MODEL", "rmt") - if model_type != "rmt": + if model_type != "rmt" and model_type != "uls": for name in self.names: for attr in self.by_name[name]["attributes"]: if model_type == "cart": @@ -319,7 +319,7 @@ class AnalyticModel: self.attr_by_name[name][attr].build_xgb() else: logger.error(f"build_fitted: unknown model type: {model_type}") - elif self.force_tree: + elif model_type == "rmt" and self.force_tree: for name in self.names: for attr in self.by_name[name]["attributes"]: if ( @@ -337,8 +337,9 @@ class AnalyticModel: threshold=threshold, ) else: + # model_type == "rmt" and not self.force_tree or model_type == "uls" paramfit = ParamFit() - tree_allowed = bool(int(os.getenv("DFATOOL_RMT_ENABLED", "1"))) + tree_allowed = model_type == "rmt" tree_required = dict() for name in self.names: -- cgit v1.2.3