summaryrefslogtreecommitdiff
path: root/lib/automata.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/automata.py')
-rwxr-xr-xlib/automata.py196
1 files changed, 33 insertions, 163 deletions
diff --git a/lib/automata.py b/lib/automata.py
index 1e47596..b1e5623 100755
--- a/lib/automata.py
+++ b/lib/automata.py
@@ -1,7 +1,12 @@
#!/usr/bin/env python3
"""Classes and helper functions for PTA and other automata."""
-from .functions import AnalyticFunction, NormalizationFunction
+from .functions import (
+ AnalyticFunction,
+ NormalizationFunction,
+ ModelFunction,
+ StaticFunction,
+)
from .parameters import ModelAttribute
from .utils import is_numeric
import itertools
@@ -14,7 +19,7 @@ import yaml
logger = logging.getLogger(__name__)
-def _dict_to_list(input_dict: dict) -> list:
+def dict_to_list(input_dict: dict) -> list:
return [input_dict[x] for x in sorted(input_dict.keys())]
@@ -74,122 +79,15 @@ class SimulationResult:
self.mean_power = 0
-class PTAAttribute:
- u"""
- A single PTA attribute (e.g. power, duration).
-
- A PTA attribute can be described by a static value and an analytic
- function (depending on parameters and function arguments).
-
- It is not specified how value_error and function_error are determined --
- at the moment, they do not use cross validation.
-
- :param value: static value, typically in µW/µs/pJ
- :param value_error: mean absolute error of value (optional)
- :param function: AnalyticFunction for parameter-aware prediction (optional)
- :param function_error: mean absolute error of function (optional)
- """
-
- def __init__(
- self,
- value: float = 0,
- function: AnalyticFunction = None,
- value_error=None,
- function_error=None,
- ):
- self.value = value
- self.function = function
- self.value_error = value_error
- self.function_error = function_error
-
- def __repr__(self):
- if self.function is not None:
- return "PTAATtribute<{:.0f}, {}>".format(
- self.value, self.function.model_function
- )
- return "PTAATtribute<{:.0f}, None>".format(self.value)
-
- def eval(self, param_dict=dict(), args=list()):
- """
- Return attribute for given `param_dict` and `args` value.
-
- Uses `function` if set and usable for the given `param_dict` and
- `value` otherwise.
- """
- param_list = _dict_to_list(param_dict)
- if self.function and self.function.is_predictable(param_list):
- return self.function.eval(param_list, args)
- return self.value
-
- def eval_mae(self, param_dict=dict(), args=list()):
- """
- Return attribute mean absolute error for given `param_dict` and `args` value.
-
- Uses `function_error` if `function` is set and usable for the given `param_dict` and `value_error` otherwise.
- """
- param_list = _dict_to_list(param_dict)
- if self.function and self.function.is_predictable(param_list):
- return self.function_error["mae"]
- return self.value_error["mae"]
-
- def to_json(self):
- ret = {"static": self.value, "static_error": self.value_error}
- if self.function:
- ret["function"] = {
- "raw": self.function.model_function,
- "regression_args": list(self.function.model_args),
- }
- ret["function_error"] = self.function_error
- return ret
-
- @classmethod
- def from_json(cls, json_input: dict, parameters: dict):
- ret = cls()
- if "static" in json_input:
- ret.value = json_input["static"]
- if "static_error" in json_input:
- ret.value_error = json_input["static_error"]
- if "function" in json_input:
- ret.function = AnalyticFunction(
- json_input["function"]["raw"],
- parameters,
- 0,
- regression_args=json_input["function"]["regression_args"],
- )
- if "function_error" in json_input:
- ret.function_error = json_input["function_error"]
- return ret
-
- @classmethod
- def from_json_maybe(cls, json_wrapped: dict, attribute: str, parameters: dict):
- if type(json_wrapped) is dict and attribute in json_wrapped:
- return cls.from_json(json_wrapped[attribute], parameters)
- return cls()
-
-
-def _json_function_to_analytic_function(base, attribute: str, parameters: list):
- if attribute in base and "function" in base[attribute]:
- base = base[attribute]["function"]
- return AnalyticFunction(
- base["raw"], parameters, 0, regression_args=base["regression_args"]
- )
- return None
-
-
class State:
"""A single PTA state."""
- def __init__(
- self,
- name: str,
- power: PTAAttribute = PTAAttribute(),
- power_function: AnalyticFunction = None,
- ):
+ def __init__(self, name: str, power: ModelFunction = StaticFunction(0)):
u"""
Create a new PTA state.
:param name: state name
- :param power: state power PTAAttribute in µW, default static 0 / parameterized None
+ :param power: state power ModelFunction in µW, default static StaticFunction(0)
:param power_function: Legacy support
"""
self.name = name
@@ -197,13 +95,7 @@ class State:
self.outgoing_transitions = {}
if type(self.power) is float or type(self.power) is int:
- self.power = PTAAttribute(self.power)
-
- if power_function is not None:
- if type(power_function) is AnalyticFunction:
- self.power.function = power_function
- else:
- raise ValueError("power_function must be None or AnalyticFunction")
+ self.power = StaticFunction(self.power)
def __repr__(self):
return "State<{:s}, {}>".format(self.name, self.power)
@@ -220,7 +112,7 @@ class State:
:param param_dict: current parameters
:returns: energy spent in pJ
"""
- return self.power.eval(param_dict) * duration
+ return self.power.eval(dict_to_list(param_dict)) * duration
def set_random_energy_model(self, static_model=True):
u"""Set a random static state power between 0 µW and 50 mW."""
@@ -417,12 +309,9 @@ class Transition:
orig_state: State,
dest_state: State,
name: str,
- energy: PTAAttribute = PTAAttribute(),
- energy_function: AnalyticFunction = None,
- duration: PTAAttribute = PTAAttribute(),
- duration_function: AnalyticFunction = None,
- timeout: PTAAttribute = PTAAttribute(),
- timeout_function: AnalyticFunction = None,
+ energy: ModelFunction = StaticFunction(0),
+ duration: ModelFunction = StaticFunction(0),
+ timeout: ModelFunction = StaticFunction(0),
is_interrupt: bool = False,
arguments: list = [],
argument_values: list = [],
@@ -457,22 +346,13 @@ class Transition:
self.codegen = codegen
if type(self.energy) is float or type(self.energy) is int:
- self.energy = PTAAttribute(self.energy)
- if energy_function is not None:
- if type(energy_function) is AnalyticFunction:
- self.energy.function = energy_function
+ self.energy = StaticFunction(self.energy)
if type(self.duration) is float or type(self.duration) is int:
- self.duration = PTAAttribute(self.duration)
- if duration_function is not None:
- if type(duration_function) is AnalyticFunction:
- self.duration.function = duration_function
+ self.duration = StaticFunction(self.duration)
if type(self.timeout) is float or type(self.timeout) is int:
- self.timeout = PTAAttribute(self.timeout)
- if timeout_function is not None:
- if type(timeout_function) is AnalyticFunction:
- self.timeout.function = timeout_function
+ self.timeout = StaticFunction(self.timeout)
for handler in self.return_value_handlers:
if "formula" in handler:
@@ -487,7 +367,7 @@ class Transition:
:returns: transition duration in µs
"""
- return self.duration.eval(param_dict, args)
+ return self.duration.eval(dict_to_list(param_dict) + args)
def get_energy(self, param_dict: dict = {}, args: list = []) -> float:
u"""
@@ -496,15 +376,14 @@ class Transition:
:param param_dict: current parameter values
:param args: function arguments
"""
- return self.energy.eval(param_dict, args)
+ return self.energy.eval(dict_to_list(param_dict) + args)
def set_random_energy_model(self, static_model=True):
self.energy.value = int(np.random.sample() * 50000)
self.duration.value = int(np.random.sample() * 50000)
- if self.is_interrupt:
- self.timeout.value = int(np.random.sample() * 50000)
+ self.timeout.value = int(np.random.sample() * 50000)
- def get_timeout(self, param_dict: dict = {}) -> float:
+ def get_timeout(self, param_dict: dict = {}, args: list = list()) -> float:
u"""
Return transition timeout in µs.
@@ -513,7 +392,7 @@ class Transition:
:param param_dict: current parameter values
:param args: function arguments
"""
- return self.timeout.eval(param_dict)
+ return self.timeout.eval(dict_to_list(param_dict) + args)
def get_params_after_transition(self, param_dict: dict, args: list = []) -> dict:
"""
@@ -703,9 +582,7 @@ class PTA:
kwargs[key] = json_input[key]
pta = cls(**kwargs)
for name, state in json_input["state"].items():
- pta.add_state(
- name, power=PTAAttribute.from_json_maybe(state, "power", pta.parameters)
- )
+ pta.add_state(name, power=ModelFunction.from_json_maybe(state, "power"))
for transition in json_input["transitions"]:
kwargs = dict()
for key in [
@@ -730,15 +607,9 @@ class PTA:
origin,
transition["destination"],
transition["name"],
- duration=PTAAttribute.from_json_maybe(
- transition, "duration", pta.parameters
- ),
- energy=PTAAttribute.from_json_maybe(
- transition, "energy", pta.parameters
- ),
- timeout=PTAAttribute.from_json_maybe(
- transition, "timeout", pta.parameters
- ),
+ duration=ModelFunction.from_json_maybe(transition, "duration"),
+ energy=ModelFunction.from_json_maybe(transition, "energy"),
+ timeout=ModelFunction.from_json_maybe(transition, "timeout"),
**kwargs
)
@@ -762,9 +633,7 @@ class PTA:
pta = cls(**kwargs)
for name, state in json_input["state"].items():
- pta.add_state(
- name, power=PTAAttribute(value=float(state["power"]["static"]))
- )
+ pta.add_state(name, power=StaticFunction(float(state["power"]["static"])))
for trans_name in sorted(json_input["transition"].keys()):
transition = json_input["transition"][trans_name]
@@ -818,8 +687,7 @@ class PTA:
if "state" in yaml_input:
for state_name, state in yaml_input["state"].items():
pta.add_state(
- state_name,
- power=PTAAttribute.from_json_maybe(state, "power", pta.parameters),
+ state_name, power=ModelFunction.from_json_maybe(state, "power")
)
for trans_name in sorted(yaml_input["transition"].keys()):
@@ -902,7 +770,7 @@ class PTA:
and kwargs["power_function"] is not None
):
kwargs["power_function"] = AnalyticFunction(
- kwargs["power_function"], self.parameters, 0
+ None, kwargs["power_function"], self.parameters, 0
)
self.state[state_name] = State(state_name, **kwargs)
@@ -925,7 +793,7 @@ class PTA:
and kwargs[key] is not None
and type(kwargs[key]) != AnalyticFunction
):
- kwargs[key] = AnalyticFunction(kwargs[key], self.parameters, 0)
+ kwargs[key] = AnalyticFunction(None, kwargs[key], self.parameters, 0)
new_transition = Transition(orig_state, dest_state, function_name, **kwargs)
self.transitions.append(new_transition)
@@ -1252,7 +1120,9 @@ class PTA:
accounting.sleep(duration)
else:
transition = state.get_transition(function_name)
- total_duration += transition.duration.eval(param_dict, function_args)
+ total_duration += transition.duration.eval(
+ dict_to_list(param_dict) + function_args
+ )
if transition.duration.value_error is not None:
total_duration_mae += (
transition.duration.eval_mae(param_dict, function_args) ** 2