summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-02-21 11:09:15 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-02-21 11:09:15 +0100
commit6bf411afb9289408c62e1696d8fb2b4da47a9fab (patch)
tree4eb44b80585e44f581c14233e465731b6bf14685
parentbf95fffe4fb017e2dcc0bbb268089d3231e15573 (diff)
model: Remove no longer useful build_dtree wrapper
-rw-r--r--lib/model.py63
1 files changed, 24 insertions, 39 deletions
diff --git a/lib/model.py b/lib/model.py
index cb08fce..9266153 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -293,16 +293,9 @@ class AnalyticModel:
return lut_median_getter
- def get_fitted(self, use_mean=False, safe_functions_enabled=False):
- """
- Get parameter-aware model function and model information function.
+ def build_fitted(self, safe_functions_enabled=False):
- Returns two functions:
- model_function(name, attribute, param=parameter values) -> model value.
- model_info(name, attribute) -> {'fit_result' : ..., 'function' : ... } or None
- """
-
- if self.force_tree and not self.fit_done:
+ if self.force_tree:
for name in self.names:
for attr in self.by_name[name]["attributes"]:
if (
@@ -316,14 +309,12 @@ class AnalyticModel:
else:
threshold = 0
logger.debug(f"build_dtree({name}, {attr}, threshold={threshold})")
- self.build_dtree(
- name,
- attr,
+ self.attr_by_name[name][attr].build_dtree(
+ self.by_name[name]["param"],
+ self.by_name[name][attr],
threshold=threshold,
)
- self.fit_done = True
-
- if not self.fit_done:
+ else:
paramfit = ParamFit()
tree_allowed = bool(int(os.getenv("DFATOOL_DTREE_ENABLED", "1")))
use_fol = bool(int(os.getenv("DFATOOL_FIT_FOL", "0")))
@@ -368,11 +359,27 @@ class AnalyticModel:
logger.debug(
f"build_dtree({name}, {attr}, threshold={threshold})"
)
- self.build_dtree(name, attr, threshold=threshold)
+ self.attr_by_name[name][attr].build_dtree(
+ self.by_name[name]["param"],
+ self.by_name[name][attr],
+ threshold=threshold,
+ )
else:
self.attr_by_name[name][attr].set_data_from_paramfit(paramfit)
- self.fit_done = True
+ self.fit_done = True
+
+ def get_fitted(self, use_mean=False, safe_functions_enabled=False):
+ """
+ Get parameter-aware model function and model information function.
+
+ Returns two functions:
+ model_function(name, attribute, param=parameter values) -> model value.
+ model_info(name, attribute) -> {'fit_result' : ..., 'function' : ... } or None
+ """
+
+ if not self.fit_done:
+ self.build_fitted(safe_functions_enabled=safe_functions_enabled)
static_model = dict()
for name, attr in self.attr_by_name.items():
@@ -464,28 +471,6 @@ class AnalyticModel:
return detailed_results, raw_results
return detailed_results
- def build_dtree(self, name, attribute, threshold=100, **kwargs):
- if name not in self.attr_by_name:
- self.attr_by_name[name] = dict()
-
- if attribute not in self.attr_by_name[name]:
- self.attr_by_name[name][attribute] = ModelAttribute(
- name,
- attribute,
- self.by_name[name][attribute],
- self.by_name[name]["param"],
- self.parameters,
- self._num_args.get(name, 0),
- param_type=ParamType(self.by_name[name]["param"]),
- )
-
- self.attr_by_name[name][attribute].build_dtree(
- self.by_name[name]["param"],
- self.by_name[name][attribute],
- threshold=threshold,
- **kwargs,
- )
-
def to_dref(
self, static_quality, lut_quality, model_quality, xv_models=None
) -> dict: