summaryrefslogtreecommitdiff
path: root/lib/parameters.py
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2022-01-17 08:41:21 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2022-01-17 08:41:21 +0100
commit5ad2c5ef6d84763b579c22121052e875e434a119 (patch)
tree3622116b155e102d87a6a89d51fcd1f72093adc2 /lib/parameters.py
parent4fb775981b2ab8fa57ccd3ef22d3f4f2e9149e25 (diff)
CART, XGBoost: Fall back to StaticFunction if no parameters are available
Diffstat (limited to 'lib/parameters.py')
-rw-r--r--lib/parameters.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/lib/parameters.py b/lib/parameters.py
index 4b6dc2c..401c7c6 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -924,6 +924,11 @@ class ModelAttribute:
fit_parameters, category_to_index, ignore_index = param_to_ndarray(
parameters, with_nan=False, categorial_to_scalar=categorial_to_scalar
)
+ if fit_parameters.shape[1] == 0:
+ logger.warning(
+ f"Cannot generate CART due to lack of parameters: parameter shape is {np.array(parameters).shape}, fit_parameter shape is {fit_parameters.shape}"
+ )
+ return df.StaticFunction(np.mean(data))
cart.fit(fit_parameters, data)
self.model_function = df.SKLearnRegressionFunction(
np.mean(data), cart, category_to_index, ignore_index
@@ -945,6 +950,11 @@ class ModelAttribute:
fit_parameters, category_to_index, ignore_index = param_to_ndarray(
parameters, with_nan=False, categorial_to_scalar=categorial_to_scalar
)
+ if fit_parameters.shape[1] == 0:
+ logger.warning(
+ f"Cannot run XGBoost due to lack of parameters: parameter shape is {np.array(parameters).shape}, fit_parameter shape is {fit_parameters.shape}"
+ )
+ return df.StaticFunction(np.mean(data))
xgb.fit(fit_parameters, np.reshape(data, (-1, 1)))
self.model_function = df.SKLearnRegressionFunction(
np.mean(data), xgb, category_to_index, ignore_index