summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2021-08-24 12:15:36 +0200
committerDaniel Friesel <daniel.friesel@uos.de>2021-08-24 12:15:36 +0200
commit656d7dd32d069f784e0e9398e9f98d11f2127b1f (patch)
tree616d669d3dc8b6acdda1e220ea51d6d020d3b55f /lib
parenta87a547d3e8c0c07b0d58ec34adf0f3bcd0cbce8 (diff)
dtree: Add function arg support. Update tests to reflect new models.
Diffstat (limited to 'lib')
-rw-r--r--lib/parameters.py29
1 files changed, 19 insertions, 10 deletions
diff --git a/lib/parameters.py b/lib/parameters.py
index eae04c8..d64be86 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -772,13 +772,17 @@ class ModelAttribute:
# TODO remove data entries which are None (and remove corresponding parameters, too!)
- parameter_names = self.param_names
- if len(parameter_names) == 0 or np.std(data) < threshold:
+ param_count = len(self.param_names) + self.arg_count
+ if param_count == 0 or np.std(data) < threshold:
return df.StaticFunction(np.mean(data))
# sf.value_error["std"] = np.std(data)
mean_stds = list()
- for param_index, param in enumerate(parameter_names):
+ for param_index in range(param_count):
+
+ if param_index in self.ignore_param:
+ mean_stds.append(np.inf)
+ continue
unique_values = list(set(map(lambda p: p[param_index], parameters)))
@@ -825,7 +829,14 @@ class ModelAttribute:
# all children have the same configuration. We shouldn't get here due to the threshold check above...
if with_function_leaves:
# try generating a function. if it fails, model_function is a StaticFunction.
- ma = ModelAttribute("tmp", "tmp", data, parameters, self.param_names, 0)
+ ma = ModelAttribute(
+ "tmp",
+ "tmp",
+ data,
+ parameters,
+ self.param_names,
+ arg_count=self.arg_count,
+ )
ParamStats.compute_for_attr(ma)
paramfit = ParamFit(parallel=False)
for key, param, args, kwargs in ma.get_data_for_paramfit():
@@ -833,15 +844,13 @@ class ModelAttribute:
paramfit.fit()
ma.set_data_from_paramfit(paramfit)
return ma.model_function
- else:
- logging.warning(
- f"While building DTree for configurations {parameters}: Children have identical configuration, but high stddev ({np.std(data)}). Falling back to Staticfunction"
- )
+ # else:
+ # logging.warning(
+ # f"While building DTree for configurations {parameters}: Children have identical configuration, but high stddev ({np.std(data)}). Falling back to Staticfunction"
+ # )
return df.StaticFunction(np.mean(data))
symbol_index = np.argmin(mean_stds)
- symbol = parameter_names[symbol_index]
-
unique_values = list(set(map(lambda p: p[symbol_index], parameters)))
child = dict()