diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2023-02-01 15:04:16 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2023-02-01 15:04:16 +0100 |
commit | 5cc4f63a7fdacd49551b996bebf9927290ef7beb (patch) | |
tree | 7b78cdf14204ec9de9fc5c4f99c9a2233b1fac19 | |
parent | 131c0b1026bd8b6c7fb584cfe13f4bc62c6da555 (diff) |
XGB: Add --export-webconf support (generate a list of CART)
-rw-r--r-- | lib/functions.py | 37 | ||||
-rw-r--r-- | lib/parameters.py | 8 |
2 files changed, 42 insertions, 3 deletions
diff --git a/lib/functions.py b/lib/functions.py index 6613e36..74b5509 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -620,7 +620,7 @@ class LMTFunction(SKLearnRegressionFunction): class XGBoostFunction(SKLearnRegressionFunction): - def to_json(self): + def to_json(self, feature_names=None, **kwargs): import json tempfile = f"/tmp/xgb{os.getpid()}.json" @@ -631,8 +631,43 @@ class XGBoostFunction(SKLearnRegressionFunction): with open(tempfile, "r") as f: data = json.load(f) os.remove(tempfile) + + if feature_names: + return list( + map( + lambda tree: self.tree_to_webconf_json( + tree, feature_names, **kwargs + ), + data, + ) + ) return data + def tree_to_webconf_json(self, tree, feature_names, **kwargs): + ret = dict() + if "children" in tree: + return { + "functionError": None, + "type": "scalarSplit", + "paramName": feature_names[int(tree["split"][1:])], + "paramDecisionValue": tree["split_condition"], + "value": None, + "valueError": None, + "left": self.tree_to_webconf_json( + tree["children"][0], feature_names, **kwargs + ), + "right": self.tree_to_webconf_json( + tree["children"][1], feature_names, **kwargs + ), + } + else: + return { + "functionError": None, + "type": "static", + "value": tree["leaf"], + "valueError": None, + } + def get_number_of_nodes(self): return sum(map(self._get_number_of_nodes, self.to_json())) diff --git a/lib/parameters.py b/lib/parameters.py index 503e779..99418b6 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -591,7 +591,7 @@ class ModelAttribute: return f"ModelAttribute<{self.name}, {self.attr}, mean={mean}>" def to_json(self, **kwargs): - if type(self.model_function) == df.CARTFunction: + if type(self.model_function) in (df.CARTFunction, df.XGBoostFunction): import sklearn.tree feature_names = list( @@ -621,7 +621,11 @@ class ModelAttribute: "argCount": self.arg_count, "modelFunction": self.model_function.to_json(**kwargs), } - if type(self.model_function) in (df.CARTFunction, df.FOLFunction): + if type(self.model_function) in ( + df.CARTFunction, + df.FOLFunction, + df.XGBoostFunction, + ): feature_names = self.param_names feature_names += list( map( |