summaryrefslogtreecommitdiff
path: root/lib/parameters.py
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-02-20 07:23:36 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-02-20 07:23:36 +0100
commitddf6e00b3b16a07b994107a79450909f43588445 (patch)
treebeb8d4cb2b9de4a20e427fe30caa177599710cb8 /lib/parameters.py
parent8290d9ee20b5d2305c5fd519bf534029f36c30d9 (diff)
Re-add (very very basic, for now) Symbolic Regression support
Diffstat (limited to 'lib/parameters.py')
-rw-r--r--lib/parameters.py26
1 files changed, 26 insertions, 0 deletions
diff --git a/lib/parameters.py b/lib/parameters.py
index d3f77a3..06dc70a 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -859,6 +859,29 @@ class ModelAttribute:
else:
logger.warning(f"Fit of first-order linear model function failed.")
+ def build_symreg_model(self):
+ ignore_irrelevant = bool(
+ int(os.getenv("DFATOOL_DTREE_IGNORE_IRRELEVANT_PARAMS", "0"))
+ )
+ ignore_param_indexes = list()
+ if ignore_irrelevant:
+ for param_index, param in enumerate(self.param_names):
+ if not self.stats.depends_on_param(param):
+ ignore_param_indexes.append(param_index)
+ x = df.SymbolicRegressionFunction(
+ self.median,
+ self.param_names,
+ n_samples=self.data.shape[0],
+ num_args=self.arg_count,
+ )
+ x.fit(self.param_values, self.data, ignore_param_indexes=ignore_param_indexes)
+ if x.fit_success:
+ self.model_function = x
+ else:
+ logger.debug(
+ f"Symbolic Regression model generation for {self.name} {self.attr} failed."
+ )
+
def fit_override_function(self):
function_str = self.function_override
x = df.AnalyticFunction(
@@ -906,6 +929,7 @@ class ModelAttribute:
with_sklearn_decart=None,
with_lmt=None,
with_xgboost=None,
+ with_gplearn_symreg=None,
ignore_irrelevant_parameters=None,
loss_ignore_scalar=None,
threshold=100,
@@ -948,6 +972,8 @@ class ModelAttribute:
with_lmt = bool(int(os.getenv("DFATOOL_DTREE_LMT", "0")))
if with_xgboost is None:
with_xgboost = bool(int(os.getenv("DFATOOL_USE_XGBOOST", "0")))
+ if with_gplearn_symreg is None:
+ with_gplearn_symreg = bool(int(os.getenv("DFATOOL_USE_SYMREG", "0")))
if ignore_irrelevant_parameters is None:
ignore_irrelevant_parameters = bool(
int(os.getenv("DFATOOL_DTREE_IGNORE_IRRELEVANT_PARAMS", "0"))