summaryrefslogtreecommitdiff
path: root/lib/functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/functions.py')
-rw-r--r--lib/functions.py8
1 files changed, 6 insertions, 2 deletions
diff --git a/lib/functions.py b/lib/functions.py
index f4e1709..4a6dac2 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -417,9 +417,10 @@ class SubstateFunction(ModelFunction):
class SKLearnRegressionFunction(ModelFunction):
- def __init__(self, value, regressor, ignore_index):
+ def __init__(self, value, regressor, categorial_to_index, ignore_index):
super().__init__(value)
self.regressor = regressor
+ self.categorial_to_index = categorial_to_index
self.ignore_index = ignore_index
def is_predictable(self, param_list=None):
@@ -442,7 +443,10 @@ class SKLearnRegressionFunction(ModelFunction):
actual_param_list = list()
for i, param in enumerate(param_list):
if not self.ignore_index[i]:
- actual_param_list.append(param)
+ if i in self.categorial_to_index:
+ actual_param_list.append(self.categorial_to_index[i][param])
+ else:
+ actual_param_list.append(param)
return self.regressor.predict(np.array([actual_param_list]))