summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-03-08 10:55:46 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-03-08 10:55:46 +0100
commit164c9e5a7e1015578ee6005c35dc1cbc39125550 (patch)
tree3f21e799eec930460b1a008da0f43160f16a06a4 /lib
parente31a672ba4dc8a9b605cb02882f8afbfa4bbbe7b (diff)
make FOLFunction inherit from SKLearnRegressionFunction
Diffstat (limited to 'lib')
-rw-r--r--lib/functions.py31
1 files changed, 15 insertions, 16 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 212e47a..0f19668 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -1555,14 +1555,9 @@ class SymbolicRegressionFunction(SKLearnRegressionFunction):
# first-order linear function (no feature interaction)
-class FOLFunction(ModelFunction):
+class FOLFunction(SKLearnRegressionFunction):
always_predictable = True
-
- def __init__(self, value, parameters, num_args=0, **kwargs):
- super().__init__(value, **kwargs)
- self.parameter_names = parameters
- self._num_args = num_args
- self.fit_success = False
+ has_eval_arr = False
def fit(self, param_values, data, ignore_param_indexes=None):
self.categorical_to_scalar = bool(
@@ -1599,7 +1594,7 @@ class FOLFunction(ModelFunction):
num_vars += 1
funbuf = "regression_arg(0)"
num_vars = 1
- for j, param_name in enumerate(self.parameter_names):
+ for j, param_name in enumerate(self.param_names):
if self.ignore_index[j]:
continue
else:
@@ -1608,10 +1603,10 @@ class FOLFunction(ModelFunction):
f" + regression_arg({num_vars}) * parameter({param_name})"
)
num_vars += 1
- for k in range(j + 1, len(self.parameter_names)):
+ for k in range(j + 1, len(self.param_names)):
if self.ignore_index[j]:
continue
- funbuf += f" + regression_arg({num_vars}) * parameter({param_name}) * parameter({self.parameter_names[k]})"
+ funbuf += f" + regression_arg({num_vars}) * parameter({param_name}) * parameter({self.param_names[k]})"
num_vars += 1
else:
num_vars = fit_parameters.shape[0] + 1
@@ -1620,7 +1615,7 @@ class FOLFunction(ModelFunction):
rawbuf += f" + reg_param[{i}] * model_param[{i-1}]"
funbuf = "regression_arg(0)"
i = 1
- for j, param_name in enumerate(self.parameter_names):
+ for j, param_name in enumerate(self.param_names):
if self.ignore_index[j]:
continue
else:
@@ -1711,17 +1706,21 @@ class FOLFunction(ModelFunction):
{
"type": "analytic",
"functionStr": self.model_function,
- "argCount": self._num_args,
- "parameterNames": self.parameter_names,
+ "argCount": self.arg_count,
+ "parameterNames": self.param_names,
"regressionModel": list(self.model_args),
}
)
return ret
def hyper_to_dref(self):
- return {
- "fol/categorical to scalar": int(self.categorical_to_scalar),
- }
+ hyper = super().hyper_to_dref()
+ hyper.update(
+ {
+ "fol/categorical to scalar": int(self.categorical_to_scalar),
+ }
+ )
+ return hyper
class AnalyticFunction(ModelFunction):