summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xlib/automata.py123
1 files changed, 119 insertions, 4 deletions
diff --git a/lib/automata.py b/lib/automata.py
index b1cd235..8321754 100755
--- a/lib/automata.py
+++ b/lib/automata.py
@@ -1,3 +1,5 @@
+"""Classes and helper functions for PTA and other automata."""
+
from functions import AnalyticFunction
def _dict_to_list(input_dict: dict) -> list:
@@ -15,36 +17,70 @@ def _attribute_to_json(static_value: float, param_function: AnalyticFunction) ->
return ret
class State:
+ """A single PTA state."""
+
def __init__(self, name: str, power: float = 0,
power_function: AnalyticFunction = None):
+ u"""
+ Create a new PTA state.
+
+ arguments:
+ name -- state name
+ power -- static state power in µW
+ power_function -- parameterized state power in µW
+ """
self.name = name
self.power = power
self.power_function = power_function
self.outgoing_transitions = {}
def add_outgoing_transition(self, new_transition: object):
+ """Add a new outgoing transition."""
self.outgoing_transitions[new_transition.name] = new_transition
def get_energy(self, duration: float, param_dict: dict = {}) -> float:
+ u"""
+ Return energy spent in state in pJ.
+
+ arguments:
+ duration -- duration in µs
+ param_dict -- current parameters
+ """
if self.power_function:
return self.power_function.eval(_dict_to_list(param_dict)) * duration
return self.power * duration
def get_transition(self, transition_name: str) -> object:
+ """Return Transition object for outgoing transtion transition_name."""
return self.outgoing_transitions[transition_name]
- def has_interrupt_transitions(self):
+ def has_interrupt_transitions(self) -> bool:
+ """Check whether this state has any outgoing interrupt transitions."""
for trans in self.outgoing_transitions.values():
if trans.is_interrupt:
return True
return False
def get_next_interrupt(self, parameters: dict) -> object:
+ """
+ Return the outgoing interrupt transition with the lowet timeout.
+
+ Must only be called if has_interrupt_transitions returned true.
+
+ arguments:
+ parameters -- current parameter values
+ """
interrupts = filter(lambda x: x.is_interrupt, self.outgoing_transitions.values())
interrupts = sorted(interrupts, key = lambda x: x.get_timeout(parameters))
return interrupts[0]
def dfs(self, depth: int):
+ """
+ Return a generator object for depth-first search over all outgoing transitions.
+
+ arguments:
+ depth -- search depth
+ """
if depth == 0:
for trans in self.outgoing_transitions.values():
yield [trans.name]
@@ -56,6 +92,7 @@ class State:
yield new_suffix
def to_json(self) -> dict:
+ """Return JSON encoding of this state object."""
ret = {
'name' : self.name,
'power' : _attribute_to_json(self.power, self.power_function)
@@ -63,6 +100,8 @@ class State:
return ret
class Transition:
+ """A single PTA transition with one origin and one destination state."""
+
def __init__(self, orig_state: State, dest_state: State, name: str,
energy: float = 0, energy_function: AnalyticFunction = None,
duration: float = 0, duration_function: AnalyticFunction = None,
@@ -74,9 +113,9 @@ class Transition:
Create a new transition between two PTA states.
arguments:
- orig_state -- origin (State object)
- dest_state -- destination (State object)
- name -- transition name, typically the same as a driver/library function name (str)
+ orig_state -- origin state
+ dest_state -- destination state
+ name -- transition name, typically the same as a driver/library function name
"""
self.name = name
self.origin = orig_state
@@ -94,21 +133,50 @@ class Transition:
self.set_param = set_param
def get_duration(self, param_dict: dict = {}, args: list = []) -> float:
+ u"""
+ Return transition duration in µs.
+
+ arguments:
+ param_dict -- current parameter values
+ args -- function arguments
+ """
if self.duration_function:
return self.duration_function.eval(_dict_to_list(param_dict), args)
return self.duration
def get_energy(self, param_dict: dict = {}, args: list = []) -> float:
+ u"""
+ Return transition energy cost in pJ.
+
+ arguments:
+ param_dict -- current parameter values
+ args -- function arguments
+ """
if self.energy_function:
return self.energy_function.eval(_dict_to_list(param_dict), args)
return self.energy
def get_timeout(self, param_dict: dict = {}) -> float:
+ u"""
+ Return transition timeout in µs.
+
+ Returns 0 if the transition does not have a timeout.
+
+ arguments:
+ param_dict -- current parameter values
+ args -- function arguments
+ """
if self.timeout_function:
return self.timeout_function.eval(_dict_to_list(param_dict))
return self.timeout
def get_params_after_transition(self, param_dict: dict, args: list = []) -> dict:
+ """
+ Return the new parameter dict after taking this transition.
+
+ parameter values may be affected by this transition's update function,
+ it's argument-to-param map, and its set_param settings.
+ """
if self.param_update_function:
return self.param_update_function(param_dict, args)
ret = param_dict.copy()
@@ -121,6 +189,7 @@ class Transition:
return ret
def to_json(self) -> dict:
+ """Return JSON encoding of this transition object."""
ret = {
'name' : self.name,
'origin' : self.origin.name,
@@ -147,8 +216,23 @@ def _json_get_static(base, attribute: str):
return 0
class PTA:
+ """
+ A parameterized priced timed automaton. All states are accepting.
+
+ Suitable for simulation, model storage, and (soon) benchmark generation.
+ """
+
def __init__(self, state_names: list = [],
parameters: list = [], initial_param_values: list = None):
+ """
+ Return a new PTA object.
+
+ arguments:
+ state_names -- names of PTA states. Note that the PTA always contains
+ an initial UNINITIALIZED state, regardless of the content of state_names.
+ parameters -- names of PTA parameters
+ initial_param_values -- initial value for each parameter
+ """
self.states = dict([[state_name, State(state_name)] for state_name in state_names])
self.parameters = parameters.copy()
if initial_param_values:
@@ -162,6 +246,11 @@ class PTA:
@classmethod
def from_json(cls, json_input: dict):
+ """
+ Return a PTA created from the provided JSON data.
+
+ Compatible with the to_json method.
+ """
kwargs = {}
for key in ('state_names', 'parameters', 'initial_param_values'):
if key in json_input:
@@ -195,6 +284,11 @@ class PTA:
return pta
def to_json(self) -> dict:
+ """
+ Return JSON encoding of this PTA.
+
+ Compatible with the from_json method.
+ """
ret = {
'parameters' : self.parameters,
'initial_param_values' : self.initial_param_values,
@@ -204,12 +298,26 @@ class PTA:
return ret
def add_state(self, state_name: str, **kwargs):
+ """
+ Add a new state.
+
+ See the State() documentation for acceptable arguments.
+ """
if 'power_function' in kwargs and type(kwargs['power_function']) != AnalyticFunction:
kwargs['power_function'] = AnalyticFunction(kwargs['power_function'],
self.parameters, 0)
self.states[state_name] = State(state_name, **kwargs)
def add_transition(self, orig_state: str, dest_state: str, function_name: str, **kwargs):
+ """
+ Add function_name as new transition from orig_state to dest_state.
+
+ arguments:
+ orig_state -- origin state name. Must be known to PTA
+ dest_state -- destination state name. Must be known to PTA.
+ function_name -- function name
+ kwargs -- see Transition() documentation
+ """
orig_state = self.states[orig_state]
dest_state = self.states[dest_state]
for key in ('duration_function', 'energy_function', 'timeout_function'):
@@ -221,6 +329,13 @@ class PTA:
orig_state.add_outgoing_transition(new_transition)
def dfs(self, depth: int = 10, orig_state: str = 'UNINITIALIZED'):
+ """
+ Return a generator object for depth-first search starting at orig_state.
+
+ arguments:
+ depth -- search depth
+ orig_state -- initial state for depth-first search
+ """
return self.states[orig_state].dfs(depth)
def simulate(self, trace: list, orig_state: str = 'UNINITIALIZED'):