summaryrefslogtreecommitdiff
path: root/lib/functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/functions.py')
-rw-r--r--lib/functions.py37
1 files changed, 36 insertions, 1 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 6613e36..74b5509 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -620,7 +620,7 @@ class LMTFunction(SKLearnRegressionFunction):
class XGBoostFunction(SKLearnRegressionFunction):
- def to_json(self):
+ def to_json(self, feature_names=None, **kwargs):
import json
tempfile = f"/tmp/xgb{os.getpid()}.json"
@@ -631,8 +631,43 @@ class XGBoostFunction(SKLearnRegressionFunction):
with open(tempfile, "r") as f:
data = json.load(f)
os.remove(tempfile)
+
+ if feature_names:
+ return list(
+ map(
+ lambda tree: self.tree_to_webconf_json(
+ tree, feature_names, **kwargs
+ ),
+ data,
+ )
+ )
return data
+ def tree_to_webconf_json(self, tree, feature_names, **kwargs):
+ ret = dict()
+ if "children" in tree:
+ return {
+ "functionError": None,
+ "type": "scalarSplit",
+ "paramName": feature_names[int(tree["split"][1:])],
+ "paramDecisionValue": tree["split_condition"],
+ "value": None,
+ "valueError": None,
+ "left": self.tree_to_webconf_json(
+ tree["children"][0], feature_names, **kwargs
+ ),
+ "right": self.tree_to_webconf_json(
+ tree["children"][1], feature_names, **kwargs
+ ),
+ }
+ else:
+ return {
+ "functionError": None,
+ "type": "static",
+ "value": tree["leaf"],
+ "valueError": None,
+ }
+
def get_number_of_nodes(self):
return sum(map(self._get_number_of_nodes, self.to_json()))