summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-02-21 10:48:47 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-02-21 10:49:09 +0100
commitbf95fffe4fb017e2dcc0bbb268089d3231e15573 (patch)
tree5bf609b10a7bcb997f88de76d83bc6f559e48a3b
parent3d19e24370798f37d8119b4366b8486d67ed3110 (diff)
Always build model in get_fitted; never in constructor
-rw-r--r--lib/model.py60
-rw-r--r--lib/validation.py3
2 files changed, 23 insertions, 40 deletions
diff --git a/lib/model.py b/lib/model.py
index e2a455d..cb08fce 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -154,6 +154,7 @@ class AnalyticModel:
self.fit_done = True
return
+ self.force_tree = force_tree
self.fit_done = False
if compute_stats:
@@ -177,23 +178,6 @@ class AnalyticModel:
(name, attr)
]
- if force_tree:
- for name in self.names:
- for attr in self.by_name[name]["attributes"]:
- if max_std and name in max_std and attr in max_std[name]:
- threshold = max_std[name][attr]
- elif compute_stats:
- threshold = self.attr_by_name[name][attr].stats.std_param_lut
- else:
- threshold = 0
- logger.debug(f"build_dtree({name}, {attr}, threshold={threshold})")
- self.build_dtree(
- name,
- attr,
- threshold=threshold,
- )
- self.fit_done = True
-
def get_by_param(self):
if not "by_param" in self.cache:
self.cache["by_param"] = by_name_to_by_param(self.by_name)
@@ -318,6 +302,27 @@ class AnalyticModel:
model_info(name, attribute) -> {'fit_result' : ..., 'function' : ... } or None
"""
+ if self.force_tree and not self.fit_done:
+ for name in self.names:
+ for attr in self.by_name[name]["attributes"]:
+ if (
+ self.dtree_max_std
+ and name in self.dtree_max_std
+ and attr in self.dtree_max_std[name]
+ ):
+ threshold = self.dtree_max_std[name][attr]
+ elif self.attr_by_name[name][attr].stats:
+ threshold = self.attr_by_name[name][attr].stats.std_param_lut
+ else:
+ threshold = 0
+ logger.debug(f"build_dtree({name}, {attr}, threshold={threshold})")
+ self.build_dtree(
+ name,
+ attr,
+ threshold=threshold,
+ )
+ self.fit_done = True
+
if not self.fit_done:
paramfit = ParamFit()
tree_allowed = bool(int(os.getenv("DFATOOL_DTREE_ENABLED", "1")))
@@ -772,26 +777,7 @@ class PTAModel(AnalyticModel):
if compute_stats:
self._compute_stats(by_name)
- if force_tree:
- for name in self.names:
- for attr in self.by_name[name]["attributes"]:
- if (
- dtree_max_std
- and name in dtree_max_std
- and attr in dtree_max_std[name]
- ):
- threshold = dtree_max_std[name][attr]
- elif compute_stats:
- threshold = (self.attr_by_name[name][attr].stats.std_param_lut,)
- else:
- threshold = 0
- logger.debug(f"build_dtree({name}, {attr}, threshold={threshold})")
- self.build_dtree(
- name,
- attr,
- threshold=threshold,
- )
- self.fit_done = True
+ self.force_tree = force_tree
if self.pelt is not None:
# cluster_substates uses self.attr_by_name[*]["power"].param_values, which is set by _compute_stats
diff --git a/lib/validation.py b/lib/validation.py
index 6203003..958a9e0 100644
--- a/lib/validation.py
+++ b/lib/validation.py
@@ -302,8 +302,6 @@ class CrossValidator:
logger.debug("Creating training model instance")
kwargs = self.kwargs.copy()
- if static:
- kwargs["force_tree"] = False
training_data = self.model_class(
training, self.parameters, *self.args, **kwargs
)
@@ -311,7 +309,6 @@ class CrossValidator:
training_model = model_getter(training_data)
kwargs = self.kwargs.copy()
kwargs["compute_stats"] = False
- kwargs["force_tree"] = False
logger.debug("Creating validation model instance")
validation_data = self.model_class(
validation, self.parameters, *self.args, **kwargs