diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-03-22 09:23:02 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-03-22 09:23:02 +0100 |
commit | 81a1b42c554414bf9f0995741a866ebd1ece6c3c (patch) | |
tree | 686b4843ae1768143bdf1c397601ef8b51817041 /lib/parameters.py | |
parent | 995753a7913b89ea5f58983cf2cfca1cb8a00b77 (diff) |
RMT: add max depth hyper-parameter
Diffstat (limited to 'lib/parameters.py')
-rw-r--r-- | lib/parameters.py | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/lib/parameters.py b/lib/parameters.py index 46a1e10..4047c10 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -1072,7 +1072,9 @@ class ModelAttribute: "build_rmt {self.name} {self.attr} called with loss_ignore_scalar=True, with_function_leaves=False. This does not make sense." ) - logger.debug(f"build_rmt(threshold={threshold})") + max_depth = int(os.getenv("DFATOOL_RMT_MAX_DEPTH", "0")) + + logger.debug(f"build_rmt(threshold={threshold}, max_depth={max_depth})") self.model_function = self._build_rmt( self.param_values, @@ -1081,6 +1083,7 @@ class ModelAttribute: with_nonbinary_nodes=with_nonbinary_nodes, loss_ignore_scalar=loss_ignore_scalar, submodel=os.getenv("DFATOOL_RMT_SUBMODEL", "uls"), + max_depth=max_depth, threshold=threshold, ) @@ -1093,6 +1096,7 @@ class ModelAttribute: loss_ignore_scalar=False, submodel="uls", threshold=100, + max_depth=0, level=0, ): """ @@ -1256,8 +1260,10 @@ class ModelAttribute: assert not np.any(np.isnan(children)) loss.append(np.sum(children)) - if np.all(np.isinf(loss)) or np.min(loss) >= np.sum( - (np.array(data) - np.mean(data)) ** 2 + if ( + np.all(np.isinf(loss)) + or (max_depth and level >= max_depth) + or np.min(loss) >= np.sum((np.array(data) - np.mean(data)) ** 2) ): if ffs_feasible: # try generating a function. if it fails, model_function is a StaticFunction. @@ -1357,6 +1363,7 @@ class ModelAttribute: loss_ignore_scalar=loss_ignore_scalar, submodel=submodel, threshold=threshold, + max_depth=max_depth, level=level + 1, ) |