diff options
-rwxr-xr-x | bin/test_automata.py | 13 | ||||
-rwxr-xr-x | lib/automata.py | 40 |
2 files changed, 42 insertions, 11 deletions
diff --git a/bin/test_automata.py b/bin/test_automata.py index c117c56..2b6b3d4 100755 --- a/bin/test_automata.py +++ b/bin/test_automata.py @@ -42,6 +42,7 @@ example_json_1 = { 'duration' : { 'static' : 120 }, 'energy ' : { 'static' : 10000 }, 'arg_to_param_map' : { 'txpower' : 0 }, + 'argument_values' : [ [10, 20, 30] ], }, { 'name' : 'send', @@ -64,6 +65,8 @@ example_json_1 = { }, }, 'arg_to_param_map' : { 'txbytes' : 1 }, + 'argument_values' : [ ['"foo"', '"hodor"'], [3, 5] ], + 'argument_combination' : 'zip', }, { 'name' : 'txComplete', @@ -120,6 +123,16 @@ class TestPTA(unittest.TestCase): def test_from_json_dfs(self): pta = PTA.from_json(example_json_1) self.assertEqual(sorted(pta.dfs(1)), [['init', 'init'], ['init', 'send'], ['init', 'setTxPower']]) + self.assertEqual(sorted(pta.dfs(1, with_arguments = True)), + [ + [('init', ()), ('init', ())], + [('init', ()), ('send', ('"foo"', 3))], + [('init', ()), ('send', ('"hodor"', 5))], + [('init', ()), ('setTxPower', (10,))], + [('init', ()), ('setTxPower', (20,))], + [('init', ()), ('setTxPower', (30,))], + ] + ) def test_from_json_function(self): pta = PTA.from_json(example_json_1) diff --git a/lib/automata.py b/lib/automata.py index 563ff4d..e27aa91 100755 --- a/lib/automata.py +++ b/lib/automata.py @@ -86,18 +86,32 @@ class State: if depth == 0: for trans in self.outgoing_transitions.values(): if with_arguments: - for args in itertools.product(*trans.argument_values): - yield [(trans.name, args)] + if trans.argument_combination == 'cartesian': + for args in itertools.product(*trans.argument_values): + yield [(trans.name, args)] + else: + for args in zip(*trans.argument_values): + yield [(trans.name, args)] else: yield [trans.name] else: for trans in self.outgoing_transitions.values(): for suffix in trans.destination.dfs(depth - 1, with_arguments = with_arguments): if with_arguments: - for args in itertools.product(*trans.argument_values): - new_suffix = [(trans.name, args)] - new_suffix.extend(suffix) - yield new_suffix + if trans.argument_combination == 'cartesian': + for args in itertools.product(*trans.argument_values): + new_suffix = [(trans.name, args)] + new_suffix.extend(suffix) + yield new_suffix + else: + if len(trans.argument_values): + arg_values = zip(*trans.argument_values) + else: + arg_values = [tuple()] + for args in arg_values: + new_suffix = [(trans.name, args)] + new_suffix.extend(suffix) + yield new_suffix else: new_suffix = [trans.name] new_suffix.extend(suffix) @@ -121,6 +135,7 @@ class Transition: is_interrupt: bool = False, arguments: list = [], argument_values: list = [], + argument_combination: str = 'cartesian', # or 'zip' param_update_function = None, arg_to_param_map: dict = None, set_param = None): @@ -144,6 +159,7 @@ class Transition: self.is_interrupt = is_interrupt self.arguments = arguments.copy() self.argument_values = argument_values.copy() + self.argument_combination = argument_combination self.param_update_function = param_update_function self.arg_to_param_map = arg_to_param_map self.set_param = set_param @@ -213,6 +229,7 @@ class Transition: 'is_interrupt' : self.is_interrupt, 'arguments' : self.arguments, 'argument_values' : self.argument_values, + 'argument_combination' : self.argument_combination, 'arg_to_param_map' : self.arg_to_param_map, 'set_param' : self.set_param, 'duration' : _attribute_to_json(self.duration, self.duration_function), @@ -271,7 +288,7 @@ class PTA: if 'transition' in json_input: return cls.from_legacy_json(json_input) - kwargs = {} + kwargs = dict() for key in ('state_names', 'parameters', 'initial_param_values'): if key in json_input: kwargs[key] = json_input[key] @@ -283,9 +300,10 @@ class PTA: duration_function = _json_function_to_analytic_function(transition, 'duration', pta.parameters) energy_function = _json_function_to_analytic_function(transition, 'energy', pta.parameters) timeout_function = _json_function_to_analytic_function(transition, 'timeout', pta.parameters) - arg_to_param_map = None - if 'arg_to_param_map' in transition: - arg_to_param_map = transition['arg_to_param_map'] + kwargs = dict() + for key in ['arg_to_param_map', 'argument_values', 'argument_combination']: + if key in transition: + kwargs[key] = transition[key] origins = transition['origin'] if type(origins) != list: origins = [origins] @@ -298,7 +316,7 @@ class PTA: energy_function = energy_function, timeout = _json_get_static(transition, 'timeout'), timeout_function = timeout_function, - arg_to_param_map = arg_to_param_map + **kwargs ) return pta |