summaryrefslogtreecommitdiff
path: root/lib/automata.py
diff options
context:
space:
mode:
authorDaniel Friesel <derf@finalrewind.org>2019-02-19 17:49:47 +0100
committerDaniel Friesel <derf@finalrewind.org>2019-02-19 17:59:51 +0100
commit1fa9a52245e9e5db6c9b5027e6a4db472eae2099 (patch)
tree8e03f0e944809136b1d40046fea5e2d2417a42e7 /lib/automata.py
parent07a587d243147ee855fe0448ec0c6f69f04d7d05 (diff)
automata: Add type annotations
Diffstat (limited to 'lib/automata.py')
-rwxr-xr-xlib/automata.py157
1 files changed, 80 insertions, 77 deletions
diff --git a/lib/automata.py b/lib/automata.py
index 3386761..b1cd235 100755
--- a/lib/automata.py
+++ b/lib/automata.py
@@ -1,16 +1,9 @@
from functions import AnalyticFunction
-def _parse_function(input_function):
- if type('input_function') == 'str':
- raise NotImplemented
- if type('input_function') == 'function':
- return 'raise ValueError', input_function
- raise ValueError('Function description must be provided as string or function')
-
-def _dict_to_list(input_dict):
+def _dict_to_list(input_dict: dict) -> list:
return [input_dict[x] for x in sorted(input_dict.keys())]
-def _attribute_to_json(static_value, param_function):
+def _attribute_to_json(static_value: float, param_function: AnalyticFunction) -> dict:
ret = {
'static' : static_value
}
@@ -21,14 +14,70 @@ def _attribute_to_json(static_value, param_function):
}
return ret
+class State:
+ def __init__(self, name: str, power: float = 0,
+ power_function: AnalyticFunction = None):
+ self.name = name
+ self.power = power
+ self.power_function = power_function
+ self.outgoing_transitions = {}
+
+ def add_outgoing_transition(self, new_transition: object):
+ self.outgoing_transitions[new_transition.name] = new_transition
+
+ def get_energy(self, duration: float, param_dict: dict = {}) -> float:
+ if self.power_function:
+ return self.power_function.eval(_dict_to_list(param_dict)) * duration
+ return self.power * duration
+
+ def get_transition(self, transition_name: str) -> object:
+ return self.outgoing_transitions[transition_name]
+
+ def has_interrupt_transitions(self):
+ for trans in self.outgoing_transitions.values():
+ if trans.is_interrupt:
+ return True
+ return False
+
+ def get_next_interrupt(self, parameters: dict) -> object:
+ interrupts = filter(lambda x: x.is_interrupt, self.outgoing_transitions.values())
+ interrupts = sorted(interrupts, key = lambda x: x.get_timeout(parameters))
+ return interrupts[0]
+
+ def dfs(self, depth: int):
+ if depth == 0:
+ for trans in self.outgoing_transitions.values():
+ yield [trans.name]
+ else:
+ for trans in self.outgoing_transitions.values():
+ for suffix in trans.destination.dfs(depth - 1):
+ new_suffix = [trans.name]
+ new_suffix.extend(suffix)
+ yield new_suffix
+
+ def to_json(self) -> dict:
+ ret = {
+ 'name' : self.name,
+ 'power' : _attribute_to_json(self.power, self.power_function)
+ }
+ return ret
+
class Transition:
- def __init__(self, orig_state, dest_state, name,
- energy = 0, energy_function = None,
- duration = 0, duration_function = None,
- timeout = 0, timeout_function = None,
- is_interrupt = False,
- arguments = [], param_update_function = None,
- arg_to_param_map = None, set_param = None):
+ def __init__(self, orig_state: State, dest_state: State, name: str,
+ energy: float = 0, energy_function: AnalyticFunction = None,
+ duration: float = 0, duration_function: AnalyticFunction = None,
+ timeout: float = 0, timeout_function: AnalyticFunction = None,
+ is_interrupt: bool = False,
+ arguments: list = [], param_update_function = None,
+ arg_to_param_map: dict = None, set_param = None):
+ """
+ Create a new transition between two PTA states.
+
+ arguments:
+ orig_state -- origin (State object)
+ dest_state -- destination (State object)
+ name -- transition name, typically the same as a driver/library function name (str)
+ """
self.name = name
self.origin = orig_state
self.destination = dest_state
@@ -44,22 +93,22 @@ class Transition:
self.arg_to_param_map = arg_to_param_map
self.set_param = set_param
- def get_duration(self, param_dict = {}, args = []):
+ def get_duration(self, param_dict: dict = {}, args: list = []) -> float:
if self.duration_function:
return self.duration_function.eval(_dict_to_list(param_dict), args)
return self.duration
- def get_energy(self, param_dict = {}, args = []):
+ def get_energy(self, param_dict: dict = {}, args: list = []) -> float:
if self.energy_function:
return self.energy_function.eval(_dict_to_list(param_dict), args)
return self.energy
- def get_timeout(self, param_dict = {}):
+ def get_timeout(self, param_dict: dict = {}) -> float:
if self.timeout_function:
return self.timeout_function.eval(_dict_to_list(param_dict))
return self.timeout
- def get_params_after_transition(self, param_dict, args = []):
+ def get_params_after_transition(self, param_dict: dict, args: list = []) -> dict:
if self.param_update_function:
return self.param_update_function(param_dict, args)
ret = param_dict.copy()
@@ -71,7 +120,7 @@ class Transition:
ret[k] = v
return ret
- def to_json(self):
+ def to_json(self) -> dict:
ret = {
'name' : self.name,
'origin' : self.origin.name,
@@ -86,66 +135,20 @@ class Transition:
}
return ret
-class State:
- def __init__(self, name, power = 0, power_function = None):
- self.name = name
- self.power = power
- self.power_function = power_function
- self.outgoing_transitions = {}
-
- def add_outgoing_transition(self, new_transition):
- self.outgoing_transitions[new_transition.name] = new_transition
-
- def get_energy(self, duration, param_dict = {}):
- if self.power_function:
- return self.power_function.eval(_dict_to_list(param_dict)) * duration
- return self.power * duration
-
- def get_transition(self, transition_name):
- return self.outgoing_transitions[transition_name]
-
- def has_interrupt_transitions(self):
- for trans in self.outgoing_transitions.values():
- if trans.is_interrupt:
- return True
- return False
-
- def get_next_interrupt(self, parameters):
- interrupts = filter(lambda x: x.is_interrupt, self.outgoing_transitions.values())
- interrupts = sorted(interrupts, key = lambda x: x.get_timeout(parameters))
- return interrupts[0]
-
- def dfs(self, depth):
- if depth == 0:
- for trans in self.outgoing_transitions.values():
- yield [trans.name]
- else:
- for trans in self.outgoing_transitions.values():
- for suffix in trans.destination.dfs(depth - 1):
- new_suffix = [trans.name]
- new_suffix.extend(suffix)
- yield new_suffix
-
- def to_json(self):
- ret = {
- 'name' : self.name,
- 'power' : _attribute_to_json(self.power, self.power_function)
- }
- return ret
-
-def _json_function_to_analytic_function(base, attribute, parameters):
+def _json_function_to_analytic_function(base, attribute: str, parameters: list):
if attribute in base and 'function' in base[attribute]:
base = base[attribute]['function']
return AnalyticFunction(base['raw'], parameters, 0, regression_args = base['regression_args'])
return None
-def _json_get_static(base, attribute):
+def _json_get_static(base, attribute: str):
if attribute in base:
return base[attribute]['static']
return 0
class PTA:
- def __init__(self, state_names = [], parameters = [], initial_param_values = None):
+ def __init__(self, state_names: list = [],
+ parameters: list = [], initial_param_values: list = None):
self.states = dict([[state_name, State(state_name)] for state_name in state_names])
self.parameters = parameters.copy()
if initial_param_values:
@@ -158,7 +161,7 @@ class PTA:
self.states['UNINITIALIZED'] = State('UNINITIALIZED')
@classmethod
- def from_json(cls, json_input):
+ def from_json(cls, json_input: dict):
kwargs = {}
for key in ('state_names', 'parameters', 'initial_param_values'):
if key in json_input:
@@ -191,7 +194,7 @@ class PTA:
return pta
- def to_json(self):
+ def to_json(self) -> dict:
ret = {
'parameters' : self.parameters,
'initial_param_values' : self.initial_param_values,
@@ -200,13 +203,13 @@ class PTA:
}
return ret
- def add_state(self, state_name, **kwargs):
+ def add_state(self, state_name: str, **kwargs):
if 'power_function' in kwargs and type(kwargs['power_function']) != AnalyticFunction:
kwargs['power_function'] = AnalyticFunction(kwargs['power_function'],
self.parameters, 0)
self.states[state_name] = State(state_name, **kwargs)
- def add_transition(self, orig_state, dest_state, function_name, **kwargs):
+ def add_transition(self, orig_state: str, dest_state: str, function_name: str, **kwargs):
orig_state = self.states[orig_state]
dest_state = self.states[dest_state]
for key in ('duration_function', 'energy_function', 'timeout_function'):
@@ -217,10 +220,10 @@ class PTA:
self.transitions.append(new_transition)
orig_state.add_outgoing_transition(new_transition)
- def dfs(self, depth = 10, orig_state = 'UNINITIALIZED'):
+ def dfs(self, depth: int = 10, orig_state: str = 'UNINITIALIZED'):
return self.states[orig_state].dfs(depth)
- def simulate(self, trace, orig_state = 'UNINITIALIZED'):
+ def simulate(self, trace: list, orig_state: str = 'UNINITIALIZED'):
total_duration = 0.
total_energy = 0.
state = self.states[orig_state]