summaryrefslogtreecommitdiff
path: root/lib/parameters.py
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2022-01-06 08:10:40 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2022-01-06 08:10:40 +0100
commit30792ce0630d8f0768379d8e7e0b8556fbfa9894 (patch)
treeb17ae9e88dfb3d8ad74bf48b4ac21c8f7e638764 /lib/parameters.py
parent063858aadafc168e17ac87a9687ece2cd6fe5519 (diff)
add preliminary xgboost support
Diffstat (limited to 'lib/parameters.py')
-rw-r--r--lib/parameters.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/lib/parameters.py b/lib/parameters.py
index 38e36b2..ca28cbb 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -876,6 +876,7 @@ class ModelAttribute:
with_function_leaves=False,
with_nonbinary_nodes=True,
with_sklearn_cart=False,
+ with_xgboost=False,
loss_ignore_scalar=False,
threshold=100,
):
@@ -908,6 +909,25 @@ class ModelAttribute:
)
return
+ if with_xgboost:
+ from xgboost import XGBRegressor
+
+ # TODO retrieve parameters from env
+ xgb = XGBRegressor(
+ n_estimators=100,
+ max_depth=10,
+ eta=0.2,
+ subsample=0.7,
+ gamma=0.01,
+ alpha=0.0006,
+ )
+ fit_parameters, ignore_index = param_to_ndarray(parameters, with_nan=False)
+ xgb.fit(fit_parameters, data)
+ self.model_function = df.SKLearnRegressionFunction(
+ np.mean(data), xgb, ignore_index
+ )
+ return
+
if loss_ignore_scalar and not with_function_leaves:
logger.warning(
"build_dtree called with loss_ignore_scalar=True, with_function_leaves=False. This does not make sense."