summaryrefslogtreecommitdiff
path: root/lib/functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/functions.py')
-rw-r--r--lib/functions.py85
1 files changed, 84 insertions, 1 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 3c2b424..0b0044b 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -11,7 +11,7 @@ import numpy as np
import os
import re
from scipy import optimize
-from .utils import is_numeric
+from .utils import is_numeric, param_to_ndarray
logger = logging.getLogger(__name__)
@@ -600,6 +600,89 @@ class XGBoostFunction(SKLearnRegressionFunction):
return 1 + max(ret)
+# first-order linear function (no feature interaction)
+class FOLFunction(ModelFunction):
+ def __init__(self, value, parameters, num_args=0):
+ super().__init__(value)
+ self.parameter_names = parameters
+ self._num_args = num_args
+ self.fit_success = False
+
+ def fit(self, param_values, data):
+ categorial_to_scalar = bool(
+ int(os.getenv("DFATOOL_PARAM_CATEGORIAL_TO_SCALAR", "0"))
+ )
+ fit_parameters, categorial_to_index, ignore_index = param_to_ndarray(
+ param_values,
+ with_nan=False,
+ categorial_to_scalar=categorial_to_scalar,
+ )
+ self.categorial_to_index = categorial_to_index
+ self.ignore_index = ignore_index
+ fit_parameters = fit_parameters.swapaxes(0, 1)
+ num_vars = fit_parameters.shape[0]
+ funbuf = "lambda reg_param, model_param: 0"
+ for i in range(num_vars):
+ funbuf += f" + reg_param[{i}] * model_param[{i}]"
+ self._function_str = self.model_function = funbuf
+ self._function = eval(funbuf)
+
+ error_function = lambda P, X, y: self._function(P, X) - y
+ self.model_args = list(np.ones((num_vars)))
+ try:
+ res = optimize.least_squares(
+ error_function, self.model_args, args=(fit_parameters, data), xtol=2e-15
+ )
+ except ValueError as err:
+ logger.warning(f"Fit failed: {err} (function: {self.model_function})")
+ return
+ if res.status > 0:
+ self.model_args = res.x
+ self.fit_success = True
+ else:
+ logger.warning(
+ f"Fit failed: {res.message} (function: {self.model_function})"
+ )
+
+ def is_predictable(self, param_list=None):
+ """
+ Return whether the model function can be evaluated on the given parameter values.
+ """
+ return True
+
+ def eval(self, param_list=None):
+ """
+ Evaluate model function with specified param/arg values.
+
+ Far a Staticfunction, this is just the static value
+
+ """
+ if param_list is None:
+ return self.value
+ actual_param_list = list()
+ for i, param in enumerate(param_list):
+ if not self.ignore_index[i]:
+ if i in self.categorial_to_index:
+ try:
+ actual_param_list.append(self.categorial_to_index[i][param])
+ except KeyError:
+ # param was not part of training data. substitute an unused scalar.
+ # Note that all param values which were not part of training data map to the same scalar this way.
+ # This should be harmless.
+ actual_param_list.append(
+ max(self.categorial_to_index[i].values()) + 1
+ )
+ else:
+ actual_param_list.append(param)
+ try:
+ return self._function(self.model_args, actual_param_list)
+ except FloatingPointError as e:
+ logger.error(
+ f"{e} when predicting {self._function_str}({param_list}), returning static value"
+ )
+ return self.value
+
+
class AnalyticFunction(ModelFunction):
"""
A multi-dimensional model function, generated from a string, which can be optimized using regression.