diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2022-03-21 19:20:19 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2022-03-21 19:20:19 +0100 |
commit | 2de3d7c6a79ad6a8d2e595181cfebf1e5d69cfd1 (patch) | |
tree | fc55bc94d56c8ed0f981c946177302be17b49784 /lib/parameters.py | |
parent | 2592ddc7f3c5d8ef073e984a51e7887c5238c111 (diff) |
dtree: fix missing function leaves when nonbinary nodes are disabled
Diffstat (limited to 'lib/parameters.py')
-rw-r--r-- | lib/parameters.py | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/lib/parameters.py b/lib/parameters.py index 1cbc5af..dae6e2a 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -1194,12 +1194,6 @@ class ModelAttribute: loss.append(np.inf) continue - # if not with_nonbinary_nodes and sorted(unique_values) != [0, 1]: - if not with_nonbinary_nodes and len(unique_values) > 2: - # param cannot be handled with a binary split - loss.append(np.inf) - continue - if ( with_function_leaves and self.param_type[param_index] == ParamType.SCALAR @@ -1210,6 +1204,12 @@ class ModelAttribute: ffs_feasible = True continue + # if not with_nonbinary_nodes and sorted(unique_values) != [0, 1]: + if not with_nonbinary_nodes and len(unique_values) > 2: + # param cannot be handled with a binary split + loss.append(np.inf) + continue + if ignore_irrelevant_parameters: std_by_param = _mean_std_by_param( by_param, distinct_values_by_param_index, param_index @@ -1257,6 +1257,7 @@ class ModelAttribute: if np.all(np.isinf(loss)): # all children have the same configuration. We shouldn't get here due to the threshold check above... if ffs_feasible: + logger.debug("ffs feasible, attempting to fit a leaf") # try generating a function. if it fails, model_function is a StaticFunction. ma = ModelAttribute( self.name + "_", @@ -1296,7 +1297,6 @@ class ModelAttribute: child_parameters = list(map(lambda i: parameters[i], indexes)) child_data = list(map(lambda i: data[i], indexes)) if len(child_data): - logger.debug(f"subtree level {level+1}") child[value] = self._build_dtree( child_parameters, child_data, |