summaryrefslogtreecommitdiff
path: root/lib/automata.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/automata.py')
-rwxr-xr-xlib/automata.py40
1 files changed, 29 insertions, 11 deletions
diff --git a/lib/automata.py b/lib/automata.py
index 563ff4d..e27aa91 100755
--- a/lib/automata.py
+++ b/lib/automata.py
@@ -86,18 +86,32 @@ class State:
if depth == 0:
for trans in self.outgoing_transitions.values():
if with_arguments:
- for args in itertools.product(*trans.argument_values):
- yield [(trans.name, args)]
+ if trans.argument_combination == 'cartesian':
+ for args in itertools.product(*trans.argument_values):
+ yield [(trans.name, args)]
+ else:
+ for args in zip(*trans.argument_values):
+ yield [(trans.name, args)]
else:
yield [trans.name]
else:
for trans in self.outgoing_transitions.values():
for suffix in trans.destination.dfs(depth - 1, with_arguments = with_arguments):
if with_arguments:
- for args in itertools.product(*trans.argument_values):
- new_suffix = [(trans.name, args)]
- new_suffix.extend(suffix)
- yield new_suffix
+ if trans.argument_combination == 'cartesian':
+ for args in itertools.product(*trans.argument_values):
+ new_suffix = [(trans.name, args)]
+ new_suffix.extend(suffix)
+ yield new_suffix
+ else:
+ if len(trans.argument_values):
+ arg_values = zip(*trans.argument_values)
+ else:
+ arg_values = [tuple()]
+ for args in arg_values:
+ new_suffix = [(trans.name, args)]
+ new_suffix.extend(suffix)
+ yield new_suffix
else:
new_suffix = [trans.name]
new_suffix.extend(suffix)
@@ -121,6 +135,7 @@ class Transition:
is_interrupt: bool = False,
arguments: list = [],
argument_values: list = [],
+ argument_combination: str = 'cartesian', # or 'zip'
param_update_function = None,
arg_to_param_map: dict = None,
set_param = None):
@@ -144,6 +159,7 @@ class Transition:
self.is_interrupt = is_interrupt
self.arguments = arguments.copy()
self.argument_values = argument_values.copy()
+ self.argument_combination = argument_combination
self.param_update_function = param_update_function
self.arg_to_param_map = arg_to_param_map
self.set_param = set_param
@@ -213,6 +229,7 @@ class Transition:
'is_interrupt' : self.is_interrupt,
'arguments' : self.arguments,
'argument_values' : self.argument_values,
+ 'argument_combination' : self.argument_combination,
'arg_to_param_map' : self.arg_to_param_map,
'set_param' : self.set_param,
'duration' : _attribute_to_json(self.duration, self.duration_function),
@@ -271,7 +288,7 @@ class PTA:
if 'transition' in json_input:
return cls.from_legacy_json(json_input)
- kwargs = {}
+ kwargs = dict()
for key in ('state_names', 'parameters', 'initial_param_values'):
if key in json_input:
kwargs[key] = json_input[key]
@@ -283,9 +300,10 @@ class PTA:
duration_function = _json_function_to_analytic_function(transition, 'duration', pta.parameters)
energy_function = _json_function_to_analytic_function(transition, 'energy', pta.parameters)
timeout_function = _json_function_to_analytic_function(transition, 'timeout', pta.parameters)
- arg_to_param_map = None
- if 'arg_to_param_map' in transition:
- arg_to_param_map = transition['arg_to_param_map']
+ kwargs = dict()
+ for key in ['arg_to_param_map', 'argument_values', 'argument_combination']:
+ if key in transition:
+ kwargs[key] = transition[key]
origins = transition['origin']
if type(origins) != list:
origins = [origins]
@@ -298,7 +316,7 @@ class PTA:
energy_function = energy_function,
timeout = _json_get_static(transition, 'timeout'),
timeout_function = timeout_function,
- arg_to_param_map = arg_to_param_map
+ **kwargs
)
return pta