diff options
Diffstat (limited to 'lib/model.py')
-rw-r--r-- | lib/model.py | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/lib/model.py b/lib/model.py index 3595e28..f43d6a2 100644 --- a/lib/model.py +++ b/lib/model.py @@ -377,8 +377,12 @@ class AnalyticModel: mean_stds.append(np.inf) continue - # temporary hack for ResKIL / kconfig-webconf evaluation of regression trees with function nodes - if with_function_leaves and param == "batch_size": + if ( + with_function_leaves + and len(unique_values) > 3 + and all(map(lambda x: type(x) is int, unique_values)) + ): + # param can be modeled as a function. Do not split on it. mean_stds.append(np.inf) continue @@ -409,8 +413,8 @@ class AnalyticModel: if np.all(np.isinf(mean_stds)): # all children have the same configuration. We shouldn't get here due to the threshold check above... - # temporary hack for ResKIL / kconfig-webconf evaluation of regression trees with function nodes - if with_function_leaves and "batch_size" in parameter_names: + if with_function_leaves: + # try generating a function. if it fails, model_function is a StaticFunction. ma = ModelAttribute("tmp", "tmp", data, parameters, self.parameters, 0) ParamStats.compute_for_attr(ma) paramfit = ParamFit(parallel=False) @@ -418,8 +422,6 @@ class AnalyticModel: paramfit.enqueue(key, param, args, kwargs) paramfit.fit() ma.set_data_from_paramfit(paramfit) - print(ma) - print(ma.model_function) return ma.model_function else: logging.warning( |