diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2022-01-26 16:13:04 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2022-01-26 16:13:04 +0100 |
commit | e149c6bc24935ff8383471759c8775d3174ec29d (patch) | |
tree | 119776318c5c678c549d04dfacc5a3daeca82d6a /lib | |
parent | 943c8bee7a511e1aff5e1639a479ea868d7656a7 (diff) |
Add tree attribute export for XGBoost
Diffstat (limited to 'lib')
-rw-r--r-- | lib/functions.py | 34 | ||||
-rw-r--r-- | lib/parameters.py | 8 |
2 files changed, 32 insertions, 10 deletions
diff --git a/lib/functions.py b/lib/functions.py index 14893ad..43ee58a 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -471,19 +471,37 @@ class CARTFunction(SKLearnRegressionFunction): class XGBoostFunction(SKLearnRegressionFunction): def get_number_of_nodes(self): + import json + + self.regressor.get_booster().dump_model( + "/tmp/xgb.json", dump_format="json", with_stats=True + ) + with open("/tmp/xgb.json", "r") as f: + data = json.load(f) + + return sum(map(self._get_number_of_nodes, data)) + + def _get_number_of_nodes(self, data): ret = 1 - for v in self.child.values(): - if type(v) is SplitFunction: - ret += v.get_number_of_nodes() - else: - ret += 1 + for child in data.get("children", list()): + ret += self._get_number_of_nodes(child) return ret def get_max_depth(self): + import json + + self.regressor.get_booster().dump_model( + "/tmp/xgb.json", dump_format="json", with_stats=True + ) + with open("/tmp/xgb.json", "r") as f: + data = json.load(f) + + return max(map(self._get_max_depth, data)) + + def _get_max_depth(self, data): ret = [0] - for v in self.child.values(): - if type(v) is SplitFunction: - ret.append(v.get_max_depth()) + for child in data.get("children", list()): + ret.append(self._get_max_depth(child)) return 1 + max(ret) diff --git a/lib/parameters.py b/lib/parameters.py index 9350a06..4e98f54 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -666,7 +666,11 @@ class ModelAttribute: def to_dref(self, unit=None): ret = {"mean": (self.mean, unit), "median": (self.median, unit)} - if type(self.model_function) in (df.SplitFunction, df.CARTFunction): + if type(self.model_function) in ( + df.SplitFunction, + df.CARTFunction, + df.XGBoostFunction, + ): ret["decision tree/nodes"] = self.model_function.get_number_of_nodes() ret["decision tree/max depth"] = self.model_function.get_max_depth() @@ -961,7 +965,7 @@ class ModelAttribute: self.model_function = df.StaticFunction(np.mean(data)) return xgb.fit(fit_parameters, np.reshape(data, (-1, 1))) - self.model_function = df.SKLearnRegressionFunction( + self.model_function = df.XGBoostFunction( np.mean(data), xgb, category_to_index, ignore_index ) output_filename = os.getenv("DFATOOL_XGB_DUMP_MODEL", None) |