diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-02-21 11:09:15 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-02-21 11:09:15 +0100 |
commit | 6bf411afb9289408c62e1696d8fb2b4da47a9fab (patch) | |
tree | 4eb44b80585e44f581c14233e465731b6bf14685 | |
parent | bf95fffe4fb017e2dcc0bbb268089d3231e15573 (diff) |
model: Remove no longer useful build_dtree wrapper
-rw-r--r-- | lib/model.py | 63 |
1 files changed, 24 insertions, 39 deletions
diff --git a/lib/model.py b/lib/model.py index cb08fce..9266153 100644 --- a/lib/model.py +++ b/lib/model.py @@ -293,16 +293,9 @@ class AnalyticModel: return lut_median_getter - def get_fitted(self, use_mean=False, safe_functions_enabled=False): - """ - Get parameter-aware model function and model information function. + def build_fitted(self, safe_functions_enabled=False): - Returns two functions: - model_function(name, attribute, param=parameter values) -> model value. - model_info(name, attribute) -> {'fit_result' : ..., 'function' : ... } or None - """ - - if self.force_tree and not self.fit_done: + if self.force_tree: for name in self.names: for attr in self.by_name[name]["attributes"]: if ( @@ -316,14 +309,12 @@ class AnalyticModel: else: threshold = 0 logger.debug(f"build_dtree({name}, {attr}, threshold={threshold})") - self.build_dtree( - name, - attr, + self.attr_by_name[name][attr].build_dtree( + self.by_name[name]["param"], + self.by_name[name][attr], threshold=threshold, ) - self.fit_done = True - - if not self.fit_done: + else: paramfit = ParamFit() tree_allowed = bool(int(os.getenv("DFATOOL_DTREE_ENABLED", "1"))) use_fol = bool(int(os.getenv("DFATOOL_FIT_FOL", "0"))) @@ -368,11 +359,27 @@ class AnalyticModel: logger.debug( f"build_dtree({name}, {attr}, threshold={threshold})" ) - self.build_dtree(name, attr, threshold=threshold) + self.attr_by_name[name][attr].build_dtree( + self.by_name[name]["param"], + self.by_name[name][attr], + threshold=threshold, + ) else: self.attr_by_name[name][attr].set_data_from_paramfit(paramfit) - self.fit_done = True + self.fit_done = True + + def get_fitted(self, use_mean=False, safe_functions_enabled=False): + """ + Get parameter-aware model function and model information function. + + Returns two functions: + model_function(name, attribute, param=parameter values) -> model value. + model_info(name, attribute) -> {'fit_result' : ..., 'function' : ... } or None + """ + + if not self.fit_done: + self.build_fitted(safe_functions_enabled=safe_functions_enabled) static_model = dict() for name, attr in self.attr_by_name.items(): @@ -464,28 +471,6 @@ class AnalyticModel: return detailed_results, raw_results return detailed_results - def build_dtree(self, name, attribute, threshold=100, **kwargs): - if name not in self.attr_by_name: - self.attr_by_name[name] = dict() - - if attribute not in self.attr_by_name[name]: - self.attr_by_name[name][attribute] = ModelAttribute( - name, - attribute, - self.by_name[name][attribute], - self.by_name[name]["param"], - self.parameters, - self._num_args.get(name, 0), - param_type=ParamType(self.by_name[name]["param"]), - ) - - self.attr_by_name[name][attribute].build_dtree( - self.by_name[name]["param"], - self.by_name[name][attribute], - threshold=threshold, - **kwargs, - ) - def to_dref( self, static_quality, lut_quality, model_quality, xv_models=None ) -> dict: |