diff options
author | Daniel Friesel <derf@finalrewind.org> | 2019-02-19 17:49:47 +0100 |
---|---|---|
committer | Daniel Friesel <derf@finalrewind.org> | 2019-02-19 17:59:51 +0100 |
commit | 1fa9a52245e9e5db6c9b5027e6a4db472eae2099 (patch) | |
tree | 8e03f0e944809136b1d40046fea5e2d2417a42e7 /lib | |
parent | 07a587d243147ee855fe0448ec0c6f69f04d7d05 (diff) |
automata: Add type annotations
Diffstat (limited to 'lib')
-rwxr-xr-x | lib/automata.py | 157 |
1 files changed, 80 insertions, 77 deletions
diff --git a/lib/automata.py b/lib/automata.py index 3386761..b1cd235 100755 --- a/lib/automata.py +++ b/lib/automata.py @@ -1,16 +1,9 @@ from functions import AnalyticFunction -def _parse_function(input_function): - if type('input_function') == 'str': - raise NotImplemented - if type('input_function') == 'function': - return 'raise ValueError', input_function - raise ValueError('Function description must be provided as string or function') - -def _dict_to_list(input_dict): +def _dict_to_list(input_dict: dict) -> list: return [input_dict[x] for x in sorted(input_dict.keys())] -def _attribute_to_json(static_value, param_function): +def _attribute_to_json(static_value: float, param_function: AnalyticFunction) -> dict: ret = { 'static' : static_value } @@ -21,14 +14,70 @@ def _attribute_to_json(static_value, param_function): } return ret +class State: + def __init__(self, name: str, power: float = 0, + power_function: AnalyticFunction = None): + self.name = name + self.power = power + self.power_function = power_function + self.outgoing_transitions = {} + + def add_outgoing_transition(self, new_transition: object): + self.outgoing_transitions[new_transition.name] = new_transition + + def get_energy(self, duration: float, param_dict: dict = {}) -> float: + if self.power_function: + return self.power_function.eval(_dict_to_list(param_dict)) * duration + return self.power * duration + + def get_transition(self, transition_name: str) -> object: + return self.outgoing_transitions[transition_name] + + def has_interrupt_transitions(self): + for trans in self.outgoing_transitions.values(): + if trans.is_interrupt: + return True + return False + + def get_next_interrupt(self, parameters: dict) -> object: + interrupts = filter(lambda x: x.is_interrupt, self.outgoing_transitions.values()) + interrupts = sorted(interrupts, key = lambda x: x.get_timeout(parameters)) + return interrupts[0] + + def dfs(self, depth: int): + if depth == 0: + for trans in self.outgoing_transitions.values(): + yield [trans.name] + else: + for trans in self.outgoing_transitions.values(): + for suffix in trans.destination.dfs(depth - 1): + new_suffix = [trans.name] + new_suffix.extend(suffix) + yield new_suffix + + def to_json(self) -> dict: + ret = { + 'name' : self.name, + 'power' : _attribute_to_json(self.power, self.power_function) + } + return ret + class Transition: - def __init__(self, orig_state, dest_state, name, - energy = 0, energy_function = None, - duration = 0, duration_function = None, - timeout = 0, timeout_function = None, - is_interrupt = False, - arguments = [], param_update_function = None, - arg_to_param_map = None, set_param = None): + def __init__(self, orig_state: State, dest_state: State, name: str, + energy: float = 0, energy_function: AnalyticFunction = None, + duration: float = 0, duration_function: AnalyticFunction = None, + timeout: float = 0, timeout_function: AnalyticFunction = None, + is_interrupt: bool = False, + arguments: list = [], param_update_function = None, + arg_to_param_map: dict = None, set_param = None): + """ + Create a new transition between two PTA states. + + arguments: + orig_state -- origin (State object) + dest_state -- destination (State object) + name -- transition name, typically the same as a driver/library function name (str) + """ self.name = name self.origin = orig_state self.destination = dest_state @@ -44,22 +93,22 @@ class Transition: self.arg_to_param_map = arg_to_param_map self.set_param = set_param - def get_duration(self, param_dict = {}, args = []): + def get_duration(self, param_dict: dict = {}, args: list = []) -> float: if self.duration_function: return self.duration_function.eval(_dict_to_list(param_dict), args) return self.duration - def get_energy(self, param_dict = {}, args = []): + def get_energy(self, param_dict: dict = {}, args: list = []) -> float: if self.energy_function: return self.energy_function.eval(_dict_to_list(param_dict), args) return self.energy - def get_timeout(self, param_dict = {}): + def get_timeout(self, param_dict: dict = {}) -> float: if self.timeout_function: return self.timeout_function.eval(_dict_to_list(param_dict)) return self.timeout - def get_params_after_transition(self, param_dict, args = []): + def get_params_after_transition(self, param_dict: dict, args: list = []) -> dict: if self.param_update_function: return self.param_update_function(param_dict, args) ret = param_dict.copy() @@ -71,7 +120,7 @@ class Transition: ret[k] = v return ret - def to_json(self): + def to_json(self) -> dict: ret = { 'name' : self.name, 'origin' : self.origin.name, @@ -86,66 +135,20 @@ class Transition: } return ret -class State: - def __init__(self, name, power = 0, power_function = None): - self.name = name - self.power = power - self.power_function = power_function - self.outgoing_transitions = {} - - def add_outgoing_transition(self, new_transition): - self.outgoing_transitions[new_transition.name] = new_transition - - def get_energy(self, duration, param_dict = {}): - if self.power_function: - return self.power_function.eval(_dict_to_list(param_dict)) * duration - return self.power * duration - - def get_transition(self, transition_name): - return self.outgoing_transitions[transition_name] - - def has_interrupt_transitions(self): - for trans in self.outgoing_transitions.values(): - if trans.is_interrupt: - return True - return False - - def get_next_interrupt(self, parameters): - interrupts = filter(lambda x: x.is_interrupt, self.outgoing_transitions.values()) - interrupts = sorted(interrupts, key = lambda x: x.get_timeout(parameters)) - return interrupts[0] - - def dfs(self, depth): - if depth == 0: - for trans in self.outgoing_transitions.values(): - yield [trans.name] - else: - for trans in self.outgoing_transitions.values(): - for suffix in trans.destination.dfs(depth - 1): - new_suffix = [trans.name] - new_suffix.extend(suffix) - yield new_suffix - - def to_json(self): - ret = { - 'name' : self.name, - 'power' : _attribute_to_json(self.power, self.power_function) - } - return ret - -def _json_function_to_analytic_function(base, attribute, parameters): +def _json_function_to_analytic_function(base, attribute: str, parameters: list): if attribute in base and 'function' in base[attribute]: base = base[attribute]['function'] return AnalyticFunction(base['raw'], parameters, 0, regression_args = base['regression_args']) return None -def _json_get_static(base, attribute): +def _json_get_static(base, attribute: str): if attribute in base: return base[attribute]['static'] return 0 class PTA: - def __init__(self, state_names = [], parameters = [], initial_param_values = None): + def __init__(self, state_names: list = [], + parameters: list = [], initial_param_values: list = None): self.states = dict([[state_name, State(state_name)] for state_name in state_names]) self.parameters = parameters.copy() if initial_param_values: @@ -158,7 +161,7 @@ class PTA: self.states['UNINITIALIZED'] = State('UNINITIALIZED') @classmethod - def from_json(cls, json_input): + def from_json(cls, json_input: dict): kwargs = {} for key in ('state_names', 'parameters', 'initial_param_values'): if key in json_input: @@ -191,7 +194,7 @@ class PTA: return pta - def to_json(self): + def to_json(self) -> dict: ret = { 'parameters' : self.parameters, 'initial_param_values' : self.initial_param_values, @@ -200,13 +203,13 @@ class PTA: } return ret - def add_state(self, state_name, **kwargs): + def add_state(self, state_name: str, **kwargs): if 'power_function' in kwargs and type(kwargs['power_function']) != AnalyticFunction: kwargs['power_function'] = AnalyticFunction(kwargs['power_function'], self.parameters, 0) self.states[state_name] = State(state_name, **kwargs) - def add_transition(self, orig_state, dest_state, function_name, **kwargs): + def add_transition(self, orig_state: str, dest_state: str, function_name: str, **kwargs): orig_state = self.states[orig_state] dest_state = self.states[dest_state] for key in ('duration_function', 'energy_function', 'timeout_function'): @@ -217,10 +220,10 @@ class PTA: self.transitions.append(new_transition) orig_state.add_outgoing_transition(new_transition) - def dfs(self, depth = 10, orig_state = 'UNINITIALIZED'): + def dfs(self, depth: int = 10, orig_state: str = 'UNINITIALIZED'): return self.states[orig_state].dfs(depth) - def simulate(self, trace, orig_state = 'UNINITIALIZED'): + def simulate(self, trace: list, orig_state: str = 'UNINITIALIZED'): total_duration = 0. total_energy = 0. state = self.states[orig_state] |