summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rwxr-xr-xlib/automata.py36
1 files changed, 26 insertions, 10 deletions
diff --git a/lib/automata.py b/lib/automata.py
index 8321754..3913670 100755
--- a/lib/automata.py
+++ b/lib/automata.py
@@ -1,6 +1,7 @@
"""Classes and helper functions for PTA and other automata."""
from functions import AnalyticFunction
+import itertools
def _dict_to_list(input_dict: dict) -> list:
return [input_dict[x] for x in sorted(input_dict.keys())]
@@ -74,22 +75,33 @@ class State:
interrupts = sorted(interrupts, key = lambda x: x.get_timeout(parameters))
return interrupts[0]
- def dfs(self, depth: int):
+ def dfs(self, depth: int, with_arguments: bool = False):
"""
Return a generator object for depth-first search over all outgoing transitions.
arguments:
depth -- search depth
+ with_arguments -- perform dfs with function+argument transitions instead of just function transitions.
"""
if depth == 0:
for trans in self.outgoing_transitions.values():
- yield [trans.name]
+ if with_arguments:
+ for args in itertools.product(*trans.argument_values):
+ yield [[trans.name, args]]
+ else:
+ yield [trans.name]
else:
for trans in self.outgoing_transitions.values():
- for suffix in trans.destination.dfs(depth - 1):
- new_suffix = [trans.name]
- new_suffix.extend(suffix)
- yield new_suffix
+ for suffix in trans.destination.dfs(depth - 1, with_arguments = with_arguments):
+ if with_arguments:
+ for args in itertools.product(*trans.argument_values):
+ new_suffix = [[trans.name, args]]
+ new_suffix.extend(suffix)
+ yield new_suffix
+ else:
+ new_suffix = [trans.name]
+ new_suffix.extend(suffix)
+ yield new_suffix
def to_json(self) -> dict:
"""Return JSON encoding of this state object."""
@@ -107,8 +119,11 @@ class Transition:
duration: float = 0, duration_function: AnalyticFunction = None,
timeout: float = 0, timeout_function: AnalyticFunction = None,
is_interrupt: bool = False,
- arguments: list = [], param_update_function = None,
- arg_to_param_map: dict = None, set_param = None):
+ arguments: list = [],
+ argument_values: list = [],
+ param_update_function = None,
+ arg_to_param_map: dict = None,
+ set_param = None):
"""
Create a new transition between two PTA states.
@@ -128,6 +143,7 @@ class Transition:
self.timeout_function = timeout_function
self.is_interrupt = is_interrupt
self.arguments = arguments.copy()
+ self.argument_values = argument_values.copy()
self.param_update_function = param_update_function
self.arg_to_param_map = arg_to_param_map
self.set_param = set_param
@@ -328,7 +344,7 @@ class PTA:
self.transitions.append(new_transition)
orig_state.add_outgoing_transition(new_transition)
- def dfs(self, depth: int = 10, orig_state: str = 'UNINITIALIZED'):
+ def dfs(self, depth: int = 10, orig_state: str = 'UNINITIALIZED', **kwargs):
"""
Return a generator object for depth-first search starting at orig_state.
@@ -336,7 +352,7 @@ class PTA:
depth -- search depth
orig_state -- initial state for depth-first search
"""
- return self.states[orig_state].dfs(depth)
+ return self.states[orig_state].dfs(depth, **kwargs)
def simulate(self, trace: list, orig_state: str = 'UNINITIALIZED'):
total_duration = 0.