diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-01-19 10:43:42 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-01-19 10:43:42 +0100 |
commit | 5b030d650f398b0b6190260235c45aedb4959fa2 (patch) | |
tree | c2c357062a61cc23765cca27e292a31c1049fce1 /lib | |
parent | a3f2c15b16d1194b18e189783380d092072b2d45 (diff) |
dataref export: include XGB hyper-parameters
Diffstat (limited to 'lib')
-rw-r--r-- | lib/functions.py | 11 | ||||
-rw-r--r-- | lib/parameters.py | 5 |
2 files changed, 15 insertions, 1 deletions
diff --git a/lib/functions.py b/lib/functions.py index 1b8612e..82e516e 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -731,6 +731,17 @@ class XGBoostFunction(SKLearnRegressionFunction): def get_complexity_score(self): return self.get_number_of_nodes() + def to_dref(self): + return { + "hyper/n estimators": self.regressor.n_estimators, + "hyper/max depth": self.regressor.max_depth, + "hyper/subsample": self.regressor.subsample, + "hyper/eta": self.regressor.learning_rate, + "hyper/gamma": self.regressor.gamma, + "hyper/alpha": self.regressor.reg_alpha, + "hyper/lambda": self.regressor.reg_lambda, + } + # first-order linear function (no feature interaction) class FOLFunction(ModelFunction): diff --git a/lib/parameters.py b/lib/parameters.py index 0f23ef8..6183ccc 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -692,6 +692,9 @@ class ModelAttribute: ret["decision tree/inner nodes"] = 0 ret["decision tree/max depth"] = 0 + if type(self.model_function) == df.XGBoostFunction: + ret.update(self.model_function.to_dref()) + return ret def to_dot(self): @@ -1056,7 +1059,7 @@ class ModelAttribute: n_estimators=int(os.getenv("DFATOOL_XGB_N_ESTIMATORS", "100")), max_depth=int(os.getenv("DFATOOL_XGB_MAX_DEPTH", "10")), subsample=float(os.getenv("DFATOOL_XGB_SUBSAMPLE", "0.7")), - eta=float(os.getenv("DFATOOL_XGB_ETA", "0.3")), + learning_rate=float(os.getenv("DFATOOL_XGB_ETA", "0.3")), gamma=float(os.getenv("DFATOOL_XGB_GAMMA", "0.01")), reg_alpha=float(os.getenv("DFATOOL_XGB_REG_ALPHA", "0.0006")), reg_lambda=float(os.getenv("DFATOOL_XGB_REG_LAMBDA", "1")), |