diff options
-rw-r--r-- | README.md | 1 | ||||
-rw-r--r-- | lib/functions.py | 9 | ||||
-rw-r--r-- | lib/parameters.py | 13 |
3 files changed, 20 insertions, 3 deletions
@@ -113,6 +113,7 @@ The following variables may be set to alter the behaviour of dfatool components. | `DFATOOL_COMPENSATE_DRIFT` | **0**, 1 | Perform drift compensation for loaders without sync input (e.g. EnergyTrace or Keysight) | | `DFATOOL_DRIFT_COMPENSATION_PENALTY` | 0 .. 100 (default: majority vote over several penalties) | Specify penalty for ruptures.py PELT changepoint petection | | `DFATOOL_MODEL` | cart, decart, fol, lgbm, lmt, **rmt**, symreg, uls, xgb | Modeling method. See below for method-specific configuration options. | +| `DFATOOL_RMT_MAX_DEPTH` | **0** .. *n* | Maximum depth for RMT. Default (0): unlimited. | | `DFATOOL_RMT_SUBMODEL` | cart, fol, static, symreg, **uls** | Modeling method for RMT leaf functions. | | `DFATOOL_PREPROCESSING_RELEVANCE_METHOD` | **none**, mi | Ignore parameters deemed irrelevant by the specified heuristic before passing them on to `DFATOOL_MODEL`. | | `DFATOOL_PREPROCESSING_RELEVANCE_THRESHOLD` | .. **0.1** .. | Threshold for relevance heuristic. | diff --git a/lib/functions.py b/lib/functions.py index 25c0354..32fade0 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -444,6 +444,15 @@ class SplitFunction(ModelFunction): child.to_dot(pydot, graph, feature_names, str(id(self))) graph.add_edge(pydot.Edge(str(id(self)), str(id(child)), label=key)) + def hyper_to_dref(self): + hyper = super().hyper_to_dref() + hyper.update( + { + "rmt/max depth": int(os.getenv("DFATOOL_RMT_MAX_DEPTH", "0")) + or "infty", + } + ) + @classmethod def from_json(cls, data): assert data["type"] == "split" diff --git a/lib/parameters.py b/lib/parameters.py index 46a1e10..4047c10 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -1072,7 +1072,9 @@ class ModelAttribute: "build_rmt {self.name} {self.attr} called with loss_ignore_scalar=True, with_function_leaves=False. This does not make sense." ) - logger.debug(f"build_rmt(threshold={threshold})") + max_depth = int(os.getenv("DFATOOL_RMT_MAX_DEPTH", "0")) + + logger.debug(f"build_rmt(threshold={threshold}, max_depth={max_depth})") self.model_function = self._build_rmt( self.param_values, @@ -1081,6 +1083,7 @@ class ModelAttribute: with_nonbinary_nodes=with_nonbinary_nodes, loss_ignore_scalar=loss_ignore_scalar, submodel=os.getenv("DFATOOL_RMT_SUBMODEL", "uls"), + max_depth=max_depth, threshold=threshold, ) @@ -1093,6 +1096,7 @@ class ModelAttribute: loss_ignore_scalar=False, submodel="uls", threshold=100, + max_depth=0, level=0, ): """ @@ -1256,8 +1260,10 @@ class ModelAttribute: assert not np.any(np.isnan(children)) loss.append(np.sum(children)) - if np.all(np.isinf(loss)) or np.min(loss) >= np.sum( - (np.array(data) - np.mean(data)) ** 2 + if ( + np.all(np.isinf(loss)) + or (max_depth and level >= max_depth) + or np.min(loss) >= np.sum((np.array(data) - np.mean(data)) ** 2) ): if ffs_feasible: # try generating a function. if it fails, model_function is a StaticFunction. @@ -1357,6 +1363,7 @@ class ModelAttribute: loss_ignore_scalar=loss_ignore_scalar, submodel=submodel, threshold=threshold, + max_depth=max_depth, level=level + 1, ) |