diff options
-rwxr-xr-x | bin/analyze-trace.py | 47 | ||||
-rw-r--r-- | lib/functions.py | 17 | ||||
-rw-r--r-- | lib/utils.py | 8 |
3 files changed, 70 insertions, 2 deletions
diff --git a/bin/analyze-trace.py b/bin/analyze-trace.py index 372dd4e..9aa72d8 100755 --- a/bin/analyze-trace.py +++ b/bin/analyze-trace.py @@ -43,6 +43,7 @@ def learn_pta(observations, annotation, delta=dict(), delta_param=dict()): if annotation.kernels: for i in range(prev_i, annotation.kernels[0].offset): this = observations[i]["name"] + " @ " + observations[i]["place"] + if not prev in delta: delta[prev] = set() delta[prev].add(this) @@ -71,6 +72,7 @@ def learn_pta(observations, annotation, delta=dict(), delta_param=dict()): prev = prev_non_kernel for i in range(prev_i, kernel.offset): this = observations[i]["name"] + " @ " + observations[i]["place"] + if not prev in delta: delta[prev] = set() delta[prev].add(this) @@ -98,6 +100,7 @@ def learn_pta(observations, annotation, delta=dict(), delta_param=dict()): prev = prev_non_kernel for i in range(prev_i, annotation.end.offset): this = observations[i]["name"] + " @ " + observations[i]["place"] + if not prev in delta: delta[prev] = set() delta[prev].add(this) @@ -170,6 +173,7 @@ def main(): delta_by_name = dict() delta_param_by_name = dict() for annotation in annotations: + am_tt_param_names = sorted(annotation.start.param.keys()) if annotation.name not in delta_by_name: delta_by_name[annotation.name] = dict() delta_param_by_name[annotation.name] = dict() @@ -181,16 +185,54 @@ def main(): ) observations += meta_obs + def format_guard(guard): + return "∧".join(map(lambda kv: f"{kv[0]}={kv[1]}", guard)) + for name in sorted(delta_by_name.keys()): + delta_cond = dict() for t_from, t_to_set in delta_by_name[name].items(): + i_to_transition = dict() delta_param_sets = list() to_names = list() - for t_to in t_to_set: + transition_guard = dict() + + if len(t_to_set) > 1: + am_tt_by_name = { + name: { + "attributes": [t_from], + "param": list(), + t_from: list(), + }, + } + for i, t_to in enumerate(sorted(t_to_set)): + for param in delta_param_by_name[name][(t_from, t_to)]: + am_tt_by_name[name]["param"].append( + dfatool.utils.param_dict_to_list( + dfatool.utils.param_str_to_dict(param), + am_tt_param_names, + ) + ) + am_tt_by_name[name][t_from].append(i) + i_to_transition[i] = t_to + am = AnalyticModel(am_tt_by_name, am_tt_param_names, force_tree=True) + model, info = am.get_fitted() + flat_model = info(name, t_from).flatten() + + for prefix, output in flat_model: + transition_name = i_to_transition[int(output)] + if transition_name not in transition_guard: + transition_guard[transition_name] = list() + transition_guard[transition_name].append(prefix) + + for t_to in sorted(t_to_set): delta_params = delta_param_by_name[name][(t_from, t_to)] delta_param_sets.append(delta_params) to_names.append(t_to) n_confs = len(delta_params) - print(f"{name} {t_from} → {t_to} ({n_confs:4d}x)") + print( + f"{name} {t_from} → {t_to} ({' ∨ '.join(map(format_guard, transition_guard.get(t_to, list()))) or '⊤'})" + ) + for i in range(len(delta_param_sets)): for j in range(i + 1, len(delta_param_sets)): if not delta_param_sets[i].isdisjoint(delta_param_sets[j]): @@ -203,6 +245,7 @@ def main(): raise RuntimeError( f"Outbound transitions of <{t_from}> are not deterministic" ) + print("") by_name, parameter_names = dfatool.utils.observations_to_by_name(observations) diff --git a/lib/functions.py b/lib/functions.py index 187e6ff..35b04ef 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -466,6 +466,23 @@ class SplitFunction(ModelFunction): ) return hyper + # SplitFunction only + def flatten(self): + paths = list() + for param_value, subtree in self.child.items(): + if type(subtree) is SplitFunction: + for path, value in subtree.flatten(): + path = [(self.param_name, param_value)] + path + paths.append((path, value)) + elif type(subtree) is StaticFunction: + path = [(self.param_name, param_value)] + paths.append((path, subtree.value)) + else: + raise RuntimeError( + "flatten is only implemented for RMTs with constant leaves" + ) + return paths + @classmethod def from_json(cls, data): assert data["type"] == "split" diff --git a/lib/utils.py b/lib/utils.py index c148426..48a29d8 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -325,6 +325,14 @@ def param_dict_to_str(param_dict): return " ".join(ret) +def param_str_to_dict(param_str): + ret = dict() + for param_pair in param_str.split(): + key, value = param_pair.split("=") + ret[key] = soft_cast_int_or_float(value) + return ret + + def observations_enum_to_bool(observations: list, kconfig=False): """ Convert enum / categorical observations to boolean-only ones. |