diff options
-rw-r--r-- | lib/model.py | 60 | ||||
-rw-r--r-- | lib/validation.py | 3 |
2 files changed, 23 insertions, 40 deletions
diff --git a/lib/model.py b/lib/model.py index e2a455d..cb08fce 100644 --- a/lib/model.py +++ b/lib/model.py @@ -154,6 +154,7 @@ class AnalyticModel: self.fit_done = True return + self.force_tree = force_tree self.fit_done = False if compute_stats: @@ -177,23 +178,6 @@ class AnalyticModel: (name, attr) ] - if force_tree: - for name in self.names: - for attr in self.by_name[name]["attributes"]: - if max_std and name in max_std and attr in max_std[name]: - threshold = max_std[name][attr] - elif compute_stats: - threshold = self.attr_by_name[name][attr].stats.std_param_lut - else: - threshold = 0 - logger.debug(f"build_dtree({name}, {attr}, threshold={threshold})") - self.build_dtree( - name, - attr, - threshold=threshold, - ) - self.fit_done = True - def get_by_param(self): if not "by_param" in self.cache: self.cache["by_param"] = by_name_to_by_param(self.by_name) @@ -318,6 +302,27 @@ class AnalyticModel: model_info(name, attribute) -> {'fit_result' : ..., 'function' : ... } or None """ + if self.force_tree and not self.fit_done: + for name in self.names: + for attr in self.by_name[name]["attributes"]: + if ( + self.dtree_max_std + and name in self.dtree_max_std + and attr in self.dtree_max_std[name] + ): + threshold = self.dtree_max_std[name][attr] + elif self.attr_by_name[name][attr].stats: + threshold = self.attr_by_name[name][attr].stats.std_param_lut + else: + threshold = 0 + logger.debug(f"build_dtree({name}, {attr}, threshold={threshold})") + self.build_dtree( + name, + attr, + threshold=threshold, + ) + self.fit_done = True + if not self.fit_done: paramfit = ParamFit() tree_allowed = bool(int(os.getenv("DFATOOL_DTREE_ENABLED", "1"))) @@ -772,26 +777,7 @@ class PTAModel(AnalyticModel): if compute_stats: self._compute_stats(by_name) - if force_tree: - for name in self.names: - for attr in self.by_name[name]["attributes"]: - if ( - dtree_max_std - and name in dtree_max_std - and attr in dtree_max_std[name] - ): - threshold = dtree_max_std[name][attr] - elif compute_stats: - threshold = (self.attr_by_name[name][attr].stats.std_param_lut,) - else: - threshold = 0 - logger.debug(f"build_dtree({name}, {attr}, threshold={threshold})") - self.build_dtree( - name, - attr, - threshold=threshold, - ) - self.fit_done = True + self.force_tree = force_tree if self.pelt is not None: # cluster_substates uses self.attr_by_name[*]["power"].param_values, which is set by _compute_stats diff --git a/lib/validation.py b/lib/validation.py index 6203003..958a9e0 100644 --- a/lib/validation.py +++ b/lib/validation.py @@ -302,8 +302,6 @@ class CrossValidator: logger.debug("Creating training model instance") kwargs = self.kwargs.copy() - if static: - kwargs["force_tree"] = False training_data = self.model_class( training, self.parameters, *self.args, **kwargs ) @@ -311,7 +309,6 @@ class CrossValidator: training_model = model_getter(training_data) kwargs = self.kwargs.copy() kwargs["compute_stats"] = False - kwargs["force_tree"] = False logger.debug("Creating validation model instance") validation_data = self.model_class( validation, self.parameters, *self.args, **kwargs |