diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2021-06-23 16:35:26 +0200 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2021-06-23 16:35:26 +0200 |
commit | e761a12419477776a0846de5210d58c8ee57ffee (patch) | |
tree | 77e263fbd94a9e32179d4ecb9d4c3ccdc3d3ca71 /lib/model.py | |
parent | 31a87322198be21b2db3e71fd3070e198d368333 (diff) |
simplify dtree builder
Diffstat (limited to 'lib/model.py')
-rw-r--r-- | lib/model.py | 77 |
1 files changed, 52 insertions, 25 deletions
diff --git a/lib/model.py b/lib/model.py index 8c00fd2..293d2a9 100644 --- a/lib/model.py +++ b/lib/model.py @@ -316,9 +316,27 @@ class AnalyticModel: return detailed_results - def build_tree(self, this_symbols, this_data, data_index=0, threshold=100, level=0): + def build_dtree(self, name, attribute, threshold=100): + + if name not in self.attr_by_name: + self.attr_by_name[name] = dict() + + if attribute not in self.attr_by_name[name]: + self.attr_by_name[name][attribute] = ModelAttribute( + name, + attribute, + self.by_name[name][attribute], + self.by_name[name]["param"], + self.parameters, + ) + + self.attr_by_name[name][attribute].model_function = self._build_dtree( + self.by_name[name]["param"], self.by_name[name][attribute], threshold + ) + + def _build_dtree(self, parameters, data, threshold=100, level=0): """ - Build a Decision Tree on `this_data` for kconfig models. + Build a Decision Tree on `param` / `data` for kconfig models. :param this_symbols: parameter names :param this_data: list of measurements. Each entry is a (param vector, mearusements vector) tuple. @@ -329,36 +347,40 @@ class AnalyticModel: :returns: SplitFunction or StaticFunction """ - rom_sizes = list(map(lambda x: x[1][data_index], this_data)) - - if np.std(rom_sizes) < threshold or len(this_symbols) == 0: - return StaticFunction(np.mean(rom_sizes)) - # sf.value_error["std"] = np.std(rom_sizes) + parameter_names = self.parameters + if len(parameter_names) == 0 or np.std(data) < threshold: + return StaticFunction(np.mean(data)) + # sf.value_error["std"] = np.std(data) mean_stds = list() - for i, param in enumerate(this_symbols): + for param_index, param in enumerate(parameter_names): - unique_values = list(set(map(lambda vrr: vrr[0][i], this_data))) + unique_values = list(set(map(lambda p: p[param_index], parameters))) if None in unique_values: # param is a choice and undefined in some configs. Do not split on it. mean_stds.append(np.inf) continue - child_values = list() + child_indexes = list() for value in unique_values: - child_values.append( - list(filter(lambda vrr: vrr[0][i] == value, this_data)) + child_indexes.append( + list( + filter( + lambda i: parameters[i][param_index] == value, + range(len(parameters)), + ) + ) ) - if len(list(filter(len, child_values))) < 2: + if len(list(filter(len, child_indexes))) < 2: # this param only has a single value. there's no point in splitting. mean_stds.append(np.inf) continue children = list() - for child in child_values: - children.append(np.std(list(map(lambda x: x[1][data_index], child)))) + for child in child_indexes: + children.append(np.std(list(map(lambda i: data[i], child)))) if np.any(np.isnan(children)): mean_stds.append(np.inf) @@ -368,30 +390,35 @@ class AnalyticModel: if np.all(np.isinf(mean_stds)): # all children have the same configuration. We shouldn't get here due to the threshold check above... logging.warning("Waht") - return StaticFunction(np.mean(rom_sizes)) + return StaticFunction(np.mean(data)) symbol_index = np.argmin(mean_stds) - symbol = this_symbols[symbol_index] + symbol = parameter_names[symbol_index] - unique_values = list(set(map(lambda vrr: vrr[0][symbol_index], this_data))) + unique_values = list(set(map(lambda p: p[symbol_index], parameters))) child = dict() for value in unique_values: - children = list( - filter(lambda vrr: vrr[0][symbol_index] == value, this_data) + indexes = list( + filter( + lambda i: parameters[i][symbol_index] == value, + range(len(parameters)), + ) ) - if len(children): + child_parameters = list(map(lambda i: parameters[i], indexes)) + child_data = list(map(lambda i: data[i], indexes)) + if len(child_data): print( - f"Level {level} split on {symbol} == {value} has {len(children)} children" + f"Level {level} split on {symbol} == {value} has {len(child_data)} children" ) - child[value] = self.build_tree( - this_symbols, children, data_index, threshold, level + 1 + child[value] = self._build_dtree( + child_parameters, child_data, threshold, level + 1 ) assert len(child.values()) >= 2 - return SplitFunction(np.mean(rom_sizes), symbol_index, child) + return SplitFunction(np.mean(data), symbol_index, child) def to_dref(self, static_quality, lut_quality, model_quality) -> dict: ret = dict() |