summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xlib/automata.py32
-rwxr-xr-xtest/pta.py51
2 files changed, 76 insertions, 7 deletions
diff --git a/lib/automata.py b/lib/automata.py
index 94b3717..7964b58 100755
--- a/lib/automata.py
+++ b/lib/automata.py
@@ -393,11 +393,14 @@ class PTA:
transition = yaml_input['transition'][trans_name]
arguments = list()
argument_values = list()
+ arg_to_param_map = dict()
is_interrupt = False
if 'arguments' in transition:
- for argument in transition['arguments']:
+ for i, argument in enumerate(transition['arguments']):
arguments.append(argument['name'])
argument_values.append(argument['values'])
+ if 'parameter' in argument:
+ arg_to_param_map[argument['parameter']] = i
for origin in transition['src']:
pta.add_transition(origin, transition['dst'], trans_name,
arguments = arguments, argument_values = argument_values)
@@ -454,7 +457,17 @@ class PTA:
"""Return PTA-specific ID of transition."""
return self.transitions.index(transition)
- def dfs(self, depth: int = 10, orig_state: str = 'UNINITIALIZED', **kwargs):
+ def _dfs_with_param(self, generator, param_dict):
+ for trace in generator:
+ param = param_dict.copy()
+ ret = list()
+ for elem in trace:
+ transition, arguments = elem
+ param = transition.get_params_after_transition(param, arguments)
+ ret.append((transition, arguments, param.copy()))
+ yield ret
+
+ def dfs(self, depth: int = 10, orig_state: str = 'UNINITIALIZED', param_dict: dict = None, with_parameters: bool = False, **kwargs):
"""
Return a generator object for depth-first search starting at orig_state.
@@ -462,11 +475,22 @@ class PTA:
depth -- search depth
orig_state -- initial state for depth-first search
"""
+ if with_parameters and not param_dict:
+ param_dict = dict([[self.parameters[i], self.initial_param_values[i]] for i in range(len(self.parameters))])
+
+ if with_parameters and not 'with_arguments' in kwargs:
+ raise ValueError("with_parameters = True requires with_arguments = True")
+
if self.accepting_states:
- return filter(lambda x: x[-1][0].destination.name in self.accepting_states,
+ generator = filter(lambda x: x[-1][0].destination.name in self.accepting_states,
self.state[orig_state].dfs(depth, **kwargs))
else:
- return self.state[orig_state].dfs(depth, **kwargs)
+ generator = self.state[orig_state].dfs(depth, **kwargs)
+
+ if with_parameters:
+ return self._dfs_with_param(generator, param_dict)
+ else:
+ return generator
def simulate(self, trace: list, orig_state: str = 'UNINITIALIZED'):
total_duration = 0.
diff --git a/test/pta.py b/test/pta.py
index 0ed95be..bb6ae45 100755
--- a/test/pta.py
+++ b/test/pta.py
@@ -5,7 +5,7 @@ import unittest
example_json_1 = {
'parameters' : ['datarate', 'txbytes', 'txpower'],
- 'initial_param_values' : [None, None],
+ 'initial_param_values' : [None, None, None],
'state' : {
'IDLE' : {
'power' : {
@@ -85,7 +85,9 @@ example_json_1 = {
],
}
-def dfs_tran_to_name(runs: list, with_args: bool) -> list:
+def dfs_tran_to_name(runs: list, with_args: bool = False, with_param: bool = False) -> list:
+ if with_param:
+ return list(map(lambda run: list(map(lambda x: (x[0].name, x[1], x[2]), run)), runs))
if with_args:
return list(map(lambda run: list(map(lambda x: (x[0].name, x[1]), run)), runs))
return list(map(lambda run: list(map(lambda x: (x[0].name), run)), runs))
@@ -158,7 +160,7 @@ class TestPTA(unittest.TestCase):
# print(json)
# self.assertDictEqual(json, example_json_1)
- def test_from_json_dfs(self):
+ def test_from_json_dfs_arg(self):
pta = PTA.from_json(example_json_1)
self.assertEqual(sorted(dfs_tran_to_name(pta.dfs(1), False)), [['init', 'init'], ['init', 'send'], ['init', 'setTxPower']])
self.assertEqual(sorted(dfs_tran_to_name(pta.dfs(1, with_arguments = True), True)),
@@ -172,6 +174,49 @@ class TestPTA(unittest.TestCase):
]
)
+ def test_from_json_dfs_param(self):
+ pta = PTA.from_json(example_json_1)
+ no_param = {
+ 'datarate' : None,
+ 'txbytes' : None,
+ 'txpower' : None,
+ }
+ param_tx3 = {
+ 'datarate' : None,
+ 'txbytes' : 3,
+ 'txpower' : None,
+ }
+ param_tx5 = {
+ 'datarate' : None,
+ 'txbytes' : 5,
+ 'txpower' : None,
+ }
+ param_txp10 = {
+ 'datarate' : None,
+ 'txbytes' : None,
+ 'txpower' : 10,
+ }
+ param_txp20 = {
+ 'datarate' : None,
+ 'txbytes' : None,
+ 'txpower' : 20,
+ }
+ param_txp30 = {
+ 'datarate' : None,
+ 'txbytes' : None,
+ 'txpower' : 30,
+ }
+ self.assertEqual(sorted(dfs_tran_to_name(pta.dfs(1, with_arguments = True, with_parameters = True), True, True)),
+ [
+ [('init', (), no_param), ('init', (), no_param)],
+ [('init', (), no_param), ('send', ('"foo"', 3), param_tx3)],
+ [('init', (), no_param), ('send', ('"hodor"', 5), param_tx5)],
+ [('init', (), no_param), ('setTxPower', (10,), param_txp10)],
+ [('init', (), no_param), ('setTxPower', (20,), param_txp20)],
+ [('init', (), no_param), ('setTxPower', (30,), param_txp30)],
+ ]
+ )
+
def test_from_json_function(self):
pta = PTA.from_json(example_json_1)
self.assertEqual(pta.state['TX'].get_energy(1000, {'datarate' : 10, 'txbytes' : 6, 'txpower' : 10 }), 1000 * (100 + 2 * 10))