summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/functions.py11
1 files changed, 7 insertions, 4 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 63523a8..6366f0a 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -859,7 +859,7 @@ class LMTFunction(SKLearnRegressionFunction):
class XGBoostFunction(SKLearnRegressionFunction):
- def to_json(self, **kwargs):
+ def to_json(self, internal=False, **kwargs):
import json
tempfile = f"/tmp/xgb{os.getpid()}.json"
@@ -871,6 +871,9 @@ class XGBoostFunction(SKLearnRegressionFunction):
data = json.load(f)
os.remove(tempfile)
+ if internal:
+ return data
+
return list(
map(
lambda tree: self.tree_to_webconf_json(tree, **kwargs),
@@ -896,7 +899,7 @@ class XGBoostFunction(SKLearnRegressionFunction):
}
def get_number_of_nodes(self):
- return sum(map(self._get_number_of_nodes, self.to_json()))
+ return sum(map(self._get_number_of_nodes, self.to_json(internal=True)))
def _get_number_of_nodes(self, data):
ret = 1
@@ -905,7 +908,7 @@ class XGBoostFunction(SKLearnRegressionFunction):
return ret
def get_number_of_leaves(self):
- return sum(map(self._get_number_of_leaves, self.to_json()))
+ return sum(map(self._get_number_of_leaves, self.to_json(internal=True)))
def _get_number_of_leaves(self, data):
if "leaf" in data:
@@ -916,7 +919,7 @@ class XGBoostFunction(SKLearnRegressionFunction):
return ret
def get_max_depth(self):
- return max(map(self._get_max_depth, self.to_json()))
+ return max(map(self._get_max_depth, self.to_json(internal=True)))
def _get_max_depth(self, data):
ret = [0]