summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2023-02-01 15:04:16 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2023-02-01 15:04:16 +0100
commit5cc4f63a7fdacd49551b996bebf9927290ef7beb (patch)
tree7b78cdf14204ec9de9fc5c4f99c9a2233b1fac19
parent131c0b1026bd8b6c7fb584cfe13f4bc62c6da555 (diff)
XGB: Add --export-webconf support (generate a list of CART)
-rw-r--r--lib/functions.py37
-rw-r--r--lib/parameters.py8
2 files changed, 42 insertions, 3 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 6613e36..74b5509 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -620,7 +620,7 @@ class LMTFunction(SKLearnRegressionFunction):
class XGBoostFunction(SKLearnRegressionFunction):
- def to_json(self):
+ def to_json(self, feature_names=None, **kwargs):
import json
tempfile = f"/tmp/xgb{os.getpid()}.json"
@@ -631,8 +631,43 @@ class XGBoostFunction(SKLearnRegressionFunction):
with open(tempfile, "r") as f:
data = json.load(f)
os.remove(tempfile)
+
+ if feature_names:
+ return list(
+ map(
+ lambda tree: self.tree_to_webconf_json(
+ tree, feature_names, **kwargs
+ ),
+ data,
+ )
+ )
return data
+ def tree_to_webconf_json(self, tree, feature_names, **kwargs):
+ ret = dict()
+ if "children" in tree:
+ return {
+ "functionError": None,
+ "type": "scalarSplit",
+ "paramName": feature_names[int(tree["split"][1:])],
+ "paramDecisionValue": tree["split_condition"],
+ "value": None,
+ "valueError": None,
+ "left": self.tree_to_webconf_json(
+ tree["children"][0], feature_names, **kwargs
+ ),
+ "right": self.tree_to_webconf_json(
+ tree["children"][1], feature_names, **kwargs
+ ),
+ }
+ else:
+ return {
+ "functionError": None,
+ "type": "static",
+ "value": tree["leaf"],
+ "valueError": None,
+ }
+
def get_number_of_nodes(self):
return sum(map(self._get_number_of_nodes, self.to_json()))
diff --git a/lib/parameters.py b/lib/parameters.py
index 503e779..99418b6 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -591,7 +591,7 @@ class ModelAttribute:
return f"ModelAttribute<{self.name}, {self.attr}, mean={mean}>"
def to_json(self, **kwargs):
- if type(self.model_function) == df.CARTFunction:
+ if type(self.model_function) in (df.CARTFunction, df.XGBoostFunction):
import sklearn.tree
feature_names = list(
@@ -621,7 +621,11 @@ class ModelAttribute:
"argCount": self.arg_count,
"modelFunction": self.model_function.to_json(**kwargs),
}
- if type(self.model_function) in (df.CARTFunction, df.FOLFunction):
+ if type(self.model_function) in (
+ df.CARTFunction,
+ df.FOLFunction,
+ df.XGBoostFunction,
+ ):
feature_names = self.param_names
feature_names += list(
map(