From 164c9e5a7e1015578ee6005c35dc1cbc39125550 Mon Sep 17 00:00:00 2001 From: Birte Kristina Friesel Date: Fri, 8 Mar 2024 10:55:46 +0100 Subject: make FOLFunction inherit from SKLearnRegressionFunction --- lib/functions.py | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/lib/functions.py b/lib/functions.py index 212e47a..0f19668 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -1555,14 +1555,9 @@ class SymbolicRegressionFunction(SKLearnRegressionFunction): # first-order linear function (no feature interaction) -class FOLFunction(ModelFunction): +class FOLFunction(SKLearnRegressionFunction): always_predictable = True - - def __init__(self, value, parameters, num_args=0, **kwargs): - super().__init__(value, **kwargs) - self.parameter_names = parameters - self._num_args = num_args - self.fit_success = False + has_eval_arr = False def fit(self, param_values, data, ignore_param_indexes=None): self.categorical_to_scalar = bool( @@ -1599,7 +1594,7 @@ class FOLFunction(ModelFunction): num_vars += 1 funbuf = "regression_arg(0)" num_vars = 1 - for j, param_name in enumerate(self.parameter_names): + for j, param_name in enumerate(self.param_names): if self.ignore_index[j]: continue else: @@ -1608,10 +1603,10 @@ class FOLFunction(ModelFunction): f" + regression_arg({num_vars}) * parameter({param_name})" ) num_vars += 1 - for k in range(j + 1, len(self.parameter_names)): + for k in range(j + 1, len(self.param_names)): if self.ignore_index[j]: continue - funbuf += f" + regression_arg({num_vars}) * parameter({param_name}) * parameter({self.parameter_names[k]})" + funbuf += f" + regression_arg({num_vars}) * parameter({param_name}) * parameter({self.param_names[k]})" num_vars += 1 else: num_vars = fit_parameters.shape[0] + 1 @@ -1620,7 +1615,7 @@ class FOLFunction(ModelFunction): rawbuf += f" + reg_param[{i}] * model_param[{i-1}]" funbuf = "regression_arg(0)" i = 1 - for j, param_name in enumerate(self.parameter_names): + for j, param_name in enumerate(self.param_names): if self.ignore_index[j]: continue else: @@ -1711,17 +1706,21 @@ class FOLFunction(ModelFunction): { "type": "analytic", "functionStr": self.model_function, - "argCount": self._num_args, - "parameterNames": self.parameter_names, + "argCount": self.arg_count, + "parameterNames": self.param_names, "regressionModel": list(self.model_args), } ) return ret def hyper_to_dref(self): - return { - "fol/categorical to scalar": int(self.categorical_to_scalar), - } + hyper = super().hyper_to_dref() + hyper.update( + { + "fol/categorical to scalar": int(self.categorical_to_scalar), + } + ) + return hyper class AnalyticFunction(ModelFunction): -- cgit v1.2.3