From 572ba1bedc4aae42d66359468c1de8bd62e1a546 Mon Sep 17 00:00:00 2001 From: Birte Kristina Friesel Date: Thu, 22 Feb 2024 13:49:30 +0100 Subject: Support Symbolic Regression rather than ULS in RMT leaves --- README.md | 2 +- lib/cli.py | 2 ++ lib/parameters.py | 41 ++++++++++++++++++++++++++--------------- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 274e121..fd73245 100644 --- a/README.md +++ b/README.md @@ -112,7 +112,7 @@ The following variables may be set to alter the behaviour of dfatool components. | `DFATOOL_COMPENSATE_DRIFT` | **0**, 1 | Perform drift compensation for loaders without sync input (e.g. EnergyTrace or Keysight) | | `DFATOOL_DRIFT_COMPENSATION_PENALTY` | 0 .. 100 (default: majority vote over several penalties) | Specify penalty for ruptures.py PELT changepoint petection | | `DFATOOL_MODEL` | cart, decart, fol, lmt, **rmt**, symreg, xgb | Modeling method. See below for method-specific configuration options. | -| `DFATOOL_RMT_SUBMODEL` | fol, static, **uls** | Modeling method for RMT leaf functions. | +| `DFATOOL_RMT_SUBMODEL` | fol, static, symreg, **uls** | Modeling method for RMT leaf functions. | | `DFATOOL_RMT_ENABLED` | 0, **1** | Use decision trees in get\_fitted | | `DFATOOL_CART_MAX_DEPTH` | **0** .. *n* | maximum depth for sklearn CART. Default (0): unlimited. | | `DFATOOL_LMT_MAX_DEPTH` | **5** .. 20 | Maximum depth for LMT. | diff --git a/lib/cli.py b/lib/cli.py index 499e959..3b79fed 100644 --- a/lib/cli.py +++ b/lib/cli.py @@ -184,6 +184,8 @@ def print_splitinfo(info, prefix=""): print_splitinfo(info.child_gt, f"{prefix} {info.param_name}>{info.threshold}") elif type(info) is df.AnalyticFunction: print_analyticinfo(prefix, info) + elif type(info) is df.SymbolicRegressionFunction: + print_symreginfo(prefix, info) elif type(info) is df.StaticFunction: print(f"{prefix}: {info.value}") else: diff --git a/lib/parameters.py b/lib/parameters.py index acc77d4..521ab86 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -1068,6 +1068,7 @@ class ModelAttribute: with_nonbinary_nodes=with_nonbinary_nodes, ignore_irrelevant_parameters=ignore_irrelevant_parameters, loss_ignore_scalar=loss_ignore_scalar, + submodel=os.getenv("DFATOOL_RMT_SUBMODEL", "uls"), threshold=threshold, relevance_threshold=relevance_threshold, ) @@ -1080,6 +1081,7 @@ class ModelAttribute: with_nonbinary_nodes=True, ignore_irrelevant_parameters=True, loss_ignore_scalar=False, + submodel="uls", threshold=100, relevance_threshold=0.5, level=0, @@ -1232,13 +1234,17 @@ class ModelAttribute: param_type=self.param_type, codependent_param=codependent_param_dict(parameters), ) - ParamStats.compute_for_attr(ma) - paramfit = ParamFit(parallel=False) - for key, param, args, kwargs in ma.get_data_for_paramfit(): - paramfit.enqueue(key, param, args, kwargs) - paramfit.fit() - ma.set_data_from_paramfit(paramfit) - return ma.model_function + if submodel == "symreg": + if ma.build_symreg(): + return ma.model_function + else: + ParamStats.compute_for_attr(ma) + paramfit = ParamFit(parallel=False) + for key, param, args, kwargs in ma.get_data_for_paramfit(): + paramfit.enqueue(key, param, args, kwargs) + paramfit.fit() + ma.set_data_from_paramfit(paramfit) + return ma.model_function return df.StaticFunction(np.mean(data), n_samples=len(data)) split_feasible = True @@ -1272,14 +1278,18 @@ class ModelAttribute: param_type=self.param_type, codependent_param=codependent_param_dict(parameters), ) - ParamStats.compute_for_attr(ma) - paramfit = ParamFit(parallel=False) - for key, param, args, kwargs in ma.get_data_for_paramfit(): - paramfit.enqueue(key, param, args, kwargs) - paramfit.fit() - ma.set_data_from_paramfit(paramfit) - if type(ma.model_function) == df.AnalyticFunction: - return ma.model_function + if submodel == "symreg": + if ma.build_symreg(): + return ma.model_function + else: + ParamStats.compute_for_attr(ma) + paramfit = ParamFit(parallel=False) + for key, param, args, kwargs in ma.get_data_for_paramfit(): + paramfit.enqueue(key, param, args, kwargs) + paramfit.fit() + ma.set_data_from_paramfit(paramfit) + if type(ma.model_function) == df.AnalyticFunction: + return ma.model_function symbol_index = np.argmin(loss) unique_values = list(set(map(lambda p: p[symbol_index], parameters))) @@ -1303,6 +1313,7 @@ class ModelAttribute: with_nonbinary_nodes=with_nonbinary_nodes, ignore_irrelevant_parameters=ignore_irrelevant_parameters, loss_ignore_scalar=loss_ignore_scalar, + submodel=submodel, threshold=threshold, relevance_threshold=relevance_threshold, level=level + 1, -- cgit v1.2.3