diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2022-01-31 16:47:21 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2022-01-31 16:47:21 +0100 |
commit | 063b277b32e8bdde5ce04efdf236c7821c79a3da (patch) | |
tree | b445255c45b5d6f71296d3f255abd84224cc19cc /lib | |
parent | c2a5c5675bce5a357f4b7f7b5020dedc7fc33caf (diff) |
RMT generation: only split on relevant variables
Diffstat (limited to 'lib')
-rw-r--r-- | lib/parameters.py | 22 |
1 files changed, 19 insertions, 3 deletions
diff --git a/lib/parameters.py b/lib/parameters.py index a619103..faa2e96 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -1006,7 +1006,7 @@ class ModelAttribute: :param loss_ignore_scalar: Ignore scalar parameters when computing the loss for split candidates. Only sensible if with_function_leaves is enabled. :param threshold: Return a StaticFunction leaf node if std(data) < threshold. Default 100. - :returns: SplitFunction or StaticFunction + :returns: ModelFunction """ param_count = len(self.param_names) + self.arg_count @@ -1015,6 +1015,14 @@ class ModelAttribute: # sf.value_error["std"] = np.std(data) loss = list() + + ffs_feasible = False + by_param = partition_by_param(data, parameters) + distinct_values_by_param_index = distinct_param_values( + parameters + ) # required, "unique_values" in for loop is insufficient for std_by_param foo + std_static = np.std(data) + std_lut = np.mean([np.std(v) for v in by_param.values()]) for param_index in range(param_count): if param_index in self.ignore_param: @@ -1034,6 +1042,13 @@ class ModelAttribute: loss.append(np.inf) continue + std_by_param = _mean_std_by_param( + by_param, distinct_values_by_param_index, param_index + ) + if not _depends_on_param(None, std_by_param, std_lut): + loss.append(np.inf) + continue + if ( with_function_leaves and self.param_type[param_index] == ParamType.SCALAR @@ -1041,6 +1056,7 @@ class ModelAttribute: ): # param can be modeled as a function. Do not split on it. loss.append(np.inf) + ffs_feasible = True continue child_indexes = list() @@ -1075,13 +1091,13 @@ class ModelAttribute: children.extend((np.array(child_data) - np.mean(child_data)) ** 2) if np.any(np.isnan(children)): - loss.append(np.inf) # + loss.append(np.inf) else: loss.append(np.sum(children)) if np.all(np.isinf(loss)): # all children have the same configuration. We shouldn't get here due to the threshold check above... - if with_function_leaves: + if ffs_feasible: # try generating a function. if it fails, model_function is a StaticFunction. ma = ModelAttribute( "tmp", |