diff options
Diffstat (limited to 'lib/functions.py')
-rw-r--r-- | lib/functions.py | 34 |
1 files changed, 26 insertions, 8 deletions
diff --git a/lib/functions.py b/lib/functions.py index 14893ad..43ee58a 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -471,19 +471,37 @@ class CARTFunction(SKLearnRegressionFunction): class XGBoostFunction(SKLearnRegressionFunction): def get_number_of_nodes(self): + import json + + self.regressor.get_booster().dump_model( + "/tmp/xgb.json", dump_format="json", with_stats=True + ) + with open("/tmp/xgb.json", "r") as f: + data = json.load(f) + + return sum(map(self._get_number_of_nodes, data)) + + def _get_number_of_nodes(self, data): ret = 1 - for v in self.child.values(): - if type(v) is SplitFunction: - ret += v.get_number_of_nodes() - else: - ret += 1 + for child in data.get("children", list()): + ret += self._get_number_of_nodes(child) return ret def get_max_depth(self): + import json + + self.regressor.get_booster().dump_model( + "/tmp/xgb.json", dump_format="json", with_stats=True + ) + with open("/tmp/xgb.json", "r") as f: + data = json.load(f) + + return max(map(self._get_max_depth, data)) + + def _get_max_depth(self, data): ret = [0] - for v in self.child.values(): - if type(v) is SplitFunction: - ret.append(v.get_max_depth()) + for child in data.get("children", list()): + ret.append(self._get_max_depth(child)) return 1 + max(ret) |