summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-01-19 10:02:31 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-01-19 10:02:31 +0100
commit1ebae6fef9b19ef5e2bee8412658f2d8bc04e35b (patch)
tree18c36ceb740d2f20e2869a1e38aa902ab9b12499
parentd475b21ac86d0684e059a36641ecaa1f0621e680 (diff)
Remove unused ModelAttribute.split code; splits are handled via SplitFunction
-rw-r--r--lib/parameters.py114
1 files changed, 34 insertions, 80 deletions
diff --git a/lib/parameters.py b/lib/parameters.py
index d0894e4..dd435e6 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -585,9 +585,6 @@ class ModelAttribute:
# LUT model used as upper bound of model accuracy
self.by_param = None # set via ParallelParamStats
- # Split (decision tree) information
- self.split = None
-
# param model override
self.function_override = None
@@ -826,25 +823,38 @@ class ModelAttribute:
return np.median(self.by_param[param])
def get_data_for_paramfit(self, safe_functions_enabled=False):
- if self.split:
- return self.get_data_for_paramfit_split(
- safe_functions_enabled=safe_functions_enabled
- )
- else:
- return self.get_data_for_paramfit_this(
- safe_functions_enabled=safe_functions_enabled
- )
-
- def get_data_for_paramfit_split(self, safe_functions_enabled=False):
- # currently unused
- split_param_index, child_by_param_value = self.split
ret = list()
- for param_value, child in child_by_param_value.items():
- child_ret = child.get_data_for_paramfit(
- safe_functions_enabled=safe_functions_enabled
- )
- for key, param, args, kwargs in child_ret:
- ret.append((key[:2] + (param_value,) + key[2:], param, args, kwargs))
+ for param_index, param_name in enumerate(self.param_names):
+ if (
+ self.stats.depends_on_param(param_name)
+ and not param_index in self.ignore_codependent_param
+ ):
+ by_param = self._by_param_for_index(param_index)
+ ret.append(
+ (
+ (self.name, self.attr),
+ param_name,
+ (by_param, param_index, safe_functions_enabled),
+ dict(),
+ )
+ )
+ if self.arg_count:
+ for arg_index in range(self.arg_count):
+ param_index = len(self.param_names) + arg_index
+ if (
+ self.stats.depends_on_arg(arg_index)
+ and not param_index in self.ignore_codependent_param
+ ):
+ by_param = self._by_param_for_index(param_index)
+ ret.append(
+ (
+ (self.name, self.attr),
+ arg_index,
+ (by_param, param_index, safe_functions_enabled),
+ dict(),
+ )
+ )
+
return ret
def _by_param_for_index(self, param_index):
@@ -884,41 +894,6 @@ class ModelAttribute:
return False
return True
- def get_data_for_paramfit_this(self, safe_functions_enabled=False):
- ret = list()
- for param_index, param_name in enumerate(self.param_names):
- if (
- self.stats.depends_on_param(param_name)
- and not param_index in self.ignore_codependent_param
- ):
- by_param = self._by_param_for_index(param_index)
- ret.append(
- (
- (self.name, self.attr),
- param_name,
- (by_param, param_index, safe_functions_enabled),
- dict(),
- )
- )
- if self.arg_count:
- for arg_index in range(self.arg_count):
- param_index = len(self.param_names) + arg_index
- if (
- self.stats.depends_on_arg(arg_index)
- and not param_index in self.ignore_codependent_param
- ):
- by_param = self._by_param_for_index(param_index)
- ret.append(
- (
- (self.name, self.attr),
- arg_index,
- (by_param, param_index, safe_functions_enabled),
- dict(),
- )
- )
-
- return ret
-
def build_fol_model(self):
ignore_irrelevant = bool(
int(os.getenv("DFATOOL_DTREE_IGNORE_IRRELEVANT_PARAMS", "0"))
@@ -954,32 +929,11 @@ class ModelAttribute:
logger.warning(f"Fit of user-defined model function {function_str} failed.")
def set_data_from_paramfit(self, paramfit, prefix=tuple()):
- if self.split:
- self.set_data_from_paramfit_split(paramfit, prefix)
- else:
- self.set_data_from_paramfit_this(paramfit, prefix)
-
- def set_data_from_paramfit_split(self, paramfit, prefix):
- # currently unused
- split_param_index, child_by_param_value = self.split
- function_map = {
- "split_by": split_param_index,
- "child": dict(),
- "child_static": dict(),
- }
- function_child = dict()
- info_child = dict()
- for param_value, child in child_by_param_value.items():
- child.set_data_from_paramfit(paramfit, prefix + (param_value,))
- function_child[param_value] = child.model_function
- self.model_function = df.SplitFunction(
- self.median, split_param_index, function_child
- )
-
- def set_data_from_paramfit_this(self, paramfit, prefix):
fit_result = paramfit.get_result((self.name, self.attr) + prefix)
if self.model_function is None:
- self.model_function = df.StaticFunction(self.median)
+ self.model_function = df.StaticFunction(
+ self.median, n_samples=self.data.shape[0]
+ )
if os.getenv("DFATOOL_NO_PARAM"):
pass
elif len(fit_result.keys()):