summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2022-01-26 16:13:04 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2022-01-26 16:13:04 +0100
commite149c6bc24935ff8383471759c8775d3174ec29d (patch)
tree119776318c5c678c549d04dfacc5a3daeca82d6a /lib
parent943c8bee7a511e1aff5e1639a479ea868d7656a7 (diff)
Add tree attribute export for XGBoost
Diffstat (limited to 'lib')
-rw-r--r--lib/functions.py34
-rw-r--r--lib/parameters.py8
2 files changed, 32 insertions, 10 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 14893ad..43ee58a 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -471,19 +471,37 @@ class CARTFunction(SKLearnRegressionFunction):
class XGBoostFunction(SKLearnRegressionFunction):
def get_number_of_nodes(self):
+ import json
+
+ self.regressor.get_booster().dump_model(
+ "/tmp/xgb.json", dump_format="json", with_stats=True
+ )
+ with open("/tmp/xgb.json", "r") as f:
+ data = json.load(f)
+
+ return sum(map(self._get_number_of_nodes, data))
+
+ def _get_number_of_nodes(self, data):
ret = 1
- for v in self.child.values():
- if type(v) is SplitFunction:
- ret += v.get_number_of_nodes()
- else:
- ret += 1
+ for child in data.get("children", list()):
+ ret += self._get_number_of_nodes(child)
return ret
def get_max_depth(self):
+ import json
+
+ self.regressor.get_booster().dump_model(
+ "/tmp/xgb.json", dump_format="json", with_stats=True
+ )
+ with open("/tmp/xgb.json", "r") as f:
+ data = json.load(f)
+
+ return max(map(self._get_max_depth, data))
+
+ def _get_max_depth(self, data):
ret = [0]
- for v in self.child.values():
- if type(v) is SplitFunction:
- ret.append(v.get_max_depth())
+ for child in data.get("children", list()):
+ ret.append(self._get_max_depth(child))
return 1 + max(ret)
diff --git a/lib/parameters.py b/lib/parameters.py
index 9350a06..4e98f54 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -666,7 +666,11 @@ class ModelAttribute:
def to_dref(self, unit=None):
ret = {"mean": (self.mean, unit), "median": (self.median, unit)}
- if type(self.model_function) in (df.SplitFunction, df.CARTFunction):
+ if type(self.model_function) in (
+ df.SplitFunction,
+ df.CARTFunction,
+ df.XGBoostFunction,
+ ):
ret["decision tree/nodes"] = self.model_function.get_number_of_nodes()
ret["decision tree/max depth"] = self.model_function.get_max_depth()
@@ -961,7 +965,7 @@ class ModelAttribute:
self.model_function = df.StaticFunction(np.mean(data))
return
xgb.fit(fit_parameters, np.reshape(data, (-1, 1)))
- self.model_function = df.SKLearnRegressionFunction(
+ self.model_function = df.XGBoostFunction(
np.mean(data), xgb, category_to_index, ignore_index
)
output_filename = os.getenv("DFATOOL_XGB_DUMP_MODEL", None)