From 572ba1bedc4aae42d66359468c1de8bd62e1a546 Mon Sep 17 00:00:00 2001 From: Birte Kristina Friesel Date: Thu, 22 Feb 2024 13:49:30 +0100 Subject: Support Symbolic Regression rather than ULS in RMT leaves --- lib/cli.py | 2 ++ lib/parameters.py | 41 ++++++++++++++++++++++++++--------------- 2 files changed, 28 insertions(+), 15 deletions(-) (limited to 'lib') diff --git a/lib/cli.py b/lib/cli.py index 499e959..3b79fed 100644 --- a/lib/cli.py +++ b/lib/cli.py @@ -184,6 +184,8 @@ def print_splitinfo(info, prefix=""): print_splitinfo(info.child_gt, f"{prefix} {info.param_name}>{info.threshold}") elif type(info) is df.AnalyticFunction: print_analyticinfo(prefix, info) + elif type(info) is df.SymbolicRegressionFunction: + print_symreginfo(prefix, info) elif type(info) is df.StaticFunction: print(f"{prefix}: {info.value}") else: diff --git a/lib/parameters.py b/lib/parameters.py index acc77d4..521ab86 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -1068,6 +1068,7 @@ class ModelAttribute: with_nonbinary_nodes=with_nonbinary_nodes, ignore_irrelevant_parameters=ignore_irrelevant_parameters, loss_ignore_scalar=loss_ignore_scalar, + submodel=os.getenv("DFATOOL_RMT_SUBMODEL", "uls"), threshold=threshold, relevance_threshold=relevance_threshold, ) @@ -1080,6 +1081,7 @@ class ModelAttribute: with_nonbinary_nodes=True, ignore_irrelevant_parameters=True, loss_ignore_scalar=False, + submodel="uls", threshold=100, relevance_threshold=0.5, level=0, @@ -1232,13 +1234,17 @@ class ModelAttribute: param_type=self.param_type, codependent_param=codependent_param_dict(parameters), ) - ParamStats.compute_for_attr(ma) - paramfit = ParamFit(parallel=False) - for key, param, args, kwargs in ma.get_data_for_paramfit(): - paramfit.enqueue(key, param, args, kwargs) - paramfit.fit() - ma.set_data_from_paramfit(paramfit) - return ma.model_function + if submodel == "symreg": + if ma.build_symreg(): + return ma.model_function + else: + ParamStats.compute_for_attr(ma) + paramfit = ParamFit(parallel=False) + for key, param, args, kwargs in ma.get_data_for_paramfit(): + paramfit.enqueue(key, param, args, kwargs) + paramfit.fit() + ma.set_data_from_paramfit(paramfit) + return ma.model_function return df.StaticFunction(np.mean(data), n_samples=len(data)) split_feasible = True @@ -1272,14 +1278,18 @@ class ModelAttribute: param_type=self.param_type, codependent_param=codependent_param_dict(parameters), ) - ParamStats.compute_for_attr(ma) - paramfit = ParamFit(parallel=False) - for key, param, args, kwargs in ma.get_data_for_paramfit(): - paramfit.enqueue(key, param, args, kwargs) - paramfit.fit() - ma.set_data_from_paramfit(paramfit) - if type(ma.model_function) == df.AnalyticFunction: - return ma.model_function + if submodel == "symreg": + if ma.build_symreg(): + return ma.model_function + else: + ParamStats.compute_for_attr(ma) + paramfit = ParamFit(parallel=False) + for key, param, args, kwargs in ma.get_data_for_paramfit(): + paramfit.enqueue(key, param, args, kwargs) + paramfit.fit() + ma.set_data_from_paramfit(paramfit) + if type(ma.model_function) == df.AnalyticFunction: + return ma.model_function symbol_index = np.argmin(loss) unique_values = list(set(map(lambda p: p[symbol_index], parameters))) @@ -1303,6 +1313,7 @@ class ModelAttribute: with_nonbinary_nodes=with_nonbinary_nodes, ignore_irrelevant_parameters=ignore_irrelevant_parameters, loss_ignore_scalar=loss_ignore_scalar, + submodel=submodel, threshold=threshold, relevance_threshold=relevance_threshold, level=level + 1, -- cgit v1.2.3