summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2022-03-21 19:20:19 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2022-03-21 19:20:19 +0100
commit2de3d7c6a79ad6a8d2e595181cfebf1e5d69cfd1 (patch)
treefc55bc94d56c8ed0f981c946177302be17b49784
parent2592ddc7f3c5d8ef073e984a51e7887c5238c111 (diff)
dtree: fix missing function leaves when nonbinary nodes are disabled
-rw-r--r--lib/parameters.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/lib/parameters.py b/lib/parameters.py
index 1cbc5af..dae6e2a 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -1194,12 +1194,6 @@ class ModelAttribute:
loss.append(np.inf)
continue
- # if not with_nonbinary_nodes and sorted(unique_values) != [0, 1]:
- if not with_nonbinary_nodes and len(unique_values) > 2:
- # param cannot be handled with a binary split
- loss.append(np.inf)
- continue
-
if (
with_function_leaves
and self.param_type[param_index] == ParamType.SCALAR
@@ -1210,6 +1204,12 @@ class ModelAttribute:
ffs_feasible = True
continue
+ # if not with_nonbinary_nodes and sorted(unique_values) != [0, 1]:
+ if not with_nonbinary_nodes and len(unique_values) > 2:
+ # param cannot be handled with a binary split
+ loss.append(np.inf)
+ continue
+
if ignore_irrelevant_parameters:
std_by_param = _mean_std_by_param(
by_param, distinct_values_by_param_index, param_index
@@ -1257,6 +1257,7 @@ class ModelAttribute:
if np.all(np.isinf(loss)):
# all children have the same configuration. We shouldn't get here due to the threshold check above...
if ffs_feasible:
+ logger.debug("ffs feasible, attempting to fit a leaf")
# try generating a function. if it fails, model_function is a StaticFunction.
ma = ModelAttribute(
self.name + "_",
@@ -1296,7 +1297,6 @@ class ModelAttribute:
child_parameters = list(map(lambda i: parameters[i], indexes))
child_data = list(map(lambda i: data[i], indexes))
if len(child_data):
- logger.debug(f"subtree level {level+1}")
child[value] = self._build_dtree(
child_parameters,
child_data,