summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2022-03-02 21:40:48 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2022-03-02 21:40:48 +0100
commitb5dda07a577232f3dff6abe88b2e829e35d6602e (patch)
tree8ce5933a3c242cc2c31981dac5958031dbfb1aa9
parent4c5b7b2f07b3e1f661dc85828f8fee5c7c88ac72 (diff)
ModelAttribute.to_json: special handling for CARTFunction. TODO refactoring
-rw-r--r--lib/parameters.py45
1 files changed, 45 insertions, 0 deletions
diff --git a/lib/parameters.py b/lib/parameters.py
index 91548f9..d9ad4c0 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -626,11 +626,56 @@ class ModelAttribute:
return f"ModelAttribute<{self.name}, {self.attr}, mean={mean}>"
def to_json(self, **kwargs):
+ if type(self.model_function) == df.CARTFunction:
+ import sklearn.tree
+
+ feature_names = list(
+ map(
+ lambda i: self.param_names[i],
+ filter(
+ lambda i: not self.model_function.ignore_index[i],
+ range(len(self.param_names)),
+ ),
+ )
+ )
+ feature_names += list(
+ map(
+ lambda i: f"arg{i-len(self.param_names)}",
+ filter(
+ lambda i: not self.model_function.ignore_index[i],
+ range(
+ len(self.param_names),
+ len(self.param_names) + self.arg_count,
+ ),
+ ),
+ )
+ )
+ kwargs["feature_names"] = feature_names
ret = {
"paramNames": self.param_names,
"argCount": self.arg_count,
"modelFunction": self.model_function.to_json(**kwargs),
}
+ if type(self.model_function) == df.CARTFunction:
+ feature_names = self.param_names
+ feature_names += list(
+ map(
+ lambda i: f"arg{i-len(self.param_names)}",
+ filter(
+ lambda i: not self.model_function.ignore_index[i],
+ range(
+ len(self.param_names),
+ len(self.param_names) + self.arg_count,
+ ),
+ ),
+ )
+ )
+ ret["paramValueToIndex"] = dict(
+ map(
+ lambda kv: (feature_names[kv[0]], kv[1]),
+ self.model_function.categorial_to_index.items(),
+ )
+ )
return ret
def to_dref(self, unit=None):