diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2021-08-20 10:49:17 +0200 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2021-08-20 10:49:17 +0200 |
commit | 2b97f76c58bfb81e15f1cb463cedf82272c3194b (patch) | |
tree | 25a7b1fdb8957ff5d633796179e9fb06fcececf8 | |
parent | 26d07c8eae44d0a6919fd775728cdb1bb2808298 (diff) |
model: support for decision tree with function leaves
pretty hacky at the moment, but good enough for eval
-rw-r--r-- | lib/model.py | 47 |
1 files changed, 41 insertions, 6 deletions
diff --git a/lib/model.py b/lib/model.py index 98a9d4f..3595e28 100644 --- a/lib/model.py +++ b/lib/model.py @@ -330,11 +330,24 @@ class AnalyticModel: self.parameters, ) + # temporary hack for ResKIL / kconfig-webconf evaluation of regression trees with function nodes + with_function_leaves = bool(os.getenv("DFATOOL_DTREE_FUNCTION_LEAVES")) + if ( + with_function_leaves + and attribute in "accuracy model_size_mb power_w".split() + ): + with_function_leaves = False + self.attr_by_name[name][attribute].model_function = self._build_dtree( - self.by_name[name]["param"], self.by_name[name][attribute], threshold + self.by_name[name]["param"], + self.by_name[name][attribute], + with_function_leaves, + threshold, ) - def _build_dtree(self, parameters, data, threshold=100, level=0): + def _build_dtree( + self, parameters, data, with_function_leaves=False, threshold=100, level=0 + ): """ Build a Decision Tree on `param` / `data` for kconfig models. @@ -364,6 +377,11 @@ class AnalyticModel: mean_stds.append(np.inf) continue + # temporary hack for ResKIL / kconfig-webconf evaluation of regression trees with function nodes + if with_function_leaves and param == "batch_size": + mean_stds.append(np.inf) + continue + child_indexes = list() for value in unique_values: child_indexes.append( @@ -391,9 +409,22 @@ class AnalyticModel: if np.all(np.isinf(mean_stds)): # all children have the same configuration. We shouldn't get here due to the threshold check above... - logging.warning( - f"While building DTree for configurations {parameters}: Children have identical configuration, but high stddev ({np.std(data)}). Falling back to Staticfunction" - ) + # temporary hack for ResKIL / kconfig-webconf evaluation of regression trees with function nodes + if with_function_leaves and "batch_size" in parameter_names: + ma = ModelAttribute("tmp", "tmp", data, parameters, self.parameters, 0) + ParamStats.compute_for_attr(ma) + paramfit = ParamFit(parallel=False) + for key, param, args, kwargs in ma.get_data_for_paramfit(): + paramfit.enqueue(key, param, args, kwargs) + paramfit.fit() + ma.set_data_from_paramfit(paramfit) + print(ma) + print(ma.model_function) + return ma.model_function + else: + logging.warning( + f"While building DTree for configurations {parameters}: Children have identical configuration, but high stddev ({np.std(data)}). Falling back to Staticfunction" + ) return StaticFunction(np.mean(data)) symbol_index = np.argmin(mean_stds) @@ -414,7 +445,11 @@ class AnalyticModel: child_data = list(map(lambda i: data[i], indexes)) if len(child_data): child[value] = self._build_dtree( - child_parameters, child_data, threshold, level + 1 + child_parameters, + child_data, + with_function_leaves, + threshold, + level + 1, ) assert len(child.values()) >= 2 |