summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2022-01-06 08:10:40 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2022-01-06 08:10:40 +0100
commit30792ce0630d8f0768379d8e7e0b8556fbfa9894 (patch)
treeb17ae9e88dfb3d8ad74bf48b4ac21c8f7e638764
parent063858aadafc168e17ac87a9687ece2cd6fe5519 (diff)
add preliminary xgboost support
-rw-r--r--lib/functions.py2
-rw-r--r--lib/model.py6
-rw-r--r--lib/parameters.py20
3 files changed, 27 insertions, 1 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 5358d8e..f4e1709 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -443,7 +443,7 @@ class SKLearnRegressionFunction(ModelFunction):
for i, param in enumerate(param_list):
if not self.ignore_index[i]:
actual_param_list.append(param)
- return self.regressor.predict([actual_param_list])
+ return self.regressor.predict(np.array([actual_param_list]))
class AnalyticFunction(ModelFunction):
diff --git a/lib/model.py b/lib/model.py
index e84e680..1270cf6 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -160,6 +160,7 @@ class AnalyticModel:
with_sklearn_cart = bool(
int(os.getenv("DFATOOL_DTREE_SKLEARN_CART", "0"))
)
+ with_xgboost = bool(int(os.getenv("DFATOOL_USE_XGBOOST", "0")))
loss_ignore_scalar = bool(
int(os.getenv("DFATOOL_DTREE_LOSS_IGNORE_SCALAR", "0"))
)
@@ -173,6 +174,7 @@ class AnalyticModel:
with_function_leaves=with_function_leaves,
with_nonbinary_nodes=with_nonbinary_nodes,
with_sklearn_cart=with_sklearn_cart,
+ with_xgboost=with_xgboost,
loss_ignore_scalar=loss_ignore_scalar,
)
self.fit_done = True
@@ -324,6 +326,7 @@ class AnalyticModel:
with_sklearn_cart = bool(
int(os.getenv("DFATOOL_DTREE_SKLEARN_CART", "0"))
)
+ with_xgboost = bool(int(os.getenv("DFATOOL_USE_XGBOOST", "0")))
loss_ignore_scalar = bool(
int(os.getenv("DFATOOL_DTREE_LOSS_IGNORE_SCALAR", "0"))
)
@@ -344,6 +347,7 @@ class AnalyticModel:
with_function_leaves=with_function_leaves,
with_nonbinary_nodes=with_nonbinary_nodes,
with_sklearn_cart=with_sklearn_cart,
+ with_xgboost=with_xgboost,
loss_ignore_scalar=loss_ignore_scalar,
)
else:
@@ -423,6 +427,7 @@ class AnalyticModel:
with_function_leaves=False,
with_nonbinary_nodes=True,
with_sklearn_cart=False,
+ with_xgboost=False,
loss_ignore_scalar=False,
):
@@ -445,6 +450,7 @@ class AnalyticModel:
with_function_leaves=with_function_leaves,
with_nonbinary_nodes=with_nonbinary_nodes,
with_sklearn_cart=with_sklearn_cart,
+ with_xgboost=with_xgboost,
loss_ignore_scalar=loss_ignore_scalar,
threshold=threshold,
)
diff --git a/lib/parameters.py b/lib/parameters.py
index 38e36b2..ca28cbb 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -876,6 +876,7 @@ class ModelAttribute:
with_function_leaves=False,
with_nonbinary_nodes=True,
with_sklearn_cart=False,
+ with_xgboost=False,
loss_ignore_scalar=False,
threshold=100,
):
@@ -908,6 +909,25 @@ class ModelAttribute:
)
return
+ if with_xgboost:
+ from xgboost import XGBRegressor
+
+ # TODO retrieve parameters from env
+ xgb = XGBRegressor(
+ n_estimators=100,
+ max_depth=10,
+ eta=0.2,
+ subsample=0.7,
+ gamma=0.01,
+ alpha=0.0006,
+ )
+ fit_parameters, ignore_index = param_to_ndarray(parameters, with_nan=False)
+ xgb.fit(fit_parameters, data)
+ self.model_function = df.SKLearnRegressionFunction(
+ np.mean(data), xgb, ignore_index
+ )
+ return
+
if loss_ignore_scalar and not with_function_leaves:
logger.warning(
"build_dtree called with loss_ignore_scalar=True, with_function_leaves=False. This does not make sense."