summaryrefslogtreecommitdiff
path: root/lib/functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/functions.py')
-rw-r--r--lib/functions.py34
1 files changed, 26 insertions, 8 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 14893ad..43ee58a 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -471,19 +471,37 @@ class CARTFunction(SKLearnRegressionFunction):
class XGBoostFunction(SKLearnRegressionFunction):
def get_number_of_nodes(self):
+ import json
+
+ self.regressor.get_booster().dump_model(
+ "/tmp/xgb.json", dump_format="json", with_stats=True
+ )
+ with open("/tmp/xgb.json", "r") as f:
+ data = json.load(f)
+
+ return sum(map(self._get_number_of_nodes, data))
+
+ def _get_number_of_nodes(self, data):
ret = 1
- for v in self.child.values():
- if type(v) is SplitFunction:
- ret += v.get_number_of_nodes()
- else:
- ret += 1
+ for child in data.get("children", list()):
+ ret += self._get_number_of_nodes(child)
return ret
def get_max_depth(self):
+ import json
+
+ self.regressor.get_booster().dump_model(
+ "/tmp/xgb.json", dump_format="json", with_stats=True
+ )
+ with open("/tmp/xgb.json", "r") as f:
+ data = json.load(f)
+
+ return max(map(self._get_max_depth, data))
+
+ def _get_max_depth(self, data):
ret = [0]
- for v in self.child.values():
- if type(v) is SplitFunction:
- ret.append(v.get_max_depth())
+ for child in data.get("children", list()):
+ ret.append(self._get_max_depth(child))
return 1 + max(ret)