diff options
Diffstat (limited to 'lib/functions.py')
-rw-r--r-- | lib/functions.py | 85 |
1 files changed, 84 insertions, 1 deletions
diff --git a/lib/functions.py b/lib/functions.py index 3c2b424..0b0044b 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -11,7 +11,7 @@ import numpy as np import os import re from scipy import optimize -from .utils import is_numeric +from .utils import is_numeric, param_to_ndarray logger = logging.getLogger(__name__) @@ -600,6 +600,89 @@ class XGBoostFunction(SKLearnRegressionFunction): return 1 + max(ret) +# first-order linear function (no feature interaction) +class FOLFunction(ModelFunction): + def __init__(self, value, parameters, num_args=0): + super().__init__(value) + self.parameter_names = parameters + self._num_args = num_args + self.fit_success = False + + def fit(self, param_values, data): + categorial_to_scalar = bool( + int(os.getenv("DFATOOL_PARAM_CATEGORIAL_TO_SCALAR", "0")) + ) + fit_parameters, categorial_to_index, ignore_index = param_to_ndarray( + param_values, + with_nan=False, + categorial_to_scalar=categorial_to_scalar, + ) + self.categorial_to_index = categorial_to_index + self.ignore_index = ignore_index + fit_parameters = fit_parameters.swapaxes(0, 1) + num_vars = fit_parameters.shape[0] + funbuf = "lambda reg_param, model_param: 0" + for i in range(num_vars): + funbuf += f" + reg_param[{i}] * model_param[{i}]" + self._function_str = self.model_function = funbuf + self._function = eval(funbuf) + + error_function = lambda P, X, y: self._function(P, X) - y + self.model_args = list(np.ones((num_vars))) + try: + res = optimize.least_squares( + error_function, self.model_args, args=(fit_parameters, data), xtol=2e-15 + ) + except ValueError as err: + logger.warning(f"Fit failed: {err} (function: {self.model_function})") + return + if res.status > 0: + self.model_args = res.x + self.fit_success = True + else: + logger.warning( + f"Fit failed: {res.message} (function: {self.model_function})" + ) + + def is_predictable(self, param_list=None): + """ + Return whether the model function can be evaluated on the given parameter values. + """ + return True + + def eval(self, param_list=None): + """ + Evaluate model function with specified param/arg values. + + Far a Staticfunction, this is just the static value + + """ + if param_list is None: + return self.value + actual_param_list = list() + for i, param in enumerate(param_list): + if not self.ignore_index[i]: + if i in self.categorial_to_index: + try: + actual_param_list.append(self.categorial_to_index[i][param]) + except KeyError: + # param was not part of training data. substitute an unused scalar. + # Note that all param values which were not part of training data map to the same scalar this way. + # This should be harmless. + actual_param_list.append( + max(self.categorial_to_index[i].values()) + 1 + ) + else: + actual_param_list.append(param) + try: + return self._function(self.model_args, actual_param_list) + except FloatingPointError as e: + logger.error( + f"{e} when predicting {self._function_str}({param_list}), returning static value" + ) + return self.value + + class AnalyticFunction(ModelFunction): """ A multi-dimensional model function, generated from a string, which can be optimized using regression. |