diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2022-01-06 11:58:09 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2022-01-06 11:58:09 +0100 |
commit | 01860ccf2addcb1dd84418887f76b88c4acdf53a (patch) | |
tree | 36fd3ba51210caab7315891c24d019eeb069378d /lib/functions.py | |
parent | 09a1140ee677f633c28fb887692b417874402356 (diff) |
sklearn (CART, XGBoost): support mapping of categorial parameter values to scalar indexes
Diffstat (limited to 'lib/functions.py')
-rw-r--r-- | lib/functions.py | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/lib/functions.py b/lib/functions.py index f4e1709..4a6dac2 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -417,9 +417,10 @@ class SubstateFunction(ModelFunction): class SKLearnRegressionFunction(ModelFunction): - def __init__(self, value, regressor, ignore_index): + def __init__(self, value, regressor, categorial_to_index, ignore_index): super().__init__(value) self.regressor = regressor + self.categorial_to_index = categorial_to_index self.ignore_index = ignore_index def is_predictable(self, param_list=None): @@ -442,7 +443,10 @@ class SKLearnRegressionFunction(ModelFunction): actual_param_list = list() for i, param in enumerate(param_list): if not self.ignore_index[i]: - actual_param_list.append(param) + if i in self.categorial_to_index: + actual_param_list.append(self.categorial_to_index[i][param]) + else: + actual_param_list.append(param) return self.regressor.predict(np.array([actual_param_list])) |