diff options
Diffstat (limited to 'lib/functions.py')
-rw-r--r-- | lib/functions.py | 44 |
1 files changed, 32 insertions, 12 deletions
diff --git a/lib/functions.py b/lib/functions.py index de5c722..978a993 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -355,6 +355,15 @@ class SplitFunction(ModelFunction): ret.append(v.get_max_depth()) return 1 + max(ret) + def get_number_of_leaves(self): + ret = 0 + for v in self.child.values(): + if type(v) is SplitFunction: + ret += v.get_number_of_leaves() + else: + ret += 1 + return ret + def to_dot(self, pydot, graph, feature_names, parent=None): try: label = feature_names[self.param_index] @@ -484,6 +493,9 @@ class CARTFunction(SKLearnRegressionFunction): def get_number_of_nodes(self): return self.regressor.tree_.node_count + def get_number_of_leaves(self): + return self.regressor.tree_.n_leaves + def get_max_depth(self): return self.regressor.get_depth() @@ -543,16 +555,21 @@ class LMTFunction(SKLearnRegressionFunction): class XGBoostFunction(SKLearnRegressionFunction): - def get_number_of_nodes(self): + def to_json(self): import json + tempfile = f"/tmp/xgb{os.getpid()}.json" + self.regressor.get_booster().dump_model( - "/tmp/xgb.json", dump_format="json", with_stats=True + tempfile, dump_format="json", with_stats=True ) - with open("/tmp/xgb.json", "r") as f: + with open(tempfile, "r") as f: data = json.load(f) + os.remove(tempfile) + return data - return sum(map(self._get_number_of_nodes, data)) + def get_number_of_nodes(self): + return sum(map(self._get_number_of_nodes, self.to_json())) def _get_number_of_nodes(self, data): ret = 1 @@ -560,16 +577,19 @@ class XGBoostFunction(SKLearnRegressionFunction): ret += self._get_number_of_nodes(child) return ret - def get_max_depth(self): - import json + def get_number_of_leaves(self): + return sum(map(self._get_number_of_leaves, self.to_json())) - self.regressor.get_booster().dump_model( - "/tmp/xgb.json", dump_format="json", with_stats=True - ) - with open("/tmp/xgb.json", "r") as f: - data = json.load(f) + def _get_number_of_leaves(self, data): + if "leaf" in data: + return 1 + ret = 0 + for child in data.get("children", list()): + ret += self._get_number_of_leaves(child) + return ret - return max(map(self._get_max_depth, data)) + def get_max_depth(self): + return max(map(self._get_max_depth, self.to_json())) def _get_max_depth(self, data): ret = [0] |