summaryrefslogtreecommitdiff
path: root/lib/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/model.py')
-rw-r--r--lib/model.py29
1 files changed, 21 insertions, 8 deletions
diff --git a/lib/model.py b/lib/model.py
index f43d6a2..c9b0e33 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -234,21 +234,33 @@ class AnalyticModel:
if not self.fit_done:
paramfit = ParamFit()
+ tree_required = dict()
for name in self.names:
+ tree_required[name] = dict()
for attr in self.attr_by_name[name].keys():
- for key, param, args, kwargs in self.attr_by_name[name][
+ if self.attr_by_name[name][
attr
- ].get_data_for_paramfit(
- safe_functions_enabled=safe_functions_enabled
- ):
- paramfit.enqueue(key, param, args, kwargs)
+ ].all_relevant_parameters_are_none_or_numeric():
+ for key, param, args, kwargs in self.attr_by_name[name][
+ attr
+ ].get_data_for_paramfit(
+ safe_functions_enabled=safe_functions_enabled
+ ):
+ paramfit.enqueue(key, param, args, kwargs)
+ else:
+ tree_required[name][attr] = self.attr_by_name[name][
+ attr
+ ].depends_on_any_param()
paramfit.fit()
for name in self.names:
for attr in self.attr_by_name[name].keys():
- self.attr_by_name[name][attr].set_data_from_paramfit(paramfit)
+ if tree_required[name].get(attr, False):
+ self.build_dtree(name, attr, 0.1, with_function_leaves=True)
+ else:
+ self.attr_by_name[name][attr].set_data_from_paramfit(paramfit)
self.fit_done = True
@@ -316,7 +328,7 @@ class AnalyticModel:
return detailed_results
- def build_dtree(self, name, attribute, threshold=100):
+ def build_dtree(self, name, attribute, threshold=100, with_function_leaves=False):
if name not in self.attr_by_name:
self.attr_by_name[name] = dict()
@@ -331,7 +343,8 @@ class AnalyticModel:
)
# temporary hack for ResKIL / kconfig-webconf evaluation of regression trees with function nodes
- with_function_leaves = bool(os.getenv("DFATOOL_DTREE_FUNCTION_LEAVES"))
+ if not with_function_leaves:
+ with_function_leaves = bool(os.getenv("DFATOOL_DTREE_FUNCTION_LEAVES"))
if (
with_function_leaves
and attribute in "accuracy model_size_mb power_w".split()