summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-03-22 09:23:02 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-03-22 09:23:02 +0100
commit81a1b42c554414bf9f0995741a866ebd1ece6c3c (patch)
tree686b4843ae1768143bdf1c397601ef8b51817041
parent995753a7913b89ea5f58983cf2cfca1cb8a00b77 (diff)
RMT: add max depth hyper-parameter
-rw-r--r--README.md1
-rw-r--r--lib/functions.py9
-rw-r--r--lib/parameters.py13
3 files changed, 20 insertions, 3 deletions
diff --git a/README.md b/README.md
index b60d265..2ba8e89 100644
--- a/README.md
+++ b/README.md
@@ -113,6 +113,7 @@ The following variables may be set to alter the behaviour of dfatool components.
| `DFATOOL_COMPENSATE_DRIFT` | **0**, 1 | Perform drift compensation for loaders without sync input (e.g. EnergyTrace or Keysight) |
| `DFATOOL_DRIFT_COMPENSATION_PENALTY` | 0 .. 100 (default: majority vote over several penalties) | Specify penalty for ruptures.py PELT changepoint petection |
| `DFATOOL_MODEL` | cart, decart, fol, lgbm, lmt, **rmt**, symreg, uls, xgb | Modeling method. See below for method-specific configuration options. |
+| `DFATOOL_RMT_MAX_DEPTH` | **0** .. *n* | Maximum depth for RMT. Default (0): unlimited. |
| `DFATOOL_RMT_SUBMODEL` | cart, fol, static, symreg, **uls** | Modeling method for RMT leaf functions. |
| `DFATOOL_PREPROCESSING_RELEVANCE_METHOD` | **none**, mi | Ignore parameters deemed irrelevant by the specified heuristic before passing them on to `DFATOOL_MODEL`. |
| `DFATOOL_PREPROCESSING_RELEVANCE_THRESHOLD` | .. **0.1** .. | Threshold for relevance heuristic. |
diff --git a/lib/functions.py b/lib/functions.py
index 25c0354..32fade0 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -444,6 +444,15 @@ class SplitFunction(ModelFunction):
child.to_dot(pydot, graph, feature_names, str(id(self)))
graph.add_edge(pydot.Edge(str(id(self)), str(id(child)), label=key))
+ def hyper_to_dref(self):
+ hyper = super().hyper_to_dref()
+ hyper.update(
+ {
+ "rmt/max depth": int(os.getenv("DFATOOL_RMT_MAX_DEPTH", "0"))
+ or "infty",
+ }
+ )
+
@classmethod
def from_json(cls, data):
assert data["type"] == "split"
diff --git a/lib/parameters.py b/lib/parameters.py
index 46a1e10..4047c10 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -1072,7 +1072,9 @@ class ModelAttribute:
"build_rmt {self.name} {self.attr} called with loss_ignore_scalar=True, with_function_leaves=False. This does not make sense."
)
- logger.debug(f"build_rmt(threshold={threshold})")
+ max_depth = int(os.getenv("DFATOOL_RMT_MAX_DEPTH", "0"))
+
+ logger.debug(f"build_rmt(threshold={threshold}, max_depth={max_depth})")
self.model_function = self._build_rmt(
self.param_values,
@@ -1081,6 +1083,7 @@ class ModelAttribute:
with_nonbinary_nodes=with_nonbinary_nodes,
loss_ignore_scalar=loss_ignore_scalar,
submodel=os.getenv("DFATOOL_RMT_SUBMODEL", "uls"),
+ max_depth=max_depth,
threshold=threshold,
)
@@ -1093,6 +1096,7 @@ class ModelAttribute:
loss_ignore_scalar=False,
submodel="uls",
threshold=100,
+ max_depth=0,
level=0,
):
"""
@@ -1256,8 +1260,10 @@ class ModelAttribute:
assert not np.any(np.isnan(children))
loss.append(np.sum(children))
- if np.all(np.isinf(loss)) or np.min(loss) >= np.sum(
- (np.array(data) - np.mean(data)) ** 2
+ if (
+ np.all(np.isinf(loss))
+ or (max_depth and level >= max_depth)
+ or np.min(loss) >= np.sum((np.array(data) - np.mean(data)) ** 2)
):
if ffs_feasible:
# try generating a function. if it fails, model_function is a StaticFunction.
@@ -1357,6 +1363,7 @@ class ModelAttribute:
loss_ignore_scalar=loss_ignore_scalar,
submodel=submodel,
threshold=threshold,
+ max_depth=max_depth,
level=level + 1,
)