diff options
-rw-r--r-- | lib/cli.py | 41 | ||||
-rw-r--r-- | lib/functions.py | 102 | ||||
-rw-r--r-- | lib/parameters.py | 2 |
3 files changed, 135 insertions, 10 deletions
@@ -39,15 +39,27 @@ def print_static(model, static_model, name, attribute, with_dependence=False): unit = "µs" elif attribute == "substate_count": unit = "su" - print( - "{:10s}: {:28s} : {:.2f} {:s} ({:.2f})".format( - name, - attribute, - static_model(name, attribute), - unit, - model.attr_by_name[name][attribute].stats.generic_param_dependence_ratio(), + if model.attr_by_name[name][attribute].stats: + print( + "{:10s}: {:28s} : {:.2f} {:s} ({:.2f})".format( + name, + attribute, + static_model(name, attribute), + unit, + model.attr_by_name[name][ + attribute + ].stats.generic_param_dependence_ratio(), + ) + ) + else: + print( + "{:10s}: {:28s} : {:.2f} {:s}".format( + name, + attribute, + static_model(name, attribute), + unit, + ) ) - ) if with_dependence: for param in model.parameters: print( @@ -164,6 +176,17 @@ def print_splitinfo(param_names, info, prefix=""): else: param_name = f"arg{info.param_index - len(param_names)}" print_splitinfo(param_names, v, f"{prefix} {param_name}={k}") + if type(info) is df.ScalarSplitFunction: + if info.param_index < len(param_names): + param_name = param_names[info.param_index] + else: + param_name = f"arg{info.param_index - len(param_names)}" + print_splitinfo( + param_names, info.child_le, f"{prefix} {param_name}≤{info.threshold}" + ) + print_splitinfo( + param_names, info.child_gt, f"{prefix} {param_name}>{info.threshold}" + ) elif type(info) is df.AnalyticFunction: print_analyticinfo(prefix, info) elif type(info) is df.StaticFunction: @@ -183,6 +206,8 @@ def print_model(prefix, info, feature_names): print_cartinfo(prefix, info, feature_names) elif type(info) is df.SplitFunction: print_splitinfo(feature_names, info, prefix) + elif type(info) is df.ScalarSplitFunction: + print_splitinfo(feature_names, info, prefix) elif type(info) is df.LMTFunction: print_lmtinfo(prefix, info, feature_names) else: diff --git a/lib/functions.py b/lib/functions.py index d18477e..5bf43aa 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -234,6 +234,8 @@ class ModelFunction: mf = StaticFunction.from_json(data) elif data["type"] == "split": mf = SplitFunction.from_json(data) + elif data["type"] == "scalarSplit": + mf = ScalarSplitFunction.from_json(data) elif data["type"] == "analytic": mf = AnalyticFunction.from_json(data) else: @@ -372,7 +374,7 @@ class SplitFunction(ModelFunction): def get_number_of_nodes(self): ret = 1 for v in self.child.values(): - if type(v) is SplitFunction: + if type(v) in (SplitFunction, ScalarSplitFunction): ret += v.get_number_of_nodes() else: ret += 1 @@ -426,6 +428,104 @@ class SplitFunction(ModelFunction): return f"SplitFunction<{self.value}, param_index={self.param_index}>" +class ScalarSplitFunction(ModelFunction): + def __init__(self, value, param_index, threshold, child_le, child_gt, **kwargs): + super().__init__(value, **kwargs) + self.param_index = param_index + self.threshold = threshold + self.child_le = child_le + self.child_gt = child_gt + + def is_predictable(self, param_list): + """ + Return whether the model function can be evaluated on the given parameter values. + """ + return is_numeric(param_list[self.param_index]) + + def eval(self, param_list): + param_value = param_list[self.param_index] + if param_value <= self.threshold: + return self.child_le.eval(param_list) + return self.child_gt.eval(param_list) + + def webconf_function_map(self): + return ( + self.child_le.webconf_function_map() + self.child_gt.webconf_function_map() + ) + + def to_json(self, feature_names=None, **kwargs): + ret = super().to_json(**kwargs) + with_param_name = kwargs.get("with_param_name", False) + param_names = kwargs.get("param_names", list()) + update = { + "type": "scalarSplit", + "paramIndex": self.param_index, + "paramName": feature_names[self.param_index], + "paramDecisionValue": self.threshold, + "left": self.child_le.to_json(), + "right": self.child_gt.to_json(), + } + if with_param_name and param_names: + update["paramName"] = param_names[self.param_index] + ret.update(update) + return ret + + def get_number_of_nodes(self): + ret = 1 + for v in (self.child_le, self.child_gt): + if type(v) in (SplitFunction, ScalarSplitFunction): + ret += v.get_number_of_nodes() + else: + ret += 1 + return ret + + def get_max_depth(self): + ret = [0] + for v in (self.child_le, self.child_gt): + if type(v) in (SplitFunction, ScalarSplitFunction): + ret.append(v.get_max_depth()) + return 1 + max(ret) + + def get_number_of_leaves(self): + ret = 0 + for v in (self.child_le, self.child_gt): + if type(v) in (SplitFunction, ScalarSplitFunction): + ret += v.get_number_of_leaves() + else: + ret += 1 + return ret + + def get_complexity_score(self): + ret = 1 + for v in (self.child_le, self.child_gt): + ret += v.get_complexity_score() + return ret + + def to_dot(self, pydot, graph, feature_names, parent=None): + try: + label = feature_names[self.param_index] + except IndexError: + label = f"param{self.param_index}" + graph.add_node(pydot.Node(str(id(self)), label=label)) + for key, child in self.child.items(): + child.to_dot(pydot, graph, feature_names, str(id(self))) + graph.add_edge(pydot.Edge(str(id(self)), str(id(child)), label=key)) + + @classmethod + def from_json(cls, data): + assert data["type"] == "scalarSplit" + left = ModelFunction.from_json(data["left"]) + right = ModelFunction.from_json(data["right"]) + self = cls( + data["value"], data["paramIndex"], data["paramDecisionValue"], left, right + ) + + return self + + def __repr__(self): + return f"ScalarSplitFunction<{self.value}, param_index={self.param_index}>" + + class SubstateFunction(ModelFunction): def __init__(self, value, sequence_by_count, count_model, sub_model, **kwargs): super().__init__(value, **kwargs) diff --git a/lib/parameters.py b/lib/parameters.py index 96a996e..210cbe3 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -828,7 +828,7 @@ class ModelAttribute: return np.median(self.by_param[param]) def get_by_param(self): - if self.by_param is None: + if self.by_param is None and self.param_values is not None: self.by_param = partition_by_param(self.data, self.param_values) return self.by_param |