summaryrefslogtreecommitdiff
path: root/lib/functions.py
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2022-01-06 11:58:09 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2022-01-06 11:58:09 +0100
commit01860ccf2addcb1dd84418887f76b88c4acdf53a (patch)
tree36fd3ba51210caab7315891c24d019eeb069378d /lib/functions.py
parent09a1140ee677f633c28fb887692b417874402356 (diff)
sklearn (CART, XGBoost): support mapping of categorial parameter values to scalar indexes
Diffstat (limited to 'lib/functions.py')
-rw-r--r--lib/functions.py8
1 files changed, 6 insertions, 2 deletions
diff --git a/lib/functions.py b/lib/functions.py
index f4e1709..4a6dac2 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -417,9 +417,10 @@ class SubstateFunction(ModelFunction):
class SKLearnRegressionFunction(ModelFunction):
- def __init__(self, value, regressor, ignore_index):
+ def __init__(self, value, regressor, categorial_to_index, ignore_index):
super().__init__(value)
self.regressor = regressor
+ self.categorial_to_index = categorial_to_index
self.ignore_index = ignore_index
def is_predictable(self, param_list=None):
@@ -442,7 +443,10 @@ class SKLearnRegressionFunction(ModelFunction):
actual_param_list = list()
for i, param in enumerate(param_list):
if not self.ignore_index[i]:
- actual_param_list.append(param)
+ if i in self.categorial_to_index:
+ actual_param_list.append(self.categorial_to_index[i][param])
+ else:
+ actual_param_list.append(param)
return self.regressor.predict(np.array([actual_param_list]))