diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-02-20 07:23:36 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-02-20 07:23:36 +0100 |
commit | ddf6e00b3b16a07b994107a79450909f43588445 (patch) | |
tree | beb8d4cb2b9de4a20e427fe30caa177599710cb8 | |
parent | 8290d9ee20b5d2305c5fd519bf534029f36c30d9 (diff) |
Re-add (very very basic, for now) Symbolic Regression support
-rw-r--r-- | lib/cli.py | 6 | ||||
-rw-r--r-- | lib/functions.py | 103 | ||||
-rw-r--r-- | lib/model.py | 19 | ||||
-rw-r--r-- | lib/parameters.py | 26 |
4 files changed, 148 insertions, 6 deletions
@@ -124,6 +124,10 @@ def print_staticinfo(prefix, info): print(f"{prefix}: {info.value}") +def print_symreginfo(prefix, info): + print(f"{prefix}: {str(info.regressor)}") + + def print_cartinfo(prefix, info): _print_cartinfo(prefix, info.to_json()) @@ -203,6 +207,8 @@ def print_model(prefix, info): print_lmtinfo(prefix, info) elif type(info) is df.XGBoostFunction: print_xgbinfo(prefix, info) + elif type(info) is df.SymbolicRegressionFunction: + print_symreginfo(prefix, info) else: print(f"{prefix}: {type(info)} UNIMPLEMENTED") diff --git a/lib/functions.py b/lib/functions.py index 4940956..d0ef7e2 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -949,6 +949,109 @@ class XGBoostFunction(SKLearnRegressionFunction): } +class SymbolicRegressionFunction(ModelFunction): + def __init__(self, value, parameters, num_args=0, **kwargs): + super().__init__(value, **kwargs) + self.parameter_names = parameters + self._num_args = num_args + self.fit_success = False + + def fit(self, param_values, data, ignore_param_indexes=None): + self.categorical_to_scalar = bool( + int(os.getenv("DFATOOL_PARAM_CATEGORICAL_TO_SCALAR", "0")) + ) + fit_parameters, categorical_to_index, ignore_index = param_to_ndarray( + param_values, + with_nan=False, + categorical_to_scalar=self.categorical_to_scalar, + ignore_indexes=ignore_param_indexes, + ) + self.categorical_to_index = categorical_to_index + self.ignore_index = ignore_index + + if fit_parameters.shape[1] == 0: + logger.debug( + f"Cannot use Symbolic Regression due to lack of parameters: parameter shape is {np.array(param_values).shape}, fit_parameter shape is {fit_parameters.shape}" + ) + return + + from dfatool.gplearn.genetic import SymbolicRegressor + + self.regressor = SymbolicRegressor() + self.regressor.fit(fit_parameters, data) + self.fit_success = True + + # TODO inherit from SKLearnRegressionFunction, making this obsolete. + # requires SKLearnRegressionFunction to provide .fit and refactoring + # build_tree to move .fit into SKLearnRegressionFunction descendants. + def is_predictable(self, param_list=None): + """ + Return whether the model function can be evaluated on the given parameter values. + + For a StaticFunction, this is always the case (i.e., this function always returns true). + """ + return True + + # TODO inherit from SKLearnRegressionFunction, making this obsolete. + # requires SKLearnRegressionFunction to provide .fit and refactoring + # build_tree to move .fit into SKLearnRegressionFunction descendants. + def eval(self, param_list=None): + """ + Evaluate model function with specified param/arg values. + + Far a Staticfunction, this is just the static value + + """ + if param_list is None: + return self.value + actual_param_list = list() + for i, param in enumerate(param_list): + if not self.ignore_index[i]: + if i in self.categorical_to_index: + try: + actual_param_list.append(self.categorical_to_index[i][param]) + except KeyError: + # param was not part of training data. substitute an unused scalar. + # Note that all param values which were not part of training data map to the same scalar this way. + # This should be harmless. + actual_param_list.append( + max(self.categorical_to_index[i].values()) + 1 + ) + else: + actual_param_list.append(int(param)) + predictions = self.regressor.predict(np.array([actual_param_list])) + if predictions.shape == (1,): + return predictions[0] + return predictions + + # TODO inherit from SKLearnRegressionFunction, making this obsolete. + # requires SKLearnRegressionFunction to provide .fit and refactoring + # build_tree to move .fit into SKLearnRegressionFunction descendants. + def eval_arr(self, params): + actual_params = list() + for param_tuple in params: + actual_param_list = list() + for i, param in enumerate(param_tuple): + if not self.ignore_index[i]: + if i in self.categorical_to_index: + try: + actual_param_list.append( + self.categorical_to_index[i][param] + ) + except KeyError: + # param was not part of training data. substitute an unused scalar. + # Note that all param values which were not part of training data map to the same scalar this way. + # This should be harmless. + actual_param_list.append( + max(self.categorical_to_index[i].values()) + 1 + ) + else: + actual_param_list.append(int(param)) + actual_params.append(actual_param_list) + predictions = self.regressor.predict(np.array(actual_params)) + return predictions + + # first-order linear function (no feature interaction) class FOLFunction(ModelFunction): always_predictable = True diff --git a/lib/model.py b/lib/model.py index 5844113..e2a455d 100644 --- a/lib/model.py +++ b/lib/model.py @@ -296,9 +296,11 @@ class AnalyticModel: if fallback: return list( map( - lambda p: lut_model[name][key][tuple(p)] - if tuple(p) in lut_model[name][key] - else static_model[name][key], + lambda p: ( + lut_model[name][key][tuple(p)] + if tuple(p) in lut_model[name][key] + else static_model[name][key] + ), params, ) ) @@ -320,6 +322,7 @@ class AnalyticModel: paramfit = ParamFit() tree_allowed = bool(int(os.getenv("DFATOOL_DTREE_ENABLED", "1"))) use_fol = bool(int(os.getenv("DFATOOL_FIT_FOL", "0"))) + use_symreg = bool(int(os.getenv("DFATOOL_FIT_SYMREG", "0"))) tree_required = dict() for name in self.names: @@ -329,6 +332,8 @@ class AnalyticModel: self.attr_by_name[name][attr].fit_override_function() elif use_fol: self.attr_by_name[name][attr].build_fol_model() + elif use_symreg: + self.attr_by_name[name][attr].build_symreg_model() elif self.attr_by_name[name][ attr ].all_relevant_parameters_are_none_or_numeric(): @@ -395,9 +400,11 @@ class AnalyticModel: return model_function.eval_arr(kwargs["params"]) return list( map( - lambda p: model_function.eval(p) - if model_function.is_predictable(p) - else static_model[name][key], + lambda p: ( + model_function.eval(p) + if model_function.is_predictable(p) + else static_model[name][key] + ), kwargs["params"], ) ) diff --git a/lib/parameters.py b/lib/parameters.py index d3f77a3..06dc70a 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -859,6 +859,29 @@ class ModelAttribute: else: logger.warning(f"Fit of first-order linear model function failed.") + def build_symreg_model(self): + ignore_irrelevant = bool( + int(os.getenv("DFATOOL_DTREE_IGNORE_IRRELEVANT_PARAMS", "0")) + ) + ignore_param_indexes = list() + if ignore_irrelevant: + for param_index, param in enumerate(self.param_names): + if not self.stats.depends_on_param(param): + ignore_param_indexes.append(param_index) + x = df.SymbolicRegressionFunction( + self.median, + self.param_names, + n_samples=self.data.shape[0], + num_args=self.arg_count, + ) + x.fit(self.param_values, self.data, ignore_param_indexes=ignore_param_indexes) + if x.fit_success: + self.model_function = x + else: + logger.debug( + f"Symbolic Regression model generation for {self.name} {self.attr} failed." + ) + def fit_override_function(self): function_str = self.function_override x = df.AnalyticFunction( @@ -906,6 +929,7 @@ class ModelAttribute: with_sklearn_decart=None, with_lmt=None, with_xgboost=None, + with_gplearn_symreg=None, ignore_irrelevant_parameters=None, loss_ignore_scalar=None, threshold=100, @@ -948,6 +972,8 @@ class ModelAttribute: with_lmt = bool(int(os.getenv("DFATOOL_DTREE_LMT", "0"))) if with_xgboost is None: with_xgboost = bool(int(os.getenv("DFATOOL_USE_XGBOOST", "0"))) + if with_gplearn_symreg is None: + with_gplearn_symreg = bool(int(os.getenv("DFATOOL_USE_SYMREG", "0"))) if ignore_irrelevant_parameters is None: ignore_irrelevant_parameters = bool( int(os.getenv("DFATOOL_DTREE_IGNORE_IRRELEVANT_PARAMS", "0")) |