diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2022-01-26 16:13:04 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2022-01-26 16:13:04 +0100 |
commit | e149c6bc24935ff8383471759c8775d3174ec29d (patch) | |
tree | 119776318c5c678c549d04dfacc5a3daeca82d6a /lib/functions.py | |
parent | 943c8bee7a511e1aff5e1639a479ea868d7656a7 (diff) |
Add tree attribute export for XGBoost
Diffstat (limited to 'lib/functions.py')
-rw-r--r-- | lib/functions.py | 34 |
1 files changed, 26 insertions, 8 deletions
diff --git a/lib/functions.py b/lib/functions.py index 14893ad..43ee58a 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -471,19 +471,37 @@ class CARTFunction(SKLearnRegressionFunction): class XGBoostFunction(SKLearnRegressionFunction): def get_number_of_nodes(self): + import json + + self.regressor.get_booster().dump_model( + "/tmp/xgb.json", dump_format="json", with_stats=True + ) + with open("/tmp/xgb.json", "r") as f: + data = json.load(f) + + return sum(map(self._get_number_of_nodes, data)) + + def _get_number_of_nodes(self, data): ret = 1 - for v in self.child.values(): - if type(v) is SplitFunction: - ret += v.get_number_of_nodes() - else: - ret += 1 + for child in data.get("children", list()): + ret += self._get_number_of_nodes(child) return ret def get_max_depth(self): + import json + + self.regressor.get_booster().dump_model( + "/tmp/xgb.json", dump_format="json", with_stats=True + ) + with open("/tmp/xgb.json", "r") as f: + data = json.load(f) + + return max(map(self._get_max_depth, data)) + + def _get_max_depth(self, data): ret = [0] - for v in self.child.values(): - if type(v) is SplitFunction: - ret.append(v.get_max_depth()) + for child in data.get("children", list()): + ret.append(self._get_max_depth(child)) return 1 + max(ret) |