summaryrefslogtreecommitdiff
path: root/lib/parameters.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/parameters.py')
-rw-r--r--lib/parameters.py13
1 files changed, 10 insertions, 3 deletions
diff --git a/lib/parameters.py b/lib/parameters.py
index 46a1e10..4047c10 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -1072,7 +1072,9 @@ class ModelAttribute:
"build_rmt {self.name} {self.attr} called with loss_ignore_scalar=True, with_function_leaves=False. This does not make sense."
)
- logger.debug(f"build_rmt(threshold={threshold})")
+ max_depth = int(os.getenv("DFATOOL_RMT_MAX_DEPTH", "0"))
+
+ logger.debug(f"build_rmt(threshold={threshold}, max_depth={max_depth})")
self.model_function = self._build_rmt(
self.param_values,
@@ -1081,6 +1083,7 @@ class ModelAttribute:
with_nonbinary_nodes=with_nonbinary_nodes,
loss_ignore_scalar=loss_ignore_scalar,
submodel=os.getenv("DFATOOL_RMT_SUBMODEL", "uls"),
+ max_depth=max_depth,
threshold=threshold,
)
@@ -1093,6 +1096,7 @@ class ModelAttribute:
loss_ignore_scalar=False,
submodel="uls",
threshold=100,
+ max_depth=0,
level=0,
):
"""
@@ -1256,8 +1260,10 @@ class ModelAttribute:
assert not np.any(np.isnan(children))
loss.append(np.sum(children))
- if np.all(np.isinf(loss)) or np.min(loss) >= np.sum(
- (np.array(data) - np.mean(data)) ** 2
+ if (
+ np.all(np.isinf(loss))
+ or (max_depth and level >= max_depth)
+ or np.min(loss) >= np.sum((np.array(data) - np.mean(data)) ** 2)
):
if ffs_feasible:
# try generating a function. if it fails, model_function is a StaticFunction.
@@ -1357,6 +1363,7 @@ class ModelAttribute:
loss_ignore_scalar=loss_ignore_scalar,
submodel=submodel,
threshold=threshold,
+ max_depth=max_depth,
level=level + 1,
)