diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2022-01-06 08:10:40 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2022-01-06 08:10:40 +0100 |
commit | 30792ce0630d8f0768379d8e7e0b8556fbfa9894 (patch) | |
tree | b17ae9e88dfb3d8ad74bf48b4ac21c8f7e638764 /lib/parameters.py | |
parent | 063858aadafc168e17ac87a9687ece2cd6fe5519 (diff) |
add preliminary xgboost support
Diffstat (limited to 'lib/parameters.py')
-rw-r--r-- | lib/parameters.py | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/lib/parameters.py b/lib/parameters.py index 38e36b2..ca28cbb 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -876,6 +876,7 @@ class ModelAttribute: with_function_leaves=False, with_nonbinary_nodes=True, with_sklearn_cart=False, + with_xgboost=False, loss_ignore_scalar=False, threshold=100, ): @@ -908,6 +909,25 @@ class ModelAttribute: ) return + if with_xgboost: + from xgboost import XGBRegressor + + # TODO retrieve parameters from env + xgb = XGBRegressor( + n_estimators=100, + max_depth=10, + eta=0.2, + subsample=0.7, + gamma=0.01, + alpha=0.0006, + ) + fit_parameters, ignore_index = param_to_ndarray(parameters, with_nan=False) + xgb.fit(fit_parameters, data) + self.model_function = df.SKLearnRegressionFunction( + np.mean(data), xgb, ignore_index + ) + return + if loss_ignore_scalar and not with_function_leaves: logger.warning( "build_dtree called with loss_ignore_scalar=True, with_function_leaves=False. This does not make sense." |