diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-01-19 10:02:31 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-01-19 10:02:31 +0100 |
commit | 1ebae6fef9b19ef5e2bee8412658f2d8bc04e35b (patch) | |
tree | 18c36ceb740d2f20e2869a1e38aa902ab9b12499 | |
parent | d475b21ac86d0684e059a36641ecaa1f0621e680 (diff) |
Remove unused ModelAttribute.split code; splits are handled via SplitFunction
-rw-r--r-- | lib/parameters.py | 114 |
1 files changed, 34 insertions, 80 deletions
diff --git a/lib/parameters.py b/lib/parameters.py index d0894e4..dd435e6 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -585,9 +585,6 @@ class ModelAttribute: # LUT model used as upper bound of model accuracy self.by_param = None # set via ParallelParamStats - # Split (decision tree) information - self.split = None - # param model override self.function_override = None @@ -826,25 +823,38 @@ class ModelAttribute: return np.median(self.by_param[param]) def get_data_for_paramfit(self, safe_functions_enabled=False): - if self.split: - return self.get_data_for_paramfit_split( - safe_functions_enabled=safe_functions_enabled - ) - else: - return self.get_data_for_paramfit_this( - safe_functions_enabled=safe_functions_enabled - ) - - def get_data_for_paramfit_split(self, safe_functions_enabled=False): - # currently unused - split_param_index, child_by_param_value = self.split ret = list() - for param_value, child in child_by_param_value.items(): - child_ret = child.get_data_for_paramfit( - safe_functions_enabled=safe_functions_enabled - ) - for key, param, args, kwargs in child_ret: - ret.append((key[:2] + (param_value,) + key[2:], param, args, kwargs)) + for param_index, param_name in enumerate(self.param_names): + if ( + self.stats.depends_on_param(param_name) + and not param_index in self.ignore_codependent_param + ): + by_param = self._by_param_for_index(param_index) + ret.append( + ( + (self.name, self.attr), + param_name, + (by_param, param_index, safe_functions_enabled), + dict(), + ) + ) + if self.arg_count: + for arg_index in range(self.arg_count): + param_index = len(self.param_names) + arg_index + if ( + self.stats.depends_on_arg(arg_index) + and not param_index in self.ignore_codependent_param + ): + by_param = self._by_param_for_index(param_index) + ret.append( + ( + (self.name, self.attr), + arg_index, + (by_param, param_index, safe_functions_enabled), + dict(), + ) + ) + return ret def _by_param_for_index(self, param_index): @@ -884,41 +894,6 @@ class ModelAttribute: return False return True - def get_data_for_paramfit_this(self, safe_functions_enabled=False): - ret = list() - for param_index, param_name in enumerate(self.param_names): - if ( - self.stats.depends_on_param(param_name) - and not param_index in self.ignore_codependent_param - ): - by_param = self._by_param_for_index(param_index) - ret.append( - ( - (self.name, self.attr), - param_name, - (by_param, param_index, safe_functions_enabled), - dict(), - ) - ) - if self.arg_count: - for arg_index in range(self.arg_count): - param_index = len(self.param_names) + arg_index - if ( - self.stats.depends_on_arg(arg_index) - and not param_index in self.ignore_codependent_param - ): - by_param = self._by_param_for_index(param_index) - ret.append( - ( - (self.name, self.attr), - arg_index, - (by_param, param_index, safe_functions_enabled), - dict(), - ) - ) - - return ret - def build_fol_model(self): ignore_irrelevant = bool( int(os.getenv("DFATOOL_DTREE_IGNORE_IRRELEVANT_PARAMS", "0")) @@ -954,32 +929,11 @@ class ModelAttribute: logger.warning(f"Fit of user-defined model function {function_str} failed.") def set_data_from_paramfit(self, paramfit, prefix=tuple()): - if self.split: - self.set_data_from_paramfit_split(paramfit, prefix) - else: - self.set_data_from_paramfit_this(paramfit, prefix) - - def set_data_from_paramfit_split(self, paramfit, prefix): - # currently unused - split_param_index, child_by_param_value = self.split - function_map = { - "split_by": split_param_index, - "child": dict(), - "child_static": dict(), - } - function_child = dict() - info_child = dict() - for param_value, child in child_by_param_value.items(): - child.set_data_from_paramfit(paramfit, prefix + (param_value,)) - function_child[param_value] = child.model_function - self.model_function = df.SplitFunction( - self.median, split_param_index, function_child - ) - - def set_data_from_paramfit_this(self, paramfit, prefix): fit_result = paramfit.get_result((self.name, self.attr) + prefix) if self.model_function is None: - self.model_function = df.StaticFunction(self.median) + self.model_function = df.StaticFunction( + self.median, n_samples=self.data.shape[0] + ) if os.getenv("DFATOOL_NO_PARAM"): pass elif len(fit_result.keys()): |