diff options
-rwxr-xr-x | lib/automata.py | 32 | ||||
-rwxr-xr-x | test/pta.py | 51 |
2 files changed, 76 insertions, 7 deletions
diff --git a/lib/automata.py b/lib/automata.py index 94b3717..7964b58 100755 --- a/lib/automata.py +++ b/lib/automata.py @@ -393,11 +393,14 @@ class PTA: transition = yaml_input['transition'][trans_name] arguments = list() argument_values = list() + arg_to_param_map = dict() is_interrupt = False if 'arguments' in transition: - for argument in transition['arguments']: + for i, argument in enumerate(transition['arguments']): arguments.append(argument['name']) argument_values.append(argument['values']) + if 'parameter' in argument: + arg_to_param_map[argument['parameter']] = i for origin in transition['src']: pta.add_transition(origin, transition['dst'], trans_name, arguments = arguments, argument_values = argument_values) @@ -454,7 +457,17 @@ class PTA: """Return PTA-specific ID of transition.""" return self.transitions.index(transition) - def dfs(self, depth: int = 10, orig_state: str = 'UNINITIALIZED', **kwargs): + def _dfs_with_param(self, generator, param_dict): + for trace in generator: + param = param_dict.copy() + ret = list() + for elem in trace: + transition, arguments = elem + param = transition.get_params_after_transition(param, arguments) + ret.append((transition, arguments, param.copy())) + yield ret + + def dfs(self, depth: int = 10, orig_state: str = 'UNINITIALIZED', param_dict: dict = None, with_parameters: bool = False, **kwargs): """ Return a generator object for depth-first search starting at orig_state. @@ -462,11 +475,22 @@ class PTA: depth -- search depth orig_state -- initial state for depth-first search """ + if with_parameters and not param_dict: + param_dict = dict([[self.parameters[i], self.initial_param_values[i]] for i in range(len(self.parameters))]) + + if with_parameters and not 'with_arguments' in kwargs: + raise ValueError("with_parameters = True requires with_arguments = True") + if self.accepting_states: - return filter(lambda x: x[-1][0].destination.name in self.accepting_states, + generator = filter(lambda x: x[-1][0].destination.name in self.accepting_states, self.state[orig_state].dfs(depth, **kwargs)) else: - return self.state[orig_state].dfs(depth, **kwargs) + generator = self.state[orig_state].dfs(depth, **kwargs) + + if with_parameters: + return self._dfs_with_param(generator, param_dict) + else: + return generator def simulate(self, trace: list, orig_state: str = 'UNINITIALIZED'): total_duration = 0. diff --git a/test/pta.py b/test/pta.py index 0ed95be..bb6ae45 100755 --- a/test/pta.py +++ b/test/pta.py @@ -5,7 +5,7 @@ import unittest example_json_1 = { 'parameters' : ['datarate', 'txbytes', 'txpower'], - 'initial_param_values' : [None, None], + 'initial_param_values' : [None, None, None], 'state' : { 'IDLE' : { 'power' : { @@ -85,7 +85,9 @@ example_json_1 = { ], } -def dfs_tran_to_name(runs: list, with_args: bool) -> list: +def dfs_tran_to_name(runs: list, with_args: bool = False, with_param: bool = False) -> list: + if with_param: + return list(map(lambda run: list(map(lambda x: (x[0].name, x[1], x[2]), run)), runs)) if with_args: return list(map(lambda run: list(map(lambda x: (x[0].name, x[1]), run)), runs)) return list(map(lambda run: list(map(lambda x: (x[0].name), run)), runs)) @@ -158,7 +160,7 @@ class TestPTA(unittest.TestCase): # print(json) # self.assertDictEqual(json, example_json_1) - def test_from_json_dfs(self): + def test_from_json_dfs_arg(self): pta = PTA.from_json(example_json_1) self.assertEqual(sorted(dfs_tran_to_name(pta.dfs(1), False)), [['init', 'init'], ['init', 'send'], ['init', 'setTxPower']]) self.assertEqual(sorted(dfs_tran_to_name(pta.dfs(1, with_arguments = True), True)), @@ -172,6 +174,49 @@ class TestPTA(unittest.TestCase): ] ) + def test_from_json_dfs_param(self): + pta = PTA.from_json(example_json_1) + no_param = { + 'datarate' : None, + 'txbytes' : None, + 'txpower' : None, + } + param_tx3 = { + 'datarate' : None, + 'txbytes' : 3, + 'txpower' : None, + } + param_tx5 = { + 'datarate' : None, + 'txbytes' : 5, + 'txpower' : None, + } + param_txp10 = { + 'datarate' : None, + 'txbytes' : None, + 'txpower' : 10, + } + param_txp20 = { + 'datarate' : None, + 'txbytes' : None, + 'txpower' : 20, + } + param_txp30 = { + 'datarate' : None, + 'txbytes' : None, + 'txpower' : 30, + } + self.assertEqual(sorted(dfs_tran_to_name(pta.dfs(1, with_arguments = True, with_parameters = True), True, True)), + [ + [('init', (), no_param), ('init', (), no_param)], + [('init', (), no_param), ('send', ('"foo"', 3), param_tx3)], + [('init', (), no_param), ('send', ('"hodor"', 5), param_tx5)], + [('init', (), no_param), ('setTxPower', (10,), param_txp10)], + [('init', (), no_param), ('setTxPower', (20,), param_txp20)], + [('init', (), no_param), ('setTxPower', (30,), param_txp30)], + ] + ) + def test_from_json_function(self): pta = PTA.from_json(example_json_1) self.assertEqual(pta.state['TX'].get_energy(1000, {'datarate' : 10, 'txbytes' : 6, 'txpower' : 10 }), 1000 * (100 + 2 * 10)) |