summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-02-08 07:55:00 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-02-08 07:55:00 +0100
commit745f57c5226a3203678ee23cc0f293ee8ea45949 (patch)
tree0ce59d137fdcedc29e5b22cae4eb705b27a732be /lib
parent5413ff7c90d6e02693356bc0359ce9863ce80456 (diff)
model export/import: rename paramDecisionValue to threshold
Diffstat (limited to 'lib')
-rw-r--r--lib/cli.py8
-rw-r--r--lib/functions.py12
2 files changed, 9 insertions, 11 deletions
diff --git a/lib/cli.py b/lib/cli.py
index 0a5f79a..c2839fb 100644
--- a/lib/cli.py
+++ b/lib/cli.py
@@ -137,11 +137,11 @@ def _print_lmtinfo(prefix, model):
print(f"""{prefix}: {model["value"]}""")
elif model["type"] == "scalarSplit":
_print_lmtinfo(
- f"""{prefix} {model["paramName"]}≤{model["paramDecisionValue"]} """,
+ f"""{prefix} {model["paramName"]}≤{model["threshold"]} """,
model["left"],
)
_print_lmtinfo(
- f"""{prefix} {model["paramName"]}>{model["paramDecisionValue"]} """,
+ f"""{prefix} {model["paramName"]}>{model["threshold"]} """,
model["right"],
)
else:
@@ -157,12 +157,12 @@ def _print_cartinfo(prefix, model, feature_names):
print(f"""{prefix}: {model["value"]}""")
else:
_print_cartinfo(
- f"""{prefix} {model["paramName"]}≤{model["paramDecisionValue"]} """,
+ f"""{prefix} {model["paramName"]}≤{model["threshold"]} """,
model["left"],
feature_names,
)
_print_cartinfo(
- f"""{prefix} {model["paramName"]}>{model["paramDecisionValue"]} """,
+ f"""{prefix} {model["paramName"]}>{model["threshold"]} """,
model["right"],
feature_names,
)
diff --git a/lib/functions.py b/lib/functions.py
index 5bf43aa..23e7523 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -461,7 +461,7 @@ class ScalarSplitFunction(ModelFunction):
"type": "scalarSplit",
"paramIndex": self.param_index,
"paramName": feature_names[self.param_index],
- "paramDecisionValue": self.threshold,
+ "threshold": self.threshold,
"left": self.child_le.to_json(),
"right": self.child_gt.to_json(),
}
@@ -516,9 +516,7 @@ class ScalarSplitFunction(ModelFunction):
assert data["type"] == "scalarSplit"
left = ModelFunction.from_json(data["left"])
right = ModelFunction.from_json(data["right"])
- self = cls(
- data["value"], data["paramIndex"], data["paramDecisionValue"], left, right
- )
+ self = cls(data["value"], data["paramIndex"], data["threshold"], left, right)
return self
@@ -708,7 +706,7 @@ class CARTFunction(SKLearnRegressionFunction):
sub_data["paramName"] = self.feature_names[
self.regressor.tree_.feature[node_id]
]
- sub_data["paramDecisionValue"] = tree.threshold[node_id]
+ sub_data["threshold"] = tree.threshold[node_id]
sub_data["type"] = "scalarSplit"
# child value
@@ -762,7 +760,7 @@ class LMTFunction(SKLearnRegressionFunction):
return {
"type": "scalarSplit",
"paramName": self.feature_names[node["col"]],
- "paramDecisionValue": node["th"],
+ "threshold": node["th"],
"left": self.recurse_(node_hash, node["children"][0]),
"right": self.recurse_(node_hash, node["children"][1]),
}
@@ -810,7 +808,7 @@ class XGBoostFunction(SKLearnRegressionFunction):
"functionError": None,
"type": "scalarSplit",
"paramName": feature_names[int(tree["split"][1:])],
- "paramDecisionValue": tree["split_condition"],
+ "threshold": tree["split_condition"],
"value": None,
"valueError": None,
"left": self.tree_to_webconf_json(