diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-02-22 13:49:30 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-02-22 13:49:30 +0100 |
commit | 572ba1bedc4aae42d66359468c1de8bd62e1a546 (patch) | |
tree | c813e16935f102bb63f1f7afb4d3ddde286dd1c4 /lib | |
parent | 677d24805b0bd50fc9326c43c4153a93d8592a41 (diff) |
Support Symbolic Regression rather than ULS in RMT leaves
Diffstat (limited to 'lib')
-rw-r--r-- | lib/cli.py | 2 | ||||
-rw-r--r-- | lib/parameters.py | 41 |
2 files changed, 28 insertions, 15 deletions
@@ -184,6 +184,8 @@ def print_splitinfo(info, prefix=""): print_splitinfo(info.child_gt, f"{prefix} {info.param_name}>{info.threshold}") elif type(info) is df.AnalyticFunction: print_analyticinfo(prefix, info) + elif type(info) is df.SymbolicRegressionFunction: + print_symreginfo(prefix, info) elif type(info) is df.StaticFunction: print(f"{prefix}: {info.value}") else: diff --git a/lib/parameters.py b/lib/parameters.py index acc77d4..521ab86 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -1068,6 +1068,7 @@ class ModelAttribute: with_nonbinary_nodes=with_nonbinary_nodes, ignore_irrelevant_parameters=ignore_irrelevant_parameters, loss_ignore_scalar=loss_ignore_scalar, + submodel=os.getenv("DFATOOL_RMT_SUBMODEL", "uls"), threshold=threshold, relevance_threshold=relevance_threshold, ) @@ -1080,6 +1081,7 @@ class ModelAttribute: with_nonbinary_nodes=True, ignore_irrelevant_parameters=True, loss_ignore_scalar=False, + submodel="uls", threshold=100, relevance_threshold=0.5, level=0, @@ -1232,13 +1234,17 @@ class ModelAttribute: param_type=self.param_type, codependent_param=codependent_param_dict(parameters), ) - ParamStats.compute_for_attr(ma) - paramfit = ParamFit(parallel=False) - for key, param, args, kwargs in ma.get_data_for_paramfit(): - paramfit.enqueue(key, param, args, kwargs) - paramfit.fit() - ma.set_data_from_paramfit(paramfit) - return ma.model_function + if submodel == "symreg": + if ma.build_symreg(): + return ma.model_function + else: + ParamStats.compute_for_attr(ma) + paramfit = ParamFit(parallel=False) + for key, param, args, kwargs in ma.get_data_for_paramfit(): + paramfit.enqueue(key, param, args, kwargs) + paramfit.fit() + ma.set_data_from_paramfit(paramfit) + return ma.model_function return df.StaticFunction(np.mean(data), n_samples=len(data)) split_feasible = True @@ -1272,14 +1278,18 @@ class ModelAttribute: param_type=self.param_type, codependent_param=codependent_param_dict(parameters), ) - ParamStats.compute_for_attr(ma) - paramfit = ParamFit(parallel=False) - for key, param, args, kwargs in ma.get_data_for_paramfit(): - paramfit.enqueue(key, param, args, kwargs) - paramfit.fit() - ma.set_data_from_paramfit(paramfit) - if type(ma.model_function) == df.AnalyticFunction: - return ma.model_function + if submodel == "symreg": + if ma.build_symreg(): + return ma.model_function + else: + ParamStats.compute_for_attr(ma) + paramfit = ParamFit(parallel=False) + for key, param, args, kwargs in ma.get_data_for_paramfit(): + paramfit.enqueue(key, param, args, kwargs) + paramfit.fit() + ma.set_data_from_paramfit(paramfit) + if type(ma.model_function) == df.AnalyticFunction: + return ma.model_function symbol_index = np.argmin(loss) unique_values = list(set(map(lambda p: p[symbol_index], parameters))) @@ -1303,6 +1313,7 @@ class ModelAttribute: with_nonbinary_nodes=with_nonbinary_nodes, ignore_irrelevant_parameters=ignore_irrelevant_parameters, loss_ignore_scalar=loss_ignore_scalar, + submodel=submodel, threshold=threshold, relevance_threshold=relevance_threshold, level=level + 1, |