summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-02-22 13:49:30 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-02-22 13:49:30 +0100
commit572ba1bedc4aae42d66359468c1de8bd62e1a546 (patch)
treec813e16935f102bb63f1f7afb4d3ddde286dd1c4 /lib
parent677d24805b0bd50fc9326c43c4153a93d8592a41 (diff)
Support Symbolic Regression rather than ULS in RMT leaves
Diffstat (limited to 'lib')
-rw-r--r--lib/cli.py2
-rw-r--r--lib/parameters.py41
2 files changed, 28 insertions, 15 deletions
diff --git a/lib/cli.py b/lib/cli.py
index 499e959..3b79fed 100644
--- a/lib/cli.py
+++ b/lib/cli.py
@@ -184,6 +184,8 @@ def print_splitinfo(info, prefix=""):
print_splitinfo(info.child_gt, f"{prefix} {info.param_name}>{info.threshold}")
elif type(info) is df.AnalyticFunction:
print_analyticinfo(prefix, info)
+ elif type(info) is df.SymbolicRegressionFunction:
+ print_symreginfo(prefix, info)
elif type(info) is df.StaticFunction:
print(f"{prefix}: {info.value}")
else:
diff --git a/lib/parameters.py b/lib/parameters.py
index acc77d4..521ab86 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -1068,6 +1068,7 @@ class ModelAttribute:
with_nonbinary_nodes=with_nonbinary_nodes,
ignore_irrelevant_parameters=ignore_irrelevant_parameters,
loss_ignore_scalar=loss_ignore_scalar,
+ submodel=os.getenv("DFATOOL_RMT_SUBMODEL", "uls"),
threshold=threshold,
relevance_threshold=relevance_threshold,
)
@@ -1080,6 +1081,7 @@ class ModelAttribute:
with_nonbinary_nodes=True,
ignore_irrelevant_parameters=True,
loss_ignore_scalar=False,
+ submodel="uls",
threshold=100,
relevance_threshold=0.5,
level=0,
@@ -1232,13 +1234,17 @@ class ModelAttribute:
param_type=self.param_type,
codependent_param=codependent_param_dict(parameters),
)
- ParamStats.compute_for_attr(ma)
- paramfit = ParamFit(parallel=False)
- for key, param, args, kwargs in ma.get_data_for_paramfit():
- paramfit.enqueue(key, param, args, kwargs)
- paramfit.fit()
- ma.set_data_from_paramfit(paramfit)
- return ma.model_function
+ if submodel == "symreg":
+ if ma.build_symreg():
+ return ma.model_function
+ else:
+ ParamStats.compute_for_attr(ma)
+ paramfit = ParamFit(parallel=False)
+ for key, param, args, kwargs in ma.get_data_for_paramfit():
+ paramfit.enqueue(key, param, args, kwargs)
+ paramfit.fit()
+ ma.set_data_from_paramfit(paramfit)
+ return ma.model_function
return df.StaticFunction(np.mean(data), n_samples=len(data))
split_feasible = True
@@ -1272,14 +1278,18 @@ class ModelAttribute:
param_type=self.param_type,
codependent_param=codependent_param_dict(parameters),
)
- ParamStats.compute_for_attr(ma)
- paramfit = ParamFit(parallel=False)
- for key, param, args, kwargs in ma.get_data_for_paramfit():
- paramfit.enqueue(key, param, args, kwargs)
- paramfit.fit()
- ma.set_data_from_paramfit(paramfit)
- if type(ma.model_function) == df.AnalyticFunction:
- return ma.model_function
+ if submodel == "symreg":
+ if ma.build_symreg():
+ return ma.model_function
+ else:
+ ParamStats.compute_for_attr(ma)
+ paramfit = ParamFit(parallel=False)
+ for key, param, args, kwargs in ma.get_data_for_paramfit():
+ paramfit.enqueue(key, param, args, kwargs)
+ paramfit.fit()
+ ma.set_data_from_paramfit(paramfit)
+ if type(ma.model_function) == df.AnalyticFunction:
+ return ma.model_function
symbol_index = np.argmin(loss)
unique_values = list(set(map(lambda p: p[symbol_index], parameters)))
@@ -1303,6 +1313,7 @@ class ModelAttribute:
with_nonbinary_nodes=with_nonbinary_nodes,
ignore_irrelevant_parameters=ignore_irrelevant_parameters,
loss_ignore_scalar=loss_ignore_scalar,
+ submodel=submodel,
threshold=threshold,
relevance_threshold=relevance_threshold,
level=level + 1,