From 745f57c5226a3203678ee23cc0f293ee8ea45949 Mon Sep 17 00:00:00 2001 From: Birte Kristina Friesel Date: Thu, 8 Feb 2024 07:55:00 +0100 Subject: model export/import: rename paramDecisionValue to threshold --- lib/cli.py | 8 ++++---- lib/functions.py | 12 +++++------- 2 files changed, 9 insertions(+), 11 deletions(-) (limited to 'lib') diff --git a/lib/cli.py b/lib/cli.py index 0a5f79a..c2839fb 100644 --- a/lib/cli.py +++ b/lib/cli.py @@ -137,11 +137,11 @@ def _print_lmtinfo(prefix, model): print(f"""{prefix}: {model["value"]}""") elif model["type"] == "scalarSplit": _print_lmtinfo( - f"""{prefix} {model["paramName"]}≤{model["paramDecisionValue"]} """, + f"""{prefix} {model["paramName"]}≤{model["threshold"]} """, model["left"], ) _print_lmtinfo( - f"""{prefix} {model["paramName"]}>{model["paramDecisionValue"]} """, + f"""{prefix} {model["paramName"]}>{model["threshold"]} """, model["right"], ) else: @@ -157,12 +157,12 @@ def _print_cartinfo(prefix, model, feature_names): print(f"""{prefix}: {model["value"]}""") else: _print_cartinfo( - f"""{prefix} {model["paramName"]}≤{model["paramDecisionValue"]} """, + f"""{prefix} {model["paramName"]}≤{model["threshold"]} """, model["left"], feature_names, ) _print_cartinfo( - f"""{prefix} {model["paramName"]}>{model["paramDecisionValue"]} """, + f"""{prefix} {model["paramName"]}>{model["threshold"]} """, model["right"], feature_names, ) diff --git a/lib/functions.py b/lib/functions.py index 5bf43aa..23e7523 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -461,7 +461,7 @@ class ScalarSplitFunction(ModelFunction): "type": "scalarSplit", "paramIndex": self.param_index, "paramName": feature_names[self.param_index], - "paramDecisionValue": self.threshold, + "threshold": self.threshold, "left": self.child_le.to_json(), "right": self.child_gt.to_json(), } @@ -516,9 +516,7 @@ class ScalarSplitFunction(ModelFunction): assert data["type"] == "scalarSplit" left = ModelFunction.from_json(data["left"]) right = ModelFunction.from_json(data["right"]) - self = cls( - data["value"], data["paramIndex"], data["paramDecisionValue"], left, right - ) + self = cls(data["value"], data["paramIndex"], data["threshold"], left, right) return self @@ -708,7 +706,7 @@ class CARTFunction(SKLearnRegressionFunction): sub_data["paramName"] = self.feature_names[ self.regressor.tree_.feature[node_id] ] - sub_data["paramDecisionValue"] = tree.threshold[node_id] + sub_data["threshold"] = tree.threshold[node_id] sub_data["type"] = "scalarSplit" # child value @@ -762,7 +760,7 @@ class LMTFunction(SKLearnRegressionFunction): return { "type": "scalarSplit", "paramName": self.feature_names[node["col"]], - "paramDecisionValue": node["th"], + "threshold": node["th"], "left": self.recurse_(node_hash, node["children"][0]), "right": self.recurse_(node_hash, node["children"][1]), } @@ -810,7 +808,7 @@ class XGBoostFunction(SKLearnRegressionFunction): "functionError": None, "type": "scalarSplit", "paramName": feature_names[int(tree["split"][1:])], - "paramDecisionValue": tree["split_condition"], + "threshold": tree["split_condition"], "value": None, "valueError": None, "left": self.tree_to_webconf_json( -- cgit v1.2.3