diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-03-08 10:55:46 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-03-08 10:55:46 +0100 |
commit | 164c9e5a7e1015578ee6005c35dc1cbc39125550 (patch) | |
tree | 3f21e799eec930460b1a008da0f43160f16a06a4 | |
parent | e31a672ba4dc8a9b605cb02882f8afbfa4bbbe7b (diff) |
make FOLFunction inherit from SKLearnRegressionFunction
-rw-r--r-- | lib/functions.py | 31 |
1 files changed, 15 insertions, 16 deletions
diff --git a/lib/functions.py b/lib/functions.py index 212e47a..0f19668 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -1555,14 +1555,9 @@ class SymbolicRegressionFunction(SKLearnRegressionFunction): # first-order linear function (no feature interaction) -class FOLFunction(ModelFunction): +class FOLFunction(SKLearnRegressionFunction): always_predictable = True - - def __init__(self, value, parameters, num_args=0, **kwargs): - super().__init__(value, **kwargs) - self.parameter_names = parameters - self._num_args = num_args - self.fit_success = False + has_eval_arr = False def fit(self, param_values, data, ignore_param_indexes=None): self.categorical_to_scalar = bool( @@ -1599,7 +1594,7 @@ class FOLFunction(ModelFunction): num_vars += 1 funbuf = "regression_arg(0)" num_vars = 1 - for j, param_name in enumerate(self.parameter_names): + for j, param_name in enumerate(self.param_names): if self.ignore_index[j]: continue else: @@ -1608,10 +1603,10 @@ class FOLFunction(ModelFunction): f" + regression_arg({num_vars}) * parameter({param_name})" ) num_vars += 1 - for k in range(j + 1, len(self.parameter_names)): + for k in range(j + 1, len(self.param_names)): if self.ignore_index[j]: continue - funbuf += f" + regression_arg({num_vars}) * parameter({param_name}) * parameter({self.parameter_names[k]})" + funbuf += f" + regression_arg({num_vars}) * parameter({param_name}) * parameter({self.param_names[k]})" num_vars += 1 else: num_vars = fit_parameters.shape[0] + 1 @@ -1620,7 +1615,7 @@ class FOLFunction(ModelFunction): rawbuf += f" + reg_param[{i}] * model_param[{i-1}]" funbuf = "regression_arg(0)" i = 1 - for j, param_name in enumerate(self.parameter_names): + for j, param_name in enumerate(self.param_names): if self.ignore_index[j]: continue else: @@ -1711,17 +1706,21 @@ class FOLFunction(ModelFunction): { "type": "analytic", "functionStr": self.model_function, - "argCount": self._num_args, - "parameterNames": self.parameter_names, + "argCount": self.arg_count, + "parameterNames": self.param_names, "regressionModel": list(self.model_args), } ) return ret def hyper_to_dref(self): - return { - "fol/categorical to scalar": int(self.categorical_to_scalar), - } + hyper = super().hyper_to_dref() + hyper.update( + { + "fol/categorical to scalar": int(self.categorical_to_scalar), + } + ) + return hyper class AnalyticFunction(ModelFunction): |