summaryrefslogtreecommitdiff
path: root/lib/functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/functions.py')
-rw-r--r--lib/functions.py102
1 files changed, 101 insertions, 1 deletions
diff --git a/lib/functions.py b/lib/functions.py
index d18477e..5bf43aa 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -234,6 +234,8 @@ class ModelFunction:
mf = StaticFunction.from_json(data)
elif data["type"] == "split":
mf = SplitFunction.from_json(data)
+ elif data["type"] == "scalarSplit":
+ mf = ScalarSplitFunction.from_json(data)
elif data["type"] == "analytic":
mf = AnalyticFunction.from_json(data)
else:
@@ -372,7 +374,7 @@ class SplitFunction(ModelFunction):
def get_number_of_nodes(self):
ret = 1
for v in self.child.values():
- if type(v) is SplitFunction:
+ if type(v) in (SplitFunction, ScalarSplitFunction):
ret += v.get_number_of_nodes()
else:
ret += 1
@@ -426,6 +428,104 @@ class SplitFunction(ModelFunction):
return f"SplitFunction<{self.value}, param_index={self.param_index}>"
+class ScalarSplitFunction(ModelFunction):
+ def __init__(self, value, param_index, threshold, child_le, child_gt, **kwargs):
+ super().__init__(value, **kwargs)
+ self.param_index = param_index
+ self.threshold = threshold
+ self.child_le = child_le
+ self.child_gt = child_gt
+
+ def is_predictable(self, param_list):
+ """
+ Return whether the model function can be evaluated on the given parameter values.
+ """
+ return is_numeric(param_list[self.param_index])
+
+ def eval(self, param_list):
+ param_value = param_list[self.param_index]
+ if param_value <= self.threshold:
+ return self.child_le.eval(param_list)
+ return self.child_gt.eval(param_list)
+
+ def webconf_function_map(self):
+ return (
+ self.child_le.webconf_function_map() + self.child_gt.webconf_function_map()
+ )
+
+ def to_json(self, feature_names=None, **kwargs):
+ ret = super().to_json(**kwargs)
+ with_param_name = kwargs.get("with_param_name", False)
+ param_names = kwargs.get("param_names", list())
+ update = {
+ "type": "scalarSplit",
+ "paramIndex": self.param_index,
+ "paramName": feature_names[self.param_index],
+ "paramDecisionValue": self.threshold,
+ "left": self.child_le.to_json(),
+ "right": self.child_gt.to_json(),
+ }
+ if with_param_name and param_names:
+ update["paramName"] = param_names[self.param_index]
+ ret.update(update)
+ return ret
+
+ def get_number_of_nodes(self):
+ ret = 1
+ for v in (self.child_le, self.child_gt):
+ if type(v) in (SplitFunction, ScalarSplitFunction):
+ ret += v.get_number_of_nodes()
+ else:
+ ret += 1
+ return ret
+
+ def get_max_depth(self):
+ ret = [0]
+ for v in (self.child_le, self.child_gt):
+ if type(v) in (SplitFunction, ScalarSplitFunction):
+ ret.append(v.get_max_depth())
+ return 1 + max(ret)
+
+ def get_number_of_leaves(self):
+ ret = 0
+ for v in (self.child_le, self.child_gt):
+ if type(v) in (SplitFunction, ScalarSplitFunction):
+ ret += v.get_number_of_leaves()
+ else:
+ ret += 1
+ return ret
+
+ def get_complexity_score(self):
+ ret = 1
+ for v in (self.child_le, self.child_gt):
+ ret += v.get_complexity_score()
+ return ret
+
+ def to_dot(self, pydot, graph, feature_names, parent=None):
+ try:
+ label = feature_names[self.param_index]
+ except IndexError:
+ label = f"param{self.param_index}"
+ graph.add_node(pydot.Node(str(id(self)), label=label))
+ for key, child in self.child.items():
+ child.to_dot(pydot, graph, feature_names, str(id(self)))
+ graph.add_edge(pydot.Edge(str(id(self)), str(id(child)), label=key))
+
+ @classmethod
+ def from_json(cls, data):
+ assert data["type"] == "scalarSplit"
+ left = ModelFunction.from_json(data["left"])
+ right = ModelFunction.from_json(data["right"])
+ self = cls(
+ data["value"], data["paramIndex"], data["paramDecisionValue"], left, right
+ )
+
+ return self
+
+ def __repr__(self):
+ return f"ScalarSplitFunction<{self.value}, param_index={self.param_index}>"
+
+
class SubstateFunction(ModelFunction):
def __init__(self, value, sequence_by_count, count_model, sub_model, **kwargs):
super().__init__(value, **kwargs)