summaryrefslogtreecommitdiff
path: root/lib/functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/functions.py')
-rw-r--r--lib/functions.py44
1 files changed, 32 insertions, 12 deletions
diff --git a/lib/functions.py b/lib/functions.py
index de5c722..978a993 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -355,6 +355,15 @@ class SplitFunction(ModelFunction):
ret.append(v.get_max_depth())
return 1 + max(ret)
+ def get_number_of_leaves(self):
+ ret = 0
+ for v in self.child.values():
+ if type(v) is SplitFunction:
+ ret += v.get_number_of_leaves()
+ else:
+ ret += 1
+ return ret
+
def to_dot(self, pydot, graph, feature_names, parent=None):
try:
label = feature_names[self.param_index]
@@ -484,6 +493,9 @@ class CARTFunction(SKLearnRegressionFunction):
def get_number_of_nodes(self):
return self.regressor.tree_.node_count
+ def get_number_of_leaves(self):
+ return self.regressor.tree_.n_leaves
+
def get_max_depth(self):
return self.regressor.get_depth()
@@ -543,16 +555,21 @@ class LMTFunction(SKLearnRegressionFunction):
class XGBoostFunction(SKLearnRegressionFunction):
- def get_number_of_nodes(self):
+ def to_json(self):
import json
+ tempfile = f"/tmp/xgb{os.getpid()}.json"
+
self.regressor.get_booster().dump_model(
- "/tmp/xgb.json", dump_format="json", with_stats=True
+ tempfile, dump_format="json", with_stats=True
)
- with open("/tmp/xgb.json", "r") as f:
+ with open(tempfile, "r") as f:
data = json.load(f)
+ os.remove(tempfile)
+ return data
- return sum(map(self._get_number_of_nodes, data))
+ def get_number_of_nodes(self):
+ return sum(map(self._get_number_of_nodes, self.to_json()))
def _get_number_of_nodes(self, data):
ret = 1
@@ -560,16 +577,19 @@ class XGBoostFunction(SKLearnRegressionFunction):
ret += self._get_number_of_nodes(child)
return ret
- def get_max_depth(self):
- import json
+ def get_number_of_leaves(self):
+ return sum(map(self._get_number_of_leaves, self.to_json()))
- self.regressor.get_booster().dump_model(
- "/tmp/xgb.json", dump_format="json", with_stats=True
- )
- with open("/tmp/xgb.json", "r") as f:
- data = json.load(f)
+ def _get_number_of_leaves(self, data):
+ if "leaf" in data:
+ return 1
+ ret = 0
+ for child in data.get("children", list()):
+ ret += self._get_number_of_leaves(child)
+ return ret
- return max(map(self._get_max_depth, data))
+ def get_max_depth(self):
+ return max(map(self._get_max_depth, self.to_json()))
def _get_max_depth(self, data):
ret = [0]