diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-02-21 10:48:47 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-02-21 10:49:09 +0100 |
commit | bf95fffe4fb017e2dcc0bbb268089d3231e15573 (patch) | |
tree | 5bf609b10a7bcb997f88de76d83bc6f559e48a3b /lib/model.py | |
parent | 3d19e24370798f37d8119b4366b8486d67ed3110 (diff) |
Always build model in get_fitted; never in constructor
Diffstat (limited to 'lib/model.py')
-rw-r--r-- | lib/model.py | 60 |
1 files changed, 23 insertions, 37 deletions
diff --git a/lib/model.py b/lib/model.py index e2a455d..cb08fce 100644 --- a/lib/model.py +++ b/lib/model.py @@ -154,6 +154,7 @@ class AnalyticModel: self.fit_done = True return + self.force_tree = force_tree self.fit_done = False if compute_stats: @@ -177,23 +178,6 @@ class AnalyticModel: (name, attr) ] - if force_tree: - for name in self.names: - for attr in self.by_name[name]["attributes"]: - if max_std and name in max_std and attr in max_std[name]: - threshold = max_std[name][attr] - elif compute_stats: - threshold = self.attr_by_name[name][attr].stats.std_param_lut - else: - threshold = 0 - logger.debug(f"build_dtree({name}, {attr}, threshold={threshold})") - self.build_dtree( - name, - attr, - threshold=threshold, - ) - self.fit_done = True - def get_by_param(self): if not "by_param" in self.cache: self.cache["by_param"] = by_name_to_by_param(self.by_name) @@ -318,6 +302,27 @@ class AnalyticModel: model_info(name, attribute) -> {'fit_result' : ..., 'function' : ... } or None """ + if self.force_tree and not self.fit_done: + for name in self.names: + for attr in self.by_name[name]["attributes"]: + if ( + self.dtree_max_std + and name in self.dtree_max_std + and attr in self.dtree_max_std[name] + ): + threshold = self.dtree_max_std[name][attr] + elif self.attr_by_name[name][attr].stats: + threshold = self.attr_by_name[name][attr].stats.std_param_lut + else: + threshold = 0 + logger.debug(f"build_dtree({name}, {attr}, threshold={threshold})") + self.build_dtree( + name, + attr, + threshold=threshold, + ) + self.fit_done = True + if not self.fit_done: paramfit = ParamFit() tree_allowed = bool(int(os.getenv("DFATOOL_DTREE_ENABLED", "1"))) @@ -772,26 +777,7 @@ class PTAModel(AnalyticModel): if compute_stats: self._compute_stats(by_name) - if force_tree: - for name in self.names: - for attr in self.by_name[name]["attributes"]: - if ( - dtree_max_std - and name in dtree_max_std - and attr in dtree_max_std[name] - ): - threshold = dtree_max_std[name][attr] - elif compute_stats: - threshold = (self.attr_by_name[name][attr].stats.std_param_lut,) - else: - threshold = 0 - logger.debug(f"build_dtree({name}, {attr}, threshold={threshold})") - self.build_dtree( - name, - attr, - threshold=threshold, - ) - self.fit_done = True + self.force_tree = force_tree if self.pelt is not None: # cluster_substates uses self.attr_by_name[*]["power"].param_values, which is set by _compute_stats |