diff options
-rw-r--r-- | lib/functions.py | 44 | ||||
-rw-r--r-- | lib/parameters.py | 8 |
2 files changed, 38 insertions, 14 deletions
diff --git a/lib/functions.py b/lib/functions.py index de5c722..978a993 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -355,6 +355,15 @@ class SplitFunction(ModelFunction): ret.append(v.get_max_depth()) return 1 + max(ret) + def get_number_of_leaves(self): + ret = 0 + for v in self.child.values(): + if type(v) is SplitFunction: + ret += v.get_number_of_leaves() + else: + ret += 1 + return ret + def to_dot(self, pydot, graph, feature_names, parent=None): try: label = feature_names[self.param_index] @@ -484,6 +493,9 @@ class CARTFunction(SKLearnRegressionFunction): def get_number_of_nodes(self): return self.regressor.tree_.node_count + def get_number_of_leaves(self): + return self.regressor.tree_.n_leaves + def get_max_depth(self): return self.regressor.get_depth() @@ -543,16 +555,21 @@ class LMTFunction(SKLearnRegressionFunction): class XGBoostFunction(SKLearnRegressionFunction): - def get_number_of_nodes(self): + def to_json(self): import json + tempfile = f"/tmp/xgb{os.getpid()}.json" + self.regressor.get_booster().dump_model( - "/tmp/xgb.json", dump_format="json", with_stats=True + tempfile, dump_format="json", with_stats=True ) - with open("/tmp/xgb.json", "r") as f: + with open(tempfile, "r") as f: data = json.load(f) + os.remove(tempfile) + return data - return sum(map(self._get_number_of_nodes, data)) + def get_number_of_nodes(self): + return sum(map(self._get_number_of_nodes, self.to_json())) def _get_number_of_nodes(self, data): ret = 1 @@ -560,16 +577,19 @@ class XGBoostFunction(SKLearnRegressionFunction): ret += self._get_number_of_nodes(child) return ret - def get_max_depth(self): - import json + def get_number_of_leaves(self): + return sum(map(self._get_number_of_leaves, self.to_json())) - self.regressor.get_booster().dump_model( - "/tmp/xgb.json", dump_format="json", with_stats=True - ) - with open("/tmp/xgb.json", "r") as f: - data = json.load(f) + def _get_number_of_leaves(self, data): + if "leaf" in data: + return 1 + ret = 0 + for child in data.get("children", list()): + ret += self._get_number_of_leaves(child) + return ret - return max(map(self._get_max_depth, data)) + def get_max_depth(self): + return max(map(self._get_max_depth, self.to_json())) def _get_max_depth(self, data): ret = [0] diff --git a/lib/parameters.py b/lib/parameters.py index dae6e2a..cb4b76f 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -693,7 +693,12 @@ class ModelAttribute: ret["decision tree/nodes"] = 1 ret["decision tree/max depth"] = 1 - if type(self.model_function) in (df.LMTFunction,): + if type(self.model_function) in ( + df.SplitFunction, + df.CARTFunction, + df.LMTFunction, + df.XGBoostFunction, + ): ret["decision tree/leaves"] = self.model_function.get_number_of_leaves() return ret @@ -1257,7 +1262,6 @@ class ModelAttribute: if np.all(np.isinf(loss)): # all children have the same configuration. We shouldn't get here due to the threshold check above... if ffs_feasible: - logger.debug("ffs feasible, attempting to fit a leaf") # try generating a function. if it fails, model_function is a StaticFunction. ma = ModelAttribute( self.name + "_", |