summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-02-20 07:23:36 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-02-20 07:23:36 +0100
commitddf6e00b3b16a07b994107a79450909f43588445 (patch)
treebeb8d4cb2b9de4a20e427fe30caa177599710cb8
parent8290d9ee20b5d2305c5fd519bf534029f36c30d9 (diff)
Re-add (very very basic, for now) Symbolic Regression support
-rw-r--r--lib/cli.py6
-rw-r--r--lib/functions.py103
-rw-r--r--lib/model.py19
-rw-r--r--lib/parameters.py26
4 files changed, 148 insertions, 6 deletions
diff --git a/lib/cli.py b/lib/cli.py
index 81fd9ae..499e959 100644
--- a/lib/cli.py
+++ b/lib/cli.py
@@ -124,6 +124,10 @@ def print_staticinfo(prefix, info):
print(f"{prefix}: {info.value}")
+def print_symreginfo(prefix, info):
+ print(f"{prefix}: {str(info.regressor)}")
+
+
def print_cartinfo(prefix, info):
_print_cartinfo(prefix, info.to_json())
@@ -203,6 +207,8 @@ def print_model(prefix, info):
print_lmtinfo(prefix, info)
elif type(info) is df.XGBoostFunction:
print_xgbinfo(prefix, info)
+ elif type(info) is df.SymbolicRegressionFunction:
+ print_symreginfo(prefix, info)
else:
print(f"{prefix}: {type(info)} UNIMPLEMENTED")
diff --git a/lib/functions.py b/lib/functions.py
index 4940956..d0ef7e2 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -949,6 +949,109 @@ class XGBoostFunction(SKLearnRegressionFunction):
}
+class SymbolicRegressionFunction(ModelFunction):
+ def __init__(self, value, parameters, num_args=0, **kwargs):
+ super().__init__(value, **kwargs)
+ self.parameter_names = parameters
+ self._num_args = num_args
+ self.fit_success = False
+
+ def fit(self, param_values, data, ignore_param_indexes=None):
+ self.categorical_to_scalar = bool(
+ int(os.getenv("DFATOOL_PARAM_CATEGORICAL_TO_SCALAR", "0"))
+ )
+ fit_parameters, categorical_to_index, ignore_index = param_to_ndarray(
+ param_values,
+ with_nan=False,
+ categorical_to_scalar=self.categorical_to_scalar,
+ ignore_indexes=ignore_param_indexes,
+ )
+ self.categorical_to_index = categorical_to_index
+ self.ignore_index = ignore_index
+
+ if fit_parameters.shape[1] == 0:
+ logger.debug(
+ f"Cannot use Symbolic Regression due to lack of parameters: parameter shape is {np.array(param_values).shape}, fit_parameter shape is {fit_parameters.shape}"
+ )
+ return
+
+ from dfatool.gplearn.genetic import SymbolicRegressor
+
+ self.regressor = SymbolicRegressor()
+ self.regressor.fit(fit_parameters, data)
+ self.fit_success = True
+
+ # TODO inherit from SKLearnRegressionFunction, making this obsolete.
+ # requires SKLearnRegressionFunction to provide .fit and refactoring
+ # build_tree to move .fit into SKLearnRegressionFunction descendants.
+ def is_predictable(self, param_list=None):
+ """
+ Return whether the model function can be evaluated on the given parameter values.
+
+ For a StaticFunction, this is always the case (i.e., this function always returns true).
+ """
+ return True
+
+ # TODO inherit from SKLearnRegressionFunction, making this obsolete.
+ # requires SKLearnRegressionFunction to provide .fit and refactoring
+ # build_tree to move .fit into SKLearnRegressionFunction descendants.
+ def eval(self, param_list=None):
+ """
+ Evaluate model function with specified param/arg values.
+
+ Far a Staticfunction, this is just the static value
+
+ """
+ if param_list is None:
+ return self.value
+ actual_param_list = list()
+ for i, param in enumerate(param_list):
+ if not self.ignore_index[i]:
+ if i in self.categorical_to_index:
+ try:
+ actual_param_list.append(self.categorical_to_index[i][param])
+ except KeyError:
+ # param was not part of training data. substitute an unused scalar.
+ # Note that all param values which were not part of training data map to the same scalar this way.
+ # This should be harmless.
+ actual_param_list.append(
+ max(self.categorical_to_index[i].values()) + 1
+ )
+ else:
+ actual_param_list.append(int(param))
+ predictions = self.regressor.predict(np.array([actual_param_list]))
+ if predictions.shape == (1,):
+ return predictions[0]
+ return predictions
+
+ # TODO inherit from SKLearnRegressionFunction, making this obsolete.
+ # requires SKLearnRegressionFunction to provide .fit and refactoring
+ # build_tree to move .fit into SKLearnRegressionFunction descendants.
+ def eval_arr(self, params):
+ actual_params = list()
+ for param_tuple in params:
+ actual_param_list = list()
+ for i, param in enumerate(param_tuple):
+ if not self.ignore_index[i]:
+ if i in self.categorical_to_index:
+ try:
+ actual_param_list.append(
+ self.categorical_to_index[i][param]
+ )
+ except KeyError:
+ # param was not part of training data. substitute an unused scalar.
+ # Note that all param values which were not part of training data map to the same scalar this way.
+ # This should be harmless.
+ actual_param_list.append(
+ max(self.categorical_to_index[i].values()) + 1
+ )
+ else:
+ actual_param_list.append(int(param))
+ actual_params.append(actual_param_list)
+ predictions = self.regressor.predict(np.array(actual_params))
+ return predictions
+
+
# first-order linear function (no feature interaction)
class FOLFunction(ModelFunction):
always_predictable = True
diff --git a/lib/model.py b/lib/model.py
index 5844113..e2a455d 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -296,9 +296,11 @@ class AnalyticModel:
if fallback:
return list(
map(
- lambda p: lut_model[name][key][tuple(p)]
- if tuple(p) in lut_model[name][key]
- else static_model[name][key],
+ lambda p: (
+ lut_model[name][key][tuple(p)]
+ if tuple(p) in lut_model[name][key]
+ else static_model[name][key]
+ ),
params,
)
)
@@ -320,6 +322,7 @@ class AnalyticModel:
paramfit = ParamFit()
tree_allowed = bool(int(os.getenv("DFATOOL_DTREE_ENABLED", "1")))
use_fol = bool(int(os.getenv("DFATOOL_FIT_FOL", "0")))
+ use_symreg = bool(int(os.getenv("DFATOOL_FIT_SYMREG", "0")))
tree_required = dict()
for name in self.names:
@@ -329,6 +332,8 @@ class AnalyticModel:
self.attr_by_name[name][attr].fit_override_function()
elif use_fol:
self.attr_by_name[name][attr].build_fol_model()
+ elif use_symreg:
+ self.attr_by_name[name][attr].build_symreg_model()
elif self.attr_by_name[name][
attr
].all_relevant_parameters_are_none_or_numeric():
@@ -395,9 +400,11 @@ class AnalyticModel:
return model_function.eval_arr(kwargs["params"])
return list(
map(
- lambda p: model_function.eval(p)
- if model_function.is_predictable(p)
- else static_model[name][key],
+ lambda p: (
+ model_function.eval(p)
+ if model_function.is_predictable(p)
+ else static_model[name][key]
+ ),
kwargs["params"],
)
)
diff --git a/lib/parameters.py b/lib/parameters.py
index d3f77a3..06dc70a 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -859,6 +859,29 @@ class ModelAttribute:
else:
logger.warning(f"Fit of first-order linear model function failed.")
+ def build_symreg_model(self):
+ ignore_irrelevant = bool(
+ int(os.getenv("DFATOOL_DTREE_IGNORE_IRRELEVANT_PARAMS", "0"))
+ )
+ ignore_param_indexes = list()
+ if ignore_irrelevant:
+ for param_index, param in enumerate(self.param_names):
+ if not self.stats.depends_on_param(param):
+ ignore_param_indexes.append(param_index)
+ x = df.SymbolicRegressionFunction(
+ self.median,
+ self.param_names,
+ n_samples=self.data.shape[0],
+ num_args=self.arg_count,
+ )
+ x.fit(self.param_values, self.data, ignore_param_indexes=ignore_param_indexes)
+ if x.fit_success:
+ self.model_function = x
+ else:
+ logger.debug(
+ f"Symbolic Regression model generation for {self.name} {self.attr} failed."
+ )
+
def fit_override_function(self):
function_str = self.function_override
x = df.AnalyticFunction(
@@ -906,6 +929,7 @@ class ModelAttribute:
with_sklearn_decart=None,
with_lmt=None,
with_xgboost=None,
+ with_gplearn_symreg=None,
ignore_irrelevant_parameters=None,
loss_ignore_scalar=None,
threshold=100,
@@ -948,6 +972,8 @@ class ModelAttribute:
with_lmt = bool(int(os.getenv("DFATOOL_DTREE_LMT", "0")))
if with_xgboost is None:
with_xgboost = bool(int(os.getenv("DFATOOL_USE_XGBOOST", "0")))
+ if with_gplearn_symreg is None:
+ with_gplearn_symreg = bool(int(os.getenv("DFATOOL_USE_SYMREG", "0")))
if ignore_irrelevant_parameters is None:
ignore_irrelevant_parameters = bool(
int(os.getenv("DFATOOL_DTREE_IGNORE_IRRELEVANT_PARAMS", "0"))