diff options
Diffstat (limited to 'lib/functions.py')
-rw-r--r-- | lib/functions.py | 30 |
1 files changed, 30 insertions, 0 deletions
diff --git a/lib/functions.py b/lib/functions.py index b1477da..320e8ed 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -409,6 +409,36 @@ class SubstateFunction(ModelFunction): return "SubstateFunction" +class SKLearnRegressionFunction(ModelFunction): + def __init__(self, value, regressor, ignore_index): + super().__init__(value) + self.regressor = regressor + self.ignore_index = ignore_index + + def is_predictable(self, param_list=None): + """ + Return whether the model function can be evaluated on the given parameter values. + + For a StaticFunction, this is always the case (i.e., this function always returns true). + """ + return True + + def eval(self, param_list=None): + """ + Evaluate model function with specified param/arg values. + + Far a Staticfunction, this is just the static value + + """ + if param_list is None: + return self.value + actual_param_list = list() + for i, param in enumerate(param_list): + if not self.ignore_index[i]: + actual_param_list.append(param) + return self.regressor.predict([actual_param_list]) + + class AnalyticFunction(ModelFunction): """ A multi-dimensional model function, generated from a string, which can be optimized using regression. |