diff options
Diffstat (limited to 'lib/automata.py')
-rwxr-xr-x | lib/automata.py | 196 |
1 files changed, 33 insertions, 163 deletions
diff --git a/lib/automata.py b/lib/automata.py index 1e47596..b1e5623 100755 --- a/lib/automata.py +++ b/lib/automata.py @@ -1,7 +1,12 @@ #!/usr/bin/env python3 """Classes and helper functions for PTA and other automata.""" -from .functions import AnalyticFunction, NormalizationFunction +from .functions import ( + AnalyticFunction, + NormalizationFunction, + ModelFunction, + StaticFunction, +) from .parameters import ModelAttribute from .utils import is_numeric import itertools @@ -14,7 +19,7 @@ import yaml logger = logging.getLogger(__name__) -def _dict_to_list(input_dict: dict) -> list: +def dict_to_list(input_dict: dict) -> list: return [input_dict[x] for x in sorted(input_dict.keys())] @@ -74,122 +79,15 @@ class SimulationResult: self.mean_power = 0 -class PTAAttribute: - u""" - A single PTA attribute (e.g. power, duration). - - A PTA attribute can be described by a static value and an analytic - function (depending on parameters and function arguments). - - It is not specified how value_error and function_error are determined -- - at the moment, they do not use cross validation. - - :param value: static value, typically in µW/µs/pJ - :param value_error: mean absolute error of value (optional) - :param function: AnalyticFunction for parameter-aware prediction (optional) - :param function_error: mean absolute error of function (optional) - """ - - def __init__( - self, - value: float = 0, - function: AnalyticFunction = None, - value_error=None, - function_error=None, - ): - self.value = value - self.function = function - self.value_error = value_error - self.function_error = function_error - - def __repr__(self): - if self.function is not None: - return "PTAATtribute<{:.0f}, {}>".format( - self.value, self.function.model_function - ) - return "PTAATtribute<{:.0f}, None>".format(self.value) - - def eval(self, param_dict=dict(), args=list()): - """ - Return attribute for given `param_dict` and `args` value. - - Uses `function` if set and usable for the given `param_dict` and - `value` otherwise. - """ - param_list = _dict_to_list(param_dict) - if self.function and self.function.is_predictable(param_list): - return self.function.eval(param_list, args) - return self.value - - def eval_mae(self, param_dict=dict(), args=list()): - """ - Return attribute mean absolute error for given `param_dict` and `args` value. - - Uses `function_error` if `function` is set and usable for the given `param_dict` and `value_error` otherwise. - """ - param_list = _dict_to_list(param_dict) - if self.function and self.function.is_predictable(param_list): - return self.function_error["mae"] - return self.value_error["mae"] - - def to_json(self): - ret = {"static": self.value, "static_error": self.value_error} - if self.function: - ret["function"] = { - "raw": self.function.model_function, - "regression_args": list(self.function.model_args), - } - ret["function_error"] = self.function_error - return ret - - @classmethod - def from_json(cls, json_input: dict, parameters: dict): - ret = cls() - if "static" in json_input: - ret.value = json_input["static"] - if "static_error" in json_input: - ret.value_error = json_input["static_error"] - if "function" in json_input: - ret.function = AnalyticFunction( - json_input["function"]["raw"], - parameters, - 0, - regression_args=json_input["function"]["regression_args"], - ) - if "function_error" in json_input: - ret.function_error = json_input["function_error"] - return ret - - @classmethod - def from_json_maybe(cls, json_wrapped: dict, attribute: str, parameters: dict): - if type(json_wrapped) is dict and attribute in json_wrapped: - return cls.from_json(json_wrapped[attribute], parameters) - return cls() - - -def _json_function_to_analytic_function(base, attribute: str, parameters: list): - if attribute in base and "function" in base[attribute]: - base = base[attribute]["function"] - return AnalyticFunction( - base["raw"], parameters, 0, regression_args=base["regression_args"] - ) - return None - - class State: """A single PTA state.""" - def __init__( - self, - name: str, - power: PTAAttribute = PTAAttribute(), - power_function: AnalyticFunction = None, - ): + def __init__(self, name: str, power: ModelFunction = StaticFunction(0)): u""" Create a new PTA state. :param name: state name - :param power: state power PTAAttribute in µW, default static 0 / parameterized None + :param power: state power ModelFunction in µW, default static StaticFunction(0) :param power_function: Legacy support """ self.name = name @@ -197,13 +95,7 @@ class State: self.outgoing_transitions = {} if type(self.power) is float or type(self.power) is int: - self.power = PTAAttribute(self.power) - - if power_function is not None: - if type(power_function) is AnalyticFunction: - self.power.function = power_function - else: - raise ValueError("power_function must be None or AnalyticFunction") + self.power = StaticFunction(self.power) def __repr__(self): return "State<{:s}, {}>".format(self.name, self.power) @@ -220,7 +112,7 @@ class State: :param param_dict: current parameters :returns: energy spent in pJ """ - return self.power.eval(param_dict) * duration + return self.power.eval(dict_to_list(param_dict)) * duration def set_random_energy_model(self, static_model=True): u"""Set a random static state power between 0 µW and 50 mW.""" @@ -417,12 +309,9 @@ class Transition: orig_state: State, dest_state: State, name: str, - energy: PTAAttribute = PTAAttribute(), - energy_function: AnalyticFunction = None, - duration: PTAAttribute = PTAAttribute(), - duration_function: AnalyticFunction = None, - timeout: PTAAttribute = PTAAttribute(), - timeout_function: AnalyticFunction = None, + energy: ModelFunction = StaticFunction(0), + duration: ModelFunction = StaticFunction(0), + timeout: ModelFunction = StaticFunction(0), is_interrupt: bool = False, arguments: list = [], argument_values: list = [], @@ -457,22 +346,13 @@ class Transition: self.codegen = codegen if type(self.energy) is float or type(self.energy) is int: - self.energy = PTAAttribute(self.energy) - if energy_function is not None: - if type(energy_function) is AnalyticFunction: - self.energy.function = energy_function + self.energy = StaticFunction(self.energy) if type(self.duration) is float or type(self.duration) is int: - self.duration = PTAAttribute(self.duration) - if duration_function is not None: - if type(duration_function) is AnalyticFunction: - self.duration.function = duration_function + self.duration = StaticFunction(self.duration) if type(self.timeout) is float or type(self.timeout) is int: - self.timeout = PTAAttribute(self.timeout) - if timeout_function is not None: - if type(timeout_function) is AnalyticFunction: - self.timeout.function = timeout_function + self.timeout = StaticFunction(self.timeout) for handler in self.return_value_handlers: if "formula" in handler: @@ -487,7 +367,7 @@ class Transition: :returns: transition duration in µs """ - return self.duration.eval(param_dict, args) + return self.duration.eval(dict_to_list(param_dict) + args) def get_energy(self, param_dict: dict = {}, args: list = []) -> float: u""" @@ -496,15 +376,14 @@ class Transition: :param param_dict: current parameter values :param args: function arguments """ - return self.energy.eval(param_dict, args) + return self.energy.eval(dict_to_list(param_dict) + args) def set_random_energy_model(self, static_model=True): self.energy.value = int(np.random.sample() * 50000) self.duration.value = int(np.random.sample() * 50000) - if self.is_interrupt: - self.timeout.value = int(np.random.sample() * 50000) + self.timeout.value = int(np.random.sample() * 50000) - def get_timeout(self, param_dict: dict = {}) -> float: + def get_timeout(self, param_dict: dict = {}, args: list = list()) -> float: u""" Return transition timeout in µs. @@ -513,7 +392,7 @@ class Transition: :param param_dict: current parameter values :param args: function arguments """ - return self.timeout.eval(param_dict) + return self.timeout.eval(dict_to_list(param_dict) + args) def get_params_after_transition(self, param_dict: dict, args: list = []) -> dict: """ @@ -703,9 +582,7 @@ class PTA: kwargs[key] = json_input[key] pta = cls(**kwargs) for name, state in json_input["state"].items(): - pta.add_state( - name, power=PTAAttribute.from_json_maybe(state, "power", pta.parameters) - ) + pta.add_state(name, power=ModelFunction.from_json_maybe(state, "power")) for transition in json_input["transitions"]: kwargs = dict() for key in [ @@ -730,15 +607,9 @@ class PTA: origin, transition["destination"], transition["name"], - duration=PTAAttribute.from_json_maybe( - transition, "duration", pta.parameters - ), - energy=PTAAttribute.from_json_maybe( - transition, "energy", pta.parameters - ), - timeout=PTAAttribute.from_json_maybe( - transition, "timeout", pta.parameters - ), + duration=ModelFunction.from_json_maybe(transition, "duration"), + energy=ModelFunction.from_json_maybe(transition, "energy"), + timeout=ModelFunction.from_json_maybe(transition, "timeout"), **kwargs ) @@ -762,9 +633,7 @@ class PTA: pta = cls(**kwargs) for name, state in json_input["state"].items(): - pta.add_state( - name, power=PTAAttribute(value=float(state["power"]["static"])) - ) + pta.add_state(name, power=StaticFunction(float(state["power"]["static"]))) for trans_name in sorted(json_input["transition"].keys()): transition = json_input["transition"][trans_name] @@ -818,8 +687,7 @@ class PTA: if "state" in yaml_input: for state_name, state in yaml_input["state"].items(): pta.add_state( - state_name, - power=PTAAttribute.from_json_maybe(state, "power", pta.parameters), + state_name, power=ModelFunction.from_json_maybe(state, "power") ) for trans_name in sorted(yaml_input["transition"].keys()): @@ -902,7 +770,7 @@ class PTA: and kwargs["power_function"] is not None ): kwargs["power_function"] = AnalyticFunction( - kwargs["power_function"], self.parameters, 0 + None, kwargs["power_function"], self.parameters, 0 ) self.state[state_name] = State(state_name, **kwargs) @@ -925,7 +793,7 @@ class PTA: and kwargs[key] is not None and type(kwargs[key]) != AnalyticFunction ): - kwargs[key] = AnalyticFunction(kwargs[key], self.parameters, 0) + kwargs[key] = AnalyticFunction(None, kwargs[key], self.parameters, 0) new_transition = Transition(orig_state, dest_state, function_name, **kwargs) self.transitions.append(new_transition) @@ -1252,7 +1120,9 @@ class PTA: accounting.sleep(duration) else: transition = state.get_transition(function_name) - total_duration += transition.duration.eval(param_dict, function_args) + total_duration += transition.duration.eval( + dict_to_list(param_dict) + function_args + ) if transition.duration.value_error is not None: total_duration_mae += ( transition.duration.eval_mae(param_dict, function_args) ** 2 |