summaryrefslogtreecommitdiff
path: root/lib/parameters.py
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-03-22 09:23:02 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-03-22 09:23:02 +0100
commit81a1b42c554414bf9f0995741a866ebd1ece6c3c (patch)
tree686b4843ae1768143bdf1c397601ef8b51817041 /lib/parameters.py
parent995753a7913b89ea5f58983cf2cfca1cb8a00b77 (diff)
RMT: add max depth hyper-parameter
Diffstat (limited to 'lib/parameters.py')
-rw-r--r--lib/parameters.py13
1 files changed, 10 insertions, 3 deletions
diff --git a/lib/parameters.py b/lib/parameters.py
index 46a1e10..4047c10 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -1072,7 +1072,9 @@ class ModelAttribute:
"build_rmt {self.name} {self.attr} called with loss_ignore_scalar=True, with_function_leaves=False. This does not make sense."
)
- logger.debug(f"build_rmt(threshold={threshold})")
+ max_depth = int(os.getenv("DFATOOL_RMT_MAX_DEPTH", "0"))
+
+ logger.debug(f"build_rmt(threshold={threshold}, max_depth={max_depth})")
self.model_function = self._build_rmt(
self.param_values,
@@ -1081,6 +1083,7 @@ class ModelAttribute:
with_nonbinary_nodes=with_nonbinary_nodes,
loss_ignore_scalar=loss_ignore_scalar,
submodel=os.getenv("DFATOOL_RMT_SUBMODEL", "uls"),
+ max_depth=max_depth,
threshold=threshold,
)
@@ -1093,6 +1096,7 @@ class ModelAttribute:
loss_ignore_scalar=False,
submodel="uls",
threshold=100,
+ max_depth=0,
level=0,
):
"""
@@ -1256,8 +1260,10 @@ class ModelAttribute:
assert not np.any(np.isnan(children))
loss.append(np.sum(children))
- if np.all(np.isinf(loss)) or np.min(loss) >= np.sum(
- (np.array(data) - np.mean(data)) ** 2
+ if (
+ np.all(np.isinf(loss))
+ or (max_depth and level >= max_depth)
+ or np.min(loss) >= np.sum((np.array(data) - np.mean(data)) ** 2)
):
if ffs_feasible:
# try generating a function. if it fails, model_function is a StaticFunction.
@@ -1357,6 +1363,7 @@ class ModelAttribute:
loss_ignore_scalar=loss_ignore_scalar,
submodel=submodel,
threshold=threshold,
+ max_depth=max_depth,
level=level + 1,
)