From 01ddab4bc7cff2c06b67e6327c848baa8141ed5c Mon Sep 17 00:00:00 2001 From: Birte Kristina Friesel Date: Mon, 12 Feb 2024 13:33:18 +0100 Subject: XGBoostFunction: fix complexity score and statistics calculation --- lib/functions.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/lib/functions.py b/lib/functions.py index 63523a8..6366f0a 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -859,7 +859,7 @@ class LMTFunction(SKLearnRegressionFunction): class XGBoostFunction(SKLearnRegressionFunction): - def to_json(self, **kwargs): + def to_json(self, internal=False, **kwargs): import json tempfile = f"/tmp/xgb{os.getpid()}.json" @@ -871,6 +871,9 @@ class XGBoostFunction(SKLearnRegressionFunction): data = json.load(f) os.remove(tempfile) + if internal: + return data + return list( map( lambda tree: self.tree_to_webconf_json(tree, **kwargs), @@ -896,7 +899,7 @@ class XGBoostFunction(SKLearnRegressionFunction): } def get_number_of_nodes(self): - return sum(map(self._get_number_of_nodes, self.to_json())) + return sum(map(self._get_number_of_nodes, self.to_json(internal=True))) def _get_number_of_nodes(self, data): ret = 1 @@ -905,7 +908,7 @@ class XGBoostFunction(SKLearnRegressionFunction): return ret def get_number_of_leaves(self): - return sum(map(self._get_number_of_leaves, self.to_json())) + return sum(map(self._get_number_of_leaves, self.to_json(internal=True))) def _get_number_of_leaves(self, data): if "leaf" in data: @@ -916,7 +919,7 @@ class XGBoostFunction(SKLearnRegressionFunction): return ret def get_max_depth(self): - return max(map(self._get_max_depth, self.to_json())) + return max(map(self._get_max_depth, self.to_json(internal=True))) def _get_max_depth(self, data): ret = [0] -- cgit v1.2.3