diff options
-rwxr-xr-x | bin/generate-dfa-benchmark.py | 20 | ||||
-rwxr-xr-x | lib/automata.py | 18 | ||||
-rwxr-xr-x | test/pta.py | 8 |
3 files changed, 28 insertions, 18 deletions
diff --git a/bin/generate-dfa-benchmark.py b/bin/generate-dfa-benchmark.py index 8e3c227..ac1bdd6 100755 --- a/bin/generate-dfa-benchmark.py +++ b/bin/generate-dfa-benchmark.py @@ -37,16 +37,7 @@ import yaml from automata import PTA from harness import OnboardTimerHarness -opt = {} - -def trace_matches_filter(trace: list, trace_filter: list) -> bool: - for allowed_trace in trace_filter: - if len(trace) < len(allowed_trace): - continue - different_element_count = len(list(filter(None, map(lambda x,y: x[0].name != y, trace, allowed_trace)))) - if different_element_count == 0: - return True - return False +opt = dict() def benchmark_from_runs(pta: PTA, runs: list, harness: object, benchmark_id: int = 0) -> io.StringIO: outbuf = io.StringIO() @@ -215,6 +206,8 @@ if __name__ == '__main__': for trace in opt['trace-filter'].split(): trace_filter.append(trace.split(',')) opt['trace-filter'] = trace_filter + else: + opt['trace-filter'] = None except getopt.GetoptError as err: print(err) @@ -233,12 +226,7 @@ if __name__ == '__main__': else: timer_pin = 'GPIO::p1_0' - runs = list() - - for run in pta.dfs(opt['depth'], with_arguments = True, with_parameters = True): - if 'trace-filter' in opt and not trace_matches_filter(run, opt['trace-filter']): - continue - runs.append(run) + runs = list(pta.dfs(opt['depth'], with_arguments = True, with_parameters = True, trace_filter = opt['trace-filter'])) num_transitions = len(runs) diff --git a/lib/automata.py b/lib/automata.py index d9a722b..502dad8 100755 --- a/lib/automata.py +++ b/lib/automata.py @@ -75,16 +75,20 @@ class State: interrupts = sorted(interrupts, key = lambda x: x.get_timeout(parameters)) return interrupts[0] - def dfs(self, depth: int, with_arguments: bool = False): + def dfs(self, depth: int, with_arguments: bool = False, trace_filter = None): """ Return a generator object for depth-first search over all outgoing transitions. arguments: depth -- search depth with_arguments -- perform dfs with function+argument transitions instead of just function transitions. + trace_filter -- list of lists. Each sub-list is a trace. Only traces matching one of the provided sub-lists are returned. + E.g. trace_filter = [['init', 'foo'], ['init', 'bar']] will only return traces with init as first and foo or bar as second element. """ if depth == 0: for trans in self.outgoing_transitions.values(): + if trace_filter is not None and len(list(filter(lambda x: x == trans.name, map(lambda x: x[0], trace_filter)))) == 0: + continue if with_arguments: if trans.argument_combination == 'cartesian': for args in itertools.product(*trans.argument_values): @@ -96,7 +100,16 @@ class State: yield [(trans,)] else: for trans in self.outgoing_transitions.values(): - for suffix in trans.destination.dfs(depth - 1, with_arguments = with_arguments): + if trace_filter is not None and len(list(filter(lambda x: x == trans.name, map(lambda x: x[0], trace_filter)))) == 0: + continue + if trace_filter is not None: + new_trace_filter = map(lambda x: x[1:], filter(lambda x: x[0] == trans.name, trace_filter)) + new_trace_filter = list(filter(len, new_trace_filter)) + if len(new_trace_filter) == 0: + new_trace_filter = None + else: + new_trace_filter = None + for suffix in trans.destination.dfs(depth - 1, with_arguments = with_arguments, trace_filter = new_trace_filter): if with_arguments: if trans.argument_combination == 'cartesian': for args in itertools.product(*trans.argument_values): @@ -494,6 +507,7 @@ class PTA: param_dict: initial parameter values with_arguments: perform dfs with argument values with_parameters: include parameters in trace? + trace_filter: list of lists. Each sub-list is a trace. Only traces matching one of the provided sub-lists are returned. The returned generator emits traces. Each trace consts of a list of tuples describing the corresponding transition and (if enabled) diff --git a/test/pta.py b/test/pta.py index 25ec20c..8ca41aa 100755 --- a/test/pta.py +++ b/test/pta.py @@ -155,6 +155,14 @@ class TestPTA(unittest.TestCase): ['init', 'set2', 'set1'], ['init', 'set2', 'set2']]) + def test_dfs_trace_filter(self): + pta = PTA(['IDLE']) + pta.add_transition('UNINITIALIZED', 'IDLE', 'init') + pta.add_transition('IDLE', 'IDLE', 'set1') + pta.add_transition('IDLE', 'IDLE', 'set2') + self.assertEqual(sorted(dfs_tran_to_name(pta.dfs(2, trace_filter=[['init', 'set1', 'set2'], ['init', 'set2', 'set1']]), False)), + [['init', 'set1', 'set2'], ['init', 'set2', 'set1']]) + def test_dfs_accepting(self): pta = PTA(['IDLE', 'TX'], accepting_states = ['IDLE']) pta.add_transition('UNINITIALIZED', 'IDLE', 'init') |