diff options
-rw-r--r-- | lib/parameters.py | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/lib/parameters.py b/lib/parameters.py index cc6e10c..eae04c8 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -524,6 +524,10 @@ class ModelAttribute: self._check_codependent_param() + # There must be at least 3 distinct data values (≠ None) if an analytic model + # is to be fitted. For 2 (or less) values, decision trees are better. + self.min_values_for_analytic_model = 3 + def __repr__(self): mean = np.mean(self.data) return f"ModelAttribute<{self.name}, {self.attr}, mean={mean}>" @@ -638,6 +642,12 @@ class ModelAttribute: param_values = list(map(lambda x: x[param_index], self.param_values)) if not all(map(lambda n: n is None or is_numeric(n), param_values)): return False + distinct_values = self.stats.distinct_values_by_param_index[param_index] + if ( + None in distinct_values + and len(distinct_values) - 1 < self.min_values_for_analytic_model + ) or len(distinct_values) < self.min_values_for_analytic_model: + return False return True def get_data_for_paramfit_this(self, safe_functions_enabled=False): @@ -779,7 +789,7 @@ class ModelAttribute: if ( with_function_leaves - and len(unique_values) > 3 + and len(unique_values) >= self.min_values_for_analytic_model and all(map(lambda x: type(x) is int, unique_values)) ): # param can be modeled as a function. Do not split on it. @@ -797,7 +807,7 @@ class ModelAttribute: ) ) - if len(list(filter(len, child_indexes))) < 2: + if len(list(filter(len, child_indexes))) <= 1: # this param only has a single value. there's no point in splitting. mean_stds.append(np.inf) continue |