diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/parameters.py | 25 |
1 files changed, 13 insertions, 12 deletions
diff --git a/lib/parameters.py b/lib/parameters.py index 266bcec..e199153 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -901,24 +901,24 @@ class ModelAttribute: return df.StaticFunction(np.mean(data)) # sf.value_error["std"] = np.std(data) - mean_stds = list() + loss = list() for param_index in range(param_count): if param_index in self.ignore_param: - mean_stds.append(np.inf) + loss.append(np.inf) continue unique_values = list(set(map(lambda p: p[param_index], parameters))) if None in unique_values: # param is a choice and undefined in some configs. Do not split on it. - mean_stds.append(np.inf) + loss.append(np.inf) continue # if not with_nonbinary_nodes and sorted(unique_values) != [0, 1]: if not with_nonbinary_nodes and len(unique_values) > 2: # param cannot be handled with a binary split - mean_stds.append(np.inf) + loss.append(np.inf) continue if ( @@ -927,7 +927,7 @@ class ModelAttribute: and len(unique_values) >= self.min_values_for_analytic_model ): # param can be modeled as a function. Do not split on it. - mean_stds.append(np.inf) + loss.append(np.inf) continue child_indexes = list() @@ -943,7 +943,7 @@ class ModelAttribute: if len(list(filter(len, child_indexes))) <= 1: # this param only has a single value. there's no point in splitting. - mean_stds.append(np.inf) + loss.append(np.inf) continue children = list() @@ -956,16 +956,17 @@ class ModelAttribute: child_param, ignore_parameters=self.nonscalar_param_indexes, ) - children.extend(map(np.std, child_data_by_scalar.values())) + for sub_data in child_data_by_scalar.values(): + children.extend((np.array(sub_data) - np.mean(sub_data)) ** 2) else: - children.append(np.std(child_data)) + children.extend((np.array(child_data) - np.mean(child_data)) ** 2) if np.any(np.isnan(children)): - mean_stds.append(np.inf) + loss.append(np.inf) # else: - mean_stds.append(np.mean(children)) + loss.append(np.sum(children)) - if np.all(np.isinf(mean_stds)): + if np.all(np.isinf(loss)): # all children have the same configuration. We shouldn't get here due to the threshold check above... if with_function_leaves: # try generating a function. if it fails, model_function is a StaticFunction. @@ -991,7 +992,7 @@ class ModelAttribute: # ) return df.StaticFunction(np.mean(data)) - symbol_index = np.argmin(mean_stds) + symbol_index = np.argmin(loss) unique_values = list(set(map(lambda p: p[symbol_index], parameters))) child = dict() |