summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-02-26 16:20:39 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-02-26 16:20:39 +0100
commit112ebbe21ff330c68faf883b125ba4932e007544 (patch)
tree3a083ac80c7af96b9c2cbc356736cd5f07236aa6
parentb89d1defee4bb8a93164b3b717d67c317d8e1db5 (diff)
add RMT_SUBMODEL=cart
-rw-r--r--README.md2
-rw-r--r--lib/parameters.py10
2 files changed, 9 insertions, 3 deletions
diff --git a/README.md b/README.md
index fd73245..76f0113 100644
--- a/README.md
+++ b/README.md
@@ -112,7 +112,7 @@ The following variables may be set to alter the behaviour of dfatool components.
| `DFATOOL_COMPENSATE_DRIFT` | **0**, 1 | Perform drift compensation for loaders without sync input (e.g. EnergyTrace or Keysight) |
| `DFATOOL_DRIFT_COMPENSATION_PENALTY` | 0 .. 100 (default: majority vote over several penalties) | Specify penalty for ruptures.py PELT changepoint petection |
| `DFATOOL_MODEL` | cart, decart, fol, lmt, **rmt**, symreg, xgb | Modeling method. See below for method-specific configuration options. |
-| `DFATOOL_RMT_SUBMODEL` | fol, static, symreg, **uls** | Modeling method for RMT leaf functions. |
+| `DFATOOL_RMT_SUBMODEL` | cart, fol, static, symreg, **uls** | Modeling method for RMT leaf functions. |
| `DFATOOL_RMT_ENABLED` | 0, **1** | Use decision trees in get\_fitted |
| `DFATOOL_CART_MAX_DEPTH` | **0** .. *n* | maximum depth for sklearn CART. Default (0): unlimited. |
| `DFATOOL_LMT_MAX_DEPTH` | **5** .. 20 | Maximum depth for LMT. |
diff --git a/lib/parameters.py b/lib/parameters.py
index 521ab86..0b0da81 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -1234,7 +1234,10 @@ class ModelAttribute:
param_type=self.param_type,
codependent_param=codependent_param_dict(parameters),
)
- if submodel == "symreg":
+ if submodel == "cart":
+ if ma.build_cart():
+ return ma.model_function
+ elif submodel == "symreg":
if ma.build_symreg():
return ma.model_function
else:
@@ -1278,7 +1281,10 @@ class ModelAttribute:
param_type=self.param_type,
codependent_param=codependent_param_dict(parameters),
)
- if submodel == "symreg":
+ if submodel == "cart":
+ if ma.build_cart():
+ return ma.model_function
+ elif submodel == "symreg":
if ma.build_symreg():
return ma.model_function
else: