diff options
Diffstat (limited to 'lib/functions.py')
-rw-r--r-- | lib/functions.py | 102 |
1 files changed, 101 insertions, 1 deletions
diff --git a/lib/functions.py b/lib/functions.py index d18477e..5bf43aa 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -234,6 +234,8 @@ class ModelFunction: mf = StaticFunction.from_json(data) elif data["type"] == "split": mf = SplitFunction.from_json(data) + elif data["type"] == "scalarSplit": + mf = ScalarSplitFunction.from_json(data) elif data["type"] == "analytic": mf = AnalyticFunction.from_json(data) else: @@ -372,7 +374,7 @@ class SplitFunction(ModelFunction): def get_number_of_nodes(self): ret = 1 for v in self.child.values(): - if type(v) is SplitFunction: + if type(v) in (SplitFunction, ScalarSplitFunction): ret += v.get_number_of_nodes() else: ret += 1 @@ -426,6 +428,104 @@ class SplitFunction(ModelFunction): return f"SplitFunction<{self.value}, param_index={self.param_index}>" +class ScalarSplitFunction(ModelFunction): + def __init__(self, value, param_index, threshold, child_le, child_gt, **kwargs): + super().__init__(value, **kwargs) + self.param_index = param_index + self.threshold = threshold + self.child_le = child_le + self.child_gt = child_gt + + def is_predictable(self, param_list): + """ + Return whether the model function can be evaluated on the given parameter values. + """ + return is_numeric(param_list[self.param_index]) + + def eval(self, param_list): + param_value = param_list[self.param_index] + if param_value <= self.threshold: + return self.child_le.eval(param_list) + return self.child_gt.eval(param_list) + + def webconf_function_map(self): + return ( + self.child_le.webconf_function_map() + self.child_gt.webconf_function_map() + ) + + def to_json(self, feature_names=None, **kwargs): + ret = super().to_json(**kwargs) + with_param_name = kwargs.get("with_param_name", False) + param_names = kwargs.get("param_names", list()) + update = { + "type": "scalarSplit", + "paramIndex": self.param_index, + "paramName": feature_names[self.param_index], + "paramDecisionValue": self.threshold, + "left": self.child_le.to_json(), + "right": self.child_gt.to_json(), + } + if with_param_name and param_names: + update["paramName"] = param_names[self.param_index] + ret.update(update) + return ret + + def get_number_of_nodes(self): + ret = 1 + for v in (self.child_le, self.child_gt): + if type(v) in (SplitFunction, ScalarSplitFunction): + ret += v.get_number_of_nodes() + else: + ret += 1 + return ret + + def get_max_depth(self): + ret = [0] + for v in (self.child_le, self.child_gt): + if type(v) in (SplitFunction, ScalarSplitFunction): + ret.append(v.get_max_depth()) + return 1 + max(ret) + + def get_number_of_leaves(self): + ret = 0 + for v in (self.child_le, self.child_gt): + if type(v) in (SplitFunction, ScalarSplitFunction): + ret += v.get_number_of_leaves() + else: + ret += 1 + return ret + + def get_complexity_score(self): + ret = 1 + for v in (self.child_le, self.child_gt): + ret += v.get_complexity_score() + return ret + + def to_dot(self, pydot, graph, feature_names, parent=None): + try: + label = feature_names[self.param_index] + except IndexError: + label = f"param{self.param_index}" + graph.add_node(pydot.Node(str(id(self)), label=label)) + for key, child in self.child.items(): + child.to_dot(pydot, graph, feature_names, str(id(self))) + graph.add_edge(pydot.Edge(str(id(self)), str(id(child)), label=key)) + + @classmethod + def from_json(cls, data): + assert data["type"] == "scalarSplit" + left = ModelFunction.from_json(data["left"]) + right = ModelFunction.from_json(data["right"]) + self = cls( + data["value"], data["paramIndex"], data["paramDecisionValue"], left, right + ) + + return self + + def __repr__(self): + return f"ScalarSplitFunction<{self.value}, param_index={self.param_index}>" + + class SubstateFunction(ModelFunction): def __init__(self, value, sequence_by_count, count_model, sub_model, **kwargs): super().__init__(value, **kwargs) |