diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/functions.py | 9 | ||||
-rw-r--r-- | lib/parameters.py | 13 |
2 files changed, 19 insertions, 3 deletions
diff --git a/lib/functions.py b/lib/functions.py index 25c0354..32fade0 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -444,6 +444,15 @@ class SplitFunction(ModelFunction): child.to_dot(pydot, graph, feature_names, str(id(self))) graph.add_edge(pydot.Edge(str(id(self)), str(id(child)), label=key)) + def hyper_to_dref(self): + hyper = super().hyper_to_dref() + hyper.update( + { + "rmt/max depth": int(os.getenv("DFATOOL_RMT_MAX_DEPTH", "0")) + or "infty", + } + ) + @classmethod def from_json(cls, data): assert data["type"] == "split" diff --git a/lib/parameters.py b/lib/parameters.py index 46a1e10..4047c10 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -1072,7 +1072,9 @@ class ModelAttribute: "build_rmt {self.name} {self.attr} called with loss_ignore_scalar=True, with_function_leaves=False. This does not make sense." ) - logger.debug(f"build_rmt(threshold={threshold})") + max_depth = int(os.getenv("DFATOOL_RMT_MAX_DEPTH", "0")) + + logger.debug(f"build_rmt(threshold={threshold}, max_depth={max_depth})") self.model_function = self._build_rmt( self.param_values, @@ -1081,6 +1083,7 @@ class ModelAttribute: with_nonbinary_nodes=with_nonbinary_nodes, loss_ignore_scalar=loss_ignore_scalar, submodel=os.getenv("DFATOOL_RMT_SUBMODEL", "uls"), + max_depth=max_depth, threshold=threshold, ) @@ -1093,6 +1096,7 @@ class ModelAttribute: loss_ignore_scalar=False, submodel="uls", threshold=100, + max_depth=0, level=0, ): """ @@ -1256,8 +1260,10 @@ class ModelAttribute: assert not np.any(np.isnan(children)) loss.append(np.sum(children)) - if np.all(np.isinf(loss)) or np.min(loss) >= np.sum( - (np.array(data) - np.mean(data)) ** 2 + if ( + np.all(np.isinf(loss)) + or (max_depth and level >= max_depth) + or np.min(loss) >= np.sum((np.array(data) - np.mean(data)) ** 2) ): if ffs_feasible: # try generating a function. if it fails, model_function is a StaticFunction. @@ -1357,6 +1363,7 @@ class ModelAttribute: loss_ignore_scalar=loss_ignore_scalar, submodel=submodel, threshold=threshold, + max_depth=max_depth, level=level + 1, ) |