diff options
-rw-r--r-- | lib/model.py | 29 | ||||
-rw-r--r-- | lib/parameters.py | 24 |
2 files changed, 42 insertions, 11 deletions
diff --git a/lib/model.py b/lib/model.py index f43d6a2..c9b0e33 100644 --- a/lib/model.py +++ b/lib/model.py @@ -234,21 +234,33 @@ class AnalyticModel: if not self.fit_done: paramfit = ParamFit() + tree_required = dict() for name in self.names: + tree_required[name] = dict() for attr in self.attr_by_name[name].keys(): - for key, param, args, kwargs in self.attr_by_name[name][ + if self.attr_by_name[name][ attr - ].get_data_for_paramfit( - safe_functions_enabled=safe_functions_enabled - ): - paramfit.enqueue(key, param, args, kwargs) + ].all_relevant_parameters_are_none_or_numeric(): + for key, param, args, kwargs in self.attr_by_name[name][ + attr + ].get_data_for_paramfit( + safe_functions_enabled=safe_functions_enabled + ): + paramfit.enqueue(key, param, args, kwargs) + else: + tree_required[name][attr] = self.attr_by_name[name][ + attr + ].depends_on_any_param() paramfit.fit() for name in self.names: for attr in self.attr_by_name[name].keys(): - self.attr_by_name[name][attr].set_data_from_paramfit(paramfit) + if tree_required[name].get(attr, False): + self.build_dtree(name, attr, 0.1, with_function_leaves=True) + else: + self.attr_by_name[name][attr].set_data_from_paramfit(paramfit) self.fit_done = True @@ -316,7 +328,7 @@ class AnalyticModel: return detailed_results - def build_dtree(self, name, attribute, threshold=100): + def build_dtree(self, name, attribute, threshold=100, with_function_leaves=False): if name not in self.attr_by_name: self.attr_by_name[name] = dict() @@ -331,7 +343,8 @@ class AnalyticModel: ) # temporary hack for ResKIL / kconfig-webconf evaluation of regression trees with function nodes - with_function_leaves = bool(os.getenv("DFATOOL_DTREE_FUNCTION_LEAVES")) + if not with_function_leaves: + with_function_leaves = bool(os.getenv("DFATOOL_DTREE_FUNCTION_LEAVES")) if ( with_function_leaves and attribute in "accuracy model_size_mb power_w".split() diff --git a/lib/parameters.py b/lib/parameters.py index 82b1fdc..9ecd7e9 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -291,9 +291,7 @@ def _compute_param_statistics_parallel(arg): def _all_params_are_numeric(data, param_idx): """Check if all `data['param'][*][param_idx]` elements are numeric, as reported by `utils.is_numeric`.""" param_values = list(map(lambda x: x[param_idx], data)) - if len(list(filter(is_numeric, param_values))) == len(param_values): - return True - return False + return all(map(is_numeric, param_values)) class ParallelParamStats: @@ -738,6 +736,26 @@ class ModelAttribute: new_param_values.append(param_tuple) return partition_by_param(self.data, new_param_values) + def depends_on_any_param(self): + for param_index, param_name in enumerate(self.param_names): + if ( + self.stats.depends_on_param(param_name) + and not param_index in self.ignore_param + ): + return True + return False + + def all_relevant_parameters_are_none_or_numeric(self): + for param_index, param_name in enumerate(self.param_names): + if ( + self.stats.depends_on_param(param_name) + and not param_index in self.ignore_param + ): + param_values = list(map(lambda x: x[param_index], self.param_values)) + if not all(map(lambda n: n is None or is_numeric(n), param_values)): + return False + return True + def get_data_for_paramfit_this(self, safe_functions_enabled=False): ret = list() for param_index, param_name in enumerate(self.param_names): |