diff options
Diffstat (limited to 'lib/model.py')
-rw-r--r-- | lib/model.py | 115 |
1 files changed, 1 insertions, 114 deletions
diff --git a/lib/model.py b/lib/model.py index dc0cb4d..76444b0 100644 --- a/lib/model.py +++ b/lib/model.py @@ -346,126 +346,13 @@ class AnalyticModel: ): with_function_leaves = False - self.attr_by_name[name][attribute].model_function = self._build_dtree( + self.attr_by_name[name][attribute].build_dtree( self.by_name[name]["param"], self.by_name[name][attribute], with_function_leaves, threshold, ) - def _build_dtree( - self, parameters, data, with_function_leaves=False, threshold=100, level=0 - ): - """ - Build a Decision Tree on `param` / `data` for kconfig models. - - :param this_symbols: parameter names - :param this_data: list of measurements. Each entry is a (param vector, mearusements vector) tuple. - param vector holds parameter values (same order as parameter names). mearuserements vector holds measurements. - :param data_index: Index in measurements vector to use for model generation. Default 0. - :param threshold: Return a StaticFunction leaf node if std(data[data_index]) < threshold. Default 100. - - :returns: SplitFunction or StaticFunction - """ - - # TODO remove data entries which are None (and remove corresponding parameters, too!) - - parameter_names = self.parameters - if len(parameter_names) == 0 or np.std(data) < threshold: - return StaticFunction(np.mean(data)) - # sf.value_error["std"] = np.std(data) - - mean_stds = list() - for param_index, param in enumerate(parameter_names): - - unique_values = list(set(map(lambda p: p[param_index], parameters))) - - if None in unique_values: - # param is a choice and undefined in some configs. Do not split on it. - mean_stds.append(np.inf) - continue - - if ( - with_function_leaves - and len(unique_values) > 3 - and all(map(lambda x: type(x) is int, unique_values)) - ): - # param can be modeled as a function. Do not split on it. - mean_stds.append(np.inf) - continue - - child_indexes = list() - for value in unique_values: - child_indexes.append( - list( - filter( - lambda i: parameters[i][param_index] == value, - range(len(parameters)), - ) - ) - ) - - if len(list(filter(len, child_indexes))) < 2: - # this param only has a single value. there's no point in splitting. - mean_stds.append(np.inf) - continue - - children = list() - for child in child_indexes: - children.append(np.std(list(map(lambda i: data[i], child)))) - - if np.any(np.isnan(children)): - mean_stds.append(np.inf) - else: - mean_stds.append(np.mean(children)) - - if np.all(np.isinf(mean_stds)): - # all children have the same configuration. We shouldn't get here due to the threshold check above... - if with_function_leaves: - # try generating a function. if it fails, model_function is a StaticFunction. - ma = ModelAttribute("tmp", "tmp", data, parameters, self.parameters, 0) - ParamStats.compute_for_attr(ma) - paramfit = ParamFit(parallel=False) - for key, param, args, kwargs in ma.get_data_for_paramfit(): - paramfit.enqueue(key, param, args, kwargs) - paramfit.fit() - ma.set_data_from_paramfit(paramfit) - return ma.model_function - else: - logging.warning( - f"While building DTree for configurations {parameters}: Children have identical configuration, but high stddev ({np.std(data)}). Falling back to Staticfunction" - ) - return StaticFunction(np.mean(data)) - - symbol_index = np.argmin(mean_stds) - symbol = parameter_names[symbol_index] - - unique_values = list(set(map(lambda p: p[symbol_index], parameters))) - - child = dict() - - for value in unique_values: - indexes = list( - filter( - lambda i: parameters[i][symbol_index] == value, - range(len(parameters)), - ) - ) - child_parameters = list(map(lambda i: parameters[i], indexes)) - child_data = list(map(lambda i: data[i], indexes)) - if len(child_data): - child[value] = self._build_dtree( - child_parameters, - child_data, - with_function_leaves, - threshold, - level + 1, - ) - - assert len(child.values()) >= 2 - - return SplitFunction(np.mean(data), symbol_index, child) - def to_dref(self, static_quality, lut_quality, model_quality) -> dict: ret = dict() for name in self.names: |