summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2022-01-31 16:47:21 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2022-01-31 16:47:21 +0100
commit063b277b32e8bdde5ce04efdf236c7821c79a3da (patch)
treeb445255c45b5d6f71296d3f255abd84224cc19cc /lib
parentc2a5c5675bce5a357f4b7f7b5020dedc7fc33caf (diff)
RMT generation: only split on relevant variables
Diffstat (limited to 'lib')
-rw-r--r--lib/parameters.py22
1 files changed, 19 insertions, 3 deletions
diff --git a/lib/parameters.py b/lib/parameters.py
index a619103..faa2e96 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -1006,7 +1006,7 @@ class ModelAttribute:
:param loss_ignore_scalar: Ignore scalar parameters when computing the loss for split candidates. Only sensible if with_function_leaves is enabled.
:param threshold: Return a StaticFunction leaf node if std(data) < threshold. Default 100.
- :returns: SplitFunction or StaticFunction
+ :returns: ModelFunction
"""
param_count = len(self.param_names) + self.arg_count
@@ -1015,6 +1015,14 @@ class ModelAttribute:
# sf.value_error["std"] = np.std(data)
loss = list()
+
+ ffs_feasible = False
+ by_param = partition_by_param(data, parameters)
+ distinct_values_by_param_index = distinct_param_values(
+ parameters
+ ) # required, "unique_values" in for loop is insufficient for std_by_param foo
+ std_static = np.std(data)
+ std_lut = np.mean([np.std(v) for v in by_param.values()])
for param_index in range(param_count):
if param_index in self.ignore_param:
@@ -1034,6 +1042,13 @@ class ModelAttribute:
loss.append(np.inf)
continue
+ std_by_param = _mean_std_by_param(
+ by_param, distinct_values_by_param_index, param_index
+ )
+ if not _depends_on_param(None, std_by_param, std_lut):
+ loss.append(np.inf)
+ continue
+
if (
with_function_leaves
and self.param_type[param_index] == ParamType.SCALAR
@@ -1041,6 +1056,7 @@ class ModelAttribute:
):
# param can be modeled as a function. Do not split on it.
loss.append(np.inf)
+ ffs_feasible = True
continue
child_indexes = list()
@@ -1075,13 +1091,13 @@ class ModelAttribute:
children.extend((np.array(child_data) - np.mean(child_data)) ** 2)
if np.any(np.isnan(children)):
- loss.append(np.inf) #
+ loss.append(np.inf)
else:
loss.append(np.sum(children))
if np.all(np.isinf(loss)):
# all children have the same configuration. We shouldn't get here due to the threshold check above...
- if with_function_leaves:
+ if ffs_feasible:
# try generating a function. if it fails, model_function is a StaticFunction.
ma = ModelAttribute(
"tmp",