From 81a1b42c554414bf9f0995741a866ebd1ece6c3c Mon Sep 17 00:00:00 2001 From: Birte Kristina Friesel Date: Fri, 22 Mar 2024 09:23:02 +0100 Subject: RMT: add max depth hyper-parameter --- lib/functions.py | 9 +++++++++ lib/parameters.py | 13 ++++++++++--- 2 files changed, 19 insertions(+), 3 deletions(-) (limited to 'lib') diff --git a/lib/functions.py b/lib/functions.py index 25c0354..32fade0 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -444,6 +444,15 @@ class SplitFunction(ModelFunction): child.to_dot(pydot, graph, feature_names, str(id(self))) graph.add_edge(pydot.Edge(str(id(self)), str(id(child)), label=key)) + def hyper_to_dref(self): + hyper = super().hyper_to_dref() + hyper.update( + { + "rmt/max depth": int(os.getenv("DFATOOL_RMT_MAX_DEPTH", "0")) + or "infty", + } + ) + @classmethod def from_json(cls, data): assert data["type"] == "split" diff --git a/lib/parameters.py b/lib/parameters.py index 46a1e10..4047c10 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -1072,7 +1072,9 @@ class ModelAttribute: "build_rmt {self.name} {self.attr} called with loss_ignore_scalar=True, with_function_leaves=False. This does not make sense." ) - logger.debug(f"build_rmt(threshold={threshold})") + max_depth = int(os.getenv("DFATOOL_RMT_MAX_DEPTH", "0")) + + logger.debug(f"build_rmt(threshold={threshold}, max_depth={max_depth})") self.model_function = self._build_rmt( self.param_values, @@ -1081,6 +1083,7 @@ class ModelAttribute: with_nonbinary_nodes=with_nonbinary_nodes, loss_ignore_scalar=loss_ignore_scalar, submodel=os.getenv("DFATOOL_RMT_SUBMODEL", "uls"), + max_depth=max_depth, threshold=threshold, ) @@ -1093,6 +1096,7 @@ class ModelAttribute: loss_ignore_scalar=False, submodel="uls", threshold=100, + max_depth=0, level=0, ): """ @@ -1256,8 +1260,10 @@ class ModelAttribute: assert not np.any(np.isnan(children)) loss.append(np.sum(children)) - if np.all(np.isinf(loss)) or np.min(loss) >= np.sum( - (np.array(data) - np.mean(data)) ** 2 + if ( + np.all(np.isinf(loss)) + or (max_depth and level >= max_depth) + or np.min(loss) >= np.sum((np.array(data) - np.mean(data)) ** 2) ): if ffs_feasible: # try generating a function. if it fails, model_function is a StaticFunction. @@ -1357,6 +1363,7 @@ class ModelAttribute: loss_ignore_scalar=loss_ignore_scalar, submodel=submodel, threshold=threshold, + max_depth=max_depth, level=level + 1, ) -- cgit v1.2.3