summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xbin/analyze-trace.py47
-rw-r--r--lib/functions.py17
-rw-r--r--lib/utils.py8
3 files changed, 70 insertions, 2 deletions
diff --git a/bin/analyze-trace.py b/bin/analyze-trace.py
index 372dd4e..9aa72d8 100755
--- a/bin/analyze-trace.py
+++ b/bin/analyze-trace.py
@@ -43,6 +43,7 @@ def learn_pta(observations, annotation, delta=dict(), delta_param=dict()):
if annotation.kernels:
for i in range(prev_i, annotation.kernels[0].offset):
this = observations[i]["name"] + " @ " + observations[i]["place"]
+
if not prev in delta:
delta[prev] = set()
delta[prev].add(this)
@@ -71,6 +72,7 @@ def learn_pta(observations, annotation, delta=dict(), delta_param=dict()):
prev = prev_non_kernel
for i in range(prev_i, kernel.offset):
this = observations[i]["name"] + " @ " + observations[i]["place"]
+
if not prev in delta:
delta[prev] = set()
delta[prev].add(this)
@@ -98,6 +100,7 @@ def learn_pta(observations, annotation, delta=dict(), delta_param=dict()):
prev = prev_non_kernel
for i in range(prev_i, annotation.end.offset):
this = observations[i]["name"] + " @ " + observations[i]["place"]
+
if not prev in delta:
delta[prev] = set()
delta[prev].add(this)
@@ -170,6 +173,7 @@ def main():
delta_by_name = dict()
delta_param_by_name = dict()
for annotation in annotations:
+ am_tt_param_names = sorted(annotation.start.param.keys())
if annotation.name not in delta_by_name:
delta_by_name[annotation.name] = dict()
delta_param_by_name[annotation.name] = dict()
@@ -181,16 +185,54 @@ def main():
)
observations += meta_obs
+ def format_guard(guard):
+ return "∧".join(map(lambda kv: f"{kv[0]}={kv[1]}", guard))
+
for name in sorted(delta_by_name.keys()):
+ delta_cond = dict()
for t_from, t_to_set in delta_by_name[name].items():
+ i_to_transition = dict()
delta_param_sets = list()
to_names = list()
- for t_to in t_to_set:
+ transition_guard = dict()
+
+ if len(t_to_set) > 1:
+ am_tt_by_name = {
+ name: {
+ "attributes": [t_from],
+ "param": list(),
+ t_from: list(),
+ },
+ }
+ for i, t_to in enumerate(sorted(t_to_set)):
+ for param in delta_param_by_name[name][(t_from, t_to)]:
+ am_tt_by_name[name]["param"].append(
+ dfatool.utils.param_dict_to_list(
+ dfatool.utils.param_str_to_dict(param),
+ am_tt_param_names,
+ )
+ )
+ am_tt_by_name[name][t_from].append(i)
+ i_to_transition[i] = t_to
+ am = AnalyticModel(am_tt_by_name, am_tt_param_names, force_tree=True)
+ model, info = am.get_fitted()
+ flat_model = info(name, t_from).flatten()
+
+ for prefix, output in flat_model:
+ transition_name = i_to_transition[int(output)]
+ if transition_name not in transition_guard:
+ transition_guard[transition_name] = list()
+ transition_guard[transition_name].append(prefix)
+
+ for t_to in sorted(t_to_set):
delta_params = delta_param_by_name[name][(t_from, t_to)]
delta_param_sets.append(delta_params)
to_names.append(t_to)
n_confs = len(delta_params)
- print(f"{name} {t_from} → {t_to} ({n_confs:4d}x)")
+ print(
+ f"{name} {t_from} → {t_to} ({' ∨ '.join(map(format_guard, transition_guard.get(t_to, list()))) or '⊤'})"
+ )
+
for i in range(len(delta_param_sets)):
for j in range(i + 1, len(delta_param_sets)):
if not delta_param_sets[i].isdisjoint(delta_param_sets[j]):
@@ -203,6 +245,7 @@ def main():
raise RuntimeError(
f"Outbound transitions of <{t_from}> are not deterministic"
)
+
print("")
by_name, parameter_names = dfatool.utils.observations_to_by_name(observations)
diff --git a/lib/functions.py b/lib/functions.py
index 187e6ff..35b04ef 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -466,6 +466,23 @@ class SplitFunction(ModelFunction):
)
return hyper
+ # SplitFunction only
+ def flatten(self):
+ paths = list()
+ for param_value, subtree in self.child.items():
+ if type(subtree) is SplitFunction:
+ for path, value in subtree.flatten():
+ path = [(self.param_name, param_value)] + path
+ paths.append((path, value))
+ elif type(subtree) is StaticFunction:
+ path = [(self.param_name, param_value)]
+ paths.append((path, subtree.value))
+ else:
+ raise RuntimeError(
+ "flatten is only implemented for RMTs with constant leaves"
+ )
+ return paths
+
@classmethod
def from_json(cls, data):
assert data["type"] == "split"
diff --git a/lib/utils.py b/lib/utils.py
index c148426..48a29d8 100644
--- a/lib/utils.py
+++ b/lib/utils.py
@@ -325,6 +325,14 @@ def param_dict_to_str(param_dict):
return " ".join(ret)
+def param_str_to_dict(param_str):
+ ret = dict()
+ for param_pair in param_str.split():
+ key, value = param_pair.split("=")
+ ret[key] = soft_cast_int_or_float(value)
+ return ret
+
+
def observations_enum_to_bool(observations: list, kconfig=False):
"""
Convert enum / categorical observations to boolean-only ones.