summaryrefslogtreecommitdiff
path: root/lib/automata.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/automata.py')
-rwxr-xr-xlib/automata.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/lib/automata.py b/lib/automata.py
index 71dcacc..de40eb4 100755
--- a/lib/automata.py
+++ b/lib/automata.py
@@ -83,11 +83,14 @@ class State:
depth -- search depth
with_arguments -- perform dfs with function+argument transitions instead of just function transitions.
trace_filter -- list of lists. Each sub-list is a trace. Only traces matching one of the provided sub-lists are returned.
- E.g. trace_filter = [['init', 'foo'], ['init', 'bar']] will only return traces with init as first and foo or bar as second element.
+ E.g. trace_filter = [['init', 'foo'], ['init', 'bar']] will only return traces with init as first and foo or bar as second element.
+ trace_filter = [['init', 'foo', '$'], ['init', 'bar'], '$'] will only return the traces ['init', 'foo'] and ['init', 'bar'].
"""
+ # A '$' entry in trace_filter indicates that the trace should (successfully) terminate here regardless of `depth`.
if trace_filter is not None and next(filter(lambda x: x == '$', map(lambda x: x[0], trace_filter)), None) is not None:
yield []
+ # there may be other entries in trace_filter that still yield results.
if depth == 0:
for trans in self.outgoing_transitions.values():
if trace_filter is not None and len(list(filter(lambda x: x == trans.name, map(lambda x: x[0], trace_filter)))) == 0:
@@ -154,7 +157,8 @@ class Transition:
argument_combination: str = 'cartesian', # or 'zip'
param_update_function = None,
arg_to_param_map: dict = None,
- set_param = None):
+ set_param = None,
+ return_value_handlers: list = []):
"""
Create a new transition between two PTA states.
@@ -179,6 +183,7 @@ class Transition:
self.param_update_function = param_update_function
self.arg_to_param_map = arg_to_param_map
self.set_param = set_param
+ self.return_value_handlers = return_value_handlers
def get_duration(self, param_dict: dict = {}, args: list = []) -> float:
u"""
@@ -444,6 +449,8 @@ class PTA:
kwargs['set_param'] = transition['set_param']
if 'is_interrupt' in transition:
kwargs['is_interrupt'] = transition['is_interrupt']
+ if 'return_value' in transition:
+ kwargs['return_value_handlers'] = transition['return_value']
if not 'src' in transition:
transition['src'] = ['UNINITIALIZED']
if not 'dst' in transition: