summaryrefslogtreecommitdiff
path: root/lib/functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/functions.py')
-rw-r--r--lib/functions.py30
1 files changed, 30 insertions, 0 deletions
diff --git a/lib/functions.py b/lib/functions.py
index b1477da..320e8ed 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -409,6 +409,36 @@ class SubstateFunction(ModelFunction):
return "SubstateFunction"
+class SKLearnRegressionFunction(ModelFunction):
+ def __init__(self, value, regressor, ignore_index):
+ super().__init__(value)
+ self.regressor = regressor
+ self.ignore_index = ignore_index
+
+ def is_predictable(self, param_list=None):
+ """
+ Return whether the model function can be evaluated on the given parameter values.
+
+ For a StaticFunction, this is always the case (i.e., this function always returns true).
+ """
+ return True
+
+ def eval(self, param_list=None):
+ """
+ Evaluate model function with specified param/arg values.
+
+ Far a Staticfunction, this is just the static value
+
+ """
+ if param_list is None:
+ return self.value
+ actual_param_list = list()
+ for i, param in enumerate(param_list):
+ if not self.ignore_index[i]:
+ actual_param_list.append(param)
+ return self.regressor.predict([actual_param_list])
+
+
class AnalyticFunction(ModelFunction):
"""
A multi-dimensional model function, generated from a string, which can be optimized using regression.