summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2021-12-03 14:46:23 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2021-12-03 14:46:23 +0100
commita9d538afbd9d766a35093e851fbe5c12112fb2eb (patch)
treebb706dc2dd04da432209f8a90458f1913b9af510 /lib
parentde1187bdd18d7494d58665ac8f46ad9a42790384 (diff)
dtree: switch to ssr loss
Diffstat (limited to 'lib')
-rw-r--r--lib/parameters.py25
1 files changed, 13 insertions, 12 deletions
diff --git a/lib/parameters.py b/lib/parameters.py
index 266bcec..e199153 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -901,24 +901,24 @@ class ModelAttribute:
return df.StaticFunction(np.mean(data))
# sf.value_error["std"] = np.std(data)
- mean_stds = list()
+ loss = list()
for param_index in range(param_count):
if param_index in self.ignore_param:
- mean_stds.append(np.inf)
+ loss.append(np.inf)
continue
unique_values = list(set(map(lambda p: p[param_index], parameters)))
if None in unique_values:
# param is a choice and undefined in some configs. Do not split on it.
- mean_stds.append(np.inf)
+ loss.append(np.inf)
continue
# if not with_nonbinary_nodes and sorted(unique_values) != [0, 1]:
if not with_nonbinary_nodes and len(unique_values) > 2:
# param cannot be handled with a binary split
- mean_stds.append(np.inf)
+ loss.append(np.inf)
continue
if (
@@ -927,7 +927,7 @@ class ModelAttribute:
and len(unique_values) >= self.min_values_for_analytic_model
):
# param can be modeled as a function. Do not split on it.
- mean_stds.append(np.inf)
+ loss.append(np.inf)
continue
child_indexes = list()
@@ -943,7 +943,7 @@ class ModelAttribute:
if len(list(filter(len, child_indexes))) <= 1:
# this param only has a single value. there's no point in splitting.
- mean_stds.append(np.inf)
+ loss.append(np.inf)
continue
children = list()
@@ -956,16 +956,17 @@ class ModelAttribute:
child_param,
ignore_parameters=self.nonscalar_param_indexes,
)
- children.extend(map(np.std, child_data_by_scalar.values()))
+ for sub_data in child_data_by_scalar.values():
+ children.extend((np.array(sub_data) - np.mean(sub_data)) ** 2)
else:
- children.append(np.std(child_data))
+ children.extend((np.array(child_data) - np.mean(child_data)) ** 2)
if np.any(np.isnan(children)):
- mean_stds.append(np.inf)
+ loss.append(np.inf) #
else:
- mean_stds.append(np.mean(children))
+ loss.append(np.sum(children))
- if np.all(np.isinf(mean_stds)):
+ if np.all(np.isinf(loss)):
# all children have the same configuration. We shouldn't get here due to the threshold check above...
if with_function_leaves:
# try generating a function. if it fails, model_function is a StaticFunction.
@@ -991,7 +992,7 @@ class ModelAttribute:
# )
return df.StaticFunction(np.mean(data))
- symbol_index = np.argmin(mean_stds)
+ symbol_index = np.argmin(loss)
unique_values = list(set(map(lambda p: p[symbol_index], parameters)))
child = dict()