summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/cli.py41
-rw-r--r--lib/functions.py102
-rw-r--r--lib/parameters.py2
3 files changed, 135 insertions, 10 deletions
diff --git a/lib/cli.py b/lib/cli.py
index 7997482..0a5f79a 100644
--- a/lib/cli.py
+++ b/lib/cli.py
@@ -39,15 +39,27 @@ def print_static(model, static_model, name, attribute, with_dependence=False):
unit = "µs"
elif attribute == "substate_count":
unit = "su"
- print(
- "{:10s}: {:28s} : {:.2f} {:s} ({:.2f})".format(
- name,
- attribute,
- static_model(name, attribute),
- unit,
- model.attr_by_name[name][attribute].stats.generic_param_dependence_ratio(),
+ if model.attr_by_name[name][attribute].stats:
+ print(
+ "{:10s}: {:28s} : {:.2f} {:s} ({:.2f})".format(
+ name,
+ attribute,
+ static_model(name, attribute),
+ unit,
+ model.attr_by_name[name][
+ attribute
+ ].stats.generic_param_dependence_ratio(),
+ )
+ )
+ else:
+ print(
+ "{:10s}: {:28s} : {:.2f} {:s}".format(
+ name,
+ attribute,
+ static_model(name, attribute),
+ unit,
+ )
)
- )
if with_dependence:
for param in model.parameters:
print(
@@ -164,6 +176,17 @@ def print_splitinfo(param_names, info, prefix=""):
else:
param_name = f"arg{info.param_index - len(param_names)}"
print_splitinfo(param_names, v, f"{prefix} {param_name}={k}")
+ if type(info) is df.ScalarSplitFunction:
+ if info.param_index < len(param_names):
+ param_name = param_names[info.param_index]
+ else:
+ param_name = f"arg{info.param_index - len(param_names)}"
+ print_splitinfo(
+ param_names, info.child_le, f"{prefix} {param_name}≤{info.threshold}"
+ )
+ print_splitinfo(
+ param_names, info.child_gt, f"{prefix} {param_name}>{info.threshold}"
+ )
elif type(info) is df.AnalyticFunction:
print_analyticinfo(prefix, info)
elif type(info) is df.StaticFunction:
@@ -183,6 +206,8 @@ def print_model(prefix, info, feature_names):
print_cartinfo(prefix, info, feature_names)
elif type(info) is df.SplitFunction:
print_splitinfo(feature_names, info, prefix)
+ elif type(info) is df.ScalarSplitFunction:
+ print_splitinfo(feature_names, info, prefix)
elif type(info) is df.LMTFunction:
print_lmtinfo(prefix, info, feature_names)
else:
diff --git a/lib/functions.py b/lib/functions.py
index d18477e..5bf43aa 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -234,6 +234,8 @@ class ModelFunction:
mf = StaticFunction.from_json(data)
elif data["type"] == "split":
mf = SplitFunction.from_json(data)
+ elif data["type"] == "scalarSplit":
+ mf = ScalarSplitFunction.from_json(data)
elif data["type"] == "analytic":
mf = AnalyticFunction.from_json(data)
else:
@@ -372,7 +374,7 @@ class SplitFunction(ModelFunction):
def get_number_of_nodes(self):
ret = 1
for v in self.child.values():
- if type(v) is SplitFunction:
+ if type(v) in (SplitFunction, ScalarSplitFunction):
ret += v.get_number_of_nodes()
else:
ret += 1
@@ -426,6 +428,104 @@ class SplitFunction(ModelFunction):
return f"SplitFunction<{self.value}, param_index={self.param_index}>"
+class ScalarSplitFunction(ModelFunction):
+ def __init__(self, value, param_index, threshold, child_le, child_gt, **kwargs):
+ super().__init__(value, **kwargs)
+ self.param_index = param_index
+ self.threshold = threshold
+ self.child_le = child_le
+ self.child_gt = child_gt
+
+ def is_predictable(self, param_list):
+ """
+ Return whether the model function can be evaluated on the given parameter values.
+ """
+ return is_numeric(param_list[self.param_index])
+
+ def eval(self, param_list):
+ param_value = param_list[self.param_index]
+ if param_value <= self.threshold:
+ return self.child_le.eval(param_list)
+ return self.child_gt.eval(param_list)
+
+ def webconf_function_map(self):
+ return (
+ self.child_le.webconf_function_map() + self.child_gt.webconf_function_map()
+ )
+
+ def to_json(self, feature_names=None, **kwargs):
+ ret = super().to_json(**kwargs)
+ with_param_name = kwargs.get("with_param_name", False)
+ param_names = kwargs.get("param_names", list())
+ update = {
+ "type": "scalarSplit",
+ "paramIndex": self.param_index,
+ "paramName": feature_names[self.param_index],
+ "paramDecisionValue": self.threshold,
+ "left": self.child_le.to_json(),
+ "right": self.child_gt.to_json(),
+ }
+ if with_param_name and param_names:
+ update["paramName"] = param_names[self.param_index]
+ ret.update(update)
+ return ret
+
+ def get_number_of_nodes(self):
+ ret = 1
+ for v in (self.child_le, self.child_gt):
+ if type(v) in (SplitFunction, ScalarSplitFunction):
+ ret += v.get_number_of_nodes()
+ else:
+ ret += 1
+ return ret
+
+ def get_max_depth(self):
+ ret = [0]
+ for v in (self.child_le, self.child_gt):
+ if type(v) in (SplitFunction, ScalarSplitFunction):
+ ret.append(v.get_max_depth())
+ return 1 + max(ret)
+
+ def get_number_of_leaves(self):
+ ret = 0
+ for v in (self.child_le, self.child_gt):
+ if type(v) in (SplitFunction, ScalarSplitFunction):
+ ret += v.get_number_of_leaves()
+ else:
+ ret += 1
+ return ret
+
+ def get_complexity_score(self):
+ ret = 1
+ for v in (self.child_le, self.child_gt):
+ ret += v.get_complexity_score()
+ return ret
+
+ def to_dot(self, pydot, graph, feature_names, parent=None):
+ try:
+ label = feature_names[self.param_index]
+ except IndexError:
+ label = f"param{self.param_index}"
+ graph.add_node(pydot.Node(str(id(self)), label=label))
+ for key, child in self.child.items():
+ child.to_dot(pydot, graph, feature_names, str(id(self)))
+ graph.add_edge(pydot.Edge(str(id(self)), str(id(child)), label=key))
+
+ @classmethod
+ def from_json(cls, data):
+ assert data["type"] == "scalarSplit"
+ left = ModelFunction.from_json(data["left"])
+ right = ModelFunction.from_json(data["right"])
+ self = cls(
+ data["value"], data["paramIndex"], data["paramDecisionValue"], left, right
+ )
+
+ return self
+
+ def __repr__(self):
+ return f"ScalarSplitFunction<{self.value}, param_index={self.param_index}>"
+
+
class SubstateFunction(ModelFunction):
def __init__(self, value, sequence_by_count, count_model, sub_model, **kwargs):
super().__init__(value, **kwargs)
diff --git a/lib/parameters.py b/lib/parameters.py
index 96a996e..210cbe3 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -828,7 +828,7 @@ class ModelAttribute:
return np.median(self.by_param[param])
def get_by_param(self):
- if self.by_param is None:
+ if self.by_param is None and self.param_values is not None:
self.by_param = partition_by_param(self.data, self.param_values)
return self.by_param