From 22e116e51462aa98321c88f375dabf973b0ab2c2 Mon Sep 17 00:00:00 2001 From: Birte Kristina Friesel Date: Fri, 8 Mar 2024 08:45:09 +0100 Subject: move fit_parameters check to helper function --- lib/functions.py | 86 +++++++++++++++++--------------------------------------- 1 file changed, 26 insertions(+), 60 deletions(-) (limited to 'lib') diff --git a/lib/functions.py b/lib/functions.py index 88ecb76..47f2983 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -615,6 +615,13 @@ class SKLearnRegressionFunction(ModelFunction): ) self.fit_success = None + def _check_fit_param(self, fit_parameters, name, step): + if fit_parameters.shape[1] == 0: + logger.warning(f"Cannot generate {name}: {step} removed all parameters") + self.fit_success = False + return False + return True + def _preprocess_parameters(self, fit_parameters, data): if dfatool_preproc_relevance_method == "mi": return self._preprocess_parameters_mi(fit_parameters, data) @@ -785,20 +792,12 @@ class CARTFunction(SKLearnRegressionFunction): ) ) - if fit_parameters.shape[1] == 0: - logger.warning( - f"Cannot generate CART due to lack of parameters: parameter shape is {np.array(param_values).shape}, fit_parameter shape after param_to_ndarray is {fit_parameters.shape}" - ) - self.fit_success = False + if not self._check_fit_param(fit_parameters, "CART", "param_to_ndarray"): return self fit_parameters = self._preprocess_parameters(fit_parameters, data) - if fit_parameters.shape[1] == 0: - logger.warning( - f"Cannot generate CART due to lack of parameters: parameter shape is {np.array(param_values).shape}, fit_parameter shape after pre-processing is {fit_parameters.shape}" - ) - self.fit_success = False + if not self._check_fit_param(fit_parameters, "CART", "preprocessing"): return self logger.debug("Fitting sklearn CART ...") @@ -962,20 +961,13 @@ class LMTFunction(SKLearnRegressionFunction): with_nan=False, categorical_to_scalar=self.categorical_to_scalar, ) - if fit_parameters.shape[1] == 0: - logger.warning( - f"Cannot generate LMT due to lack of parameters: parameter shape is {np.array(param_values).shape}, fit_parameter shape is {fit_parameters.shape}" - ) - self.fit_success = False + + if not self._check_fit_param(fit_parameters, "LMT", "param_to_ndarray"): return self fit_parameters = self._preprocess_parameters(fit_parameters, data) - if fit_parameters.shape[1] == 0: - logger.warning( - f"Cannot generate LMT due to lack of parameters: parameter shape is {np.array(param_values).shape}, fit_parameter shape after pre-processing is {fit_parameters.shape}" - ) - self.fit_success = False + if not self._check_fit_param(fit_parameters, "LMT", "preprocessing"): return self logger.debug("Fitting LMT ...") @@ -1106,20 +1098,13 @@ class LightGBMFunction(SKLearnRegressionFunction): with_nan=False, categorical_to_scalar=self.categorical_to_scalar, ) - if fit_parameters.shape[1] == 0: - logger.warning( - f"Cannot run LightGBM due to lack of parameters: parameter shape is {np.array(param_values).shape}, fit_parameter shape is {fit_parameters.shape}" - ) - self.fit_success = False + + if not self._check_fit_param(fit_parameters, "LightGBM", "param_to_ndarray"): return self fit_parameters = self._preprocess_parameters(fit_parameters, data) - if fit_parameters.shape[1] == 0: - logger.warning( - f"Cannot generate LightGBM due to lack of parameters: parameter shape is {np.array(param_values).shape}, fit_parameter shape after pre-processing is {fit_parameters.shape}" - ) - self.fit_success = False + if not self._check_fit_param(fit_parameters, "LightGBM", "preprocessing"): return self import dfatool.lightgbm as lightgbm @@ -1295,20 +1280,13 @@ class XGBoostFunction(SKLearnRegressionFunction): with_nan=False, categorical_to_scalar=self.categorical_to_scalar, ) - if fit_parameters.shape[1] == 0: - logger.warning( - f"Cannot run XGBoost due to lack of parameters: parameter shape is {np.array(param_values).shape}, fit_parameter shape is {fit_parameters.shape}" - ) - self.fit_success = False + + if not self._check_fit_param(fit_parameters, "XGBoost", "param_to_ndarray"): return self fit_parameters = self._preprocess_parameters(fit_parameters, data) - if fit_parameters.shape[1] == 0: - logger.warning( - f"Cannot run XGBoost due to lack of parameters: parameter shape is {np.array(param_values).shape}, fit_parameter shape after pre-processing is {fit_parameters.shape}" - ) - self.fit_success = False + if not self._check_fit_param(fit_parameters, "XGBoost", "preprocessing"): return self import xgboost @@ -1495,20 +1473,16 @@ class SymbolicRegressionFunction(SKLearnRegressionFunction): ignore_indexes=ignore_param_indexes, ) - if fit_parameters.shape[1] == 0: - logger.debug( - f"Cannot use Symbolic Regression due to lack of parameters: parameter shape is {np.array(param_values).shape}, fit_parameter shape is {fit_parameters.shape}" - ) - self.fit_success = False + if not self._check_fit_param( + fit_parameters, "Symbolic Regression", "param_to_ndarray" + ): return self fit_parameters = self._preprocess_parameters(fit_parameters, data) - if fit_parameters.shape[1] == 0: - logger.warning( - f"Cannot use Symbolic Regression due to lack of parameters: parameter shape is {np.array(param_values).shape}, fit_parameter shape after pre-processing is {fit_parameters.shape}" - ) - self.fit_success = False + if not self._check_fit_param( + fit_parameters, "Symbolic Regression", "preprocessing" + ): return self from dfatool.gplearn.genetic import SymbolicRegressor @@ -1570,20 +1544,12 @@ class FOLFunction(ModelFunction): ignore_indexes=ignore_param_indexes, ) - if fit_parameters.shape[1] == 0: - logger.debug( - f"Cannot run FOL due to lack of parameters: parameter shape is {np.array(param_values).shape}, fit_parameter shape is {fit_parameters.shape}" - ) - self.fit_success = False + if not self._check_fit_param(fit_parameters, "FOL", "param_to_ndarray"): return self fit_parameters = self._preprocess_parameters(fit_parameters, data) - if fit_parameters.shape[1] == 0: - logger.warning( - f"Cannot run FOL due to lack of parameters: parameter shape is {np.array(param_values).shape}, fit_parameter shape after pre-processing is {fit_parameters.shape}" - ) - self.fit_success = False + if not self._check_fit_param(fit_parameters, "FOL", "preprocessing"): return self fit_parameters = fit_parameters.swapaxes(0, 1) -- cgit v1.2.3