diff options
Diffstat (limited to 'lib')
-rwxr-xr-x | lib/automata.py | 40 |
1 files changed, 29 insertions, 11 deletions
diff --git a/lib/automata.py b/lib/automata.py index 563ff4d..e27aa91 100755 --- a/lib/automata.py +++ b/lib/automata.py @@ -86,18 +86,32 @@ class State: if depth == 0: for trans in self.outgoing_transitions.values(): if with_arguments: - for args in itertools.product(*trans.argument_values): - yield [(trans.name, args)] + if trans.argument_combination == 'cartesian': + for args in itertools.product(*trans.argument_values): + yield [(trans.name, args)] + else: + for args in zip(*trans.argument_values): + yield [(trans.name, args)] else: yield [trans.name] else: for trans in self.outgoing_transitions.values(): for suffix in trans.destination.dfs(depth - 1, with_arguments = with_arguments): if with_arguments: - for args in itertools.product(*trans.argument_values): - new_suffix = [(trans.name, args)] - new_suffix.extend(suffix) - yield new_suffix + if trans.argument_combination == 'cartesian': + for args in itertools.product(*trans.argument_values): + new_suffix = [(trans.name, args)] + new_suffix.extend(suffix) + yield new_suffix + else: + if len(trans.argument_values): + arg_values = zip(*trans.argument_values) + else: + arg_values = [tuple()] + for args in arg_values: + new_suffix = [(trans.name, args)] + new_suffix.extend(suffix) + yield new_suffix else: new_suffix = [trans.name] new_suffix.extend(suffix) @@ -121,6 +135,7 @@ class Transition: is_interrupt: bool = False, arguments: list = [], argument_values: list = [], + argument_combination: str = 'cartesian', # or 'zip' param_update_function = None, arg_to_param_map: dict = None, set_param = None): @@ -144,6 +159,7 @@ class Transition: self.is_interrupt = is_interrupt self.arguments = arguments.copy() self.argument_values = argument_values.copy() + self.argument_combination = argument_combination self.param_update_function = param_update_function self.arg_to_param_map = arg_to_param_map self.set_param = set_param @@ -213,6 +229,7 @@ class Transition: 'is_interrupt' : self.is_interrupt, 'arguments' : self.arguments, 'argument_values' : self.argument_values, + 'argument_combination' : self.argument_combination, 'arg_to_param_map' : self.arg_to_param_map, 'set_param' : self.set_param, 'duration' : _attribute_to_json(self.duration, self.duration_function), @@ -271,7 +288,7 @@ class PTA: if 'transition' in json_input: return cls.from_legacy_json(json_input) - kwargs = {} + kwargs = dict() for key in ('state_names', 'parameters', 'initial_param_values'): if key in json_input: kwargs[key] = json_input[key] @@ -283,9 +300,10 @@ class PTA: duration_function = _json_function_to_analytic_function(transition, 'duration', pta.parameters) energy_function = _json_function_to_analytic_function(transition, 'energy', pta.parameters) timeout_function = _json_function_to_analytic_function(transition, 'timeout', pta.parameters) - arg_to_param_map = None - if 'arg_to_param_map' in transition: - arg_to_param_map = transition['arg_to_param_map'] + kwargs = dict() + for key in ['arg_to_param_map', 'argument_values', 'argument_combination']: + if key in transition: + kwargs[key] = transition[key] origins = transition['origin'] if type(origins) != list: origins = [origins] @@ -298,7 +316,7 @@ class PTA: energy_function = energy_function, timeout = _json_get_static(transition, 'timeout'), timeout_function = timeout_function, - arg_to_param_map = arg_to_param_map + **kwargs ) return pta |