summaryrefslogtreecommitdiff
path: root/lib/model.py
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2021-06-23 16:35:26 +0200
committerDaniel Friesel <daniel.friesel@uos.de>2021-06-23 16:35:26 +0200
commite761a12419477776a0846de5210d58c8ee57ffee (patch)
tree77e263fbd94a9e32179d4ecb9d4c3ccdc3d3ca71 /lib/model.py
parent31a87322198be21b2db3e71fd3070e198d368333 (diff)
simplify dtree builder
Diffstat (limited to 'lib/model.py')
-rw-r--r--lib/model.py77
1 files changed, 52 insertions, 25 deletions
diff --git a/lib/model.py b/lib/model.py
index 8c00fd2..293d2a9 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -316,9 +316,27 @@ class AnalyticModel:
return detailed_results
- def build_tree(self, this_symbols, this_data, data_index=0, threshold=100, level=0):
+ def build_dtree(self, name, attribute, threshold=100):
+
+ if name not in self.attr_by_name:
+ self.attr_by_name[name] = dict()
+
+ if attribute not in self.attr_by_name[name]:
+ self.attr_by_name[name][attribute] = ModelAttribute(
+ name,
+ attribute,
+ self.by_name[name][attribute],
+ self.by_name[name]["param"],
+ self.parameters,
+ )
+
+ self.attr_by_name[name][attribute].model_function = self._build_dtree(
+ self.by_name[name]["param"], self.by_name[name][attribute], threshold
+ )
+
+ def _build_dtree(self, parameters, data, threshold=100, level=0):
"""
- Build a Decision Tree on `this_data` for kconfig models.
+ Build a Decision Tree on `param` / `data` for kconfig models.
:param this_symbols: parameter names
:param this_data: list of measurements. Each entry is a (param vector, mearusements vector) tuple.
@@ -329,36 +347,40 @@ class AnalyticModel:
:returns: SplitFunction or StaticFunction
"""
- rom_sizes = list(map(lambda x: x[1][data_index], this_data))
-
- if np.std(rom_sizes) < threshold or len(this_symbols) == 0:
- return StaticFunction(np.mean(rom_sizes))
- # sf.value_error["std"] = np.std(rom_sizes)
+ parameter_names = self.parameters
+ if len(parameter_names) == 0 or np.std(data) < threshold:
+ return StaticFunction(np.mean(data))
+ # sf.value_error["std"] = np.std(data)
mean_stds = list()
- for i, param in enumerate(this_symbols):
+ for param_index, param in enumerate(parameter_names):
- unique_values = list(set(map(lambda vrr: vrr[0][i], this_data)))
+ unique_values = list(set(map(lambda p: p[param_index], parameters)))
if None in unique_values:
# param is a choice and undefined in some configs. Do not split on it.
mean_stds.append(np.inf)
continue
- child_values = list()
+ child_indexes = list()
for value in unique_values:
- child_values.append(
- list(filter(lambda vrr: vrr[0][i] == value, this_data))
+ child_indexes.append(
+ list(
+ filter(
+ lambda i: parameters[i][param_index] == value,
+ range(len(parameters)),
+ )
+ )
)
- if len(list(filter(len, child_values))) < 2:
+ if len(list(filter(len, child_indexes))) < 2:
# this param only has a single value. there's no point in splitting.
mean_stds.append(np.inf)
continue
children = list()
- for child in child_values:
- children.append(np.std(list(map(lambda x: x[1][data_index], child))))
+ for child in child_indexes:
+ children.append(np.std(list(map(lambda i: data[i], child))))
if np.any(np.isnan(children)):
mean_stds.append(np.inf)
@@ -368,30 +390,35 @@ class AnalyticModel:
if np.all(np.isinf(mean_stds)):
# all children have the same configuration. We shouldn't get here due to the threshold check above...
logging.warning("Waht")
- return StaticFunction(np.mean(rom_sizes))
+ return StaticFunction(np.mean(data))
symbol_index = np.argmin(mean_stds)
- symbol = this_symbols[symbol_index]
+ symbol = parameter_names[symbol_index]
- unique_values = list(set(map(lambda vrr: vrr[0][symbol_index], this_data)))
+ unique_values = list(set(map(lambda p: p[symbol_index], parameters)))
child = dict()
for value in unique_values:
- children = list(
- filter(lambda vrr: vrr[0][symbol_index] == value, this_data)
+ indexes = list(
+ filter(
+ lambda i: parameters[i][symbol_index] == value,
+ range(len(parameters)),
+ )
)
- if len(children):
+ child_parameters = list(map(lambda i: parameters[i], indexes))
+ child_data = list(map(lambda i: data[i], indexes))
+ if len(child_data):
print(
- f"Level {level} split on {symbol} == {value} has {len(children)} children"
+ f"Level {level} split on {symbol} == {value} has {len(child_data)} children"
)
- child[value] = self.build_tree(
- this_symbols, children, data_index, threshold, level + 1
+ child[value] = self._build_dtree(
+ child_parameters, child_data, threshold, level + 1
)
assert len(child.values()) >= 2
- return SplitFunction(np.mean(rom_sizes), symbol_index, child)
+ return SplitFunction(np.mean(data), symbol_index, child)
def to_dref(self, static_quality, lut_quality, model_quality) -> dict:
ret = dict()