diff options
-rw-r--r-- | lib/functions.py | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/lib/functions.py b/lib/functions.py index 63523a8..6366f0a 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -859,7 +859,7 @@ class LMTFunction(SKLearnRegressionFunction): class XGBoostFunction(SKLearnRegressionFunction): - def to_json(self, **kwargs): + def to_json(self, internal=False, **kwargs): import json tempfile = f"/tmp/xgb{os.getpid()}.json" @@ -871,6 +871,9 @@ class XGBoostFunction(SKLearnRegressionFunction): data = json.load(f) os.remove(tempfile) + if internal: + return data + return list( map( lambda tree: self.tree_to_webconf_json(tree, **kwargs), @@ -896,7 +899,7 @@ class XGBoostFunction(SKLearnRegressionFunction): } def get_number_of_nodes(self): - return sum(map(self._get_number_of_nodes, self.to_json())) + return sum(map(self._get_number_of_nodes, self.to_json(internal=True))) def _get_number_of_nodes(self, data): ret = 1 @@ -905,7 +908,7 @@ class XGBoostFunction(SKLearnRegressionFunction): return ret def get_number_of_leaves(self): - return sum(map(self._get_number_of_leaves, self.to_json())) + return sum(map(self._get_number_of_leaves, self.to_json(internal=True))) def _get_number_of_leaves(self, data): if "leaf" in data: @@ -916,7 +919,7 @@ class XGBoostFunction(SKLearnRegressionFunction): return ret def get_max_depth(self): - return max(map(self._get_max_depth, self.to_json())) + return max(map(self._get_max_depth, self.to_json(internal=True))) def _get_max_depth(self, data): ret = [0] |