summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2021-08-20 10:49:17 +0200
committerDaniel Friesel <daniel.friesel@uos.de>2021-08-20 10:49:17 +0200
commit2b97f76c58bfb81e15f1cb463cedf82272c3194b (patch)
tree25a7b1fdb8957ff5d633796179e9fb06fcececf8
parent26d07c8eae44d0a6919fd775728cdb1bb2808298 (diff)
model: support for decision tree with function leaves
pretty hacky at the moment, but good enough for eval
-rw-r--r--lib/model.py47
1 files changed, 41 insertions, 6 deletions
diff --git a/lib/model.py b/lib/model.py
index 98a9d4f..3595e28 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -330,11 +330,24 @@ class AnalyticModel:
self.parameters,
)
+ # temporary hack for ResKIL / kconfig-webconf evaluation of regression trees with function nodes
+ with_function_leaves = bool(os.getenv("DFATOOL_DTREE_FUNCTION_LEAVES"))
+ if (
+ with_function_leaves
+ and attribute in "accuracy model_size_mb power_w".split()
+ ):
+ with_function_leaves = False
+
self.attr_by_name[name][attribute].model_function = self._build_dtree(
- self.by_name[name]["param"], self.by_name[name][attribute], threshold
+ self.by_name[name]["param"],
+ self.by_name[name][attribute],
+ with_function_leaves,
+ threshold,
)
- def _build_dtree(self, parameters, data, threshold=100, level=0):
+ def _build_dtree(
+ self, parameters, data, with_function_leaves=False, threshold=100, level=0
+ ):
"""
Build a Decision Tree on `param` / `data` for kconfig models.
@@ -364,6 +377,11 @@ class AnalyticModel:
mean_stds.append(np.inf)
continue
+ # temporary hack for ResKIL / kconfig-webconf evaluation of regression trees with function nodes
+ if with_function_leaves and param == "batch_size":
+ mean_stds.append(np.inf)
+ continue
+
child_indexes = list()
for value in unique_values:
child_indexes.append(
@@ -391,9 +409,22 @@ class AnalyticModel:
if np.all(np.isinf(mean_stds)):
# all children have the same configuration. We shouldn't get here due to the threshold check above...
- logging.warning(
- f"While building DTree for configurations {parameters}: Children have identical configuration, but high stddev ({np.std(data)}). Falling back to Staticfunction"
- )
+ # temporary hack for ResKIL / kconfig-webconf evaluation of regression trees with function nodes
+ if with_function_leaves and "batch_size" in parameter_names:
+ ma = ModelAttribute("tmp", "tmp", data, parameters, self.parameters, 0)
+ ParamStats.compute_for_attr(ma)
+ paramfit = ParamFit(parallel=False)
+ for key, param, args, kwargs in ma.get_data_for_paramfit():
+ paramfit.enqueue(key, param, args, kwargs)
+ paramfit.fit()
+ ma.set_data_from_paramfit(paramfit)
+ print(ma)
+ print(ma.model_function)
+ return ma.model_function
+ else:
+ logging.warning(
+ f"While building DTree for configurations {parameters}: Children have identical configuration, but high stddev ({np.std(data)}). Falling back to Staticfunction"
+ )
return StaticFunction(np.mean(data))
symbol_index = np.argmin(mean_stds)
@@ -414,7 +445,11 @@ class AnalyticModel:
child_data = list(map(lambda i: data[i], indexes))
if len(child_data):
child[value] = self._build_dtree(
- child_parameters, child_data, threshold, level + 1
+ child_parameters,
+ child_data,
+ with_function_leaves,
+ threshold,
+ level + 1,
)
assert len(child.values()) >= 2