summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/parameters.py14
1 files changed, 12 insertions, 2 deletions
diff --git a/lib/parameters.py b/lib/parameters.py
index cc6e10c..eae04c8 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -524,6 +524,10 @@ class ModelAttribute:
self._check_codependent_param()
+ # There must be at least 3 distinct data values (≠ None) if an analytic model
+ # is to be fitted. For 2 (or less) values, decision trees are better.
+ self.min_values_for_analytic_model = 3
+
def __repr__(self):
mean = np.mean(self.data)
return f"ModelAttribute<{self.name}, {self.attr}, mean={mean}>"
@@ -638,6 +642,12 @@ class ModelAttribute:
param_values = list(map(lambda x: x[param_index], self.param_values))
if not all(map(lambda n: n is None or is_numeric(n), param_values)):
return False
+ distinct_values = self.stats.distinct_values_by_param_index[param_index]
+ if (
+ None in distinct_values
+ and len(distinct_values) - 1 < self.min_values_for_analytic_model
+ ) or len(distinct_values) < self.min_values_for_analytic_model:
+ return False
return True
def get_data_for_paramfit_this(self, safe_functions_enabled=False):
@@ -779,7 +789,7 @@ class ModelAttribute:
if (
with_function_leaves
- and len(unique_values) > 3
+ and len(unique_values) >= self.min_values_for_analytic_model
and all(map(lambda x: type(x) is int, unique_values))
):
# param can be modeled as a function. Do not split on it.
@@ -797,7 +807,7 @@ class ModelAttribute:
)
)
- if len(list(filter(len, child_indexes))) < 2:
+ if len(list(filter(len, child_indexes))) <= 1:
# this param only has a single value. there's no point in splitting.
mean_stds.append(np.inf)
continue