summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xbin/test_automata.py13
-rwxr-xr-xlib/automata.py40
2 files changed, 42 insertions, 11 deletions
diff --git a/bin/test_automata.py b/bin/test_automata.py
index c117c56..2b6b3d4 100755
--- a/bin/test_automata.py
+++ b/bin/test_automata.py
@@ -42,6 +42,7 @@ example_json_1 = {
'duration' : { 'static' : 120 },
'energy ' : { 'static' : 10000 },
'arg_to_param_map' : { 'txpower' : 0 },
+ 'argument_values' : [ [10, 20, 30] ],
},
{
'name' : 'send',
@@ -64,6 +65,8 @@ example_json_1 = {
},
},
'arg_to_param_map' : { 'txbytes' : 1 },
+ 'argument_values' : [ ['"foo"', '"hodor"'], [3, 5] ],
+ 'argument_combination' : 'zip',
},
{
'name' : 'txComplete',
@@ -120,6 +123,16 @@ class TestPTA(unittest.TestCase):
def test_from_json_dfs(self):
pta = PTA.from_json(example_json_1)
self.assertEqual(sorted(pta.dfs(1)), [['init', 'init'], ['init', 'send'], ['init', 'setTxPower']])
+ self.assertEqual(sorted(pta.dfs(1, with_arguments = True)),
+ [
+ [('init', ()), ('init', ())],
+ [('init', ()), ('send', ('"foo"', 3))],
+ [('init', ()), ('send', ('"hodor"', 5))],
+ [('init', ()), ('setTxPower', (10,))],
+ [('init', ()), ('setTxPower', (20,))],
+ [('init', ()), ('setTxPower', (30,))],
+ ]
+ )
def test_from_json_function(self):
pta = PTA.from_json(example_json_1)
diff --git a/lib/automata.py b/lib/automata.py
index 563ff4d..e27aa91 100755
--- a/lib/automata.py
+++ b/lib/automata.py
@@ -86,18 +86,32 @@ class State:
if depth == 0:
for trans in self.outgoing_transitions.values():
if with_arguments:
- for args in itertools.product(*trans.argument_values):
- yield [(trans.name, args)]
+ if trans.argument_combination == 'cartesian':
+ for args in itertools.product(*trans.argument_values):
+ yield [(trans.name, args)]
+ else:
+ for args in zip(*trans.argument_values):
+ yield [(trans.name, args)]
else:
yield [trans.name]
else:
for trans in self.outgoing_transitions.values():
for suffix in trans.destination.dfs(depth - 1, with_arguments = with_arguments):
if with_arguments:
- for args in itertools.product(*trans.argument_values):
- new_suffix = [(trans.name, args)]
- new_suffix.extend(suffix)
- yield new_suffix
+ if trans.argument_combination == 'cartesian':
+ for args in itertools.product(*trans.argument_values):
+ new_suffix = [(trans.name, args)]
+ new_suffix.extend(suffix)
+ yield new_suffix
+ else:
+ if len(trans.argument_values):
+ arg_values = zip(*trans.argument_values)
+ else:
+ arg_values = [tuple()]
+ for args in arg_values:
+ new_suffix = [(trans.name, args)]
+ new_suffix.extend(suffix)
+ yield new_suffix
else:
new_suffix = [trans.name]
new_suffix.extend(suffix)
@@ -121,6 +135,7 @@ class Transition:
is_interrupt: bool = False,
arguments: list = [],
argument_values: list = [],
+ argument_combination: str = 'cartesian', # or 'zip'
param_update_function = None,
arg_to_param_map: dict = None,
set_param = None):
@@ -144,6 +159,7 @@ class Transition:
self.is_interrupt = is_interrupt
self.arguments = arguments.copy()
self.argument_values = argument_values.copy()
+ self.argument_combination = argument_combination
self.param_update_function = param_update_function
self.arg_to_param_map = arg_to_param_map
self.set_param = set_param
@@ -213,6 +229,7 @@ class Transition:
'is_interrupt' : self.is_interrupt,
'arguments' : self.arguments,
'argument_values' : self.argument_values,
+ 'argument_combination' : self.argument_combination,
'arg_to_param_map' : self.arg_to_param_map,
'set_param' : self.set_param,
'duration' : _attribute_to_json(self.duration, self.duration_function),
@@ -271,7 +288,7 @@ class PTA:
if 'transition' in json_input:
return cls.from_legacy_json(json_input)
- kwargs = {}
+ kwargs = dict()
for key in ('state_names', 'parameters', 'initial_param_values'):
if key in json_input:
kwargs[key] = json_input[key]
@@ -283,9 +300,10 @@ class PTA:
duration_function = _json_function_to_analytic_function(transition, 'duration', pta.parameters)
energy_function = _json_function_to_analytic_function(transition, 'energy', pta.parameters)
timeout_function = _json_function_to_analytic_function(transition, 'timeout', pta.parameters)
- arg_to_param_map = None
- if 'arg_to_param_map' in transition:
- arg_to_param_map = transition['arg_to_param_map']
+ kwargs = dict()
+ for key in ['arg_to_param_map', 'argument_values', 'argument_combination']:
+ if key in transition:
+ kwargs[key] = transition[key]
origins = transition['origin']
if type(origins) != list:
origins = [origins]
@@ -298,7 +316,7 @@ class PTA:
energy_function = energy_function,
timeout = _json_get_static(transition, 'timeout'),
timeout_function = timeout_function,
- arg_to_param_map = arg_to_param_map
+ **kwargs
)
return pta