summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/model.py29
-rw-r--r--lib/parameters.py24
2 files changed, 42 insertions, 11 deletions
diff --git a/lib/model.py b/lib/model.py
index f43d6a2..c9b0e33 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -234,21 +234,33 @@ class AnalyticModel:
if not self.fit_done:
paramfit = ParamFit()
+ tree_required = dict()
for name in self.names:
+ tree_required[name] = dict()
for attr in self.attr_by_name[name].keys():
- for key, param, args, kwargs in self.attr_by_name[name][
+ if self.attr_by_name[name][
attr
- ].get_data_for_paramfit(
- safe_functions_enabled=safe_functions_enabled
- ):
- paramfit.enqueue(key, param, args, kwargs)
+ ].all_relevant_parameters_are_none_or_numeric():
+ for key, param, args, kwargs in self.attr_by_name[name][
+ attr
+ ].get_data_for_paramfit(
+ safe_functions_enabled=safe_functions_enabled
+ ):
+ paramfit.enqueue(key, param, args, kwargs)
+ else:
+ tree_required[name][attr] = self.attr_by_name[name][
+ attr
+ ].depends_on_any_param()
paramfit.fit()
for name in self.names:
for attr in self.attr_by_name[name].keys():
- self.attr_by_name[name][attr].set_data_from_paramfit(paramfit)
+ if tree_required[name].get(attr, False):
+ self.build_dtree(name, attr, 0.1, with_function_leaves=True)
+ else:
+ self.attr_by_name[name][attr].set_data_from_paramfit(paramfit)
self.fit_done = True
@@ -316,7 +328,7 @@ class AnalyticModel:
return detailed_results
- def build_dtree(self, name, attribute, threshold=100):
+ def build_dtree(self, name, attribute, threshold=100, with_function_leaves=False):
if name not in self.attr_by_name:
self.attr_by_name[name] = dict()
@@ -331,7 +343,8 @@ class AnalyticModel:
)
# temporary hack for ResKIL / kconfig-webconf evaluation of regression trees with function nodes
- with_function_leaves = bool(os.getenv("DFATOOL_DTREE_FUNCTION_LEAVES"))
+ if not with_function_leaves:
+ with_function_leaves = bool(os.getenv("DFATOOL_DTREE_FUNCTION_LEAVES"))
if (
with_function_leaves
and attribute in "accuracy model_size_mb power_w".split()
diff --git a/lib/parameters.py b/lib/parameters.py
index 82b1fdc..9ecd7e9 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -291,9 +291,7 @@ def _compute_param_statistics_parallel(arg):
def _all_params_are_numeric(data, param_idx):
"""Check if all `data['param'][*][param_idx]` elements are numeric, as reported by `utils.is_numeric`."""
param_values = list(map(lambda x: x[param_idx], data))
- if len(list(filter(is_numeric, param_values))) == len(param_values):
- return True
- return False
+ return all(map(is_numeric, param_values))
class ParallelParamStats:
@@ -738,6 +736,26 @@ class ModelAttribute:
new_param_values.append(param_tuple)
return partition_by_param(self.data, new_param_values)
+ def depends_on_any_param(self):
+ for param_index, param_name in enumerate(self.param_names):
+ if (
+ self.stats.depends_on_param(param_name)
+ and not param_index in self.ignore_param
+ ):
+ return True
+ return False
+
+ def all_relevant_parameters_are_none_or_numeric(self):
+ for param_index, param_name in enumerate(self.param_names):
+ if (
+ self.stats.depends_on_param(param_name)
+ and not param_index in self.ignore_param
+ ):
+ param_values = list(map(lambda x: x[param_index], self.param_values))
+ if not all(map(lambda n: n is None or is_numeric(n), param_values)):
+ return False
+ return True
+
def get_data_for_paramfit_this(self, safe_functions_enabled=False):
ret = list()
for param_index, param_name in enumerate(self.param_names):