diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/parameters.py | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/lib/parameters.py b/lib/parameters.py index 1cbc5af..dae6e2a 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -1194,12 +1194,6 @@ class ModelAttribute: loss.append(np.inf) continue - # if not with_nonbinary_nodes and sorted(unique_values) != [0, 1]: - if not with_nonbinary_nodes and len(unique_values) > 2: - # param cannot be handled with a binary split - loss.append(np.inf) - continue - if ( with_function_leaves and self.param_type[param_index] == ParamType.SCALAR @@ -1210,6 +1204,12 @@ class ModelAttribute: ffs_feasible = True continue + # if not with_nonbinary_nodes and sorted(unique_values) != [0, 1]: + if not with_nonbinary_nodes and len(unique_values) > 2: + # param cannot be handled with a binary split + loss.append(np.inf) + continue + if ignore_irrelevant_parameters: std_by_param = _mean_std_by_param( by_param, distinct_values_by_param_index, param_index @@ -1257,6 +1257,7 @@ class ModelAttribute: if np.all(np.isinf(loss)): # all children have the same configuration. We shouldn't get here due to the threshold check above... if ffs_feasible: + logger.debug("ffs feasible, attempting to fit a leaf") # try generating a function. if it fails, model_function is a StaticFunction. ma = ModelAttribute( self.name + "_", @@ -1296,7 +1297,6 @@ class ModelAttribute: child_parameters = list(map(lambda i: parameters[i], indexes)) child_data = list(map(lambda i: data[i], indexes)) if len(child_data): - logger.debug(f"subtree level {level+1}") child[value] = self._build_dtree( child_parameters, child_data, |