summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/model.py14
1 files changed, 8 insertions, 6 deletions
diff --git a/lib/model.py b/lib/model.py
index 3595e28..f43d6a2 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -377,8 +377,12 @@ class AnalyticModel:
mean_stds.append(np.inf)
continue
- # temporary hack for ResKIL / kconfig-webconf evaluation of regression trees with function nodes
- if with_function_leaves and param == "batch_size":
+ if (
+ with_function_leaves
+ and len(unique_values) > 3
+ and all(map(lambda x: type(x) is int, unique_values))
+ ):
+ # param can be modeled as a function. Do not split on it.
mean_stds.append(np.inf)
continue
@@ -409,8 +413,8 @@ class AnalyticModel:
if np.all(np.isinf(mean_stds)):
# all children have the same configuration. We shouldn't get here due to the threshold check above...
- # temporary hack for ResKIL / kconfig-webconf evaluation of regression trees with function nodes
- if with_function_leaves and "batch_size" in parameter_names:
+ if with_function_leaves:
+ # try generating a function. if it fails, model_function is a StaticFunction.
ma = ModelAttribute("tmp", "tmp", data, parameters, self.parameters, 0)
ParamStats.compute_for_attr(ma)
paramfit = ParamFit(parallel=False)
@@ -418,8 +422,6 @@ class AnalyticModel:
paramfit.enqueue(key, param, args, kwargs)
paramfit.fit()
ma.set_data_from_paramfit(paramfit)
- print(ma)
- print(ma.model_function)
return ma.model_function
else:
logging.warning(