summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/aspectc.py35
-rwxr-xr-xlib/automata.py705
-rw-r--r--lib/codegen.py389
-rw-r--r--lib/cycles_to_energy.py139
-rw-r--r--lib/data_parameters.py337
-rw-r--r--lib/dfatool.py1797
-rw-r--r--lib/functions.py250
-rw-r--r--lib/harness.py489
-rwxr-xr-xlib/ipython_energymodel_prelude.py8
-rwxr-xr-xlib/keysightdlog.py136
-rw-r--r--lib/lex.py104
-rw-r--r--lib/modular_arithmetic.py57
-rw-r--r--lib/parameters.py450
-rwxr-xr-xlib/plotter.py261
-rwxr-xr-xlib/protocol_benchmarks.py1654
-rw-r--r--lib/pubcode/__init__.py2
-rw-r--r--lib/pubcode/code128.py283
-rw-r--r--lib/runner.py156
-rw-r--r--lib/size_to_radio_energy.py315
-rw-r--r--lib/sly/__init__.py3
-rw-r--r--lib/sly/ast.py9
-rw-r--r--lib/sly/docparse.py15
-rw-r--r--lib/sly/lex.py178
-rw-r--r--lib/sly/yacc.py867
-rw-r--r--lib/utils.py82
25 files changed, 5558 insertions, 3163 deletions
diff --git a/lib/aspectc.py b/lib/aspectc.py
index 3229057..f4a102c 100644
--- a/lib/aspectc.py
+++ b/lib/aspectc.py
@@ -1,5 +1,6 @@
import xml.etree.ElementTree as ET
+
class AspectCClass:
"""
C++ class information provided by the AspectC++ repo.acp
@@ -19,6 +20,7 @@ class AspectCClass:
for function in self.functions:
self.function[function.name] = function
+
class AspectCFunction:
"""
C++ function informationed provided by the AspectC++ repo.acp
@@ -53,16 +55,23 @@ class AspectCFunction:
:param function_node: `xml.etree.ElementTree.Element` node
"""
- name = function_node.get('name')
- kind = function_node.get('kind')
- function_id = function_node.get('id')
+ name = function_node.get("name")
+ kind = function_node.get("kind")
+ function_id = function_node.get("id")
return_type = None
argument_types = list()
- for type_node in function_node.findall('result_type/Type'):
- return_type = type_node.get('signature')
- for type_node in function_node.findall('arg_types/Type'):
- argument_types.append(type_node.get('signature'))
- return cls(name = name, kind = kind, function_id = function_id, argument_types = argument_types, return_type = return_type)
+ for type_node in function_node.findall("result_type/Type"):
+ return_type = type_node.get("signature")
+ for type_node in function_node.findall("arg_types/Type"):
+ argument_types.append(type_node.get("signature"))
+ return cls(
+ name=name,
+ kind=kind,
+ function_id=function_id,
+ argument_types=argument_types,
+ return_type=return_type,
+ )
+
class Repo:
"""
@@ -85,11 +94,13 @@ class Repo:
def _load_classes(self):
self.classes = list()
- for class_node in self.root.findall('root/Namespace[@name="::"]/children/Class'):
- name = class_node.get('name')
- class_id = class_node.get('id')
+ for class_node in self.root.findall(
+ 'root/Namespace[@name="::"]/children/Class'
+ ):
+ name = class_node.get("name")
+ class_id = class_node.get("id")
functions = list()
- for function_node in class_node.findall('children/Function'):
+ for function_node in class_node.findall("children/Function"):
function = AspectCFunction.from_function_node(function_node)
functions.append(function)
self.classes.append(AspectCClass(name, class_id, functions))
diff --git a/lib/automata.py b/lib/automata.py
index b7668c5..b3318e0 100755
--- a/lib/automata.py
+++ b/lib/automata.py
@@ -28,7 +28,15 @@ class SimulationResult:
:param mean_power: mean power during run in W
"""
- def __init__(self, duration: float, energy: float, end_state, parameters, duration_mae: float = None, energy_mae: float = None):
+ def __init__(
+ self,
+ duration: float,
+ energy: float,
+ end_state,
+ parameters,
+ duration_mae: float = None,
+ energy_mae: float = None,
+ ):
u"""
Create a new SimulationResult.
@@ -77,7 +85,13 @@ class PTAAttribute:
:param function_error: mean absolute error of function (optional)
"""
- def __init__(self, value: float = 0, function: AnalyticFunction = None, value_error=None, function_error=None):
+ def __init__(
+ self,
+ value: float = 0,
+ function: AnalyticFunction = None,
+ value_error=None,
+ function_error=None,
+ ):
self.value = value
self.function = function
self.value_error = value_error
@@ -85,8 +99,10 @@ class PTAAttribute:
def __repr__(self):
if self.function is not None:
- return 'PTAATtribute<{:.0f}, {}>'.format(self.value, self.function._model_str)
- return 'PTAATtribute<{:.0f}, None>'.format(self.value)
+ return "PTAATtribute<{:.0f}, {}>".format(
+ self.value, self.function._model_str
+ )
+ return "PTAATtribute<{:.0f}, None>".format(self.value)
def eval(self, param_dict=dict(), args=list()):
"""
@@ -108,33 +124,38 @@ class PTAAttribute:
"""
param_list = _dict_to_list(param_dict)
if self.function and self.function.is_predictable(param_list):
- return self.function_error['mae']
- return self.value_error['mae']
+ return self.function_error["mae"]
+ return self.value_error["mae"]
def to_json(self):
ret = {
- 'static': self.value,
- 'static_error': self.value_error,
+ "static": self.value,
+ "static_error": self.value_error,
}
if self.function:
- ret['function'] = {
- 'raw': self.function._model_str,
- 'regression_args': list(self.function._regression_args)
+ ret["function"] = {
+ "raw": self.function._model_str,
+ "regression_args": list(self.function._regression_args),
}
- ret['function_error'] = self.function_error
+ ret["function_error"] = self.function_error
return ret
@classmethod
def from_json(cls, json_input: dict, parameters: dict):
ret = cls()
- if 'static' in json_input:
- ret.value = json_input['static']
- if 'static_error' in json_input:
- ret.value_error = json_input['static_error']
- if 'function' in json_input:
- ret.function = AnalyticFunction(json_input['function']['raw'], parameters, 0, regression_args=json_input['function']['regression_args'])
- if 'function_error' in json_input:
- ret.function_error = json_input['function_error']
+ if "static" in json_input:
+ ret.value = json_input["static"]
+ if "static_error" in json_input:
+ ret.value_error = json_input["static_error"]
+ if "function" in json_input:
+ ret.function = AnalyticFunction(
+ json_input["function"]["raw"],
+ parameters,
+ 0,
+ regression_args=json_input["function"]["regression_args"],
+ )
+ if "function_error" in json_input:
+ ret.function_error = json_input["function_error"]
return ret
@classmethod
@@ -145,16 +166,23 @@ class PTAAttribute:
def _json_function_to_analytic_function(base, attribute: str, parameters: list):
- if attribute in base and 'function' in base[attribute]:
- base = base[attribute]['function']
- return AnalyticFunction(base['raw'], parameters, 0, regression_args=base['regression_args'])
+ if attribute in base and "function" in base[attribute]:
+ base = base[attribute]["function"]
+ return AnalyticFunction(
+ base["raw"], parameters, 0, regression_args=base["regression_args"]
+ )
return None
class State:
"""A single PTA state."""
- def __init__(self, name: str, power: PTAAttribute = PTAAttribute(), power_function: AnalyticFunction = None):
+ def __init__(
+ self,
+ name: str,
+ power: PTAAttribute = PTAAttribute(),
+ power_function: AnalyticFunction = None,
+ ):
u"""
Create a new PTA state.
@@ -173,10 +201,10 @@ class State:
if type(power_function) is AnalyticFunction:
self.power.function = power_function
else:
- raise ValueError('power_function must be None or AnalyticFunction')
+ raise ValueError("power_function must be None or AnalyticFunction")
def __repr__(self):
- return 'State<{:s}, {}>'.format(self.name, self.power)
+ return "State<{:s}, {}>".format(self.name, self.power)
def add_outgoing_transition(self, new_transition: object):
"""Add a new outgoing transition."""
@@ -206,7 +234,11 @@ class State:
try:
return self.outgoing_transitions[transition_name]
except KeyError:
- raise ValueError('State {} has no outgoing transition called {}'.format(self.name, transition_name)) from None
+ raise ValueError(
+ "State {} has no outgoing transition called {}".format(
+ self.name, transition_name
+ )
+ ) from None
def has_interrupt_transitions(self) -> bool:
"""Return whether this state has any outgoing interrupt transitions."""
@@ -224,7 +256,9 @@ class State:
:param parameters: current parameter values
:returns: Transition object
"""
- interrupts = filter(lambda x: x.is_interrupt, self.outgoing_transitions.values())
+ interrupts = filter(
+ lambda x: x.is_interrupt, self.outgoing_transitions.values()
+ )
interrupts = sorted(interrupts, key=lambda x: x.get_timeout(parameters))
return interrupts[0]
@@ -246,15 +280,32 @@ class State:
# TODO parametergewahrer Trace-Filter, z.B. "setHeaterDuration nur wenn bme680 power mode => FORCED und GAS_ENABLED"
# A '$' entry in trace_filter indicates that the trace should (successfully) terminate here regardless of `depth`.
- if trace_filter is not None and next(filter(lambda x: x == '$', map(lambda x: x[0], trace_filter)), None) is not None:
+ if (
+ trace_filter is not None
+ and next(
+ filter(lambda x: x == "$", map(lambda x: x[0], trace_filter)), None
+ )
+ is not None
+ ):
yield []
# there may be other entries in trace_filter that still yield results.
if depth == 0:
for trans in self.outgoing_transitions.values():
- if trace_filter is not None and len(list(filter(lambda x: x == trans.name, map(lambda x: x[0], trace_filter)))) == 0:
+ if (
+ trace_filter is not None
+ and len(
+ list(
+ filter(
+ lambda x: x == trans.name,
+ map(lambda x: x[0], trace_filter),
+ )
+ )
+ )
+ == 0
+ ):
continue
if with_arguments:
- if trans.argument_combination == 'cartesian':
+ if trans.argument_combination == "cartesian":
for args in itertools.product(*trans.argument_values):
if sleep:
yield [(None, sleep), (trans, args)]
@@ -273,18 +324,35 @@ class State:
yield [(trans,)]
else:
for trans in self.outgoing_transitions.values():
- if trace_filter is not None and next(filter(lambda x: x == trans.name, map(lambda x: x[0], trace_filter)), None) is None:
+ if (
+ trace_filter is not None
+ and next(
+ filter(
+ lambda x: x == trans.name, map(lambda x: x[0], trace_filter)
+ ),
+ None,
+ )
+ is None
+ ):
continue
if trace_filter is not None:
- new_trace_filter = map(lambda x: x[1:], filter(lambda x: x[0] == trans.name, trace_filter))
+ new_trace_filter = map(
+ lambda x: x[1:],
+ filter(lambda x: x[0] == trans.name, trace_filter),
+ )
new_trace_filter = list(filter(len, new_trace_filter))
if len(new_trace_filter) == 0:
new_trace_filter = None
else:
new_trace_filter = None
- for suffix in trans.destination.dfs(depth - 1, with_arguments=with_arguments, trace_filter=new_trace_filter, sleep=sleep):
+ for suffix in trans.destination.dfs(
+ depth - 1,
+ with_arguments=with_arguments,
+ trace_filter=new_trace_filter,
+ sleep=sleep,
+ ):
if with_arguments:
- if trans.argument_combination == 'cartesian':
+ if trans.argument_combination == "cartesian":
for args in itertools.product(*trans.argument_values):
if sleep:
new_suffix = [(None, sleep), (trans, args)]
@@ -314,10 +382,7 @@ class State:
def to_json(self) -> dict:
"""Return JSON encoding of this state object."""
- ret = {
- 'name': self.name,
- 'power': self.power.to_json()
- }
+ ret = {"name": self.name, "power": self.power.to_json()}
return ret
@@ -345,19 +410,27 @@ class Transition:
:param codegen: todo
"""
- def __init__(self, orig_state: State, dest_state: State, name: str,
- energy: PTAAttribute = PTAAttribute(), energy_function: AnalyticFunction = None,
- duration: PTAAttribute = PTAAttribute(), duration_function: AnalyticFunction = None,
- timeout: PTAAttribute = PTAAttribute(), timeout_function: AnalyticFunction = None,
- is_interrupt: bool = False,
- arguments: list = [],
- argument_values: list = [],
- argument_combination: str = 'cartesian', # or 'zip'
- param_update_function=None,
- arg_to_param_map: dict = None,
- set_param=None,
- return_value_handlers: list = [],
- codegen=dict()):
+ def __init__(
+ self,
+ orig_state: State,
+ dest_state: State,
+ name: str,
+ energy: PTAAttribute = PTAAttribute(),
+ energy_function: AnalyticFunction = None,
+ duration: PTAAttribute = PTAAttribute(),
+ duration_function: AnalyticFunction = None,
+ timeout: PTAAttribute = PTAAttribute(),
+ timeout_function: AnalyticFunction = None,
+ is_interrupt: bool = False,
+ arguments: list = [],
+ argument_values: list = [],
+ argument_combination: str = "cartesian", # or 'zip'
+ param_update_function=None,
+ arg_to_param_map: dict = None,
+ set_param=None,
+ return_value_handlers: list = [],
+ codegen=dict(),
+ ):
"""
Create a new transition between two PTA states.
@@ -400,8 +473,8 @@ class Transition:
self.timeout.function = timeout_function
for handler in self.return_value_handlers:
- if 'formula' in handler:
- handler['formula'] = NormalizationFunction(handler['formula'])
+ if "formula" in handler:
+ handler["formula"] = NormalizationFunction(handler["formula"])
def get_duration(self, param_dict: dict = {}, args: list = []) -> float:
u"""
@@ -465,25 +538,25 @@ class Transition:
def to_json(self) -> dict:
"""Return JSON encoding of this transition object."""
ret = {
- 'name': self.name,
- 'origin': self.origin.name,
- 'destination': self.destination.name,
- 'is_interrupt': self.is_interrupt,
- 'arguments': self.arguments,
- 'argument_values': self.argument_values,
- 'argument_combination': self.argument_combination,
- 'arg_to_param_map': self.arg_to_param_map,
- 'set_param': self.set_param,
- 'duration': self.duration.to_json(),
- 'energy': self.energy.to_json(),
- 'timeout': self.timeout.to_json()
+ "name": self.name,
+ "origin": self.origin.name,
+ "destination": self.destination.name,
+ "is_interrupt": self.is_interrupt,
+ "arguments": self.arguments,
+ "argument_values": self.argument_values,
+ "argument_combination": self.argument_combination,
+ "arg_to_param_map": self.arg_to_param_map,
+ "set_param": self.set_param,
+ "duration": self.duration.to_json(),
+ "energy": self.energy.to_json(),
+ "timeout": self.timeout.to_json(),
}
return ret
def _json_get_static(base, attribute: str):
if attribute in base:
- return base[attribute]['static']
+ return base[attribute]["static"]
return 0
@@ -505,10 +578,15 @@ class PTA:
:param transitions: list of `Transition` objects
"""
- def __init__(self, state_names: list = [],
- accepting_states: list = None,
- parameters: list = [], initial_param_values: list = None,
- codegen: dict = {}, parameter_normalization: dict = None):
+ def __init__(
+ self,
+ state_names: list = [],
+ accepting_states: list = None,
+ parameters: list = [],
+ initial_param_values: list = None,
+ codegen: dict = {},
+ parameter_normalization: dict = None,
+ ):
"""
Return a new PTA object.
@@ -524,7 +602,9 @@ class PTA:
`enum`: maps enum descriptors (keys) to parameter values. Note that the mapping is not required to correspond to the driver API.
`formula`: maps an argument or return value (passed as `param`) to a parameter value. Must be a string describing a valid python lambda function. NumPy is available as `np`.
"""
- self.state = dict([[state_name, State(state_name)] for state_name in state_names])
+ self.state = dict(
+ [[state_name, State(state_name)] for state_name in state_names]
+ )
self.accepting_states = accepting_states.copy() if accepting_states else None
self.parameters = parameters.copy()
self.parameter_normalization = parameter_normalization
@@ -535,13 +615,15 @@ class PTA:
self.initial_param_values = [None for x in self.parameters]
self.transitions = []
- if 'UNINITIALIZED' not in state_names:
- self.state['UNINITIALIZED'] = State('UNINITIALIZED')
+ if "UNINITIALIZED" not in state_names:
+ self.state["UNINITIALIZED"] = State("UNINITIALIZED")
if self.parameter_normalization:
for normalization_spec in self.parameter_normalization.values():
- if 'formula' in normalization_spec:
- normalization_spec['formula'] = NormalizationFunction(normalization_spec['formula'])
+ if "formula" in normalization_spec:
+ normalization_spec["formula"] = NormalizationFunction(
+ normalization_spec["formula"]
+ )
def normalize_parameter(self, parameter_name: str, parameter_value) -> float:
"""
@@ -553,11 +635,23 @@ class PTA:
:param parameter_name: parameter name.
:param parameter_value: parameter value.
"""
- if parameter_value is not None and self.parameter_normalization is not None and parameter_name in self.parameter_normalization:
- if 'enum' in self.parameter_normalization[parameter_name] and parameter_value in self.parameter_normalization[parameter_name]['enum']:
- return self.parameter_normalization[parameter_name]['enum'][parameter_value]
- if 'formula' in self.parameter_normalization[parameter_name]:
- normalization_formula = self.parameter_normalization[parameter_name]['formula']
+ if (
+ parameter_value is not None
+ and self.parameter_normalization is not None
+ and parameter_name in self.parameter_normalization
+ ):
+ if (
+ "enum" in self.parameter_normalization[parameter_name]
+ and parameter_value
+ in self.parameter_normalization[parameter_name]["enum"]
+ ):
+ return self.parameter_normalization[parameter_name]["enum"][
+ parameter_value
+ ]
+ if "formula" in self.parameter_normalization[parameter_name]:
+ normalization_formula = self.parameter_normalization[parameter_name][
+ "formula"
+ ]
return normalization_formula.eval(parameter_value)
return parameter_value
@@ -580,8 +674,8 @@ class PTA:
@classmethod
def from_file(cls, model_file: str):
"""Return PTA loaded from the provided JSON or YAML file."""
- with open(model_file, 'r') as f:
- if '.json' in model_file:
+ with open(model_file, "r") as f:
+ if ".json" in model_file:
return cls.from_json(json.load(f))
else:
return cls.from_yaml(yaml.safe_load(f))
@@ -593,36 +687,58 @@ class PTA:
Compatible with the to_json method.
"""
- if 'transition' in json_input:
+ if "transition" in json_input:
return cls.from_legacy_json(json_input)
kwargs = dict()
- for key in ('state_names', 'parameters', 'initial_param_values', 'accepting_states'):
+ for key in (
+ "state_names",
+ "parameters",
+ "initial_param_values",
+ "accepting_states",
+ ):
if key in json_input:
kwargs[key] = json_input[key]
pta = cls(**kwargs)
- for name, state in json_input['state'].items():
- pta.add_state(name, power=PTAAttribute.from_json_maybe(state, 'power', pta.parameters))
- for transition in json_input['transitions']:
+ for name, state in json_input["state"].items():
+ pta.add_state(
+ name, power=PTAAttribute.from_json_maybe(state, "power", pta.parameters)
+ )
+ for transition in json_input["transitions"]:
kwargs = dict()
- for key in ['arguments', 'argument_values', 'argument_combination', 'is_interrupt', 'set_param']:
+ for key in [
+ "arguments",
+ "argument_values",
+ "argument_combination",
+ "is_interrupt",
+ "set_param",
+ ]:
if key in transition:
kwargs[key] = transition[key]
# arg_to_param_map uses integer indices. This is not supported by JSON
- if 'arg_to_param_map' in transition:
- kwargs['arg_to_param_map'] = dict()
- for arg_index, param_name in transition['arg_to_param_map'].items():
- kwargs['arg_to_param_map'][int(arg_index)] = param_name
- origins = transition['origin']
+ if "arg_to_param_map" in transition:
+ kwargs["arg_to_param_map"] = dict()
+ for arg_index, param_name in transition["arg_to_param_map"].items():
+ kwargs["arg_to_param_map"][int(arg_index)] = param_name
+ origins = transition["origin"]
if type(origins) != list:
origins = [origins]
for origin in origins:
- pta.add_transition(origin, transition['destination'],
- transition['name'],
- duration=PTAAttribute.from_json_maybe(transition, 'duration', pta.parameters),
- energy=PTAAttribute.from_json_maybe(transition, 'energy', pta.parameters),
- timeout=PTAAttribute.from_json_maybe(transition, 'timeout', pta.parameters),
- **kwargs)
+ pta.add_transition(
+ origin,
+ transition["destination"],
+ transition["name"],
+ duration=PTAAttribute.from_json_maybe(
+ transition, "duration", pta.parameters
+ ),
+ energy=PTAAttribute.from_json_maybe(
+ transition, "energy", pta.parameters
+ ),
+ timeout=PTAAttribute.from_json_maybe(
+ transition, "timeout", pta.parameters
+ ),
+ **kwargs
+ )
return pta
@@ -634,37 +750,45 @@ class PTA:
Compatible with the legacy dfatool/perl format.
"""
kwargs = {
- 'parameters': list(),
- 'initial_param_values': list(),
+ "parameters": list(),
+ "initial_param_values": list(),
}
- for param in sorted(json_input['parameter'].keys()):
- kwargs['parameters'].append(param)
- kwargs['initial_param_values'].append(json_input['parameter'][param]['default'])
+ for param in sorted(json_input["parameter"].keys()):
+ kwargs["parameters"].append(param)
+ kwargs["initial_param_values"].append(
+ json_input["parameter"][param]["default"]
+ )
pta = cls(**kwargs)
- for name, state in json_input['state'].items():
- pta.add_state(name, power=PTAAttribute(value=float(state['power']['static'])))
+ for name, state in json_input["state"].items():
+ pta.add_state(
+ name, power=PTAAttribute(value=float(state["power"]["static"]))
+ )
- for trans_name in sorted(json_input['transition'].keys()):
- transition = json_input['transition'][trans_name]
- destination = transition['destination']
+ for trans_name in sorted(json_input["transition"].keys()):
+ transition = json_input["transition"][trans_name]
+ destination = transition["destination"]
arguments = list()
argument_values = list()
is_interrupt = False
- if transition['level'] == 'epilogue':
+ if transition["level"] == "epilogue":
is_interrupt = True
if type(destination) == list:
destination = destination[0]
- for arg in transition['parameters']:
- arguments.append(arg['name'])
- argument_values.append(arg['values'])
- for origin in transition['origins']:
- pta.add_transition(origin, destination, trans_name,
- arguments=arguments,
- argument_values=argument_values,
- is_interrupt=is_interrupt)
+ for arg in transition["parameters"]:
+ arguments.append(arg["name"])
+ argument_values.append(arg["values"])
+ for origin in transition["origins"]:
+ pta.add_transition(
+ origin,
+ destination,
+ trans_name,
+ arguments=arguments,
+ argument_values=argument_values,
+ is_interrupt=is_interrupt,
+ )
return pta
@@ -674,68 +798,79 @@ class PTA:
kwargs = dict()
- if 'parameters' in yaml_input:
- kwargs['parameters'] = yaml_input['parameters']
+ if "parameters" in yaml_input:
+ kwargs["parameters"] = yaml_input["parameters"]
- if 'initial_param_values' in yaml_input:
- kwargs['initial_param_values'] = yaml_input['initial_param_values']
+ if "initial_param_values" in yaml_input:
+ kwargs["initial_param_values"] = yaml_input["initial_param_values"]
- if 'states' in yaml_input:
- kwargs['state_names'] = yaml_input['states']
+ if "states" in yaml_input:
+ kwargs["state_names"] = yaml_input["states"]
# else: set to UNINITIALIZED by class constructor
- if 'codegen' in yaml_input:
- kwargs['codegen'] = yaml_input['codegen']
+ if "codegen" in yaml_input:
+ kwargs["codegen"] = yaml_input["codegen"]
- if 'parameter_normalization' in yaml_input:
- kwargs['parameter_normalization'] = yaml_input['parameter_normalization']
+ if "parameter_normalization" in yaml_input:
+ kwargs["parameter_normalization"] = yaml_input["parameter_normalization"]
pta = cls(**kwargs)
- if 'state' in yaml_input:
- for state_name, state in yaml_input['state'].items():
- pta.add_state(state_name, power=PTAAttribute.from_json_maybe(state, 'power', pta.parameters))
+ if "state" in yaml_input:
+ for state_name, state in yaml_input["state"].items():
+ pta.add_state(
+ state_name,
+ power=PTAAttribute.from_json_maybe(state, "power", pta.parameters),
+ )
- for trans_name in sorted(yaml_input['transition'].keys()):
+ for trans_name in sorted(yaml_input["transition"].keys()):
kwargs = dict()
- transition = yaml_input['transition'][trans_name]
+ transition = yaml_input["transition"][trans_name]
arguments = list()
argument_values = list()
arg_to_param_map = dict()
- if 'arguments' in transition:
- for i, argument in enumerate(transition['arguments']):
- arguments.append(argument['name'])
- argument_values.append(argument['values'])
- if 'parameter' in argument:
- arg_to_param_map[i] = argument['parameter']
- if 'argument_combination' in transition:
- kwargs['argument_combination'] = transition['argument_combination']
- if 'set_param' in transition:
- kwargs['set_param'] = transition['set_param']
- if 'is_interrupt' in transition:
- kwargs['is_interrupt'] = transition['is_interrupt']
- if 'return_value' in transition:
- kwargs['return_value_handlers'] = transition['return_value']
- if 'codegen' in transition:
- kwargs['codegen'] = transition['codegen']
- if 'loop' in transition:
- for state_name in transition['loop']:
- pta.add_transition(state_name, state_name, trans_name,
- arguments=arguments,
- argument_values=argument_values,
- arg_to_param_map=arg_to_param_map,
- **kwargs)
+ if "arguments" in transition:
+ for i, argument in enumerate(transition["arguments"]):
+ arguments.append(argument["name"])
+ argument_values.append(argument["values"])
+ if "parameter" in argument:
+ arg_to_param_map[i] = argument["parameter"]
+ if "argument_combination" in transition:
+ kwargs["argument_combination"] = transition["argument_combination"]
+ if "set_param" in transition:
+ kwargs["set_param"] = transition["set_param"]
+ if "is_interrupt" in transition:
+ kwargs["is_interrupt"] = transition["is_interrupt"]
+ if "return_value" in transition:
+ kwargs["return_value_handlers"] = transition["return_value"]
+ if "codegen" in transition:
+ kwargs["codegen"] = transition["codegen"]
+ if "loop" in transition:
+ for state_name in transition["loop"]:
+ pta.add_transition(
+ state_name,
+ state_name,
+ trans_name,
+ arguments=arguments,
+ argument_values=argument_values,
+ arg_to_param_map=arg_to_param_map,
+ **kwargs
+ )
else:
- if 'src' not in transition:
- transition['src'] = ['UNINITIALIZED']
- if 'dst' not in transition:
- transition['dst'] = 'UNINITIALIZED'
- for origin in transition['src']:
- pta.add_transition(origin, transition['dst'], trans_name,
- arguments=arguments,
- argument_values=argument_values,
- arg_to_param_map=arg_to_param_map,
- **kwargs)
+ if "src" not in transition:
+ transition["src"] = ["UNINITIALIZED"]
+ if "dst" not in transition:
+ transition["dst"] = "UNINITIALIZED"
+ for origin in transition["src"]:
+ pta.add_transition(
+ origin,
+ transition["dst"],
+ trans_name,
+ arguments=arguments,
+ argument_values=argument_values,
+ arg_to_param_map=arg_to_param_map,
+ **kwargs
+ )
return pta
@@ -746,11 +881,13 @@ class PTA:
Compatible with the from_json method.
"""
ret = {
- 'parameters': self.parameters,
- 'initial_param_values': self.initial_param_values,
- 'state': dict([[state.name, state.to_json()] for state in self.state.values()]),
- 'transitions': [trans.to_json() for trans in self.transitions],
- 'accepting_states': self.accepting_states,
+ "parameters": self.parameters,
+ "initial_param_values": self.initial_param_values,
+ "state": dict(
+ [[state.name, state.to_json()] for state in self.state.values()]
+ ),
+ "transitions": [trans.to_json() for trans in self.transitions],
+ "accepting_states": self.accepting_states,
}
return ret
@@ -760,12 +897,19 @@ class PTA:
See the State() documentation for acceptable arguments.
"""
- if 'power_function' in kwargs and type(kwargs['power_function']) != AnalyticFunction and kwargs['power_function'] is not None:
- kwargs['power_function'] = AnalyticFunction(kwargs['power_function'],
- self.parameters, 0)
+ if (
+ "power_function" in kwargs
+ and type(kwargs["power_function"]) != AnalyticFunction
+ and kwargs["power_function"] is not None
+ ):
+ kwargs["power_function"] = AnalyticFunction(
+ kwargs["power_function"], self.parameters, 0
+ )
self.state[state_name] = State(state_name, **kwargs)
- def add_transition(self, orig_state: str, dest_state: str, function_name: str, **kwargs):
+ def add_transition(
+ self, orig_state: str, dest_state: str, function_name: str, **kwargs
+ ):
"""
Add function_name as new transition from orig_state to dest_state.
@@ -776,8 +920,12 @@ class PTA:
"""
orig_state = self.state[orig_state]
dest_state = self.state[dest_state]
- for key in ('duration_function', 'energy_function', 'timeout_function'):
- if key in kwargs and kwargs[key] is not None and type(kwargs[key]) != AnalyticFunction:
+ for key in ("duration_function", "energy_function", "timeout_function"):
+ if (
+ key in kwargs
+ and kwargs[key] is not None
+ and type(kwargs[key]) != AnalyticFunction
+ ):
kwargs[key] = AnalyticFunction(kwargs[key], self.parameters, 0)
new_transition = Transition(orig_state, dest_state, function_name, **kwargs)
@@ -824,7 +972,12 @@ class PTA:
return self.get_unique_transitions().index(transition)
def get_initial_param_dict(self):
- return dict([[self.parameters[i], self.initial_param_values[i]] for i in range(len(self.parameters))])
+ return dict(
+ [
+ [self.parameters[i], self.initial_param_values[i]]
+ for i in range(len(self.parameters))
+ ]
+ )
def set_random_energy_model(self, static_model=True):
u"""
@@ -841,18 +994,24 @@ class PTA:
def get_most_expensive_state(self):
max_state = None
for state in self.state.values():
- if state.name != 'UNINITIALIZED' and (max_state is None or state.power.value > max_state.power.value):
+ if state.name != "UNINITIALIZED" and (
+ max_state is None or state.power.value > max_state.power.value
+ ):
max_state = state
return max_state
def get_least_expensive_state(self):
min_state = None
for state in self.state.values():
- if state.name != 'UNINITIALIZED' and (min_state is None or state.power.value < min_state.power.value):
+ if state.name != "UNINITIALIZED" and (
+ min_state is None or state.power.value < min_state.power.value
+ ):
min_state = state
return min_state
- def min_duration_until_energy_overflow(self, energy_granularity=1e-12, max_energy_value=2 ** 32 - 1):
+ def min_duration_until_energy_overflow(
+ self, energy_granularity=1e-12, max_energy_value=2 ** 32 - 1
+ ):
"""
Return minimum duration (in s) until energy counter overflow during online accounting.
@@ -862,7 +1021,9 @@ class PTA:
max_power_state = self.get_most_expensive_state()
if max_power_state.has_interrupt_transitions():
- raise RuntimeWarning('state with maximum power consumption has outgoing interrupt transitions, results will be inaccurate')
+ raise RuntimeWarning(
+ "state with maximum power consumption has outgoing interrupt transitions, results will be inaccurate"
+ )
# convert from µW to W
max_power = max_power_state.power.value * 1e-6
@@ -870,7 +1031,9 @@ class PTA:
min_duration = max_energy_value * energy_granularity / max_power
return min_duration
- def max_duration_until_energy_overflow(self, energy_granularity=1e-12, max_energy_value=2 ** 32 - 1):
+ def max_duration_until_energy_overflow(
+ self, energy_granularity=1e-12, max_energy_value=2 ** 32 - 1
+ ):
"""
Return maximum duration (in s) until energy counter overflow during online accounting.
@@ -880,7 +1043,9 @@ class PTA:
min_power_state = self.get_least_expensive_state()
if min_power_state.has_interrupt_transitions():
- raise RuntimeWarning('state with maximum power consumption has outgoing interrupt transitions, results will be inaccurate')
+ raise RuntimeWarning(
+ "state with maximum power consumption has outgoing interrupt transitions, results will be inaccurate"
+ )
# convert from µW to W
min_power = min_power_state.power.value * 1e-6
@@ -904,14 +1069,19 @@ class PTA:
for i, argument in enumerate(transition.arguments):
if len(transition.argument_values[i]) <= 2:
continue
- if transition.argument_combination == 'zip':
+ if transition.argument_combination == "zip":
continue
values_are_numeric = True
for value in transition.argument_values[i]:
- if not is_numeric(self.normalize_parameter(transition.arg_to_param_map[i], value)):
+ if not is_numeric(
+ self.normalize_parameter(transition.arg_to_param_map[i], value)
+ ):
values_are_numeric = False
if values_are_numeric and len(transition.argument_values[i]) > 2:
- transition.argument_values[i] = [transition.argument_values[i][0], transition.argument_values[i][-1]]
+ transition.argument_values[i] = [
+ transition.argument_values[i][0],
+ transition.argument_values[i][-1],
+ ]
def _dfs_with_param(self, generator, param_dict):
for trace in generator:
@@ -921,13 +1091,23 @@ class PTA:
transition, arguments = elem
if transition is not None:
param = transition.get_params_after_transition(param, arguments)
- ret.append((transition, arguments, self.normalize_parameters(param)))
+ ret.append(
+ (transition, arguments, self.normalize_parameters(param))
+ )
else:
# parameters have already been normalized
ret.append((transition, arguments, param))
yield ret
- def bfs(self, depth: int = 10, orig_state: str = 'UNINITIALIZED', param_dict: dict = None, with_parameters: bool = False, transition_filter=None, state_filter=None):
+ def bfs(
+ self,
+ depth: int = 10,
+ orig_state: str = "UNINITIALIZED",
+ param_dict: dict = None,
+ with_parameters: bool = False,
+ transition_filter=None,
+ state_filter=None,
+ ):
"""
Return a generator object for breadth-first search of traces starting at orig_state.
@@ -968,7 +1148,14 @@ class PTA:
yield new_trace
state_queue.put((new_trace, transition.destination))
- def dfs(self, depth: int = 10, orig_state: str = 'UNINITIALIZED', param_dict: dict = None, with_parameters: bool = False, **kwargs):
+ def dfs(
+ self,
+ depth: int = 10,
+ orig_state: str = "UNINITIALIZED",
+ param_dict: dict = None,
+ with_parameters: bool = False,
+ **kwargs
+ ):
"""
Return a generator object for depth-first search starting at orig_state.
@@ -994,12 +1181,14 @@ class PTA:
if with_parameters and not param_dict:
param_dict = self.get_initial_param_dict()
- if with_parameters and 'with_arguments' not in kwargs:
+ if with_parameters and "with_arguments" not in kwargs:
raise ValueError("with_parameters = True requires with_arguments = True")
if self.accepting_states:
- generator = filter(lambda x: x[-1][0].destination.name in self.accepting_states,
- self.state[orig_state].dfs(depth, **kwargs))
+ generator = filter(
+ lambda x: x[-1][0].destination.name in self.accepting_states,
+ self.state[orig_state].dfs(depth, **kwargs),
+ )
else:
generator = self.state[orig_state].dfs(depth, **kwargs)
@@ -1008,7 +1197,13 @@ class PTA:
else:
return generator
- def simulate(self, trace: list, orig_state: str = 'UNINITIALIZED', orig_param=None, accounting=None):
+ def simulate(
+ self,
+ trace: list,
+ orig_state: str = "UNINITIALIZED",
+ orig_param=None,
+ accounting=None,
+ ):
u"""
Simulate a single run through the PTA and return total energy, duration, final state, and resulting parameters.
@@ -1021,10 +1216,10 @@ class PTA:
:returns: SimulationResult with duration in s, total energy in J, end state, and final parameters
"""
- total_duration = 0.
- total_duration_mae = 0.
- total_energy = 0.
- total_energy_error = 0.
+ total_duration = 0.0
+ total_duration_mae = 0.0
+ total_energy = 0.0
+ total_energy_error = 0.0
if type(orig_state) is State:
state = orig_state
else:
@@ -1032,7 +1227,12 @@ class PTA:
if orig_param:
param_dict = orig_param.copy()
else:
- param_dict = dict([[self.parameters[i], self.initial_param_values[i]] for i in range(len(self.parameters))])
+ param_dict = dict(
+ [
+ [self.parameters[i], self.initial_param_values[i]]
+ for i in range(len(self.parameters))
+ ]
+ )
for function in trace:
if isinstance(function[0], Transition):
function_name = function[0].name
@@ -1040,11 +1240,13 @@ class PTA:
else:
function_name = function[0]
function_args = function[1:]
- if function_name is None or function_name == '_':
+ if function_name is None or function_name == "_":
duration = function_args[0]
total_energy += state.get_energy(duration, param_dict)
if state.power.value_error is not None:
- total_energy_error += (duration * state.power.eval_mae(param_dict, function_args))**2
+ total_energy_error += (
+ duration * state.power.eval_mae(param_dict, function_args)
+ ) ** 2
total_duration += duration
# assumption: sleep is near-exact and does not contribute to the duration error
if accounting is not None:
@@ -1053,15 +1255,21 @@ class PTA:
transition = state.get_transition(function_name)
total_duration += transition.duration.eval(param_dict, function_args)
if transition.duration.value_error is not None:
- total_duration_mae += transition.duration.eval_mae(param_dict, function_args)**2
+ total_duration_mae += (
+ transition.duration.eval_mae(param_dict, function_args) ** 2
+ )
total_energy += transition.get_energy(param_dict, function_args)
if transition.energy.value_error is not None:
- total_energy_error += transition.energy.eval_mae(param_dict, function_args)**2
- param_dict = transition.get_params_after_transition(param_dict, function_args)
+ total_energy_error += (
+ transition.energy.eval_mae(param_dict, function_args) ** 2
+ )
+ param_dict = transition.get_params_after_transition(
+ param_dict, function_args
+ )
state = transition.destination
if accounting is not None:
accounting.pass_transition(transition)
- while (state.has_interrupt_transitions()):
+ while state.has_interrupt_transitions():
transition = state.get_next_interrupt(param_dict)
duration = transition.get_timeout(param_dict)
total_duration += duration
@@ -1072,45 +1280,82 @@ class PTA:
param_dict = transition.get_params_after_transition(param_dict)
state = transition.destination
- return SimulationResult(total_duration, total_energy, state, param_dict, duration_mae=np.sqrt(total_duration_mae), energy_mae=np.sqrt(total_energy_error))
+ return SimulationResult(
+ total_duration,
+ total_energy,
+ state,
+ param_dict,
+ duration_mae=np.sqrt(total_duration_mae),
+ energy_mae=np.sqrt(total_energy_error),
+ )
def update(self, static_model, param_model, static_error=None, analytic_error=None):
for state in self.state.values():
- if state.name != 'UNINITIALIZED':
+ if state.name != "UNINITIALIZED":
try:
- state.power.value = static_model(state.name, 'power')
+ state.power.value = static_model(state.name, "power")
if static_error is not None:
- state.power.value_error = static_error[state.name]['power']
- if param_model(state.name, 'power'):
- state.power.function = param_model(state.name, 'power')['function']
+ state.power.value_error = static_error[state.name]["power"]
+ if param_model(state.name, "power"):
+ state.power.function = param_model(state.name, "power")[
+ "function"
+ ]
if analytic_error is not None:
- state.power.function_error = analytic_error[state.name]['power']
+ state.power.function_error = analytic_error[state.name][
+ "power"
+ ]
except KeyError:
- print('[W] skipping model update of state {} due to missing data'.format(state.name))
+ print(
+ "[W] skipping model update of state {} due to missing data".format(
+ state.name
+ )
+ )
pass
for transition in self.transitions:
try:
- transition.duration.value = static_model(transition.name, 'duration')
- if param_model(transition.name, 'duration'):
- transition.duration.function = param_model(transition.name, 'duration')['function']
+ transition.duration.value = static_model(transition.name, "duration")
+ if param_model(transition.name, "duration"):
+ transition.duration.function = param_model(
+ transition.name, "duration"
+ )["function"]
if analytic_error is not None:
- transition.duration.function_error = analytic_error[transition.name]['duration']
- transition.energy.value = static_model(transition.name, 'energy')
- if param_model(transition.name, 'energy'):
- transition.energy.function = param_model(transition.name, 'energy')['function']
+ transition.duration.function_error = analytic_error[
+ transition.name
+ ]["duration"]
+ transition.energy.value = static_model(transition.name, "energy")
+ if param_model(transition.name, "energy"):
+ transition.energy.function = param_model(transition.name, "energy")[
+ "function"
+ ]
if analytic_error is not None:
- transition.energy.function_error = analytic_error[transition.name]['energy']
+ transition.energy.function_error = analytic_error[
+ transition.name
+ ]["energy"]
if transition.is_interrupt:
- transition.timeout.value = static_model(transition.name, 'timeout')
- if param_model(transition.name, 'timeout'):
- transition.timeout.function = param_model(transition.name, 'timeout')['function']
+ transition.timeout.value = static_model(transition.name, "timeout")
+ if param_model(transition.name, "timeout"):
+ transition.timeout.function = param_model(
+ transition.name, "timeout"
+ )["function"]
if analytic_error is not None:
- transition.timeout.function_error = analytic_error[transition.name]['timeout']
+ transition.timeout.function_error = analytic_error[
+ transition.name
+ ]["timeout"]
if static_error is not None:
- transition.duration.value_error = static_error[transition.name]['duration']
- transition.energy.value_error = static_error[transition.name]['energy']
- transition.timeout.value_error = static_error[transition.name]['timeout']
+ transition.duration.value_error = static_error[transition.name][
+ "duration"
+ ]
+ transition.energy.value_error = static_error[transition.name][
+ "energy"
+ ]
+ transition.timeout.value_error = static_error[transition.name][
+ "timeout"
+ ]
except KeyError:
- print('[W] skipping model update of transition {} due to missing data'.format(transition.name))
+ print(
+ "[W] skipping model update of transition {} due to missing data".format(
+ transition.name
+ )
+ )
pass
diff --git a/lib/codegen.py b/lib/codegen.py
index e0bf45f..62776fd 100644
--- a/lib/codegen.py
+++ b/lib/codegen.py
@@ -60,38 +60,46 @@ class ClassFunction:
self.body = body
def get_definition(self):
- return '{} {}({});'.format(self.return_type, self.name, ', '.join(self.arguments))
+ return "{} {}({});".format(
+ self.return_type, self.name, ", ".join(self.arguments)
+ )
def get_implementation(self):
if self.body is None:
- return ''
- return '{} {}::{}({}) {{\n{}}}\n'.format(self.return_type, self.class_name, self.name, ', '.join(self.arguments), self.body)
+ return ""
+ return "{} {}::{}({}) {{\n{}}}\n".format(
+ self.return_type,
+ self.class_name,
+ self.name,
+ ", ".join(self.arguments),
+ self.body,
+ )
def get_accountingmethod(method):
"""Return AccountingMethod class for method."""
- if method == 'static_state_immediate':
+ if method == "static_state_immediate":
return StaticStateOnlyAccountingImmediateCalculation
- if method == 'static_state':
+ if method == "static_state":
return StaticStateOnlyAccounting
- if method == 'static_statetransition_immediate':
+ if method == "static_statetransition_immediate":
return StaticAccountingImmediateCalculation
- if method == 'static_statetransition':
+ if method == "static_statetransition":
return StaticAccounting
- raise ValueError('Unknown accounting method: {}'.format(method))
+ raise ValueError("Unknown accounting method: {}".format(method))
def get_simulated_accountingmethod(method):
"""Return SimulatedAccountingMethod class for method."""
- if method == 'static_state_immediate':
+ if method == "static_state_immediate":
return SimulatedStaticStateOnlyAccountingImmediateCalculation
- if method == 'static_statetransition_immediate':
+ if method == "static_statetransition_immediate":
return SimulatedStaticAccountingImmediateCalculation
- if method == 'static_state':
+ if method == "static_state":
return SimulatedStaticStateOnlyAccounting
- if method == 'static_statetransition':
+ if method == "static_statetransition":
return SimulatedStaticAccounting
- raise ValueError('Unknown accounting method: {}'.format(method))
+ raise ValueError("Unknown accounting method: {}".format(method))
class SimulatedAccountingMethod:
@@ -104,7 +112,18 @@ class SimulatedAccountingMethod:
* variable size for accounting of durations, power and energy values
"""
- def __init__(self, pta: PTA, timer_freq_hz, timer_type, ts_type, power_type, energy_type, ts_granularity=1e-6, power_granularity=1e-6, energy_granularity=1e-12):
+ def __init__(
+ self,
+ pta: PTA,
+ timer_freq_hz,
+ timer_type,
+ ts_type,
+ power_type,
+ energy_type,
+ ts_granularity=1e-6,
+ power_granularity=1e-6,
+ energy_granularity=1e-12,
+ ):
"""
Simulate Online Accounting for a given PTA.
@@ -121,7 +140,7 @@ class SimulatedAccountingMethod:
self.ts_class = simulate_int_type(ts_type)
self.power_class = simulate_int_type(power_type)
self.energy_class = simulate_int_type(energy_type)
- self.current_state = pta.state['UNINITIALIZED']
+ self.current_state = pta.state["UNINITIALIZED"]
self.ts_granularity = ts_granularity
self.power_granularity = power_granularity
@@ -137,7 +156,13 @@ class SimulatedAccountingMethod:
Does not use Module types and therefore does not consider overflows or data-type limitations"""
if self.energy_granularity == self.power_granularity * self.ts_granularity:
return power * time
- return int(power * self.power_granularity * time * self.ts_granularity / self.energy_granularity)
+ return int(
+ power
+ * self.power_granularity
+ * time
+ * self.ts_granularity
+ / self.energy_granularity
+ )
def _sleep_duration(self, duration_us):
u"""
@@ -202,11 +227,11 @@ class SimulatedStaticAccountingImmediateCalculation(SimulatedAccountingMethod):
def sleep(self, duration_us):
time = self._sleep_duration(duration_us)
- print('sleep duration is {}'.format(time))
+ print("sleep duration is {}".format(time))
power = int(self.current_state.power.value)
- print('power is {}'.format(power))
+ print("power is {}".format(power))
energy = self._energy_from_power_and_time(time, power)
- print('energy is {}'.format(energy))
+ print("energy is {}".format(energy))
self.energy += energy
def pass_transition(self, transition: Transition):
@@ -232,7 +257,7 @@ class SimulatedStaticAccounting(SimulatedAccountingMethod):
self.time_in_state[state_name] = self.ts_class(0)
self.transition_count = list()
for transition in pta.transitions:
- self.transition_count.append(simulate_int_type('uint16_t')(0))
+ self.transition_count.append(simulate_int_type("uint16_t")(0))
def sleep(self, duration_us):
self.time_in_state[self.current_state.name] += self._sleep_duration(duration_us)
@@ -245,7 +270,9 @@ class SimulatedStaticAccounting(SimulatedAccountingMethod):
pta = self.pta
energy = self.energy_class(0)
for state in pta.state.values():
- energy += self._energy_from_power_and_time(self.time_in_state[state.name], int(state.power.value))
+ energy += self._energy_from_power_and_time(
+ self.time_in_state[state.name], int(state.power.value)
+ )
for i, transition in enumerate(pta.transitions):
energy += self.transition_count[i] * int(transition.energy.value)
return energy.val
@@ -275,7 +302,9 @@ class SimulatedStaticStateOnlyAccounting(SimulatedAccountingMethod):
pta = self.pta
energy = self.energy_class(0)
for state in pta.state.values():
- energy += self._energy_from_power_and_time(self.time_in_state[state.name], int(state.power.value))
+ energy += self._energy_from_power_and_time(
+ self.time_in_state[state.name], int(state.power.value)
+ )
return energy.val
@@ -290,32 +319,50 @@ class AccountingMethod:
self.public_functions = list()
def pre_transition_hook(self, transition):
- return ''
+ return ""
def init_code(self):
- return ''
+ return ""
def get_includes(self):
return map(lambda x: '#include "{}"'.format(x), self.include_paths)
class StaticStateOnlyAccountingImmediateCalculation(AccountingMethod):
- def __init__(self, class_name: str, pta: PTA, ts_type='unsigned int', power_type='unsigned int', energy_type='unsigned long'):
+ def __init__(
+ self,
+ class_name: str,
+ pta: PTA,
+ ts_type="unsigned int",
+ power_type="unsigned int",
+ energy_type="unsigned long",
+ ):
super().__init__(class_name, pta)
self.ts_type = ts_type
- self.include_paths.append('driver/uptime.h')
- self.private_variables.append('unsigned char lastState;')
- self.private_variables.append('{} lastStateChange;'.format(ts_type))
- self.private_variables.append('{} totalEnergy;'.format(energy_type))
- self.private_variables.append(array_template.format(
- type=power_type,
- name='state_power',
- length=len(pta.state),
- elements=', '.join(map(lambda state_name: '{:.0f}'.format(pta.state[state_name].power), pta.get_state_names()))
- ))
+ self.include_paths.append("driver/uptime.h")
+ self.private_variables.append("unsigned char lastState;")
+ self.private_variables.append("{} lastStateChange;".format(ts_type))
+ self.private_variables.append("{} totalEnergy;".format(energy_type))
+ self.private_variables.append(
+ array_template.format(
+ type=power_type,
+ name="state_power",
+ length=len(pta.state),
+ elements=", ".join(
+ map(
+ lambda state_name: "{:.0f}".format(pta.state[state_name].power),
+ pta.get_state_names(),
+ )
+ ),
+ )
+ )
get_energy_function = """return totalEnergy;"""
- self.public_functions.append(ClassFunction(class_name, energy_type, 'getEnergy', list(), get_energy_function))
+ self.public_functions.append(
+ ClassFunction(
+ class_name, energy_type, "getEnergy", list(), get_energy_function
+ )
+ )
def pre_transition_hook(self, transition):
return """
@@ -323,30 +370,50 @@ class StaticStateOnlyAccountingImmediateCalculation(AccountingMethod):
totalEnergy += (now - lastStateChange) * state_power[lastState];
lastStateChange = now;
lastState = {};
- """.format(self.pta.get_state_id(transition.destination))
+ """.format(
+ self.pta.get_state_id(transition.destination)
+ )
def init_code(self):
return """
totalEnergy = 0;
lastStateChange = 0;
lastState = 0;
- """.format(num_states=len(self.pta.state))
+ """.format(
+ num_states=len(self.pta.state)
+ )
class StaticStateOnlyAccounting(AccountingMethod):
- def __init__(self, class_name: str, pta: PTA, ts_type='unsigned int', power_type='unsigned int', energy_type='unsigned long'):
+ def __init__(
+ self,
+ class_name: str,
+ pta: PTA,
+ ts_type="unsigned int",
+ power_type="unsigned int",
+ energy_type="unsigned long",
+ ):
super().__init__(class_name, pta)
self.ts_type = ts_type
- self.include_paths.append('driver/uptime.h')
- self.private_variables.append('unsigned char lastState;')
- self.private_variables.append('{} lastStateChange;'.format(ts_type))
- self.private_variables.append(array_template.format(
- type=power_type,
- name='state_power',
- length=len(pta.state),
- elements=', '.join(map(lambda state_name: '{:.0f}'.format(pta.state[state_name].power), pta.get_state_names()))
- ))
- self.private_variables.append('{} timeInState[{}];'.format(ts_type, len(pta.state)))
+ self.include_paths.append("driver/uptime.h")
+ self.private_variables.append("unsigned char lastState;")
+ self.private_variables.append("{} lastStateChange;".format(ts_type))
+ self.private_variables.append(
+ array_template.format(
+ type=power_type,
+ name="state_power",
+ length=len(pta.state),
+ elements=", ".join(
+ map(
+ lambda state_name: "{:.0f}".format(pta.state[state_name].power),
+ pta.get_state_names(),
+ )
+ ),
+ )
+ )
+ self.private_variables.append(
+ "{} timeInState[{}];".format(ts_type, len(pta.state))
+ )
get_energy_function = """
{energy_type} total_energy = 0;
@@ -354,8 +421,14 @@ class StaticStateOnlyAccounting(AccountingMethod):
total_energy += timeInState[i] * state_power[i];
}}
return total_energy;
- """.format(energy_type=energy_type, num_states=len(pta.state))
- self.public_functions.append(ClassFunction(class_name, energy_type, 'getEnergy', list(), get_energy_function))
+ """.format(
+ energy_type=energy_type, num_states=len(pta.state)
+ )
+ self.public_functions.append(
+ ClassFunction(
+ class_name, energy_type, "getEnergy", list(), get_energy_function
+ )
+ )
def pre_transition_hook(self, transition):
return """
@@ -363,7 +436,9 @@ class StaticStateOnlyAccounting(AccountingMethod):
timeInState[lastState] += now - lastStateChange;
lastStateChange = now;
lastState = {};
- """.format(self.pta.get_state_id(transition.destination))
+ """.format(
+ self.pta.get_state_id(transition.destination)
+ )
def init_code(self):
return """
@@ -372,30 +447,59 @@ class StaticStateOnlyAccounting(AccountingMethod):
}}
lastState = 0;
lastStateChange = 0;
- """.format(num_states=len(self.pta.state))
+ """.format(
+ num_states=len(self.pta.state)
+ )
class StaticAccounting(AccountingMethod):
- def __init__(self, class_name: str, pta: PTA, ts_type='unsigned int', power_type='unsigned int', energy_type='unsigned long'):
+ def __init__(
+ self,
+ class_name: str,
+ pta: PTA,
+ ts_type="unsigned int",
+ power_type="unsigned int",
+ energy_type="unsigned long",
+ ):
super().__init__(class_name, pta)
self.ts_type = ts_type
- self.include_paths.append('driver/uptime.h')
- self.private_variables.append('unsigned char lastState;')
- self.private_variables.append('{} lastStateChange;'.format(ts_type))
- self.private_variables.append(array_template.format(
- type=power_type,
- name='state_power',
- length=len(pta.state),
- elements=', '.join(map(lambda state_name: '{:.0f}'.format(pta.state[state_name].power), pta.get_state_names()))
- ))
- self.private_variables.append(array_template.format(
- type=energy_type,
- name='transition_energy',
- length=len(pta.get_unique_transitions()),
- elements=', '.join(map(lambda transition: '{:.0f}'.format(transition.energy), pta.get_unique_transitions()))
- ))
- self.private_variables.append('{} timeInState[{}];'.format(ts_type, len(pta.state)))
- self.private_variables.append('{} transitionCount[{}];'.format('unsigned int', len(pta.get_unique_transitions())))
+ self.include_paths.append("driver/uptime.h")
+ self.private_variables.append("unsigned char lastState;")
+ self.private_variables.append("{} lastStateChange;".format(ts_type))
+ self.private_variables.append(
+ array_template.format(
+ type=power_type,
+ name="state_power",
+ length=len(pta.state),
+ elements=", ".join(
+ map(
+ lambda state_name: "{:.0f}".format(pta.state[state_name].power),
+ pta.get_state_names(),
+ )
+ ),
+ )
+ )
+ self.private_variables.append(
+ array_template.format(
+ type=energy_type,
+ name="transition_energy",
+ length=len(pta.get_unique_transitions()),
+ elements=", ".join(
+ map(
+ lambda transition: "{:.0f}".format(transition.energy),
+ pta.get_unique_transitions(),
+ )
+ ),
+ )
+ )
+ self.private_variables.append(
+ "{} timeInState[{}];".format(ts_type, len(pta.state))
+ )
+ self.private_variables.append(
+ "{} transitionCount[{}];".format(
+ "unsigned int", len(pta.get_unique_transitions())
+ )
+ )
get_energy_function = """
{energy_type} total_energy = 0;
@@ -406,8 +510,16 @@ class StaticAccounting(AccountingMethod):
total_energy += transitionCount[i] * transition_energy[i];
}}
return total_energy;
- """.format(energy_type=energy_type, num_states=len(pta.state), num_transitions=len(pta.get_unique_transitions()))
- self.public_functions.append(ClassFunction(class_name, energy_type, 'getEnergy', list(), get_energy_function))
+ """.format(
+ energy_type=energy_type,
+ num_states=len(pta.state),
+ num_transitions=len(pta.get_unique_transitions()),
+ )
+ self.public_functions.append(
+ ClassFunction(
+ class_name, energy_type, "getEnergy", list(), get_energy_function
+ )
+ )
def pre_transition_hook(self, transition):
return """
@@ -416,7 +528,10 @@ class StaticAccounting(AccountingMethod):
transitionCount[{}]++;
lastStateChange = now;
lastState = {};
- """.format(self.pta.get_unique_transition_id(transition), self.pta.get_state_id(transition.destination))
+ """.format(
+ self.pta.get_unique_transition_id(transition),
+ self.pta.get_state_id(transition.destination),
+ )
def init_code(self):
return """
@@ -428,28 +543,53 @@ class StaticAccounting(AccountingMethod):
}}
lastState = 0;
lastStateChange = 0;
- """.format(num_states=len(self.pta.state), num_transitions=len(self.pta.get_unique_transitions()))
+ """.format(
+ num_states=len(self.pta.state),
+ num_transitions=len(self.pta.get_unique_transitions()),
+ )
class StaticAccountingImmediateCalculation(AccountingMethod):
- def __init__(self, class_name: str, pta: PTA, ts_type='unsigned int', power_type='unsigned int', energy_type='unsigned long'):
+ def __init__(
+ self,
+ class_name: str,
+ pta: PTA,
+ ts_type="unsigned int",
+ power_type="unsigned int",
+ energy_type="unsigned long",
+ ):
super().__init__(class_name, pta)
self.ts_type = ts_type
- self.include_paths.append('driver/uptime.h')
- self.private_variables.append('unsigned char lastState;')
- self.private_variables.append('{} lastStateChange;'.format(ts_type))
- self.private_variables.append('{} totalEnergy;'.format(energy_type))
- self.private_variables.append(array_template.format(
- type=power_type,
- name='state_power',
- length=len(pta.state),
- elements=', '.join(map(lambda state_name: '{:.0f}'.format(pta.state[state_name].power), pta.get_state_names()))
- ))
+ self.include_paths.append("driver/uptime.h")
+ self.private_variables.append("unsigned char lastState;")
+ self.private_variables.append("{} lastStateChange;".format(ts_type))
+ self.private_variables.append("{} totalEnergy;".format(energy_type))
+ self.private_variables.append(
+ array_template.format(
+ type=power_type,
+ name="state_power",
+ length=len(pta.state),
+ elements=", ".join(
+ map(
+ lambda state_name: "{:.0f}".format(pta.state[state_name].power),
+ pta.get_state_names(),
+ )
+ ),
+ )
+ )
get_energy_function = """
return totalEnergy;
- """.format(energy_type=energy_type, num_states=len(pta.state), num_transitions=len(pta.get_unique_transitions()))
- self.public_functions.append(ClassFunction(class_name, energy_type, 'getEnergy', list(), get_energy_function))
+ """.format(
+ energy_type=energy_type,
+ num_states=len(pta.state),
+ num_transitions=len(pta.get_unique_transitions()),
+ )
+ self.public_functions.append(
+ ClassFunction(
+ class_name, energy_type, "getEnergy", list(), get_energy_function
+ )
+ )
def pre_transition_hook(self, transition):
return """
@@ -458,21 +598,26 @@ class StaticAccountingImmediateCalculation(AccountingMethod):
totalEnergy += {};
lastStateChange = now;
lastState = {};
- """.format(transition.energy, self.pta.get_state_id(transition.destination))
+ """.format(
+ transition.energy, self.pta.get_state_id(transition.destination)
+ )
def init_code(self):
return """
lastState = 0;
lastStateChange = 0;
- """.format(num_states=len(self.pta.state), num_transitions=len(self.pta.get_unique_transitions()))
+ """.format(
+ num_states=len(self.pta.state),
+ num_transitions=len(self.pta.get_unique_transitions()),
+ )
class MultipassDriver:
"""Generate C++ header and no-op implementation for a multipass driver based on a DFA model."""
def __init__(self, name, pta, class_info, enum=dict(), accounting=AccountingMethod):
- self.impl = ''
- self.header = ''
+ self.impl = ""
+ self.header = ""
self.name = name
self.pta = pta
self.class_info = class_info
@@ -484,35 +629,53 @@ class MultipassDriver:
private_variables = list()
public_variables = list()
- public_functions.append(ClassFunction(self.name, '', self.name, list(), accounting.init_code()))
+ public_functions.append(
+ ClassFunction(self.name, "", self.name, list(), accounting.init_code())
+ )
for transition in self.pta.get_unique_transitions():
- if transition.name == 'getEnergy':
+ if transition.name == "getEnergy":
continue
# XXX right now we only verify whether both functions have the
# same number of arguments. This breaks in many overloading cases.
function_info = self.class_info.function[transition.name]
for function_candidate in self.class_info.functions:
- if function_candidate.name == transition.name and len(function_candidate.argument_types) == len(transition.arguments):
+ if function_candidate.name == transition.name and len(
+ function_candidate.argument_types
+ ) == len(transition.arguments):
function_info = function_candidate
function_arguments = list()
for i in range(len(transition.arguments)):
- function_arguments.append('{} {}'.format(function_info.argument_types[i], transition.arguments[i]))
+ function_arguments.append(
+ "{} {}".format(
+ function_info.argument_types[i], transition.arguments[i]
+ )
+ )
function_body = accounting.pre_transition_hook(transition)
- if function_info.return_type != 'void':
- function_body += 'return 0;\n'
+ if function_info.return_type != "void":
+ function_body += "return 0;\n"
- public_functions.append(ClassFunction(self.name, function_info.return_type, transition.name, function_arguments, function_body))
+ public_functions.append(
+ ClassFunction(
+ self.name,
+ function_info.return_type,
+ transition.name,
+ function_arguments,
+ function_body,
+ )
+ )
enums = list()
for enum_name in self.enum.keys():
- enums.append('enum {} {{ {} }};'.format(enum_name, ', '.join(self.enum[enum_name])))
+ enums.append(
+ "enum {} {{ {} }};".format(enum_name, ", ".join(self.enum[enum_name]))
+ )
if accounting:
includes.extend(accounting.get_includes())
@@ -522,11 +685,21 @@ class MultipassDriver:
public_variables.extend(accounting.public_variables)
self.header = header_template.format(
- name=self.name, name_lower=self.name.lower(),
- includes='\n'.join(includes),
- private_variables='\n'.join(private_variables),
- public_variables='\n'.join(public_variables),
- public_functions='\n'.join(map(lambda x: x.get_definition(), public_functions)),
- private_functions='',
- enums='\n'.join(enums))
- self.impl = implementation_template.format(name=self.name, name_lower=self.name.lower(), functions='\n\n'.join(map(lambda x: x.get_implementation(), public_functions)))
+ name=self.name,
+ name_lower=self.name.lower(),
+ includes="\n".join(includes),
+ private_variables="\n".join(private_variables),
+ public_variables="\n".join(public_variables),
+ public_functions="\n".join(
+ map(lambda x: x.get_definition(), public_functions)
+ ),
+ private_functions="",
+ enums="\n".join(enums),
+ )
+ self.impl = implementation_template.format(
+ name=self.name,
+ name_lower=self.name.lower(),
+ functions="\n\n".join(
+ map(lambda x: x.get_implementation(), public_functions)
+ ),
+ )
diff --git a/lib/cycles_to_energy.py b/lib/cycles_to_energy.py
index 35f9199..a8e88b8 100644
--- a/lib/cycles_to_energy.py
+++ b/lib/cycles_to_energy.py
@@ -5,86 +5,80 @@ Contains classes for some embedded CPUs/MCUs. Given a configuration, each
class can convert a cycle count to an energy consumption.
"""
+
def get_class(cpu_name):
"""Return model class for cpu_name."""
- if cpu_name == 'MSP430':
+ if cpu_name == "MSP430":
return MSP430
- if cpu_name == 'ATMega168':
+ if cpu_name == "ATMega168":
return ATMega168
- if cpu_name == 'ATMega328':
+ if cpu_name == "ATMega328":
return ATMega328
- if cpu_name == 'ATTiny88':
+ if cpu_name == "ATTiny88":
return ATTiny88
- if cpu_name == 'esp8266':
+ if cpu_name == "esp8266":
return ESP8266
+
def _param_list_to_dict(device, param_list):
param_dict = dict()
for i, parameter in enumerate(sorted(device.parameters.keys())):
param_dict[parameter] = param_list[i]
return param_dict
+
class MSP430:
- name = 'MSP430'
+ name = "MSP430"
parameters = {
- 'cpu_freq': [1e6, 4e6, 8e6, 12e6, 16e6],
- 'memory' : ['unified', 'fram0', 'fram50', 'fram66', 'fram75', 'fram100', 'ram'],
- 'voltage': [2.2, 3.0],
- }
- default_params = {
- 'cpu_freq': 4e6,
- 'memory' : 'unified',
- 'voltage': 3
+ "cpu_freq": [1e6, 4e6, 8e6, 12e6, 16e6],
+ "memory": ["unified", "fram0", "fram50", "fram66", "fram75", "fram100", "ram"],
+ "voltage": [2.2, 3.0],
}
+ default_params = {"cpu_freq": 4e6, "memory": "unified", "voltage": 3}
current_by_mem = {
- 'unified' : [210, 640, 1220, 1475, 1845],
- 'fram0' : [370, 1280, 2510, 2080, 2650],
- 'fram50' : [240, 745, 1440, 1575, 1990],
- 'fram66' : [200, 560, 1070, 1300, 1620],
- 'fram75' : [170, 480, 890, 1155, 1420],
- 'fram100' : [110, 235, 420, 640, 730],
- 'ram' : [130, 320, 585, 890, 1070],
+ "unified": [210, 640, 1220, 1475, 1845],
+ "fram0": [370, 1280, 2510, 2080, 2650],
+ "fram50": [240, 745, 1440, 1575, 1990],
+ "fram66": [200, 560, 1070, 1300, 1620],
+ "fram75": [170, 480, 890, 1155, 1420],
+ "fram100": [110, 235, 420, 640, 730],
+ "ram": [130, 320, 585, 890, 1070],
}
def get_current(params):
if type(params) != dict:
return MSP430.get_current(_param_list_to_dict(MSP430, params))
- cpu_freq_index = MSP430.parameters['cpu_freq'].index(params['cpu_freq'])
+ cpu_freq_index = MSP430.parameters["cpu_freq"].index(params["cpu_freq"])
- return MSP430.current_by_mem[params['memory']][cpu_freq_index] * 1e-6
+ return MSP430.current_by_mem[params["memory"]][cpu_freq_index] * 1e-6
def get_power(params):
if type(params) != dict:
return MSP430.get_energy(_param_list_to_dict(MSP430, params))
- return MSP430.get_current(params) * params['voltage']
+ return MSP430.get_current(params) * params["voltage"]
def get_energy_per_cycle(params):
if type(params) != dict:
return MSP430.get_energy_per_cycle(_param_list_to_dict(MSP430, params))
- return MSP430.get_power(params) / params['cpu_freq']
+ return MSP430.get_power(params) / params["cpu_freq"]
+
class ATMega168:
- name = 'ATMega168'
- parameters = {
- 'cpu_freq': [1e6, 4e6, 8e6],
- 'voltage': [2, 3, 5]
- }
- default_params = {
- 'cpu_freq': 4e6,
- 'voltage': 3
- }
+ name = "ATMega168"
+ parameters = {"cpu_freq": [1e6, 4e6, 8e6], "voltage": [2, 3, 5]}
+ default_params = {"cpu_freq": 4e6, "voltage": 3}
def get_current(params):
if type(params) != dict:
return ATMega168.get_current(_param_list_to_dict(ATMega168, params))
- if params['cpu_freq'] == 1e6 and params['voltage'] <= 2:
+ if params["cpu_freq"] == 1e6 and params["voltage"] <= 2:
return 0.5e-3
- if params['cpu_freq'] == 4e6 and params['voltage'] <= 3:
+ if params["cpu_freq"] == 4e6 and params["voltage"] <= 3:
return 3.5e-3
- if params['cpu_freq'] == 8e6 and params['voltage'] <= 5:
+ if params["cpu_freq"] == 8e6 and params["voltage"] <= 5:
return 12e-3
return None
@@ -92,37 +86,34 @@ class ATMega168:
if type(params) != dict:
return ATMega168.get_energy(_param_list_to_dict(ATMega168, params))
- return ATMega168.get_current(params) * params['voltage']
+ return ATMega168.get_current(params) * params["voltage"]
def get_energy_per_cycle(params):
if type(params) != dict:
- return ATMega168.get_energy_per_cycle(_param_list_to_dict(ATMega168, params))
+ return ATMega168.get_energy_per_cycle(
+ _param_list_to_dict(ATMega168, params)
+ )
+
+ return ATMega168.get_power(params) / params["cpu_freq"]
- return ATMega168.get_power(params) / params['cpu_freq']
class ATMega328:
- name = 'ATMega328'
- parameters = {
- 'cpu_freq': [1e6, 4e6, 8e6],
- 'voltage': [2, 3, 5]
- }
- default_params = {
- 'cpu_freq': 4e6,
- 'voltage': 3
- }
+ name = "ATMega328"
+ parameters = {"cpu_freq": [1e6, 4e6, 8e6], "voltage": [2, 3, 5]}
+ default_params = {"cpu_freq": 4e6, "voltage": 3}
# Source: ATMega328P Datasheet p.316 / table 28.2.4
def get_current(params):
if type(params) != dict:
return ATMega328.get_current(_param_list_to_dict(ATMega328, params))
# specified for 2V
- if params['cpu_freq'] == 1e6 and params['voltage'] >= 1.8:
+ if params["cpu_freq"] == 1e6 and params["voltage"] >= 1.8:
return 0.3e-3
# specified for 3V
- if params['cpu_freq'] == 4e6 and params['voltage'] >= 1.8:
+ if params["cpu_freq"] == 4e6 and params["voltage"] >= 1.8:
return 1.7e-3
# specified for 5V
- if params['cpu_freq'] == 8e6 and params['voltage'] >= 2.5:
+ if params["cpu_freq"] == 8e6 and params["voltage"] >= 2.5:
return 5.2e-3
return None
@@ -130,26 +121,26 @@ class ATMega328:
if type(params) != dict:
return ATMega328.get_energy(_param_list_to_dict(ATMega328, params))
- return ATMega328.get_current(params) * params['voltage']
+ return ATMega328.get_current(params) * params["voltage"]
def get_energy_per_cycle(params):
if type(params) != dict:
- return ATMega328.get_energy_per_cycle(_param_list_to_dict(ATMega328, params))
+ return ATMega328.get_energy_per_cycle(
+ _param_list_to_dict(ATMega328, params)
+ )
+
+ return ATMega328.get_power(params) / params["cpu_freq"]
- return ATMega328.get_power(params) / params['cpu_freq']
class ESP8266:
# Source: ESP8266EX Datasheet, table 5-2 (v2017.11) / table 3-4 (v2018.11)
# Taken at 3.0V
- name = 'ESP8266'
+ name = "ESP8266"
parameters = {
- 'cpu_freq': [80e6],
- 'voltage': [2.5, 3.0, 3.3, 3.6] # min / ... / typ / max
- }
- default_params = {
- 'cpu_freq': 80e6,
- 'voltage': 3.0
+ "cpu_freq": [80e6],
+ "voltage": [2.5, 3.0, 3.3, 3.6], # min / ... / typ / max
}
+ default_params = {"cpu_freq": 80e6, "voltage": 3.0}
def get_current(params):
if type(params) != dict:
@@ -161,33 +152,27 @@ class ESP8266:
if type(params) != dict:
return ESP8266.get_power(_param_list_to_dict(ESP8266, params))
- return ESP8266.get_current(params) * params['voltage']
+ return ESP8266.get_current(params) * params["voltage"]
def get_energy_per_cycle(params):
if type(params) != dict:
return ESP8266.get_energy_per_cycle(_param_list_to_dict(ESP8266, params))
- return ESP8266.get_power(params) / params['cpu_freq']
+ return ESP8266.get_power(params) / params["cpu_freq"]
+
class ATTiny88:
- name = 'ATTiny88'
- parameters = {
- 'cpu_freq': [1e6, 4e6, 8e6],
- 'voltage': [2, 3, 5]
- }
- default_params = {
- 'cpu_freq' : 4e6,
- 'voltage' : 3
- }
+ name = "ATTiny88"
+ parameters = {"cpu_freq": [1e6, 4e6, 8e6], "voltage": [2, 3, 5]}
+ default_params = {"cpu_freq": 4e6, "voltage": 3}
def get_current(params):
if type(params) != dict:
return ATTiny88.get_current(_param_list_to_dict(ATTiny88, params))
- if params['cpu_freq'] == 1e6 and params['voltage'] <= 2:
+ if params["cpu_freq"] == 1e6 and params["voltage"] <= 2:
return 0.2e-3
- if params['cpu_freq'] == 4e6 and params['voltage'] <= 3:
+ if params["cpu_freq"] == 4e6 and params["voltage"] <= 3:
return 1.4e-3
- if params['cpu_freq'] == 8e6 and params['voltage'] <= 5:
+ if params["cpu_freq"] == 8e6 and params["voltage"] <= 5:
return 4.5e-3
return None
-
diff --git a/lib/data_parameters.py b/lib/data_parameters.py
index 3b7a148..1150b71 100644
--- a/lib/data_parameters.py
+++ b/lib/data_parameters.py
@@ -10,6 +10,7 @@ from . import cycles_to_energy, size_to_radio_energy, utils
import numpy as np
import ubjson
+
def _string_value_length(json):
if type(json) == str:
return len(json)
@@ -22,6 +23,7 @@ def _string_value_length(json):
return 0
+
# TODO distinguish between int and uint, which is not visible from the
# data value alone
def _int_value_length(json):
@@ -40,18 +42,21 @@ def _int_value_length(json):
return 0
+
def _string_key_length(json):
if type(json) == dict:
return sum(map(len, json.keys())) + sum(map(_string_key_length, json.values()))
return 0
+
def _num_keys(json):
if type(json) == dict:
return len(json.keys()) + sum(map(_num_keys, json.values()))
return 0
+
def _num_of_type(json, wanted_type):
ret = 0
if type(json) == wanted_type:
@@ -65,16 +70,17 @@ def _num_of_type(json, wanted_type):
return ret
+
def json_to_param(json):
"""Return numeric parameters describing the structure of JSON data."""
ret = dict()
- ret['strlen_keys'] = _string_key_length(json)
- ret['strlen_values'] = _string_value_length(json)
- ret['bytelen_int'] = _int_value_length(json)
- ret['num_int'] = _num_of_type(json, int)
- ret['num_float'] = _num_of_type(json, float)
- ret['num_str'] = _num_of_type(json, str)
+ ret["strlen_keys"] = _string_key_length(json)
+ ret["strlen_values"] = _string_value_length(json)
+ ret["bytelen_int"] = _int_value_length(json)
+ ret["num_int"] = _num_of_type(json, int)
+ ret["num_float"] = _num_of_type(json, float)
+ ret["num_str"] = _num_of_type(json, str)
return ret
@@ -127,16 +133,16 @@ class Protolog:
# bogus data
if val > 10_000_000:
return np.nan
- for val in data['nop']:
+ for val in data["nop"]:
# bogus data
if val > 10_000_000:
return np.nan
# All measurements in data[key] cover the same instructions, so they
# should be identical -> it's safe to take the median.
# However, we leave out the first measurement as it is often bogus.
- if key == 'nop':
- return np.median(data['nop'][1:])
- return max(0, int(np.median(data[key][1:]) - np.median(data['nop'][1:])))
+ if key == "nop":
+ return np.median(data["nop"][1:])
+ return max(0, int(np.median(data[key][1:]) - np.median(data["nop"][1:])))
def _median_callcycles(data):
ret = dict()
@@ -146,37 +152,44 @@ class Protolog:
idem = lambda x: x
datamap = [
- ['bss_nop', 'bss_size_nop', idem],
- ['bss_ser', 'bss_size_ser', idem],
- ['bss_serdes', 'bss_size_serdes', idem],
- ['callcycles_raw', 'callcycles', idem],
- ['callcycles_median', 'callcycles', _median_callcycles],
+ ["bss_nop", "bss_size_nop", idem],
+ ["bss_ser", "bss_size_ser", idem],
+ ["bss_serdes", "bss_size_serdes", idem],
+ ["callcycles_raw", "callcycles", idem],
+ ["callcycles_median", "callcycles", _median_callcycles],
# Used to remove nop cycles from callcycles_median
- ['cycles_nop', 'cycles', lambda x: Protolog._median_cycles(x, 'nop')],
- ['cycles_ser', 'cycles', lambda x: Protolog._median_cycles(x, 'ser')],
- ['cycles_des', 'cycles', lambda x: Protolog._median_cycles(x, 'des')],
- ['cycles_enc', 'cycles', lambda x: Protolog._median_cycles(x, 'enc')],
- ['cycles_dec', 'cycles', lambda x: Protolog._median_cycles(x, 'dec')],
- #['cycles_ser_arr', 'cycles', lambda x: np.array(x['ser'][1:]) - np.mean(x['nop'][1:])],
- #['cycles_des_arr', 'cycles', lambda x: np.array(x['des'][1:]) - np.mean(x['nop'][1:])],
- #['cycles_enc_arr', 'cycles', lambda x: np.array(x['enc'][1:]) - np.mean(x['nop'][1:])],
- #['cycles_dec_arr', 'cycles', lambda x: np.array(x['dec'][1:]) - np.mean(x['nop'][1:])],
- ['data_nop', 'data_size_nop', idem],
- ['data_ser', 'data_size_ser', idem],
- ['data_serdes', 'data_size_serdes', idem],
- ['heap_ser', 'heap_usage_ser', idem],
- ['heap_des', 'heap_usage_des', idem],
- ['serialized_size', 'serialized_size', idem],
- ['stack_alloc_ser', 'stack_online_ser', lambda x: x['allocated']],
- ['stack_set_ser', 'stack_online_ser', lambda x: x['used']],
- ['stack_alloc_des', 'stack_online_des', lambda x: x['allocated']],
- ['stack_set_des', 'stack_online_des', lambda x: x['used']],
- ['text_nop', 'text_size_nop', idem],
- ['text_ser', 'text_size_ser', idem],
- ['text_serdes', 'text_size_serdes', idem],
+ ["cycles_nop", "cycles", lambda x: Protolog._median_cycles(x, "nop")],
+ ["cycles_ser", "cycles", lambda x: Protolog._median_cycles(x, "ser")],
+ ["cycles_des", "cycles", lambda x: Protolog._median_cycles(x, "des")],
+ ["cycles_enc", "cycles", lambda x: Protolog._median_cycles(x, "enc")],
+ ["cycles_dec", "cycles", lambda x: Protolog._median_cycles(x, "dec")],
+ # ['cycles_ser_arr', 'cycles', lambda x: np.array(x['ser'][1:]) - np.mean(x['nop'][1:])],
+ # ['cycles_des_arr', 'cycles', lambda x: np.array(x['des'][1:]) - np.mean(x['nop'][1:])],
+ # ['cycles_enc_arr', 'cycles', lambda x: np.array(x['enc'][1:]) - np.mean(x['nop'][1:])],
+ # ['cycles_dec_arr', 'cycles', lambda x: np.array(x['dec'][1:]) - np.mean(x['nop'][1:])],
+ ["data_nop", "data_size_nop", idem],
+ ["data_ser", "data_size_ser", idem],
+ ["data_serdes", "data_size_serdes", idem],
+ ["heap_ser", "heap_usage_ser", idem],
+ ["heap_des", "heap_usage_des", idem],
+ ["serialized_size", "serialized_size", idem],
+ ["stack_alloc_ser", "stack_online_ser", lambda x: x["allocated"]],
+ ["stack_set_ser", "stack_online_ser", lambda x: x["used"]],
+ ["stack_alloc_des", "stack_online_des", lambda x: x["allocated"]],
+ ["stack_set_des", "stack_online_des", lambda x: x["used"]],
+ ["text_nop", "text_size_nop", idem],
+ ["text_ser", "text_size_ser", idem],
+ ["text_serdes", "text_size_serdes", idem],
]
- def __init__(self, logfile, cpu_conf = None, cpu_conf_str = None, radio_conf = None, radio_conf_str = None):
+ def __init__(
+ self,
+ logfile,
+ cpu_conf=None,
+ cpu_conf_str=None,
+ radio_conf=None,
+ radio_conf_str=None,
+ ):
"""
Load and enrich raw protobench log data.
@@ -185,116 +198,177 @@ class Protolog:
"""
self.cpu = None
self.radio = None
- with open(logfile, 'rb') as f:
+ with open(logfile, "rb") as f:
self.data = ubjson.load(f)
self.libraries = set()
self.architectures = set()
self.aggregate = dict()
for arch_lib in self.data.keys():
- arch, lib, libopts = arch_lib.split(':')
- library = lib + ':' + libopts
+ arch, lib, libopts = arch_lib.split(":")
+ library = lib + ":" + libopts
for benchmark in self.data[arch_lib].keys():
for benchmark_item in self.data[arch_lib][benchmark].keys():
subv = self.data[arch_lib][benchmark][benchmark_item]
for aggregate_label, data_label, getter in Protolog.datamap:
try:
- self.add_datapoint(arch, library, (benchmark, benchmark_item), subv, aggregate_label, data_label, getter)
+ self.add_datapoint(
+ arch,
+ library,
+ (benchmark, benchmark_item),
+ subv,
+ aggregate_label,
+ data_label,
+ getter,
+ )
except KeyError:
pass
except TypeError as e:
- print('TypeError in {} {} {} {}: {} -> {}'.format(
- arch_lib, benchmark, benchmark_item, aggregate_label,
- subv[data_label]['v'], str(e)))
+ print(
+ "TypeError in {} {} {} {}: {} -> {}".format(
+ arch_lib,
+ benchmark,
+ benchmark_item,
+ aggregate_label,
+ subv[data_label]["v"],
+ str(e),
+ )
+ )
pass
try:
- codegen = codegen_for_lib(lib, libopts.split(','), subv['data'])
+ codegen = codegen_for_lib(lib, libopts.split(","), subv["data"])
if codegen.max_serialized_bytes != None:
- self.add_datapoint(arch, library, (benchmark, benchmark_item), subv, 'buffer_size', data_label, lambda x: codegen.max_serialized_bytes)
+ self.add_datapoint(
+ arch,
+ library,
+ (benchmark, benchmark_item),
+ subv,
+ "buffer_size",
+ data_label,
+ lambda x: codegen.max_serialized_bytes,
+ )
else:
- self.add_datapoint(arch, library, (benchmark, benchmark_item), subv, 'buffer_size', data_label, lambda x: 0)
+ self.add_datapoint(
+ arch,
+ library,
+ (benchmark, benchmark_item),
+ subv,
+ "buffer_size",
+ data_label,
+ lambda x: 0,
+ )
except:
# avro's codegen will raise RuntimeError("Unsupported Schema") on unsupported data. Other libraries may just silently ignore it.
- self.add_datapoint(arch, library, (benchmark, benchmark_item), subv, 'buffer_size', data_label, lambda x: 0)
- #self.aggregate[(benchmark, benchmark_item)][arch][lib][aggregate_label] = getter(value[data_label]['v'])
-
+ self.add_datapoint(
+ arch,
+ library,
+ (benchmark, benchmark_item),
+ subv,
+ "buffer_size",
+ data_label,
+ lambda x: 0,
+ )
+ # self.aggregate[(benchmark, benchmark_item)][arch][lib][aggregate_label] = getter(value[data_label]['v'])
for key in self.aggregate.keys():
for arch in self.aggregate[key].keys():
for lib, val in self.aggregate[key][arch].items():
try:
- val['cycles_encser'] = val['cycles_enc'] + val['cycles_ser']
+ val["cycles_encser"] = val["cycles_enc"] + val["cycles_ser"]
except KeyError:
pass
try:
- val['cycles_desdec'] = val['cycles_des'] + val['cycles_dec']
+ val["cycles_desdec"] = val["cycles_des"] + val["cycles_dec"]
except KeyError:
pass
try:
- for line in val['callcycles_median'].keys():
- val['callcycles_median'][line] -= val['cycles_nop']
+ for line in val["callcycles_median"].keys():
+ val["callcycles_median"][line] -= val["cycles_nop"]
except KeyError:
pass
try:
- val['data_serdes_delta'] = val['data_serdes'] - val['data_nop']
+ val["data_serdes_delta"] = val["data_serdes"] - val["data_nop"]
except KeyError:
pass
try:
- val['data_serdes_delta_nobuf'] = val['data_serdes'] - val['data_nop'] - val['buffer_size']
+ val["data_serdes_delta_nobuf"] = (
+ val["data_serdes"] - val["data_nop"] - val["buffer_size"]
+ )
except KeyError:
pass
try:
- val['bss_serdes_delta'] = val['bss_serdes'] - val['bss_nop']
+ val["bss_serdes_delta"] = val["bss_serdes"] - val["bss_nop"]
except KeyError:
pass
try:
- val['bss_serdes_delta_nobuf'] = val['bss_serdes'] - val['bss_nop'] - val['buffer_size']
+ val["bss_serdes_delta_nobuf"] = (
+ val["bss_serdes"] - val["bss_nop"] - val["buffer_size"]
+ )
except KeyError:
pass
try:
- val['text_serdes_delta'] = val['text_serdes'] - val['text_nop']
+ val["text_serdes_delta"] = val["text_serdes"] - val["text_nop"]
except KeyError:
pass
try:
- val['total_dmem_ser'] = val['stack_alloc_ser']
- val['written_dmem_ser'] = val['stack_set_ser']
- val['total_dmem_ser'] += val['heap_ser']
- val['written_dmem_ser'] += val['heap_ser']
+ val["total_dmem_ser"] = val["stack_alloc_ser"]
+ val["written_dmem_ser"] = val["stack_set_ser"]
+ val["total_dmem_ser"] += val["heap_ser"]
+ val["written_dmem_ser"] += val["heap_ser"]
except KeyError:
pass
try:
- val['total_dmem_des'] = val['stack_alloc_des']
- val['written_dmem_des'] = val['stack_set_des']
- val['total_dmem_des'] += val['heap_des']
- val['written_dmem_des'] += val['heap_des']
+ val["total_dmem_des"] = val["stack_alloc_des"]
+ val["written_dmem_des"] = val["stack_set_des"]
+ val["total_dmem_des"] += val["heap_des"]
+ val["written_dmem_des"] += val["heap_des"]
except KeyError:
pass
try:
- val['total_dmem_serdes'] = max(val['total_dmem_ser'], val['total_dmem_des'])
+ val["total_dmem_serdes"] = max(
+ val["total_dmem_ser"], val["total_dmem_des"]
+ )
except KeyError:
pass
try:
- val['text_ser_delta'] = val['text_ser'] - val['text_nop']
- val['text_serdes_delta'] = val['text_serdes'] - val['text_nop']
+ val["text_ser_delta"] = val["text_ser"] - val["text_nop"]
+ val["text_serdes_delta"] = val["text_serdes"] - val["text_nop"]
except KeyError:
pass
try:
- val['bss_ser_delta'] = val['bss_ser'] - val['bss_nop']
- val['bss_serdes_delta'] = val['bss_serdes'] - val['bss_nop']
+ val["bss_ser_delta"] = val["bss_ser"] - val["bss_nop"]
+ val["bss_serdes_delta"] = val["bss_serdes"] - val["bss_nop"]
except KeyError:
pass
try:
- val['data_ser_delta'] = val['data_ser'] - val['data_nop']
- val['data_serdes_delta'] = val['data_serdes'] - val['data_nop']
+ val["data_ser_delta"] = val["data_ser"] - val["data_nop"]
+ val["data_serdes_delta"] = val["data_serdes"] - val["data_nop"]
except KeyError:
pass
try:
- val['allmem_ser'] = val['text_ser'] + val['data_ser'] + val['bss_ser'] + val['total_dmem_ser'] - val['buffer_size']
- val['allmem_serdes'] = val['text_serdes'] + val['data_serdes'] + val['bss_serdes'] + val['total_dmem_serdes'] - val['buffer_size']
+ val["allmem_ser"] = (
+ val["text_ser"]
+ + val["data_ser"]
+ + val["bss_ser"]
+ + val["total_dmem_ser"]
+ - val["buffer_size"]
+ )
+ val["allmem_serdes"] = (
+ val["text_serdes"]
+ + val["data_serdes"]
+ + val["bss_serdes"]
+ + val["total_dmem_serdes"]
+ - val["buffer_size"]
+ )
except KeyError:
pass
try:
- val['smem_serdes'] = val['text_serdes'] + val['data_serdes'] + val['bss_serdes'] - val['buffer_size']
+ val["smem_serdes"] = (
+ val["text_serdes"]
+ + val["data_serdes"]
+ + val["bss_serdes"]
+ - val["buffer_size"]
+ )
except KeyError:
pass
@@ -303,7 +377,7 @@ class Protolog:
if cpu_conf:
self.cpu_conf = cpu_conf
- cpu = self.cpu = cycles_to_energy.get_class(cpu_conf['model'])
+ cpu = self.cpu = cycles_to_energy.get_class(cpu_conf["model"])
for key, value in cpu.default_params.items():
if not key in cpu_conf:
cpu_conf[key] = value
@@ -312,48 +386,102 @@ class Protolog:
for lib, val in self.aggregate[key][arch].items():
# All energy data is stored in nanojoules (nJ)
try:
- val['energy_enc'] = int(val['cycles_enc'] * cpu.get_power(cpu_conf) / cpu_conf['cpu_freq'] * 1e9)
+ val["energy_enc"] = int(
+ val["cycles_enc"]
+ * cpu.get_power(cpu_conf)
+ / cpu_conf["cpu_freq"]
+ * 1e9
+ )
except KeyError:
pass
except ValueError:
- print('cycles_enc is NaN for {} -> {} -> {}'.format(arch, lib, key))
+ print(
+ "cycles_enc is NaN for {} -> {} -> {}".format(
+ arch, lib, key
+ )
+ )
try:
- val['energy_ser'] = int(val['cycles_ser'] * cpu.get_power(cpu_conf) / cpu_conf['cpu_freq'] * 1e9)
+ val["energy_ser"] = int(
+ val["cycles_ser"]
+ * cpu.get_power(cpu_conf)
+ / cpu_conf["cpu_freq"]
+ * 1e9
+ )
except KeyError:
pass
except ValueError:
- print('cycles_ser is NaN for {} -> {} -> {}'.format(arch, lib, key))
+ print(
+ "cycles_ser is NaN for {} -> {} -> {}".format(
+ arch, lib, key
+ )
+ )
try:
- val['energy_encser'] = int(val['cycles_encser'] * cpu.get_power(cpu_conf) / cpu_conf['cpu_freq'] * 1e9)
+ val["energy_encser"] = int(
+ val["cycles_encser"]
+ * cpu.get_power(cpu_conf)
+ / cpu_conf["cpu_freq"]
+ * 1e9
+ )
except KeyError:
pass
except ValueError:
- print('cycles_encser is NaN for {} -> {} -> {}'.format(arch, lib, key))
+ print(
+ "cycles_encser is NaN for {} -> {} -> {}".format(
+ arch, lib, key
+ )
+ )
try:
- val['energy_des'] = int(val['cycles_des'] * cpu.get_power(cpu_conf) / cpu_conf['cpu_freq'] * 1e9)
+ val["energy_des"] = int(
+ val["cycles_des"]
+ * cpu.get_power(cpu_conf)
+ / cpu_conf["cpu_freq"]
+ * 1e9
+ )
except KeyError:
pass
except ValueError:
- print('cycles_des is NaN for {} -> {} -> {}'.format(arch, lib, key))
+ print(
+ "cycles_des is NaN for {} -> {} -> {}".format(
+ arch, lib, key
+ )
+ )
try:
- val['energy_dec'] = int(val['cycles_dec'] * cpu.get_power(cpu_conf) / cpu_conf['cpu_freq'] * 1e9)
+ val["energy_dec"] = int(
+ val["cycles_dec"]
+ * cpu.get_power(cpu_conf)
+ / cpu_conf["cpu_freq"]
+ * 1e9
+ )
except KeyError:
pass
except ValueError:
- print('cycles_dec is NaN for {} -> {} -> {}'.format(arch, lib, key))
+ print(
+ "cycles_dec is NaN for {} -> {} -> {}".format(
+ arch, lib, key
+ )
+ )
try:
- val['energy_desdec'] = int(val['cycles_desdec'] * cpu.get_power(cpu_conf) / cpu_conf['cpu_freq'] * 1e9)
+ val["energy_desdec"] = int(
+ val["cycles_desdec"]
+ * cpu.get_power(cpu_conf)
+ / cpu_conf["cpu_freq"]
+ * 1e9
+ )
except KeyError:
pass
except ValueError:
- print('cycles_desdec is NaN for {} -> {} -> {}'.format(arch, lib, key))
+ print(
+ "cycles_desdec is NaN for {} -> {} -> {}".format(
+ arch, lib, key
+ )
+ )
if radio_conf_str:
radio_conf = utils.parse_conf_str(radio_conf_str)
if radio_conf:
self.radio_conf = radio_conf
- radio = self.radio = size_to_radio_energy.get_class(radio_conf['model'])
+ radio = self.radio = size_to_radio_energy.get_class(radio_conf["model"])
for key, value in radio.default_params.items():
if not key in radio_conf:
radio_conf[key] = value
@@ -361,17 +489,22 @@ class Protolog:
for arch in self.aggregate[key].keys():
for lib, val in self.aggregate[key][arch].items():
try:
- radio_conf['txbytes'] = val['serialized_size']
- if radio_conf['txbytes'] > 0:
- val['energy_tx'] = int(radio.get_energy(radio_conf) * 1e9)
+ radio_conf["txbytes"] = val["serialized_size"]
+ if radio_conf["txbytes"] > 0:
+ val["energy_tx"] = int(
+ radio.get_energy(radio_conf) * 1e9
+ )
else:
- val['energy_tx'] = 0
- val['energy_encsertx'] = val['energy_encser'] + val['energy_tx']
- val['energy_desdecrx'] = val['energy_desdec'] + val['energy_tx']
+ val["energy_tx"] = 0
+ val["energy_encsertx"] = (
+ val["energy_encser"] + val["energy_tx"]
+ )
+ val["energy_desdecrx"] = (
+ val["energy_desdec"] + val["energy_tx"]
+ )
except KeyError:
pass
-
def add_datapoint(self, arch, lib, key, value, aggregate_label, data_label, getter):
"""
Set self.aggregate[key][arch][lib][aggregate_Label] = getter(value[data_label]['v']).
@@ -379,7 +512,7 @@ class Protolog:
Additionally, add lib to self.libraries and arch to self.architectures
key usually is ('benchmark name', 'sub-benchmark index').
"""
- if data_label in value and 'v' in value[data_label]:
+ if data_label in value and "v" in value[data_label]:
self.architectures.add(arch)
self.libraries.add(lib)
if not key in self.aggregate:
@@ -388,4 +521,6 @@ class Protolog:
self.aggregate[key][arch] = dict()
if not lib in self.aggregate[key][arch]:
self.aggregate[key][arch][lib] = dict()
- self.aggregate[key][arch][lib][aggregate_label] = getter(value[data_label]['v'])
+ self.aggregate[key][arch][lib][aggregate_label] = getter(
+ value[data_label]["v"]
+ )
diff --git a/lib/dfatool.py b/lib/dfatool.py
index 8fb41a5..56f0f2d 100644
--- a/lib/dfatool.py
+++ b/lib/dfatool.py
@@ -15,12 +15,19 @@ from multiprocessing import Pool
from .functions import analytic
from .functions import AnalyticFunction
from .parameters import ParamStats
-from .utils import vprint, is_numeric, soft_cast_int, param_slice_eq, remove_index_from_tuple
+from .utils import (
+ vprint,
+ is_numeric,
+ soft_cast_int,
+ param_slice_eq,
+ remove_index_from_tuple,
+)
from .utils import by_name_to_by_param, match_parameter_values, running_mean
try:
from .pubcode import Code128
import zbar
+
zbar_available = True
except ImportError:
zbar_available = False
@@ -47,25 +54,25 @@ def gplearn_to_function(function_str: str):
inv -- 1 / x if |x| > 0.001, otherwise 0
"""
eval_globals = {
- 'add': lambda x, y: x + y,
- 'sub': lambda x, y: x - y,
- 'mul': lambda x, y: x * y,
- 'div': lambda x, y: np.divide(x, y) if np.abs(y) > 0.001 else 1.,
- 'sqrt': lambda x: np.sqrt(np.abs(x)),
- 'log': lambda x: np.log(np.abs(x)) if np.abs(x) > 0.001 else 0.,
- 'inv': lambda x: 1. / x if np.abs(x) > 0.001 else 0.,
+ "add": lambda x, y: x + y,
+ "sub": lambda x, y: x - y,
+ "mul": lambda x, y: x * y,
+ "div": lambda x, y: np.divide(x, y) if np.abs(y) > 0.001 else 1.0,
+ "sqrt": lambda x: np.sqrt(np.abs(x)),
+ "log": lambda x: np.log(np.abs(x)) if np.abs(x) > 0.001 else 0.0,
+ "inv": lambda x: 1.0 / x if np.abs(x) > 0.001 else 0.0,
}
last_arg_index = 0
for i in range(0, 100):
- if function_str.find('X{:d}'.format(i)) >= 0:
+ if function_str.find("X{:d}".format(i)) >= 0:
last_arg_index = i
arg_list = []
for i in range(0, last_arg_index + 1):
- arg_list.append('X{:d}'.format(i))
+ arg_list.append("X{:d}".format(i))
- eval_str = 'lambda {}, *whatever: {}'.format(','.join(arg_list), function_str)
+ eval_str = "lambda {}, *whatever: {}".format(",".join(arg_list), function_str)
print(eval_str)
return eval(eval_str, eval_globals)
@@ -123,32 +130,35 @@ def regression_measures(predicted: np.ndarray, actual: np.ndarray):
count -- Number of values
"""
if type(predicted) != np.ndarray:
- raise ValueError('first arg must be ndarray, is {}'.format(type(predicted)))
+ raise ValueError("first arg must be ndarray, is {}".format(type(predicted)))
if type(actual) != np.ndarray:
- raise ValueError('second arg must be ndarray, is {}'.format(type(actual)))
+ raise ValueError("second arg must be ndarray, is {}".format(type(actual)))
deviations = predicted - actual
# mean = np.mean(actual)
if len(deviations) == 0:
return {}
measures = {
- 'mae': np.mean(np.abs(deviations), dtype=np.float64),
- 'msd': np.mean(deviations**2, dtype=np.float64),
- 'rmsd': np.sqrt(np.mean(deviations**2), dtype=np.float64),
- 'ssr': np.sum(deviations**2, dtype=np.float64),
- 'rsq': r2_score(actual, predicted),
- 'count': len(actual),
+ "mae": np.mean(np.abs(deviations), dtype=np.float64),
+ "msd": np.mean(deviations ** 2, dtype=np.float64),
+ "rmsd": np.sqrt(np.mean(deviations ** 2), dtype=np.float64),
+ "ssr": np.sum(deviations ** 2, dtype=np.float64),
+ "rsq": r2_score(actual, predicted),
+ "count": len(actual),
}
# rsq_quotient = np.sum((actual - mean)**2, dtype=np.float64) * np.sum((predicted - mean)**2, dtype=np.float64)
if np.all(actual != 0):
- measures['mape'] = np.mean(np.abs(deviations / actual)) * 100 # bad measure
+ measures["mape"] = np.mean(np.abs(deviations / actual)) * 100 # bad measure
else:
- measures['mape'] = np.nan
+ measures["mape"] = np.nan
if np.all(np.abs(predicted) + np.abs(actual) != 0):
- measures['smape'] = np.mean(np.abs(deviations) / ((np.abs(predicted) + np.abs(actual)) / 2)) * 100
+ measures["smape"] = (
+ np.mean(np.abs(deviations) / ((np.abs(predicted) + np.abs(actual)) / 2))
+ * 100
+ )
else:
- measures['smape'] = np.nan
+ measures["smape"] = np.nan
# if np.all(rsq_quotient != 0):
# measures['rsq'] = (np.sum((actual - mean) * (predicted - mean), dtype=np.float64)**2) / rsq_quotient
@@ -177,7 +187,7 @@ class KeysightCSV:
with open(filename) as f:
for _ in range(4):
next(f)
- reader = csv.reader(f, delimiter=',')
+ reader = csv.reader(f, delimiter=",")
for i, row in enumerate(reader):
timestamps[i] = float(row[0])
currents[i] = float(row[2]) * -1
@@ -266,29 +276,35 @@ class CrossValidator:
}
}
"""
- ret = {
- 'by_name': dict()
- }
+ ret = {"by_name": dict()}
for name in self.names:
- ret['by_name'][name] = dict()
- for attribute in self.by_name[name]['attributes']:
- ret['by_name'][name][attribute] = {
- 'mae_list': list(),
- 'smape_list': list()
+ ret["by_name"][name] = dict()
+ for attribute in self.by_name[name]["attributes"]:
+ ret["by_name"][name][attribute] = {
+ "mae_list": list(),
+ "smape_list": list(),
}
for _ in range(count):
res = self._single_montecarlo(model_getter)
for name in self.names:
- for attribute in self.by_name[name]['attributes']:
- ret['by_name'][name][attribute]['mae_list'].append(res['by_name'][name][attribute]['mae'])
- ret['by_name'][name][attribute]['smape_list'].append(res['by_name'][name][attribute]['smape'])
+ for attribute in self.by_name[name]["attributes"]:
+ ret["by_name"][name][attribute]["mae_list"].append(
+ res["by_name"][name][attribute]["mae"]
+ )
+ ret["by_name"][name][attribute]["smape_list"].append(
+ res["by_name"][name][attribute]["smape"]
+ )
for name in self.names:
- for attribute in self.by_name[name]['attributes']:
- ret['by_name'][name][attribute]['mae'] = np.mean(ret['by_name'][name][attribute]['mae_list'])
- ret['by_name'][name][attribute]['smape'] = np.mean(ret['by_name'][name][attribute]['smape_list'])
+ for attribute in self.by_name[name]["attributes"]:
+ ret["by_name"][name][attribute]["mae"] = np.mean(
+ ret["by_name"][name][attribute]["mae_list"]
+ )
+ ret["by_name"][name][attribute]["smape"] = np.mean(
+ ret["by_name"][name][attribute]["smape_list"]
+ )
return ret
@@ -296,77 +312,87 @@ class CrossValidator:
training = dict()
validation = dict()
for name in self.names:
- training[name] = {
- 'attributes': self.by_name[name]['attributes']
- }
- validation[name] = {
- 'attributes': self.by_name[name]['attributes']
- }
+ training[name] = {"attributes": self.by_name[name]["attributes"]}
+ validation[name] = {"attributes": self.by_name[name]["attributes"]}
- if 'isa' in self.by_name[name]:
- training[name]['isa'] = self.by_name[name]['isa']
- validation[name]['isa'] = self.by_name[name]['isa']
+ if "isa" in self.by_name[name]:
+ training[name]["isa"] = self.by_name[name]["isa"]
+ validation[name]["isa"] = self.by_name[name]["isa"]
- data_count = len(self.by_name[name]['param'])
+ data_count = len(self.by_name[name]["param"])
training_subset, validation_subset = _xv_partition_montecarlo(data_count)
- for attribute in self.by_name[name]['attributes']:
+ for attribute in self.by_name[name]["attributes"]:
self.by_name[name][attribute] = np.array(self.by_name[name][attribute])
- training[name][attribute] = self.by_name[name][attribute][training_subset]
- validation[name][attribute] = self.by_name[name][attribute][validation_subset]
+ training[name][attribute] = self.by_name[name][attribute][
+ training_subset
+ ]
+ validation[name][attribute] = self.by_name[name][attribute][
+ validation_subset
+ ]
# We can't use slice syntax for 'param', which may contain strings and other odd values
- training[name]['param'] = list()
- validation[name]['param'] = list()
+ training[name]["param"] = list()
+ validation[name]["param"] = list()
for idx in training_subset:
- training[name]['param'].append(self.by_name[name]['param'][idx])
+ training[name]["param"].append(self.by_name[name]["param"][idx])
for idx in validation_subset:
- validation[name]['param'].append(self.by_name[name]['param'][idx])
+ validation[name]["param"].append(self.by_name[name]["param"][idx])
- training_data = self.model_class(training, self.parameters, self.arg_count, verbose=False)
+ training_data = self.model_class(
+ training, self.parameters, self.arg_count, verbose=False
+ )
training_model = model_getter(training_data)
- validation_data = self.model_class(validation, self.parameters, self.arg_count, verbose=False)
+ validation_data = self.model_class(
+ validation, self.parameters, self.arg_count, verbose=False
+ )
return validation_data.assess(training_model)
def _preprocess_mimosa(measurement):
- setup = measurement['setup']
- mim = MIMOSA(float(setup['mimosa_voltage']), int(setup['mimosa_shunt']), with_traces=measurement['with_traces'])
+ setup = measurement["setup"]
+ mim = MIMOSA(
+ float(setup["mimosa_voltage"]),
+ int(setup["mimosa_shunt"]),
+ with_traces=measurement["with_traces"],
+ )
try:
- charges, triggers = mim.load_data(measurement['content'])
+ charges, triggers = mim.load_data(measurement["content"])
trigidx = mim.trigger_edges(triggers)
except EOFError as e:
- mim.errors.append('MIMOSA logfile error: {}'.format(e))
+ mim.errors.append("MIMOSA logfile error: {}".format(e))
trigidx = list()
if len(trigidx) == 0:
- mim.errors.append('MIMOSA log has no triggers')
+ mim.errors.append("MIMOSA log has no triggers")
return {
- 'fileno': measurement['fileno'],
- 'info': measurement['info'],
- 'has_datasource_error': len(mim.errors) > 0,
- 'datasource_errors': mim.errors,
- 'expected_trace': measurement['expected_trace'],
- 'repeat_id': measurement['repeat_id'],
+ "fileno": measurement["fileno"],
+ "info": measurement["info"],
+ "has_datasource_error": len(mim.errors) > 0,
+ "datasource_errors": mim.errors,
+ "expected_trace": measurement["expected_trace"],
+ "repeat_id": measurement["repeat_id"],
}
- cal_edges = mim.calibration_edges(running_mean(mim.currents_nocal(charges[0:trigidx[0]]), 10))
+ cal_edges = mim.calibration_edges(
+ running_mean(mim.currents_nocal(charges[0 : trigidx[0]]), 10)
+ )
calfunc, caldata = mim.calibration_function(charges, cal_edges)
vcalfunc = np.vectorize(calfunc, otypes=[np.float64])
processed_data = {
- 'fileno': measurement['fileno'],
- 'info': measurement['info'],
- 'triggers': len(trigidx),
- 'first_trig': trigidx[0] * 10,
- 'calibration': caldata,
- 'energy_trace': mim.analyze_states(charges, trigidx, vcalfunc),
- 'has_datasource_error': len(mim.errors) > 0,
- 'datasource_errors': mim.errors,
+ "fileno": measurement["fileno"],
+ "info": measurement["info"],
+ "triggers": len(trigidx),
+ "first_trig": trigidx[0] * 10,
+ "calibration": caldata,
+ "energy_trace": mim.analyze_states(charges, trigidx, vcalfunc),
+ "has_datasource_error": len(mim.errors) > 0,
+ "datasource_errors": mim.errors,
}
- for key in ['expected_trace', 'repeat_id']:
+ for key in ["expected_trace", "repeat_id"]:
if key in measurement:
processed_data[key] = measurement[key]
@@ -374,22 +400,28 @@ def _preprocess_mimosa(measurement):
def _preprocess_etlog(measurement):
- setup = measurement['setup']
- etlog = EnergyTraceLog(float(setup['voltage']), int(setup['state_duration']), measurement['transition_names'])
+ setup = measurement["setup"]
+ etlog = EnergyTraceLog(
+ float(setup["voltage"]),
+ int(setup["state_duration"]),
+ measurement["transition_names"],
+ )
try:
- etlog.load_data(measurement['content'])
- states_and_transitions = etlog.analyze_states(measurement['expected_trace'], measurement['repeat_id'])
+ etlog.load_data(measurement["content"])
+ states_and_transitions = etlog.analyze_states(
+ measurement["expected_trace"], measurement["repeat_id"]
+ )
except EOFError as e:
- etlog.errors.append('EnergyTrace logfile error: {}'.format(e))
+ etlog.errors.append("EnergyTrace logfile error: {}".format(e))
processed_data = {
- 'fileno': measurement['fileno'],
- 'repeat_id': measurement['repeat_id'],
- 'info': measurement['info'],
- 'expected_trace': measurement['expected_trace'],
- 'energy_trace': states_and_transitions,
- 'has_datasource_error': len(etlog.errors) > 0,
- 'datasource_errors': etlog.errors,
+ "fileno": measurement["fileno"],
+ "repeat_id": measurement["repeat_id"],
+ "info": measurement["info"],
+ "expected_trace": measurement["expected_trace"],
+ "energy_trace": states_and_transitions,
+ "has_datasource_error": len(etlog.errors) > 0,
+ "datasource_errors": etlog.errors,
}
return processed_data
@@ -421,35 +453,40 @@ class TimingData:
for trace_group in self.traces_by_fileno:
for trace in trace_group:
# TimingHarness logs states, but does not aggregate any data for them at the moment -> throw all states away
- transitions = list(filter(lambda x: x['isa'] == 'transition', trace['trace']))
- self.traces.append({
- 'id': trace['id'],
- 'trace': transitions,
- })
+ transitions = list(
+ filter(lambda x: x["isa"] == "transition", trace["trace"])
+ )
+ self.traces.append(
+ {"id": trace["id"], "trace": transitions,}
+ )
for i, trace in enumerate(self.traces):
- trace['orig_id'] = trace['id']
- trace['id'] = i
- for log_entry in trace['trace']:
- paramkeys = sorted(log_entry['parameter'].keys())
- if 'param' not in log_entry['offline_aggregates']:
- log_entry['offline_aggregates']['param'] = list()
- if 'duration' in log_entry['offline_aggregates']:
- for i in range(len(log_entry['offline_aggregates']['duration'])):
+ trace["orig_id"] = trace["id"]
+ trace["id"] = i
+ for log_entry in trace["trace"]:
+ paramkeys = sorted(log_entry["parameter"].keys())
+ if "param" not in log_entry["offline_aggregates"]:
+ log_entry["offline_aggregates"]["param"] = list()
+ if "duration" in log_entry["offline_aggregates"]:
+ for i in range(len(log_entry["offline_aggregates"]["duration"])):
paramvalues = list()
for paramkey in paramkeys:
- if type(log_entry['parameter'][paramkey]) is list:
- paramvalues.append(soft_cast_int(log_entry['parameter'][paramkey][i]))
+ if type(log_entry["parameter"][paramkey]) is list:
+ paramvalues.append(
+ soft_cast_int(log_entry["parameter"][paramkey][i])
+ )
else:
- paramvalues.append(soft_cast_int(log_entry['parameter'][paramkey]))
- if arg_support_enabled and 'args' in log_entry:
- paramvalues.extend(map(soft_cast_int, log_entry['args']))
- log_entry['offline_aggregates']['param'].append(paramvalues)
+ paramvalues.append(
+ soft_cast_int(log_entry["parameter"][paramkey])
+ )
+ if arg_support_enabled and "args" in log_entry:
+ paramvalues.extend(map(soft_cast_int, log_entry["args"]))
+ log_entry["offline_aggregates"]["param"].append(paramvalues)
def _preprocess_0(self):
for filename in self.filenames:
- with open(filename, 'r') as f:
+ with open(filename, "r") as f:
log_data = json.load(f)
- self.traces_by_fileno.extend(log_data['traces'])
+ self.traces_by_fileno.extend(log_data["traces"])
self._concatenate_analyzed_traces()
def get_preprocessed_data(self, verbose=True):
@@ -470,17 +507,25 @@ class TimingData:
def sanity_check_aggregate(aggregate):
for key in aggregate:
- if 'param' not in aggregate[key]:
- raise RuntimeError('aggregate[{}][param] does not exist'.format(key))
- if 'attributes' not in aggregate[key]:
- raise RuntimeError('aggregate[{}][attributes] does not exist'.format(key))
- for attribute in aggregate[key]['attributes']:
+ if "param" not in aggregate[key]:
+ raise RuntimeError("aggregate[{}][param] does not exist".format(key))
+ if "attributes" not in aggregate[key]:
+ raise RuntimeError("aggregate[{}][attributes] does not exist".format(key))
+ for attribute in aggregate[key]["attributes"]:
if attribute not in aggregate[key]:
- raise RuntimeError('aggregate[{}][{}] does not exist, even though it is contained in aggregate[{}][attributes]'.format(key, attribute, key))
- param_len = len(aggregate[key]['param'])
+ raise RuntimeError(
+ "aggregate[{}][{}] does not exist, even though it is contained in aggregate[{}][attributes]".format(
+ key, attribute, key
+ )
+ )
+ param_len = len(aggregate[key]["param"])
attr_len = len(aggregate[key][attribute])
if param_len != attr_len:
- raise RuntimeError('parameter mismatch: len(aggregate[{}][param]) == {} != len(aggregate[{}][{}]) == {}'.format(key, param_len, key, attribute, attr_len))
+ raise RuntimeError(
+ "parameter mismatch: len(aggregate[{}][param]) == {} != len(aggregate[{}][{}]) == {}".format(
+ key, param_len, key, attribute, attr_len
+ )
+ )
class RawData:
@@ -559,11 +604,11 @@ class RawData:
with tarfile.open(filenames[0]) as tf:
for member in tf.getmembers():
- if member.name == 'ptalog.json' and self.version == 0:
+ if member.name == "ptalog.json" and self.version == 0:
self.version = 1
# might also be version 2
# depends on whether *.etlog exists or not
- elif '.etlog' in member.name:
+ elif ".etlog" in member.name:
self.version = 2
break
@@ -572,18 +617,18 @@ class RawData:
self.load_cache()
def set_cache_file(self):
- cache_key = hashlib.sha256('!'.join(self.filenames).encode()).hexdigest()
- self.cache_dir = os.path.dirname(self.filenames[0]) + '/cache'
- self.cache_file = '{}/{}.json'.format(self.cache_dir, cache_key)
+ cache_key = hashlib.sha256("!".join(self.filenames).encode()).hexdigest()
+ self.cache_dir = os.path.dirname(self.filenames[0]) + "/cache"
+ self.cache_file = "{}/{}.json".format(self.cache_dir, cache_key)
def load_cache(self):
if os.path.exists(self.cache_file):
- with open(self.cache_file, 'r') as f:
+ with open(self.cache_file, "r") as f:
cache_data = json.load(f)
- self.traces = cache_data['traces']
- self.preprocessing_stats = cache_data['preprocessing_stats']
- if 'pta' in cache_data:
- self.pta = cache_data['pta']
+ self.traces = cache_data["traces"]
+ self.preprocessing_stats = cache_data["preprocessing_stats"]
+ if "pta" in cache_data:
+ self.pta = cache_data["pta"]
self.preprocessed = True
def save_cache(self):
@@ -593,30 +638,30 @@ class RawData:
os.mkdir(self.cache_dir)
except FileExistsError:
pass
- with open(self.cache_file, 'w') as f:
+ with open(self.cache_file, "w") as f:
cache_data = {
- 'traces': self.traces,
- 'preprocessing_stats': self.preprocessing_stats,
- 'pta': self.pta,
+ "traces": self.traces,
+ "preprocessing_stats": self.preprocessing_stats,
+ "pta": self.pta,
}
json.dump(cache_data, f)
def _state_is_too_short(self, online, offline, state_duration, next_transition):
# We cannot control when an interrupt causes a state to be left
- if next_transition['plan']['level'] == 'epilogue':
+ if next_transition["plan"]["level"] == "epilogue":
return False
# Note: state_duration is stored as ms, not us
- return offline['us'] < state_duration * 500
+ return offline["us"] < state_duration * 500
def _state_is_too_long(self, online, offline, state_duration, prev_transition):
# If the previous state was left by an interrupt, we may have some
# waiting time left over. So it's okay if the current state is longer
# than expected.
- if prev_transition['plan']['level'] == 'epilogue':
+ if prev_transition["plan"]["level"] == "epilogue":
return False
# state_duration is stored as ms, not us
- return offline['us'] > state_duration * 1500
+ return offline["us"] > state_duration * 1500
def _measurement_is_valid_2(self, processed_data):
"""
@@ -642,8 +687,8 @@ class RawData:
"""
# Check for low-level parser errors
- if processed_data['has_datasource_error']:
- processed_data['error'] = '; '.join(processed_data['datasource_errors'])
+ if processed_data["has_datasource_error"]:
+ processed_data["error"] = "; ".join(processed_data["datasource_errors"])
return False
# Note that the low-level parser (EnergyTraceLog) already checks
@@ -680,26 +725,27 @@ class RawData:
- uW_mean_delta_prev: Differenz zwischen uW_mean und uW_mean des vorherigen Zustands
- uW_mean_delta_next: Differenz zwischen uW_mean und uW_mean des Folgezustands
"""
- setup = self.setup_by_fileno[processed_data['fileno']]
- if 'expected_trace' in processed_data:
- traces = processed_data['expected_trace']
+ setup = self.setup_by_fileno[processed_data["fileno"]]
+ if "expected_trace" in processed_data:
+ traces = processed_data["expected_trace"]
else:
- traces = self.traces_by_fileno[processed_data['fileno']]
- state_duration = setup['state_duration']
+ traces = self.traces_by_fileno[processed_data["fileno"]]
+ state_duration = setup["state_duration"]
# Check MIMOSA error
- if processed_data['has_datasource_error']:
- processed_data['error'] = '; '.join(processed_data['datasource_errors'])
+ if processed_data["has_datasource_error"]:
+ processed_data["error"] = "; ".join(processed_data["datasource_errors"])
return False
# Check trigger count
sched_trigger_count = 0
for run in traces:
- sched_trigger_count += len(run['trace'])
- if sched_trigger_count != processed_data['triggers']:
- processed_data['error'] = 'got {got:d} trigger edges, expected {exp:d}'.format(
- got=processed_data['triggers'],
- exp=sched_trigger_count
+ sched_trigger_count += len(run["trace"])
+ if sched_trigger_count != processed_data["triggers"]:
+ processed_data[
+ "error"
+ ] = "got {got:d} trigger edges, expected {exp:d}".format(
+ got=processed_data["triggers"], exp=sched_trigger_count
)
return False
# Check state durations. Very short or long states can indicate a
@@ -707,62 +753,102 @@ class RawData:
# triggers elsewhere
online_datapoints = []
for run_idx, run in enumerate(traces):
- for trace_part_idx in range(len(run['trace'])):
+ for trace_part_idx in range(len(run["trace"])):
online_datapoints.append((run_idx, trace_part_idx))
for offline_idx, online_ref in enumerate(online_datapoints):
online_run_idx, online_trace_part_idx = online_ref
- offline_trace_part = processed_data['energy_trace'][offline_idx]
- online_trace_part = traces[online_run_idx]['trace'][online_trace_part_idx]
+ offline_trace_part = processed_data["energy_trace"][offline_idx]
+ online_trace_part = traces[online_run_idx]["trace"][online_trace_part_idx]
if self._parameter_names is None:
- self._parameter_names = sorted(online_trace_part['parameter'].keys())
-
- if sorted(online_trace_part['parameter'].keys()) != self._parameter_names:
- processed_data['error'] = 'Offline #{off_idx:d} (online {on_name:s} @ {on_idx:d}/{on_sub:d}) has inconsistent parameter set: should be {param_want:s}, is {param_is:s}'.format(
- off_idx=offline_idx, on_idx=online_run_idx,
+ self._parameter_names = sorted(online_trace_part["parameter"].keys())
+
+ if sorted(online_trace_part["parameter"].keys()) != self._parameter_names:
+ processed_data[
+ "error"
+ ] = "Offline #{off_idx:d} (online {on_name:s} @ {on_idx:d}/{on_sub:d}) has inconsistent parameter set: should be {param_want:s}, is {param_is:s}".format(
+ off_idx=offline_idx,
+ on_idx=online_run_idx,
on_sub=online_trace_part_idx,
- on_name=online_trace_part['name'],
+ on_name=online_trace_part["name"],
param_want=self._parameter_names,
- param_is=sorted(online_trace_part['parameter'].keys())
+ param_is=sorted(online_trace_part["parameter"].keys()),
)
- if online_trace_part['isa'] != offline_trace_part['isa']:
- processed_data['error'] = 'Offline #{off_idx:d} (online {on_name:s} @ {on_idx:d}/{on_sub:d}) claims to be {off_isa:s}, but should be {on_isa:s}'.format(
- off_idx=offline_idx, on_idx=online_run_idx,
+ if online_trace_part["isa"] != offline_trace_part["isa"]:
+ processed_data[
+ "error"
+ ] = "Offline #{off_idx:d} (online {on_name:s} @ {on_idx:d}/{on_sub:d}) claims to be {off_isa:s}, but should be {on_isa:s}".format(
+ off_idx=offline_idx,
+ on_idx=online_run_idx,
on_sub=online_trace_part_idx,
- on_name=online_trace_part['name'],
- off_isa=offline_trace_part['isa'],
- on_isa=online_trace_part['isa'])
+ on_name=online_trace_part["name"],
+ off_isa=offline_trace_part["isa"],
+ on_isa=online_trace_part["isa"],
+ )
return False
# Clipping in UNINITIALIZED (offline_idx == 0) can happen during
# calibration and is handled by MIMOSA
- if offline_idx != 0 and offline_trace_part['clip_rate'] != 0 and not self.ignore_clipping:
- processed_data['error'] = 'Offline #{off_idx:d} (online {on_name:s} @ {on_idx:d}/{on_sub:d}) was clipping {clip:f}% of the time'.format(
- off_idx=offline_idx, on_idx=online_run_idx,
+ if (
+ offline_idx != 0
+ and offline_trace_part["clip_rate"] != 0
+ and not self.ignore_clipping
+ ):
+ processed_data[
+ "error"
+ ] = "Offline #{off_idx:d} (online {on_name:s} @ {on_idx:d}/{on_sub:d}) was clipping {clip:f}% of the time".format(
+ off_idx=offline_idx,
+ on_idx=online_run_idx,
on_sub=online_trace_part_idx,
- on_name=online_trace_part['name'],
- clip=offline_trace_part['clip_rate'] * 100,
+ on_name=online_trace_part["name"],
+ clip=offline_trace_part["clip_rate"] * 100,
)
return False
- if online_trace_part['isa'] == 'state' and online_trace_part['name'] != 'UNINITIALIZED' and len(traces[online_run_idx]['trace']) > online_trace_part_idx + 1:
- online_prev_transition = traces[online_run_idx]['trace'][online_trace_part_idx - 1]
- online_next_transition = traces[online_run_idx]['trace'][online_trace_part_idx + 1]
+ if (
+ online_trace_part["isa"] == "state"
+ and online_trace_part["name"] != "UNINITIALIZED"
+ and len(traces[online_run_idx]["trace"]) > online_trace_part_idx + 1
+ ):
+ online_prev_transition = traces[online_run_idx]["trace"][
+ online_trace_part_idx - 1
+ ]
+ online_next_transition = traces[online_run_idx]["trace"][
+ online_trace_part_idx + 1
+ ]
try:
- if self._state_is_too_short(online_trace_part, offline_trace_part, state_duration, online_next_transition):
- processed_data['error'] = 'Offline #{off_idx:d} (online {on_name:s} @ {on_idx:d}/{on_sub:d}) is too short (duration = {dur:d} us)'.format(
- off_idx=offline_idx, on_idx=online_run_idx,
+ if self._state_is_too_short(
+ online_trace_part,
+ offline_trace_part,
+ state_duration,
+ online_next_transition,
+ ):
+ processed_data[
+ "error"
+ ] = "Offline #{off_idx:d} (online {on_name:s} @ {on_idx:d}/{on_sub:d}) is too short (duration = {dur:d} us)".format(
+ off_idx=offline_idx,
+ on_idx=online_run_idx,
on_sub=online_trace_part_idx,
- on_name=online_trace_part['name'],
- dur=offline_trace_part['us'])
+ on_name=online_trace_part["name"],
+ dur=offline_trace_part["us"],
+ )
return False
- if self._state_is_too_long(online_trace_part, offline_trace_part, state_duration, online_prev_transition):
- processed_data['error'] = 'Offline #{off_idx:d} (online {on_name:s} @ {on_idx:d}/{on_sub:d}) is too long (duration = {dur:d} us)'.format(
- off_idx=offline_idx, on_idx=online_run_idx,
+ if self._state_is_too_long(
+ online_trace_part,
+ offline_trace_part,
+ state_duration,
+ online_prev_transition,
+ ):
+ processed_data[
+ "error"
+ ] = "Offline #{off_idx:d} (online {on_name:s} @ {on_idx:d}/{on_sub:d}) is too long (duration = {dur:d} us)".format(
+ off_idx=offline_idx,
+ on_idx=online_run_idx,
on_sub=online_trace_part_idx,
- on_name=online_trace_part['name'],
- dur=offline_trace_part['us'])
+ on_name=online_trace_part["name"],
+ dur=offline_trace_part["us"],
+ )
return False
except KeyError:
pass
@@ -775,136 +861,169 @@ class RawData:
# (appends data from measurement['energy_trace'])
# If measurement['expected_trace'] exists, it is edited in place instead
online_datapoints = []
- if 'expected_trace' in measurement:
- traces = measurement['expected_trace']
- traces = self.traces_by_fileno[measurement['fileno']]
+ if "expected_trace" in measurement:
+ traces = measurement["expected_trace"]
+ traces = self.traces_by_fileno[measurement["fileno"]]
else:
- traces = self.traces_by_fileno[measurement['fileno']]
+ traces = self.traces_by_fileno[measurement["fileno"]]
for run_idx, run in enumerate(traces):
- for trace_part_idx in range(len(run['trace'])):
+ for trace_part_idx in range(len(run["trace"])):
online_datapoints.append((run_idx, trace_part_idx))
for offline_idx, online_ref in enumerate(online_datapoints):
online_run_idx, online_trace_part_idx = online_ref
- offline_trace_part = measurement['energy_trace'][offline_idx]
- online_trace_part = traces[online_run_idx]['trace'][online_trace_part_idx]
+ offline_trace_part = measurement["energy_trace"][offline_idx]
+ online_trace_part = traces[online_run_idx]["trace"][online_trace_part_idx]
- if 'offline' not in online_trace_part:
- online_trace_part['offline'] = [offline_trace_part]
+ if "offline" not in online_trace_part:
+ online_trace_part["offline"] = [offline_trace_part]
else:
- online_trace_part['offline'].append(offline_trace_part)
+ online_trace_part["offline"].append(offline_trace_part)
- paramkeys = sorted(online_trace_part['parameter'].keys())
+ paramkeys = sorted(online_trace_part["parameter"].keys())
paramvalues = list()
for paramkey in paramkeys:
- if type(online_trace_part['parameter'][paramkey]) is list:
- paramvalues.append(soft_cast_int(online_trace_part['parameter'][paramkey][measurement['repeat_id']]))
+ if type(online_trace_part["parameter"][paramkey]) is list:
+ paramvalues.append(
+ soft_cast_int(
+ online_trace_part["parameter"][paramkey][
+ measurement["repeat_id"]
+ ]
+ )
+ )
else:
- paramvalues.append(soft_cast_int(online_trace_part['parameter'][paramkey]))
+ paramvalues.append(
+ soft_cast_int(online_trace_part["parameter"][paramkey])
+ )
# NB: Unscheduled transitions do not have an 'args' field set.
# However, they should only be caused by interrupts, and
# interrupts don't have args anyways.
- if arg_support_enabled and 'args' in online_trace_part:
- paramvalues.extend(map(soft_cast_int, online_trace_part['args']))
-
- if 'offline_aggregates' not in online_trace_part:
- online_trace_part['offline_attributes'] = ['power', 'duration', 'energy']
- online_trace_part['offline_aggregates'] = {
- 'power': [],
- 'duration': [],
- 'power_std': [],
- 'energy': [],
- 'paramkeys': [],
- 'param': [],
+ if arg_support_enabled and "args" in online_trace_part:
+ paramvalues.extend(map(soft_cast_int, online_trace_part["args"]))
+
+ if "offline_aggregates" not in online_trace_part:
+ online_trace_part["offline_attributes"] = [
+ "power",
+ "duration",
+ "energy",
+ ]
+ online_trace_part["offline_aggregates"] = {
+ "power": [],
+ "duration": [],
+ "power_std": [],
+ "energy": [],
+ "paramkeys": [],
+ "param": [],
}
- if online_trace_part['isa'] == 'transition':
- online_trace_part['offline_attributes'].extend(['rel_energy_prev', 'rel_energy_next', 'timeout'])
- online_trace_part['offline_aggregates']['rel_energy_prev'] = []
- online_trace_part['offline_aggregates']['rel_energy_next'] = []
- online_trace_part['offline_aggregates']['timeout'] = []
+ if online_trace_part["isa"] == "transition":
+ online_trace_part["offline_attributes"].extend(
+ ["rel_energy_prev", "rel_energy_next", "timeout"]
+ )
+ online_trace_part["offline_aggregates"]["rel_energy_prev"] = []
+ online_trace_part["offline_aggregates"]["rel_energy_next"] = []
+ online_trace_part["offline_aggregates"]["timeout"] = []
# Note: All state/transitions are 20us "too long" due to injected
# active wait states. These are needed to work around MIMOSA's
# relatively low sample rate of 100 kHz (10us) and removed here.
- online_trace_part['offline_aggregates']['power'].append(
- offline_trace_part['uW_mean'])
- online_trace_part['offline_aggregates']['duration'].append(
- offline_trace_part['us'] - 20)
- online_trace_part['offline_aggregates']['power_std'].append(
- offline_trace_part['uW_std'])
- online_trace_part['offline_aggregates']['energy'].append(
- offline_trace_part['uW_mean'] * (offline_trace_part['us'] - 20))
- online_trace_part['offline_aggregates']['paramkeys'].append(paramkeys)
- online_trace_part['offline_aggregates']['param'].append(paramvalues)
- if online_trace_part['isa'] == 'transition':
- online_trace_part['offline_aggregates']['rel_energy_prev'].append(
- offline_trace_part['uW_mean_delta_prev'] * (offline_trace_part['us'] - 20))
- online_trace_part['offline_aggregates']['rel_energy_next'].append(
- offline_trace_part['uW_mean_delta_next'] * (offline_trace_part['us'] - 20))
- online_trace_part['offline_aggregates']['timeout'].append(
- offline_trace_part['timeout'])
+ online_trace_part["offline_aggregates"]["power"].append(
+ offline_trace_part["uW_mean"]
+ )
+ online_trace_part["offline_aggregates"]["duration"].append(
+ offline_trace_part["us"] - 20
+ )
+ online_trace_part["offline_aggregates"]["power_std"].append(
+ offline_trace_part["uW_std"]
+ )
+ online_trace_part["offline_aggregates"]["energy"].append(
+ offline_trace_part["uW_mean"] * (offline_trace_part["us"] - 20)
+ )
+ online_trace_part["offline_aggregates"]["paramkeys"].append(paramkeys)
+ online_trace_part["offline_aggregates"]["param"].append(paramvalues)
+ if online_trace_part["isa"] == "transition":
+ online_trace_part["offline_aggregates"]["rel_energy_prev"].append(
+ offline_trace_part["uW_mean_delta_prev"]
+ * (offline_trace_part["us"] - 20)
+ )
+ online_trace_part["offline_aggregates"]["rel_energy_next"].append(
+ offline_trace_part["uW_mean_delta_next"]
+ * (offline_trace_part["us"] - 20)
+ )
+ online_trace_part["offline_aggregates"]["timeout"].append(
+ offline_trace_part["timeout"]
+ )
def _merge_online_and_etlog(self, measurement):
# Edits self.traces_by_fileno[measurement['fileno']][*]['trace'][*]['offline']
# and self.traces_by_fileno[measurement['fileno']][*]['trace'][*]['offline_aggregates'] in place
# (appends data from measurement['energy_trace'])
online_datapoints = []
- traces = self.traces_by_fileno[measurement['fileno']]
+ traces = self.traces_by_fileno[measurement["fileno"]]
for run_idx, run in enumerate(traces):
- for trace_part_idx in range(len(run['trace'])):
+ for trace_part_idx in range(len(run["trace"])):
online_datapoints.append((run_idx, trace_part_idx))
for offline_idx, online_ref in enumerate(online_datapoints):
online_run_idx, online_trace_part_idx = online_ref
- offline_trace_part = measurement['energy_trace'][offline_idx]
- online_trace_part = traces[online_run_idx]['trace'][online_trace_part_idx]
+ offline_trace_part = measurement["energy_trace"][offline_idx]
+ online_trace_part = traces[online_run_idx]["trace"][online_trace_part_idx]
- if 'offline' not in online_trace_part:
- online_trace_part['offline'] = [offline_trace_part]
+ if "offline" not in online_trace_part:
+ online_trace_part["offline"] = [offline_trace_part]
else:
- online_trace_part['offline'].append(offline_trace_part)
+ online_trace_part["offline"].append(offline_trace_part)
- paramkeys = sorted(online_trace_part['parameter'].keys())
+ paramkeys = sorted(online_trace_part["parameter"].keys())
paramvalues = list()
for paramkey in paramkeys:
- if type(online_trace_part['parameter'][paramkey]) is list:
- paramvalues.append(soft_cast_int(online_trace_part['parameter'][paramkey][measurement['repeat_id']]))
+ if type(online_trace_part["parameter"][paramkey]) is list:
+ paramvalues.append(
+ soft_cast_int(
+ online_trace_part["parameter"][paramkey][
+ measurement["repeat_id"]
+ ]
+ )
+ )
else:
- paramvalues.append(soft_cast_int(online_trace_part['parameter'][paramkey]))
+ paramvalues.append(
+ soft_cast_int(online_trace_part["parameter"][paramkey])
+ )
# NB: Unscheduled transitions do not have an 'args' field set.
# However, they should only be caused by interrupts, and
# interrupts don't have args anyways.
- if arg_support_enabled and 'args' in online_trace_part:
- paramvalues.extend(map(soft_cast_int, online_trace_part['args']))
-
- if 'offline_aggregates' not in online_trace_part:
- online_trace_part['offline_aggregates'] = {
- 'offline_attributes': ['power', 'duration', 'energy'],
- 'duration': list(),
- 'power': list(),
- 'power_std': list(),
- 'energy': list(),
- 'paramkeys': list(),
- 'param': list()
+ if arg_support_enabled and "args" in online_trace_part:
+ paramvalues.extend(map(soft_cast_int, online_trace_part["args"]))
+
+ if "offline_aggregates" not in online_trace_part:
+ online_trace_part["offline_aggregates"] = {
+ "offline_attributes": ["power", "duration", "energy"],
+ "duration": list(),
+ "power": list(),
+ "power_std": list(),
+ "energy": list(),
+ "paramkeys": list(),
+ "param": list(),
}
- offline_aggregates = online_trace_part['offline_aggregates']
+ offline_aggregates = online_trace_part["offline_aggregates"]
# if online_trace_part['isa'] == 'transitions':
# online_trace_part['offline_attributes'].extend(['rel_energy_prev', 'rel_energy_next'])
# offline_aggregates['rel_energy_prev'] = list()
# offline_aggregates['rel_energy_next'] = list()
- offline_aggregates['duration'].append(offline_trace_part['s'] * 1e6)
- offline_aggregates['power'].append(offline_trace_part['W_mean'] * 1e6)
- offline_aggregates['power_std'].append(offline_trace_part['W_std'] * 1e6)
- offline_aggregates['energy'].append(offline_trace_part['W_mean'] * offline_trace_part['s'] * 1e12)
- offline_aggregates['paramkeys'].append(paramkeys)
- offline_aggregates['param'].append(paramvalues)
+ offline_aggregates["duration"].append(offline_trace_part["s"] * 1e6)
+ offline_aggregates["power"].append(offline_trace_part["W_mean"] * 1e6)
+ offline_aggregates["power_std"].append(offline_trace_part["W_std"] * 1e6)
+ offline_aggregates["energy"].append(
+ offline_trace_part["W_mean"] * offline_trace_part["s"] * 1e12
+ )
+ offline_aggregates["paramkeys"].append(paramkeys)
+ offline_aggregates["param"].append(paramvalues)
# if online_trace_part['isa'] == 'transition':
# offline_aggregates['rel_energy_prev'].append(offline_trace_part['W_mean_delta_prev'] * offline_trace_part['s'] * 1e12)
@@ -922,8 +1041,8 @@ class RawData:
for trace in list_of_traces:
trace_output.extend(trace.copy())
for i, trace in enumerate(trace_output):
- trace['orig_id'] = trace['id']
- trace['id'] = i
+ trace["orig_id"] = trace["id"]
+ trace["id"] = i
return trace_output
def get_preprocessed_data(self, verbose=True):
@@ -1000,25 +1119,29 @@ class RawData:
if version == 0:
with tarfile.open(filename) as tf:
- self.setup_by_fileno.append(json.load(tf.extractfile('setup.json')))
- self.traces_by_fileno.append(json.load(tf.extractfile('src/apps/DriverEval/DriverLog.json')))
+ self.setup_by_fileno.append(json.load(tf.extractfile("setup.json")))
+ self.traces_by_fileno.append(
+ json.load(tf.extractfile("src/apps/DriverEval/DriverLog.json"))
+ )
for member in tf.getmembers():
_, extension = os.path.splitext(member.name)
- if extension == '.mim':
- offline_data.append({
- 'content': tf.extractfile(member).read(),
- 'fileno': i,
- 'info': member,
- 'setup': self.setup_by_fileno[i],
- 'with_traces': self.with_traces,
- })
+ if extension == ".mim":
+ offline_data.append(
+ {
+ "content": tf.extractfile(member).read(),
+ "fileno": i,
+ "info": member,
+ "setup": self.setup_by_fileno[i],
+ "with_traces": self.with_traces,
+ }
+ )
elif version == 1:
new_filenames = list()
with tarfile.open(filename) as tf:
- ptalog = json.load(tf.extractfile(tf.getmember('ptalog.json')))
- self.pta = ptalog['pta']
+ ptalog = json.load(tf.extractfile(tf.getmember("ptalog.json")))
+ self.pta = ptalog["pta"]
# Benchmark code may be too large to be executed in a single
# run, so benchmarks (a benchmark is basically a list of DFA runs)
@@ -1043,33 +1166,37 @@ class RawData:
# ptalog['files'][0][0] is its first iteration/repetition,
# ptalog['files'][0][1] the second, etc.
- for j, traces in enumerate(ptalog['traces']):
- new_filenames.append('{}#{}'.format(filename, j))
+ for j, traces in enumerate(ptalog["traces"]):
+ new_filenames.append("{}#{}".format(filename, j))
self.traces_by_fileno.append(traces)
- self.setup_by_fileno.append({
- 'mimosa_voltage': ptalog['configs'][j]['voltage'],
- 'mimosa_shunt': ptalog['configs'][j]['shunt'],
- 'state_duration': ptalog['opt']['sleep'],
- })
- for repeat_id, mim_file in enumerate(ptalog['files'][j]):
+ self.setup_by_fileno.append(
+ {
+ "mimosa_voltage": ptalog["configs"][j]["voltage"],
+ "mimosa_shunt": ptalog["configs"][j]["shunt"],
+ "state_duration": ptalog["opt"]["sleep"],
+ }
+ )
+ for repeat_id, mim_file in enumerate(ptalog["files"][j]):
member = tf.getmember(mim_file)
- offline_data.append({
- 'content': tf.extractfile(member).read(),
- 'fileno': j,
- 'info': member,
- 'setup': self.setup_by_fileno[j],
- 'repeat_id': repeat_id,
- 'expected_trace': ptalog['traces'][j],
- 'with_traces': self.with_traces,
- })
+ offline_data.append(
+ {
+ "content": tf.extractfile(member).read(),
+ "fileno": j,
+ "info": member,
+ "setup": self.setup_by_fileno[j],
+ "repeat_id": repeat_id,
+ "expected_trace": ptalog["traces"][j],
+ "with_traces": self.with_traces,
+ }
+ )
self.filenames = new_filenames
elif version == 2:
new_filenames = list()
with tarfile.open(filename) as tf:
- ptalog = json.load(tf.extractfile(tf.getmember('ptalog.json')))
- self.pta = ptalog['pta']
+ ptalog = json.load(tf.extractfile(tf.getmember("ptalog.json")))
+ self.pta = ptalog["pta"]
# Benchmark code may be too large to be executed in a single
# run, so benchmarks (a benchmark is basically a list of DFA runs)
@@ -1103,32 +1230,45 @@ class RawData:
# to an invalid measurement and thus power[b] corresponding
# to duration[C]. At the moment, this is harmless, but in the
# future it might not be.
- if 'offline_aggregates' in ptalog['traces'][0][0]['trace'][0]:
- for trace_group in ptalog['traces']:
+ if "offline_aggregates" in ptalog["traces"][0][0]["trace"][0]:
+ for trace_group in ptalog["traces"]:
for trace in trace_group:
- for state_or_transition in trace['trace']:
- offline_aggregates = state_or_transition.pop('offline_aggregates', None)
+ for state_or_transition in trace["trace"]:
+ offline_aggregates = state_or_transition.pop(
+ "offline_aggregates", None
+ )
if offline_aggregates:
- state_or_transition['online_aggregates'] = offline_aggregates
+ state_or_transition[
+ "online_aggregates"
+ ] = offline_aggregates
- for j, traces in enumerate(ptalog['traces']):
- new_filenames.append('{}#{}'.format(filename, j))
+ for j, traces in enumerate(ptalog["traces"]):
+ new_filenames.append("{}#{}".format(filename, j))
self.traces_by_fileno.append(traces)
- self.setup_by_fileno.append({
- 'voltage': ptalog['configs'][j]['voltage'],
- 'state_duration': ptalog['opt']['sleep'],
- })
- for repeat_id, etlog_file in enumerate(ptalog['files'][j]):
+ self.setup_by_fileno.append(
+ {
+ "voltage": ptalog["configs"][j]["voltage"],
+ "state_duration": ptalog["opt"]["sleep"],
+ }
+ )
+ for repeat_id, etlog_file in enumerate(ptalog["files"][j]):
member = tf.getmember(etlog_file)
- offline_data.append({
- 'content': tf.extractfile(member).read(),
- 'fileno': j,
- 'info': member,
- 'setup': self.setup_by_fileno[j],
- 'repeat_id': repeat_id,
- 'expected_trace': ptalog['traces'][j],
- 'transition_names': list(map(lambda x: x['name'], ptalog['pta']['transitions']))
- })
+ offline_data.append(
+ {
+ "content": tf.extractfile(member).read(),
+ "fileno": j,
+ "info": member,
+ "setup": self.setup_by_fileno[j],
+ "repeat_id": repeat_id,
+ "expected_trace": ptalog["traces"][j],
+ "transition_names": list(
+ map(
+ lambda x: x["name"],
+ ptalog["pta"]["transitions"],
+ )
+ ),
+ }
+ )
self.filenames = new_filenames
# TODO remove 'offline_aggregates' from pre-parse data and place
# it under 'online_aggregates' or similar instead. This way, if
@@ -1145,52 +1285,69 @@ class RawData:
num_valid = 0
for measurement in measurements:
- if 'energy_trace' not in measurement:
- vprint(self.verbose, '[W] Skipping {ar:s}/{m:s}: {e:s}'.format(
- ar=self.filenames[measurement['fileno']],
- m=measurement['info'].name,
- e='; '.join(measurement['datasource_errors'])))
+ if "energy_trace" not in measurement:
+ vprint(
+ self.verbose,
+ "[W] Skipping {ar:s}/{m:s}: {e:s}".format(
+ ar=self.filenames[measurement["fileno"]],
+ m=measurement["info"].name,
+ e="; ".join(measurement["datasource_errors"]),
+ ),
+ )
continue
if version == 0:
# Strip the last state (it is not part of the scheduled measurement)
- measurement['energy_trace'].pop()
+ measurement["energy_trace"].pop()
elif version == 1:
# The first online measurement is the UNINITIALIZED state. In v1,
# it is not part of the expected PTA trace -> remove it.
- measurement['energy_trace'].pop(0)
+ measurement["energy_trace"].pop(0)
if version == 0 or version == 1:
if self._measurement_is_valid_01(measurement):
self._merge_online_and_offline(measurement)
num_valid += 1
else:
- vprint(self.verbose, '[W] Skipping {ar:s}/{m:s}: {e:s}'.format(
- ar=self.filenames[measurement['fileno']],
- m=measurement['info'].name,
- e=measurement['error']))
+ vprint(
+ self.verbose,
+ "[W] Skipping {ar:s}/{m:s}: {e:s}".format(
+ ar=self.filenames[measurement["fileno"]],
+ m=measurement["info"].name,
+ e=measurement["error"],
+ ),
+ )
elif version == 2:
if self._measurement_is_valid_2(measurement):
self._merge_online_and_etlog(measurement)
num_valid += 1
else:
- vprint(self.verbose, '[W] Skipping {ar:s}/{m:s}: {e:s}'.format(
- ar=self.filenames[measurement['fileno']],
- m=measurement['info'].name,
- e=measurement['error']))
- vprint(self.verbose, '[I] {num_valid:d}/{num_total:d} measurements are valid'.format(
- num_valid=num_valid,
- num_total=len(measurements)))
+ vprint(
+ self.verbose,
+ "[W] Skipping {ar:s}/{m:s}: {e:s}".format(
+ ar=self.filenames[measurement["fileno"]],
+ m=measurement["info"].name,
+ e=measurement["error"],
+ ),
+ )
+ vprint(
+ self.verbose,
+ "[I] {num_valid:d}/{num_total:d} measurements are valid".format(
+ num_valid=num_valid, num_total=len(measurements)
+ ),
+ )
if version == 0:
self.traces = self._concatenate_traces(self.traces_by_fileno)
elif version == 1:
- self.traces = self._concatenate_traces(map(lambda x: x['expected_trace'], measurements))
+ self.traces = self._concatenate_traces(
+ map(lambda x: x["expected_trace"], measurements)
+ )
self.traces = self._concatenate_traces(self.traces_by_fileno)
elif version == 2:
self.traces = self._concatenate_traces(self.traces_by_fileno)
self.preprocessing_stats = {
- 'num_runs': len(measurements),
- 'num_valid': num_valid
+ "num_runs": len(measurements),
+ "num_valid": num_valid,
}
@@ -1207,16 +1364,33 @@ class ParallelParamFit:
self.fit_queue = []
self.by_param = by_param
- def enqueue(self, state_or_tran, attribute, param_index, param_name, safe_functions_enabled=False, param_filter=None):
+ def enqueue(
+ self,
+ state_or_tran,
+ attribute,
+ param_index,
+ param_name,
+ safe_functions_enabled=False,
+ param_filter=None,
+ ):
"""
Add state_or_tran/attribute/param_name to fit queue.
This causes fit() to compute the best-fitting function for this model part.
"""
- self.fit_queue.append({
- 'key': [state_or_tran, attribute, param_name, param_filter],
- 'args': [self.by_param, state_or_tran, attribute, param_index, safe_functions_enabled, param_filter]
- })
+ self.fit_queue.append(
+ {
+ "key": [state_or_tran, attribute, param_name, param_filter],
+ "args": [
+ self.by_param,
+ state_or_tran,
+ attribute,
+ param_index,
+ safe_functions_enabled,
+ param_filter,
+ ],
+ }
+ )
def fit(self):
"""
@@ -1236,13 +1410,17 @@ def _try_fits_parallel(arg):
Must be a global function as it is called from a multiprocessing Pool.
"""
- return {
- 'key': arg['key'],
- 'result': _try_fits(*arg['args'])
- }
+ return {"key": arg["key"], "result": _try_fits(*arg["args"])}
-def _try_fits(by_param, state_or_tran, model_attribute, param_index, safe_functions_enabled=False, param_filter: dict = None):
+def _try_fits(
+ by_param,
+ state_or_tran,
+ model_attribute,
+ param_index,
+ safe_functions_enabled=False,
+ param_filter: dict = None,
+):
"""
Determine goodness-of-fit for prediction of `by_param[(state_or_tran, *)][model_attribute]` dependence on `param_index` using various functions.
@@ -1281,22 +1459,28 @@ def _try_fits(by_param, state_or_tran, model_attribute, param_index, safe_functi
function_names = list(functions.keys())
for function_name in function_names:
function_object = functions[function_name]
- if is_numeric(param_key[1][param_index]) and not function_object.is_valid(param_key[1][param_index]):
+ if is_numeric(param_key[1][param_index]) and not function_object.is_valid(
+ param_key[1][param_index]
+ ):
functions.pop(function_name, None)
raw_results = dict()
raw_results_by_param = dict()
- ref_results = {
- 'mean': list(),
- 'median': list()
- }
+ ref_results = {"mean": list(), "median": list()}
results = dict()
results_by_param = dict()
seen_parameter_combinations = set()
# for each parameter combination:
- for param_key in filter(lambda x: x[0] == state_or_tran and remove_index_from_tuple(x[1], param_index) not in seen_parameter_combinations and len(by_param[x]['param']) and match_parameter_values(by_param[x]['param'][0], param_filter), by_param.keys()):
+ for param_key in filter(
+ lambda x: x[0] == state_or_tran
+ and remove_index_from_tuple(x[1], param_index)
+ not in seen_parameter_combinations
+ and len(by_param[x]["param"])
+ and match_parameter_values(by_param[x]["param"][0], param_filter),
+ by_param.keys(),
+ ):
X = []
Y = []
num_valid = 0
@@ -1304,10 +1488,14 @@ def _try_fits(by_param, state_or_tran, model_attribute, param_index, safe_functi
# Ensure that each parameter combination is only optimized once. Otherwise, with parameters (1, 2, 5), (1, 3, 5), (1, 4, 5) and param_index == 1,
# the parameter combination (1, *, 5) would be optimized three times, both wasting time and biasing results towards more frequently occuring combinations of non-param_index parameters
- seen_parameter_combinations.add(remove_index_from_tuple(param_key[1], param_index))
+ seen_parameter_combinations.add(
+ remove_index_from_tuple(param_key[1], param_index)
+ )
# for each value of the parameter denoted by param_index (all other parameters remain the same):
- for k, v in filter(lambda kv: param_slice_eq(kv[0], param_key, param_index), by_param.items()):
+ for k, v in filter(
+ lambda kv: param_slice_eq(kv[0], param_key, param_index), by_param.items()
+ ):
num_total += 1
if is_numeric(k[1][param_index]):
num_valid += 1
@@ -1324,7 +1512,9 @@ def _try_fits(by_param, state_or_tran, model_attribute, param_index, safe_functi
if function_name not in raw_results:
raw_results[function_name] = dict()
error_function = param_function.error_function
- res = optimize.least_squares(error_function, [0, 1], args=(X, Y), xtol=2e-15)
+ res = optimize.least_squares(
+ error_function, [0, 1], args=(X, Y), xtol=2e-15
+ )
measures = regression_measures(param_function.eval(res.x, X), Y)
raw_results_by_param[other_parameters][function_name] = measures
for measure, error_rate in measures.items():
@@ -1333,38 +1523,37 @@ def _try_fits(by_param, state_or_tran, model_attribute, param_index, safe_functi
raw_results[function_name][measure].append(error_rate)
# print(function_name, res, measures)
mean_measures = aggregate_measures(np.mean(Y), Y)
- ref_results['mean'].append(mean_measures['rmsd'])
- raw_results_by_param[other_parameters]['mean'] = mean_measures
+ ref_results["mean"].append(mean_measures["rmsd"])
+ raw_results_by_param[other_parameters]["mean"] = mean_measures
median_measures = aggregate_measures(np.median(Y), Y)
- ref_results['median'].append(median_measures['rmsd'])
- raw_results_by_param[other_parameters]['median'] = median_measures
+ ref_results["median"].append(median_measures["rmsd"])
+ raw_results_by_param[other_parameters]["median"] = median_measures
- if not len(ref_results['mean']):
+ if not len(ref_results["mean"]):
# Insufficient data for fitting
# print('[W] Insufficient data for fitting {}/{}/{}'.format(state_or_tran, model_attribute, param_index))
- return {
- 'best': None,
- 'best_rmsd': np.inf,
- 'results': results
- }
+ return {"best": None, "best_rmsd": np.inf, "results": results}
- for other_parameter_combination, other_parameter_results in raw_results_by_param.items():
+ for (
+ other_parameter_combination,
+ other_parameter_results,
+ ) in raw_results_by_param.items():
best_fit_val = np.inf
best_fit_name = None
results = dict()
for function_name, result in other_parameter_results.items():
if len(result) > 0:
results[function_name] = result
- rmsd = result['rmsd']
+ rmsd = result["rmsd"]
if rmsd < best_fit_val:
best_fit_val = rmsd
best_fit_name = function_name
results_by_param[other_parameter_combination] = {
- 'best': best_fit_name,
- 'best_rmsd': best_fit_val,
- 'mean_rmsd': results['mean']['rmsd'],
- 'median_rmsd': results['median']['rmsd'],
- 'results': results
+ "best": best_fit_name,
+ "best_rmsd": best_fit_val,
+ "mean_rmsd": results["mean"]["rmsd"],
+ "median_rmsd": results["median"]["rmsd"],
+ "results": results,
}
best_fit_val = np.inf
@@ -1375,26 +1564,26 @@ def _try_fits(by_param, state_or_tran, model_attribute, param_index, safe_functi
results[function_name] = {}
for measure in result.keys():
results[function_name][measure] = np.mean(result[measure])
- rmsd = results[function_name]['rmsd']
+ rmsd = results[function_name]["rmsd"]
if rmsd < best_fit_val:
best_fit_val = rmsd
best_fit_name = function_name
return {
- 'best': best_fit_name,
- 'best_rmsd': best_fit_val,
- 'mean_rmsd': np.mean(ref_results['mean']),
- 'median_rmsd': np.mean(ref_results['median']),
- 'results': results,
- 'results_by_other_param': results_by_param
+ "best": best_fit_name,
+ "best_rmsd": best_fit_val,
+ "mean_rmsd": np.mean(ref_results["mean"]),
+ "median_rmsd": np.mean(ref_results["median"]),
+ "results": results,
+ "results_by_other_param": results_by_param,
}
def _num_args_from_by_name(by_name):
num_args = dict()
for key, value in by_name.items():
- if 'args' in value:
- num_args[key] = len(value['args'][0])
+ if "args" in value:
+ num_args[key] = len(value["args"][0])
return num_args
@@ -1413,19 +1602,44 @@ def get_fit_result(results, name, attribute, verbose=False, param_filter: dict =
"""
fit_result = dict()
for result in results:
- if result['key'][0] == name and result['key'][1] == attribute and result['key'][3] == param_filter and result['result']['best'] is not None: # dürfte an ['best'] != None liegen-> Fit für gefilterten Kram schlägt fehl?
- this_result = result['result']
- if this_result['best_rmsd'] >= min(this_result['mean_rmsd'], this_result['median_rmsd']):
- vprint(verbose, '[I] Not modeling {} {} as function of {}: best ({:.0f}) is worse than ref ({:.0f}, {:.0f})'.format(
- name, attribute, result['key'][2], this_result['best_rmsd'],
- this_result['mean_rmsd'], this_result['median_rmsd']))
+ if (
+ result["key"][0] == name
+ and result["key"][1] == attribute
+ and result["key"][3] == param_filter
+ and result["result"]["best"] is not None
+ ): # dürfte an ['best'] != None liegen-> Fit für gefilterten Kram schlägt fehl?
+ this_result = result["result"]
+ if this_result["best_rmsd"] >= min(
+ this_result["mean_rmsd"], this_result["median_rmsd"]
+ ):
+ vprint(
+ verbose,
+ "[I] Not modeling {} {} as function of {}: best ({:.0f}) is worse than ref ({:.0f}, {:.0f})".format(
+ name,
+ attribute,
+ result["key"][2],
+ this_result["best_rmsd"],
+ this_result["mean_rmsd"],
+ this_result["median_rmsd"],
+ ),
+ )
# See notes on depends_on_param
- elif this_result['best_rmsd'] >= 0.8 * min(this_result['mean_rmsd'], this_result['median_rmsd']):
- vprint(verbose, '[I] Not modeling {} {} as function of {}: best ({:.0f}) is not much better than ref ({:.0f}, {:.0f})'.format(
- name, attribute, result['key'][2], this_result['best_rmsd'],
- this_result['mean_rmsd'], this_result['median_rmsd']))
+ elif this_result["best_rmsd"] >= 0.8 * min(
+ this_result["mean_rmsd"], this_result["median_rmsd"]
+ ):
+ vprint(
+ verbose,
+ "[I] Not modeling {} {} as function of {}: best ({:.0f}) is not much better than ref ({:.0f}, {:.0f})".format(
+ name,
+ attribute,
+ result["key"][2],
+ this_result["best_rmsd"],
+ this_result["mean_rmsd"],
+ this_result["median_rmsd"],
+ ),
+ )
else:
- fit_result[result['key'][2]] = this_result
+ fit_result[result["key"][2]] = this_result
return fit_result
@@ -1471,7 +1685,15 @@ class AnalyticModel:
assess -- calculate model quality
"""
- def __init__(self, by_name, parameters, arg_count=None, function_override=dict(), verbose=True, use_corrcoef=False):
+ def __init__(
+ self,
+ by_name,
+ parameters,
+ arg_count=None,
+ function_override=dict(),
+ verbose=True,
+ use_corrcoef=False,
+ ):
"""
Create a new AnalyticModel and compute parameter statistics.
@@ -1521,19 +1743,29 @@ class AnalyticModel:
if self._num_args is None:
self._num_args = _num_args_from_by_name(by_name)
- self.stats = ParamStats(self.by_name, self.by_param, self.parameters, self._num_args, verbose=verbose, use_corrcoef=use_corrcoef)
+ self.stats = ParamStats(
+ self.by_name,
+ self.by_param,
+ self.parameters,
+ self._num_args,
+ verbose=verbose,
+ use_corrcoef=use_corrcoef,
+ )
def _get_model_from_dict(self, model_dict, model_function):
model = {}
for name, elem in model_dict.items():
model[name] = {}
- for key in elem['attributes']:
+ for key in elem["attributes"]:
try:
model[name][key] = model_function(elem[key])
except RuntimeWarning:
- vprint(self.verbose, '[W] Got no data for {} {}'.format(name, key))
+ vprint(self.verbose, "[W] Got no data for {} {}".format(name, key))
except FloatingPointError as fpe:
- vprint(self.verbose, '[W] Got no data for {} {}: {}'.format(name, key, fpe))
+ vprint(
+ self.verbose,
+ "[W] Got no data for {} {}: {}".format(name, key, fpe),
+ )
return model
def param_index(self, param_name):
@@ -1596,22 +1828,28 @@ class AnalyticModel:
model_function(name, attribute, param=parameter values) -> model value.
model_info(name, attribute) -> {'fit_result' : ..., 'function' : ... } or None
"""
- if 'fitted_model_getter' in self.cache and 'fitted_info_getter' in self.cache:
- return self.cache['fitted_model_getter'], self.cache['fitted_info_getter']
+ if "fitted_model_getter" in self.cache and "fitted_info_getter" in self.cache:
+ return self.cache["fitted_model_getter"], self.cache["fitted_info_getter"]
static_model = self._get_model_from_dict(self.by_name, np.median)
param_model = dict([[name, {}] for name in self.by_name.keys()])
paramfit = ParallelParamFit(self.by_param)
for name in self.by_name.keys():
- for attribute in self.by_name[name]['attributes']:
+ for attribute in self.by_name[name]["attributes"]:
for param_index, param in enumerate(self.parameters):
if self.stats.depends_on_param(name, attribute, param):
paramfit.enqueue(name, attribute, param_index, param, False)
if arg_support_enabled and name in self._num_args:
for arg_index in range(self._num_args[name]):
if self.stats.depends_on_arg(name, attribute, arg_index):
- paramfit.enqueue(name, attribute, len(self.parameters) + arg_index, arg_index, False)
+ paramfit.enqueue(
+ name,
+ attribute,
+ len(self.parameters) + arg_index,
+ arg_index,
+ False,
+ )
paramfit.fit()
@@ -1619,8 +1857,10 @@ class AnalyticModel:
num_args = 0
if name in self._num_args:
num_args = self._num_args[name]
- for attribute in self.by_name[name]['attributes']:
- fit_result = get_fit_result(paramfit.results, name, attribute, self.verbose)
+ for attribute in self.by_name[name]["attributes"]:
+ fit_result = get_fit_result(
+ paramfit.results, name, attribute, self.verbose
+ )
if (name, attribute) in self.function_override:
function_str = self.function_override[(name, attribute)]
@@ -1628,25 +1868,27 @@ class AnalyticModel:
x.fit(self.by_param, name, attribute)
if x.fit_success:
param_model[name][attribute] = {
- 'fit_result': fit_result,
- 'function': x
+ "fit_result": fit_result,
+ "function": x,
}
elif len(fit_result.keys()):
- x = analytic.function_powerset(fit_result, self.parameters, num_args)
+ x = analytic.function_powerset(
+ fit_result, self.parameters, num_args
+ )
x.fit(self.by_param, name, attribute)
if x.fit_success:
param_model[name][attribute] = {
- 'fit_result': fit_result,
- 'function': x
+ "fit_result": fit_result,
+ "function": x,
}
def model_getter(name, key, **kwargs):
- if 'arg' in kwargs and 'param' in kwargs:
- kwargs['param'].extend(map(soft_cast_int, kwargs['arg']))
+ if "arg" in kwargs and "param" in kwargs:
+ kwargs["param"].extend(map(soft_cast_int, kwargs["arg"]))
if key in param_model[name]:
- param_list = kwargs['param']
- param_function = param_model[name][key]['function']
+ param_list = kwargs["param"]
+ param_function = param_model[name][key]["function"]
if param_function.is_predictable(param_list):
return param_function.eval(param_list)
return static_model[name][key]
@@ -1656,8 +1898,8 @@ class AnalyticModel:
return param_model[name][key]
return None
- self.cache['fitted_model_getter'] = model_getter
- self.cache['fitted_info_getter'] = info_getter
+ self.cache["fitted_model_getter"] = model_getter
+ self.cache["fitted_info_getter"] = info_getter
return model_getter, info_getter
@@ -1677,13 +1919,22 @@ class AnalyticModel:
detailed_results = {}
for name, elem in sorted(self.by_name.items()):
detailed_results[name] = {}
- for attribute in elem['attributes']:
- predicted_data = np.array(list(map(lambda i: model_function(name, attribute, param=elem['param'][i]), range(len(elem[attribute])))))
+ for attribute in elem["attributes"]:
+ predicted_data = np.array(
+ list(
+ map(
+ lambda i: model_function(
+ name, attribute, param=elem["param"][i]
+ ),
+ range(len(elem[attribute])),
+ )
+ )
+ )
measures = regression_measures(predicted_data, elem[attribute])
detailed_results[name][attribute] = measures
return {
- 'by_name': detailed_results,
+ "by_name": detailed_results,
}
def to_json(self):
@@ -1695,25 +1946,28 @@ def _add_trace_data_to_aggregate(aggregate, key, element):
# Only cares about element['isa'], element['offline_aggregates'], and
# element['plan']['level']
if key not in aggregate:
- aggregate[key] = {
- 'isa': element['isa']
- }
- for datakey in element['offline_aggregates'].keys():
+ aggregate[key] = {"isa": element["isa"]}
+ for datakey in element["offline_aggregates"].keys():
aggregate[key][datakey] = []
- if element['isa'] == 'state':
- aggregate[key]['attributes'] = ['power']
+ if element["isa"] == "state":
+ aggregate[key]["attributes"] = ["power"]
else:
# TODO do not hardcode values
- aggregate[key]['attributes'] = ['duration', 'energy', 'rel_energy_prev', 'rel_energy_next']
+ aggregate[key]["attributes"] = [
+ "duration",
+ "energy",
+ "rel_energy_prev",
+ "rel_energy_next",
+ ]
# Uncomment this line if you also want to analyze mean transition power
# aggrgate[key]['attributes'].append('power')
- if 'plan' in element and element['plan']['level'] == 'epilogue':
- aggregate[key]['attributes'].insert(0, 'timeout')
- attributes = aggregate[key]['attributes'].copy()
+ if "plan" in element and element["plan"]["level"] == "epilogue":
+ aggregate[key]["attributes"].insert(0, "timeout")
+ attributes = aggregate[key]["attributes"].copy()
for attribute in attributes:
- if attribute not in element['offline_aggregates']:
- aggregate[key]['attributes'].remove(attribute)
- for datakey, dataval in element['offline_aggregates'].items():
+ if attribute not in element["offline_aggregates"]:
+ aggregate[key]["attributes"].remove(attribute)
+ for datakey, dataval in element["offline_aggregates"].items():
aggregate[key][datakey].extend(dataval)
@@ -1771,16 +2025,20 @@ def pta_trace_to_aggregate(traces, ignore_trace_indexes=[]):
"""
arg_count = dict()
by_name = dict()
- parameter_names = sorted(traces[0]['trace'][0]['parameter'].keys())
+ parameter_names = sorted(traces[0]["trace"][0]["parameter"].keys())
for run in traces:
- if run['id'] not in ignore_trace_indexes:
- for elem in run['trace']:
- if elem['isa'] == 'transition' and not elem['name'] in arg_count and 'args' in elem:
- arg_count[elem['name']] = len(elem['args'])
- if elem['name'] != 'UNINITIALIZED':
- _add_trace_data_to_aggregate(by_name, elem['name'], elem)
+ if run["id"] not in ignore_trace_indexes:
+ for elem in run["trace"]:
+ if (
+ elem["isa"] == "transition"
+ and not elem["name"] in arg_count
+ and "args" in elem
+ ):
+ arg_count[elem["name"]] = len(elem["args"])
+ if elem["name"] != "UNINITIALIZED":
+ _add_trace_data_to_aggregate(by_name, elem["name"], elem)
for elem in by_name.values():
- for key in elem['attributes']:
+ for key in elem["attributes"]:
elem[key] = np.array(elem[key])
return by_name, parameter_names, arg_count
@@ -1817,7 +2075,19 @@ class PTAModel:
- rel_energy_next: transition energy relative to next state mean power in pJ
"""
- def __init__(self, by_name, parameters, arg_count, traces=[], ignore_trace_indexes=[], discard_outliers=None, function_override={}, verbose=True, use_corrcoef=False, pta=None):
+ def __init__(
+ self,
+ by_name,
+ parameters,
+ arg_count,
+ traces=[],
+ ignore_trace_indexes=[],
+ discard_outliers=None,
+ function_override={},
+ verbose=True,
+ use_corrcoef=False,
+ pta=None,
+ ):
"""
Prepare a new PTA energy model.
@@ -1854,9 +2124,16 @@ class PTAModel:
self._num_args = arg_count
self._use_corrcoef = use_corrcoef
self.traces = traces
- self.stats = ParamStats(self.by_name, self.by_param, self._parameter_names, self._num_args, self._use_corrcoef, verbose=verbose)
+ self.stats = ParamStats(
+ self.by_name,
+ self.by_param,
+ self._parameter_names,
+ self._num_args,
+ self._use_corrcoef,
+ verbose=verbose,
+ )
self.cache = {}
- np.seterr('raise')
+ np.seterr("raise")
self._outlier_threshold = discard_outliers
self.function_override = function_override.copy()
self.verbose = verbose
@@ -1866,7 +2143,7 @@ class PTAModel:
def _aggregate_to_ndarray(self, aggregate):
for elem in aggregate.values():
- for key in elem['attributes']:
+ for key in elem["attributes"]:
elem[key] = np.array(elem[key])
# This heuristic is very similar to the "function is not much better than
@@ -1884,13 +2161,16 @@ class PTAModel:
model = {}
for name, elem in model_dict.items():
model[name] = {}
- for key in elem['attributes']:
+ for key in elem["attributes"]:
try:
model[name][key] = model_function(elem[key])
except RuntimeWarning:
- vprint(self.verbose, '[W] Got no data for {} {}'.format(name, key))
+ vprint(self.verbose, "[W] Got no data for {} {}".format(name, key))
except FloatingPointError as fpe:
- vprint(self.verbose, '[W] Got no data for {} {}: {}'.format(name, key, fpe))
+ vprint(
+ self.verbose,
+ "[W] Got no data for {} {}: {}".format(name, key, fpe),
+ )
return model
def get_static(self, use_mean=False):
@@ -1953,63 +2233,110 @@ class PTAModel:
model_function(name, attribute, param=parameter values) -> model value.
model_info(name, attribute) -> {'fit_result' : ..., 'function' : ... } or None
"""
- if 'fitted_model_getter' in self.cache and 'fitted_info_getter' in self.cache:
- return self.cache['fitted_model_getter'], self.cache['fitted_info_getter']
+ if "fitted_model_getter" in self.cache and "fitted_info_getter" in self.cache:
+ return self.cache["fitted_model_getter"], self.cache["fitted_info_getter"]
static_model = self._get_model_from_dict(self.by_name, np.median)
- param_model = dict([[state_or_tran, {}] for state_or_tran in self.by_name.keys()])
+ param_model = dict(
+ [[state_or_tran, {}] for state_or_tran in self.by_name.keys()]
+ )
paramfit = ParallelParamFit(self.by_param)
for state_or_tran in self.by_name.keys():
- for model_attribute in self.by_name[state_or_tran]['attributes']:
+ for model_attribute in self.by_name[state_or_tran]["attributes"]:
fit_results = {}
for parameter_index, parameter_name in enumerate(self._parameter_names):
- if self.depends_on_param(state_or_tran, model_attribute, parameter_name):
- paramfit.enqueue(state_or_tran, model_attribute, parameter_index, parameter_name, safe_functions_enabled)
- for codependent_param_dict in self.stats.codependent_parameter_value_dicts(state_or_tran, model_attribute, parameter_name):
- paramfit.enqueue(state_or_tran, model_attribute, parameter_index, parameter_name, safe_functions_enabled, codependent_param_dict)
- if arg_support_enabled and self.by_name[state_or_tran]['isa'] == 'transition':
+ if self.depends_on_param(
+ state_or_tran, model_attribute, parameter_name
+ ):
+ paramfit.enqueue(
+ state_or_tran,
+ model_attribute,
+ parameter_index,
+ parameter_name,
+ safe_functions_enabled,
+ )
+ for (
+ codependent_param_dict
+ ) in self.stats.codependent_parameter_value_dicts(
+ state_or_tran, model_attribute, parameter_name
+ ):
+ paramfit.enqueue(
+ state_or_tran,
+ model_attribute,
+ parameter_index,
+ parameter_name,
+ safe_functions_enabled,
+ codependent_param_dict,
+ )
+ if (
+ arg_support_enabled
+ and self.by_name[state_or_tran]["isa"] == "transition"
+ ):
for arg_index in range(self._num_args[state_or_tran]):
- if self.depends_on_arg(state_or_tran, model_attribute, arg_index):
- paramfit.enqueue(state_or_tran, model_attribute, len(self._parameter_names) + arg_index, arg_index, safe_functions_enabled)
+ if self.depends_on_arg(
+ state_or_tran, model_attribute, arg_index
+ ):
+ paramfit.enqueue(
+ state_or_tran,
+ model_attribute,
+ len(self._parameter_names) + arg_index,
+ arg_index,
+ safe_functions_enabled,
+ )
paramfit.fit()
for state_or_tran in self.by_name.keys():
num_args = 0
- if arg_support_enabled and self.by_name[state_or_tran]['isa'] == 'transition':
+ if (
+ arg_support_enabled
+ and self.by_name[state_or_tran]["isa"] == "transition"
+ ):
num_args = self._num_args[state_or_tran]
- for model_attribute in self.by_name[state_or_tran]['attributes']:
- fit_results = get_fit_result(paramfit.results, state_or_tran, model_attribute, self.verbose)
+ for model_attribute in self.by_name[state_or_tran]["attributes"]:
+ fit_results = get_fit_result(
+ paramfit.results, state_or_tran, model_attribute, self.verbose
+ )
for parameter_name in self._parameter_names:
- if self.depends_on_param(state_or_tran, model_attribute, parameter_name):
- for codependent_param_dict in self.stats.codependent_parameter_value_dicts(state_or_tran, model_attribute, parameter_name):
+ if self.depends_on_param(
+ state_or_tran, model_attribute, parameter_name
+ ):
+ for (
+ codependent_param_dict
+ ) in self.stats.codependent_parameter_value_dicts(
+ state_or_tran, model_attribute, parameter_name
+ ):
pass
# FIXME get_fit_result hat ja gar keinen Parameter als Argument...
if (state_or_tran, model_attribute) in self.function_override:
- function_str = self.function_override[(state_or_tran, model_attribute)]
+ function_str = self.function_override[
+ (state_or_tran, model_attribute)
+ ]
x = AnalyticFunction(function_str, self._parameter_names, num_args)
x.fit(self.by_param, state_or_tran, model_attribute)
if x.fit_success:
param_model[state_or_tran][model_attribute] = {
- 'fit_result': fit_results,
- 'function': x
+ "fit_result": fit_results,
+ "function": x,
}
elif len(fit_results.keys()):
- x = analytic.function_powerset(fit_results, self._parameter_names, num_args)
+ x = analytic.function_powerset(
+ fit_results, self._parameter_names, num_args
+ )
x.fit(self.by_param, state_or_tran, model_attribute)
if x.fit_success:
param_model[state_or_tran][model_attribute] = {
- 'fit_result': fit_results,
- 'function': x
+ "fit_result": fit_results,
+ "function": x,
}
def model_getter(name, key, **kwargs):
- if 'arg' in kwargs and 'param' in kwargs:
- kwargs['param'].extend(map(soft_cast_int, kwargs['arg']))
+ if "arg" in kwargs and "param" in kwargs:
+ kwargs["param"].extend(map(soft_cast_int, kwargs["arg"]))
if key in param_model[name]:
- param_list = kwargs['param']
- param_function = param_model[name][key]['function']
+ param_list = kwargs["param"]
+ param_function = param_model[name][key]["function"]
if param_function.is_predictable(param_list):
return param_function.eval(param_list)
return static_model[name][key]
@@ -2019,8 +2346,8 @@ class PTAModel:
return param_model[name][key]
return None
- self.cache['fitted_model_getter'] = model_getter
- self.cache['fitted_info_getter'] = info_getter
+ self.cache["fitted_model_getter"] = model_getter
+ self.cache["fitted_info_getter"] = info_getter
return model_getter, info_getter
@@ -2029,16 +2356,32 @@ class PTAModel:
static_quality = self.assess(static_model)
param_model, param_info = self.get_fitted()
analytic_quality = self.assess(param_model)
- self.pta.update(static_model, param_info, static_error=static_quality['by_name'], analytic_error=analytic_quality['by_name'])
+ self.pta.update(
+ static_model,
+ param_info,
+ static_error=static_quality["by_name"],
+ analytic_error=analytic_quality["by_name"],
+ )
return self.pta.to_json()
def states(self):
"""Return sorted list of state names."""
- return sorted(list(filter(lambda k: self.by_name[k]['isa'] == 'state', self.by_name.keys())))
+ return sorted(
+ list(
+ filter(lambda k: self.by_name[k]["isa"] == "state", self.by_name.keys())
+ )
+ )
def transitions(self):
"""Return sorted list of transition names."""
- return sorted(list(filter(lambda k: self.by_name[k]['isa'] == 'transition', self.by_name.keys())))
+ return sorted(
+ list(
+ filter(
+ lambda k: self.by_name[k]["isa"] == "transition",
+ self.by_name.keys(),
+ )
+ )
+ )
def states_and_transitions(self):
"""Return list of states and transition names."""
@@ -2050,7 +2393,7 @@ class PTAModel:
return self._parameter_names
def attributes(self, state_or_trans):
- return self.by_name[state_or_trans]['attributes']
+ return self.by_name[state_or_trans]["attributes"]
def assess(self, model_function):
"""
@@ -2068,16 +2411,23 @@ class PTAModel:
detailed_results = {}
for name, elem in sorted(self.by_name.items()):
detailed_results[name] = {}
- for key in elem['attributes']:
- predicted_data = np.array(list(map(lambda i: model_function(name, key, param=elem['param'][i]), range(len(elem[key])))))
+ for key in elem["attributes"]:
+ predicted_data = np.array(
+ list(
+ map(
+ lambda i: model_function(name, key, param=elem["param"][i]),
+ range(len(elem[key])),
+ )
+ )
+ )
measures = regression_measures(predicted_data, elem[key])
detailed_results[name][key] = measures
- return {
- 'by_name': detailed_results
- }
+ return {"by_name": detailed_results}
- def assess_states(self, model_function, model_attribute='power', distribution: dict = None):
+ def assess_states(
+ self, model_function, model_attribute="power", distribution: dict = None
+ ):
"""
Calculate overall model error assuming equal distribution of states
"""
@@ -2089,7 +2439,9 @@ class PTAModel:
distribution = dict(map(lambda x: [x, 1 / num_states], self.states()))
if not np.isclose(sum(distribution.values()), 1):
- raise ValueError('distribution must be a probability distribution with sum 1')
+ raise ValueError(
+ "distribution must be a probability distribution with sum 1"
+ )
# total_value = None
# try:
@@ -2097,7 +2449,17 @@ class PTAModel:
# except KeyError:
# pass
- total_error = np.sqrt(sum(map(lambda x: np.square(model_quality['by_name'][x][model_attribute]['mae'] * distribution[x]), self.states())))
+ total_error = np.sqrt(
+ sum(
+ map(
+ lambda x: np.square(
+ model_quality["by_name"][x][model_attribute]["mae"]
+ * distribution[x]
+ ),
+ self.states(),
+ )
+ )
+ )
return total_error
def assess_on_traces(self, model_function):
@@ -2118,44 +2480,72 @@ class PTAModel:
real_timeout_list = []
for trace in self.traces:
- if trace['id'] not in self.ignore_trace_indexes:
- for rep_id in range(len(trace['trace'][0]['offline'])):
- model_energy = 0.
- real_energy = 0.
- model_rel_energy = 0.
- model_state_energy = 0.
- model_duration = 0.
- real_duration = 0.
- model_timeout = 0.
- real_timeout = 0.
- for i, trace_part in enumerate(trace['trace']):
- name = trace_part['name']
- prev_name = trace['trace'][i - 1]['name']
- isa = trace_part['isa']
- if name != 'UNINITIALIZED':
+ if trace["id"] not in self.ignore_trace_indexes:
+ for rep_id in range(len(trace["trace"][0]["offline"])):
+ model_energy = 0.0
+ real_energy = 0.0
+ model_rel_energy = 0.0
+ model_state_energy = 0.0
+ model_duration = 0.0
+ real_duration = 0.0
+ model_timeout = 0.0
+ real_timeout = 0.0
+ for i, trace_part in enumerate(trace["trace"]):
+ name = trace_part["name"]
+ prev_name = trace["trace"][i - 1]["name"]
+ isa = trace_part["isa"]
+ if name != "UNINITIALIZED":
try:
- param = trace_part['offline_aggregates']['param'][rep_id]
- prev_param = trace['trace'][i - 1]['offline_aggregates']['param'][rep_id]
- power = trace_part['offline'][rep_id]['uW_mean']
- duration = trace_part['offline'][rep_id]['us']
- prev_duration = trace['trace'][i - 1]['offline'][rep_id]['us']
+ param = trace_part["offline_aggregates"]["param"][
+ rep_id
+ ]
+ prev_param = trace["trace"][i - 1][
+ "offline_aggregates"
+ ]["param"][rep_id]
+ power = trace_part["offline"][rep_id]["uW_mean"]
+ duration = trace_part["offline"][rep_id]["us"]
+ prev_duration = trace["trace"][i - 1]["offline"][
+ rep_id
+ ]["us"]
real_energy += power * duration
- if isa == 'state':
- model_energy += model_function(name, 'power', param=param) * duration
+ if isa == "state":
+ model_energy += (
+ model_function(name, "power", param=param)
+ * duration
+ )
else:
- model_energy += model_function(name, 'energy', param=param)
+ model_energy += model_function(
+ name, "energy", param=param
+ )
# If i == 1, the previous state was UNINITIALIZED, for which we do not have model data
if i == 1:
- model_rel_energy += model_function(name, 'energy', param=param)
+ model_rel_energy += model_function(
+ name, "energy", param=param
+ )
else:
- model_rel_energy += model_function(prev_name, 'power', param=prev_param) * (prev_duration + duration)
- model_state_energy += model_function(prev_name, 'power', param=prev_param) * (prev_duration + duration)
- model_rel_energy += model_function(name, 'rel_energy_prev', param=param)
+ model_rel_energy += model_function(
+ prev_name, "power", param=prev_param
+ ) * (prev_duration + duration)
+ model_state_energy += model_function(
+ prev_name, "power", param=prev_param
+ ) * (prev_duration + duration)
+ model_rel_energy += model_function(
+ name, "rel_energy_prev", param=param
+ )
real_duration += duration
- model_duration += model_function(name, 'duration', param=param)
- if 'plan' in trace_part and trace_part['plan']['level'] == 'epilogue':
- real_timeout += trace_part['offline'][rep_id]['timeout']
- model_timeout += model_function(name, 'timeout', param=param)
+ model_duration += model_function(
+ name, "duration", param=param
+ )
+ if (
+ "plan" in trace_part
+ and trace_part["plan"]["level"] == "epilogue"
+ ):
+ real_timeout += trace_part["offline"][rep_id][
+ "timeout"
+ ]
+ model_timeout += model_function(
+ name, "timeout", param=param
+ )
except KeyError:
# if states/transitions have been removed via --filter-param, this is harmless
pass
@@ -2169,11 +2559,21 @@ class PTAModel:
model_timeout_list.append(model_timeout)
return {
- 'duration_by_trace': regression_measures(np.array(model_duration_list), np.array(real_duration_list)),
- 'energy_by_trace': regression_measures(np.array(model_energy_list), np.array(real_energy_list)),
- 'timeout_by_trace': regression_measures(np.array(model_timeout_list), np.array(real_timeout_list)),
- 'rel_energy_by_trace': regression_measures(np.array(model_rel_energy_list), np.array(real_energy_list)),
- 'state_energy_by_trace': regression_measures(np.array(model_state_energy_list), np.array(real_energy_list)),
+ "duration_by_trace": regression_measures(
+ np.array(model_duration_list), np.array(real_duration_list)
+ ),
+ "energy_by_trace": regression_measures(
+ np.array(model_energy_list), np.array(real_energy_list)
+ ),
+ "timeout_by_trace": regression_measures(
+ np.array(model_timeout_list), np.array(real_timeout_list)
+ ),
+ "rel_energy_by_trace": regression_measures(
+ np.array(model_rel_energy_list), np.array(real_energy_list)
+ ),
+ "state_energy_by_trace": regression_measures(
+ np.array(model_state_energy_list), np.array(real_energy_list)
+ ),
}
@@ -2230,17 +2630,19 @@ class EnergyTraceLog:
"""
if not zbar_available:
- self.errors.append('zbar module is not available. Try "apt install python3-zbar"')
+ self.errors.append(
+ 'zbar module is not available. Try "apt install python3-zbar"'
+ )
return list()
- lines = log_data.decode('ascii').split('\n')
- data_count = sum(map(lambda x: len(x) > 0 and x[0] != '#', lines))
- data_lines = filter(lambda x: len(x) > 0 and x[0] != '#', lines)
+ lines = log_data.decode("ascii").split("\n")
+ data_count = sum(map(lambda x: len(x) > 0 and x[0] != "#", lines))
+ data_lines = filter(lambda x: len(x) > 0 and x[0] != "#", lines)
data = np.empty((data_count, 4))
for i, line in enumerate(data_lines):
- fields = line.split(' ')
+ fields = line.split(" ")
if len(fields) == 4:
timestamp, current, voltage, total_energy = map(int, fields)
elif len(fields) == 5:
@@ -2252,15 +2654,26 @@ class EnergyTraceLog:
self.interval_start_timestamp = data[:-1, 0] * 1e-6
self.interval_duration = (data[1:, 0] - data[:-1, 0]) * 1e-6
- self.interval_power = ((data[1:, 3] - data[:-1, 3]) * 1e-9) / ((data[1:, 0] - data[:-1, 0]) * 1e-6)
+ self.interval_power = ((data[1:, 3] - data[:-1, 3]) * 1e-9) / (
+ (data[1:, 0] - data[:-1, 0]) * 1e-6
+ )
m_duration_us = data[-1, 0] - data[0, 0]
self.sample_rate = data_count / (m_duration_us * 1e-6)
- vprint(self.verbose, 'got {} samples with {} seconds of log data ({} Hz)'.format(data_count, m_duration_us * 1e-6, self.sample_rate))
+ vprint(
+ self.verbose,
+ "got {} samples with {} seconds of log data ({} Hz)".format(
+ data_count, m_duration_us * 1e-6, self.sample_rate
+ ),
+ )
- return self.interval_start_timestamp, self.interval_duration, self.interval_power
+ return (
+ self.interval_start_timestamp,
+ self.interval_duration,
+ self.interval_power,
+ )
def ts_to_index(self, timestamp):
"""
@@ -2279,7 +2692,12 @@ class EnergyTraceLog:
mid_index = left_index + (right_index - left_index) // 2
# I'm feeling lucky
- if timestamp > self.interval_start_timestamp[mid_index] and timestamp <= self.interval_start_timestamp[mid_index] + self.interval_duration[mid_index]:
+ if (
+ timestamp > self.interval_start_timestamp[mid_index]
+ and timestamp
+ <= self.interval_start_timestamp[mid_index]
+ + self.interval_duration[mid_index]
+ ):
return mid_index
if timestamp <= self.interval_start_timestamp[mid_index]:
@@ -2322,16 +2740,29 @@ class EnergyTraceLog:
expected_transitions = list()
for trace_number, trace in enumerate(traces):
- for state_or_transition_number, state_or_transition in enumerate(trace['trace']):
- if state_or_transition['isa'] == 'transition':
+ for state_or_transition_number, state_or_transition in enumerate(
+ trace["trace"]
+ ):
+ if state_or_transition["isa"] == "transition":
try:
- expected_transitions.append((
- state_or_transition['name'],
- state_or_transition['online_aggregates']['duration'][offline_index] * 1e-6
- ))
+ expected_transitions.append(
+ (
+ state_or_transition["name"],
+ state_or_transition["online_aggregates"]["duration"][
+ offline_index
+ ]
+ * 1e-6,
+ )
+ )
except IndexError:
- self.errors.append('Entry #{} ("{}") in trace #{} has no duration entry for offline_index/repeat_id {}'.format(
- state_or_transition_number, state_or_transition['name'], trace_number, offline_index))
+ self.errors.append(
+ 'Entry #{} ("{}") in trace #{} has no duration entry for offline_index/repeat_id {}'.format(
+ state_or_transition_number,
+ state_or_transition["name"],
+ trace_number,
+ offline_index,
+ )
+ )
return energy_trace
next_barcode = first_sync
@@ -2342,51 +2773,101 @@ class EnergyTraceLog:
print('[!!!] did not find transition "{}"'.format(name))
break
next_barcode = end + self.state_duration + duration
- vprint(self.verbose, '{} barcode "{}" area: {:0.2f} .. {:0.2f} / {:0.2f} seconds'.format(offline_index, bc, start, stop, end))
+ vprint(
+ self.verbose,
+ '{} barcode "{}" area: {:0.2f} .. {:0.2f} / {:0.2f} seconds'.format(
+ offline_index, bc, start, stop, end
+ ),
+ )
if bc != name:
- vprint(self.verbose, '[!!!] mismatch: expected "{}", got "{}"'.format(name, bc))
- vprint(self.verbose, '{} estimated transition area: {:0.3f} .. {:0.3f} seconds'.format(offline_index, end, end + duration))
+ vprint(
+ self.verbose,
+ '[!!!] mismatch: expected "{}", got "{}"'.format(name, bc),
+ )
+ vprint(
+ self.verbose,
+ "{} estimated transition area: {:0.3f} .. {:0.3f} seconds".format(
+ offline_index, end, end + duration
+ ),
+ )
transition_start_index = self.ts_to_index(end)
transition_done_index = self.ts_to_index(end + duration) + 1
state_start_index = transition_done_index
- state_done_index = self.ts_to_index(end + duration + self.state_duration) + 1
-
- vprint(self.verbose, '{} estimated transitionindex: {:0.3f} .. {:0.3f} seconds'.format(offline_index, transition_start_index / self.sample_rate, transition_done_index / self.sample_rate))
+ state_done_index = (
+ self.ts_to_index(end + duration + self.state_duration) + 1
+ )
- energy_trace.append({
- 'isa': 'transition',
- 'W_mean': np.mean(self.interval_power[transition_start_index: transition_done_index]),
- 'W_std': np.std(self.interval_power[transition_start_index: transition_done_index]),
- 's': duration,
- 's_coarse': self.interval_start_timestamp[transition_done_index] - self.interval_start_timestamp[transition_start_index]
+ vprint(
+ self.verbose,
+ "{} estimated transitionindex: {:0.3f} .. {:0.3f} seconds".format(
+ offline_index,
+ transition_start_index / self.sample_rate,
+ transition_done_index / self.sample_rate,
+ ),
+ )
- })
+ energy_trace.append(
+ {
+ "isa": "transition",
+ "W_mean": np.mean(
+ self.interval_power[
+ transition_start_index:transition_done_index
+ ]
+ ),
+ "W_std": np.std(
+ self.interval_power[
+ transition_start_index:transition_done_index
+ ]
+ ),
+ "s": duration,
+ "s_coarse": self.interval_start_timestamp[transition_done_index]
+ - self.interval_start_timestamp[transition_start_index],
+ }
+ )
if len(energy_trace) > 1:
- energy_trace[-1]['W_mean_delta_prev'] = energy_trace[-1]['W_mean'] - energy_trace[-2]['W_mean']
+ energy_trace[-1]["W_mean_delta_prev"] = (
+ energy_trace[-1]["W_mean"] - energy_trace[-2]["W_mean"]
+ )
- energy_trace.append({
- 'isa': 'state',
- 'W_mean': np.mean(self.interval_power[state_start_index: state_done_index]),
- 'W_std': np.std(self.interval_power[state_start_index: state_done_index]),
- 's': self.state_duration,
- 's_coarse': self.interval_start_timestamp[state_done_index] - self.interval_start_timestamp[state_start_index]
- })
+ energy_trace.append(
+ {
+ "isa": "state",
+ "W_mean": np.mean(
+ self.interval_power[state_start_index:state_done_index]
+ ),
+ "W_std": np.std(
+ self.interval_power[state_start_index:state_done_index]
+ ),
+ "s": self.state_duration,
+ "s_coarse": self.interval_start_timestamp[state_done_index]
+ - self.interval_start_timestamp[state_start_index],
+ }
+ )
- energy_trace[-2]['W_mean_delta_next'] = energy_trace[-2]['W_mean'] - energy_trace[-1]['W_mean']
+ energy_trace[-2]["W_mean_delta_next"] = (
+ energy_trace[-2]["W_mean"] - energy_trace[-1]["W_mean"]
+ )
expected_transition_count = len(expected_transitions)
recovered_transition_ount = len(energy_trace) // 2
if expected_transition_count != recovered_transition_ount:
- self.errors.append('Expected {:d} transitions, got {:d}'.format(expected_transition_count, recovered_transition_ount))
+ self.errors.append(
+ "Expected {:d} transitions, got {:d}".format(
+ expected_transition_count, recovered_transition_ount
+ )
+ )
return energy_trace
def find_first_sync(self):
# LED Power is approx. self.led_power W, use self.led_power/2 W above surrounding median as threshold
- sync_threshold_power = np.median(self.interval_power[: int(3 * self.sample_rate)]) + self.led_power / 3
+ sync_threshold_power = (
+ np.median(self.interval_power[: int(3 * self.sample_rate)])
+ + self.led_power / 3
+ )
for i, ts in enumerate(self.interval_start_timestamp):
if ts > 2 and self.interval_power[i] > sync_threshold_power:
return self.interval_start_timestamp[i - 300]
@@ -2410,26 +2891,56 @@ class EnergyTraceLog:
lookaround = int(0.1 * self.sample_rate)
# LED Power is approx. self.led_power W, use self.led_power/2 W above surrounding median as threshold
- sync_threshold_power = np.median(self.interval_power[start_position - lookaround: start_position + lookaround]) + self.led_power / 3
+ sync_threshold_power = (
+ np.median(
+ self.interval_power[
+ start_position - lookaround : start_position + lookaround
+ ]
+ )
+ + self.led_power / 3
+ )
- vprint(self.verbose, 'looking for barcode starting at {:0.2f} s, threshold is {:0.1f} mW'.format(start_ts, sync_threshold_power * 1e3))
+ vprint(
+ self.verbose,
+ "looking for barcode starting at {:0.2f} s, threshold is {:0.1f} mW".format(
+ start_ts, sync_threshold_power * 1e3
+ ),
+ )
sync_area_start = None
sync_start_ts = None
sync_area_end = None
sync_end_ts = None
for i, ts in enumerate(self.interval_start_timestamp):
- if sync_area_start is None and ts >= start_ts and self.interval_power[i] > sync_threshold_power:
+ if (
+ sync_area_start is None
+ and ts >= start_ts
+ and self.interval_power[i] > sync_threshold_power
+ ):
sync_area_start = i - 300
sync_start_ts = ts
- if sync_area_start is not None and sync_area_end is None and ts > sync_start_ts + self.min_barcode_duration and (ts > sync_start_ts + self.max_barcode_duration or abs(sync_threshold_power - self.interval_power[i]) > self.led_power):
+ if (
+ sync_area_start is not None
+ and sync_area_end is None
+ and ts > sync_start_ts + self.min_barcode_duration
+ and (
+ ts > sync_start_ts + self.max_barcode_duration
+ or abs(sync_threshold_power - self.interval_power[i])
+ > self.led_power
+ )
+ ):
sync_area_end = i
sync_end_ts = ts
break
- barcode_data = self.interval_power[sync_area_start: sync_area_end]
+ barcode_data = self.interval_power[sync_area_start:sync_area_end]
- vprint(self.verbose, 'barcode search area: {:0.2f} .. {:0.2f} seconds ({} samples)'.format(sync_start_ts, sync_end_ts, len(barcode_data)))
+ vprint(
+ self.verbose,
+ "barcode search area: {:0.2f} .. {:0.2f} seconds ({} samples)".format(
+ sync_start_ts, sync_end_ts, len(barcode_data)
+ ),
+ )
bc, start, stop, padding_bits = self.find_barcode_in_power_data(barcode_data)
@@ -2439,7 +2950,9 @@ class EnergyTraceLog:
start_ts = self.interval_start_timestamp[sync_area_start + start]
stop_ts = self.interval_start_timestamp[sync_area_start + stop]
- end_ts = stop_ts + self.module_duration * padding_bits + self.quiet_zone_duration
+ end_ts = (
+ stop_ts + self.module_duration * padding_bits + self.quiet_zone_duration
+ )
# barcode content, barcode start timestamp, barcode stop timestamp, barcode end (stop + padding) timestamp
return bc, start_ts, stop_ts, end_ts
@@ -2455,7 +2968,9 @@ class EnergyTraceLog:
# -> Create a black and white (not grayscale) image to avoid this.
# Unfortunately, this decreases resilience against background noise
# (e.g. a not-exactly-idle peripheral device or CPU interrupts).
- image_data = np.around(1 - ((barcode_data - min_power) / (max_power - min_power)))
+ image_data = np.around(
+ 1 - ((barcode_data - min_power) / (max_power - min_power))
+ )
image_data *= 255
# zbar only returns the complete barcode position if it is at least
@@ -2469,12 +2984,12 @@ class EnergyTraceLog:
# img = Image.frombytes('L', (width, height), image_data).resize((width, 100))
# img.save('/tmp/test-{}.png'.format(os.getpid()))
- zbimg = zbar.Image(width, height, 'Y800', image_data)
+ zbimg = zbar.Image(width, height, "Y800", image_data)
scanner = zbar.ImageScanner()
- scanner.parse_config('enable')
+ scanner.parse_config("enable")
if scanner.scan(zbimg):
- sym, = zbimg.symbols
+ (sym,) = zbimg.symbols
content = sym.data
try:
sym_start = sym.location[1][0]
@@ -2482,7 +2997,7 @@ class EnergyTraceLog:
sym_start = 0
sym_end = sym.location[0][0]
- match = re.fullmatch(r'T(\d+)', content)
+ match = re.fullmatch(r"T(\d+)", content)
if match:
content = self.transition_names[int(match.group(1))]
@@ -2490,7 +3005,7 @@ class EnergyTraceLog:
# additional non-barcode padding (encoded as LED off / image white).
# Calculate the amount of extra bits to determine the offset until
# the transition starts.
- padding_bits = len(Code128(sym.data, charset='B').modules) % 8
+ padding_bits = len(Code128(sym.data, charset="B").modules) % 8
# sym_start leaves out the first two bars, but we don't do anything about that here
# sym_end leaves out the last three bars, each of which is one padding bit long.
@@ -2499,7 +3014,7 @@ class EnergyTraceLog:
return content, sym_start, sym_end, padding_bits
else:
- vprint(self.verbose, 'unable to find barcode')
+ vprint(self.verbose, "unable to find barcode")
return None, None, None, None
@@ -2555,15 +3070,15 @@ class MIMOSA:
:returns: (numpy array of charges (pJ per 10µs), numpy array of triggers (0/1 int, per 10µs))
"""
- num_bytes = tf.getmember('/tmp/mimosa//mimosa_scale_1.tmp').size
+ num_bytes = tf.getmember("/tmp/mimosa//mimosa_scale_1.tmp").size
charges = np.ndarray(shape=(int(num_bytes / 4)), dtype=np.int32)
triggers = np.ndarray(shape=(int(num_bytes / 4)), dtype=np.int8)
- with tf.extractfile('/tmp/mimosa//mimosa_scale_1.tmp') as f:
+ with tf.extractfile("/tmp/mimosa//mimosa_scale_1.tmp") as f:
content = f.read()
- iterator = struct.iter_unpack('<I', content)
+ iterator = struct.iter_unpack("<I", content)
i = 0
for word in iterator:
- charges[i] = (word[0] >> 4)
+ charges[i] = word[0] >> 4
triggers[i] = (word[0] & 0x08) >> 3
i += 1
return charges, triggers
@@ -2616,7 +3131,7 @@ class MIMOSA:
trigidx = []
if len(triggers) < 1000000:
- self.errors.append('MIMOSA log is too short')
+ self.errors.append("MIMOSA log is too short")
return trigidx
prevtrig = triggers[999999]
@@ -2625,13 +3140,17 @@ class MIMOSA:
# something went wrong and are unable to determine when the first
# transition starts.
if prevtrig != 0:
- self.errors.append('Unable to find start of first transition (log starts with trigger == {} != 0)'.format(prevtrig))
+ self.errors.append(
+ "Unable to find start of first transition (log starts with trigger == {} != 0)".format(
+ prevtrig
+ )
+ )
# if the last trigger is high (i.e., trigger/buzzer pin is active when the benchmark ends),
# it terminated in the middle of a transition -- meaning that it was not
# measured in its entirety.
if triggers[-1] != 0:
- self.errors.append('Log ends during a transition'.format(prevtrig))
+ self.errors.append("Log ends during a transition".format(prevtrig))
# the device is reset for MIMOSA calibration in the first 10s and may
# send bogus interrupts -> bogus triggers
@@ -2663,11 +3182,23 @@ class MIMOSA:
for i in range(100000, len(currents)):
if r1idx == 0 and currents[i] > ua_r1 * 0.6:
r1idx = i
- elif r1idx != 0 and r2idx == 0 and i > (r1idx + 180000) and currents[i] < ua_r1 * 0.4:
+ elif (
+ r1idx != 0
+ and r2idx == 0
+ and i > (r1idx + 180000)
+ and currents[i] < ua_r1 * 0.4
+ ):
r2idx = i
# 2s disconnected, 2s r1, 2s r2 with r1 < r2 -> ua_r1 > ua_r2
# allow 5ms buffer in both directions to account for bouncing relais contacts
- return r1idx - 180500, r1idx - 500, r1idx + 500, r2idx - 500, r2idx + 500, r2idx + 180500
+ return (
+ r1idx - 180500,
+ r1idx - 500,
+ r1idx + 500,
+ r2idx - 500,
+ r2idx + 500,
+ r2idx + 180500,
+ )
def calibration_function(self, charges, cal_edges):
u"""
@@ -2711,7 +3242,7 @@ class MIMOSA:
if cal_r2_mean > cal_0_mean:
b_lower = (ua_r2 - 0) / (cal_r2_mean - cal_0_mean)
else:
- vprint(self.verbose, '[W] 0 uA == %.f uA during calibration' % (ua_r2))
+ vprint(self.verbose, "[W] 0 uA == %.f uA during calibration" % (ua_r2))
b_lower = 0
b_upper = (ua_r1 - ua_r2) / (cal_r1_mean - cal_r2_mean)
@@ -2726,7 +3257,9 @@ class MIMOSA:
return 0
else:
return charge * b_lower + a_lower
+
else:
+
def calfunc(charge):
if charge < cal_0_mean:
return 0
@@ -2736,19 +3269,19 @@ class MIMOSA:
return charge * b_upper + a_upper + ua_r2
caldata = {
- 'edges': [x * 10 for x in cal_edges],
- 'offset': cal_0_mean,
- 'offset2': cal_r2_mean,
- 'slope_low': b_lower,
- 'slope_high': b_upper,
- 'add_low': a_lower,
- 'add_high': a_upper,
- 'r0_err_uW': np.mean(self.currents_nocal(chg_r0)) * self.voltage,
- 'r0_std_uW': np.std(self.currents_nocal(chg_r0)) * self.voltage,
- 'r1_err_uW': (np.mean(self.currents_nocal(chg_r1)) - ua_r1) * self.voltage,
- 'r1_std_uW': np.std(self.currents_nocal(chg_r1)) * self.voltage,
- 'r2_err_uW': (np.mean(self.currents_nocal(chg_r2)) - ua_r2) * self.voltage,
- 'r2_std_uW': np.std(self.currents_nocal(chg_r2)) * self.voltage,
+ "edges": [x * 10 for x in cal_edges],
+ "offset": cal_0_mean,
+ "offset2": cal_r2_mean,
+ "slope_low": b_lower,
+ "slope_high": b_upper,
+ "add_low": a_lower,
+ "add_high": a_upper,
+ "r0_err_uW": np.mean(self.currents_nocal(chg_r0)) * self.voltage,
+ "r0_std_uW": np.std(self.currents_nocal(chg_r0)) * self.voltage,
+ "r1_err_uW": (np.mean(self.currents_nocal(chg_r1)) - ua_r1) * self.voltage,
+ "r1_std_uW": np.std(self.currents_nocal(chg_r1)) * self.voltage,
+ "r2_err_uW": (np.mean(self.currents_nocal(chg_r2)) - ua_r2) * self.voltage,
+ "r2_std_uW": np.std(self.currents_nocal(chg_r2)) * self.voltage,
}
# print("if charge < %f : return 0" % cal_0_mean)
@@ -2843,51 +3376,59 @@ class MIMOSA:
statelist = []
prevsubidx = 0
for subidx in subst:
- statelist.append({
- 'duration': (subidx - prevsubidx) * 10,
- 'uW_mean': np.mean(range_ua[prevsubidx: subidx] * self.voltage),
- 'uW_std': np.std(range_ua[prevsubidx: subidx] * self.voltage),
- })
+ statelist.append(
+ {
+ "duration": (subidx - prevsubidx) * 10,
+ "uW_mean": np.mean(
+ range_ua[prevsubidx:subidx] * self.voltage
+ ),
+ "uW_std": np.std(
+ range_ua[prevsubidx:subidx] * self.voltage
+ ),
+ }
+ )
prevsubidx = subidx
substates = {
- 'threshold': thr,
- 'states': statelist,
+ "threshold": thr,
+ "states": statelist,
}
- isa = 'state'
+ isa = "state"
if not is_state:
- isa = 'transition'
+ isa = "transition"
data = {
- 'isa': isa,
- 'clip_rate': np.mean(range_raw == 65535),
- 'raw_mean': np.mean(range_raw),
- 'raw_std': np.std(range_raw),
- 'uW_mean': np.mean(range_ua * self.voltage),
- 'uW_std': np.std(range_ua * self.voltage),
- 'us': (idx - previdx) * 10,
+ "isa": isa,
+ "clip_rate": np.mean(range_raw == 65535),
+ "raw_mean": np.mean(range_raw),
+ "raw_std": np.std(range_raw),
+ "uW_mean": np.mean(range_ua * self.voltage),
+ "uW_std": np.std(range_ua * self.voltage),
+ "us": (idx - previdx) * 10,
}
if self.with_traces:
- data['uW'] = range_ua * self.voltage
+ data["uW"] = range_ua * self.voltage
- if 'states' in substates:
- data['substates'] = substates
- ssum = np.sum(list(map(lambda x: x['duration'], substates['states'])))
- if ssum != data['us']:
- vprint(self.verbose, "ERR: duration %d vs %d" % (data['us'], ssum))
+ if "states" in substates:
+ data["substates"] = substates
+ ssum = np.sum(list(map(lambda x: x["duration"], substates["states"])))
+ if ssum != data["us"]:
+ vprint(self.verbose, "ERR: duration %d vs %d" % (data["us"], ssum))
- if isa == 'transition':
+ if isa == "transition":
# subtract average power of previous state
# (that is, the state from which this transition originates)
- data['uW_mean_delta_prev'] = data['uW_mean'] - iterdata[-1]['uW_mean']
+ data["uW_mean_delta_prev"] = data["uW_mean"] - iterdata[-1]["uW_mean"]
# placeholder to avoid extra cases in the analysis
- data['uW_mean_delta_next'] = data['uW_mean']
- data['timeout'] = iterdata[-1]['us']
+ data["uW_mean_delta_next"] = data["uW_mean"]
+ data["timeout"] = iterdata[-1]["us"]
elif len(iterdata) > 0:
# subtract average power of next state
# (the state into which this transition leads)
- iterdata[-1]['uW_mean_delta_next'] = iterdata[-1]['uW_mean'] - data['uW_mean']
+ iterdata[-1]["uW_mean_delta_next"] = (
+ iterdata[-1]["uW_mean"] - data["uW_mean"]
+ )
iterdata.append(data)
diff --git a/lib/functions.py b/lib/functions.py
index 2451ef6..6d8daa4 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -12,6 +12,7 @@ from .utils import is_numeric, vprint
arg_support_enabled = True
+
def powerset(iterable):
"""
Return powerset of `iterable` elements.
@@ -19,7 +20,8 @@ def powerset(iterable):
Example: `powerset([1, 2])` -> `[(), (1), (2), (1, 2)]`
"""
s = list(iterable)
- return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
+ return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
+
class ParamFunction:
"""
@@ -82,6 +84,7 @@ class ParamFunction:
"""
return self._param_function(P, X) - y
+
class NormalizationFunction:
"""
Wrapper for parameter normalization functions used in YAML PTA/DFA models.
@@ -95,7 +98,7 @@ class NormalizationFunction:
`param` and return a float.
"""
self._function_str = function_str
- self._function = eval('lambda param: ' + function_str)
+ self._function = eval("lambda param: " + function_str)
def eval(self, param_value: float) -> float:
"""
@@ -105,6 +108,7 @@ class NormalizationFunction:
"""
return self._function(param_value)
+
class AnalyticFunction:
"""
A multi-dimensional model function, generated from a string, which can be optimized using regression.
@@ -114,7 +118,9 @@ class AnalyticFunction:
packet length.
"""
- def __init__(self, function_str, parameters, num_args, verbose = True, regression_args = None):
+ def __init__(
+ self, function_str, parameters, num_args, verbose=True, regression_args=None
+ ):
"""
Create a new AnalyticFunction object from a function string.
@@ -143,22 +149,30 @@ class AnalyticFunction:
self.verbose = verbose
if type(function_str) == str:
- num_vars_re = re.compile(r'regression_arg\(([0-9]+)\)')
+ num_vars_re = re.compile(r"regression_arg\(([0-9]+)\)")
num_vars = max(map(int, num_vars_re.findall(function_str))) + 1
for i in range(len(parameters)):
- if rawfunction.find('parameter({})'.format(parameters[i])) >= 0:
+ if rawfunction.find("parameter({})".format(parameters[i])) >= 0:
self._dependson[i] = True
- rawfunction = rawfunction.replace('parameter({})'.format(parameters[i]), 'model_param[{:d}]'.format(i))
+ rawfunction = rawfunction.replace(
+ "parameter({})".format(parameters[i]),
+ "model_param[{:d}]".format(i),
+ )
for i in range(0, num_args):
- if rawfunction.find('function_arg({:d})'.format(i)) >= 0:
+ if rawfunction.find("function_arg({:d})".format(i)) >= 0:
self._dependson[len(parameters) + i] = True
- rawfunction = rawfunction.replace('function_arg({:d})'.format(i), 'model_param[{:d}]'.format(len(parameters) + i))
+ rawfunction = rawfunction.replace(
+ "function_arg({:d})".format(i),
+ "model_param[{:d}]".format(len(parameters) + i),
+ )
for i in range(num_vars):
- rawfunction = rawfunction.replace('regression_arg({:d})'.format(i), 'reg_param[{:d}]'.format(i))
+ rawfunction = rawfunction.replace(
+ "regression_arg({:d})".format(i), "reg_param[{:d}]".format(i)
+ )
self._function_str = rawfunction
- self._function = eval('lambda reg_param, model_param: ' + rawfunction)
+ self._function = eval("lambda reg_param, model_param: " + rawfunction)
else:
- self._function_str = 'raise ValueError'
+ self._function_str = "raise ValueError"
self._function = function_str
if regression_args:
@@ -217,7 +231,12 @@ class AnalyticFunction:
else:
X[i].extend([np.nan] * len(val[model_attribute]))
elif key[0] == state_or_tran and len(key[1]) != dimension:
- vprint(self.verbose, '[W] Invalid parameter key length while gathering fit data for {}/{}. is {}, want {}.'.format(state_or_tran, model_attribute, len(key[1]), dimension))
+ vprint(
+ self.verbose,
+ "[W] Invalid parameter key length while gathering fit data for {}/{}. is {}, want {}.".format(
+ state_or_tran, model_attribute, len(key[1]), dimension
+ ),
+ )
X = np.array(X)
Y = np.array(Y)
@@ -237,21 +256,40 @@ class AnalyticFunction:
argument values are present, they must come after parameter values
in the order of their appearance in the function signature.
"""
- X, Y, num_valid, num_total = self.get_fit_data(by_param, state_or_tran, model_attribute)
+ X, Y, num_valid, num_total = self.get_fit_data(
+ by_param, state_or_tran, model_attribute
+ )
if num_valid > 2:
error_function = lambda P, X, y: self._function(P, X) - y
try:
- res = optimize.least_squares(error_function, self._regression_args, args=(X, Y), xtol=2e-15)
+ res = optimize.least_squares(
+ error_function, self._regression_args, args=(X, Y), xtol=2e-15
+ )
except ValueError as err:
- vprint(self.verbose, '[W] Fit failed for {}/{}: {} (function: {})'.format(state_or_tran, model_attribute, err, self._model_str))
+ vprint(
+ self.verbose,
+ "[W] Fit failed for {}/{}: {} (function: {})".format(
+ state_or_tran, model_attribute, err, self._model_str
+ ),
+ )
return
if res.status > 0:
self._regression_args = res.x
self.fit_success = True
else:
- vprint(self.verbose, '[W] Fit failed for {}/{}: {} (function: {})'.format(state_or_tran, model_attribute, res.message, self._model_str))
+ vprint(
+ self.verbose,
+ "[W] Fit failed for {}/{}: {} (function: {})".format(
+ state_or_tran, model_attribute, res.message, self._model_str
+ ),
+ )
else:
- vprint(self.verbose, '[W] Insufficient amount of valid parameter keys, cannot fit {}/{}'.format(state_or_tran, model_attribute))
+ vprint(
+ self.verbose,
+ "[W] Insufficient amount of valid parameter keys, cannot fit {}/{}".format(
+ state_or_tran, model_attribute
+ ),
+ )
def is_predictable(self, param_list):
"""
@@ -268,7 +306,7 @@ class AnalyticFunction:
return False
return True
- def eval(self, param_list, arg_list = []):
+ def eval(self, param_list, arg_list=[]):
"""
Evaluate model function with specified param/arg values.
@@ -280,6 +318,7 @@ class AnalyticFunction:
return self._function(param_list, arg_list)
return self._function(self._regression_args, param_list)
+
class analytic:
"""
Utilities for analytic description of parameter-dependent model attributes and regression analysis.
@@ -292,28 +331,28 @@ class analytic:
_num0_8 = np.vectorize(lambda x: 8 - bin(int(x)).count("1"))
_num0_16 = np.vectorize(lambda x: 16 - bin(int(x)).count("1"))
_num1 = np.vectorize(lambda x: bin(int(x)).count("1"))
- _safe_log = np.vectorize(lambda x: np.log(np.abs(x)) if np.abs(x) > 0.001 else 1.)
- _safe_inv = np.vectorize(lambda x: 1 / x if np.abs(x) > 0.001 else 1.)
+ _safe_log = np.vectorize(lambda x: np.log(np.abs(x)) if np.abs(x) > 0.001 else 1.0)
+ _safe_inv = np.vectorize(lambda x: 1 / x if np.abs(x) > 0.001 else 1.0)
_safe_sqrt = np.vectorize(lambda x: np.sqrt(np.abs(x)))
_function_map = {
- 'linear' : lambda x: x,
- 'logarithmic' : np.log,
- 'logarithmic1' : lambda x: np.log(x + 1),
- 'exponential' : np.exp,
- 'square' : lambda x : x ** 2,
- 'inverse' : lambda x : 1 / x,
- 'sqrt' : lambda x: np.sqrt(np.abs(x)),
- 'num0_8' : _num0_8,
- 'num0_16' : _num0_16,
- 'num1' : _num1,
- 'safe_log' : lambda x: np.log(np.abs(x)) if np.abs(x) > 0.001 else 1.,
- 'safe_inv' : lambda x: 1 / x if np.abs(x) > 0.001 else 1.,
- 'safe_sqrt': lambda x: np.sqrt(np.abs(x)),
+ "linear": lambda x: x,
+ "logarithmic": np.log,
+ "logarithmic1": lambda x: np.log(x + 1),
+ "exponential": np.exp,
+ "square": lambda x: x ** 2,
+ "inverse": lambda x: 1 / x,
+ "sqrt": lambda x: np.sqrt(np.abs(x)),
+ "num0_8": _num0_8,
+ "num0_16": _num0_16,
+ "num1": _num1,
+ "safe_log": lambda x: np.log(np.abs(x)) if np.abs(x) > 0.001 else 1.0,
+ "safe_inv": lambda x: 1 / x if np.abs(x) > 0.001 else 1.0,
+ "safe_sqrt": lambda x: np.sqrt(np.abs(x)),
}
@staticmethod
- def functions(safe_functions_enabled = False):
+ def functions(safe_functions_enabled=False):
"""
Retrieve pre-defined set of regression function candidates.
@@ -329,74 +368,87 @@ class analytic:
variables are expected.
"""
functions = {
- 'linear' : ParamFunction(
- lambda reg_param, model_param: reg_param[0] + reg_param[1] * model_param,
+ "linear": ParamFunction(
+ lambda reg_param, model_param: reg_param[0]
+ + reg_param[1] * model_param,
lambda model_param: True,
- 2
+ 2,
),
- 'logarithmic' : ParamFunction(
- lambda reg_param, model_param: reg_param[0] + reg_param[1] * np.log(model_param),
+ "logarithmic": ParamFunction(
+ lambda reg_param, model_param: reg_param[0]
+ + reg_param[1] * np.log(model_param),
lambda model_param: model_param > 0,
- 2
+ 2,
),
- 'logarithmic1' : ParamFunction(
- lambda reg_param, model_param: reg_param[0] + reg_param[1] * np.log(model_param + 1),
+ "logarithmic1": ParamFunction(
+ lambda reg_param, model_param: reg_param[0]
+ + reg_param[1] * np.log(model_param + 1),
lambda model_param: model_param > -1,
- 2
+ 2,
),
- 'exponential' : ParamFunction(
- lambda reg_param, model_param: reg_param[0] + reg_param[1] * np.exp(model_param),
+ "exponential": ParamFunction(
+ lambda reg_param, model_param: reg_param[0]
+ + reg_param[1] * np.exp(model_param),
lambda model_param: model_param <= 64,
- 2
+ 2,
),
#'polynomial' : lambda reg_param, model_param: reg_param[0] + reg_param[1] * model_param + reg_param[2] * model_param ** 2,
- 'square' : ParamFunction(
- lambda reg_param, model_param: reg_param[0] + reg_param[1] * model_param ** 2,
+ "square": ParamFunction(
+ lambda reg_param, model_param: reg_param[0]
+ + reg_param[1] * model_param ** 2,
lambda model_param: True,
- 2
+ 2,
),
- 'inverse' : ParamFunction(
- lambda reg_param, model_param: reg_param[0] + reg_param[1] / model_param,
+ "inverse": ParamFunction(
+ lambda reg_param, model_param: reg_param[0]
+ + reg_param[1] / model_param,
lambda model_param: model_param != 0,
- 2
+ 2,
),
- 'sqrt' : ParamFunction(
- lambda reg_param, model_param: reg_param[0] + reg_param[1] * np.sqrt(model_param),
+ "sqrt": ParamFunction(
+ lambda reg_param, model_param: reg_param[0]
+ + reg_param[1] * np.sqrt(model_param),
lambda model_param: model_param >= 0,
- 2
+ 2,
),
- 'num0_8' : ParamFunction(
- lambda reg_param, model_param: reg_param[0] + reg_param[1] * analytic._num0_8(model_param),
+ "num0_8": ParamFunction(
+ lambda reg_param, model_param: reg_param[0]
+ + reg_param[1] * analytic._num0_8(model_param),
lambda model_param: True,
- 2
+ 2,
),
- 'num0_16' : ParamFunction(
- lambda reg_param, model_param: reg_param[0] + reg_param[1] * analytic._num0_16(model_param),
+ "num0_16": ParamFunction(
+ lambda reg_param, model_param: reg_param[0]
+ + reg_param[1] * analytic._num0_16(model_param),
lambda model_param: True,
- 2
+ 2,
),
- 'num1' : ParamFunction(
- lambda reg_param, model_param: reg_param[0] + reg_param[1] * analytic._num1(model_param),
+ "num1": ParamFunction(
+ lambda reg_param, model_param: reg_param[0]
+ + reg_param[1] * analytic._num1(model_param),
lambda model_param: True,
- 2
+ 2,
),
}
if safe_functions_enabled:
- functions['safe_log'] = ParamFunction(
- lambda reg_param, model_param: reg_param[0] + reg_param[1] * analytic._safe_log(model_param),
+ functions["safe_log"] = ParamFunction(
+ lambda reg_param, model_param: reg_param[0]
+ + reg_param[1] * analytic._safe_log(model_param),
lambda model_param: True,
- 2
+ 2,
)
- functions['safe_inv'] = ParamFunction(
- lambda reg_param, model_param: reg_param[0] + reg_param[1] * analytic._safe_inv(model_param),
+ functions["safe_inv"] = ParamFunction(
+ lambda reg_param, model_param: reg_param[0]
+ + reg_param[1] * analytic._safe_inv(model_param),
lambda model_param: True,
- 2
+ 2,
)
- functions['safe_sqrt'] = ParamFunction(
- lambda reg_param, model_param: reg_param[0] + reg_param[1] * analytic._safe_sqrt(model_param),
+ functions["safe_sqrt"] = ParamFunction(
+ lambda reg_param, model_param: reg_param[0]
+ + reg_param[1] * analytic._safe_sqrt(model_param),
lambda model_param: True,
- 2
+ 2,
)
return functions
@@ -404,27 +456,27 @@ class analytic:
@staticmethod
def _fmap(reference_type, reference_name, function_type):
"""Map arg/parameter name and best-fit function name to function text suitable for AnalyticFunction."""
- ref_str = '{}({})'.format(reference_type,reference_name)
- if function_type == 'linear':
+ ref_str = "{}({})".format(reference_type, reference_name)
+ if function_type == "linear":
return ref_str
- if function_type == 'logarithmic':
- return 'np.log({})'.format(ref_str)
- if function_type == 'logarithmic1':
- return 'np.log({} + 1)'.format(ref_str)
- if function_type == 'exponential':
- return 'np.exp({})'.format(ref_str)
- if function_type == 'exponential':
- return 'np.exp({})'.format(ref_str)
- if function_type == 'square':
- return '({})**2'.format(ref_str)
- if function_type == 'inverse':
- return '1/({})'.format(ref_str)
- if function_type == 'sqrt':
- return 'np.sqrt({})'.format(ref_str)
- return 'analytic._{}({})'.format(function_type, ref_str)
+ if function_type == "logarithmic":
+ return "np.log({})".format(ref_str)
+ if function_type == "logarithmic1":
+ return "np.log({} + 1)".format(ref_str)
+ if function_type == "exponential":
+ return "np.exp({})".format(ref_str)
+ if function_type == "exponential":
+ return "np.exp({})".format(ref_str)
+ if function_type == "square":
+ return "({})**2".format(ref_str)
+ if function_type == "inverse":
+ return "1/({})".format(ref_str)
+ if function_type == "sqrt":
+ return "np.sqrt({})".format(ref_str)
+ return "analytic._{}({})".format(function_type, ref_str)
@staticmethod
- def function_powerset(fit_results, parameter_names, num_args = 0):
+ def function_powerset(fit_results, parameter_names, num_args=0):
"""
Combine per-parameter regression results into a single multi-dimensional function.
@@ -443,14 +495,22 @@ class analytic:
Returns an AnalyticFunction instantce corresponding to the combined
function.
"""
- buf = '0'
+ buf = "0"
arg_idx = 0
for combination in powerset(fit_results.items()):
- buf += ' + regression_arg({:d})'.format(arg_idx)
+ buf += " + regression_arg({:d})".format(arg_idx)
arg_idx += 1
for function_item in combination:
if arg_support_enabled and is_numeric(function_item[0]):
- buf += ' * {}'.format(analytic._fmap('function_arg', function_item[0], function_item[1]['best']))
+ buf += " * {}".format(
+ analytic._fmap(
+ "function_arg", function_item[0], function_item[1]["best"]
+ )
+ )
else:
- buf += ' * {}'.format(analytic._fmap('parameter', function_item[0], function_item[1]['best']))
+ buf += " * {}".format(
+ analytic._fmap(
+ "parameter", function_item[0], function_item[1]["best"]
+ )
+ )
return AnalyticFunction(buf, parameter_names, num_args)
diff --git a/lib/harness.py b/lib/harness.py
index 54518e3..3b279c0 100644
--- a/lib/harness.py
+++ b/lib/harness.py
@@ -24,7 +24,16 @@ class TransitionHarness:
primitive values (-> set by the return value of the current run, not necessarily constan)
* `args`: function arguments, if isa == 'transition'
"""
- def __init__(self, gpio_pin = None, gpio_mode = 'around', pta = None, log_return_values = False, repeat = 0, post_transition_delay_us = 0):
+
+ def __init__(
+ self,
+ gpio_pin=None,
+ gpio_mode="around",
+ pta=None,
+ log_return_values=False,
+ repeat=0,
+ post_transition_delay_us=0,
+ ):
"""
Create a new TransitionHarness
@@ -47,7 +56,14 @@ class TransitionHarness:
self.reset()
def copy(self):
- new_object = __class__(gpio_pin = self.gpio_pin, gpio_mode = self.gpio_mode, pta = self.pta, log_return_values = self.log_return_values, repeat = self.repeat, post_transition_delay_us = self.post_transition_delay_us)
+ new_object = __class__(
+ gpio_pin=self.gpio_pin,
+ gpio_mode=self.gpio_mode,
+ pta=self.pta,
+ log_return_values=self.log_return_values,
+ repeat=self.repeat,
+ post_transition_delay_us=self.post_transition_delay_us,
+ )
new_object.traces = self.traces.copy()
new_object.trace_id = self.trace_id
return new_object
@@ -62,12 +78,16 @@ class TransitionHarness:
of the current benchmark iteration. Resets `done` and `synced`,
"""
for trace in self.traces:
- for state_or_transition in trace['trace']:
- if 'return_values' in state_or_transition:
- state_or_transition['return_values'] = state_or_transition['return_values'][:undo_from]
- for param_name in state_or_transition['parameter'].keys():
- if type(state_or_transition['parameter'][param_name]) is list:
- state_or_transition['parameter'][param_name] = state_or_transition['parameter'][param_name][:undo_from]
+ for state_or_transition in trace["trace"]:
+ if "return_values" in state_or_transition:
+ state_or_transition["return_values"] = state_or_transition[
+ "return_values"
+ ][:undo_from]
+ for param_name in state_or_transition["parameter"].keys():
+ if type(state_or_transition["parameter"][param_name]) is list:
+ state_or_transition["parameter"][
+ param_name
+ ] = state_or_transition["parameter"][param_name][:undo_from]
def reset(self):
"""
@@ -95,33 +115,32 @@ class TransitionHarness:
def global_code(self):
"""Return global (pre-`main()`) C++ code needed for tracing."""
- ret = ''
+ ret = ""
if self.gpio_pin != None:
- ret += '#define PTALOG_GPIO {}\n'.format(self.gpio_pin)
- if self.gpio_mode == 'before':
- ret += '#define PTALOG_GPIO_BEFORE\n'
- elif self.gpio_mode == 'bar':
- ret += '#define PTALOG_GPIO_BAR\n'
+ ret += "#define PTALOG_GPIO {}\n".format(self.gpio_pin)
+ if self.gpio_mode == "before":
+ ret += "#define PTALOG_GPIO_BEFORE\n"
+ elif self.gpio_mode == "bar":
+ ret += "#define PTALOG_GPIO_BAR\n"
if self.log_return_values:
- ret += '#define PTALOG_WITH_RETURNVALUES\n'
- ret += 'uint16_t transition_return_value;\n'
+ ret += "#define PTALOG_WITH_RETURNVALUES\n"
+ ret += "uint16_t transition_return_value;\n"
ret += '#include "object/ptalog.h"\n'
if self.gpio_pin != None:
- ret += 'PTALog ptalog({});\n'.format(self.gpio_pin)
+ ret += "PTALog ptalog({});\n".format(self.gpio_pin)
else:
- ret += 'PTALog ptalog;\n'
+ ret += "PTALog ptalog;\n"
return ret
- def start_benchmark(self, benchmark_id = 0):
+ def start_benchmark(self, benchmark_id=0):
"""Return C++ code to signal benchmark start to harness."""
- return 'ptalog.startBenchmark({:d});\n'.format(benchmark_id)
+ return "ptalog.startBenchmark({:d});\n".format(benchmark_id)
def start_trace(self):
"""Prepare a new trace/run in the internal `.traces` structure."""
- self.traces.append({
- 'id' : self.trace_id,
- 'trace' : list(),
- })
+ self.traces.append(
+ {"id": self.trace_id, "trace": list(),}
+ )
self.trace_id += 1
def append_state(self, state_name, param):
@@ -131,13 +150,11 @@ class TransitionHarness:
:param state_name: state name
:param param: parameter dict
"""
- self.traces[-1]['trace'].append({
- 'name': state_name,
- 'isa': 'state',
- 'parameter': param,
- })
+ self.traces[-1]["trace"].append(
+ {"name": state_name, "isa": "state", "parameter": param,}
+ )
- def append_transition(self, transition_name, param, args = []):
+ def append_transition(self, transition_name, param, args=[]):
"""
Append a transition to the current run in the internal `.traces` structure.
@@ -145,122 +162,188 @@ class TransitionHarness:
:param param: parameter dict
:param args: function arguments (optional)
"""
- self.traces[-1]['trace'].append({
- 'name': transition_name,
- 'isa': 'transition',
- 'parameter': param,
- 'args' : args,
- })
+ self.traces[-1]["trace"].append(
+ {
+ "name": transition_name,
+ "isa": "transition",
+ "parameter": param,
+ "args": args,
+ }
+ )
def start_run(self):
"""Return C++ code used to start a new run/trace."""
- return 'ptalog.reset();\n'
+ return "ptalog.reset();\n"
def _pass_transition_call(self, transition_id):
- if self.gpio_mode == 'bar':
- barcode_bits = Code128('T{}'.format(transition_id), charset='B').modules
+ if self.gpio_mode == "bar":
+ barcode_bits = Code128("T{}".format(transition_id), charset="B").modules
if len(barcode_bits) % 8 != 0:
barcode_bits.extend([1] * (8 - (len(barcode_bits) % 8)))
- barcode_bytes = [255 - int("".join(map(str, reversed(barcode_bits[i:i+8]))), 2) for i in range(0, len(barcode_bits), 8)]
- inline_array = "".join(map(lambda s: '\\x{:02x}'.format(s), barcode_bytes))
- return 'ptalog.startTransition("{}", {});\n'.format(inline_array, len(barcode_bytes))
+ barcode_bytes = [
+ 255 - int("".join(map(str, reversed(barcode_bits[i : i + 8]))), 2)
+ for i in range(0, len(barcode_bits), 8)
+ ]
+ inline_array = "".join(map(lambda s: "\\x{:02x}".format(s), barcode_bytes))
+ return 'ptalog.startTransition("{}", {});\n'.format(
+ inline_array, len(barcode_bytes)
+ )
else:
- return 'ptalog.startTransition();\n'
+ return "ptalog.startTransition();\n"
- def pass_transition(self, transition_id, transition_code, transition: object = None):
+ def pass_transition(
+ self, transition_id, transition_code, transition: object = None
+ ):
"""
Return C++ code used to pass a transition, including the corresponding function call.
Tracks which transition has been executed and optionally its return value. May also inject a delay, if
`post_transition_delay_us` is set.
"""
- ret = 'ptalog.passTransition({:d});\n'.format(transition_id)
+ ret = "ptalog.passTransition({:d});\n".format(transition_id)
ret += self._pass_transition_call(transition_id)
- if self.log_return_values and transition and len(transition.return_value_handlers):
- ret += 'transition_return_value = {}\n'.format(transition_code)
- ret += 'ptalog.logReturn(transition_return_value);\n'
+ if (
+ self.log_return_values
+ and transition
+ and len(transition.return_value_handlers)
+ ):
+ ret += "transition_return_value = {}\n".format(transition_code)
+ ret += "ptalog.logReturn(transition_return_value);\n"
else:
- ret += '{}\n'.format(transition_code)
+ ret += "{}\n".format(transition_code)
if self.post_transition_delay_us:
- ret += 'arch.delay_us({});\n'.format(self.post_transition_delay_us)
- ret += 'ptalog.stopTransition();\n'
+ ret += "arch.delay_us({});\n".format(self.post_transition_delay_us)
+ ret += "ptalog.stopTransition();\n"
return ret
- def stop_run(self, num_traces = 0):
- return 'ptalog.dump({:d});\n'.format(num_traces)
+ def stop_run(self, num_traces=0):
+ return "ptalog.dump({:d});\n".format(num_traces)
def stop_benchmark(self):
- return 'ptalog.stopBenchmark();\n'
+ return "ptalog.stopBenchmark();\n"
- def _append_nondeterministic_parameter_value(self, log_data_target, parameter_name, parameter_value):
- if log_data_target['parameter'][parameter_name] is None:
- log_data_target['parameter'][parameter_name] = list()
- log_data_target['parameter'][parameter_name].append(parameter_value)
+ def _append_nondeterministic_parameter_value(
+ self, log_data_target, parameter_name, parameter_value
+ ):
+ if log_data_target["parameter"][parameter_name] is None:
+ log_data_target["parameter"][parameter_name] = list()
+ log_data_target["parameter"][parameter_name].append(parameter_value)
def parser_cb(self, line):
- #print('[HARNESS] got line {}'.format(line))
- if re.match(r'\[PTA\] benchmark stop', line):
+ # print('[HARNESS] got line {}'.format(line))
+ if re.match(r"\[PTA\] benchmark stop", line):
self.repetitions += 1
self.synced = False
if self.repeat > 0 and self.repetitions == self.repeat:
self.done = True
- print('[HARNESS] done')
+ print("[HARNESS] done")
return
- if re.match(r'\[PTA\] benchmark start, id=(\S+)', line):
+ if re.match(r"\[PTA\] benchmark start, id=(\S+)", line):
self.synced = True
- print('[HARNESS] synced, {}/{}'.format(self.repetitions + 1, self.repeat))
+ print("[HARNESS] synced, {}/{}".format(self.repetitions + 1, self.repeat))
if self.synced:
- res = re.match(r'\[PTA\] trace=(\S+) count=(\S+)', line)
+ res = re.match(r"\[PTA\] trace=(\S+) count=(\S+)", line)
if res:
self.trace_id = int(res.group(1))
self.trace_length = int(res.group(2))
self.current_transition_in_trace = 0
if self.log_return_values:
- res = re.match(r'\[PTA\] transition=(\S+) return=(\S+)', line)
+ res = re.match(r"\[PTA\] transition=(\S+) return=(\S+)", line)
else:
- res = re.match(r'\[PTA\] transition=(\S+)', line)
+ res = re.match(r"\[PTA\] transition=(\S+)", line)
if res:
transition_id = int(res.group(1))
# self.traces contains transitions and states, UART output only contains transitions -> use index * 2
try:
- log_data_target = self.traces[self.trace_id]['trace'][self.current_transition_in_trace * 2]
+ log_data_target = self.traces[self.trace_id]["trace"][
+ self.current_transition_in_trace * 2
+ ]
except IndexError:
transition_name = None
if self.pta:
transition_name = self.pta.transitions[transition_id].name
- print('[HARNESS] benchmark id={:d} trace={:d}: transition #{:d} (ID {:d}, name {}) is out of bounds'.format(0, self.trace_id, self.current_transition_in_trace, transition_id, transition_name))
- print(' Offending line: {}'.format(line))
+ print(
+ "[HARNESS] benchmark id={:d} trace={:d}: transition #{:d} (ID {:d}, name {}) is out of bounds".format(
+ 0,
+ self.trace_id,
+ self.current_transition_in_trace,
+ transition_id,
+ transition_name,
+ )
+ )
+ print(" Offending line: {}".format(line))
return
- if log_data_target['isa'] != 'transition':
+ if log_data_target["isa"] != "transition":
self.abort = True
- raise RuntimeError('Log mismatch: Expected transition, got {:s}'.format(log_data_target['isa']))
+ raise RuntimeError(
+ "Log mismatch: Expected transition, got {:s}".format(
+ log_data_target["isa"]
+ )
+ )
if self.pta:
transition = self.pta.transitions[transition_id]
- if transition.name != log_data_target['name']:
+ if transition.name != log_data_target["name"]:
self.abort = True
- raise RuntimeError('Log mismatch: Expected transition {:s}, got transition {:s} -- may have been caused by preceding malformed UART output'.format(log_data_target['name'], transition.name))
+ raise RuntimeError(
+ "Log mismatch: Expected transition {:s}, got transition {:s} -- may have been caused by preceding malformed UART output".format(
+ log_data_target["name"], transition.name
+ )
+ )
if self.log_return_values and len(transition.return_value_handlers):
for handler in transition.return_value_handlers:
- if 'parameter' in handler:
+ if "parameter" in handler:
parameter_value = return_value = int(res.group(2))
- if 'return_values' not in log_data_target:
- log_data_target['return_values'] = list()
- log_data_target['return_values'].append(return_value)
-
- if 'formula' in handler:
- parameter_value = handler['formula'].eval(return_value)
-
- self._append_nondeterministic_parameter_value(log_data_target, handler['parameter'], parameter_value)
- for following_log_data_target in self.traces[self.trace_id]['trace'][(self.current_transition_in_trace * 2 + 1) :]:
- self._append_nondeterministic_parameter_value(following_log_data_target, handler['parameter'], parameter_value)
- if 'apply_from' in handler and any(map(lambda x: x['name'] == handler['apply_from'], self.traces[self.trace_id]['trace'][: (self.current_transition_in_trace * 2 + 1)])):
- for preceding_log_data_target in reversed(self.traces[self.trace_id]['trace'][: (self.current_transition_in_trace * 2)]):
- self._append_nondeterministic_parameter_value(preceding_log_data_target, handler['parameter'], parameter_value)
- if preceding_log_data_target['name'] == handler['apply_from']:
+ if "return_values" not in log_data_target:
+ log_data_target["return_values"] = list()
+ log_data_target["return_values"].append(return_value)
+
+ if "formula" in handler:
+ parameter_value = handler["formula"].eval(
+ return_value
+ )
+
+ self._append_nondeterministic_parameter_value(
+ log_data_target,
+ handler["parameter"],
+ parameter_value,
+ )
+ for following_log_data_target in self.traces[
+ self.trace_id
+ ]["trace"][
+ (self.current_transition_in_trace * 2 + 1) :
+ ]:
+ self._append_nondeterministic_parameter_value(
+ following_log_data_target,
+ handler["parameter"],
+ parameter_value,
+ )
+ if "apply_from" in handler and any(
+ map(
+ lambda x: x["name"] == handler["apply_from"],
+ self.traces[self.trace_id]["trace"][
+ : (self.current_transition_in_trace * 2 + 1)
+ ],
+ )
+ ):
+ for preceding_log_data_target in reversed(
+ self.traces[self.trace_id]["trace"][
+ : (self.current_transition_in_trace * 2)
+ ]
+ ):
+ self._append_nondeterministic_parameter_value(
+ preceding_log_data_target,
+ handler["parameter"],
+ parameter_value,
+ )
+ if (
+ preceding_log_data_target["name"]
+ == handler["apply_from"]
+ ):
break
self.current_transition_in_trace += 1
+
class OnboardTimerHarness(TransitionHarness):
"""TODO
@@ -271,13 +354,25 @@ class OnboardTimerHarness(TransitionHarness):
benchmark iteration.
I.e. `.traces[*]['trace'][*]['offline_aggregates']['duration'] = [..., ...]`
"""
+
def __init__(self, counter_limits, **kwargs):
super().__init__(**kwargs)
self.trace_length = 0
- self.one_cycle_in_us, self.one_overflow_in_us, self.counter_max_overflow = counter_limits
+ (
+ self.one_cycle_in_us,
+ self.one_overflow_in_us,
+ self.counter_max_overflow,
+ ) = counter_limits
def copy(self):
- new_harness = __class__((self.one_cycle_in_us, self.one_overflow_in_us, self.counter_max_overflow), gpio_pin = self.gpio_pin, gpio_mode = self.gpio_mode, pta = self.pta, log_return_values = self.log_return_values, repeat = self.repeat)
+ new_harness = __class__(
+ (self.one_cycle_in_us, self.one_overflow_in_us, self.counter_max_overflow),
+ gpio_pin=self.gpio_pin,
+ gpio_mode=self.gpio_mode,
+ pta=self.pta,
+ log_return_values=self.log_return_values,
+ repeat=self.repeat,
+ )
new_harness.traces = self.traces.copy()
new_harness.trace_id = self.trace_id
return new_harness
@@ -293,123 +388,215 @@ class OnboardTimerHarness(TransitionHarness):
"""
super().undo(undo_from)
for trace in self.traces:
- for state_or_transition in trace['trace']:
- if 'offline_aggregates' in state_or_transition:
- state_or_transition['offline_aggregates']['duration'] = state_or_transition['offline_aggregates']['duration'][:undo_from]
+ for state_or_transition in trace["trace"]:
+ if "offline_aggregates" in state_or_transition:
+ state_or_transition["offline_aggregates"][
+ "duration"
+ ] = state_or_transition["offline_aggregates"]["duration"][
+ :undo_from
+ ]
def global_code(self):
ret = '#include "driver/counter.h"\n'
- ret += '#define PTALOG_TIMING\n'
+ ret += "#define PTALOG_TIMING\n"
ret += super().global_code()
return ret
- def start_benchmark(self, benchmark_id = 0):
- ret = 'counter.start();\n'
- ret += 'counter.stop();\n'
- ret += 'ptalog.passNop(counter);\n'
+ def start_benchmark(self, benchmark_id=0):
+ ret = "counter.start();\n"
+ ret += "counter.stop();\n"
+ ret += "ptalog.passNop(counter);\n"
ret += super().start_benchmark(benchmark_id)
return ret
- def pass_transition(self, transition_id, transition_code, transition: object = None):
- ret = 'ptalog.passTransition({:d});\n'.format(transition_id)
+ def pass_transition(
+ self, transition_id, transition_code, transition: object = None
+ ):
+ ret = "ptalog.passTransition({:d});\n".format(transition_id)
ret += self._pass_transition_call(transition_id)
- ret += 'counter.start();\n'
- if self.log_return_values and transition and len(transition.return_value_handlers):
- ret += 'transition_return_value = {}\n'.format(transition_code)
+ ret += "counter.start();\n"
+ if (
+ self.log_return_values
+ and transition
+ and len(transition.return_value_handlers)
+ ):
+ ret += "transition_return_value = {}\n".format(transition_code)
else:
- ret += '{}\n'.format(transition_code)
- ret += 'counter.stop();\n'
- if self.log_return_values and transition and len(transition.return_value_handlers):
- ret += 'ptalog.logReturn(transition_return_value);\n'
- ret += 'ptalog.stopTransition(counter);\n'
+ ret += "{}\n".format(transition_code)
+ ret += "counter.stop();\n"
+ if (
+ self.log_return_values
+ and transition
+ and len(transition.return_value_handlers)
+ ):
+ ret += "ptalog.logReturn(transition_return_value);\n"
+ ret += "ptalog.stopTransition(counter);\n"
return ret
- def _append_nondeterministic_parameter_value(self, log_data_target, parameter_name, parameter_value):
- if log_data_target['parameter'][parameter_name] is None:
- log_data_target['parameter'][parameter_name] = list()
- log_data_target['parameter'][parameter_name].append(parameter_value)
+ def _append_nondeterministic_parameter_value(
+ self, log_data_target, parameter_name, parameter_value
+ ):
+ if log_data_target["parameter"][parameter_name] is None:
+ log_data_target["parameter"][parameter_name] = list()
+ log_data_target["parameter"][parameter_name].append(parameter_value)
def parser_cb(self, line):
# print('[HARNESS] got line {}'.format(line))
- res = re.match(r'\[PTA\] nop=(\S+)/(\S+)', line)
+ res = re.match(r"\[PTA\] nop=(\S+)/(\S+)", line)
if res:
self.nop_cycles = int(res.group(1))
if int(res.group(2)):
- raise RuntimeError('Counter overflow ({:d}/{:d}) during NOP test, wtf?!'.format(res.group(1), res.group(2)))
- if re.match(r'\[PTA\] benchmark stop', line):
+ raise RuntimeError(
+ "Counter overflow ({:d}/{:d}) during NOP test, wtf?!".format(
+ res.group(1), res.group(2)
+ )
+ )
+ if re.match(r"\[PTA\] benchmark stop", line):
self.repetitions += 1
self.synced = False
if self.repeat > 0 and self.repetitions == self.repeat:
self.done = True
- print('[HARNESS] done')
+ print("[HARNESS] done")
return
# May be repeated, e.g. if the device is reset shortly after start by
# EnergyTrace.
- if re.match(r'\[PTA\] benchmark start, id=(\S+)', line):
+ if re.match(r"\[PTA\] benchmark start, id=(\S+)", line):
self.synced = True
- print('[HARNESS] synced, {}/{}'.format(self.repetitions + 1, self.repeat))
+ print("[HARNESS] synced, {}/{}".format(self.repetitions + 1, self.repeat))
if self.synced:
- res = re.match(r'\[PTA\] trace=(\S+) count=(\S+)', line)
+ res = re.match(r"\[PTA\] trace=(\S+) count=(\S+)", line)
if res:
self.trace_id = int(res.group(1))
self.trace_length = int(res.group(2))
self.current_transition_in_trace = 0
if self.log_return_values:
- res = re.match(r'\[PTA\] transition=(\S+) cycles=(\S+)/(\S+) return=(\S+)', line)
+ res = re.match(
+ r"\[PTA\] transition=(\S+) cycles=(\S+)/(\S+) return=(\S+)", line
+ )
else:
- res = re.match(r'\[PTA\] transition=(\S+) cycles=(\S+)/(\S+)', line)
+ res = re.match(r"\[PTA\] transition=(\S+) cycles=(\S+)/(\S+)", line)
if res:
transition_id = int(res.group(1))
cycles = int(res.group(2))
overflow = int(res.group(3))
if overflow >= self.counter_max_overflow:
self.abort = True
- raise RuntimeError('Counter overflow ({:d}/{:d}) in benchmark id={:d} trace={:d}: transition #{:d} (ID {:d})'.format(cycles, overflow, 0, self.trace_id, self.current_transition_in_trace, transition_id))
- duration_us = cycles * self.one_cycle_in_us + overflow * self.one_overflow_in_us - self.nop_cycles * self.one_cycle_in_us
+ raise RuntimeError(
+ "Counter overflow ({:d}/{:d}) in benchmark id={:d} trace={:d}: transition #{:d} (ID {:d})".format(
+ cycles,
+ overflow,
+ 0,
+ self.trace_id,
+ self.current_transition_in_trace,
+ transition_id,
+ )
+ )
+ duration_us = (
+ cycles * self.one_cycle_in_us
+ + overflow * self.one_overflow_in_us
+ - self.nop_cycles * self.one_cycle_in_us
+ )
if duration_us < 0:
duration_us = 0
# self.traces contains transitions and states, UART output only contains transitions -> use index * 2
try:
- log_data_target = self.traces[self.trace_id]['trace'][self.current_transition_in_trace * 2]
+ log_data_target = self.traces[self.trace_id]["trace"][
+ self.current_transition_in_trace * 2
+ ]
except IndexError:
transition_name = None
if self.pta:
transition_name = self.pta.transitions[transition_id].name
- print('[HARNESS] benchmark id={:d} trace={:d}: transition #{:d} (ID {:d}, name {}) is out of bounds'.format(0, self.trace_id, self.current_transition_in_trace, transition_id, transition_name))
- print(' Offending line: {}'.format(line))
+ print(
+ "[HARNESS] benchmark id={:d} trace={:d}: transition #{:d} (ID {:d}, name {}) is out of bounds".format(
+ 0,
+ self.trace_id,
+ self.current_transition_in_trace,
+ transition_id,
+ transition_name,
+ )
+ )
+ print(" Offending line: {}".format(line))
return
- if log_data_target['isa'] != 'transition':
+ if log_data_target["isa"] != "transition":
self.abort = True
- raise RuntimeError('Log mismatch in benchmark id={:d} trace={:d}: transition #{:d} (ID {:d}): Expected transition, got {:s}'.format(0,
- self.trace_id, self.current_transition_in_trace, transition_id, log_data_target['isa']))
+ raise RuntimeError(
+ "Log mismatch in benchmark id={:d} trace={:d}: transition #{:d} (ID {:d}): Expected transition, got {:s}".format(
+ 0,
+ self.trace_id,
+ self.current_transition_in_trace,
+ transition_id,
+ log_data_target["isa"],
+ )
+ )
if self.pta:
transition = self.pta.transitions[transition_id]
- if transition.name != log_data_target['name']:
+ if transition.name != log_data_target["name"]:
self.abort = True
- raise RuntimeError('Log mismatch in benchmark id={:d} trace={:d}: transition #{:d} (ID {:d}): Expected transition {:s}, got transition {:s} -- may have been caused by preceding maformed UART output'.format(0, self.trace_id, self.current_transition_in_trace, transition_id, log_data_target['name'], transition.name, line))
+ raise RuntimeError(
+ "Log mismatch in benchmark id={:d} trace={:d}: transition #{:d} (ID {:d}): Expected transition {:s}, got transition {:s} -- may have been caused by preceding maformed UART output".format(
+ 0,
+ self.trace_id,
+ self.current_transition_in_trace,
+ transition_id,
+ log_data_target["name"],
+ transition.name,
+ line,
+ )
+ )
if self.log_return_values and len(transition.return_value_handlers):
for handler in transition.return_value_handlers:
- if 'parameter' in handler:
+ if "parameter" in handler:
parameter_value = return_value = int(res.group(4))
- if 'return_values' not in log_data_target:
- log_data_target['return_values'] = list()
- log_data_target['return_values'].append(return_value)
-
- if 'formula' in handler:
- parameter_value = handler['formula'].eval(return_value)
-
- self._append_nondeterministic_parameter_value(log_data_target, handler['parameter'], parameter_value)
- for following_log_data_target in self.traces[self.trace_id]['trace'][(self.current_transition_in_trace * 2 + 1) :]:
- self._append_nondeterministic_parameter_value(following_log_data_target, handler['parameter'], parameter_value)
- if 'apply_from' in handler and any(map(lambda x: x['name'] == handler['apply_from'], self.traces[self.trace_id]['trace'][: (self.current_transition_in_trace * 2 + 1)])):
- for preceding_log_data_target in reversed(self.traces[self.trace_id]['trace'][: (self.current_transition_in_trace * 2)]):
- self._append_nondeterministic_parameter_value(preceding_log_data_target, handler['parameter'], parameter_value)
- if preceding_log_data_target['name'] == handler['apply_from']:
+ if "return_values" not in log_data_target:
+ log_data_target["return_values"] = list()
+ log_data_target["return_values"].append(return_value)
+
+ if "formula" in handler:
+ parameter_value = handler["formula"].eval(
+ return_value
+ )
+
+ self._append_nondeterministic_parameter_value(
+ log_data_target,
+ handler["parameter"],
+ parameter_value,
+ )
+ for following_log_data_target in self.traces[
+ self.trace_id
+ ]["trace"][
+ (self.current_transition_in_trace * 2 + 1) :
+ ]:
+ self._append_nondeterministic_parameter_value(
+ following_log_data_target,
+ handler["parameter"],
+ parameter_value,
+ )
+ if "apply_from" in handler and any(
+ map(
+ lambda x: x["name"] == handler["apply_from"],
+ self.traces[self.trace_id]["trace"][
+ : (self.current_transition_in_trace * 2 + 1)
+ ],
+ )
+ ):
+ for preceding_log_data_target in reversed(
+ self.traces[self.trace_id]["trace"][
+ : (self.current_transition_in_trace * 2)
+ ]
+ ):
+ self._append_nondeterministic_parameter_value(
+ preceding_log_data_target,
+ handler["parameter"],
+ parameter_value,
+ )
+ if (
+ preceding_log_data_target["name"]
+ == handler["apply_from"]
+ ):
break
- if 'offline_aggregates' not in log_data_target:
- log_data_target['offline_aggregates'] = {
- 'duration' : list()
- }
- log_data_target['offline_aggregates']['duration'].append(duration_us)
+ if "offline_aggregates" not in log_data_target:
+ log_data_target["offline_aggregates"] = {"duration": list()}
+ log_data_target["offline_aggregates"]["duration"].append(duration_us)
self.current_transition_in_trace += 1
diff --git a/lib/ipython_energymodel_prelude.py b/lib/ipython_energymodel_prelude.py
index 6777b17..0457838 100755
--- a/lib/ipython_energymodel_prelude.py
+++ b/lib/ipython_energymodel_prelude.py
@@ -5,9 +5,11 @@ from dfatool import PTAModel, RawData, soft_cast_int
ignored_trace_indexes = None
-files = '../data/20170125_125433_cc1200.tar ../data/20170125_142420_cc1200.tar ../data/20170125_144957_cc1200.tar ../data/20170125_151149_cc1200.tar ../data/20170125_151824_cc1200.tar ../data/20170125_154019_cc1200.tar'.split(' ')
-#files = '../data/20170116_124500_LM75x.tar ../data/20170116_131306_LM75x.tar'.split(' ')
+files = "../data/20170125_125433_cc1200.tar ../data/20170125_142420_cc1200.tar ../data/20170125_144957_cc1200.tar ../data/20170125_151149_cc1200.tar ../data/20170125_151824_cc1200.tar ../data/20170125_154019_cc1200.tar".split(
+ " "
+)
+# files = '../data/20170116_124500_LM75x.tar ../data/20170116_131306_LM75x.tar'.split(' ')
raw_data = RawData(files)
preprocessed_data = raw_data.get_preprocessed_data()
-model = PTAModel(preprocessed_data, ignore_trace_indexes = ignored_trace_indexes)
+model = PTAModel(preprocessed_data, ignore_trace_indexes=ignored_trace_indexes)
diff --git a/lib/keysightdlog.py b/lib/keysightdlog.py
index 0cf8da1..89264b9 100755
--- a/lib/keysightdlog.py
+++ b/lib/keysightdlog.py
@@ -8,69 +8,74 @@ import struct
import sys
import xml.etree.ElementTree as ET
+
def plot_y(Y, **kwargs):
plot_xy(np.arange(len(Y)), Y, **kwargs)
-def plot_xy(X, Y, xlabel = None, ylabel = None, title = None, output = None):
- fig, ax1 = plt.subplots(figsize=(10,6))
+
+def plot_xy(X, Y, xlabel=None, ylabel=None, title=None, output=None):
+ fig, ax1 = plt.subplots(figsize=(10, 6))
if title != None:
fig.canvas.set_window_title(title)
if xlabel != None:
ax1.set_xlabel(xlabel)
if ylabel != None:
ax1.set_ylabel(ylabel)
- plt.subplots_adjust(left = 0.1, bottom = 0.1, right = 0.99, top = 0.99)
+ plt.subplots_adjust(left=0.1, bottom=0.1, right=0.99, top=0.99)
plt.plot(X, Y, "bo", markersize=2)
if output:
plt.savefig(output)
- with open('{}.txt'.format(output), 'w') as f:
- print('X Y', file=f)
+ with open("{}.txt".format(output), "w") as f:
+ print("X Y", file=f)
for i in range(len(X)):
- print('{} {}'.format(X[i], Y[i]), file=f)
+ print("{} {}".format(X[i], Y[i]), file=f)
else:
plt.show()
+
filename = sys.argv[1]
-with open(filename, 'rb') as logfile:
+with open(filename, "rb") as logfile:
lines = []
- line = ''
+ line = ""
- if '.xz' in filename:
+ if ".xz" in filename:
f = lzma.open(logfile)
else:
f = logfile
- while line != '</dlog>\n':
+ while line != "</dlog>\n":
line = f.readline().decode()
lines.append(line)
- xml_header = ''.join(lines)
+ xml_header = "".join(lines)
raw_header = f.read(8)
data_offset = f.tell()
raw_data = f.read()
- xml_header = xml_header.replace('1ua>', 'X1ua>')
- xml_header = xml_header.replace('2ua>', 'X2ua>')
+ xml_header = xml_header.replace("1ua>", "X1ua>")
+ xml_header = xml_header.replace("2ua>", "X2ua>")
dlog = ET.fromstring(xml_header)
channels = []
- for channel in dlog.findall('channel'):
- channel_id = int(channel.get('id'))
- sense_curr = channel.find('sense_curr').text
- sense_volt = channel.find('sense_volt').text
- model = channel.find('ident').find('model').text
- if sense_volt == '1':
- channels.append((channel_id, model, 'V'))
- if sense_curr == '1':
- channels.append((channel_id, model, 'A'))
+ for channel in dlog.findall("channel"):
+ channel_id = int(channel.get("id"))
+ sense_curr = channel.find("sense_curr").text
+ sense_volt = channel.find("sense_volt").text
+ model = channel.find("ident").find("model").text
+ if sense_volt == "1":
+ channels.append((channel_id, model, "V"))
+ if sense_curr == "1":
+ channels.append((channel_id, model, "A"))
num_channels = len(channels)
- duration = int(dlog.find('frame').find('time').text)
- interval = float(dlog.find('frame').find('tint').text)
+ duration = int(dlog.find("frame").find("time").text)
+ interval = float(dlog.find("frame").find("tint").text)
real_duration = interval * int(len(raw_data) / (4 * num_channels))
- data = np.ndarray(shape=(num_channels, int(len(raw_data) / (4 * num_channels))), dtype=np.float32)
+ data = np.ndarray(
+ shape=(num_channels, int(len(raw_data) / (4 * num_channels))), dtype=np.float32
+ )
- iterator = struct.iter_unpack('>f', raw_data)
+ iterator = struct.iter_unpack(">f", raw_data)
channel_offset = 0
measurement_offset = 0
for value in iterator:
@@ -82,34 +87,59 @@ with open(filename, 'rb') as logfile:
channel_offset += 1
if int(real_duration) != duration:
- print('Measurement duration: {:f} of {:d} seconds at {:f} µs per sample'.format(
- real_duration, duration, interval * 1000000))
+ print(
+ "Measurement duration: {:f} of {:d} seconds at {:f} µs per sample".format(
+ real_duration, duration, interval * 1000000
+ )
+ )
else:
- print('Measurement duration: {:d} seconds at {:f} µs per sample'.format(
- duration, interval * 1000000))
+ print(
+ "Measurement duration: {:d} seconds at {:f} µs per sample".format(
+ duration, interval * 1000000
+ )
+ )
for i, channel in enumerate(channels):
channel_id, channel_model, channel_type = channel
- print('channel {:d} ({:s}): min {:f}, max {:f}, mean {:f} {:s}'.format(
- channel_id, channel_model, np.min(data[i]), np.max(data[i]), np.mean(data[i]),
- channel_type))
+ print(
+ "channel {:d} ({:s}): min {:f}, max {:f}, mean {:f} {:s}".format(
+ channel_id,
+ channel_model,
+ np.min(data[i]),
+ np.max(data[i]),
+ np.mean(data[i]),
+ channel_type,
+ )
+ )
- if i > 0 and channel_type == 'A' and channels[i-1][2] == 'V' and channel_id == channels[i-1][0]:
- power = data[i-1] * data[i]
+ if (
+ i > 0
+ and channel_type == "A"
+ and channels[i - 1][2] == "V"
+ and channel_id == channels[i - 1][0]
+ ):
+ power = data[i - 1] * data[i]
power = 3.6 * data[i]
- print('channel {:d} ({:s}): min {:f}, max {:f}, mean {:f} W'.format(
- channel_id, channel_model, np.min(power), np.max(power), np.mean(power)))
+ print(
+ "channel {:d} ({:s}): min {:f}, max {:f}, mean {:f} W".format(
+ channel_id, channel_model, np.min(power), np.max(power), np.mean(power)
+ )
+ )
min_power = np.min(power)
max_power = np.max(power)
power_border = np.mean([min_power, max_power])
low_power = power[power < power_border]
high_power = power[power >= power_border]
plot_y(power)
- print(' avg low / high power (delta): {:f} / {:f} ({:f}) W'.format(
- np.mean(low_power), np.mean(high_power),
- np.mean(high_power) - np.mean(low_power)))
- #plot_y(low_power)
- #plot_y(high_power)
+ print(
+ " avg low / high power (delta): {:f} / {:f} ({:f}) W".format(
+ np.mean(low_power),
+ np.mean(high_power),
+ np.mean(high_power) - np.mean(low_power),
+ )
+ )
+ # plot_y(low_power)
+ # plot_y(high_power)
high_power_durations = []
current_high_power_duration = 0
for is_hpe in power >= power_border:
@@ -119,12 +149,16 @@ for i, channel in enumerate(channels):
if current_high_power_duration > 0:
high_power_durations.append(current_high_power_duration)
current_high_power_duration = 0
- print(' avg high-power duration: {:f} µs'.format(np.mean(high_power_durations) * 1000000))
-
-#print(xml_header)
-#print(raw_header)
-#print(channels)
-#print(data)
-#print(np.mean(data[0]))
-#print(np.mean(data[1]))
-#print(np.mean(data[0] * data[1]))
+ print(
+ " avg high-power duration: {:f} µs".format(
+ np.mean(high_power_durations) * 1000000
+ )
+ )
+
+# print(xml_header)
+# print(raw_header)
+# print(channels)
+# print(data)
+# print(np.mean(data[0]))
+# print(np.mean(data[1]))
+# print(np.mean(data[0] * data[1]))
diff --git a/lib/lex.py b/lib/lex.py
index b18fa2b..7bb3760 100644
--- a/lib/lex.py
+++ b/lib/lex.py
@@ -3,34 +3,44 @@ from .sly import Lexer, Parser
class TimedWordLexer(Lexer):
tokens = {LPAREN, RPAREN, IDENTIFIER, NUMBER, ARGSEP, FUNCTIONSEP}
- ignore = ' \t'
+ ignore = " \t"
- LPAREN = r'\('
- RPAREN = r'\)'
- IDENTIFIER = r'[a-zA-Z_][a-zA-Z0-9_]*'
- NUMBER = r'[0-9e.]+'
- ARGSEP = r','
- FUNCTIONSEP = r';'
+ LPAREN = r"\("
+ RPAREN = r"\)"
+ IDENTIFIER = r"[a-zA-Z_][a-zA-Z0-9_]*"
+ NUMBER = r"[0-9e.]+"
+ ARGSEP = r","
+ FUNCTIONSEP = r";"
class TimedSequenceLexer(Lexer):
- tokens = {LPAREN, RPAREN, LBRACE, RBRACE, CYCLE, IDENTIFIER, NUMBER, ARGSEP, FUNCTIONSEP}
- ignore = ' \t'
-
- LPAREN = r'\('
- RPAREN = r'\)'
- LBRACE = r'\{'
- RBRACE = r'\}'
- CYCLE = r'cycle'
- IDENTIFIER = r'[a-zA-Z_][a-zA-Z0-9_]*'
- NUMBER = r'[0-9e.]+'
- ARGSEP = r','
- FUNCTIONSEP = r';'
+ tokens = {
+ LPAREN,
+ RPAREN,
+ LBRACE,
+ RBRACE,
+ CYCLE,
+ IDENTIFIER,
+ NUMBER,
+ ARGSEP,
+ FUNCTIONSEP,
+ }
+ ignore = " \t"
+
+ LPAREN = r"\("
+ RPAREN = r"\)"
+ LBRACE = r"\{"
+ RBRACE = r"\}"
+ CYCLE = r"cycle"
+ IDENTIFIER = r"[a-zA-Z_][a-zA-Z0-9_]*"
+ NUMBER = r"[0-9e.]+"
+ ARGSEP = r","
+ FUNCTIONSEP = r";"
def error(self, t):
print("Illegal character '%s'" % t.value[0])
- if t.value[0] == '{' and t.value.find('}'):
- self.index += 1 + t.value.find('}')
+ if t.value[0] == "{" and t.value.find("}"):
+ self.index += 1 + t.value.find("}")
else:
self.index += 1
@@ -38,39 +48,39 @@ class TimedSequenceLexer(Lexer):
class TimedWordParser(Parser):
tokens = TimedWordLexer.tokens
- @_('timedSymbol FUNCTIONSEP timedWord')
+ @_("timedSymbol FUNCTIONSEP timedWord")
def timedWord(self, p):
ret = [p.timedSymbol]
ret.extend(p.timedWord)
return ret
- @_('timedSymbol FUNCTIONSEP', 'timedSymbol')
+ @_("timedSymbol FUNCTIONSEP", "timedSymbol")
def timedWord(self, p):
return [p.timedSymbol]
- @_('IDENTIFIER', 'IDENTIFIER LPAREN RPAREN')
+ @_("IDENTIFIER", "IDENTIFIER LPAREN RPAREN")
def timedSymbol(self, p):
return (p.IDENTIFIER,)
- @_('IDENTIFIER LPAREN args RPAREN')
+ @_("IDENTIFIER LPAREN args RPAREN")
def timedSymbol(self, p):
return (p.IDENTIFIER, *p.args)
- @_('arg ARGSEP args')
+ @_("arg ARGSEP args")
def args(self, p):
ret = [p.arg]
ret.extend(p.args)
return ret
- @_('arg')
+ @_("arg")
def args(self, p):
return [p.arg]
- @_('NUMBER')
+ @_("NUMBER")
def arg(self, p):
return float(p.NUMBER)
- @_('IDENTIFIER')
+ @_("IDENTIFIER")
def arg(self, p):
return p.IDENTIFIER
@@ -78,66 +88,66 @@ class TimedWordParser(Parser):
class TimedSequenceParser(Parser):
tokens = TimedSequenceLexer.tokens
- @_('timedSequenceL', 'timedSequenceW')
+ @_("timedSequenceL", "timedSequenceW")
def timedSequence(self, p):
return p[0]
- @_('loop')
+ @_("loop")
def timedSequenceL(self, p):
return [p.loop]
- @_('loop timedSequenceW')
+ @_("loop timedSequenceW")
def timedSequenceL(self, p):
ret = [p.loop]
ret.extend(p.timedSequenceW)
return ret
- @_('timedWord')
+ @_("timedWord")
def timedSequenceW(self, p):
return [p.timedWord]
- @_('timedWord timedSequenceL')
+ @_("timedWord timedSequenceL")
def timedSequenceW(self, p):
ret = [p.timedWord]
ret.extend(p.timedSequenceL)
return ret
- @_('timedSymbol FUNCTIONSEP timedWord')
+ @_("timedSymbol FUNCTIONSEP timedWord")
def timedWord(self, p):
p.timedWord.word.insert(0, p.timedSymbol)
return p.timedWord
- @_('timedSymbol FUNCTIONSEP')
+ @_("timedSymbol FUNCTIONSEP")
def timedWord(self, p):
return TimedWord(word=[p.timedSymbol])
- @_('CYCLE LPAREN IDENTIFIER RPAREN LBRACE timedWord RBRACE')
+ @_("CYCLE LPAREN IDENTIFIER RPAREN LBRACE timedWord RBRACE")
def loop(self, p):
return Workload(p.IDENTIFIER, p.timedWord)
- @_('IDENTIFIER', 'IDENTIFIER LPAREN RPAREN')
+ @_("IDENTIFIER", "IDENTIFIER LPAREN RPAREN")
def timedSymbol(self, p):
return (p.IDENTIFIER,)
- @_('IDENTIFIER LPAREN args RPAREN')
+ @_("IDENTIFIER LPAREN args RPAREN")
def timedSymbol(self, p):
return (p.IDENTIFIER, *p.args)
- @_('arg ARGSEP args')
+ @_("arg ARGSEP args")
def args(self, p):
ret = [p.arg]
ret.extend(p.args)
return ret
- @_('arg')
+ @_("arg")
def args(self, p):
return [p.arg]
- @_('NUMBER')
+ @_("NUMBER")
def arg(self, p):
return float(p.NUMBER)
- @_('IDENTIFIER')
+ @_("IDENTIFIER")
def arg(self, p):
return p.IDENTIFIER
@@ -165,8 +175,8 @@ class TimedWord:
def __repr__(self):
ret = list()
for symbol in self.word:
- ret.append('{}({})'.format(symbol[0], ', '.join(map(str, symbol[1:]))))
- return 'TimedWord<"{}">'.format('; '.join(ret))
+ ret.append("{}({})".format(symbol[0], ", ".join(map(str, symbol[1:]))))
+ return 'TimedWord<"{}">'.format("; ".join(ret))
class Workload:
@@ -196,5 +206,5 @@ class TimedSequence:
def __repr__(self):
ret = list()
for symbol in self.seq:
- ret.append('{}'.format(symbol))
- return 'TimedSequence(seq=[{}])'.format(', '.join(ret))
+ ret.append("{}".format(symbol))
+ return "TimedSequence(seq=[{}])".format(", ".join(ret))
diff --git a/lib/modular_arithmetic.py b/lib/modular_arithmetic.py
index 0a69b79..c5ed1aa 100644
--- a/lib/modular_arithmetic.py
+++ b/lib/modular_arithmetic.py
@@ -3,6 +3,7 @@
import operator
import functools
+
@functools.total_ordering
class Mod:
"""A class for modular arithmetic, useful to simulate behaviour of uint8 and other limited data types.
@@ -14,20 +15,21 @@ class Mod:
:param val: stored integer value
Param mod: modulus
"""
- __slots__ = ['val','mod']
+
+ __slots__ = ["val", "mod"]
def __init__(self, val, mod):
if isinstance(val, Mod):
val = val.val
if not isinstance(val, int):
- raise ValueError('Value must be integer')
- if not isinstance(mod, int) or mod<=0:
- raise ValueError('Modulo must be positive integer')
+ raise ValueError("Value must be integer")
+ if not isinstance(mod, int) or mod <= 0:
+ raise ValueError("Modulo must be positive integer")
self.val = val % mod
self.mod = mod
def __repr__(self):
- return 'Mod({}, {})'.format(self.val, self.mod)
+ return "Mod({}, {})".format(self.val, self.mod)
def __int__(self):
return self.val
@@ -50,7 +52,7 @@ class Mod:
def _check_operand(self, other):
if not isinstance(other, (int, Mod)):
- raise TypeError('Only integer and Mod operands are supported')
+ raise TypeError("Only integer and Mod operands are supported")
def __pow__(self, other):
self._check_operand(other)
@@ -61,32 +63,46 @@ class Mod:
return Mod(self.mod - self.val, self.mod)
def __pos__(self):
- return self # The unary plus operator does nothing.
+ return self # The unary plus operator does nothing.
def __abs__(self):
- return self # The value is always kept non-negative, so the abs function should do nothing.
+ return self # The value is always kept non-negative, so the abs function should do nothing.
+
# Helper functions to build common operands based on a template.
# They need to be implemented as functions for the closures to work properly.
def _make_op(opname):
- op_fun = getattr(operator, opname) # Fetch the operator by name from the operator module
+ op_fun = getattr(
+ operator, opname
+ ) # Fetch the operator by name from the operator module
+
def op(self, other):
self._check_operand(other)
return Mod(op_fun(self.val, int(other)) % self.mod, self.mod)
+
return op
+
def _make_reflected_op(opname):
op_fun = getattr(operator, opname)
+
def op(self, other):
self._check_operand(other)
return Mod(op_fun(int(other), self.val) % self.mod, self.mod)
+
return op
+
# Build the actual operator overload methods based on the template.
-for opname, reflected_opname in [('__add__', '__radd__'), ('__sub__', '__rsub__'), ('__mul__', '__rmul__')]:
+for opname, reflected_opname in [
+ ("__add__", "__radd__"),
+ ("__sub__", "__rsub__"),
+ ("__mul__", "__rmul__"),
+]:
setattr(Mod, opname, _make_op(opname))
setattr(Mod, reflected_opname, _make_reflected_op(opname))
+
class Uint8(Mod):
__slots__ = []
@@ -94,7 +110,8 @@ class Uint8(Mod):
super().__init__(val, 256)
def __repr__(self):
- return 'Uint8({})'.format(self.val)
+ return "Uint8({})".format(self.val)
+
class Uint16(Mod):
__slots__ = []
@@ -103,7 +120,8 @@ class Uint16(Mod):
super().__init__(val, 65536)
def __repr__(self):
- return 'Uint16({})'.format(self.val)
+ return "Uint16({})".format(self.val)
+
class Uint32(Mod):
__slots__ = []
@@ -112,7 +130,8 @@ class Uint32(Mod):
super().__init__(val, 4294967296)
def __repr__(self):
- return 'Uint32({})'.format(self.val)
+ return "Uint32({})".format(self.val)
+
class Uint64(Mod):
__slots__ = []
@@ -121,7 +140,7 @@ class Uint64(Mod):
super().__init__(val, 18446744073709551616)
def __repr__(self):
- return 'Uint64({})'.format(self.val)
+ return "Uint64({})".format(self.val)
def simulate_int_type(int_type: str) -> Mod:
@@ -131,12 +150,12 @@ def simulate_int_type(int_type: str) -> Mod:
:param int_type: uint8_t / uint16_t / uint32_t / uint64_t
:returns: `Mod` subclass, e.g. Uint8
"""
- if int_type == 'uint8_t':
+ if int_type == "uint8_t":
return Uint8
- if int_type == 'uint16_t':
+ if int_type == "uint16_t":
return Uint16
- if int_type == 'uint32_t':
+ if int_type == "uint32_t":
return Uint32
- if int_type == 'uint64_t':
+ if int_type == "uint64_t":
return Uint64
- raise ValueError('unsupported integer type: {}'.format(int_type))
+ raise ValueError("unsupported integer type: {}".format(int_type))
diff --git a/lib/parameters.py b/lib/parameters.py
index 41e312a..8b562b6 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -21,8 +21,10 @@ def distinct_param_values(by_name, state_or_tran):
write() or similar has not been called yet. Other parameters should always
be initialized when leaving UNINITIALIZED.
"""
- distinct_values = [OrderedDict() for i in range(len(by_name[state_or_tran]['param'][0]))]
- for param_tuple in by_name[state_or_tran]['param']:
+ distinct_values = [
+ OrderedDict() for i in range(len(by_name[state_or_tran]["param"][0]))
+ ]
+ for param_tuple in by_name[state_or_tran]["param"]:
for i in range(len(param_tuple)):
distinct_values[i][param_tuple[i]] = True
@@ -30,8 +32,9 @@ def distinct_param_values(by_name, state_or_tran):
distinct_values = list(map(lambda x: list(x.keys()), distinct_values))
return distinct_values
+
def _depends_on_param(corr_param, std_param, std_lut):
- #if self.use_corrcoef:
+ # if self.use_corrcoef:
if False:
return corr_param > 0.1
elif std_param == 0:
@@ -40,6 +43,7 @@ def _depends_on_param(corr_param, std_param, std_lut):
return False
return std_lut / std_param < 0.5
+
def _reduce_param_matrix(matrix: np.ndarray, parameter_names: list) -> list:
"""
:param matrix: parameter dependence matrix, M[(...)] == 1 iff (model attribute) is influenced by (parameter) for other parameter value indxe == (...)
@@ -53,7 +57,7 @@ def _reduce_param_matrix(matrix: np.ndarray, parameter_names: list) -> list:
# Diese Abbruchbedingung scheint noch nicht so schlau zu sein...
# Mit wird zu viel rausgefiltert (z.B. auto_ack! -> max_retry_count in "bin/analyze-timing.py ../data/20190815_122531_nRF24_no-rx.json" nicht erkannt)
# Ohne wird zu wenig rausgefiltert (auch ganz viele Abhängigkeiten erkannt, bei denen eine Parameter-Abhängigketi immer unabhängig vom Wert der anderen Parameter besteht)
- #if not is_power_of_two(np.count_nonzero(matrix)):
+ # if not is_power_of_two(np.count_nonzero(matrix)):
# # cannot be reliably reduced to a list of parameters
# return list()
@@ -65,20 +69,23 @@ def _reduce_param_matrix(matrix: np.ndarray, parameter_names: list) -> list:
return influential_parameters
for axis in range(matrix.ndim):
- candidate = _reduce_param_matrix(np.all(matrix, axis=axis), remove_index_from_tuple(parameter_names, axis))
+ candidate = _reduce_param_matrix(
+ np.all(matrix, axis=axis), remove_index_from_tuple(parameter_names, axis)
+ )
if len(candidate):
return candidate
return list()
+
def _codependent_parameters(param, lut_by_param_values, std_by_param_values):
"""
Return list of parameters which affect whether a parameter affects a model attribute or not.
"""
return list()
- safe_div = np.vectorize(lambda x,y: 0. if x == 0 else 1 - x/y)
+ safe_div = np.vectorize(lambda x, y: 0.0 if x == 0 else 1 - x / y)
ratio_by_value = safe_div(lut_by_param_values, std_by_param_values)
- err_mode = np.seterr('ignore')
+ err_mode = np.seterr("ignore")
dep_by_value = ratio_by_value > 0.5
np.seterr(**err_mode)
@@ -86,7 +93,10 @@ def _codependent_parameters(param, lut_by_param_values, std_by_param_values):
influencer_parameters = _reduce_param_matrix(dep_by_value, other_param_list)
return influencer_parameters
-def _std_by_param(by_param, all_param_values, state_or_tran, attribute, param_index, verbose = False):
+
+def _std_by_param(
+ by_param, all_param_values, state_or_tran, attribute, param_index, verbose=False
+):
u"""
Calculate standard deviations for a static model where all parameters but `param_index` are constant.
@@ -130,7 +140,10 @@ def _std_by_param(by_param, all_param_values, state_or_tran, attribute, param_in
param_partition = list()
std_list = list()
for k, v in by_param.items():
- if k[0] == state_or_tran and (*k[1][:param_index], *k[1][param_index+1:]) == param_value:
+ if (
+ k[0] == state_or_tran
+ and (*k[1][:param_index], *k[1][param_index + 1 :]) == param_value
+ ):
param_partition.extend(v[attribute])
std_list.append(np.std(v[attribute]))
@@ -143,17 +156,26 @@ def _std_by_param(by_param, all_param_values, state_or_tran, attribute, param_in
lut_matrix[matrix_index] = np.mean(std_list)
# This can (and will) happen in normal operation, e.g. when a transition's
# arguments are combined using 'zip' rather than 'cartesian'.
- #elif len(param_partition) == 1:
+ # elif len(param_partition) == 1:
# vprint(verbose, '[W] parameter value partition for {} contains only one element -- skipping'.format(param_value))
- #else:
+ # else:
# vprint(verbose, '[W] parameter value partition for {} is empty'.format(param_value))
if np.all(np.isnan(stddev_matrix)):
- print('[W] {}/{} parameter #{} has no data partitions -- how did this even happen?'.format(state_or_tran, attribute, param_index))
- print('stddev_matrix = {}'.format(stddev_matrix))
- return stddev_matrix, 0.
+ print(
+ "[W] {}/{} parameter #{} has no data partitions -- how did this even happen?".format(
+ state_or_tran, attribute, param_index
+ )
+ )
+ print("stddev_matrix = {}".format(stddev_matrix))
+ return stddev_matrix, 0.0
+
+ return (
+ stddev_matrix,
+ np.nanmean(stddev_matrix),
+ lut_matrix,
+ ) # np.mean([np.std(partition) for partition in partitions])
- return stddev_matrix, np.nanmean(stddev_matrix), lut_matrix #np.mean([np.std(partition) for partition in partitions])
def _corr_by_param(by_name, state_or_trans, attribute, param_index):
"""
@@ -169,22 +191,46 @@ def _corr_by_param(by_name, state_or_trans, attribute, param_index):
:param param_index: index of parameter in `by_name[state_or_trans]['param']`
"""
if _all_params_are_numeric(by_name[state_or_trans], param_index):
- param_values = np.array(list((map(lambda x: x[param_index], by_name[state_or_trans]['param']))))
+ param_values = np.array(
+ list((map(lambda x: x[param_index], by_name[state_or_trans]["param"])))
+ )
try:
return np.corrcoef(by_name[state_or_trans][attribute], param_values)[0, 1]
except FloatingPointError:
# Typically happens when all parameter values are identical.
# Building a correlation coefficient is pointless in this case
# -> assume no correlation
- return 0.
+ return 0.0
except ValueError:
- print('[!] Exception in _corr_by_param(by_name, state_or_trans={}, attribute={}, param_index={})'.format(state_or_trans, attribute, param_index))
- print('[!] while executing np.corrcoef(by_name[{}][{}]={}, {}))'.format(state_or_trans, attribute, by_name[state_or_trans][attribute], param_values))
+ print(
+ "[!] Exception in _corr_by_param(by_name, state_or_trans={}, attribute={}, param_index={})".format(
+ state_or_trans, attribute, param_index
+ )
+ )
+ print(
+ "[!] while executing np.corrcoef(by_name[{}][{}]={}, {}))".format(
+ state_or_trans,
+ attribute,
+ by_name[state_or_trans][attribute],
+ param_values,
+ )
+ )
raise
else:
- return 0.
-
-def _compute_param_statistics(by_name, by_param, parameter_names, arg_count, state_or_trans, attribute, distinct_values, distinct_values_by_param_index, verbose = False):
+ return 0.0
+
+
+def _compute_param_statistics(
+ by_name,
+ by_param,
+ parameter_names,
+ arg_count,
+ state_or_trans,
+ attribute,
+ distinct_values,
+ distinct_values_by_param_index,
+ verbose=False,
+):
"""
Compute standard deviation and correlation coefficient for various data partitions.
@@ -223,87 +269,140 @@ def _compute_param_statistics(by_name, by_param, parameter_names, arg_count, sta
Only set if state_or_trans appears in arg_count, empty dict otherwise.
"""
ret = {
- 'std_static' : np.std(by_name[state_or_trans][attribute]),
- 'std_param_lut' : np.mean([np.std(by_param[x][attribute]) for x in by_param.keys() if x[0] == state_or_trans]),
- 'std_by_param' : {},
- 'std_by_param_values' : {},
- 'lut_by_param_values' : {},
- 'std_by_arg' : [],
- 'std_by_arg_values' : [],
- 'lut_by_arg_values' : [],
- 'corr_by_param' : {},
- 'corr_by_arg' : [],
- 'depends_on_param' : {},
- 'depends_on_arg' : [],
- 'param_data' : {},
+ "std_static": np.std(by_name[state_or_trans][attribute]),
+ "std_param_lut": np.mean(
+ [
+ np.std(by_param[x][attribute])
+ for x in by_param.keys()
+ if x[0] == state_or_trans
+ ]
+ ),
+ "std_by_param": {},
+ "std_by_param_values": {},
+ "lut_by_param_values": {},
+ "std_by_arg": [],
+ "std_by_arg_values": [],
+ "lut_by_arg_values": [],
+ "corr_by_param": {},
+ "corr_by_arg": [],
+ "depends_on_param": {},
+ "depends_on_arg": [],
+ "param_data": {},
}
- np.seterr('raise')
+ np.seterr("raise")
for param_idx, param in enumerate(parameter_names):
- std_matrix, mean_std, lut_matrix = _std_by_param(by_param, distinct_values_by_param_index, state_or_trans, attribute, param_idx, verbose)
- ret['std_by_param'][param] = mean_std
- ret['std_by_param_values'][param] = std_matrix
- ret['lut_by_param_values'][param] = lut_matrix
- ret['corr_by_param'][param] = _corr_by_param(by_name, state_or_trans, attribute, param_idx)
-
- ret['depends_on_param'][param] = _depends_on_param(ret['corr_by_param'][param], ret['std_by_param'][param], ret['std_param_lut'])
-
- if ret['depends_on_param'][param]:
- ret['param_data'][param] = {
- 'codependent_parameters': _codependent_parameters(param, lut_matrix, std_matrix),
- 'depends_for_codependent_value': dict()
+ std_matrix, mean_std, lut_matrix = _std_by_param(
+ by_param,
+ distinct_values_by_param_index,
+ state_or_trans,
+ attribute,
+ param_idx,
+ verbose,
+ )
+ ret["std_by_param"][param] = mean_std
+ ret["std_by_param_values"][param] = std_matrix
+ ret["lut_by_param_values"][param] = lut_matrix
+ ret["corr_by_param"][param] = _corr_by_param(
+ by_name, state_or_trans, attribute, param_idx
+ )
+
+ ret["depends_on_param"][param] = _depends_on_param(
+ ret["corr_by_param"][param],
+ ret["std_by_param"][param],
+ ret["std_param_lut"],
+ )
+
+ if ret["depends_on_param"][param]:
+ ret["param_data"][param] = {
+ "codependent_parameters": _codependent_parameters(
+ param, lut_matrix, std_matrix
+ ),
+ "depends_for_codependent_value": dict(),
}
# calculate parameter dependence for individual values of codependent parameters
codependent_param_values = list()
- for codependent_param in ret['param_data'][param]['codependent_parameters']:
+ for codependent_param in ret["param_data"][param]["codependent_parameters"]:
codependent_param_values.append(distinct_values[codependent_param])
for combi in itertools.product(*codependent_param_values):
by_name_part = deepcopy(by_name)
- filter_list = list(zip(ret['param_data'][param]['codependent_parameters'], combi))
+ filter_list = list(
+ zip(ret["param_data"][param]["codependent_parameters"], combi)
+ )
filter_aggregate_by_param(by_name_part, parameter_names, filter_list)
by_param_part = by_name_to_by_param(by_name_part)
# there may be no data for this specific parameter value combination
if state_or_trans in by_name_part:
- part_corr = _corr_by_param(by_name_part, state_or_trans, attribute, param_idx)
- part_std_lut = np.mean([np.std(by_param_part[x][attribute]) for x in by_param_part.keys() if x[0] == state_or_trans])
- _, part_std_param, _ = _std_by_param(by_param_part, distinct_values_by_param_index, state_or_trans, attribute, param_idx, verbose)
- ret['param_data'][param]['depends_for_codependent_value'][combi] = _depends_on_param(part_corr, part_std_param, part_std_lut)
+ part_corr = _corr_by_param(
+ by_name_part, state_or_trans, attribute, param_idx
+ )
+ part_std_lut = np.mean(
+ [
+ np.std(by_param_part[x][attribute])
+ for x in by_param_part.keys()
+ if x[0] == state_or_trans
+ ]
+ )
+ _, part_std_param, _ = _std_by_param(
+ by_param_part,
+ distinct_values_by_param_index,
+ state_or_trans,
+ attribute,
+ param_idx,
+ verbose,
+ )
+ ret["param_data"][param]["depends_for_codependent_value"][
+ combi
+ ] = _depends_on_param(part_corr, part_std_param, part_std_lut)
if state_or_trans in arg_count:
for arg_index in range(arg_count[state_or_trans]):
- std_matrix, mean_std, lut_matrix = _std_by_param(by_param, distinct_values_by_param_index, state_or_trans, attribute, len(parameter_names) + arg_index, verbose)
- ret['std_by_arg'].append(mean_std)
- ret['std_by_arg_values'].append(std_matrix)
- ret['lut_by_arg_values'].append(lut_matrix)
- ret['corr_by_arg'].append(_corr_by_param(by_name, state_or_trans, attribute, len(parameter_names) + arg_index))
+ std_matrix, mean_std, lut_matrix = _std_by_param(
+ by_param,
+ distinct_values_by_param_index,
+ state_or_trans,
+ attribute,
+ len(parameter_names) + arg_index,
+ verbose,
+ )
+ ret["std_by_arg"].append(mean_std)
+ ret["std_by_arg_values"].append(std_matrix)
+ ret["lut_by_arg_values"].append(lut_matrix)
+ ret["corr_by_arg"].append(
+ _corr_by_param(
+ by_name, state_or_trans, attribute, len(parameter_names) + arg_index
+ )
+ )
if False:
- ret['depends_on_arg'].append(ret['corr_by_arg'][arg_index] > 0.1)
- elif ret['std_by_arg'][arg_index] == 0:
+ ret["depends_on_arg"].append(ret["corr_by_arg"][arg_index] > 0.1)
+ elif ret["std_by_arg"][arg_index] == 0:
# In general, std_param_lut < std_by_arg. So, if std_by_arg == 0, std_param_lut == 0 follows.
# This means that the variation of arg does not affect the model quality -> no influence
- ret['depends_on_arg'].append(False)
+ ret["depends_on_arg"].append(False)
else:
- ret['depends_on_arg'].append(ret['std_param_lut'] / ret['std_by_arg'][arg_index] < 0.5)
+ ret["depends_on_arg"].append(
+ ret["std_param_lut"] / ret["std_by_arg"][arg_index] < 0.5
+ )
return ret
+
def _compute_param_statistics_parallel(arg):
- return {
- 'key' : arg['key'],
- 'result': _compute_param_statistics(*arg['args'])
- }
+ return {"key": arg["key"], "result": _compute_param_statistics(*arg["args"])}
+
def _all_params_are_numeric(data, param_idx):
"""Check if all `data['param'][*][param_idx]` elements are numeric, as reported by `utils.is_numeric`."""
- param_values = list(map(lambda x: x[param_idx], data['param']))
+ param_values = list(map(lambda x: x[param_idx], data["param"]))
if len(list(filter(is_numeric, param_values))) == len(param_values):
return True
return False
-def prune_dependent_parameters(by_name, parameter_names, correlation_threshold = 0.5):
+
+def prune_dependent_parameters(by_name, parameter_names, correlation_threshold=0.5):
"""
Remove dependent parameters from aggregate.
@@ -320,15 +419,17 @@ def prune_dependent_parameters(by_name, parameter_names, correlation_threshold =
"""
parameter_indices_to_remove = list()
- for parameter_combination in itertools.product(range(len(parameter_names)), range(len(parameter_names))):
+ for parameter_combination in itertools.product(
+ range(len(parameter_names)), range(len(parameter_names))
+ ):
index_1, index_2 = parameter_combination
if index_1 >= index_2:
continue
- parameter_values = [list(), list()] # both parameters have a value
- parameter_values_1 = list() # parameter 1 has a value
- parameter_values_2 = list() # parameter 2 has a value
+ parameter_values = [list(), list()] # both parameters have a value
+ parameter_values_1 = list() # parameter 1 has a value
+ parameter_values_2 = list() # parameter 2 has a value
for name in by_name:
- for measurement in by_name[name]['param']:
+ for measurement in by_name[name]["param"]:
value_1 = measurement[index_1]
value_2 = measurement[index_2]
if is_numeric(value_1):
@@ -342,16 +443,30 @@ def prune_dependent_parameters(by_name, parameter_names, correlation_threshold =
# Calculating the correlation coefficient only makes sense when neither value is constant
if np.std(parameter_values_1) != 0 and np.std(parameter_values_2) != 0:
correlation = np.corrcoef(parameter_values)[0][1]
- if correlation != np.nan and np.abs(correlation) > correlation_threshold:
- print('[!] Parameters {} <-> {} are correlated with coefficcient {}'.format(parameter_names[index_1], parameter_names[index_2], correlation))
+ if (
+ correlation != np.nan
+ and np.abs(correlation) > correlation_threshold
+ ):
+ print(
+ "[!] Parameters {} <-> {} are correlated with coefficcient {}".format(
+ parameter_names[index_1],
+ parameter_names[index_2],
+ correlation,
+ )
+ )
if len(parameter_values_1) < len(parameter_values_2):
index_to_remove = index_1
else:
index_to_remove = index_2
- print(' Removing parameter {}'.format(parameter_names[index_to_remove]))
+ print(
+ " Removing parameter {}".format(
+ parameter_names[index_to_remove]
+ )
+ )
parameter_indices_to_remove.append(index_to_remove)
remove_parameters_by_indices(by_name, parameter_names, parameter_indices_to_remove)
+
def remove_parameters_by_indices(by_name, parameter_names, parameter_indices_to_remove):
"""
Remove parameters listed in `parameter_indices` from aggregate `by_name` and `parameter_names`.
@@ -365,12 +480,13 @@ def remove_parameters_by_indices(by_name, parameter_names, parameter_indices_to_
"""
# Start removal from the end of the list to avoid renumbering of list elemenets
- for parameter_index in sorted(parameter_indices_to_remove, reverse = True):
+ for parameter_index in sorted(parameter_indices_to_remove, reverse=True):
for name in by_name:
- for measurement in by_name[name]['param']:
+ for measurement in by_name[name]["param"]:
measurement.pop(parameter_index)
parameter_names.pop(parameter_index)
+
class ParamStats:
"""
:param stats: `stats[state_or_tran][attribute]` = std_static, std_param_lut, ... (see `compute_param_statistics`)
@@ -378,7 +494,15 @@ class ParamStats:
:param distinct_values_by_param_index: `distinct_values[state_or_tran][i]` = [distinct values in aggregate]
"""
- def __init__(self, by_name, by_param, parameter_names, arg_count, use_corrcoef = False, verbose = False):
+ def __init__(
+ self,
+ by_name,
+ by_param,
+ parameter_names,
+ arg_count,
+ use_corrcoef=False,
+ verbose=False,
+ ):
"""
Compute standard deviation and correlation coefficient on parameterized data partitions.
@@ -411,24 +535,40 @@ class ParamStats:
for state_or_tran in by_name.keys():
self.stats[state_or_tran] = dict()
- self.distinct_values_by_param_index[state_or_tran] = distinct_param_values(by_name, state_or_tran)
+ self.distinct_values_by_param_index[state_or_tran] = distinct_param_values(
+ by_name, state_or_tran
+ )
self.distinct_values[state_or_tran] = dict()
for i, param in enumerate(parameter_names):
- self.distinct_values[state_or_tran][param] = self.distinct_values_by_param_index[state_or_tran][i]
- for attribute in by_name[state_or_tran]['attributes']:
- stats_queue.append({
- 'key': [state_or_tran, attribute],
- 'args': [by_name, by_param, parameter_names, arg_count, state_or_tran, attribute, self.distinct_values[state_or_tran], self.distinct_values_by_param_index[state_or_tran], verbose],
- })
+ self.distinct_values[state_or_tran][
+ param
+ ] = self.distinct_values_by_param_index[state_or_tran][i]
+ for attribute in by_name[state_or_tran]["attributes"]:
+ stats_queue.append(
+ {
+ "key": [state_or_tran, attribute],
+ "args": [
+ by_name,
+ by_param,
+ parameter_names,
+ arg_count,
+ state_or_tran,
+ attribute,
+ self.distinct_values[state_or_tran],
+ self.distinct_values_by_param_index[state_or_tran],
+ verbose,
+ ],
+ }
+ )
with Pool() as pool:
stats_results = pool.map(_compute_param_statistics_parallel, stats_queue)
for stats in stats_results:
- state_or_tran, attribute = stats['key']
- self.stats[state_or_tran][attribute] = stats['result']
+ state_or_tran, attribute = stats["key"]
+ self.stats[state_or_tran][attribute] = stats["result"]
- def can_be_fitted(self, state_or_tran = None) -> bool:
+ def can_be_fitted(self, state_or_tran=None) -> bool:
"""
Return whether a sufficient amount of distinct numeric parameter values is available, allowing a parameter-aware model to be generated.
@@ -441,8 +581,27 @@ class ParamStats:
for key in keys:
for param in self._parameter_names:
- if len(list(filter(lambda n: is_numeric(n), self.distinct_values[key][param]))) > 2:
- print(key, param, list(filter(lambda n: is_numeric(n), self.distinct_values[key][param])))
+ if (
+ len(
+ list(
+ filter(
+ lambda n: is_numeric(n),
+ self.distinct_values[key][param],
+ )
+ )
+ )
+ > 2
+ ):
+ print(
+ key,
+ param,
+ list(
+ filter(
+ lambda n: is_numeric(n),
+ self.distinct_values[key][param],
+ )
+ ),
+ )
return True
return False
@@ -456,7 +615,9 @@ class ParamStats:
# TODO
pass
- def has_codependent_parameters(self, state_or_tran: str, attribute: str, param: str) -> bool:
+ def has_codependent_parameters(
+ self, state_or_tran: str, attribute: str, param: str
+ ) -> bool:
"""
Return whether there are parameters which determine whether `param` influences `state_or_tran` `attribute` or not.
@@ -468,7 +629,9 @@ class ParamStats:
return True
return False
- def codependent_parameters(self, state_or_tran: str, attribute: str, param: str) -> list:
+ def codependent_parameters(
+ self, state_or_tran: str, attribute: str, param: str
+ ) -> list:
"""
Return list of parameters which determine whether `param` influences `state_or_tran` `attribute` or not.
@@ -476,12 +639,15 @@ class ParamStats:
:param attribute: model attribute
:param param: parameter name
"""
- if self.stats[state_or_tran][attribute]['depends_on_param'][param]:
- return self.stats[state_or_tran][attribute]['param_data'][param]['codependent_parameters']
+ if self.stats[state_or_tran][attribute]["depends_on_param"][param]:
+ return self.stats[state_or_tran][attribute]["param_data"][param][
+ "codependent_parameters"
+ ]
return list()
-
- def has_codependent_parameters_union(self, state_or_tran: str, attribute: str) -> bool:
+ def has_codependent_parameters_union(
+ self, state_or_tran: str, attribute: str
+ ) -> bool:
"""
Return whether there is a subset of parameters which decides whether `state_or_tran` `attribute` is static or parameter-dependent
@@ -490,11 +656,14 @@ class ParamStats:
"""
depends_on_a_parameter = False
for param in self._parameter_names:
- if self.stats[state_or_tran][attribute]['depends_on_param'][param]:
- print('{}/{} depends on {}'.format(state_or_tran, attribute, param))
+ if self.stats[state_or_tran][attribute]["depends_on_param"][param]:
+ print("{}/{} depends on {}".format(state_or_tran, attribute, param))
depends_on_a_parameter = True
- if len(self.codependent_parameters(state_or_tran, attribute, param)) == 0:
- print('has no codependent parameters')
+ if (
+ len(self.codependent_parameters(state_or_tran, attribute, param))
+ == 0
+ ):
+ print("has no codependent parameters")
# Always depends on this parameter, regardless of other parameters' values
return False
return depends_on_a_parameter
@@ -508,14 +677,21 @@ class ParamStats:
"""
codependent_parameters = set()
for param in self._parameter_names:
- if self.stats[state_or_tran][attribute]['depends_on_param'][param]:
- if len(self.codependent_parameters(state_or_tran, attribute, param)) == 0:
+ if self.stats[state_or_tran][attribute]["depends_on_param"][param]:
+ if (
+ len(self.codependent_parameters(state_or_tran, attribute, param))
+ == 0
+ ):
return list(self._parameter_names)
- for codependent_param in self.codependent_parameters(state_or_tran, attribute, param):
+ for codependent_param in self.codependent_parameters(
+ state_or_tran, attribute, param
+ ):
codependent_parameters.add(codependent_param)
return sorted(codependent_parameters)
- def codependence_by_codependent_param_values(self, state_or_tran: str, attribute: str, param: str) -> dict:
+ def codependence_by_codependent_param_values(
+ self, state_or_tran: str, attribute: str, param: str
+ ) -> dict:
"""
Return dict mapping codependent parameter values to a boolean indicating whether `param` influences `state_or_tran` `attribute`.
@@ -525,11 +701,15 @@ class ParamStats:
:param attribute: model attribute
:param param: parameter name
"""
- if self.stats[state_or_tran][attribute]['depends_on_param'][param]:
- return self.stats[state_or_tran][attribute]['param_data'][param]['depends_for_codependent_value']
+ if self.stats[state_or_tran][attribute]["depends_on_param"][param]:
+ return self.stats[state_or_tran][attribute]["param_data"][param][
+ "depends_for_codependent_value"
+ ]
return dict()
- def codependent_parameter_value_dicts(self, state_or_tran: str, attribute: str, param: str, kind='dynamic'):
+ def codependent_parameter_value_dicts(
+ self, state_or_tran: str, attribute: str, param: str, kind="dynamic"
+ ):
"""
Return dicts of codependent parameter key-value mappings for which `param` influences (or does not influence) `state_or_tran` `attribute`.
@@ -538,16 +718,21 @@ class ParamStats:
:param param: parameter name:
:param kind: 'static' or 'dynamic'. If 'dynamic' (the default), returns codependent parameter values for which `param` influences `attribute`. If 'static', returns codependent parameter values for which `param` does not influence `attribute`
"""
- codependent_parameters = self.stats[state_or_tran][attribute]['param_data'][param]['codependent_parameters']
- codependence_info = self.stats[state_or_tran][attribute]['param_data'][param]['depends_for_codependent_value']
+ codependent_parameters = self.stats[state_or_tran][attribute]["param_data"][
+ param
+ ]["codependent_parameters"]
+ codependence_info = self.stats[state_or_tran][attribute]["param_data"][param][
+ "depends_for_codependent_value"
+ ]
if len(codependent_parameters) == 0:
return
else:
for param_values, is_dynamic in codependence_info.items():
- if (is_dynamic and kind == 'dynamic') or (not is_dynamic and kind == 'static'):
+ if (is_dynamic and kind == "dynamic") or (
+ not is_dynamic and kind == "static"
+ ):
yield dict(zip(codependent_parameters, param_values))
-
def _generic_param_independence_ratio(self, state_or_trans, attribute):
"""
Return the heuristic ratio of parameter independence for state_or_trans and attribute.
@@ -559,9 +744,9 @@ class ParamStats:
if self.use_corrcoef:
# not supported
raise ValueError
- if statistics['std_static'] == 0:
+ if statistics["std_static"] == 0:
return 0
- return statistics['std_param_lut'] / statistics['std_static']
+ return statistics["std_param_lut"] / statistics["std_static"]
def generic_param_dependence_ratio(self, state_or_trans, attribute):
"""
@@ -572,7 +757,9 @@ class ParamStats:
"""
return 1 - self._generic_param_independence_ratio(state_or_trans, attribute)
- def _param_independence_ratio(self, state_or_trans: str, attribute: str, param: str) -> float:
+ def _param_independence_ratio(
+ self, state_or_trans: str, attribute: str, param: str
+ ) -> float:
"""
Return the heuristic ratio of parameter independence for state_or_trans, attribute, and param.
@@ -580,17 +767,19 @@ class ParamStats:
"""
statistics = self.stats[state_or_trans][attribute]
if self.use_corrcoef:
- return 1 - np.abs(statistics['corr_by_param'][param])
- if statistics['std_by_param'][param] == 0:
- if statistics['std_param_lut'] != 0:
+ return 1 - np.abs(statistics["corr_by_param"][param])
+ if statistics["std_by_param"][param] == 0:
+ if statistics["std_param_lut"] != 0:
raise RuntimeError("wat")
# In general, std_param_lut < std_by_param. So, if std_by_param == 0, std_param_lut == 0 follows.
# This means that the variation of param does not affect the model quality -> no influence, return 1
- return 1.
+ return 1.0
- return statistics['std_param_lut'] / statistics['std_by_param'][param]
+ return statistics["std_param_lut"] / statistics["std_by_param"][param]
- def param_dependence_ratio(self, state_or_trans: str, attribute: str, param: str) -> float:
+ def param_dependence_ratio(
+ self, state_or_trans: str, attribute: str, param: str
+ ) -> float:
"""
Return the heuristic ratio of parameter dependence for state_or_trans, attribute, and param.
@@ -607,16 +796,18 @@ class ParamStats:
def _arg_independence_ratio(self, state_or_trans, attribute, arg_index):
statistics = self.stats[state_or_trans][attribute]
if self.use_corrcoef:
- return 1 - np.abs(statistics['corr_by_arg'][arg_index])
- if statistics['std_by_arg'][arg_index] == 0:
- if statistics['std_param_lut'] != 0:
+ return 1 - np.abs(statistics["corr_by_arg"][arg_index])
+ if statistics["std_by_arg"][arg_index] == 0:
+ if statistics["std_param_lut"] != 0:
raise RuntimeError("wat")
# In general, std_param_lut < std_by_arg. So, if std_by_arg == 0, std_param_lut == 0 follows.
# This means that the variation of arg does not affect the model quality -> no influence, return 1
return 1
- return statistics['std_param_lut'] / statistics['std_by_arg'][arg_index]
+ return statistics["std_param_lut"] / statistics["std_by_arg"][arg_index]
- def arg_dependence_ratio(self, state_or_trans: str, attribute: str, arg_index: int) -> float:
+ def arg_dependence_ratio(
+ self, state_or_trans: str, attribute: str, arg_index: int
+ ) -> float:
return 1 - self._arg_independence_ratio(state_or_trans, attribute, arg_index)
# This heuristic is very similar to the "function is not much better than
@@ -625,10 +816,9 @@ class ParamStats:
# --df, 2018-04-18
def depends_on_param(self, state_or_trans, attribute, param):
"""Return whether attribute of state_or_trans depens on param."""
- return self.stats[state_or_trans][attribute]['depends_on_param'][param]
+ return self.stats[state_or_trans][attribute]["depends_on_param"][param]
# See notes on depends_on_param
def depends_on_arg(self, state_or_trans, attribute, arg_index):
"""Return whether attribute of state_or_trans depens on arg_index."""
- return self.stats[state_or_trans][attribute]['depends_on_arg'][arg_index]
-
+ return self.stats[state_or_trans][attribute]["depends_on_arg"][arg_index]
diff --git a/lib/plotter.py b/lib/plotter.py
index deed93a..16c0145 100755
--- a/lib/plotter.py
+++ b/lib/plotter.py
@@ -8,75 +8,89 @@ import re
def is_state(aggregate, name):
"""Return true if name is a state and not UNINITIALIZED."""
- return aggregate[name]['isa'] == 'state' and name != 'UNINITIALIZED'
+ return aggregate[name]["isa"] == "state" and name != "UNINITIALIZED"
def plot_states(model, aggregate):
keys = [key for key in sorted(aggregate.keys()) if is_state(aggregate, key)]
- data = [aggregate[key]['means'] for key in keys]
- mdata = [int(model['state'][key]['power']['static']) for key in keys]
- boxplot(keys, data, 'Zustand', 'µW', modeldata=mdata)
+ data = [aggregate[key]["means"] for key in keys]
+ mdata = [int(model["state"][key]["power"]["static"]) for key in keys]
+ boxplot(keys, data, "Zustand", "µW", modeldata=mdata)
def plot_transitions(model, aggregate):
- keys = [key for key in sorted(aggregate.keys()) if aggregate[key]['isa'] == 'transition']
- data = [aggregate[key]['rel_energies'] for key in keys]
- mdata = [int(model['transition'][key]['rel_energy']['static']) for key in keys]
- boxplot(keys, data, 'Transition', 'pJ (rel)', modeldata=mdata)
- data = [aggregate[key]['energies'] for key in keys]
- mdata = [int(model['transition'][key]['energy']['static']) for key in keys]
- boxplot(keys, data, 'Transition', 'pJ', modeldata=mdata)
+ keys = [
+ key for key in sorted(aggregate.keys()) if aggregate[key]["isa"] == "transition"
+ ]
+ data = [aggregate[key]["rel_energies"] for key in keys]
+ mdata = [int(model["transition"][key]["rel_energy"]["static"]) for key in keys]
+ boxplot(keys, data, "Transition", "pJ (rel)", modeldata=mdata)
+ data = [aggregate[key]["energies"] for key in keys]
+ mdata = [int(model["transition"][key]["energy"]["static"]) for key in keys]
+ boxplot(keys, data, "Transition", "pJ", modeldata=mdata)
def plot_states_duration(model, aggregate):
keys = [key for key in sorted(aggregate.keys()) if is_state(aggregate, key)]
- data = [aggregate[key]['durations'] for key in keys]
- boxplot(keys, data, 'Zustand', 'µs')
+ data = [aggregate[key]["durations"] for key in keys]
+ boxplot(keys, data, "Zustand", "µs")
def plot_transitions_duration(model, aggregate):
- keys = [key for key in sorted(aggregate.keys()) if aggregate[key]['isa'] == 'transition']
- data = [aggregate[key]['durations'] for key in keys]
- boxplot(keys, data, 'Transition', 'µs')
+ keys = [
+ key for key in sorted(aggregate.keys()) if aggregate[key]["isa"] == "transition"
+ ]
+ data = [aggregate[key]["durations"] for key in keys]
+ boxplot(keys, data, "Transition", "µs")
def plot_transitions_timeout(model, aggregate):
- keys = [key for key in sorted(aggregate.keys()) if aggregate[key]['isa'] == 'transition']
- data = [aggregate[key]['timeouts'] for key in keys]
- boxplot(keys, data, 'Timeout', 'µs')
+ keys = [
+ key for key in sorted(aggregate.keys()) if aggregate[key]["isa"] == "transition"
+ ]
+ data = [aggregate[key]["timeouts"] for key in keys]
+ boxplot(keys, data, "Timeout", "µs")
def plot_states_clips(model, aggregate):
keys = [key for key in sorted(aggregate.keys()) if is_state(aggregate, key)]
- data = [np.array([100]) * aggregate[key]['clip_rate'] for key in keys]
- boxplot(keys, data, 'Zustand', '% Clipping')
+ data = [np.array([100]) * aggregate[key]["clip_rate"] for key in keys]
+ boxplot(keys, data, "Zustand", "% Clipping")
def plot_transitions_clips(model, aggregate):
- keys = [key for key in sorted(aggregate.keys()) if aggregate[key]['isa'] == 'transition']
- data = [np.array([100]) * aggregate[key]['clip_rate'] for key in keys]
- boxplot(keys, data, 'Transition', '% Clipping')
+ keys = [
+ key for key in sorted(aggregate.keys()) if aggregate[key]["isa"] == "transition"
+ ]
+ data = [np.array([100]) * aggregate[key]["clip_rate"] for key in keys]
+ boxplot(keys, data, "Transition", "% Clipping")
def plot_substate_thresholds(model, aggregate):
keys = [key for key in sorted(aggregate.keys()) if is_state(aggregate, key)]
- data = [aggregate[key]['sub_thresholds'] for key in keys]
- boxplot(keys, data, 'Zustand', 'substate threshold (mW/dmW)')
+ data = [aggregate[key]["sub_thresholds"] for key in keys]
+ boxplot(keys, data, "Zustand", "substate threshold (mW/dmW)")
def plot_histogram(data):
- n, bins, patches = plt.hist(data, 1000, normed=1, facecolor='green', alpha=0.75)
+ n, bins, patches = plt.hist(data, 1000, normed=1, facecolor="green", alpha=0.75)
plt.show()
def plot_states_param(model, aggregate):
- keys = [key for key in sorted(aggregate.keys()) if aggregate[key]['isa'] == 'state' and key[0] != 'UNINITIALIZED']
- data = [aggregate[key]['means'] for key in keys]
- mdata = [int(model['state'][key[0]]['power']['static']) for key in keys]
- boxplot(keys, data, 'Transition', 'µW', modeldata=mdata)
-
-
-def plot_attribute(aggregate, attribute, attribute_unit='', key_filter=lambda x: True, **kwargs):
+ keys = [
+ key
+ for key in sorted(aggregate.keys())
+ if aggregate[key]["isa"] == "state" and key[0] != "UNINITIALIZED"
+ ]
+ data = [aggregate[key]["means"] for key in keys]
+ mdata = [int(model["state"][key[0]]["power"]["static"]) for key in keys]
+ boxplot(keys, data, "Transition", "µW", modeldata=mdata)
+
+
+def plot_attribute(
+ aggregate, attribute, attribute_unit="", key_filter=lambda x: True, **kwargs
+):
"""
Boxplot measurements of a single attribute according to the partitioning provided by aggregate.
@@ -94,13 +108,17 @@ def plot_attribute(aggregate, attribute, attribute_unit='', key_filter=lambda x:
def plot_substate_thresholds_p(model, aggregate):
- keys = [key for key in sorted(aggregate.keys()) if aggregate[key]['isa'] == 'state' and key[0] != 'UNINITIALIZED']
- data = [aggregate[key]['sub_thresholds'] for key in keys]
- boxplot(keys, data, 'Zustand', '% Clipping')
+ keys = [
+ key
+ for key in sorted(aggregate.keys())
+ if aggregate[key]["isa"] == "state" and key[0] != "UNINITIALIZED"
+ ]
+ data = [aggregate[key]["sub_thresholds"] for key in keys]
+ boxplot(keys, data, "Zustand", "% Clipping")
def plot_y(Y, **kwargs):
- if 'family' in kwargs and kwargs['family']:
+ if "family" in kwargs and kwargs["family"]:
plot_xy(None, Y, **kwargs)
else:
plot_xy(np.arange(len(Y)), Y, **kwargs)
@@ -116,26 +134,39 @@ def plot_xy(X, Y, xlabel=None, ylabel=None, title=None, output=None, family=Fals
ax1.set_ylabel(ylabel)
plt.subplots_adjust(left=0.1, bottom=0.1, right=0.99, top=0.99)
if family:
- cm = plt.get_cmap('brg', len(Y))
+ cm = plt.get_cmap("brg", len(Y))
for i, YY in enumerate(Y):
plt.plot(np.arange(len(YY)), YY, "-", markersize=2, color=cm(i))
else:
plt.plot(X, Y, "bo", markersize=2)
if output:
plt.savefig(output)
- with open('{}.txt'.format(output), 'w') as f:
- print('X Y', file=f)
+ with open("{}.txt".format(output), "w") as f:
+ print("X Y", file=f)
for i in range(len(X)):
- print('{} {}'.format(X[i], Y[i]), file=f)
+ print("{} {}".format(X[i], Y[i]), file=f)
else:
plt.show()
def _param_slice_eq(a, b, index):
- return (*a[1][:index], *a[1][index + 1:]) == (*b[1][:index], *b[1][index + 1:]) and a[0] == b[0]
-
-
-def plot_param(model, state_or_trans, attribute, param_idx, xlabel=None, ylabel=None, title=None, extra_function=None, output=None):
+ return (*a[1][:index], *a[1][index + 1 :]) == (
+ *b[1][:index],
+ *b[1][index + 1 :],
+ ) and a[0] == b[0]
+
+
+def plot_param(
+ model,
+ state_or_trans,
+ attribute,
+ param_idx,
+ xlabel=None,
+ ylabel=None,
+ title=None,
+ extra_function=None,
+ output=None,
+):
fig, ax1 = plt.subplots(figsize=(10, 6))
if title is not None:
fig.canvas.set_window_title(title)
@@ -147,8 +178,12 @@ def plot_param(model, state_or_trans, attribute, param_idx, xlabel=None, ylabel=
param_name = model.param_name(param_idx)
- function_filename = 'plot_param_{}_{}_{}.txt'.format(state_or_trans, attribute, param_name)
- data_filename_base = 'measurements_{}_{}_{}'.format(state_or_trans, attribute, param_name)
+ function_filename = "plot_param_{}_{}_{}.txt".format(
+ state_or_trans, attribute, param_name
+ )
+ data_filename_base = "measurements_{}_{}_{}".format(
+ state_or_trans, attribute, param_name
+ )
param_model, param_info = model.get_fitted()
@@ -156,16 +191,18 @@ def plot_param(model, state_or_trans, attribute, param_idx, xlabel=None, ylabel=
XX = []
- legend_sanitizer = re.compile(r'[^0-9a-zA-Z]+')
+ legend_sanitizer = re.compile(r"[^0-9a-zA-Z]+")
for k, v in model.by_param.items():
if k[0] == state_or_trans:
- other_param_key = (*k[1][:param_idx], *k[1][param_idx + 1:])
+ other_param_key = (*k[1][:param_idx], *k[1][param_idx + 1 :])
if other_param_key not in by_other_param:
- by_other_param[other_param_key] = {'X': [], 'Y': []}
- by_other_param[other_param_key]['X'].extend([float(k[1][param_idx])] * len(v[attribute]))
- by_other_param[other_param_key]['Y'].extend(v[attribute])
- XX.extend(by_other_param[other_param_key]['X'])
+ by_other_param[other_param_key] = {"X": [], "Y": []}
+ by_other_param[other_param_key]["X"].extend(
+ [float(k[1][param_idx])] * len(v[attribute])
+ )
+ by_other_param[other_param_key]["Y"].extend(v[attribute])
+ XX.extend(by_other_param[other_param_key]["X"])
XX = np.array(XX)
x_range = int((XX.max() - XX.min()) * 10)
@@ -175,22 +212,22 @@ def plot_param(model, state_or_trans, attribute, param_idx, xlabel=None, ylabel=
YY2 = []
YY2_legend = []
- cm = plt.get_cmap('brg', len(by_other_param))
+ cm = plt.get_cmap("brg", len(by_other_param))
for i, k in sorted(enumerate(by_other_param), key=lambda x: x[1]):
v = by_other_param[k]
- v['X'] = np.array(v['X'])
- v['Y'] = np.array(v['Y'])
- plt.plot(v['X'], v['Y'], "ro", color=cm(i), markersize=3)
- YY2_legend.append(legend_sanitizer.sub('_', 'X_{}'.format(k)))
- YY2.append(v['X'])
- YY2_legend.append(legend_sanitizer.sub('_', 'Y_{}'.format(k)))
- YY2.append(v['Y'])
-
- sanitized_k = legend_sanitizer.sub('_', str(k))
- with open('{}_{}.txt'.format(data_filename_base, sanitized_k), 'w') as f:
- print('X Y', file=f)
- for i in range(len(v['X'])):
- print('{} {}'.format(v['X'][i], v['Y'][i]), file=f)
+ v["X"] = np.array(v["X"])
+ v["Y"] = np.array(v["Y"])
+ plt.plot(v["X"], v["Y"], "ro", color=cm(i), markersize=3)
+ YY2_legend.append(legend_sanitizer.sub("_", "X_{}".format(k)))
+ YY2.append(v["X"])
+ YY2_legend.append(legend_sanitizer.sub("_", "Y_{}".format(k)))
+ YY2.append(v["Y"])
+
+ sanitized_k = legend_sanitizer.sub("_", str(k))
+ with open("{}_{}.txt".format(data_filename_base, sanitized_k), "w") as f:
+ print("X Y", file=f)
+ for i in range(len(v["X"])):
+ print("{} {}".format(v["X"][i], v["Y"][i]), file=f)
# x_range = int((v['X'].max() - v['X'].min()) * 10)
# xsp = np.linspace(v['X'].min(), v['X'].max(), x_range)
@@ -201,21 +238,21 @@ def plot_param(model, state_or_trans, attribute, param_idx, xlabel=None, ylabel=
ysp.append(param_model(state_or_trans, attribute, param=xarg))
plt.plot(xsp, ysp, "r-", color=cm(i), linewidth=0.5)
YY.append(ysp)
- YY_legend.append(legend_sanitizer.sub('_', 'regr_{}'.format(k)))
+ YY_legend.append(legend_sanitizer.sub("_", "regr_{}".format(k)))
if extra_function is not None:
ysp = []
- with np.errstate(divide='ignore', invalid='ignore'):
+ with np.errstate(divide="ignore", invalid="ignore"):
for x in xsp:
xarg = [*k[:param_idx], x, *k[param_idx:]]
ysp.append(extra_function(*xarg))
plt.plot(xsp, ysp, "r--", color=cm(i), linewidth=1, dashes=(3, 3))
YY.append(ysp)
- YY_legend.append(legend_sanitizer.sub('_', 'symb_{}'.format(k)))
+ YY_legend.append(legend_sanitizer.sub("_", "symb_{}".format(k)))
- with open(function_filename, 'w') as f:
- print(' '.join(YY_legend), file=f)
+ with open(function_filename, "w") as f:
+ print(" ".join(YY_legend), file=f)
for elem in np.array(YY).T:
- print(' '.join(map(str, elem)), file=f)
+ print(" ".join(map(str, elem)), file=f)
print(data_filename_base, function_filename)
if output:
@@ -224,7 +261,19 @@ def plot_param(model, state_or_trans, attribute, param_idx, xlabel=None, ylabel=
plt.show()
-def plot_param_fit(function, name, fitfunc, funp, parameters, datatype, index, X, Y, xaxis=None, yaxis=None):
+def plot_param_fit(
+ function,
+ name,
+ fitfunc,
+ funp,
+ parameters,
+ datatype,
+ index,
+ X,
+ Y,
+ xaxis=None,
+ yaxis=None,
+):
fig, ax1 = plt.subplots(figsize=(10, 6))
fig.canvas.set_window_title("fit %s" % (function))
plt.subplots_adjust(left=0.14, right=0.99, top=0.99, bottom=0.14)
@@ -244,10 +293,10 @@ def plot_param_fit(function, name, fitfunc, funp, parameters, datatype, index, X
if yaxis is not None:
ax1.set_ylabel(yaxis)
else:
- ax1.set_ylabel('%s %s' % (name, datatype))
+ ax1.set_ylabel("%s %s" % (name, datatype))
- otherparams = list(set(itertools.product(*X[:index], *X[index + 1:])))
- cm = plt.get_cmap('brg', len(otherparams))
+ otherparams = list(set(itertools.product(*X[:index], *X[index + 1 :])))
+ cm = plt.get_cmap("brg", len(otherparams))
for i in range(len(otherparams)):
elem = otherparams[i]
color = cm(i)
@@ -268,18 +317,17 @@ def plot_param_fit(function, name, fitfunc, funp, parameters, datatype, index, X
plt.show()
-def boxplot(ticks, measurements, xlabel='', ylabel='', modeldata=None, output=None):
+def boxplot(ticks, measurements, xlabel="", ylabel="", modeldata=None, output=None):
fig, ax1 = plt.subplots(figsize=(10, 6))
- fig.canvas.set_window_title('DriverEval')
+ fig.canvas.set_window_title("DriverEval")
plt.subplots_adjust(left=0.1, right=0.95, top=0.95, bottom=0.1)
- bp = plt.boxplot(measurements, notch=0, sym='+', vert=1, whis=1.5)
- plt.setp(bp['boxes'], color='black')
- plt.setp(bp['whiskers'], color='black')
- plt.setp(bp['fliers'], color='red', marker='+')
+ bp = plt.boxplot(measurements, notch=0, sym="+", vert=1, whis=1.5)
+ plt.setp(bp["boxes"], color="black")
+ plt.setp(bp["whiskers"], color="black")
+ plt.setp(bp["fliers"], color="red", marker="+")
- ax1.yaxis.grid(True, linestyle='-', which='major', color='lightgrey',
- alpha=0.5)
+ ax1.yaxis.grid(True, linestyle="-", which="major", color="lightgrey", alpha=0.5)
ax1.set_axisbelow(True)
# ax1.set_title('DriverEval')
@@ -294,7 +342,7 @@ def boxplot(ticks, measurements, xlabel='', ylabel='', modeldata=None, output=No
# boxColors = ['darkkhaki', 'royalblue']
medians = list(range(numBoxes))
for i in range(numBoxes):
- box = bp['boxes'][i]
+ box = bp["boxes"][i]
boxX = []
boxY = []
for j in range(5):
@@ -306,21 +354,31 @@ def boxplot(ticks, measurements, xlabel='', ylabel='', modeldata=None, output=No
# boxPolygon = Polygon(boxCoords, facecolor=boxColors[k])
# ax1.add_patch(boxPolygon)
# Now draw the median lines back over what we just filled in
- med = bp['medians'][i]
+ med = bp["medians"][i]
medianX = []
medianY = []
for j in range(2):
medianX.append(med.get_xdata()[j])
medianY.append(med.get_ydata()[j])
- plt.plot(medianX, medianY, 'k')
+ plt.plot(medianX, medianY, "k")
medians[i] = medianY[0]
# Finally, overplot the sample averages, with horizontal alignment
# in the center of each box
- plt.plot([np.average(med.get_xdata())], [np.average(measurements[i])],
- color='w', marker='*', markeredgecolor='k')
+ plt.plot(
+ [np.average(med.get_xdata())],
+ [np.average(measurements[i])],
+ color="w",
+ marker="*",
+ markeredgecolor="k",
+ )
if modeldata:
- plt.plot([np.average(med.get_xdata())], [modeldata[i]],
- color='w', marker='o', markeredgecolor='k')
+ plt.plot(
+ [np.average(med.get_xdata())],
+ [modeldata[i]],
+ color="w",
+ marker="o",
+ markeredgecolor="k",
+ )
pos = np.arange(numBoxes) + 1
upperLabels = [str(np.round(s, 2)) for s in medians]
@@ -330,16 +388,21 @@ def boxplot(ticks, measurements, xlabel='', ylabel='', modeldata=None, output=No
y0, y1 = ax1.get_ylim()
textpos = y0 + (y1 - y0) * 0.97
# ypos = ax1.get_ylim()[0]
- ax1.text(pos[tick], textpos, upperLabels[tick],
- horizontalalignment='center', size='small',
- color='royalblue')
+ ax1.text(
+ pos[tick],
+ textpos,
+ upperLabels[tick],
+ horizontalalignment="center",
+ size="small",
+ color="royalblue",
+ )
if output:
plt.savefig(output)
- with open('{}.txt'.format(output), 'w') as f:
- print('X Y', file=f)
+ with open("{}.txt".format(output), "w") as f:
+ print("X Y", file=f)
for i, data in enumerate(measurements):
for value in data:
- print('{} {}'.format(ticks[i], value), file=f)
+ print("{} {}".format(ticks[i], value), file=f)
else:
plt.show()
diff --git a/lib/protocol_benchmarks.py b/lib/protocol_benchmarks.py
index e82af67..b42e821 100755
--- a/lib/protocol_benchmarks.py
+++ b/lib/protocol_benchmarks.py
@@ -18,40 +18,41 @@ import re
import time
from filelock import FileLock
+
class DummyProtocol:
def __init__(self):
self.max_serialized_bytes = None
- self.enc_buf = ''
- self.dec_buf = ''
- self.dec_buf0 = ''
- self.dec_buf1 = ''
- self.dec_buf2 = ''
+ self.enc_buf = ""
+ self.dec_buf = ""
+ self.dec_buf0 = ""
+ self.dec_buf1 = ""
+ self.dec_buf2 = ""
self.dec_index = 0
self.transition_map = dict()
- def assign_and_kout(self, signature, assignment, transition_args = None):
+ def assign_and_kout(self, signature, assignment, transition_args=None):
self.new_var(signature)
- self.assign_var(assignment, transition_args = transition_args)
+ self.assign_var(assignment, transition_args=transition_args)
self.kout_var()
def new_var(self, signature):
self.dec_index += 1
- self.dec_buf0 += '{} dec_{:d};\n'.format(signature, self.dec_index)
+ self.dec_buf0 += "{} dec_{:d};\n".format(signature, self.dec_index)
- def assign_var(self, assignment, transition_args = None):
- snippet = 'dec_{:d} = {};\n'.format(self.dec_index, assignment)
+ def assign_var(self, assignment, transition_args=None):
+ snippet = "dec_{:d} = {};\n".format(self.dec_index, assignment)
self.dec_buf1 += snippet
if transition_args:
self.add_transition(snippet, transition_args)
def get_var(self):
- return 'dec_{:d}'.format(self.dec_index)
+ return "dec_{:d}".format(self.dec_index)
def kout_var(self):
- self.dec_buf2 += 'kout << dec_{:d};\n'.format(self.dec_index)
+ self.dec_buf2 += "kout << dec_{:d};\n".format(self.dec_index)
def note_unsupported(self, value):
- note = '// value {} has unsupported type {}\n'.format(value, type(value))
+ note = "// value {} has unsupported type {}\n".format(value, type(value))
self.enc_buf += note
self.dec_buf += note
self.dec_buf1 += note
@@ -60,31 +61,31 @@ class DummyProtocol:
return True
def get_encode(self):
- return ''
+ return ""
def get_buffer_declaration(self):
- return ''
+ return ""
def get_buffer_name(self):
return '"none"'
def get_serialize(self):
- return ''
+ return ""
def get_deserialize(self):
- return ''
+ return ""
def get_decode_and_output(self):
- return ''
+ return ""
def get_decode_vars(self):
- return ''
+ return ""
def get_decode(self):
- return ''
+ return ""
def get_decode_output(self):
- return ''
+ return ""
def get_extra_files(self):
return dict()
@@ -101,30 +102,34 @@ class DummyProtocol:
return self.transition_map[code_snippet]
return list()
-class Avro(DummyProtocol):
- def __init__(self, data, strip_schema = False):
+class Avro(DummyProtocol):
+ def __init__(self, data, strip_schema=False):
super().__init__()
self.data = data
self.strip_schema = strip_schema
self.schema = {
- 'namespace' : 'benchmark.avro',
- 'type' : 'record',
- 'name' : 'Benchmark',
- 'fields' : []
+ "namespace": "benchmark.avro",
+ "type": "record",
+ "name": "Benchmark",
+ "fields": [],
}
for key, value in data.items():
- self.add_to_dict(self.schema['fields'], key, value)
+ self.add_to_dict(self.schema["fields"], key, value)
buf = io.BytesIO()
try:
- writer = avro.datafile.DataFileWriter(buf, avro.io.DatumWriter(), avro.schema.Parse(json.dumps(self.schema)))
+ writer = avro.datafile.DataFileWriter(
+ buf, avro.io.DatumWriter(), avro.schema.Parse(json.dumps(self.schema))
+ )
writer.append(data)
writer.flush()
except avro.schema.SchemaParseException:
- raise RuntimeError('Unsupported schema') from None
+ raise RuntimeError("Unsupported schema") from None
self.serialized_data = buf.getvalue()
if strip_schema:
- self.serialized_data = self.serialized_data[self.serialized_data.find(b'}\x00')+2 : ]
+ self.serialized_data = self.serialized_data[
+ self.serialized_data.find(b"}\x00") + 2 :
+ ]
# strip leading 16-byte sync marker
self.serialized_data = self.serialized_data[16:]
# strip trailing 16-byte sync marker
@@ -138,29 +143,30 @@ class Avro(DummyProtocol):
def type_to_type_name(self, type_type):
if type_type == int:
- return 'int'
+ return "int"
if type_type == float:
- return 'float'
+ return "float"
if type_type == str:
- return 'string'
+ return "string"
if type_type == list:
- return 'array'
+ return "array"
if type_type == dict:
- return 'record'
+ return "record"
def add_to_dict(self, fields, key, value):
- new_field = {
- 'name' : key,
- 'type' : self.type_to_type_name(type(value))
- }
- if new_field['type'] == 'array':
- new_field['type'] = {'type' : 'array', 'items' : self.type_to_type_name(type(value[0]))}
- if new_field['type'] == 'record':
- new_field['type'] = {'type' : 'record', 'name': key, 'fields' : []}
+ new_field = {"name": key, "type": self.type_to_type_name(type(value))}
+ if new_field["type"] == "array":
+ new_field["type"] = {
+ "type": "array",
+ "items": self.type_to_type_name(type(value[0])),
+ }
+ if new_field["type"] == "record":
+ new_field["type"] = {"type": "record", "name": key, "fields": []}
for key, value in value.items():
- self.add_to_dict(new_field['type']['fields'], key, value)
+ self.add_to_dict(new_field["type"]["fields"], key, value)
fields.append(new_field)
+
class Thrift(DummyProtocol):
class_index = 1
@@ -169,10 +175,10 @@ class Thrift(DummyProtocol):
super().__init__()
self.data = data
self._field_id = 1
- self.proto_buf = ''
+ self.proto_buf = ""
self.proto_from_json(data)
- with open('/tmp/test.thrift', 'w') as f:
+ with open("/tmp/test.thrift", "w") as f:
f.write(self.proto_buf)
membuf = TCyMemoryBuffer()
@@ -180,7 +186,9 @@ class Thrift(DummyProtocol):
# TODO irgendwo bleibt state übrig -> nur das bei allerersten
# Aufruf geladene Protokoll wird berücksichtigt, dazu nicht passende
# Daten werden nicht serialisiert
- test_thrift = thriftpy.load('/tmp/test.thrift', module_name='test{:d}_thrift'.format(Thrift.class_index))
+ test_thrift = thriftpy.load(
+ "/tmp/test.thrift", module_name="test{:d}_thrift".format(Thrift.class_index)
+ )
Thrift.class_index += 1
benchmark = test_thrift.Benchmark()
@@ -190,7 +198,7 @@ class Thrift(DummyProtocol):
try:
proto.write_struct(benchmark)
except thriftpy.thrift.TDecodeException:
- raise RuntimeError('Unsupported data layout') from None
+ raise RuntimeError("Unsupported data layout") from None
membuf.flush()
self.serialized_data = membuf.getvalue()
@@ -203,31 +211,31 @@ class Thrift(DummyProtocol):
def type_to_type_name(self, value):
type_type = type(value)
if type_type == int:
- return 'i32'
+ return "i32"
if type_type == float:
- return 'double'
+ return "double"
if type_type == str:
- return 'string'
+ return "string"
if type_type == list:
- return 'list<{}>'.format(self.type_to_type_name(value[0]))
+ return "list<{}>".format(self.type_to_type_name(value[0]))
if type_type == dict:
sub_value = list(value.values())[0]
- return 'map<{},{}>'.format('string', self.type_to_type_name(sub_value))
+ return "map<{},{}>".format("string", self.type_to_type_name(sub_value))
def add_to_dict(self, key, value):
key_type = self.type_to_type_name(value)
- self.proto_buf += '{:d}: {} {};\n'.format(self._field_id, key_type, key)
+ self.proto_buf += "{:d}: {} {};\n".format(self._field_id, key_type, key)
self._field_id += 1
def proto_from_json(self, data):
- self.proto_buf += 'struct Benchmark {\n'
+ self.proto_buf += "struct Benchmark {\n"
for key, value in data.items():
self.add_to_dict(key, value)
- self.proto_buf += '}\n'
+ self.proto_buf += "}\n"
-class ArduinoJSON(DummyProtocol):
- def __init__(self, data, bufsize = 255, int_type = 'uint16_t', float_type = 'float'):
+class ArduinoJSON(DummyProtocol):
+ def __init__(self, data, bufsize=255, int_type="uint16_t", float_type="float"):
super().__init__()
self.data = data
self.max_serialized_bytes = self.get_serialized_length() + 2
@@ -235,9 +243,12 @@ class ArduinoJSON(DummyProtocol):
self.bufsize = bufsize
self.int_type = int_type
self.float_type = float_type
- self.enc_buf += self.add_transition('ArduinoJson::StaticJsonBuffer<{:d}> jsonBuffer;\n'.format(bufsize), [bufsize])
- self.enc_buf += 'ArduinoJson::JsonObject& root = jsonBuffer.createObject();\n'
- self.from_json(data, 'root')
+ self.enc_buf += self.add_transition(
+ "ArduinoJson::StaticJsonBuffer<{:d}> jsonBuffer;\n".format(bufsize),
+ [bufsize],
+ )
+ self.enc_buf += "ArduinoJson::JsonObject& root = jsonBuffer.createObject();\n"
+ self.from_json(data, "root")
def get_serialized_length(self):
return len(json.dumps(self.data))
@@ -249,24 +260,33 @@ class ArduinoJSON(DummyProtocol):
return self.enc_buf
def get_buffer_declaration(self):
- return 'char buf[{:d}];\n'.format(self.max_serialized_bytes)
+ return "char buf[{:d}];\n".format(self.max_serialized_bytes)
def get_buffer_name(self):
- return 'buf'
+ return "buf"
def get_length_var(self):
- return 'serialized_size'
+ return "serialized_size"
def get_serialize(self):
- return self.add_transition('uint16_t serialized_size = root.printTo(buf);\n', [self.max_serialized_bytes])
+ return self.add_transition(
+ "uint16_t serialized_size = root.printTo(buf);\n",
+ [self.max_serialized_bytes],
+ )
def get_deserialize(self):
- ret = self.add_transition('ArduinoJson::StaticJsonBuffer<{:d}> jsonBuffer;\n'.format(self.bufsize), [self.bufsize])
- ret += self.add_transition('ArduinoJson::JsonObject& root = jsonBuffer.parseObject(buf);\n', [self.max_serialized_bytes])
+ ret = self.add_transition(
+ "ArduinoJson::StaticJsonBuffer<{:d}> jsonBuffer;\n".format(self.bufsize),
+ [self.bufsize],
+ )
+ ret += self.add_transition(
+ "ArduinoJson::JsonObject& root = jsonBuffer.parseObject(buf);\n",
+ [self.max_serialized_bytes],
+ )
return ret
def get_decode_and_output(self):
- return 'kout << dec << "dec:";\n' + self.dec_buf + 'kout << endl;\n';
+ return 'kout << dec << "dec:";\n' + self.dec_buf + "kout << endl;\n"
def get_decode_vars(self):
return self.dec_buf0
@@ -275,93 +295,158 @@ class ArduinoJSON(DummyProtocol):
return self.dec_buf1
def get_decode_output(self):
- return 'kout << dec << "dec:";\n' + self.dec_buf2 + 'kout << endl;\n';
+ return 'kout << dec << "dec:";\n' + self.dec_buf2 + "kout << endl;\n"
def add_to_list(self, enc_node, dec_node, offset, value):
if type(value) == str:
- if len(value) and value[0] == '$':
- self.enc_buf += '{}.add({});\n'.format(enc_node, value[1:])
- self.dec_buf += 'kout << {}[{:d}].as<{}>();\n'.format(dec_node, offset, self.int_type)
- self.assign_and_kout(self.int_type, '{}[{:d}].as<{}>()'.format(dec_node, offset, self.int_type))
+ if len(value) and value[0] == "$":
+ self.enc_buf += "{}.add({});\n".format(enc_node, value[1:])
+ self.dec_buf += "kout << {}[{:d}].as<{}>();\n".format(
+ dec_node, offset, self.int_type
+ )
+ self.assign_and_kout(
+ self.int_type,
+ "{}[{:d}].as<{}>()".format(dec_node, offset, self.int_type),
+ )
else:
- self.enc_buf += self.add_transition('{}.add("{}");\n'.format(enc_node, value), [len(value)])
- self.dec_buf += 'kout << {}[{:d}].as<const char *>();\n'.format(dec_node, offset)
- self.assign_and_kout('char const*', '{}[{:d}].as<char const *>()'.format(dec_node, offset), transition_args = [len(value)])
+ self.enc_buf += self.add_transition(
+ '{}.add("{}");\n'.format(enc_node, value), [len(value)]
+ )
+ self.dec_buf += "kout << {}[{:d}].as<const char *>();\n".format(
+ dec_node, offset
+ )
+ self.assign_and_kout(
+ "char const*",
+ "{}[{:d}].as<char const *>()".format(dec_node, offset),
+ transition_args=[len(value)],
+ )
elif type(value) == list:
- child = enc_node + 'l'
+ child = enc_node + "l"
while child in self.children:
- child += '_'
- self.enc_buf += 'ArduinoJson::JsonArray& {} = {}.createNestedArray();\n'.format(
- child, enc_node)
+ child += "_"
+ self.enc_buf += "ArduinoJson::JsonArray& {} = {}.createNestedArray();\n".format(
+ child, enc_node
+ )
self.children.add(child)
self.from_json(value, child)
elif type(value) == dict:
- child = enc_node + 'o'
+ child = enc_node + "o"
while child in self.children:
- child += '_'
- self.enc_buf += 'ArduinoJson::JsonObject& {} = {}.createNestedObject();\n'.format(
- child, enc_node)
+ child += "_"
+ self.enc_buf += "ArduinoJson::JsonObject& {} = {}.createNestedObject();\n".format(
+ child, enc_node
+ )
self.children.add(child)
self.from_json(value, child)
elif type(value) == float:
- self.enc_buf += '{}.add({});\n'.format(enc_node, value)
- self.dec_buf += 'kout << {}[{:d}].as<{}>();\n'.format(dec_node, offset, self.float_type)
- self.assign_and_kout(self.float_type, '{}[{:d}].as<{}>()'.format(dec_node, offset, self.float_type))
+ self.enc_buf += "{}.add({});\n".format(enc_node, value)
+ self.dec_buf += "kout << {}[{:d}].as<{}>();\n".format(
+ dec_node, offset, self.float_type
+ )
+ self.assign_and_kout(
+ self.float_type,
+ "{}[{:d}].as<{}>()".format(dec_node, offset, self.float_type),
+ )
elif type(value) == int:
- self.enc_buf += '{}.add({});\n'.format(enc_node, value)
- self.dec_buf += 'kout << {}[{:d}].as<{}>();\n'.format(dec_node, offset, self.int_type)
- self.assign_and_kout(self.int_type, '{}[{:d}].as<{}>()'.format(dec_node, offset, self.int_type))
+ self.enc_buf += "{}.add({});\n".format(enc_node, value)
+ self.dec_buf += "kout << {}[{:d}].as<{}>();\n".format(
+ dec_node, offset, self.int_type
+ )
+ self.assign_and_kout(
+ self.int_type,
+ "{}[{:d}].as<{}>()".format(dec_node, offset, self.int_type),
+ )
else:
self.note_unsupported(value)
def add_to_dict(self, enc_node, dec_node, key, value):
if type(value) == str:
- if len(value) and value[0] == '$':
- self.enc_buf += self.add_transition('{}["{}"] = {};\n'.format(enc_node, key, value[1:]), [len(key)])
- self.dec_buf += 'kout << {}["{}"].as<{}>();\n'.format(dec_node, key, self.int_type)
- self.assign_and_kout(self.int_type, '{}["{}"].as<{}>()'.format(dec_node, key, self.int_type))
+ if len(value) and value[0] == "$":
+ self.enc_buf += self.add_transition(
+ '{}["{}"] = {};\n'.format(enc_node, key, value[1:]), [len(key)]
+ )
+ self.dec_buf += 'kout << {}["{}"].as<{}>();\n'.format(
+ dec_node, key, self.int_type
+ )
+ self.assign_and_kout(
+ self.int_type,
+ '{}["{}"].as<{}>()'.format(dec_node, key, self.int_type),
+ )
else:
- self.enc_buf += self.add_transition('{}["{}"] = "{}";\n'.format(enc_node, key, value), [len(key), len(value)])
- self.dec_buf += 'kout << {}["{}"].as<const char *>();\n'.format(dec_node, key)
- self.assign_and_kout('char const*', '{}["{}"].as<const char *>()'.format(dec_node, key), transition_args = [len(key), len(value)])
+ self.enc_buf += self.add_transition(
+ '{}["{}"] = "{}";\n'.format(enc_node, key, value),
+ [len(key), len(value)],
+ )
+ self.dec_buf += 'kout << {}["{}"].as<const char *>();\n'.format(
+ dec_node, key
+ )
+ self.assign_and_kout(
+ "char const*",
+ '{}["{}"].as<const char *>()'.format(dec_node, key),
+ transition_args=[len(key), len(value)],
+ )
elif type(value) == list:
- child = enc_node + 'l'
+ child = enc_node + "l"
while child in self.children:
- child += '_'
- self.enc_buf += self.add_transition('ArduinoJson::JsonArray& {} = {}.createNestedArray("{}");\n'.format(
- child, enc_node, key), [len(key)])
+ child += "_"
+ self.enc_buf += self.add_transition(
+ 'ArduinoJson::JsonArray& {} = {}.createNestedArray("{}");\n'.format(
+ child, enc_node, key
+ ),
+ [len(key)],
+ )
self.children.add(child)
self.from_json(value, child, '{}["{}"]'.format(dec_node, key))
elif type(value) == dict:
- child = enc_node + 'o'
+ child = enc_node + "o"
while child in self.children:
- child += '_'
- self.enc_buf += self.add_transition('ArduinoJson::JsonObject& {} = {}.createNestedObject("{}");\n'.format(
- child, enc_node, key), [len(key)])
+ child += "_"
+ self.enc_buf += self.add_transition(
+ 'ArduinoJson::JsonObject& {} = {}.createNestedObject("{}");\n'.format(
+ child, enc_node, key
+ ),
+ [len(key)],
+ )
self.children.add(child)
self.from_json(value, child, '{}["{}"]'.format(dec_node, key))
elif type(value) == float:
- self.enc_buf += self.add_transition('{}["{}"] = {};\n'.format(enc_node, key, value), [len(key)])
- self.dec_buf += 'kout << {}["{}"].as<{}>();\n'.format(dec_node, key, self.float_type)
- self.assign_and_kout(self.float_type, '{}["{}"].as<{}>()'.format(dec_node, key, self.float_type), transition_args = [len(key)])
+ self.enc_buf += self.add_transition(
+ '{}["{}"] = {};\n'.format(enc_node, key, value), [len(key)]
+ )
+ self.dec_buf += 'kout << {}["{}"].as<{}>();\n'.format(
+ dec_node, key, self.float_type
+ )
+ self.assign_and_kout(
+ self.float_type,
+ '{}["{}"].as<{}>()'.format(dec_node, key, self.float_type),
+ transition_args=[len(key)],
+ )
elif type(value) == int:
- self.enc_buf += self.add_transition('{}["{}"] = {};\n'.format(enc_node, key, value), [len(key)])
- self.dec_buf += 'kout << {}["{}"].as<{}>();\n'.format(dec_node, key, self.int_type)
- self.assign_and_kout(self.int_type, '{}["{}"].as<{}>()'.format(dec_node, key, self.int_type), transition_args = [len(key)])
+ self.enc_buf += self.add_transition(
+ '{}["{}"] = {};\n'.format(enc_node, key, value), [len(key)]
+ )
+ self.dec_buf += 'kout << {}["{}"].as<{}>();\n'.format(
+ dec_node, key, self.int_type
+ )
+ self.assign_and_kout(
+ self.int_type,
+ '{}["{}"].as<{}>()'.format(dec_node, key, self.int_type),
+ transition_args=[len(key)],
+ )
else:
self.note_unsupported(value)
- def from_json(self, data, enc_node = 'root', dec_node = 'root'):
+ def from_json(self, data, enc_node="root", dec_node="root"):
if type(data) == dict:
for key in sorted(data.keys()):
self.add_to_dict(enc_node, dec_node, key, data[key])
@@ -371,8 +456,16 @@ class ArduinoJSON(DummyProtocol):
class CapnProtoC(DummyProtocol):
-
- def __init__(self, data, max_serialized_bytes = 128, packed = False, trail = ['benchmark'], int_type = 'uint16_t', float_type = 'float', dec_index = 0):
+ def __init__(
+ self,
+ data,
+ max_serialized_bytes=128,
+ packed=False,
+ trail=["benchmark"],
+ int_type="uint16_t",
+ float_type="float",
+ dec_index=0,
+ ):
super().__init__()
self.data = data
self.max_serialized_bytes = max_serialized_bytes
@@ -384,169 +477,190 @@ class CapnProtoC(DummyProtocol):
self.float_type = float_type
self.proto_float_type = self.float_type_to_proto_type(float_type)
self.dec_index = dec_index
- self.trail_name = '_'.join(map(lambda x: x.capitalize(), trail))
- self.proto_buf = ''
- self.enc_buf += 'struct {} {};\n'.format(self.trail_name, self.name)
- self.cc_tail = ''
+ self.trail_name = "_".join(map(lambda x: x.capitalize(), trail))
+ self.proto_buf = ""
+ self.enc_buf += "struct {} {};\n".format(self.trail_name, self.name)
+ self.cc_tail = ""
self.key_counter = 0
self.from_json(data)
def int_type_to_proto_type(self, int_type):
- sign = ''
- if int_type[0] == 'u':
- sign = 'U'
- if '8' in int_type:
+ sign = ""
+ if int_type[0] == "u":
+ sign = "U"
+ if "8" in int_type:
self.int_bits = 8
- return sign + 'Int8'
- if '16' in int_type:
+ return sign + "Int8"
+ if "16" in int_type:
self.int_bits = 16
- return sign + 'Int16'
- if '32' in int_type:
+ return sign + "Int16"
+ if "32" in int_type:
self.int_bits = 32
- return sign + 'Int32'
+ return sign + "Int32"
self.int_bits = 64
- return sign + 'Int64'
+ return sign + "Int64"
def float_type_to_proto_type(self, float_type):
- if float_type == 'float':
+ if float_type == "float":
self.float_bits = 32
- return 'Float32'
+ return "Float32"
self.float_bits = 64
- return 'Float64'
+ return "Float64"
def is_ascii(self):
return False
def get_proto(self):
- return '@0xad5b236043de2389;\n\n' + self.proto_buf
+ return "@0xad5b236043de2389;\n\n" + self.proto_buf
def get_extra_files(self):
- return {
- 'capnp_c_bench.capnp' : self.get_proto()
- }
+ return {"capnp_c_bench.capnp": self.get_proto()}
def get_buffer_declaration(self):
- ret = 'uint8_t buf[{:d}];\n'.format(self.max_serialized_bytes)
- ret += 'uint16_t serialized_size;\n'
+ ret = "uint8_t buf[{:d}];\n".format(self.max_serialized_bytes)
+ ret += "uint16_t serialized_size;\n"
return ret
def get_buffer_name(self):
- return 'buf'
+ return "buf"
def get_encode(self):
- ret = 'struct capn c;\n'
- ret += 'capn_init_malloc(&c);\n'
- ret += 'capn_ptr cr = capn_root(&c);\n'
- ret += 'struct capn_segment *cs = cr.seg;\n\n'
- ret += '{}_ptr {}_ptr = new_{}(cs);\n'.format(
- self.trail_name, self.name, self.trail_name)
+ ret = "struct capn c;\n"
+ ret += "capn_init_malloc(&c);\n"
+ ret += "capn_ptr cr = capn_root(&c);\n"
+ ret += "struct capn_segment *cs = cr.seg;\n\n"
+ ret += "{}_ptr {}_ptr = new_{}(cs);\n".format(
+ self.trail_name, self.name, self.trail_name
+ )
- tail = 'write_{}(&{}, {}_ptr);\n'.format(
- self.trail_name, self.name, self.name)
- tail += 'capn_setp(cr, 0, {}_ptr.p);\n'.format(self.name)
+ tail = "write_{}(&{}, {}_ptr);\n".format(self.trail_name, self.name, self.name)
+ tail += "capn_setp(cr, 0, {}_ptr.p);\n".format(self.name)
return ret + self.enc_buf + self.cc_tail + tail
def get_serialize(self):
- ret = 'serialized_size = capn_write_mem(&c, buf, sizeof(buf), {:d});\n'.format(self.packed)
- ret += 'capn_free(&c);\n'
+ ret = "serialized_size = capn_write_mem(&c, buf, sizeof(buf), {:d});\n".format(
+ self.packed
+ )
+ ret += "capn_free(&c);\n"
return ret
def get_deserialize(self):
- ret = 'struct capn c;\n'
- ret += 'capn_init_mem(&c, buf, serialized_size, 0);\n'
+ ret = "struct capn c;\n"
+ ret += "capn_init_mem(&c, buf, serialized_size, 0);\n"
return ret
def get_decode_and_output(self):
- ret = '{}_ptr {}_ptr;\n'.format(self.trail_name, self.name)
- ret += '{}_ptr.p = capn_getp(capn_root(&c), 0, 1);\n'.format(self.name)
- ret += 'struct {} {};\n'.format(self.trail_name, self.name)
+ ret = "{}_ptr {}_ptr;\n".format(self.trail_name, self.name)
+ ret += "{}_ptr.p = capn_getp(capn_root(&c), 0, 1);\n".format(self.name)
+ ret += "struct {} {};\n".format(self.trail_name, self.name)
ret += 'kout << dec << "dec:";\n'
ret += self.dec_buf
- ret += 'kout << endl;\n'
- ret += 'capn_free(&c);\n'
+ ret += "kout << endl;\n"
+ ret += "capn_free(&c);\n"
return ret
def get_decode_vars(self):
return self.dec_buf0
def get_decode(self):
- ret = '{}_ptr {}_ptr;\n'.format(self.trail_name, self.name)
- ret += '{}_ptr.p = capn_getp(capn_root(&c), 0, 1);\n'.format(self.name)
- ret += 'struct {} {};\n'.format(self.trail_name, self.name)
+ ret = "{}_ptr {}_ptr;\n".format(self.trail_name, self.name)
+ ret += "{}_ptr.p = capn_getp(capn_root(&c), 0, 1);\n".format(self.name)
+ ret += "struct {} {};\n".format(self.trail_name, self.name)
ret += self.dec_buf1
- ret += 'capn_free(&c);\n'
+ ret += "capn_free(&c);\n"
return ret
def get_decode_output(self):
- return 'kout << dec << "dec:";\n' + self.dec_buf2 + 'kout << endl;\n';
+ return 'kout << dec << "dec:";\n' + self.dec_buf2 + "kout << endl;\n"
def get_length_var(self):
- return 'serialized_size'
+ return "serialized_size"
def add_field(self, fieldtype, key, value):
- extra = ''
+ extra = ""
texttype = self.proto_int_type
if fieldtype == str:
- texttype = 'Text'
+ texttype = "Text"
elif fieldtype == float:
texttype = self.proto_float_type
elif fieldtype == dict:
texttype = key.capitalize()
if type(value) == list:
- texttype = 'List({})'.format(texttype)
+ texttype = "List({})".format(texttype)
- self.proto_buf += '{} @{:d} :{};\n'.format(
- key, self.key_counter, texttype)
+ self.proto_buf += "{} @{:d} :{};\n".format(key, self.key_counter, texttype)
self.key_counter += 1
if fieldtype == str:
- self.enc_buf += 'capn_text {}_text;\n'.format(key)
- self.enc_buf += '{}_text.len = {:d};\n'.format(key, len(value))
+ self.enc_buf += "capn_text {}_text;\n".format(key)
+ self.enc_buf += "{}_text.len = {:d};\n".format(key, len(value))
self.enc_buf += '{}_text.str = "{}";\n'.format(key, value)
- self.enc_buf += '{}_text.seg = NULL;\n'.format(key)
- self.enc_buf += '{}.{} = {}_text;\n\n'.format(self.name, key, key)
- self.dec_buf += 'kout << {}.{}.str;\n'.format(self.name, key)
- self.assign_and_kout('char const *', '{}.{}.str'.format(self.name, key))
+ self.enc_buf += "{}_text.seg = NULL;\n".format(key)
+ self.enc_buf += "{}.{} = {}_text;\n\n".format(self.name, key, key)
+ self.dec_buf += "kout << {}.{}.str;\n".format(self.name, key)
+ self.assign_and_kout("char const *", "{}.{}.str".format(self.name, key))
elif fieldtype == dict:
- pass # content is handled recursively in add_to_dict
+ pass # content is handled recursively in add_to_dict
elif type(value) == list:
if type(value[0]) == float:
- self.enc_buf += self.add_transition('{}.{} = capn_new_list{:d}(cs, {:d});\n'.format(
- self.name, key, self.float_bits, len(value)), [len(value)])
+ self.enc_buf += self.add_transition(
+ "{}.{} = capn_new_list{:d}(cs, {:d});\n".format(
+ self.name, key, self.float_bits, len(value)
+ ),
+ [len(value)],
+ )
for i, elem in enumerate(value):
- self.enc_buf += 'capn_set{:d}({}.{}, {:d}, capn_from_f{:d}({:f}));\n'.format(
- self.float_bits, self.name, key, i, self.float_bits, elem)
- self.dec_buf += 'kout << capn_to_f{:d}(capn_get{:d}({}.{}, {:d}));\n'.format(
- self.float_bits, self.float_bits, self.name, key, i)
- self.assign_and_kout(self.float_type, 'capn_to_f{:d}(capn_get{:d}({}.{}, {:d}))'.format(self.float_bits, self.float_bits, self.name, key, i))
+ self.enc_buf += "capn_set{:d}({}.{}, {:d}, capn_from_f{:d}({:f}));\n".format(
+ self.float_bits, self.name, key, i, self.float_bits, elem
+ )
+ self.dec_buf += "kout << capn_to_f{:d}(capn_get{:d}({}.{}, {:d}));\n".format(
+ self.float_bits, self.float_bits, self.name, key, i
+ )
+ self.assign_and_kout(
+ self.float_type,
+ "capn_to_f{:d}(capn_get{:d}({}.{}, {:d}))".format(
+ self.float_bits, self.float_bits, self.name, key, i
+ ),
+ )
else:
- self.enc_buf += self.add_transition('{}.{} = capn_new_list{:d}(cs, {:d});\n'.format(
- self.name, key, self.int_bits, len(value)), [len(value)])
+ self.enc_buf += self.add_transition(
+ "{}.{} = capn_new_list{:d}(cs, {:d});\n".format(
+ self.name, key, self.int_bits, len(value)
+ ),
+ [len(value)],
+ )
for i, elem in enumerate(value):
- self.enc_buf += 'capn_set{:d}({}.{}, {:d}, {:d});\n'.format(
- self.int_bits, self.name, key, i, elem)
- self.dec_buf += 'kout << capn_get{:d}({}.{}, {:d});\n'.format(
- self.int_bits, self.name, key, i)
- self.assign_and_kout(self.int_type, 'capn_get{:d}({}.{}, {:d})'.format(self.int_bits, self.name, key, i))
+ self.enc_buf += "capn_set{:d}({}.{}, {:d}, {:d});\n".format(
+ self.int_bits, self.name, key, i, elem
+ )
+ self.dec_buf += "kout << capn_get{:d}({}.{}, {:d});\n".format(
+ self.int_bits, self.name, key, i
+ )
+ self.assign_and_kout(
+ self.int_type,
+ "capn_get{:d}({}.{}, {:d})".format(
+ self.int_bits, self.name, key, i
+ ),
+ )
elif fieldtype == float:
- self.enc_buf += '{}.{} = {};\n\n'.format(self.name, key, value)
- self.dec_buf += 'kout << {}.{};\n'.format(self.name, key)
- self.assign_and_kout(self.float_type, '{}.{}'.format(self.name, key))
+ self.enc_buf += "{}.{} = {};\n\n".format(self.name, key, value)
+ self.dec_buf += "kout << {}.{};\n".format(self.name, key)
+ self.assign_and_kout(self.float_type, "{}.{}".format(self.name, key))
elif fieldtype == int:
- self.enc_buf += '{}.{} = {};\n\n'.format(self.name, key, value)
- self.dec_buf += 'kout << {}.{};\n'.format(self.name, key)
- self.assign_and_kout(self.int_type, '{}.{}'.format(self.name, key))
+ self.enc_buf += "{}.{} = {};\n\n".format(self.name, key, value)
+ self.dec_buf += "kout << {}.{};\n".format(self.name, key)
+ self.assign_and_kout(self.int_type, "{}.{}".format(self.name, key))
else:
self.note_unsupported(value)
def add_to_dict(self, key, value):
if type(value) == str:
- if len(value) and value[0] == '$':
+ if len(value) and value[0] == "$":
self.add_field(int, key, value[1:])
else:
self.add_field(str, key, value)
@@ -555,19 +669,35 @@ class CapnProtoC(DummyProtocol):
elif type(value) == dict:
trail = list(self.trail)
trail.append(key)
- nested = CapnProtoC(value, trail = trail, int_type = self.int_type, float_type = self.float_type, dec_index = self.dec_index)
+ nested = CapnProtoC(
+ value,
+ trail=trail,
+ int_type=self.int_type,
+ float_type=self.float_type,
+ dec_index=self.dec_index,
+ )
self.add_field(dict, key, value)
- self.enc_buf += '{}.{} = new_{}_{}(cs);\n'.format(
- self.name, key, self.trail_name, key.capitalize())
+ self.enc_buf += "{}.{} = new_{}_{}(cs);\n".format(
+ self.name, key, self.trail_name, key.capitalize()
+ )
self.enc_buf += nested.enc_buf
- self.enc_buf += 'write_{}_{}(&{}, {}.{});\n'.format(
- self.trail_name, key.capitalize(), key, self.name, key)
- self.dec_buf += 'struct {}_{} {};\n'.format(self.trail_name, key.capitalize(), key)
- self.dec_buf += 'read_{}_{}(&{}, {}.{});\n'.format(self.trail_name, key.capitalize(), key, self.name, key)
+ self.enc_buf += "write_{}_{}(&{}, {}.{});\n".format(
+ self.trail_name, key.capitalize(), key, self.name, key
+ )
+ self.dec_buf += "struct {}_{} {};\n".format(
+ self.trail_name, key.capitalize(), key
+ )
+ self.dec_buf += "read_{}_{}(&{}, {}.{});\n".format(
+ self.trail_name, key.capitalize(), key, self.name, key
+ )
self.dec_buf += nested.dec_buf
self.dec_buf0 += nested.dec_buf0
- self.dec_buf1 += 'struct {}_{} {};\n'.format(self.trail_name, key.capitalize(), key)
- self.dec_buf1 += 'read_{}_{}(&{}, {}.{});\n'.format(self.trail_name, key.capitalize(), key, self.name, key)
+ self.dec_buf1 += "struct {}_{} {};\n".format(
+ self.trail_name, key.capitalize(), key
+ )
+ self.dec_buf1 += "read_{}_{}(&{}, {}.{});\n".format(
+ self.trail_name, key.capitalize(), key, self.name, key
+ )
self.dec_buf1 += nested.dec_buf1
self.dec_buf2 += nested.dec_buf2
self.dec_index = nested.dec_index
@@ -576,19 +706,19 @@ class CapnProtoC(DummyProtocol):
self.add_field(type(value), key, value)
def from_json(self, data):
- self.proto_buf += 'struct {} {{\n'.format(self.name.capitalize())
+ self.proto_buf += "struct {} {{\n".format(self.name.capitalize())
if type(data) == dict:
for key in sorted(data.keys()):
self.add_to_dict(key, data[key])
- self.proto_buf += '}\n'
+ self.proto_buf += "}\n"
-class ManualJSON(DummyProtocol):
+class ManualJSON(DummyProtocol):
def __init__(self, data):
super().__init__()
self.data = data
self.max_serialized_bytes = self.get_serialized_length() + 2
- self.buf = 'BufferOutput<> bout(buf);\n'
+ self.buf = "BufferOutput<> bout(buf);\n"
self.buf += 'bout << "{";\n'
self.from_json(data)
self.buf += 'bout << "}";\n'
@@ -603,23 +733,25 @@ class ManualJSON(DummyProtocol):
return True
def get_buffer_declaration(self):
- return 'char buf[{:d}];\n'.format(self.max_serialized_bytes);
+ return "char buf[{:d}];\n".format(self.max_serialized_bytes)
def get_buffer_name(self):
- return 'buf'
+ return "buf"
def get_encode(self):
return self.buf
def get_length_var(self):
- return 'bout.size()'
+ return "bout.size()"
def add_to_list(self, value, is_last):
if type(value) == str:
- if len(value) and value[0] == '$':
- self.buf += 'bout << dec << {}'.format(value[1:])
+ if len(value) and value[0] == "$":
+ self.buf += "bout << dec << {}".format(value[1:])
else:
- self.buf += self.add_transition('bout << "\\"{}\\""'.format(value), [len(value)])
+ self.buf += self.add_transition(
+ 'bout << "\\"{}\\""'.format(value), [len(value)]
+ )
elif type(value) == list:
self.buf += 'bout << "[";\n'
@@ -632,36 +764,48 @@ class ManualJSON(DummyProtocol):
self.buf += 'bout << "}"'
else:
- self.buf += 'bout << {}'.format(value)
+ self.buf += "bout << {}".format(value)
if is_last:
- self.buf += ';\n';
+ self.buf += ";\n"
else:
self.buf += ' << ",";\n'
def add_to_dict(self, key, value, is_last):
if type(value) == str:
- if len(value) and value[0] == '$':
- self.buf += self.add_transition('bout << "\\"{}\\":" << dec << {}'.format(key, value[1:]), [len(key)])
+ if len(value) and value[0] == "$":
+ self.buf += self.add_transition(
+ 'bout << "\\"{}\\":" << dec << {}'.format(key, value[1:]),
+ [len(key)],
+ )
else:
- self.buf += self.add_transition('bout << "\\"{}\\":\\"{}\\""'.format(key, value), [len(key), len(value)])
+ self.buf += self.add_transition(
+ 'bout << "\\"{}\\":\\"{}\\""'.format(key, value),
+ [len(key), len(value)],
+ )
elif type(value) == list:
- self.buf += self.add_transition('bout << "\\"{}\\":[";\n'.format(key), [len(key)])
+ self.buf += self.add_transition(
+ 'bout << "\\"{}\\":[";\n'.format(key), [len(key)]
+ )
self.from_json(value)
self.buf += 'bout << "]"'
elif type(value) == dict:
# '{{' is an escaped '{' character
- self.buf += self.add_transition('bout << "\\"{}\\":{{";\n'.format(key), [len(key)])
+ self.buf += self.add_transition(
+ 'bout << "\\"{}\\":{{";\n'.format(key), [len(key)]
+ )
self.from_json(value)
self.buf += 'bout << "}"'
else:
- self.buf += self.add_transition('bout << "\\"{}\\":" << {}'.format(key, value), [len(key)])
+ self.buf += self.add_transition(
+ 'bout << "\\"{}\\":" << {}'.format(key, value), [len(key)]
+ )
if is_last:
- self.buf += ';\n'
+ self.buf += ";\n"
else:
self.buf += ' << ",";\n'
@@ -674,74 +818,74 @@ class ManualJSON(DummyProtocol):
for i, elem in enumerate(data):
self.add_to_list(elem, i == len(data) - 1)
-class ModernJSON(DummyProtocol):
- def __init__(self, data, output_format = 'json'):
+class ModernJSON(DummyProtocol):
+ def __init__(self, data, output_format="json"):
super().__init__()
self.data = data
self.output_format = output_format
- self.buf = 'nlohmann::json js;\n'
+ self.buf = "nlohmann::json js;\n"
self.from_json(data)
def is_ascii(self):
- if self.output_format == 'json':
+ if self.output_format == "json":
return True
return False
def get_buffer_name(self):
- return 'out'
+ return "out"
def get_encode(self):
return self.buf
def get_serialize(self):
- if self.output_format == 'json':
- return 'std::string out = js.dump();\n'
- elif self.output_format == 'bson':
- return 'std::vector<std::uint8_t> out = nlohmann::json::to_bson(js);\n'
- elif self.output_format == 'cbor':
- return 'std::vector<std::uint8_t> out = nlohmann::json::to_cbor(js);\n'
- elif self.output_format == 'msgpack':
- return 'std::vector<std::uint8_t> out = nlohmann::json::to_msgpack(js);\n'
- elif self.output_format == 'ubjson':
- return 'std::vector<std::uint8_t> out = nlohmann::json::to_ubjson(js);\n'
+ if self.output_format == "json":
+ return "std::string out = js.dump();\n"
+ elif self.output_format == "bson":
+ return "std::vector<std::uint8_t> out = nlohmann::json::to_bson(js);\n"
+ elif self.output_format == "cbor":
+ return "std::vector<std::uint8_t> out = nlohmann::json::to_cbor(js);\n"
+ elif self.output_format == "msgpack":
+ return "std::vector<std::uint8_t> out = nlohmann::json::to_msgpack(js);\n"
+ elif self.output_format == "ubjson":
+ return "std::vector<std::uint8_t> out = nlohmann::json::to_ubjson(js);\n"
else:
- raise ValueError('invalid output format {}'.format(self.output_format))
+ raise ValueError("invalid output format {}".format(self.output_format))
def get_serialized_length(self):
- if self.output_format == 'json':
+ if self.output_format == "json":
return len(json.dumps(self.data))
- elif self.output_format == 'bson':
+ elif self.output_format == "bson":
return len(bson.BSON.encode(self.data))
- elif self.output_format == 'cbor':
+ elif self.output_format == "cbor":
return len(cbor.dumps(self.data))
- elif self.output_format == 'msgpack':
+ elif self.output_format == "msgpack":
return len(msgpack.dumps(self.data))
- elif self.output_format == 'ubjson':
+ elif self.output_format == "ubjson":
return len(ubjson.dumpb(self.data))
else:
- raise ValueError('invalid output format {}'.format(self.output_format))
+ raise ValueError("invalid output format {}".format(self.output_format))
def can_get_serialized_length(self):
return True
def get_length_var(self):
- return 'out.size()'
+ return "out.size()"
def add_to_list(self, prefix, index, value):
if type(value) == str:
- if len(value) and value[0] == '$':
+ if len(value) and value[0] == "$":
self.buf += value[1:]
- self.buf += '{}[{:d}] = {};\n'.format(prefix, index, value[1:])
+ self.buf += "{}[{:d}] = {};\n".format(prefix, index, value[1:])
else:
self.buf += '{}[{:d}] = "{}";\n'.format(prefix, index, value)
else:
- self.buf += '{}[{:d}] = {};\n'.format(prefix, index, value)
+ self.buf += "{}[{:d}] = {};\n".format(prefix, index, value)
def add_to_dict(self, prefix, key, value):
if type(value) == str:
- if len(value) and value[0] == '$':
+ if len(value) and value[0] == "$":
self.buf += '{}["{}"] = {};\n'.format(prefix, key, value[1:])
else:
self.buf += '{}["{}"] = "{}";\n'.format(prefix, key, value)
@@ -755,7 +899,7 @@ class ModernJSON(DummyProtocol):
else:
self.buf += '{}["{}"] = {};\n'.format(prefix, key, value)
- def from_json(self, data, prefix = 'js'):
+ def from_json(self, data, prefix="js"):
if type(data) == dict:
for key in sorted(data.keys()):
self.add_to_dict(prefix, key, data[key])
@@ -763,17 +907,20 @@ class ModernJSON(DummyProtocol):
for i, elem in enumerate(data):
self.add_to_list(prefix, i, elem)
-class MPack(DummyProtocol):
- def __init__(self, data, int_type = 'uint16_t', float_type = 'float'):
+class MPack(DummyProtocol):
+ def __init__(self, data, int_type="uint16_t", float_type="float"):
super().__init__()
self.data = data
self.max_serialized_bytes = self.get_serialized_length() + 2
self.int_type = int_type
self.float_type = float_type
- self.enc_buf += 'mpack_writer_t writer;\n'
- self.enc_buf += self.add_transition('mpack_writer_init(&writer, buf, sizeof(buf));\n', [self.max_serialized_bytes])
- self.dec_buf0 += 'char strbuf[16];\n'
+ self.enc_buf += "mpack_writer_t writer;\n"
+ self.enc_buf += self.add_transition(
+ "mpack_writer_init(&writer, buf, sizeof(buf));\n",
+ [self.max_serialized_bytes],
+ )
+ self.dec_buf0 += "char strbuf[16];\n"
self.from_json(data)
def get_serialized_length(self):
@@ -786,34 +933,37 @@ class MPack(DummyProtocol):
return False
def get_buffer_declaration(self):
- ret = 'char buf[{:d}];\n'.format(self.max_serialized_bytes)
- ret += 'uint16_t serialized_size;\n'
+ ret = "char buf[{:d}];\n".format(self.max_serialized_bytes)
+ ret += "uint16_t serialized_size;\n"
return ret
def get_buffer_name(self):
- return 'buf'
+ return "buf"
def get_encode(self):
return self.enc_buf
def get_serialize(self):
- ret = 'serialized_size = mpack_writer_buffer_used(&writer);\n'
+ ret = "serialized_size = mpack_writer_buffer_used(&writer);\n"
# OptionalTimingAnalysis and other wrappers only wrap lines ending with a ;
# We therefore deliberately do not use proper if { ... } syntax here
# to make sure that these two statements are timed as well.
- ret += 'if (mpack_writer_destroy(&writer) != mpack_ok) '
+ ret += "if (mpack_writer_destroy(&writer) != mpack_ok) "
ret += 'kout << "Encoding failed" << endl;\n'
return ret
def get_deserialize(self):
- ret = 'mpack_reader_t reader;\n'
- ret += self.add_transition('mpack_reader_init_data(&reader, buf, serialized_size);\n', [self.max_serialized_bytes])
+ ret = "mpack_reader_t reader;\n"
+ ret += self.add_transition(
+ "mpack_reader_init_data(&reader, buf, serialized_size);\n",
+ [self.max_serialized_bytes],
+ )
return ret
def get_decode_and_output(self):
ret = 'kout << dec << "dec:";\n'
- ret += 'char strbuf[16];\n'
- return ret + self.dec_buf + 'kout << endl;\n'
+ ret += "char strbuf[16];\n"
+ return ret + self.dec_buf + "kout << endl;\n"
def get_decode_vars(self):
return self.dec_buf0
@@ -822,65 +972,99 @@ class MPack(DummyProtocol):
return self.dec_buf1
def get_decode_output(self):
- return 'kout << dec << "dec:";\n' + self.dec_buf2 + 'kout << endl;\n';
+ return 'kout << dec << "dec:";\n' + self.dec_buf2 + "kout << endl;\n"
def get_length_var(self):
- return 'serialized_size'
+ return "serialized_size"
def add_value(self, value):
if type(value) == str:
- if len(value) and value[0] == '$':
- self.enc_buf += 'mpack_write(&writer, {});\n'.format(value[1:])
- self.dec_buf += 'kout << mpack_expect_uint(&reader);\n'
- self.assign_and_kout(self.int_type, 'mpack_expect_uint(&reader)')
+ if len(value) and value[0] == "$":
+ self.enc_buf += "mpack_write(&writer, {});\n".format(value[1:])
+ self.dec_buf += "kout << mpack_expect_uint(&reader);\n"
+ self.assign_and_kout(self.int_type, "mpack_expect_uint(&reader)")
else:
- self.enc_buf += self.add_transition('mpack_write_cstr_or_nil(&writer, "{}");\n'.format(value), [len(value)])
- self.dec_buf += 'mpack_expect_cstr(&reader, strbuf, sizeof(strbuf));\n'
- self.dec_buf += 'kout << strbuf;\n'
- self.dec_buf1 += self.add_transition('mpack_expect_cstr(&reader, strbuf, sizeof(strbuf));\n', [len(value)])
- self.dec_buf2 += 'kout << strbuf;\n'
+ self.enc_buf += self.add_transition(
+ 'mpack_write_cstr_or_nil(&writer, "{}");\n'.format(value),
+ [len(value)],
+ )
+ self.dec_buf += "mpack_expect_cstr(&reader, strbuf, sizeof(strbuf));\n"
+ self.dec_buf += "kout << strbuf;\n"
+ self.dec_buf1 += self.add_transition(
+ "mpack_expect_cstr(&reader, strbuf, sizeof(strbuf));\n",
+ [len(value)],
+ )
+ self.dec_buf2 += "kout << strbuf;\n"
elif type(value) == list:
self.from_json(value)
elif type(value) == dict:
self.from_json(value)
elif type(value) == int:
- self.enc_buf += 'mpack_write(&writer, ({}){:d});\n'.format(self.int_type, value)
- self.dec_buf += 'kout << mpack_expect_uint(&reader);\n'
- self.assign_and_kout(self.int_type, 'mpack_expect_uint(&reader)')
+ self.enc_buf += "mpack_write(&writer, ({}){:d});\n".format(
+ self.int_type, value
+ )
+ self.dec_buf += "kout << mpack_expect_uint(&reader);\n"
+ self.assign_and_kout(self.int_type, "mpack_expect_uint(&reader)")
elif type(value) == float:
- self.enc_buf += 'mpack_write(&writer, ({}){:f});\n'.format(self.float_type, value)
- self.dec_buf += 'kout << mpack_expect_float(&reader);\n'
- self.assign_and_kout(self.float_type, 'mpack_expect_float(&reader)')
+ self.enc_buf += "mpack_write(&writer, ({}){:f});\n".format(
+ self.float_type, value
+ )
+ self.dec_buf += "kout << mpack_expect_float(&reader);\n"
+ self.assign_and_kout(self.float_type, "mpack_expect_float(&reader)")
else:
self.note_unsupported(value)
def from_json(self, data):
if type(data) == dict:
- self.enc_buf += self.add_transition('mpack_start_map(&writer, {:d});\n'.format(len(data)), [len(data)])
- self.dec_buf += 'mpack_expect_map_max(&reader, {:d});\n'.format(len(data))
- self.dec_buf1 += self.add_transition('mpack_expect_map_max(&reader, {:d});\n'.format(len(data)), [len(data)])
+ self.enc_buf += self.add_transition(
+ "mpack_start_map(&writer, {:d});\n".format(len(data)), [len(data)]
+ )
+ self.dec_buf += "mpack_expect_map_max(&reader, {:d});\n".format(len(data))
+ self.dec_buf1 += self.add_transition(
+ "mpack_expect_map_max(&reader, {:d});\n".format(len(data)), [len(data)]
+ )
for key in sorted(data.keys()):
- self.enc_buf += self.add_transition('mpack_write_cstr(&writer, "{}");\n'.format(key), [len(key)])
+ self.enc_buf += self.add_transition(
+ 'mpack_write_cstr(&writer, "{}");\n'.format(key), [len(key)]
+ )
self.dec_buf += 'mpack_expect_cstr_match(&reader, "{}");\n'.format(key)
- self.dec_buf1 += self.add_transition('mpack_expect_cstr_match(&reader, "{}");\n'.format(key), [len(key)])
+ self.dec_buf1 += self.add_transition(
+ 'mpack_expect_cstr_match(&reader, "{}");\n'.format(key), [len(key)]
+ )
self.add_value(data[key])
- self.enc_buf += 'mpack_finish_map(&writer);\n'
- self.dec_buf += 'mpack_done_map(&reader);\n'
- self.dec_buf1 += 'mpack_done_map(&reader);\n'
+ self.enc_buf += "mpack_finish_map(&writer);\n"
+ self.dec_buf += "mpack_done_map(&reader);\n"
+ self.dec_buf1 += "mpack_done_map(&reader);\n"
if type(data) == list:
- self.enc_buf += self.add_transition('mpack_start_array(&writer, {:d});\n'.format(len(data)), [len(data)])
- self.dec_buf += 'mpack_expect_array_max(&reader, {:d});\n'.format(len(data))
- self.dec_buf1 += self.add_transition('mpack_expect_array_max(&reader, {:d});\n'.format(len(data)), [len(data)])
+ self.enc_buf += self.add_transition(
+ "mpack_start_array(&writer, {:d});\n".format(len(data)), [len(data)]
+ )
+ self.dec_buf += "mpack_expect_array_max(&reader, {:d});\n".format(len(data))
+ self.dec_buf1 += self.add_transition(
+ "mpack_expect_array_max(&reader, {:d});\n".format(len(data)),
+ [len(data)],
+ )
for elem in data:
- self.add_value(elem);
- self.enc_buf += 'mpack_finish_array(&writer);\n'
- self.dec_buf += 'mpack_done_array(&reader);\n'
- self.dec_buf1 += 'mpack_done_array(&reader);\n'
+ self.add_value(elem)
+ self.enc_buf += "mpack_finish_array(&writer);\n"
+ self.dec_buf += "mpack_done_array(&reader);\n"
+ self.dec_buf1 += "mpack_done_array(&reader);\n"
-class NanoPB(DummyProtocol):
- def __init__(self, data, max_serialized_bytes = 256, cardinality = 'required', use_maps = False, max_string_length = None, cc_prefix = '', name = 'Benchmark',
- int_type = 'uint16_t', float_type = 'float', dec_index = 0):
+class NanoPB(DummyProtocol):
+ def __init__(
+ self,
+ data,
+ max_serialized_bytes=256,
+ cardinality="required",
+ use_maps=False,
+ max_string_length=None,
+ cc_prefix="",
+ name="Benchmark",
+ int_type="uint16_t",
+ float_type="float",
+ dec_index=0,
+ ):
super().__init__()
self.data = data
self.max_serialized_bytes = max_serialized_bytes
@@ -895,60 +1079,62 @@ class NanoPB(DummyProtocol):
self.proto_float_type = self.float_type_to_proto_type(float_type)
self.dec_index = dec_index
self.fieldnum = 1
- self.proto_head = 'syntax = "proto2";\nimport "src/app/prototest/nanopb.proto";\n\n'
- self.proto_fields = ''
- self.proto_options = ''
+ self.proto_head = (
+ 'syntax = "proto2";\nimport "src/app/prototest/nanopb.proto";\n\n'
+ )
+ self.proto_fields = ""
+ self.proto_options = ""
self.sub_protos = []
- self.cc_encoders = ''
+ self.cc_encoders = ""
self.from_json(data)
def is_ascii(self):
return False
def int_type_to_proto_type(self, int_type):
- sign = 'u'
- if int_type[0] != 'u':
- sign = ''
- if '64' in int_type:
+ sign = "u"
+ if int_type[0] != "u":
+ sign = ""
+ if "64" in int_type:
self.int_bits = 64
- return sign + 'int64'
+ return sign + "int64"
# Protocol Buffers only have 32 and 64 bit integers, so we default to 32
self.int_bits = 32
- return sign + 'int32'
+ return sign + "int32"
def float_type_to_proto_type(self, float_type):
- if float_type == 'float':
+ if float_type == "float":
self.float_bits = 32
else:
self.float_bits = 64
return float_type
def get_buffer_declaration(self):
- ret = 'uint8_t buf[{:d}];\n'.format(self.max_serialized_bytes)
- ret += 'uint16_t serialized_size;\n'
+ ret = "uint8_t buf[{:d}];\n".format(self.max_serialized_bytes)
+ ret += "uint16_t serialized_size;\n"
return ret + self.get_cc_functions()
def get_buffer_name(self):
- return 'buf'
+ return "buf"
def get_serialize(self):
- ret = 'pb_ostream_t stream = pb_ostream_from_buffer(buf, sizeof(buf));\n'
- ret += 'pb_encode(&stream, Benchmark_fields, &msg);\n'
- ret += 'serialized_size = stream.bytes_written;\n'
+ ret = "pb_ostream_t stream = pb_ostream_from_buffer(buf, sizeof(buf));\n"
+ ret += "pb_encode(&stream, Benchmark_fields, &msg);\n"
+ ret += "serialized_size = stream.bytes_written;\n"
return ret
def get_deserialize(self):
- ret = 'Benchmark msg = Benchmark_init_zero;\n'
- ret += 'pb_istream_t stream = pb_istream_from_buffer(buf, serialized_size);\n'
+ ret = "Benchmark msg = Benchmark_init_zero;\n"
+ ret += "pb_istream_t stream = pb_istream_from_buffer(buf, serialized_size);\n"
# OptionalTimingAnalysis and other wrappers only wrap lines ending with a ;
# We therefore deliberately do not use proper if { ... } syntax here
# to make sure that these two statements are timed as well.
- ret += 'if (pb_decode(&stream, Benchmark_fields, &msg) == false) '
+ ret += "if (pb_decode(&stream, Benchmark_fields, &msg) == false) "
ret += 'kout << "deserialized failed" << endl;\n'
return ret
def get_decode_and_output(self):
- return 'kout << dec << "dec:";\n' + self.dec_buf + 'kout << endl;\n'
+ return 'kout << dec << "dec:";\n' + self.dec_buf + "kout << endl;\n"
def get_decode_vars(self):
return self.dec_buf0
@@ -957,100 +1143,133 @@ class NanoPB(DummyProtocol):
return self.dec_buf1
def get_decode_output(self):
- return 'kout << dec << "dec:";\n' + self.dec_buf2 + 'kout << endl;\n';
+ return 'kout << dec << "dec:";\n' + self.dec_buf2 + "kout << endl;\n"
def get_length_var(self):
- return 'serialized_size'
+ return "serialized_size"
def add_field(self, cardinality, fieldtype, key, value):
- extra = ''
+ extra = ""
texttype = self.proto_int_type
dectype = self.int_type
if fieldtype == str:
- texttype = 'string'
+ texttype = "string"
elif fieldtype == float:
texttype = self.proto_float_type
dectype = self.float_type
elif fieldtype == dict:
texttype = key.capitalize()
if type(value) == list:
- extra = '[(nanopb).max_count = {:d}]'.format(len(value))
- self.enc_buf += 'msg.{}{}_count = {:d};\n'.format(self.cc_prefix, key, len(value))
- self.proto_fields += '{} {} {} = {:d} {};\n'.format(
- cardinality, texttype, key, self.fieldnum, extra)
+ extra = "[(nanopb).max_count = {:d}]".format(len(value))
+ self.enc_buf += "msg.{}{}_count = {:d};\n".format(
+ self.cc_prefix, key, len(value)
+ )
+ self.proto_fields += "{} {} {} = {:d} {};\n".format(
+ cardinality, texttype, key, self.fieldnum, extra
+ )
self.fieldnum += 1
if fieldtype == str:
- if cardinality == 'optional':
- self.enc_buf += 'msg.{}has_{} = true;\n'.format(self.cc_prefix, key)
+ if cardinality == "optional":
+ self.enc_buf += "msg.{}has_{} = true;\n".format(self.cc_prefix, key)
if self.max_strlen:
- self.proto_options += '{}.{} max_size:{:d}\n'.format(self.name, key, self.max_strlen)
+ self.proto_options += "{}.{} max_size:{:d}\n".format(
+ self.name, key, self.max_strlen
+ )
i = -1
for i, character in enumerate(value):
- self.enc_buf += '''msg.{}{}[{:d}] = '{}';\n'''.format(self.cc_prefix, key, i, character)
- self.enc_buf += 'msg.{}{}[{:d}] = 0;\n'.format(self.cc_prefix, key, i+1)
- self.dec_buf += 'kout << msg.{}{};\n'.format(self.cc_prefix, key)
- self.assign_and_kout('char *', 'msg.{}{}'.format(self.cc_prefix, key))
+ self.enc_buf += """msg.{}{}[{:d}] = '{}';\n""".format(
+ self.cc_prefix, key, i, character
+ )
+ self.enc_buf += "msg.{}{}[{:d}] = 0;\n".format(
+ self.cc_prefix, key, i + 1
+ )
+ self.dec_buf += "kout << msg.{}{};\n".format(self.cc_prefix, key)
+ self.assign_and_kout("char *", "msg.{}{}".format(self.cc_prefix, key))
else:
- self.cc_encoders += 'bool encode_{}(pb_ostream_t *stream, const pb_field_t *field, void * const *arg)\n'.format(key)
- self.cc_encoders += '{\n'
- self.cc_encoders += 'if (!pb_encode_tag_for_field(stream, field)) return false;\n'
- self.cc_encoders += 'return pb_encode_string(stream, (uint8_t*)"{}", {:d});\n'.format(value, len(value))
- self.cc_encoders += '}\n'
- self.enc_buf += 'msg.{}{}.funcs.encode = encode_{};\n'.format(self.cc_prefix, key, key)
- self.dec_buf += '// TODO decode string {}{} via callback\n'.format(self.cc_prefix, key)
- self.dec_buf1 += '// TODO decode string {}{} via callback\n'.format(self.cc_prefix, key)
+ self.cc_encoders += "bool encode_{}(pb_ostream_t *stream, const pb_field_t *field, void * const *arg)\n".format(
+ key
+ )
+ self.cc_encoders += "{\n"
+ self.cc_encoders += (
+ "if (!pb_encode_tag_for_field(stream, field)) return false;\n"
+ )
+ self.cc_encoders += 'return pb_encode_string(stream, (uint8_t*)"{}", {:d});\n'.format(
+ value, len(value)
+ )
+ self.cc_encoders += "}\n"
+ self.enc_buf += "msg.{}{}.funcs.encode = encode_{};\n".format(
+ self.cc_prefix, key, key
+ )
+ self.dec_buf += "// TODO decode string {}{} via callback\n".format(
+ self.cc_prefix, key
+ )
+ self.dec_buf1 += "// TODO decode string {}{} via callback\n".format(
+ self.cc_prefix, key
+ )
elif fieldtype == dict:
- if cardinality == 'optional':
- self.enc_buf += 'msg.{}has_{} = true;\n'.format(self.cc_prefix, key)
+ if cardinality == "optional":
+ self.enc_buf += "msg.{}has_{} = true;\n".format(self.cc_prefix, key)
# The rest is handled recursively in add_to_dict
elif type(value) == list:
for i, elem in enumerate(value):
- self.enc_buf += 'msg.{}{}[{:d}] = {};\n'.format(self.cc_prefix, key, i, elem)
- self.dec_buf += 'kout << msg.{}{}[{:d}];\n'.format(self.cc_prefix, key, i)
+ self.enc_buf += "msg.{}{}[{:d}] = {};\n".format(
+ self.cc_prefix, key, i, elem
+ )
+ self.dec_buf += "kout << msg.{}{}[{:d}];\n".format(
+ self.cc_prefix, key, i
+ )
if fieldtype == float:
- self.assign_and_kout(self.float_type, 'msg.{}{}[{:d}]'.format(self.cc_prefix, key, i))
+ self.assign_and_kout(
+ self.float_type, "msg.{}{}[{:d}]".format(self.cc_prefix, key, i)
+ )
elif fieldtype == int:
- self.assign_and_kout(self.int_type, 'msg.{}{}[{:d}]'.format(self.cc_prefix, key, i))
+ self.assign_and_kout(
+ self.int_type, "msg.{}{}[{:d}]".format(self.cc_prefix, key, i)
+ )
elif fieldtype == int:
- if cardinality == 'optional':
- self.enc_buf += 'msg.{}has_{} = true;\n'.format(self.cc_prefix, key)
- self.enc_buf += 'msg.{}{} = {};\n'.format(self.cc_prefix, key, value)
- self.dec_buf += 'kout << msg.{}{};\n'.format(self.cc_prefix, key)
- self.assign_and_kout(self.int_type, 'msg.{}{}'.format(self.cc_prefix, key))
+ if cardinality == "optional":
+ self.enc_buf += "msg.{}has_{} = true;\n".format(self.cc_prefix, key)
+ self.enc_buf += "msg.{}{} = {};\n".format(self.cc_prefix, key, value)
+ self.dec_buf += "kout << msg.{}{};\n".format(self.cc_prefix, key)
+ self.assign_and_kout(self.int_type, "msg.{}{}".format(self.cc_prefix, key))
elif fieldtype == float:
- if cardinality == 'optional':
- self.enc_buf += 'msg.{}has_{} = true;\n'.format(self.cc_prefix, key)
- self.enc_buf += 'msg.{}{} = {};\n'.format(self.cc_prefix, key, value)
- self.dec_buf += 'kout << msg.{}{};\n'.format(self.cc_prefix, key)
- self.assign_and_kout(self.float_type, 'msg.{}{}'.format(self.cc_prefix, key))
+ if cardinality == "optional":
+ self.enc_buf += "msg.{}has_{} = true;\n".format(self.cc_prefix, key)
+ self.enc_buf += "msg.{}{} = {};\n".format(self.cc_prefix, key, value)
+ self.dec_buf += "kout << msg.{}{};\n".format(self.cc_prefix, key)
+ self.assign_and_kout(
+ self.float_type, "msg.{}{}".format(self.cc_prefix, key)
+ )
elif fieldtype == dict:
- if cardinality == 'optional':
- self.enc_buf += 'msg.{}has_{} = true;\n'.format(self.cc_prefix, key)
- self.enc_buf += 'msg.{}{} = {};\n'.format(self.cc_prefix, key, value)
- self.dec_buf += 'kout << msg.{}{};\n'.format(self.cc_prefix, key)
- self.assign_and_kout(self.float_type, 'msg.{}{}'.format(self.cc_prefix, key))
+ if cardinality == "optional":
+ self.enc_buf += "msg.{}has_{} = true;\n".format(self.cc_prefix, key)
+ self.enc_buf += "msg.{}{} = {};\n".format(self.cc_prefix, key, value)
+ self.dec_buf += "kout << msg.{}{};\n".format(self.cc_prefix, key)
+ self.assign_and_kout(
+ self.float_type, "msg.{}{}".format(self.cc_prefix, key)
+ )
else:
self.note_unsupported(value)
def get_proto(self):
- return self.proto_head + '\n\n'.join(self.get_message_definitions('Benchmark'))
+ return self.proto_head + "\n\n".join(self.get_message_definitions("Benchmark"))
def get_proto_options(self):
return self.proto_options
def get_extra_files(self):
return {
- 'nanopbbench.proto' : self.get_proto(),
- 'nanopbbench.options' : self.get_proto_options()
+ "nanopbbench.proto": self.get_proto(),
+ "nanopbbench.options": self.get_proto_options(),
}
def get_message_definitions(self, msgname):
ret = list(self.sub_protos)
- ret.append('message {} {{\n'.format(msgname) + self.proto_fields + '}\n')
+ ret.append("message {} {{\n".format(msgname) + self.proto_fields + "}\n")
return ret
def get_encode(self):
- ret = 'Benchmark msg = Benchmark_init_zero;\n'
+ ret = "Benchmark msg = Benchmark_init_zero;\n"
return ret + self.enc_buf
def get_cc_functions(self):
@@ -1058,19 +1277,27 @@ class NanoPB(DummyProtocol):
def add_to_dict(self, key, value):
if type(value) == str:
- if len(value) and value[0] == '$':
+ if len(value) and value[0] == "$":
self.add_field(self.cardinality, int, key, value[1:])
else:
self.add_field(self.cardinality, str, key, value)
elif type(value) == list:
- self.add_field('repeated', type(value[0]), key, value)
+ self.add_field("repeated", type(value[0]), key, value)
elif type(value) == dict:
nested_proto = NanoPB(
- value, max_string_length = self.max_strlen, cardinality = self.cardinality, use_maps = self.use_maps, cc_prefix =
- '{}{}.'.format(self.cc_prefix, key), name = key.capitalize(),
- int_type = self.int_type, float_type = self.float_type,
- dec_index = self.dec_index)
- self.sub_protos.extend(nested_proto.get_message_definitions(key.capitalize()))
+ value,
+ max_string_length=self.max_strlen,
+ cardinality=self.cardinality,
+ use_maps=self.use_maps,
+ cc_prefix="{}{}.".format(self.cc_prefix, key),
+ name=key.capitalize(),
+ int_type=self.int_type,
+ float_type=self.float_type,
+ dec_index=self.dec_index,
+ )
+ self.sub_protos.extend(
+ nested_proto.get_message_definitions(key.capitalize())
+ )
self.proto_options += nested_proto.proto_options
self.cc_encoders += nested_proto.cc_encoders
self.add_field(self.cardinality, dict, key, value)
@@ -1089,19 +1316,21 @@ class NanoPB(DummyProtocol):
self.add_to_dict(key, data[key])
-
class UBJ(DummyProtocol):
-
- def __init__(self, data, max_serialized_bytes = 255, int_type = 'uint16_t', float_type = 'float'):
+ def __init__(
+ self, data, max_serialized_bytes=255, int_type="uint16_t", float_type="float"
+ ):
super().__init__()
self.data = data
self.max_serialized_bytes = self.get_serialized_length() + 2
self.int_type = int_type
self.float_type = self.parse_float_type(float_type)
- self.enc_buf += 'ubjw_context_t* ctx = ubjw_open_memory(buf, buf + sizeof(buf));\n'
- self.enc_buf += 'ubjw_begin_object(ctx, UBJ_MIXED, 0);\n'
- self.from_json('root', data)
- self.enc_buf += 'ubjw_end(ctx);\n'
+ self.enc_buf += (
+ "ubjw_context_t* ctx = ubjw_open_memory(buf, buf + sizeof(buf));\n"
+ )
+ self.enc_buf += "ubjw_begin_object(ctx, UBJ_MIXED, 0);\n"
+ self.from_json("root", data)
+ self.enc_buf += "ubjw_end(ctx);\n"
def get_serialized_length(self):
return len(ubjson.dumpb(self.data))
@@ -1113,41 +1342,41 @@ class UBJ(DummyProtocol):
return False
def parse_float_type(self, float_type):
- if float_type == 'float':
+ if float_type == "float":
self.float_bits = 32
else:
self.float_bits = 64
return float_type
def get_buffer_declaration(self):
- ret = 'uint8_t buf[{:d}];\n'.format(self.max_serialized_bytes)
- ret += 'uint16_t serialized_size;\n'
+ ret = "uint8_t buf[{:d}];\n".format(self.max_serialized_bytes)
+ ret += "uint16_t serialized_size;\n"
return ret
def get_buffer_name(self):
- return 'buf'
+ return "buf"
def get_length_var(self):
- return 'serialized_size'
+ return "serialized_size"
def get_serialize(self):
- return 'serialized_size = ubjw_close_context(ctx);\n'
+ return "serialized_size = ubjw_close_context(ctx);\n"
def get_encode(self):
return self.enc_buf
def get_deserialize(self):
- ret = 'ubjr_context_t* ctx = ubjr_open_memory(buf, buf + serialized_size);\n'
- ret += 'ubjr_dynamic_t dynamic_root = ubjr_read_dynamic(ctx);\n'
- ret += 'ubjr_dynamic_t* root_values = (ubjr_dynamic_t*)dynamic_root.container_object.values;\n'
+ ret = "ubjr_context_t* ctx = ubjr_open_memory(buf, buf + serialized_size);\n"
+ ret += "ubjr_dynamic_t dynamic_root = ubjr_read_dynamic(ctx);\n"
+ ret += "ubjr_dynamic_t* root_values = (ubjr_dynamic_t*)dynamic_root.container_object.values;\n"
return ret
def get_decode_and_output(self):
ret = 'kout << dec << "dec:";\n'
ret += self.dec_buf
- ret += 'kout << endl;\n'
- ret += 'ubjr_cleanup_dynamic(&dynamic_root);\n' # This causes the data (including all strings) to be free'd
- ret += 'ubjr_close_context(ctx);\n'
+ ret += "kout << endl;\n"
+ ret += "ubjr_cleanup_dynamic(&dynamic_root);\n" # This causes the data (including all strings) to be free'd
+ ret += "ubjr_close_context(ctx);\n"
return ret
def get_decode_vars(self):
@@ -1157,90 +1386,144 @@ class UBJ(DummyProtocol):
return self.dec_buf1
def get_decode_output(self):
- ret = 'kout << dec << "dec:";\n' + self.dec_buf2 + 'kout << endl;\n'
- ret += 'ubjr_cleanup_dynamic(&dynamic_root);\n'
- ret += 'ubjr_close_context(ctx);\n'
+ ret = 'kout << dec << "dec:";\n' + self.dec_buf2 + "kout << endl;\n"
+ ret += "ubjr_cleanup_dynamic(&dynamic_root);\n"
+ ret += "ubjr_close_context(ctx);\n"
return ret
def add_to_list(self, root, index, value):
if type(value) == str:
- if len(value) and value[0] == '$':
- self.enc_buf += 'ubjw_write_integer(ctx, {});\n'.format(value[1:])
- self.dec_buf += 'kout << {}_values[{:d}].integer;\n'.format(root, index)
- self.assign_and_kout(self.int_type, '{}_values[{:d}].integer'.format(root, index))
+ if len(value) and value[0] == "$":
+ self.enc_buf += "ubjw_write_integer(ctx, {});\n".format(value[1:])
+ self.dec_buf += "kout << {}_values[{:d}].integer;\n".format(root, index)
+ self.assign_and_kout(
+ self.int_type, "{}_values[{:d}].integer".format(root, index)
+ )
else:
- self.enc_buf += self.add_transition('ubjw_write_string(ctx, "{}");\n'.format(value), [len(value)])
- self.dec_buf += 'kout << {}_values[{:d}].string;\n'.format(root, index)
- self.assign_and_kout('char *', '{}_values[{:d}].string'.format(root, index))
+ self.enc_buf += self.add_transition(
+ 'ubjw_write_string(ctx, "{}");\n'.format(value), [len(value)]
+ )
+ self.dec_buf += "kout << {}_values[{:d}].string;\n".format(root, index)
+ self.assign_and_kout(
+ "char *", "{}_values[{:d}].string".format(root, index)
+ )
elif type(value) == list:
- self.enc_buf += 'ubjw_begin_array(ctx, UBJ_MIXED, 0);\n'
- self.dec_buf += '// decoding nested lists is not supported\n'
- self.dec_buf1 += '// decoding nested lists is not supported\n'
+ self.enc_buf += "ubjw_begin_array(ctx, UBJ_MIXED, 0);\n"
+ self.dec_buf += "// decoding nested lists is not supported\n"
+ self.dec_buf1 += "// decoding nested lists is not supported\n"
self.from_json(root, value)
- self.enc_buf += 'ubjw_end(ctx);\n'
+ self.enc_buf += "ubjw_end(ctx);\n"
elif type(value) == dict:
- self.enc_buf += 'ubjw_begin_object(ctx, UBJ_MIXED, 0);\n'
- self.dec_buf += '// decoding objects in lists is not supported\n'
- self.dec_buf1 += '// decoding objects in lists is not supported\n'
+ self.enc_buf += "ubjw_begin_object(ctx, UBJ_MIXED, 0);\n"
+ self.dec_buf += "// decoding objects in lists is not supported\n"
+ self.dec_buf1 += "// decoding objects in lists is not supported\n"
self.from_json(root, value)
- self.enc_buf += 'ubjw_end(ctx);\n'
+ self.enc_buf += "ubjw_end(ctx);\n"
elif type(value) == float:
- self.enc_buf += 'ubjw_write_float{:d}(ctx, {});\n'.format(self.float_bits, value)
- self.dec_buf += 'kout << {}_values[{:d}].real;\n'.format(root, index)
- self.assign_and_kout(self.float_type, '{}_values[{:d}].real'.format(root, index))
+ self.enc_buf += "ubjw_write_float{:d}(ctx, {});\n".format(
+ self.float_bits, value
+ )
+ self.dec_buf += "kout << {}_values[{:d}].real;\n".format(root, index)
+ self.assign_and_kout(
+ self.float_type, "{}_values[{:d}].real".format(root, index)
+ )
elif type(value) == int:
- self.enc_buf += 'ubjw_write_integer(ctx, {});\n'.format(value)
- self.dec_buf += 'kout << {}_values[{:d}].integer;\n'.format(root, index)
- self.assign_and_kout(self.int_type, '{}_values[{:d}].integer'.format(root, index))
+ self.enc_buf += "ubjw_write_integer(ctx, {});\n".format(value)
+ self.dec_buf += "kout << {}_values[{:d}].integer;\n".format(root, index)
+ self.assign_and_kout(
+ self.int_type, "{}_values[{:d}].integer".format(root, index)
+ )
else:
- raise TypeError('Cannot handle {} of type {}'.format(value, type(value)))
+ raise TypeError("Cannot handle {} of type {}".format(value, type(value)))
def add_to_dict(self, root, index, key, value):
if type(value) == str:
- if len(value) and value[0] == '$':
- self.enc_buf += self.add_transition('ubjw_write_key(ctx, "{}"); ubjw_write_integer(ctx, {});\n'.format(key, value[1:]), [len(key)])
- self.dec_buf += 'kout << {}_values[{:d}].integer;\n'.format(root, index)
- self.assign_and_kout(self.int_type, '{}_values[{:d}].integer'.format(root, index))
+ if len(value) and value[0] == "$":
+ self.enc_buf += self.add_transition(
+ 'ubjw_write_key(ctx, "{}"); ubjw_write_integer(ctx, {});\n'.format(
+ key, value[1:]
+ ),
+ [len(key)],
+ )
+ self.dec_buf += "kout << {}_values[{:d}].integer;\n".format(root, index)
+ self.assign_and_kout(
+ self.int_type, "{}_values[{:d}].integer".format(root, index)
+ )
else:
- self.enc_buf += self.add_transition('ubjw_write_key(ctx, "{}"); ubjw_write_string(ctx, "{}");\n'.format(key, value), [len(key), len(value)])
- self.dec_buf += 'kout << {}_values[{:d}].string;\n'.format(root, index)
- self.assign_and_kout('char *', '{}_values[{:d}].string'.format(root, index))
+ self.enc_buf += self.add_transition(
+ 'ubjw_write_key(ctx, "{}"); ubjw_write_string(ctx, "{}");\n'.format(
+ key, value
+ ),
+ [len(key), len(value)],
+ )
+ self.dec_buf += "kout << {}_values[{:d}].string;\n".format(root, index)
+ self.assign_and_kout(
+ "char *", "{}_values[{:d}].string".format(root, index)
+ )
elif type(value) == list:
- self.enc_buf += self.add_transition('ubjw_write_key(ctx, "{}"); ubjw_begin_array(ctx, UBJ_MIXED, 0);\n'.format(key), [len(key)])
- self.dec_buf += 'ubjr_dynamic_t *{}_values = (ubjr_dynamic_t*){}_values[{:d}].container_array.values;\n'.format(
- key, root, index)
- self.dec_buf1 += 'ubjr_dynamic_t *{}_values = (ubjr_dynamic_t*){}_values[{:d}].container_array.values;\n'.format(
- key, root, index)
+ self.enc_buf += self.add_transition(
+ 'ubjw_write_key(ctx, "{}"); ubjw_begin_array(ctx, UBJ_MIXED, 0);\n'.format(
+ key
+ ),
+ [len(key)],
+ )
+ self.dec_buf += "ubjr_dynamic_t *{}_values = (ubjr_dynamic_t*){}_values[{:d}].container_array.values;\n".format(
+ key, root, index
+ )
+ self.dec_buf1 += "ubjr_dynamic_t *{}_values = (ubjr_dynamic_t*){}_values[{:d}].container_array.values;\n".format(
+ key, root, index
+ )
self.from_json(key, value)
- self.enc_buf += 'ubjw_end(ctx);\n'
+ self.enc_buf += "ubjw_end(ctx);\n"
elif type(value) == dict:
- self.enc_buf += self.add_transition('ubjw_write_key(ctx, "{}"); ubjw_begin_object(ctx, UBJ_MIXED, 0);\n'.format(key), [len(key)])
- self.dec_buf += 'ubjr_dynamic_t *{}_values = (ubjr_dynamic_t*){}_values[{:d}].container_object.values;\n'.format(
- key, root, index)
- self.dec_buf1 += 'ubjr_dynamic_t *{}_values = (ubjr_dynamic_t*){}_values[{:d}].container_object.values;\n'.format(
- key, root, index)
+ self.enc_buf += self.add_transition(
+ 'ubjw_write_key(ctx, "{}"); ubjw_begin_object(ctx, UBJ_MIXED, 0);\n'.format(
+ key
+ ),
+ [len(key)],
+ )
+ self.dec_buf += "ubjr_dynamic_t *{}_values = (ubjr_dynamic_t*){}_values[{:d}].container_object.values;\n".format(
+ key, root, index
+ )
+ self.dec_buf1 += "ubjr_dynamic_t *{}_values = (ubjr_dynamic_t*){}_values[{:d}].container_object.values;\n".format(
+ key, root, index
+ )
self.from_json(key, value)
- self.enc_buf += 'ubjw_end(ctx);\n'
+ self.enc_buf += "ubjw_end(ctx);\n"
elif type(value) == float:
- self.enc_buf += self.add_transition('ubjw_write_key(ctx, "{}"); ubjw_write_float{:d}(ctx, {});\n'.format(key, self.float_bits, value), [len(key)])
- self.dec_buf += 'kout << {}_values[{:d}].real;\n'.format(root, index)
- self.assign_and_kout(self.float_type, '{}_values[{:d}].real'.format(root, index))
+ self.enc_buf += self.add_transition(
+ 'ubjw_write_key(ctx, "{}"); ubjw_write_float{:d}(ctx, {});\n'.format(
+ key, self.float_bits, value
+ ),
+ [len(key)],
+ )
+ self.dec_buf += "kout << {}_values[{:d}].real;\n".format(root, index)
+ self.assign_and_kout(
+ self.float_type, "{}_values[{:d}].real".format(root, index)
+ )
elif type(value) == int:
- self.enc_buf += self.add_transition('ubjw_write_key(ctx, "{}"); ubjw_write_integer(ctx, {});\n'.format(key, value), [len(key)])
- self.dec_buf += 'kout << {}_values[{:d}].integer;\n'.format(root, index)
- self.assign_and_kout(self.int_type, '{}_values[{:d}].integer'.format(root, index))
+ self.enc_buf += self.add_transition(
+ 'ubjw_write_key(ctx, "{}"); ubjw_write_integer(ctx, {});\n'.format(
+ key, value
+ ),
+ [len(key)],
+ )
+ self.dec_buf += "kout << {}_values[{:d}].integer;\n".format(root, index)
+ self.assign_and_kout(
+ self.int_type, "{}_values[{:d}].integer".format(root, index)
+ )
else:
- raise TypeError('Cannot handle {} of type {}'.format(value, type(value)))
+ raise TypeError("Cannot handle {} of type {}".format(value, type(value)))
def from_json(self, root, data):
if type(data) == dict:
@@ -1253,63 +1536,64 @@ class UBJ(DummyProtocol):
class XDR(DummyProtocol):
-
- def __init__(self, data, max_serialized_bytes = 256, int_type = 'uint16_t', float_type = 'float'):
+ def __init__(
+ self, data, max_serialized_bytes=256, int_type="uint16_t", float_type="float"
+ ):
super().__init__()
self.data = data
self.max_serialized_bytes = 256
self.enc_int_type = int_type
self.dec_int_type = self.parse_int_type(int_type)
self.float_type = self.parse_float_type(float_type)
- self.enc_buf += 'XDRWriter xdrstream(buf);\n'
- self.dec_buf += 'XDRReader xdrinput(buf);\n'
- self.dec_buf0 += 'XDRReader xdrinput(buf);\n'
+ self.enc_buf += "XDRWriter xdrstream(buf);\n"
+ self.dec_buf += "XDRReader xdrinput(buf);\n"
+ self.dec_buf0 += "XDRReader xdrinput(buf);\n"
# By default, XDR does not even include a version / protocol specifier.
# This seems rather impractical -> emulate that here.
- #self.enc_buf += 'xdrstream << (uint32_t)22075;\n'
- self.dec_buf += 'char strbuf[16];\n'
- #self.dec_buf += 'xdrinput.get_uint32();\n'
- self.dec_buf0 += 'char strbuf[16];\n'
- #self.dec_buf0 += 'xdrinput.get_uint32();\n'
+ # self.enc_buf += 'xdrstream << (uint32_t)22075;\n'
+ self.dec_buf += "char strbuf[16];\n"
+ # self.dec_buf += 'xdrinput.get_uint32();\n'
+ self.dec_buf0 += "char strbuf[16];\n"
+ # self.dec_buf0 += 'xdrinput.get_uint32();\n'
self.from_json(data)
def is_ascii(self):
return False
def parse_int_type(self, int_type):
- sign = ''
- if int_type[0] == 'u':
- sign = 'u'
- if '64' in int_type:
+ sign = ""
+ if int_type[0] == "u":
+ sign = "u"
+ if "64" in int_type:
self.int_bits = 64
- return sign + 'int64'
+ return sign + "int64"
else:
self.int_bits = 32
- return sign + 'int32'
+ return sign + "int32"
def parse_float_type(self, float_type):
- if float_type == 'float':
+ if float_type == "float":
self.float_bits = 32
else:
self.float_bits = 64
return float_type
def get_buffer_declaration(self):
- ret = 'uint16_t serialized_size;\n'
- ret += 'char buf[{:d}];\n'.format(self.max_serialized_bytes)
+ ret = "uint16_t serialized_size;\n"
+ ret += "char buf[{:d}];\n".format(self.max_serialized_bytes)
return ret
def get_buffer_name(self):
- return 'buf'
+ return "buf"
def get_length_var(self):
- return 'xdrstream.size()'
+ return "xdrstream.size()"
def get_encode(self):
return self.enc_buf
def get_decode_and_output(self):
- return 'kout << dec << "dec:";\n' + self.dec_buf + 'kout << endl;\n'
+ return 'kout << dec << "dec:";\n' + self.dec_buf + "kout << endl;\n"
def get_decode_vars(self):
return self.dec_buf0
@@ -1318,115 +1602,129 @@ class XDR(DummyProtocol):
return self.dec_buf1
def get_decode_output(self):
- return 'kout << dec << "dec:";\n' + self.dec_buf2 + 'kout << endl;\n'
+ return 'kout << dec << "dec:";\n' + self.dec_buf2 + "kout << endl;\n"
def from_json(self, data):
if type(data) == dict:
for key in sorted(data.keys()):
self.from_json(data[key])
elif type(data) == list:
- self.enc_buf += 'xdrstream.setNextArrayLen({});\n'.format(len(data))
- self.enc_buf += 'xdrstream.setVariableLength();\n'
- self.enc_buf += 'xdrstream.startList();\n'
- self.dec_buf += 'xdrinput.get_uint32();\n'
- self.dec_buf1 += 'xdrinput.get_uint32();\n'
+ self.enc_buf += "xdrstream.setNextArrayLen({});\n".format(len(data))
+ self.enc_buf += "xdrstream.setVariableLength();\n"
+ self.enc_buf += "xdrstream.startList();\n"
+ self.dec_buf += "xdrinput.get_uint32();\n"
+ self.dec_buf1 += "xdrinput.get_uint32();\n"
for elem in data:
self.from_json(elem)
elif type(data) == str:
- if len(data) and data[0] == '$':
- self.enc_buf += 'xdrstream.put(({}){});\n'.format(self.enc_int_type, data[1:])
- self.dec_buf += 'kout << xdrinput.get_{}();\n'.format(self.dec_int_type)
- self.dec_buf0 += '{} dec_{};\n'.format(self.enc_int_type, self.dec_index)
- self.dec_buf1 += 'dec_{} = xdrinput.get_{}();;\n'.format(self.dec_index, self.dec_int_type)
- self.dec_buf2 += 'kout << dec_{};\n'.format(self.dec_index)
+ if len(data) and data[0] == "$":
+ self.enc_buf += "xdrstream.put(({}){});\n".format(
+ self.enc_int_type, data[1:]
+ )
+ self.dec_buf += "kout << xdrinput.get_{}();\n".format(self.dec_int_type)
+ self.dec_buf0 += "{} dec_{};\n".format(
+ self.enc_int_type, self.dec_index
+ )
+ self.dec_buf1 += "dec_{} = xdrinput.get_{}();;\n".format(
+ self.dec_index, self.dec_int_type
+ )
+ self.dec_buf2 += "kout << dec_{};\n".format(self.dec_index)
else:
# Kodierte Strings haben nicht immer ein Nullbyte am Ende
- self.enc_buf += 'xdrstream.setNextArrayLen({});\n'.format(len(data))
- self.enc_buf += 'xdrstream.setVariableLength();\n'
- self.enc_buf += self.add_transition('xdrstream.put("{}");\n'.format(data), [len(data)])
- self.dec_buf += 'xdrinput.get_string(strbuf);\n'
- self.dec_buf += 'kout << strbuf;\n'
- self.dec_buf1 += 'xdrinput.get_string(strbuf);\n'
- self.dec_buf2 += 'kout << strbuf;\n'
+ self.enc_buf += "xdrstream.setNextArrayLen({});\n".format(len(data))
+ self.enc_buf += "xdrstream.setVariableLength();\n"
+ self.enc_buf += self.add_transition(
+ 'xdrstream.put("{}");\n'.format(data), [len(data)]
+ )
+ self.dec_buf += "xdrinput.get_string(strbuf);\n"
+ self.dec_buf += "kout << strbuf;\n"
+ self.dec_buf1 += "xdrinput.get_string(strbuf);\n"
+ self.dec_buf2 += "kout << strbuf;\n"
elif type(data) == float:
- self.enc_buf += 'xdrstream.put(({}){});\n'.format(self.float_type, data)
- self.dec_buf += 'kout << xdrinput.get_{}();\n'.format(self.float_type)
- self.dec_buf0 += '{} dec_{};\n'.format(self.float_type, self.dec_index)
- self.dec_buf1 += 'dec_{} = xdrinput.get_{}();\n'.format(self.dec_index, self.float_type)
- self.dec_buf2 += 'kout << dec_{};\n'.format(self.dec_index)
+ self.enc_buf += "xdrstream.put(({}){});\n".format(self.float_type, data)
+ self.dec_buf += "kout << xdrinput.get_{}();\n".format(self.float_type)
+ self.dec_buf0 += "{} dec_{};\n".format(self.float_type, self.dec_index)
+ self.dec_buf1 += "dec_{} = xdrinput.get_{}();\n".format(
+ self.dec_index, self.float_type
+ )
+ self.dec_buf2 += "kout << dec_{};\n".format(self.dec_index)
elif type(data) == int:
- self.enc_buf += 'xdrstream.put(({}){});\n'.format(self.enc_int_type, data)
- self.dec_buf += 'kout << xdrinput.get_{}();\n'.format(self.dec_int_type)
- self.dec_buf0 += '{} dec_{};\n'.format(self.enc_int_type, self.dec_index)
- self.dec_buf1 += 'dec_{} = xdrinput.get_{}();\n'.format(self.dec_index, self.dec_int_type)
- self.dec_buf2 += 'kout << dec_{};\n'.format(self.dec_index)
+ self.enc_buf += "xdrstream.put(({}){});\n".format(self.enc_int_type, data)
+ self.dec_buf += "kout << xdrinput.get_{}();\n".format(self.dec_int_type)
+ self.dec_buf0 += "{} dec_{};\n".format(self.enc_int_type, self.dec_index)
+ self.dec_buf1 += "dec_{} = xdrinput.get_{}();\n".format(
+ self.dec_index, self.dec_int_type
+ )
+ self.dec_buf2 += "kout << dec_{};\n".format(self.dec_index)
else:
- self.enc_buf += 'xdrstream.put({});\n'.format(data)
- self.dec_buf += '// unsupported type {} of {}\n'.format(type(data), data)
- self.dec_buf1 += '// unsupported type {} of {}\n'.format(type(data), data)
+ self.enc_buf += "xdrstream.put({});\n".format(data)
+ self.dec_buf += "// unsupported type {} of {}\n".format(type(data), data)
+ self.dec_buf1 += "// unsupported type {} of {}\n".format(type(data), data)
self.dec_index += 1
-class XDR16(DummyProtocol):
- def __init__(self, data, max_serialized_bytes = 256, int_type = 'uint16_t', float_type = 'float'):
+class XDR16(DummyProtocol):
+ def __init__(
+ self, data, max_serialized_bytes=256, int_type="uint16_t", float_type="float"
+ ):
super().__init__()
self.data = data
self.max_serialized_bytes = 256
self.enc_int_type = int_type
self.dec_int_type = self.parse_int_type(int_type)
self.float_type = self.parse_float_type(float_type)
- self.enc_buf += 'XDRWriter xdrstream(buf);\n'
- self.dec_buf += 'XDRReader xdrinput(buf);\n'
- self.dec_buf0 += 'XDRReader xdrinput(buf);\n'
+ self.enc_buf += "XDRWriter xdrstream(buf);\n"
+ self.dec_buf += "XDRReader xdrinput(buf);\n"
+ self.dec_buf0 += "XDRReader xdrinput(buf);\n"
# By default, XDR does not even include a version / protocol specifier.
# This seems rather impractical -> emulate that here.
- #self.enc_buf += 'xdrstream << (uint32_t)22075;\n'
- self.dec_buf += 'char strbuf[16];\n'
- #self.dec_buf += 'xdrinput.get_uint32();\n'
- self.dec_buf0 += 'char strbuf[16];\n'
- #self.dec_buf0 += 'xdrinput.get_uint32();\n'
+ # self.enc_buf += 'xdrstream << (uint32_t)22075;\n'
+ self.dec_buf += "char strbuf[16];\n"
+ # self.dec_buf += 'xdrinput.get_uint32();\n'
+ self.dec_buf0 += "char strbuf[16];\n"
+ # self.dec_buf0 += 'xdrinput.get_uint32();\n'
self.from_json(data)
def is_ascii(self):
return False
def parse_int_type(self, int_type):
- sign = ''
- if int_type[0] == 'u':
- sign = 'u'
- if '64' in int_type:
+ sign = ""
+ if int_type[0] == "u":
+ sign = "u"
+ if "64" in int_type:
self.int_bits = 64
- return sign + 'int64'
- if '32' in int_type:
+ return sign + "int64"
+ if "32" in int_type:
self.int_bits = 32
- return sign + 'int32'
+ return sign + "int32"
else:
self.int_bits = 16
- return sign + 'int16'
+ return sign + "int16"
def parse_float_type(self, float_type):
- if float_type == 'float':
+ if float_type == "float":
self.float_bits = 32
else:
self.float_bits = 64
return float_type
def get_buffer_declaration(self):
- ret = 'uint16_t serialized_size;\n'
- ret += 'char buf[{:d}];\n'.format(self.max_serialized_bytes)
+ ret = "uint16_t serialized_size;\n"
+ ret += "char buf[{:d}];\n".format(self.max_serialized_bytes)
return ret
def get_buffer_name(self):
- return 'buf'
+ return "buf"
def get_length_var(self):
- return 'xdrstream.size()'
+ return "xdrstream.size()"
def get_encode(self):
return self.enc_buf
def get_decode_and_output(self):
- return 'kout << dec << "dec:";\n' + self.dec_buf + 'kout << endl;\n'
+ return 'kout << dec << "dec:";\n' + self.dec_buf + "kout << endl;\n"
def get_decode_vars(self):
return self.dec_buf0
@@ -1435,76 +1733,99 @@ class XDR16(DummyProtocol):
return self.dec_buf1
def get_decode_output(self):
- return 'kout << dec << "dec:";\n' + self.dec_buf2 + 'kout << endl;\n';
+ return 'kout << dec << "dec:";\n' + self.dec_buf2 + "kout << endl;\n"
def from_json(self, data):
if type(data) == dict:
for key in sorted(data.keys()):
self.from_json(data[key])
elif type(data) == list:
- self.enc_buf += 'xdrstream.setNextArrayLen({});\n'.format(len(data))
- self.enc_buf += 'xdrstream.setVariableLength();\n'
- self.enc_buf += 'xdrstream.startList();\n'
- self.dec_buf += 'xdrinput.get_uint16();\n'
- self.dec_buf1 += 'xdrinput.get_uint16();\n'
+ self.enc_buf += "xdrstream.setNextArrayLen({});\n".format(len(data))
+ self.enc_buf += "xdrstream.setVariableLength();\n"
+ self.enc_buf += "xdrstream.startList();\n"
+ self.dec_buf += "xdrinput.get_uint16();\n"
+ self.dec_buf1 += "xdrinput.get_uint16();\n"
for elem in data:
self.from_json(elem)
elif type(data) == str:
- if len(data) and data[0] == '$':
- self.enc_buf += 'xdrstream.put(({}){});\n'.format(self.enc_int_type, data[1:])
- self.dec_buf += 'kout << xdrinput.get_{}();\n'.format(self.dec_int_type)
- self.dec_buf0 += '{} dec_{};\n'.format(self.enc_int_type, self.dec_index)
- self.dec_buf1 += 'dec_{} = xdrinput.get_{}();;\n'.format(self.dec_index, self.dec_int_type)
- self.dec_buf2 += 'kout << dec_{};\n'.format(self.dec_index)
+ if len(data) and data[0] == "$":
+ self.enc_buf += "xdrstream.put(({}){});\n".format(
+ self.enc_int_type, data[1:]
+ )
+ self.dec_buf += "kout << xdrinput.get_{}();\n".format(self.dec_int_type)
+ self.dec_buf0 += "{} dec_{};\n".format(
+ self.enc_int_type, self.dec_index
+ )
+ self.dec_buf1 += "dec_{} = xdrinput.get_{}();;\n".format(
+ self.dec_index, self.dec_int_type
+ )
+ self.dec_buf2 += "kout << dec_{};\n".format(self.dec_index)
else:
# Kodierte Strings haben nicht immer ein Nullbyte am Ende
- self.enc_buf += 'xdrstream.setNextArrayLen({});\n'.format(len(data))
- self.enc_buf += 'xdrstream.setVariableLength();\n'
- self.enc_buf += self.add_transition('xdrstream.put("{}");\n'.format(data), [len(data)])
- self.dec_buf += 'xdrinput.get_string(strbuf);\n'
- self.dec_buf += 'kout << strbuf;\n'
- self.dec_buf1 += 'xdrinput.get_string(strbuf);\n'
- self.dec_buf2 += 'kout << strbuf;\n'
+ self.enc_buf += "xdrstream.setNextArrayLen({});\n".format(len(data))
+ self.enc_buf += "xdrstream.setVariableLength();\n"
+ self.enc_buf += self.add_transition(
+ 'xdrstream.put("{}");\n'.format(data), [len(data)]
+ )
+ self.dec_buf += "xdrinput.get_string(strbuf);\n"
+ self.dec_buf += "kout << strbuf;\n"
+ self.dec_buf1 += "xdrinput.get_string(strbuf);\n"
+ self.dec_buf2 += "kout << strbuf;\n"
elif type(data) == float:
- self.enc_buf += 'xdrstream.put(({}){});\n'.format(self.float_type, data)
- self.dec_buf += 'kout << xdrinput.get_{}();\n'.format(self.float_type)
- self.dec_buf0 += '{} dec_{};\n'.format(self.float_type, self.dec_index)
- self.dec_buf1 += 'dec_{} = xdrinput.get_{}();\n'.format(self.dec_index, self.float_type)
- self.dec_buf2 += 'kout << dec_{};\n'.format(self.dec_index)
+ self.enc_buf += "xdrstream.put(({}){});\n".format(self.float_type, data)
+ self.dec_buf += "kout << xdrinput.get_{}();\n".format(self.float_type)
+ self.dec_buf0 += "{} dec_{};\n".format(self.float_type, self.dec_index)
+ self.dec_buf1 += "dec_{} = xdrinput.get_{}();\n".format(
+ self.dec_index, self.float_type
+ )
+ self.dec_buf2 += "kout << dec_{};\n".format(self.dec_index)
elif type(data) == int:
- self.enc_buf += 'xdrstream.put(({}){});\n'.format(self.enc_int_type, data)
- self.dec_buf += 'kout << xdrinput.get_{}();\n'.format(self.dec_int_type)
- self.dec_buf0 += '{} dec_{};\n'.format(self.enc_int_type, self.dec_index)
- self.dec_buf1 += 'dec_{} = xdrinput.get_{}();\n'.format(self.dec_index, self.dec_int_type)
- self.dec_buf2 += 'kout << dec_{};\n'.format(self.dec_index)
+ self.enc_buf += "xdrstream.put(({}){});\n".format(self.enc_int_type, data)
+ self.dec_buf += "kout << xdrinput.get_{}();\n".format(self.dec_int_type)
+ self.dec_buf0 += "{} dec_{};\n".format(self.enc_int_type, self.dec_index)
+ self.dec_buf1 += "dec_{} = xdrinput.get_{}();\n".format(
+ self.dec_index, self.dec_int_type
+ )
+ self.dec_buf2 += "kout << dec_{};\n".format(self.dec_index)
else:
- self.enc_buf += 'xdrstream.put({});\n'.format(data)
- self.dec_buf += '// unsupported type {} of {}\n'.format(type(data), data)
- self.dec_buf1 += '// unsupported type {} of {}\n'.format(type(data), data)
- self.dec_index += 1;
+ self.enc_buf += "xdrstream.put({});\n".format(data)
+ self.dec_buf += "// unsupported type {} of {}\n".format(type(data), data)
+ self.dec_buf1 += "// unsupported type {} of {}\n".format(type(data), data)
+ self.dec_index += 1
-class Benchmark:
+class Benchmark:
def __init__(self, logfile):
self.atomic = True
self.logfile = logfile
def __enter__(self):
self.atomic = False
- with FileLock(self.logfile + '.lock'):
+ with FileLock(self.logfile + ".lock"):
if os.path.exists(self.logfile):
- with open(self.logfile, 'rb') as f:
+ with open(self.logfile, "rb") as f:
self.data = ubjson.load(f)
else:
self.data = {}
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
- with FileLock(self.logfile + '.lock'):
- with open(self.logfile, 'wb') as f:
+ with FileLock(self.logfile + ".lock"):
+ with open(self.logfile, "wb") as f:
ubjson.dump(self.data, f)
- def _add_log_entry(self, benchmark_data, arch, libkey, bench_name, bench_index, data, key, value, error):
+ def _add_log_entry(
+ self,
+ benchmark_data,
+ arch,
+ libkey,
+ bench_name,
+ bench_index,
+ data,
+ key,
+ value,
+ error,
+ ):
if not libkey in benchmark_data:
benchmark_data[libkey] = dict()
if not bench_name in benchmark_data[libkey]:
@@ -1514,94 +1835,127 @@ class Benchmark:
this_result = benchmark_data[libkey][bench_name][bench_index]
# data is unset for log(...) calls from postprocessing
if data != None:
- this_result['data'] = data
+ this_result["data"] = data
if value != None:
- this_result[key] = {
- 'v' : value,
- 'ts' : int(time.time())
- }
- print('{} {} {} ({}) :: {} -> {}'.format(
- libkey, bench_name, bench_index, data, key, value))
+ this_result[key] = {"v": value, "ts": int(time.time())}
+ print(
+ "{} {} {} ({}) :: {} -> {}".format(
+ libkey, bench_name, bench_index, data, key, value
+ )
+ )
else:
- this_result[key] = {
- 'e' : error,
- 'ts' : int(time.time())
- }
- print('{} {} {} ({}) :: {} -> [E] {}'.format(
- libkey, bench_name, bench_index, data, key, error[:500]))
-
- def log(self, arch, library, library_options, bench_name, bench_index, data, key, value = None, error = None):
+ this_result[key] = {"e": error, "ts": int(time.time())}
+ print(
+ "{} {} {} ({}) :: {} -> [E] {}".format(
+ libkey, bench_name, bench_index, data, key, error[:500]
+ )
+ )
+
+ def log(
+ self,
+ arch,
+ library,
+ library_options,
+ bench_name,
+ bench_index,
+ data,
+ key,
+ value=None,
+ error=None,
+ ):
if not library_options:
library_options = []
- libkey = '{}:{}:{}'.format(arch, library, ','.join(library_options))
+ libkey = "{}:{}:{}".format(arch, library, ",".join(library_options))
# JSON does not differentiate between int and str keys -> always use
# str(bench_index)
bench_index = str(bench_index)
if self.atomic:
- with FileLock(self.logfile + '.lock'):
+ with FileLock(self.logfile + ".lock"):
if os.path.exists(self.logfile):
- with open(self.logfile, 'rb') as f:
+ with open(self.logfile, "rb") as f:
benchmark_data = ubjson.load(f)
else:
benchmark_data = {}
- self._add_log_entry(benchmark_data, arch, libkey, bench_name, bench_index, data, key, value, error)
- with open(self.logfile, 'wb') as f:
+ self._add_log_entry(
+ benchmark_data,
+ arch,
+ libkey,
+ bench_name,
+ bench_index,
+ data,
+ key,
+ value,
+ error,
+ )
+ with open(self.logfile, "wb") as f:
ubjson.dump(benchmark_data, f)
else:
- self._add_log_entry(self.data, arch, libkey, bench_name, bench_index, data, key, value, error)
+ self._add_log_entry(
+ self.data,
+ arch,
+ libkey,
+ bench_name,
+ bench_index,
+ data,
+ key,
+ value,
+ error,
+ )
def get_snapshot(self):
- with FileLock(self.logfile + '.lock'):
+ with FileLock(self.logfile + ".lock"):
if os.path.exists(self.logfile):
- with open(self.logfile, 'rb') as f:
+ with open(self.logfile, "rb") as f:
benchmark_data = ubjson.load(f)
else:
benchmark_data = {}
return benchmark_data
+
def codegen_for_lib(library, library_options, data):
- if library == 'arduinojson':
- return ArduinoJSON(data, bufsize = 512)
+ if library == "arduinojson":
+ return ArduinoJSON(data, bufsize=512)
- if library == 'avro':
+ if library == "avro":
strip_schema = bool(int(library_options[0]))
- return Avro(data, strip_schema = strip_schema)
+ return Avro(data, strip_schema=strip_schema)
- if library == 'capnproto_c':
+ if library == "capnproto_c":
packed = bool(int(library_options[0]))
- return CapnProtoC(data, packed = packed)
+ return CapnProtoC(data, packed=packed)
- if library == 'manualjson':
+ if library == "manualjson":
return ManualJSON(data)
- if library == 'modernjson':
- dataformat, = library_options
+ if library == "modernjson":
+ (dataformat,) = library_options
return ModernJSON(data, dataformat)
- if library == 'mpack':
+ if library == "mpack":
return MPack(data)
- if library == 'nanopb':
+ if library == "nanopb":
cardinality, strbuf = library_options
- if not len(strbuf) or strbuf == '0':
+ if not len(strbuf) or strbuf == "0":
strbuf = None
else:
strbuf = int(strbuf)
- return NanoPB(data, cardinality = cardinality, max_string_length = strbuf)
+ return NanoPB(data, cardinality=cardinality, max_string_length=strbuf)
- if library == 'thrift':
+ if library == "thrift":
return Thrift(data)
- if library == 'ubjson':
+ if library == "ubjson":
return UBJ(data)
- if library == 'xdr':
+ if library == "xdr":
return XDR(data)
- if library == 'xdr16':
+ if library == "xdr16":
return XDR16(data)
- raise ValueError('Unsupported library: {}'.format(library))
+ raise ValueError("Unsupported library: {}".format(library))
+
def shorten_call(snippet, lib):
"""
@@ -1612,54 +1966,78 @@ def shorten_call(snippet, lib):
"""
# The following adjustments are protobench-specific
# "xdrstream << (uint16_t)123" -> "xdrstream << (uint16_t"
- if 'xdrstream << (' in snippet:
- snippet = snippet.split(')')[0]
+ if "xdrstream << (" in snippet:
+ snippet = snippet.split(")")[0]
# "xdrstream << variable << ..." -> "xdrstream << variable"
- elif 'xdrstream << variable' in snippet:
- snippet = '<<'.join(snippet.split('<<')[0:2])
- elif 'xdrstream.setNextArrayLen(' in snippet:
- snippet = 'xdrstream.setNextArrayLen'
- elif 'ubjw' in snippet:
- snippet = re.sub('ubjw_write_key\(ctx, [^)]+\)', 'ubjw_write_key(ctx, ?)', snippet)
- snippet = re.sub('ubjw_write_([^(]+)\(ctx, [^)]+\)', 'ubjw_write_\\1(ctx, ?)', snippet)
+ elif "xdrstream << variable" in snippet:
+ snippet = "<<".join(snippet.split("<<")[0:2])
+ elif "xdrstream.setNextArrayLen(" in snippet:
+ snippet = "xdrstream.setNextArrayLen"
+ elif "ubjw" in snippet:
+ snippet = re.sub(
+ "ubjw_write_key\(ctx, [^)]+\)", "ubjw_write_key(ctx, ?)", snippet
+ )
+ snippet = re.sub(
+ "ubjw_write_([^(]+)\(ctx, [^)]+\)", "ubjw_write_\\1(ctx, ?)", snippet
+ )
# mpack_write(&writer, (type)value) -> mpack_write(&writer, (type
- elif 'mpack_write(' in snippet:
- snippet = snippet.split(')')[0]
+ elif "mpack_write(" in snippet:
+ snippet = snippet.split(")")[0]
# mpack_write_cstr(&writer, "foo") -> mpack_write_cstr(&writer,
# same for mpack_write_cstr_or_nil
- elif 'mpack_write_cstr' in snippet:
+ elif "mpack_write_cstr" in snippet:
snippet = snippet.split('"')[0]
# mpack_start_map(&writer, x) -> mpack_start_map(&writer
# mpack_start_array(&writer, x) -> mpack_start_array(&writer
- elif 'mpack_start_' in snippet:
- snippet = snippet.split(',')[0]
- #elif 'bout <<' in snippet:
+ elif "mpack_start_" in snippet:
+ snippet = snippet.split(",")[0]
+ # elif 'bout <<' in snippet:
# if '\\":\\"' in snippet:
# snippet = 'bout << key:str'
# elif 'bout << "\\"' in snippet:
# snippet = 'bout << key'
# else:
# snippet = 'bout << other'
- elif 'msg.' in snippet:
- snippet = re.sub('msg.(?:[^[]+)(?:\[.*?\])? = .*', 'msg.? = ?', snippet)
- elif lib == 'arduinojson:':
- snippet = re.sub('ArduinoJson::JsonObject& [^ ]+ = [^.]+.createNestedObject\([^")]*\);', 'ArduinoJson::JsonObject& ? = ?.createNestedObject();', snippet)
- snippet = re.sub('ArduinoJson::JsonObject& [^ ]+ = [^.]+.createNestedObject\("[^")]*"\);', 'ArduinoJson::JsonObject& ? = ?.createNestedObject(?);', snippet)
- snippet = re.sub('ArduinoJson::JsonArray& [^ ]+ = [^.]+.createNestedArray\([^")]*\);', 'ArduinoJson::JsonArray& ? = ?.createNestedArray();', snippet)
- snippet = re.sub('ArduinoJson::JsonArray& [^ ]+ = [^.]+.createNestedArray\("[^")]*"\);', 'ArduinoJson::JsonArray& ? = ?.createNestedArray(?);', snippet)
+ elif "msg." in snippet:
+ snippet = re.sub("msg.(?:[^[]+)(?:\[.*?\])? = .*", "msg.? = ?", snippet)
+ elif lib == "arduinojson:":
+ snippet = re.sub(
+ 'ArduinoJson::JsonObject& [^ ]+ = [^.]+.createNestedObject\([^")]*\);',
+ "ArduinoJson::JsonObject& ? = ?.createNestedObject();",
+ snippet,
+ )
+ snippet = re.sub(
+ 'ArduinoJson::JsonObject& [^ ]+ = [^.]+.createNestedObject\("[^")]*"\);',
+ "ArduinoJson::JsonObject& ? = ?.createNestedObject(?);",
+ snippet,
+ )
+ snippet = re.sub(
+ 'ArduinoJson::JsonArray& [^ ]+ = [^.]+.createNestedArray\([^")]*\);',
+ "ArduinoJson::JsonArray& ? = ?.createNestedArray();",
+ snippet,
+ )
+ snippet = re.sub(
+ 'ArduinoJson::JsonArray& [^ ]+ = [^.]+.createNestedArray\("[^")]*"\);',
+ "ArduinoJson::JsonArray& ? = ?.createNestedArray(?);",
+ snippet,
+ )
snippet = re.sub('root[^[]*\["[^"]*"\] = [^";]+', 'root?["?"] = ?', snippet)
snippet = re.sub('root[^[]*\["[^"]*"\] = "[^"]+"', 'root?["?"] = "?"', snippet)
- snippet = re.sub('rootl.add\([^)]*\)', 'rootl.add(?)', snippet)
+ snippet = re.sub("rootl.add\([^)]*\)", "rootl.add(?)", snippet)
- snippet = re.sub('^dec_[^ ]*', 'dec_?', snippet)
- if lib == 'arduinojson:':
- snippet = re.sub('root[^[]*\[[^]"]+\]\.as', 'root[?].as', snippet)
+ snippet = re.sub("^dec_[^ ]*", "dec_?", snippet)
+ if lib == "arduinojson:":
+ snippet = re.sub('root[^[]*\[[^]"]+\]\.as', "root[?].as", snippet)
snippet = re.sub('root[^[]*\["[^]]+"\]\.as', 'root["?"].as', snippet)
- elif 'nanopb:' in lib:
- snippet = re.sub('= msg\.[^;]+;', '= msg.?;', snippet)
- elif lib == 'mpack:':
- snippet = re.sub('mpack_expect_([^_]+)_max\(&reader, [^)]+\)', 'mpack_expect_\\1_max(&reader, ?)', snippet)
- elif lib == 'ubjson:':
- snippet = re.sub('[^_ ]+_values[^.]+\.', '?_values[?].', snippet)
+ elif "nanopb:" in lib:
+ snippet = re.sub("= msg\.[^;]+;", "= msg.?;", snippet)
+ elif lib == "mpack:":
+ snippet = re.sub(
+ "mpack_expect_([^_]+)_max\(&reader, [^)]+\)",
+ "mpack_expect_\\1_max(&reader, ?)",
+ snippet,
+ )
+ elif lib == "ubjson:":
+ snippet = re.sub("[^_ ]+_values[^.]+\.", "?_values[?].", snippet)
return snippet
diff --git a/lib/pubcode/__init__.py b/lib/pubcode/__init__.py
index 30c8490..3007bdf 100644
--- a/lib/pubcode/__init__.py
+++ b/lib/pubcode/__init__.py
@@ -1,6 +1,6 @@
"""A simple module for creating barcodes.
"""
-__version__ = '1.1.0'
+__version__ = "1.1.0"
from .code128 import Code128
diff --git a/lib/pubcode/code128.py b/lib/pubcode/code128.py
index 1c37f37..4fd7aed 100644
--- a/lib/pubcode/code128.py
+++ b/lib/pubcode/code128.py
@@ -5,6 +5,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
from builtins import * # Use Python3-like builtins for Python2.
import base64
import io
+
try:
from PIL import Image
except ImportError:
@@ -34,74 +35,189 @@ class Code128(object):
# List of bar and space weights, indexed by symbol character values (0-105), and the STOP character (106).
# The first weights is a bar and then it alternates.
_val2bars = [
- '212222', '222122', '222221', '121223', '121322', '131222', '122213', '122312', '132212', '221213',
- '221312', '231212', '112232', '122132', '122231', '113222', '123122', '123221', '223211', '221132',
- '221231', '213212', '223112', '312131', '311222', '321122', '321221', '312212', '322112', '322211',
- '212123', '212321', '232121', '111323', '131123', '131321', '112313', '132113', '132311', '211313',
- '231113', '231311', '112133', '112331', '132131', '113123', '113321', '133121', '313121', '211331',
- '231131', '213113', '213311', '213131', '311123', '311321', '331121', '312113', '312311', '332111',
- '314111', '221411', '431111', '111224', '111422', '121124', '121421', '141122', '141221', '112214',
- '112412', '122114', '122411', '142112', '142211', '241211', '221114', '413111', '241112', '134111',
- '111242', '121142', '121241', '114212', '124112', '124211', '411212', '421112', '421211', '212141',
- '214121', '412121', '111143', '111341', '131141', '114113', '114311', '411113', '411311', '113141',
- '114131', '311141', '411131', '211412', '211214', '211232', '2331112'
+ "212222",
+ "222122",
+ "222221",
+ "121223",
+ "121322",
+ "131222",
+ "122213",
+ "122312",
+ "132212",
+ "221213",
+ "221312",
+ "231212",
+ "112232",
+ "122132",
+ "122231",
+ "113222",
+ "123122",
+ "123221",
+ "223211",
+ "221132",
+ "221231",
+ "213212",
+ "223112",
+ "312131",
+ "311222",
+ "321122",
+ "321221",
+ "312212",
+ "322112",
+ "322211",
+ "212123",
+ "212321",
+ "232121",
+ "111323",
+ "131123",
+ "131321",
+ "112313",
+ "132113",
+ "132311",
+ "211313",
+ "231113",
+ "231311",
+ "112133",
+ "112331",
+ "132131",
+ "113123",
+ "113321",
+ "133121",
+ "313121",
+ "211331",
+ "231131",
+ "213113",
+ "213311",
+ "213131",
+ "311123",
+ "311321",
+ "331121",
+ "312113",
+ "312311",
+ "332111",
+ "314111",
+ "221411",
+ "431111",
+ "111224",
+ "111422",
+ "121124",
+ "121421",
+ "141122",
+ "141221",
+ "112214",
+ "112412",
+ "122114",
+ "122411",
+ "142112",
+ "142211",
+ "241211",
+ "221114",
+ "413111",
+ "241112",
+ "134111",
+ "111242",
+ "121142",
+ "121241",
+ "114212",
+ "124112",
+ "124211",
+ "411212",
+ "421112",
+ "421211",
+ "212141",
+ "214121",
+ "412121",
+ "111143",
+ "111341",
+ "131141",
+ "114113",
+ "114311",
+ "411113",
+ "411311",
+ "113141",
+ "114131",
+ "311141",
+ "411131",
+ "211412",
+ "211214",
+ "211232",
+ "2331112",
]
class Special(object):
"""These are special characters used by the Code128 encoding."""
- START_A = '[Start Code A]'
- START_B = '[Start Code B]'
- START_C = '[Start Code C]'
- CODE_A = '[Code A]'
- CODE_B = '[Code B]'
- CODE_C = '[Code C]'
- SHIFT_A = '[Shift A]'
- SHIFT_B = '[Shift B]'
- FNC_1 = '[FNC 1]'
- FNC_2 = '[FNC 2]'
- FNC_3 = '[FNC 3]'
- FNC_4 = '[FNC 4]'
- STOP = '[Stop]'
-
- _start_codes = {'A': Special.START_A, 'B': Special.START_B, 'C': Special.START_C}
- _char_codes = {'A': Special.CODE_A, 'B': Special.CODE_B, 'C': Special.CODE_C}
+
+ START_A = "[Start Code A]"
+ START_B = "[Start Code B]"
+ START_C = "[Start Code C]"
+ CODE_A = "[Code A]"
+ CODE_B = "[Code B]"
+ CODE_C = "[Code C]"
+ SHIFT_A = "[Shift A]"
+ SHIFT_B = "[Shift B]"
+ FNC_1 = "[FNC 1]"
+ FNC_2 = "[FNC 2]"
+ FNC_3 = "[FNC 3]"
+ FNC_4 = "[FNC 4]"
+ STOP = "[Stop]"
+
+ _start_codes = {"A": Special.START_A, "B": Special.START_B, "C": Special.START_C}
+ _char_codes = {"A": Special.CODE_A, "B": Special.CODE_B, "C": Special.CODE_C}
# Lists mapping symbol values to characters in each character set. This defines the alphabet and Code128._sym2val
# is derived from this structure.
_val2sym = {
# Code Set A includes ordinals 0 through 95 and 7 special characters. The ordinals include digits,
# upper case characters, punctuation and control characters.
- 'A':
- [chr(x) for x in range(32, 95 + 1)] +
- [chr(x) for x in range(0, 31 + 1)] +
- [
- Special.FNC_3, Special.FNC_2, Special.SHIFT_B, Special.CODE_C,
- Special.CODE_B, Special.FNC_4, Special.FNC_1,
- Special.START_A, Special.START_B, Special.START_C, Special.STOP
- ],
+ "A": [chr(x) for x in range(32, 95 + 1)]
+ + [chr(x) for x in range(0, 31 + 1)]
+ + [
+ Special.FNC_3,
+ Special.FNC_2,
+ Special.SHIFT_B,
+ Special.CODE_C,
+ Special.CODE_B,
+ Special.FNC_4,
+ Special.FNC_1,
+ Special.START_A,
+ Special.START_B,
+ Special.START_C,
+ Special.STOP,
+ ],
# Code Set B includes ordinals 32 through 127 and 7 special characters. The ordinals include digits,
# upper and lover case characters and punctuation.
- 'B':
- [chr(x) for x in range(32, 127 + 1)] +
- [
- Special.FNC_3, Special.FNC_2, Special.SHIFT_A, Special.CODE_C,
- Special.FNC_4, Special.CODE_A, Special.FNC_1,
- Special.START_A, Special.START_B, Special.START_C, Special.STOP
- ],
+ "B": [chr(x) for x in range(32, 127 + 1)]
+ + [
+ Special.FNC_3,
+ Special.FNC_2,
+ Special.SHIFT_A,
+ Special.CODE_C,
+ Special.FNC_4,
+ Special.CODE_A,
+ Special.FNC_1,
+ Special.START_A,
+ Special.START_B,
+ Special.START_C,
+ Special.STOP,
+ ],
# Code Set C includes all pairs of 2 digits and 3 special characters.
- 'C':
- ['%02d' % (x,) for x in range(0, 99 + 1)] +
- [
- Special.CODE_B, Special.CODE_A, Special.FNC_1,
- Special.START_A, Special.START_B, Special.START_C, Special.STOP
- ],
+ "C": ["%02d" % (x,) for x in range(0, 99 + 1)]
+ + [
+ Special.CODE_B,
+ Special.CODE_A,
+ Special.FNC_1,
+ Special.START_A,
+ Special.START_B,
+ Special.START_C,
+ Special.STOP,
+ ],
}
# Dicts mapping characters to symbol values in each character set.
_sym2val = {
- 'A': {char: val for val, char in enumerate(_val2sym['A'])},
- 'B': {char: val for val, char in enumerate(_val2sym['B'])},
- 'C': {char: val for val, char in enumerate(_val2sym['C'])},
+ "A": {char: val for val, char in enumerate(_val2sym["A"])},
+ "B": {char: val for val, char in enumerate(_val2sym["B"])},
+ "C": {char: val for val, char in enumerate(_val2sym["C"])},
}
# How large the quiet zone is on either side of the barcode, when quiet zone is used.
@@ -121,13 +237,13 @@ class Code128(object):
"""
self._validate_charset(data, charset)
- if charset in ('A', 'B'):
+ if charset in ("A", "B"):
charset *= len(data)
- elif charset in ('C',):
- charset *= (len(data) // 2)
+ elif charset in ("C",):
+ charset *= len(data) // 2
if len(data) % 2 == 1:
# If there are an odd number of characters for charset C, encode the last character with charset B.
- charset += 'B'
+ charset += "B"
self.data = data
self.symbol_values = self._encode(data, charset)
@@ -148,13 +264,13 @@ class Code128(object):
if len(charset) > 1:
charset_data_length = 0
for symbol_charset in charset:
- if symbol_charset not in ('A', 'B', 'C'):
+ if symbol_charset not in ("A", "B", "C"):
raise Code128.CharsetError
- charset_data_length += 2 if symbol_charset is 'C' else 1
+ charset_data_length += 2 if symbol_charset is "C" else 1
if charset_data_length != len(data):
raise Code128.CharsetLengthError
elif len(charset) == 1:
- if charset not in ('A', 'B', 'C'):
+ if charset not in ("A", "B", "C"):
raise Code128.CharsetError
elif charset is not None:
raise Code128.CharsetError
@@ -182,10 +298,12 @@ class Code128(object):
if charset is not prev_charset:
# Handle a special case of there being a single A in middle of two B's or the other way around, where
# using a single shift character is more efficient than using two character set switches.
- next_charset = charsets[symbol_num + 1] if symbol_num + 1 < len(charsets) else None
- if charset == 'A' and prev_charset == next_charset == 'B':
+ next_charset = (
+ charsets[symbol_num + 1] if symbol_num + 1 < len(charsets) else None
+ )
+ if charset == "A" and prev_charset == next_charset == "B":
result.append(cls._sym2val[prev_charset][cls.Special.SHIFT_A])
- elif charset == 'B' and prev_charset == next_charset == 'A':
+ elif charset == "B" and prev_charset == next_charset == "A":
result.append(cls._sym2val[prev_charset][cls.Special.SHIFT_B])
else:
# This is the normal case.
@@ -193,7 +311,7 @@ class Code128(object):
result.append(cls._sym2val[prev_charset][charset_symbol])
prev_charset = charset
- nxt = cur + (2 if charset == 'C' else 1)
+ nxt = cur + (2 if charset == "C" else 1)
symbol = data[cur:nxt]
cur = nxt
result.append(cls._sym2val[charset][symbol])
@@ -206,9 +324,10 @@ class Code128(object):
@property
def symbols(self):
"""List of the coded symbols as strings, with special characters included."""
+
def _iter_symbols(symbol_values):
# The initial charset doesn't matter, as the start codes have the same symbol values in all charsets.
- charset = 'A'
+ charset = "A"
shift_charset = None
for symbol_value in symbol_values:
@@ -219,15 +338,15 @@ class Code128(object):
symbol = self._val2sym[charset][symbol_value]
if symbol in (self.Special.START_A, self.Special.CODE_A):
- charset = 'A'
+ charset = "A"
elif symbol in (self.Special.START_B, self.Special.CODE_B):
- charset = 'B'
+ charset = "B"
elif symbol in (self.Special.START_C, self.Special.CODE_C):
- charset = 'C'
+ charset = "C"
elif symbol in (self.Special.SHIFT_A,):
- shift_charset = 'A'
+ shift_charset = "A"
elif symbol in (self.Special.SHIFT_B,):
- shift_charset = 'B'
+ shift_charset = "B"
yield symbol
@@ -243,7 +362,7 @@ class Code128(object):
:rtype: string
"""
- return ''.join(map((lambda val: self._val2bars[val]), self.symbol_values))
+ return "".join(map((lambda val: self._val2bars[val]), self.symbol_values))
@property
def modules(self):
@@ -255,6 +374,7 @@ class Code128(object):
:rtype: list[int]
"""
+
def _iterate_modules(bars):
is_bar = True
for char in map(int, bars):
@@ -288,7 +408,9 @@ class Code128(object):
:return: A monochromatic image containing the barcode as black bars on white background.
"""
if Image is None:
- raise Code128.MissingDependencyError("PIL module is required to use image method.")
+ raise Code128.MissingDependencyError(
+ "PIL module is required to use image method."
+ )
modules = list(self.modules)
if add_quiet_zone:
@@ -296,7 +418,7 @@ class Code128(object):
modules = [1] * self.quiet_zone + modules + [1] * self.quiet_zone
width = len(modules)
- img = Image.new(mode='1', size=(width, 1))
+ img = Image.new(mode="1", size=(width, 1))
img.putdata(modules)
if height == 1 and module_width == 1:
@@ -305,7 +427,7 @@ class Code128(object):
new_size = (width * module_width, height)
return img.resize(new_size, resample=Image.NEAREST)
- def data_url(self, image_format='png', add_quiet_zone=True):
+ def data_url(self, image_format="png", add_quiet_zone=True):
"""Get a data URL representing the barcode.
>>> barcode = Code128('Hello!', charset='B')
@@ -327,22 +449,21 @@ class Code128(object):
# Using BMP can often result in smaller data URLs than PNG, but it isn't as widely supported by browsers as PNG.
# GIFs result in data URLs 10 times bigger than PNG or BMP, possibly due to lack of support for monochrome GIFs
# in Pillow, so they shouldn't be used.
- if image_format == 'png':
+ if image_format == "png":
# Unfortunately there is no way to avoid adding the zlib headers.
# Using compress_level=0 sometimes results in a slightly bigger data size (by a few bytes), but there
# doesn't appear to be a difference between levels 9 and 1, so let's just use 1.
- pil_image.save(memory_file, format='png', compress_level=1)
- elif image_format == 'bmp':
- pil_image.save(memory_file, format='bmp')
+ pil_image.save(memory_file, format="png", compress_level=1)
+ elif image_format == "bmp":
+ pil_image.save(memory_file, format="bmp")
else:
- raise Code128.UnknownFormatError('Only png and bmp are supported.')
+ raise Code128.UnknownFormatError("Only png and bmp are supported.")
# Encode the data in the BytesIO object and convert the result into unicode.
- base64_image = base64.b64encode(memory_file.getvalue()).decode('ascii')
+ base64_image = base64.b64encode(memory_file.getvalue()).decode("ascii")
- data_url = 'data:image/{format};base64,{base64_data}'.format(
- format=image_format,
- base64_data=base64_image
+ data_url = "data:image/{format};base64,{base64_data}".format(
+ format=image_format, base64_data=base64_image
)
return data_url
diff --git a/lib/runner.py b/lib/runner.py
index 85df2b6..16f0a29 100644
--- a/lib/runner.py
+++ b/lib/runner.py
@@ -30,7 +30,7 @@ class SerialReader(serial.threaded.Protocol):
def __init__(self, callback=None):
"""Create a new SerialReader object."""
self.callback = callback
- self.recv_buf = ''
+ self.recv_buf = ""
self.lines = []
def __call__(self):
@@ -39,13 +39,13 @@ class SerialReader(serial.threaded.Protocol):
def data_received(self, data):
"""Append newly received serial data to the line buffer."""
try:
- str_data = data.decode('UTF-8')
+ str_data = data.decode("UTF-8")
self.recv_buf += str_data
# We may get anything between \r\n, \n\r and simple \n newlines.
# We assume that \n is always present and use str.strip to remove leading/trailing \r symbols
# Note: Do not call str.strip on lines[-1]! Otherwise, lines may be mangled
- lines = self.recv_buf.split('\n')
+ lines = self.recv_buf.split("\n")
if len(lines) > 1:
self.lines.extend(map(str.strip, lines[:-1]))
self.recv_buf = lines[-1]
@@ -94,14 +94,16 @@ class SerialMonitor:
"""
self.ser = serial.serial_for_url(port, do_not_open=True)
self.ser.baudrate = baud
- self.ser.parity = 'N'
+ self.ser.parity = "N"
self.ser.rtscts = False
self.ser.xonxoff = False
try:
self.ser.open()
except serial.SerialException as e:
- sys.stderr.write('Could not open serial port {}: {}\n'.format(self.ser.name, e))
+ sys.stderr.write(
+ "Could not open serial port {}: {}\n".format(self.ser.name, e)
+ )
sys.exit(1)
self.reader = SerialReader(callback=callback)
@@ -131,6 +133,7 @@ class SerialMonitor:
self.worker.stop()
self.ser.close()
+
# TODO Optionale Kalibrierung mit bekannten Widerständen an GPIOs am Anfang
# TODO Sync per LED? -> Vor und ggf nach jeder Transition kurz pulsen
# TODO Für Verbraucher mit wenig Energiebedarf: Versorgung direkt per GPIO
@@ -143,14 +146,14 @@ class EnergyTraceMonitor(SerialMonitor):
def __init__(self, port: str, baud: int, callback=None, voltage=3.3):
super().__init__(port=port, baud=baud, callback=callback)
self._voltage = voltage
- self._output = time.strftime('%Y%m%d-%H%M%S.etlog')
+ self._output = time.strftime("%Y%m%d-%H%M%S.etlog")
self._start_energytrace()
def _start_energytrace(self):
- cmd = ['msp430-etv', '--save', self._output, '0']
- self._logger = subprocess.Popen(cmd,
- stdout=subprocess.PIPE, stderr=subprocess.PIPE,
- universal_newlines=True)
+ cmd = ["msp430-etv", "--save", self._output, "0"]
+ self._logger = subprocess.Popen(
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True
+ )
def close(self):
super().close()
@@ -162,14 +165,16 @@ class EnergyTraceMonitor(SerialMonitor):
def get_config(self) -> dict:
return {
- 'voltage': self._voltage,
+ "voltage": self._voltage,
}
class MIMOSAMonitor(SerialMonitor):
"""MIMOSAMonitor captures serial output and MIMOSA energy data for a specific amount of time."""
- def __init__(self, port: str, baud: int, callback=None, offset=130, shunt=330, voltage=3.3):
+ def __init__(
+ self, port: str, baud: int, callback=None, offset=130, shunt=330, voltage=3.3
+ ):
super().__init__(port=port, baud=baud, callback=callback)
self._offset = offset
self._shunt = shunt
@@ -177,39 +182,41 @@ class MIMOSAMonitor(SerialMonitor):
self._start_mimosa()
def _mimosactl(self, subcommand):
- cmd = ['mimosactl']
+ cmd = ["mimosactl"]
cmd.append(subcommand)
res = subprocess.run(cmd)
if res.returncode != 0:
res = subprocess.run(cmd)
if res.returncode != 0:
- raise RuntimeError('{} returned {}'.format(' '.join(cmd), res.returncode))
+ raise RuntimeError(
+ "{} returned {}".format(" ".join(cmd), res.returncode)
+ )
def _mimosacmd(self, opts):
- cmd = ['MimosaCMD']
+ cmd = ["MimosaCMD"]
cmd.extend(opts)
res = subprocess.run(cmd)
if res.returncode != 0:
- raise RuntimeError('{} returned {}'.format(' '.join(cmd), res.returncode))
+ raise RuntimeError("{} returned {}".format(" ".join(cmd), res.returncode))
def _start_mimosa(self):
- self._mimosactl('disconnect')
- self._mimosacmd(['--start'])
- self._mimosacmd(['--parameter', 'offset', str(self._offset)])
- self._mimosacmd(['--parameter', 'shunt', str(self._shunt)])
- self._mimosacmd(['--parameter', 'voltage', str(self._voltage)])
- self._mimosacmd(['--mimosa-start'])
+ self._mimosactl("disconnect")
+ self._mimosacmd(["--start"])
+ self._mimosacmd(["--parameter", "offset", str(self._offset)])
+ self._mimosacmd(["--parameter", "shunt", str(self._shunt)])
+ self._mimosacmd(["--parameter", "voltage", str(self._voltage)])
+ self._mimosacmd(["--mimosa-start"])
time.sleep(2)
- self._mimosactl('1k') # 987 ohm
+ self._mimosactl("1k") # 987 ohm
time.sleep(2)
- self._mimosactl('100k') # 99.3 kohm
+ self._mimosactl("100k") # 99.3 kohm
time.sleep(2)
- self._mimosactl('connect')
+ self._mimosactl("connect")
def _stop_mimosa(self):
# Make sure the MIMOSA daemon has gathered all needed data
time.sleep(2)
- self._mimosacmd(['--mimosa-stop'])
+ self._mimosacmd(["--mimosa-stop"])
mtime_changed = True
mim_file = None
time.sleep(1)
@@ -218,7 +225,7 @@ class MIMOSAMonitor(SerialMonitor):
# files lying around in the directory will not confuse our
# heuristic.
for filename in sorted(os.listdir(), reverse=True):
- if re.search(r'[.]mim$', filename):
+ if re.search(r"[.]mim$", filename):
mim_file = filename
break
while mtime_changed:
@@ -226,7 +233,7 @@ class MIMOSAMonitor(SerialMonitor):
if time.time() - os.stat(mim_file).st_mtime < 3:
mtime_changed = True
time.sleep(1)
- self._mimosacmd(['--stop'])
+ self._mimosacmd(["--stop"])
return mim_file
def close(self):
@@ -238,9 +245,9 @@ class MIMOSAMonitor(SerialMonitor):
def get_config(self) -> dict:
return {
- 'offset': self._offset,
- 'shunt': self._shunt,
- 'voltage': self._voltage,
+ "offset": self._offset,
+ "shunt": self._shunt,
+ "voltage": self._voltage,
}
@@ -263,14 +270,17 @@ class ShellMonitor:
stderr and return status are discarded at the moment.
"""
if type(timeout) != int:
- raise ValueError('timeout argument must be int')
- res = subprocess.run(['timeout', '{:d}s'.format(timeout), self.script],
- stdout=subprocess.PIPE, stderr=subprocess.PIPE,
- universal_newlines=True)
+ raise ValueError("timeout argument must be int")
+ res = subprocess.run(
+ ["timeout", "{:d}s".format(timeout), self.script],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ universal_newlines=True,
+ )
if self.callback:
- for line in res.stdout.split('\n'):
+ for line in res.stdout.split("\n"):
self.callback(line)
- return res.stdout.split('\n')
+ return res.stdout.split("\n")
def monitor(self):
raise NotImplementedError
@@ -285,28 +295,35 @@ class ShellMonitor:
def build(arch, app, opts=[]):
- command = ['make', 'arch={}'.format(arch), 'app={}'.format(app), 'clean']
+ command = ["make", "arch={}".format(arch), "app={}".format(app), "clean"]
command.extend(opts)
- res = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
- universal_newlines=True)
+ res = subprocess.run(
+ command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True
+ )
if res.returncode != 0:
- raise RuntimeError('Build failure, executing {}:\n'.format(command) + res.stderr)
- command = ['make', '-B', 'arch={}'.format(arch), 'app={}'.format(app)]
+ raise RuntimeError(
+ "Build failure, executing {}:\n".format(command) + res.stderr
+ )
+ command = ["make", "-B", "arch={}".format(arch), "app={}".format(app)]
command.extend(opts)
- res = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
- universal_newlines=True)
+ res = subprocess.run(
+ command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True
+ )
if res.returncode != 0:
- raise RuntimeError('Build failure, executing {}:\n '.format(command) + res.stderr)
+ raise RuntimeError(
+ "Build failure, executing {}:\n ".format(command) + res.stderr
+ )
return command
def flash(arch, app, opts=[]):
- command = ['make', 'arch={}'.format(arch), 'app={}'.format(app), 'program']
+ command = ["make", "arch={}".format(arch), "app={}".format(app), "program"]
command.extend(opts)
- res = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
- universal_newlines=True)
+ res = subprocess.run(
+ command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True
+ )
if res.returncode != 0:
- raise RuntimeError('Flash failure')
+ raise RuntimeError("Flash failure")
return command
@@ -316,13 +333,14 @@ def get_info(arch, opts: list = []) -> list:
Returns a list.
"""
- command = ['make', 'arch={}'.format(arch), 'info']
+ command = ["make", "arch={}".format(arch), "info"]
command.extend(opts)
- res = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
- universal_newlines=True)
+ res = subprocess.run(
+ command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True
+ )
if res.returncode != 0:
- raise RuntimeError('make info Failure')
- return res.stdout.split('\n')
+ raise RuntimeError("make info Failure")
+ return res.stdout.split("\n")
def get_monitor(arch: str, **kwargs) -> object:
@@ -336,32 +354,32 @@ def get_monitor(arch: str, **kwargs) -> object:
:param mimosa: `MIMOSAMonitor` options. Returns a MIMOSA monitor if not None.
"""
for line in get_info(arch):
- if 'Monitor:' in line:
- _, port, arg = line.split(' ')
- if port == 'run':
+ if "Monitor:" in line:
+ _, port, arg = line.split(" ")
+ if port == "run":
return ShellMonitor(arg, **kwargs)
- elif 'mimosa' in kwargs and kwargs['mimosa'] is not None:
- mimosa_kwargs = kwargs.pop('mimosa')
+ elif "mimosa" in kwargs and kwargs["mimosa"] is not None:
+ mimosa_kwargs = kwargs.pop("mimosa")
return MIMOSAMonitor(port, arg, **mimosa_kwargs, **kwargs)
- elif 'energytrace' in kwargs and kwargs['energytrace'] is not None:
- energytrace_kwargs = kwargs.pop('energytrace')
+ elif "energytrace" in kwargs and kwargs["energytrace"] is not None:
+ energytrace_kwargs = kwargs.pop("energytrace")
return EnergyTraceMonitor(port, arg, **energytrace_kwargs, **kwargs)
else:
- kwargs.pop('energytrace', None)
- kwargs.pop('mimosa', None)
+ kwargs.pop("energytrace", None)
+ kwargs.pop("mimosa", None)
return SerialMonitor(port, arg, **kwargs)
- raise RuntimeError('Monitor failure')
+ raise RuntimeError("Monitor failure")
def get_counter_limits(arch: str) -> tuple:
"""Return multipass max counter and max overflow value for arch."""
for line in get_info(arch):
- match = re.match('Counter Overflow: ([^/]*)/(.*)', line)
+ match = re.match("Counter Overflow: ([^/]*)/(.*)", line)
if match:
overflow_value = int(match.group(1))
max_overflow = int(match.group(2))
return overflow_value, max_overflow
- raise RuntimeError('Did not find Counter Overflow limits')
+ raise RuntimeError("Did not find Counter Overflow limits")
def get_counter_limits_us(arch: str) -> tuple:
@@ -370,13 +388,13 @@ def get_counter_limits_us(arch: str) -> tuple:
overflow_value = 0
max_overflow = 0
for line in get_info(arch):
- match = re.match(r'CPU\s+Freq:\s+(.*)\s+Hz', line)
+ match = re.match(r"CPU\s+Freq:\s+(.*)\s+Hz", line)
if match:
cpu_freq = int(match.group(1))
- match = re.match(r'Counter Overflow:\s+([^/]*)/(.*)', line)
+ match = re.match(r"Counter Overflow:\s+([^/]*)/(.*)", line)
if match:
overflow_value = int(match.group(1))
max_overflow = int(match.group(2))
if cpu_freq and overflow_value:
return 1000000 / cpu_freq, overflow_value * 1000000 / cpu_freq, max_overflow
- raise RuntimeError('Did not find Counter Overflow limits')
+ raise RuntimeError("Did not find Counter Overflow limits")
diff --git a/lib/size_to_radio_energy.py b/lib/size_to_radio_energy.py
index a3cf7c5..10de1a3 100644
--- a/lib/size_to_radio_energy.py
+++ b/lib/size_to_radio_energy.py
@@ -7,38 +7,42 @@ class can convert a cycle count to an energy consumption.
import numpy as np
+
def get_class(radio_name: str):
"""Return model class for radio_name."""
- if radio_name == 'CC1200tx':
+ if radio_name == "CC1200tx":
return CC1200tx
- if radio_name == 'CC1200rx':
+ if radio_name == "CC1200rx":
return CC1200rx
- if radio_name == 'NRF24L01tx':
+ if radio_name == "NRF24L01tx":
return NRF24L01tx
- if radio_name == 'NRF24L01dtx':
+ if radio_name == "NRF24L01dtx":
return NRF24L01dtx
- if radio_name == 'esp8266dtx':
+ if radio_name == "esp8266dtx":
return ESP8266dtx
- if radio_name == 'esp8266drx':
+ if radio_name == "esp8266drx":
return ESP8266drx
+
def _param_list_to_dict(device, param_list):
param_dict = dict()
for i, parameter in enumerate(sorted(device.parameters.keys())):
param_dict[parameter] = param_list[i]
return param_dict
+
class CC1200tx:
"""CC1200 TX energy based on aemr measurements."""
- name = 'CC1200tx'
+
+ name = "CC1200tx"
parameters = {
- 'symbolrate' : [6, 12, 25, 50, 100, 200, 250], # ksps
- 'txbytes' : [],
- 'txpower' : [10, 20, 30, 40, 47], # dBm = f(txpower)
+ "symbolrate": [6, 12, 25, 50, 100, 200, 250], # ksps
+ "txbytes": [],
+ "txpower": [10, 20, 30, 40, 47], # dBm = f(txpower)
}
default_params = {
- 'symbolrate' : 100,
- 'txpower' : 47,
+ "symbolrate": 100,
+ "txpower": 47,
}
@staticmethod
@@ -48,106 +52,129 @@ class CC1200tx:
# Mittlere TX-Leistung, gefitted von AEMR
# Messdaten erhoben bei 3.6V
- power = 8.18053941e+04
- power -= 1.24208376e+03 * np.sqrt(params['symbolrate'])
- power -= 5.73742779e+02 * np.log(params['txbytes'])
- power += 1.76945886e+01 * (params['txpower'])**2
- power += 2.33469617e+02 * np.sqrt(params['symbolrate']) * np.log(params['txbytes'])
- power -= 6.99137635e-01 * np.sqrt(params['symbolrate']) * (params['txpower'])**2
- power -= 3.31365158e-01 * np.log(params['txbytes']) * (params['txpower'])**2
- power += 1.32784945e-01 * np.sqrt(params['symbolrate']) * np.log(params['txbytes']) * (params['txpower'])**2
+ power = 8.18053941e04
+ power -= 1.24208376e03 * np.sqrt(params["symbolrate"])
+ power -= 5.73742779e02 * np.log(params["txbytes"])
+ power += 1.76945886e01 * (params["txpower"]) ** 2
+ power += (
+ 2.33469617e02 * np.sqrt(params["symbolrate"]) * np.log(params["txbytes"])
+ )
+ power -= (
+ 6.99137635e-01 * np.sqrt(params["symbolrate"]) * (params["txpower"]) ** 2
+ )
+ power -= 3.31365158e-01 * np.log(params["txbytes"]) * (params["txpower"]) ** 2
+ power += (
+ 1.32784945e-01
+ * np.sqrt(params["symbolrate"])
+ * np.log(params["txbytes"])
+ * (params["txpower"]) ** 2
+ )
# txDone-Timeout, gefitted von AEMR
- duration = 3.65513500e+02
- duration += 8.01016526e+04 * 1/(params['symbolrate'])
- duration -= 7.06364515e-03 * params['txbytes']
- duration += 8.00029860e+03 * 1/(params['symbolrate']) * params['txbytes']
+ duration = 3.65513500e02
+ duration += 8.01016526e04 * 1 / (params["symbolrate"])
+ duration -= 7.06364515e-03 * params["txbytes"]
+ duration += 8.00029860e03 * 1 / (params["symbolrate"]) * params["txbytes"]
# TX-Energie, gefitted von AEMR
# Achtung: Energy ist in µJ, nicht (wie in AEMR-Transitionsmodellen üblich) in pJ
# Messdaten erhoben bei 3.6V
- energy = 1.74383259e+01
- energy += 6.29922138e+03 * 1/(params['symbolrate'])
- energy += 1.13307135e-02 * params['txbytes']
- energy -= 1.28121377e-04 * (params['txpower'])**2
- energy += 6.29080184e+02 * 1/(params['symbolrate']) * params['txbytes']
- energy += 1.25647926e+00 * 1/(params['symbolrate']) * (params['txpower'])**2
- energy += 1.31996202e-05 * params['txbytes'] * (params['txpower'])**2
- energy += 1.25676966e-01 * 1/(params['symbolrate']) * params['txbytes'] * (params['txpower'])**2
+ energy = 1.74383259e01
+ energy += 6.29922138e03 * 1 / (params["symbolrate"])
+ energy += 1.13307135e-02 * params["txbytes"]
+ energy -= 1.28121377e-04 * (params["txpower"]) ** 2
+ energy += 6.29080184e02 * 1 / (params["symbolrate"]) * params["txbytes"]
+ energy += 1.25647926e00 * 1 / (params["symbolrate"]) * (params["txpower"]) ** 2
+ energy += 1.31996202e-05 * params["txbytes"] * (params["txpower"]) ** 2
+ energy += (
+ 1.25676966e-01
+ * 1
+ / (params["symbolrate"])
+ * params["txbytes"]
+ * (params["txpower"]) ** 2
+ )
return energy * 1e-6
@staticmethod
def get_energy_per_byte(params):
- A = 8.18053941e+04
- A -= 1.24208376e+03 * np.sqrt(params['symbolrate'])
- A += 1.76945886e+01 * (params['txpower'])**2
- A -= 6.99137635e-01 * np.sqrt(params['symbolrate']) * (params['txpower'])**2
- B = -5.73742779e+02
- B += 2.33469617e+02 * np.sqrt(params['symbolrate'])
- B -= 3.31365158e-01 * (params['txpower'])**2
- B += 1.32784945e-01 * np.sqrt(params['symbolrate']) * (params['txpower'])**2
- C = 3.65513500e+02
- C += 8.01016526e+04 * 1/(params['symbolrate'])
- D = -7.06364515e-03
- D += 8.00029860e+03 * 1/(params['symbolrate'])
-
- x = params['txbytes']
+ A = 8.18053941e04
+ A -= 1.24208376e03 * np.sqrt(params["symbolrate"])
+ A += 1.76945886e01 * (params["txpower"]) ** 2
+ A -= 6.99137635e-01 * np.sqrt(params["symbolrate"]) * (params["txpower"]) ** 2
+ B = -5.73742779e02
+ B += 2.33469617e02 * np.sqrt(params["symbolrate"])
+ B -= 3.31365158e-01 * (params["txpower"]) ** 2
+ B += 1.32784945e-01 * np.sqrt(params["symbolrate"]) * (params["txpower"]) ** 2
+ C = 3.65513500e02
+ C += 8.01016526e04 * 1 / (params["symbolrate"])
+ D = -7.06364515e-03
+ D += 8.00029860e03 * 1 / (params["symbolrate"])
+
+ x = params["txbytes"]
# in pJ
- de_dx = A * D + B * C * 1/x + B * D * (np.log(x) + 1)
+ de_dx = A * D + B * C * 1 / x + B * D * (np.log(x) + 1)
# in µJ
- de_dx = 1.13307135e-02
- de_dx += 6.29080184e+02 * 1/(params['symbolrate'])
- de_dx += 1.31996202e-05 * (params['txpower'])**2
- de_dx += 1.25676966e-01 * 1/(params['symbolrate']) * (params['txpower'])**2
+ de_dx = 1.13307135e-02
+ de_dx += 6.29080184e02 * 1 / (params["symbolrate"])
+ de_dx += 1.31996202e-05 * (params["txpower"]) ** 2
+ de_dx += 1.25676966e-01 * 1 / (params["symbolrate"]) * (params["txpower"]) ** 2
- #de_dx = (B * 1/x) * (C + D * x) + (A + B * np.log(x)) * D
+ # de_dx = (B * 1/x) * (C + D * x) + (A + B * np.log(x)) * D
return de_dx * 1e-6
+
class CC1200rx:
"""CC1200 RX energy based on aemr measurements."""
- name = 'CC1200rx'
+
+ name = "CC1200rx"
parameters = {
- 'symbolrate' : [6, 12, 25, 50, 100, 200, 250], # ksps
- 'txbytes' : [],
- 'txpower' : [10, 20, 30, 40, 47], # dBm = f(txpower)
+ "symbolrate": [6, 12, 25, 50, 100, 200, 250], # ksps
+ "txbytes": [],
+ "txpower": [10, 20, 30, 40, 47], # dBm = f(txpower)
}
default_params = {
- 'symbolrate' : 100,
- 'txpower' : 47,
+ "symbolrate": 100,
+ "txpower": 47,
}
@staticmethod
def get_energy(params):
# TODO
- return params['txbytes'] * CC1200rx.get_energy_per_byte(params)
+ return params["txbytes"] * CC1200rx.get_energy_per_byte(params)
@staticmethod
def get_energy_per_byte(params):
- #RX : 0 + regression_arg(0) + regression_arg(1) * np.log(parameter(symbolrate) + 1)
+ # RX : 0 + regression_arg(0) + regression_arg(1) * np.log(parameter(symbolrate) + 1)
# [84414.91636169 205.63323036]
- de_dx = (84414.91636169 + 205.63323036 * np.log(params['symbolrate'] + 1)) * 8000 / params['symbolrate']
+ de_dx = (
+ (84414.91636169 + 205.63323036 * np.log(params["symbolrate"] + 1))
+ * 8000
+ / params["symbolrate"]
+ )
return de_dx * 1e-12
+
class NRF24L01rx:
"""NRF24L01+ RX energy based on aemr measurements (using variable packet size)"""
- name = 'NRF24L01'
+
+ name = "NRF24L01"
parameters = {
- 'datarate' : [250, 1000, 2000], # kbps
- 'txbytes' : [],
- 'txpower' : [-18, -12, -6, 0], # dBm
- 'voltage' : [1.9, 3.6],
+ "datarate": [250, 1000, 2000], # kbps
+ "txbytes": [],
+ "txpower": [-18, -12, -6, 0], # dBm
+ "voltage": [1.9, 3.6],
}
default_params = {
- 'datarate' : 1000,
- 'txpower' : -6,
- 'voltage' : 3,
+ "datarate": 1000,
+ "txpower": -6,
+ "voltage": 3,
}
@staticmethod
@@ -155,35 +182,40 @@ class NRF24L01rx:
# RX : 0 + regression_arg(0) + regression_arg(1) * np.sqrt(parameter(datarate))
# [48530.73235537 117.25274402]
- de_dx = (48530.73235537 + 117.25274402 * np.sqrt(params['datarate'])) * 8000 / params['datarate']
+ de_dx = (
+ (48530.73235537 + 117.25274402 * np.sqrt(params["datarate"]))
+ * 8000
+ / params["datarate"]
+ )
return de_dx * 1e-12
+
# PYTHONPATH=lib bin/analyze-archive.py --show-model=all --show-quality=table ../data/*_RF24_no_retries.tar
class NRF24L01tx:
"""NRF24L01+ TX(*) energy based on aemr measurements (32B fixed packet size, (*)ack-await, no retries)."""
- name = 'NRF24L01'
+
+ name = "NRF24L01"
parameters = {
- 'datarate' : [250, 1000, 2000], # kbps
- 'txbytes' : [],
- 'txpower' : [-18, -12, -6, 0], # dBm
- 'voltage' : [1.9, 3.6],
+ "datarate": [250, 1000, 2000], # kbps
+ "txbytes": [],
+ "txpower": [-18, -12, -6, 0], # dBm
+ "voltage": [1.9, 3.6],
}
default_params = {
- 'datarate' : 1000,
- 'txpower' : -6,
- 'voltage' : 3,
+ "datarate": 1000,
+ "txpower": -6,
+ "voltage": 3,
}
-# AEMR:
-# TX power / energy:
-#TX : 0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate)) + regression_arg(2) * (19.47+parameter(txpower))**2 + regression_arg(3) * 1/(parameter(datarate)) * (19.47+parameter(txpower))**2
-# [6.30323056e+03 2.59889924e+06 7.82186268e+00 8.69746093e+03]
-#TX : 0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate)) + regression_arg(2) * (19.47+parameter(txpower))**2 + regression_arg(3) * 1/(parameter(datarate)) * (19.47+parameter(txpower))**2
-# [7.67932887e+00 1.02969455e+04 4.55116475e-03 2.99786534e+01]
-#epilogue : timeout : 0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate))
-# [ 1624.06589147 332251.93798766]
-
+ # AEMR:
+ # TX power / energy:
+ # TX : 0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate)) + regression_arg(2) * (19.47+parameter(txpower))**2 + regression_arg(3) * 1/(parameter(datarate)) * (19.47+parameter(txpower))**2
+ # [6.30323056e+03 2.59889924e+06 7.82186268e+00 8.69746093e+03]
+ # TX : 0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate)) + regression_arg(2) * (19.47+parameter(txpower))**2 + regression_arg(3) * 1/(parameter(datarate)) * (19.47+parameter(txpower))**2
+ # [7.67932887e+00 1.02969455e+04 4.55116475e-03 2.99786534e+01]
+ # epilogue : timeout : 0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate))
+ # [ 1624.06589147 332251.93798766]
@staticmethod
def get_energy(params):
@@ -192,31 +224,37 @@ class NRF24L01tx:
# TX-Leistung, gefitted von AEMR
# Messdaten erhoben bei 3.6V
- power = 6.30323056e+03
- power += 2.59889924e+06 * 1/params['datarate']
- power += 7.82186268e+00 * (19.47+params['txpower'])**2
- power += 8.69746093e+03 * 1/params['datarate'] * (19.47+params['txpower'])**2
+ power = 6.30323056e03
+ power += 2.59889924e06 * 1 / params["datarate"]
+ power += 7.82186268e00 * (19.47 + params["txpower"]) ** 2
+ power += (
+ 8.69746093e03 * 1 / params["datarate"] * (19.47 + params["txpower"]) ** 2
+ )
# TX-Dauer, gefitted von AEMR
duration = 1624.06589147
- duration += 332251.93798766 * 1/params['datarate']
+ duration += 332251.93798766 * 1 / params["datarate"]
# TX-Energie, gefitted von AEMR
# Achtung: Energy ist in µJ, nicht (wie in AEMR-Transitionsmodellen üblich) in pJ
# Messdaten erhoben bei 3.6V
- energy = 7.67932887e+00
- energy += 1.02969455e+04 * 1/params['datarate']
- energy += 4.55116475e-03 * (19.47+params['txpower'])**2
- energy += 2.99786534e+01 * 1/params['datarate'] * (19.47+params['txpower'])**2
+ energy = 7.67932887e00
+ energy += 1.02969455e04 * 1 / params["datarate"]
+ energy += 4.55116475e-03 * (19.47 + params["txpower"]) ** 2
+ energy += (
+ 2.99786534e01 * 1 / params["datarate"] * (19.47 + params["txpower"]) ** 2
+ )
- energy = power * 1e-6 * duration * 1e-6 * np.ceil(params['txbytes'] / 32)
+ energy = power * 1e-6 * duration * 1e-6 * np.ceil(params["txbytes"] / 32)
return energy
@staticmethod
def get_energy_per_byte(params):
if type(params) != dict:
- return NRF24L01tx.get_energy_per_byte(_param_list_to_dict(NRF24L01tx, params))
+ return NRF24L01tx.get_energy_per_byte(
+ _param_list_to_dict(NRF24L01tx, params)
+ )
# in µJ
de_dx = 0
@@ -224,17 +262,18 @@ class NRF24L01tx:
class NRF24L01dtx:
"""nRF24L01+ TX energy based on datasheet values (probably unerestimated)"""
- name = 'NRF24L01'
+
+ name = "NRF24L01"
parameters = {
- 'datarate' : [250, 1000, 2000], # kbps
- 'txbytes' : [],
- 'txpower' : [-18, -12, -6, 0], # dBm
- 'voltage' : [1.9, 3.6],
+ "datarate": [250, 1000, 2000], # kbps
+ "txbytes": [],
+ "txpower": [-18, -12, -6, 0], # dBm
+ "voltage": [1.9, 3.6],
}
default_params = {
- 'datarate' : 1000,
- 'txpower' : -6,
- 'voltage' : 3,
+ "datarate": 1000,
+ "txpower": -6,
+ "voltage": 3,
}
# 130 us RX settling: 8.9 mE
@@ -248,35 +287,37 @@ class NRF24L01dtx:
header_bytes = 7
# TX settling: 130 us @ 8 mA
- energy = 8e-3 * params['voltage'] * 130e-6
+ energy = 8e-3 * params["voltage"] * 130e-6
- if params['txpower'] == -18:
+ if params["txpower"] == -18:
current = 7e-3
- elif params['txpower'] == -12:
+ elif params["txpower"] == -12:
current = 7.5e-3
- elif params['txpower'] == -6:
+ elif params["txpower"] == -6:
current = 9e-3
- elif params['txpower'] == 0:
+ elif params["txpower"] == 0:
current = 11.3e-3
- energy += current * params['voltage'] * ((header_bytes + params['txbytes']) * 8 / (params['datarate'] * 1e3))
+ energy += (
+ current
+ * params["voltage"]
+ * ((header_bytes + params["txbytes"]) * 8 / (params["datarate"] * 1e3))
+ )
return energy
+
class ESP8266dtx:
"""esp8266 TX energy based on (hardly documented) datasheet values"""
- name = 'esp8266'
+
+ name = "esp8266"
parameters = {
- 'voltage' : [2.5, 3.0, 3.3, 3.6],
- 'txbytes' : [],
- 'bitrate' : [65e6],
- 'tx_current' : [120e-3],
- }
- default_params = {
- 'voltage' : 3,
- 'bitrate' : 65e6,
- 'tx_current' : 120e-3
+ "voltage": [2.5, 3.0, 3.3, 3.6],
+ "txbytes": [],
+ "bitrate": [65e6],
+ "tx_current": [120e-3],
}
+ default_params = {"voltage": 3, "bitrate": 65e6, "tx_current": 120e-3}
@staticmethod
def get_energy(params):
@@ -286,26 +327,26 @@ class ESP8266dtx:
@staticmethod
def get_energy_per_byte(params):
if type(params) != dict:
- return ESP8266dtx.get_energy_per_byte(_param_list_to_dict(ESP8266dtx, params))
+ return ESP8266dtx.get_energy_per_byte(
+ _param_list_to_dict(ESP8266dtx, params)
+ )
# TX in 802.11n MCS7 -> 64QAM, 65/72.2 Mbit/s @ 20MHz channel, 135/150 Mbit/s @ 40MHz
# -> Value for 65 Mbit/s @ 20MHz channel
- return params['tx_current'] * params['voltage'] / params['bitrate']
+ return params["tx_current"] * params["voltage"] / params["bitrate"]
+
class ESP8266drx:
"""esp8266 RX energy based on (hardly documented) datasheet values"""
- name = 'esp8266'
+
+ name = "esp8266"
parameters = {
- 'voltage' : [2.5, 3.0, 3.3, 3.6],
- 'txbytes' : [],
- 'bitrate' : [65e6],
- 'rx_current' : [56e-3],
- }
- default_params = {
- 'voltage' : 3,
- 'bitrate' : 65e6,
- 'rx_current' : 56e-3
+ "voltage": [2.5, 3.0, 3.3, 3.6],
+ "txbytes": [],
+ "bitrate": [65e6],
+ "rx_current": [56e-3],
}
+ default_params = {"voltage": 3, "bitrate": 65e6, "rx_current": 56e-3}
@staticmethod
def get_energy(params):
@@ -315,8 +356,10 @@ class ESP8266drx:
@staticmethod
def get_energy_per_byte(params):
if type(params) != dict:
- return ESP8266drx.get_energy_per_byte(_param_list_to_dict(ESP8266drx, params))
+ return ESP8266drx.get_energy_per_byte(
+ _param_list_to_dict(ESP8266drx, params)
+ )
# TX in 802.11n MCS7 -> 64QAM, 65/72.2 Mbit/s @ 20MHz channel, 135/150 Mbit/s @ 40MHz
# -> Value for 65 Mbit/s @ 20MHz channel
- return params['rx_current'] * params['voltage'] / params['bitrate']
+ return params["rx_current"] * params["voltage"] / params["bitrate"]
diff --git a/lib/sly/__init__.py b/lib/sly/__init__.py
index 3c1e708..3a2d92e 100644
--- a/lib/sly/__init__.py
+++ b/lib/sly/__init__.py
@@ -1,6 +1,5 @@
-
from .lex import *
from .yacc import *
__version__ = "0.4"
-__all__ = [ *lex.__all__, *yacc.__all__ ]
+__all__ = [*lex.__all__, *yacc.__all__]
diff --git a/lib/sly/ast.py b/lib/sly/ast.py
index 7b79ac5..05802bd 100644
--- a/lib/sly/ast.py
+++ b/lib/sly/ast.py
@@ -1,25 +1,24 @@
# sly/ast.py
import sys
+
class AST(object):
-
@classmethod
def __init_subclass__(cls, **kwargs):
mod = sys.modules[cls.__module__]
- if not hasattr(cls, '__annotations__'):
+ if not hasattr(cls, "__annotations__"):
return
hints = list(cls.__annotations__.items())
def __init__(self, *args, **kwargs):
if len(hints) != len(args):
- raise TypeError(f'Expected {len(hints)} arguments')
+ raise TypeError(f"Expected {len(hints)} arguments")
for arg, (name, val) in zip(args, hints):
if isinstance(val, str):
val = getattr(mod, val)
if not isinstance(arg, val):
- raise TypeError(f'{name} argument must be {val}')
+ raise TypeError(f"{name} argument must be {val}")
setattr(self, name, arg)
cls.__init__ = __init__
-
diff --git a/lib/sly/docparse.py b/lib/sly/docparse.py
index d5a83ce..0f35c97 100644
--- a/lib/sly/docparse.py
+++ b/lib/sly/docparse.py
@@ -2,7 +2,8 @@
#
# Support doc-string parsing classes
-__all__ = [ 'DocParseMeta' ]
+__all__ = ["DocParseMeta"]
+
class DocParseMeta(type):
'''
@@ -44,17 +45,17 @@ class DocParseMeta(type):
@staticmethod
def __new__(meta, clsname, bases, clsdict):
- if '__doc__' in clsdict:
+ if "__doc__" in clsdict:
lexer = meta.lexer()
parser = meta.parser()
lexer.cls_name = parser.cls_name = clsname
- lexer.cls_qualname = parser.cls_qualname = clsdict['__qualname__']
- lexer.cls_module = parser.cls_module = clsdict['__module__']
- parsedict = parser.parse(lexer.tokenize(clsdict['__doc__']))
- assert isinstance(parsedict, dict), 'Parser must return a dictionary'
+ lexer.cls_qualname = parser.cls_qualname = clsdict["__qualname__"]
+ lexer.cls_module = parser.cls_module = clsdict["__module__"]
+ parsedict = parser.parse(lexer.tokenize(clsdict["__doc__"]))
+ assert isinstance(parsedict, dict), "Parser must return a dictionary"
clsdict.update(parsedict)
return super().__new__(meta, clsname, bases, clsdict)
@classmethod
def __init_subclass__(cls):
- assert hasattr(cls, 'parser') and hasattr(cls, 'lexer')
+ assert hasattr(cls, "parser") and hasattr(cls, "lexer")
diff --git a/lib/sly/lex.py b/lib/sly/lex.py
index 246dd9e..0ab0160 100644
--- a/lib/sly/lex.py
+++ b/lib/sly/lex.py
@@ -31,51 +31,63 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# -----------------------------------------------------------------------------
-__all__ = ['Lexer', 'LexerStateChange']
+__all__ = ["Lexer", "LexerStateChange"]
import re
import copy
+
class LexError(Exception):
- '''
+ """
Exception raised if an invalid character is encountered and no default
error handler function is defined. The .text attribute of the exception
contains all remaining untokenized text. The .error_index is the index
location of the error.
- '''
+ """
+
def __init__(self, message, text, error_index):
self.args = (message,)
self.text = text
self.error_index = error_index
+
class PatternError(Exception):
- '''
+ """
Exception raised if there's some kind of problem with the specified
regex patterns in the lexer.
- '''
+ """
+
pass
+
class LexerBuildError(Exception):
- '''
+ """
Exception raised if there's some sort of problem building the lexer.
- '''
+ """
+
pass
+
class LexerStateChange(Exception):
- '''
+ """
Exception raised to force a lexing state change
- '''
+ """
+
def __init__(self, newstate, tok=None):
self.newstate = newstate
self.tok = tok
+
class Token(object):
- '''
+ """
Representation of a single token.
- '''
- __slots__ = ('type', 'value', 'lineno', 'index')
+ """
+
+ __slots__ = ("type", "value", "lineno", "index")
+
def __repr__(self):
- return f'Token(type={self.type!r}, value={self.value!r}, lineno={self.lineno}, index={self.index})'
+ return f"Token(type={self.type!r}, value={self.value!r}, lineno={self.lineno}, index={self.index})"
+
class TokenStr(str):
@staticmethod
@@ -95,35 +107,38 @@ class TokenStr(str):
if self.remap is not None:
self.remap[self.key, key] = self.key
+
class _Before:
def __init__(self, tok, pattern):
self.tok = tok
self.pattern = pattern
+
class LexerMetaDict(dict):
- '''
+ """
Special dictionary that prohibits duplicate definitions in lexer specifications.
- '''
+ """
+
def __init__(self):
- self.before = { }
- self.delete = [ ]
- self.remap = { }
+ self.before = {}
+ self.delete = []
+ self.remap = {}
def __setitem__(self, key, value):
if isinstance(value, str):
value = TokenStr(value, key, self.remap)
-
+
if isinstance(value, _Before):
self.before[key] = value.tok
value = TokenStr(value.pattern, key, self.remap)
-
+
if key in self and not isinstance(value, property):
prior = self[key]
if isinstance(prior, str):
if callable(value):
value.pattern = prior
else:
- raise AttributeError(f'Name {key} redefined')
+ raise AttributeError(f"Name {key} redefined")
super().__setitem__(key, value)
@@ -135,41 +150,47 @@ class LexerMetaDict(dict):
return super().__delitem__(key)
def __getitem__(self, key):
- if key not in self and key.split('ignore_')[-1].isupper() and key[:1] != '_':
+ if key not in self and key.split("ignore_")[-1].isupper() and key[:1] != "_":
return TokenStr(key, key, self.remap)
else:
return super().__getitem__(key)
+
class LexerMeta(type):
- '''
+ """
Metaclass for collecting lexing rules
- '''
+ """
+
@classmethod
def __prepare__(meta, name, bases):
d = LexerMetaDict()
def _(pattern, *extra):
patterns = [pattern, *extra]
+
def decorate(func):
- pattern = '|'.join(f'({pat})' for pat in patterns )
- if hasattr(func, 'pattern'):
- func.pattern = pattern + '|' + func.pattern
+ pattern = "|".join(f"({pat})" for pat in patterns)
+ if hasattr(func, "pattern"):
+ func.pattern = pattern + "|" + func.pattern
else:
func.pattern = pattern
return func
+
return decorate
- d['_'] = _
- d['before'] = _Before
+ d["_"] = _
+ d["before"] = _Before
return d
def __new__(meta, clsname, bases, attributes):
- del attributes['_']
- del attributes['before']
+ del attributes["_"]
+ del attributes["before"]
# Create attributes for use in the actual class body
- cls_attributes = { str(key): str(val) if isinstance(val, TokenStr) else val
- for key, val in attributes.items() }
+ cls_attributes = {
+ str(key): str(val) if isinstance(val, TokenStr) else val
+ for key, val in attributes.items()
+ }
cls = super().__new__(meta, clsname, bases, cls_attributes)
# Attach various metadata to the class
@@ -180,11 +201,12 @@ class LexerMeta(type):
cls._build()
return cls
+
class Lexer(metaclass=LexerMeta):
# These attributes may be defined in subclasses
tokens = set()
literals = set()
- ignore = ''
+ ignore = ""
reflags = 0
regex_module = re
@@ -214,7 +236,7 @@ class Lexer(metaclass=LexerMeta):
# Such functions can be created with the @_ decorator or by defining
# function with the same name as a previously defined string.
#
- # This function is responsible for keeping rules in order.
+ # This function is responsible for keeping rules in order.
# Collect all previous rules from base classes
rules = []
@@ -222,15 +244,21 @@ class Lexer(metaclass=LexerMeta):
for base in cls.__bases__:
if isinstance(base, LexerMeta):
rules.extend(base._rules)
-
+
# Dictionary of previous rules
existing = dict(rules)
for key, value in cls._attributes.items():
- if (key in cls._token_names) or key.startswith('ignore_') or hasattr(value, 'pattern'):
- if callable(value) and not hasattr(value, 'pattern'):
- raise LexerBuildError(f"function {value} doesn't have a regex pattern")
-
+ if (
+ (key in cls._token_names)
+ or key.startswith("ignore_")
+ or hasattr(value, "pattern")
+ ):
+ if callable(value) and not hasattr(value, "pattern"):
+ raise LexerBuildError(
+ f"function {value} doesn't have a regex pattern"
+ )
+
if key in existing:
# The definition matches something that already existed in the base class.
# We replace it, but keep the original ordering
@@ -252,21 +280,27 @@ class Lexer(metaclass=LexerMeta):
rules.append((key, value))
existing[key] = value
- elif isinstance(value, str) and not key.startswith('_') and key not in {'ignore', 'literals'}:
- raise LexerBuildError(f'{key} does not match a name in tokens')
+ elif (
+ isinstance(value, str)
+ and not key.startswith("_")
+ and key not in {"ignore", "literals"}
+ ):
+ raise LexerBuildError(f"{key} does not match a name in tokens")
# Apply deletion rules
- rules = [ (key, value) for key, value in rules if key not in cls._delete ]
+ rules = [(key, value) for key, value in rules if key not in cls._delete]
cls._rules = rules
@classmethod
def _build(cls):
- '''
+ """
Build the lexer object from the collected tokens and regular expressions.
Validate the rules to make sure they look sane.
- '''
- if 'tokens' not in vars(cls):
- raise LexerBuildError(f'{cls.__qualname__} class does not define a tokens attribute')
+ """
+ if "tokens" not in vars(cls):
+ raise LexerBuildError(
+ f"{cls.__qualname__} class does not define a tokens attribute"
+ )
# Pull definitions created for any parent classes
cls._token_names = cls._token_names | set(cls.tokens)
@@ -282,17 +316,17 @@ class Lexer(metaclass=LexerMeta):
remapped_toks = set()
for d in cls._remapping.values():
remapped_toks.update(d.values())
-
+
undefined = remapped_toks - set(cls._token_names)
if undefined:
- missing = ', '.join(undefined)
- raise LexerBuildError(f'{missing} not included in token(s)')
+ missing = ", ".join(undefined)
+ raise LexerBuildError(f"{missing} not included in token(s)")
cls._collect_rules()
parts = []
for tokname, value in cls._rules:
- if tokname.startswith('ignore_'):
+ if tokname.startswith("ignore_"):
tokname = tokname[7:]
cls._ignored_tokens.add(tokname)
@@ -301,20 +335,20 @@ class Lexer(metaclass=LexerMeta):
elif callable(value):
cls._token_funcs[tokname] = value
- pattern = getattr(value, 'pattern')
+ pattern = getattr(value, "pattern")
# Form the regular expression component
- part = f'(?P<{tokname}>{pattern})'
+ part = f"(?P<{tokname}>{pattern})"
# Make sure the individual regex compiles properly
try:
cpat = cls.regex_module.compile(part, cls.reflags)
except Exception as e:
- raise PatternError(f'Invalid regex for token {tokname}') from e
+ raise PatternError(f"Invalid regex for token {tokname}") from e
# Verify that the pattern doesn't match the empty string
- if cpat.match(''):
- raise PatternError(f'Regex for token {tokname} matches empty input')
+ if cpat.match(""):
+ raise PatternError(f"Regex for token {tokname} matches empty input")
parts.append(part)
@@ -322,43 +356,45 @@ class Lexer(metaclass=LexerMeta):
return
# Form the master regular expression
- #previous = ('|' + cls._master_re.pattern) if cls._master_re else ''
+ # previous = ('|' + cls._master_re.pattern) if cls._master_re else ''
# cls._master_re = cls.regex_module.compile('|'.join(parts) + previous, cls.reflags)
- cls._master_re = cls.regex_module.compile('|'.join(parts), cls.reflags)
+ cls._master_re = cls.regex_module.compile("|".join(parts), cls.reflags)
# Verify that that ignore and literals specifiers match the input type
if not isinstance(cls.ignore, str):
- raise LexerBuildError('ignore specifier must be a string')
+ raise LexerBuildError("ignore specifier must be a string")
if not all(isinstance(lit, str) for lit in cls.literals):
- raise LexerBuildError('literals must be specified as strings')
+ raise LexerBuildError("literals must be specified as strings")
def begin(self, cls):
- '''
+ """
Begin a new lexer state
- '''
+ """
assert isinstance(cls, LexerMeta), "state must be a subclass of Lexer"
if self.__set_state:
self.__set_state(cls)
self.__class__ = cls
def push_state(self, cls):
- '''
+ """
Push a new lexer state onto the stack
- '''
+ """
if self.__state_stack is None:
self.__state_stack = []
self.__state_stack.append(type(self))
self.begin(cls)
def pop_state(self):
- '''
+ """
Pop a lexer state from the stack
- '''
+ """
self.begin(self.__state_stack.pop())
def tokenize(self, text, lineno=1, index=0):
- _ignored_tokens = _master_re = _ignore = _token_funcs = _literals = _remapping = None
+ _ignored_tokens = (
+ _master_re
+ ) = _ignore = _token_funcs = _literals = _remapping = None
def _set_state(cls):
nonlocal _ignored_tokens, _master_re, _ignore, _token_funcs, _literals, _remapping
@@ -419,7 +455,7 @@ class Lexer(metaclass=LexerMeta):
# A lexing error
self.index = index
self.lineno = lineno
- tok.type = 'ERROR'
+ tok.type = "ERROR"
tok.value = text[index:]
tok = self.error(tok)
if tok is not None:
@@ -436,4 +472,8 @@ class Lexer(metaclass=LexerMeta):
# Default implementations of the error handler. May be changed in subclasses
def error(self, t):
- raise LexError(f'Illegal character {t.value[0]!r} at index {self.index}', t.value, self.index)
+ raise LexError(
+ f"Illegal character {t.value[0]!r} at index {self.index}",
+ t.value,
+ self.index,
+ )
diff --git a/lib/sly/yacc.py b/lib/sly/yacc.py
index c30f13c..167168d 100644
--- a/lib/sly/yacc.py
+++ b/lib/sly/yacc.py
@@ -35,22 +35,25 @@ import sys
import inspect
from collections import OrderedDict, defaultdict
-__all__ = [ 'Parser' ]
+__all__ = ["Parser"]
+
class YaccError(Exception):
- '''
+ """
Exception raised for yacc-related build errors.
- '''
+ """
+
pass
-#-----------------------------------------------------------------------------
+
+# -----------------------------------------------------------------------------
# === User configurable parameters ===
#
-# Change these to modify the default behavior of yacc (if you wish).
+# Change these to modify the default behavior of yacc (if you wish).
# Move these parameters to the Yacc class itself.
-#-----------------------------------------------------------------------------
+# -----------------------------------------------------------------------------
-ERROR_COUNT = 3 # Number of symbols that must be shifted to leave recovery mode
+ERROR_COUNT = 3 # Number of symbols that must be shifted to leave recovery mode
MAXINT = sys.maxsize
# This object is a stand-in for a logging object created by the
@@ -59,20 +62,21 @@ MAXINT = sys.maxsize
# information, they can create their own logging object and pass
# it into SLY.
+
class SlyLogger(object):
def __init__(self, f):
self.f = f
def debug(self, msg, *args, **kwargs):
- self.f.write((msg % args) + '\n')
+ self.f.write((msg % args) + "\n")
info = debug
def warning(self, msg, *args, **kwargs):
- self.f.write('WARNING: ' + (msg % args) + '\n')
+ self.f.write("WARNING: " + (msg % args) + "\n")
def error(self, msg, *args, **kwargs):
- self.f.write('ERROR: ' + (msg % args) + '\n')
+ self.f.write("ERROR: " + (msg % args) + "\n")
critical = debug
@@ -86,6 +90,7 @@ class SlyLogger(object):
# .index = Starting lex position
# ----------------------------------------------------------------------
+
class YaccSymbol:
def __str__(self):
return self.type
@@ -93,19 +98,22 @@ class YaccSymbol:
def __repr__(self):
return str(self)
+
# ----------------------------------------------------------------------
# This class is a wrapper around the objects actually passed to each
# grammar rule. Index lookup and assignment actually assign the
# .value attribute of the underlying YaccSymbol object.
# The lineno() method returns the line number of a given
-# item (or 0 if not defined).
+# item (or 0 if not defined).
# ----------------------------------------------------------------------
+
class YaccProduction:
- __slots__ = ('_slice', '_namemap', '_stack')
+ __slots__ = ("_slice", "_namemap", "_stack")
+
def __init__(self, s, stack=None):
self._slice = s
- self._namemap = { }
+ self._namemap = {}
self._stack = stack
def __getitem__(self, n):
@@ -128,34 +136,35 @@ class YaccProduction:
for tok in self._slice:
if isinstance(tok, YaccSymbol):
continue
- lineno = getattr(tok, 'lineno', None)
+ lineno = getattr(tok, "lineno", None)
if lineno:
return lineno
- raise AttributeError('No line number found')
+ raise AttributeError("No line number found")
@property
def index(self):
for tok in self._slice:
if isinstance(tok, YaccSymbol):
continue
- index = getattr(tok, 'index', None)
+ index = getattr(tok, "index", None)
if index is not None:
return index
- raise AttributeError('No index attribute found')
+ raise AttributeError("No index attribute found")
def __getattr__(self, name):
if name in self._namemap:
return self._slice[self._namemap[name]].value
else:
- nameset = '{' + ', '.join(self._namemap) + '}'
- raise AttributeError(f'No symbol {name}. Must be one of {nameset}.')
+ nameset = "{" + ", ".join(self._namemap) + "}"
+ raise AttributeError(f"No symbol {name}. Must be one of {nameset}.")
def __setattr__(self, name, value):
- if name[:1] == '_':
+ if name[:1] == "_":
super().__setattr__(name, value)
else:
raise AttributeError(f"Can't reassign the value of attribute {name!r}")
+
# -----------------------------------------------------------------------------
# === Grammar Representation ===
#
@@ -187,19 +196,23 @@ class YaccProduction:
# usyms - Set of unique symbols found in the production
# -----------------------------------------------------------------------------
+
class Production(object):
reduced = 0
- def __init__(self, number, name, prod, precedence=('right', 0), func=None, file='', line=0):
- self.name = name
- self.prod = tuple(prod)
- self.number = number
- self.func = func
- self.file = file
- self.line = line
- self.prec = precedence
-
+
+ def __init__(
+ self, number, name, prod, precedence=("right", 0), func=None, file="", line=0
+ ):
+ self.name = name
+ self.prod = tuple(prod)
+ self.number = number
+ self.func = func
+ self.file = file
+ self.line = line
+ self.prec = precedence
+
# Internal settings used during table construction
- self.len = len(self.prod) # Length of the production
+ self.len = len(self.prod) # Length of the production
# Create a list of unique production symbols used in the production
self.usyms = []
@@ -216,33 +229,33 @@ class Production(object):
m[key] = indices[0]
else:
for n, index in enumerate(indices):
- m[key+str(n)] = index
+ m[key + str(n)] = index
self.namemap = m
-
+
# List of all LR items for the production
self.lr_items = []
self.lr_next = None
def __str__(self):
if self.prod:
- s = '%s -> %s' % (self.name, ' '.join(self.prod))
+ s = "%s -> %s" % (self.name, " ".join(self.prod))
else:
- s = f'{self.name} -> <empty>'
+ s = f"{self.name} -> <empty>"
if self.prec[1]:
- s += ' [precedence=%s, level=%d]' % self.prec
+ s += " [precedence=%s, level=%d]" % self.prec
return s
def __repr__(self):
- return f'Production({self})'
+ return f"Production({self})"
def __len__(self):
return len(self.prod)
def __nonzero__(self):
- raise RuntimeError('Used')
+ raise RuntimeError("Used")
return 1
def __getitem__(self, index):
@@ -255,15 +268,16 @@ class Production(object):
p = LRItem(self, n)
# Precompute the list of productions immediately following.
try:
- p.lr_after = Prodnames[p.prod[n+1]]
+ p.lr_after = Prodnames[p.prod[n + 1]]
except (IndexError, KeyError):
p.lr_after = []
try:
- p.lr_before = p.prod[n-1]
+ p.lr_before = p.prod[n - 1]
except IndexError:
p.lr_before = None
return p
+
# -----------------------------------------------------------------------------
# class LRItem
#
@@ -288,27 +302,29 @@ class Production(object):
# lr_before - Grammar symbol immediately before
# -----------------------------------------------------------------------------
+
class LRItem(object):
def __init__(self, p, n):
- self.name = p.name
- self.prod = list(p.prod)
- self.number = p.number
- self.lr_index = n
+ self.name = p.name
+ self.prod = list(p.prod)
+ self.number = p.number
+ self.lr_index = n
self.lookaheads = {}
- self.prod.insert(n, '.')
- self.prod = tuple(self.prod)
- self.len = len(self.prod)
- self.usyms = p.usyms
+ self.prod.insert(n, ".")
+ self.prod = tuple(self.prod)
+ self.len = len(self.prod)
+ self.usyms = p.usyms
def __str__(self):
if self.prod:
- s = '%s -> %s' % (self.name, ' '.join(self.prod))
+ s = "%s -> %s" % (self.name, " ".join(self.prod))
else:
- s = f'{self.name} -> <empty>'
+ s = f"{self.name} -> <empty>"
return s
def __repr__(self):
- return f'LRItem({self})'
+ return f"LRItem({self})"
+
# -----------------------------------------------------------------------------
# rightmost_terminal()
@@ -323,6 +339,7 @@ def rightmost_terminal(symbols, terminals):
i -= 1
return None
+
# -----------------------------------------------------------------------------
# === GRAMMAR CLASS ===
#
@@ -331,45 +348,52 @@ def rightmost_terminal(symbols, terminals):
# This data is used for critical parts of the table generation process later.
# -----------------------------------------------------------------------------
+
class GrammarError(YaccError):
pass
+
class Grammar(object):
def __init__(self, terminals):
- self.Productions = [None] # A list of all of the productions. The first
- # entry is always reserved for the purpose of
- # building an augmented grammar
+ self.Productions = [None] # A list of all of the productions. The first
+ # entry is always reserved for the purpose of
+ # building an augmented grammar
- self.Prodnames = {} # A dictionary mapping the names of nonterminals to a list of all
- # productions of that nonterminal.
+ self.Prodnames = (
+ {}
+ ) # A dictionary mapping the names of nonterminals to a list of all
+ # productions of that nonterminal.
- self.Prodmap = {} # A dictionary that is only used to detect duplicate
- # productions.
+ self.Prodmap = {} # A dictionary that is only used to detect duplicate
+ # productions.
- self.Terminals = {} # A dictionary mapping the names of terminal symbols to a
- # list of the rules where they are used.
+ self.Terminals = {} # A dictionary mapping the names of terminal symbols to a
+ # list of the rules where they are used.
for term in terminals:
self.Terminals[term] = []
- self.Terminals['error'] = []
-
- self.Nonterminals = {} # A dictionary mapping names of nonterminals to a list
- # of rule numbers where they are used.
+ self.Terminals["error"] = []
- self.First = {} # A dictionary of precomputed FIRST(x) symbols
+ self.Nonterminals = {} # A dictionary mapping names of nonterminals to a list
+ # of rule numbers where they are used.
- self.Follow = {} # A dictionary of precomputed FOLLOW(x) symbols
+ self.First = {} # A dictionary of precomputed FIRST(x) symbols
- self.Precedence = {} # Precedence rules for each terminal. Contains tuples of the
- # form ('right',level) or ('nonassoc', level) or ('left',level)
+ self.Follow = {} # A dictionary of precomputed FOLLOW(x) symbols
- self.UsedPrecedence = set() # Precedence rules that were actually used by the grammer.
- # This is only used to provide error checking and to generate
- # a warning about unused precedence rules.
+ self.Precedence = (
+ {}
+ ) # Precedence rules for each terminal. Contains tuples of the
+ # form ('right',level) or ('nonassoc', level) or ('left',level)
- self.Start = None # Starting symbol for the grammar
+ self.UsedPrecedence = (
+ set()
+ ) # Precedence rules that were actually used by the grammer.
+ # This is only used to provide error checking and to generate
+ # a warning about unused precedence rules.
+ self.Start = None # Starting symbol for the grammar
def __len__(self):
return len(self.Productions)
@@ -386,11 +410,15 @@ class Grammar(object):
# -----------------------------------------------------------------------------
def set_precedence(self, term, assoc, level):
- assert self.Productions == [None], 'Must call set_precedence() before add_production()'
+ assert self.Productions == [
+ None
+ ], "Must call set_precedence() before add_production()"
if term in self.Precedence:
- raise GrammarError(f'Precedence already specified for terminal {term!r}')
- if assoc not in ['left', 'right', 'nonassoc']:
- raise GrammarError(f"Associativity of {term!r} must be one of 'left','right', or 'nonassoc'")
+ raise GrammarError(f"Precedence already specified for terminal {term!r}")
+ if assoc not in ["left", "right", "nonassoc"]:
+ raise GrammarError(
+ f"Associativity of {term!r} must be one of 'left','right', or 'nonassoc'"
+ )
self.Precedence[term] = (assoc, level)
# -----------------------------------------------------------------------------
@@ -410,51 +438,65 @@ class Grammar(object):
# are valid and that %prec is used correctly.
# -----------------------------------------------------------------------------
- def add_production(self, prodname, syms, func=None, file='', line=0):
+ def add_production(self, prodname, syms, func=None, file="", line=0):
if prodname in self.Terminals:
- raise GrammarError(f'{file}:{line}: Illegal rule name {prodname!r}. Already defined as a token')
- if prodname == 'error':
- raise GrammarError(f'{file}:{line}: Illegal rule name {prodname!r}. error is a reserved word')
+ raise GrammarError(
+ f"{file}:{line}: Illegal rule name {prodname!r}. Already defined as a token"
+ )
+ if prodname == "error":
+ raise GrammarError(
+ f"{file}:{line}: Illegal rule name {prodname!r}. error is a reserved word"
+ )
# Look for literal tokens
for n, s in enumerate(syms):
if s[0] in "'\"" and s[0] == s[-1]:
c = s[1:-1]
- if (len(c) != 1):
- raise GrammarError(f'{file}:{line}: Literal token {s} in rule {prodname!r} may only be a single character')
+ if len(c) != 1:
+ raise GrammarError(
+ f"{file}:{line}: Literal token {s} in rule {prodname!r} may only be a single character"
+ )
if c not in self.Terminals:
self.Terminals[c] = []
syms[n] = c
continue
# Determine the precedence level
- if '%prec' in syms:
- if syms[-1] == '%prec':
- raise GrammarError(f'{file}:{line}: Syntax error. Nothing follows %%prec')
- if syms[-2] != '%prec':
- raise GrammarError(f'{file}:{line}: Syntax error. %prec can only appear at the end of a grammar rule')
+ if "%prec" in syms:
+ if syms[-1] == "%prec":
+ raise GrammarError(
+ f"{file}:{line}: Syntax error. Nothing follows %%prec"
+ )
+ if syms[-2] != "%prec":
+ raise GrammarError(
+ f"{file}:{line}: Syntax error. %prec can only appear at the end of a grammar rule"
+ )
precname = syms[-1]
prodprec = self.Precedence.get(precname)
if not prodprec:
- raise GrammarError(f'{file}:{line}: Nothing known about the precedence of {precname!r}')
+ raise GrammarError(
+ f"{file}:{line}: Nothing known about the precedence of {precname!r}"
+ )
else:
self.UsedPrecedence.add(precname)
- del syms[-2:] # Drop %prec from the rule
+ del syms[-2:] # Drop %prec from the rule
else:
# If no %prec, precedence is determined by the rightmost terminal symbol
precname = rightmost_terminal(syms, self.Terminals)
- prodprec = self.Precedence.get(precname, ('right', 0))
+ prodprec = self.Precedence.get(precname, ("right", 0))
# See if the rule is already in the rulemap
- map = '%s -> %s' % (prodname, syms)
+ map = "%s -> %s" % (prodname, syms)
if map in self.Prodmap:
m = self.Prodmap[map]
- raise GrammarError(f'{file}:{line}: Duplicate rule {m}. ' +
- f'Previous definition at {m.file}:{m.line}')
+ raise GrammarError(
+ f"{file}:{line}: Duplicate rule {m}. "
+ + f"Previous definition at {m.file}:{m.line}"
+ )
# From this point on, everything is valid. Create a new Production instance
- pnumber = len(self.Productions)
+ pnumber = len(self.Productions)
if prodname not in self.Nonterminals:
self.Nonterminals[prodname] = []
@@ -493,7 +535,7 @@ class Grammar(object):
start = self.Productions[1].name
if start not in self.Nonterminals:
- raise GrammarError(f'start symbol {start} undefined')
+ raise GrammarError(f"start symbol {start} undefined")
self.Productions[0] = Production(0, "S'", [start])
self.Nonterminals[start].append(0)
self.Start = start
@@ -535,7 +577,7 @@ class Grammar(object):
for t in self.Terminals:
terminates[t] = True
- terminates['$end'] = True
+ terminates["$end"] = True
# Nonterminals:
@@ -576,7 +618,7 @@ class Grammar(object):
infinite = []
for (s, term) in terminates.items():
if not term:
- if s not in self.Prodnames and s not in self.Terminals and s != 'error':
+ if s not in self.Prodnames and s not in self.Terminals and s != "error":
# s is used-but-not-defined, and we've already warned of that,
# so it would be overkill to say that it's also non-terminating.
pass
@@ -599,7 +641,7 @@ class Grammar(object):
continue
for s in p.prod:
- if s not in self.Prodnames and s not in self.Terminals and s != 'error':
+ if s not in self.Prodnames and s not in self.Terminals and s != "error":
result.append((s, p))
return result
@@ -612,7 +654,7 @@ class Grammar(object):
def unused_terminals(self):
unused_tok = []
for s, v in self.Terminals.items():
- if s != 'error' and not v:
+ if s != "error" and not v:
unused_tok.append(s)
return unused_tok
@@ -666,7 +708,7 @@ class Grammar(object):
# Add all the non-<empty> symbols of First[x] to the result.
for f in self.First[x]:
- if f == '<empty>':
+ if f == "<empty>":
x_produces_empty = True
else:
if f not in result:
@@ -683,7 +725,7 @@ class Grammar(object):
# There was no 'break' from the loop,
# so x_produces_empty was true for all x in beta,
# so beta produces empty as well.
- result.append('<empty>')
+ result.append("<empty>")
return result
@@ -700,7 +742,7 @@ class Grammar(object):
for t in self.Terminals:
self.First[t] = [t]
- self.First['$end'] = ['$end']
+ self.First["$end"] = ["$end"]
# Nonterminals:
@@ -745,7 +787,7 @@ class Grammar(object):
if not start:
start = self.Productions[1].name
- self.Follow[start] = ['$end']
+ self.Follow[start] = ["$end"]
while True:
didadd = False
@@ -754,15 +796,15 @@ class Grammar(object):
for i, B in enumerate(p.prod):
if B in self.Nonterminals:
# Okay. We got a non-terminal in a production
- fst = self._first(p.prod[i+1:])
+ fst = self._first(p.prod[i + 1 :])
hasempty = False
for f in fst:
- if f != '<empty>' and f not in self.Follow[B]:
+ if f != "<empty>" and f not in self.Follow[B]:
self.Follow[B].append(f)
didadd = True
- if f == '<empty>':
+ if f == "<empty>":
hasempty = True
- if hasempty or i == (len(p.prod)-1):
+ if hasempty or i == (len(p.prod) - 1):
# Add elements of follow(a) to follow(b)
for f in self.Follow[p.name]:
if f not in self.Follow[B]:
@@ -772,7 +814,6 @@ class Grammar(object):
break
return self.Follow
-
# -----------------------------------------------------------------------------
# build_lritems()
#
@@ -800,11 +841,11 @@ class Grammar(object):
lri = LRItem(p, i)
# Precompute the list of productions immediately following
try:
- lri.lr_after = self.Prodnames[lri.prod[i+1]]
+ lri.lr_after = self.Prodnames[lri.prod[i + 1]]
except (IndexError, KeyError):
lri.lr_after = []
try:
- lri.lr_before = lri.prod[i-1]
+ lri.lr_before = lri.prod[i - 1]
except IndexError:
lri.lr_before = None
@@ -816,33 +857,38 @@ class Grammar(object):
i += 1
p.lr_items = lr_items
-
# ----------------------------------------------------------------------
# Debugging output. Printing the grammar will produce a detailed
# description along with some diagnostics.
# ----------------------------------------------------------------------
def __str__(self):
out = []
- out.append('Grammar:\n')
+ out.append("Grammar:\n")
for n, p in enumerate(self.Productions):
- out.append(f'Rule {n:<5d} {p}')
-
+ out.append(f"Rule {n:<5d} {p}")
+
unused_terminals = self.unused_terminals()
if unused_terminals:
- out.append('\nUnused terminals:\n')
+ out.append("\nUnused terminals:\n")
for term in unused_terminals:
- out.append(f' {term}')
+ out.append(f" {term}")
- out.append('\nTerminals, with rules where they appear:\n')
+ out.append("\nTerminals, with rules where they appear:\n")
for term in sorted(self.Terminals):
- out.append('%-20s : %s' % (term, ' '.join(str(s) for s in self.Terminals[term])))
+ out.append(
+ "%-20s : %s" % (term, " ".join(str(s) for s in self.Terminals[term]))
+ )
- out.append('\nNonterminals, with rules where they appear:\n')
+ out.append("\nNonterminals, with rules where they appear:\n")
for nonterm in sorted(self.Nonterminals):
- out.append('%-20s : %s' % (nonterm, ' '.join(str(s) for s in self.Nonterminals[nonterm])))
+ out.append(
+ "%-20s : %s"
+ % (nonterm, " ".join(str(s) for s in self.Nonterminals[nonterm]))
+ )
+
+ out.append("")
+ return "\n".join(out)
- out.append('')
- return '\n'.join(out)
# -----------------------------------------------------------------------------
# === LR Generator ===
@@ -868,6 +914,7 @@ class Grammar(object):
# FP - Set-valued function
# ------------------------------------------------------------------------------
+
def digraph(X, R, FP):
N = {}
for x in X:
@@ -879,13 +926,14 @@ def digraph(X, R, FP):
traverse(x, N, stack, F, X, R, FP)
return F
+
def traverse(x, N, stack, F, X, R, FP):
stack.append(x)
d = len(stack)
N[x] = d
- F[x] = FP(x) # F(X) <- F'(x)
+ F[x] = FP(x) # F(X) <- F'(x)
- rel = R(x) # Get y's related to x
+ rel = R(x) # Get y's related to x
for y in rel:
if N[y] == 0:
traverse(y, N, stack, F, X, R, FP)
@@ -902,9 +950,11 @@ def traverse(x, N, stack, F, X, R, FP):
F[stack[-1]] = F[x]
element = stack.pop()
+
class LALRError(YaccError):
pass
+
# -----------------------------------------------------------------------------
# == LRGeneratedTable ==
#
@@ -912,26 +962,27 @@ class LALRError(YaccError):
# public methods except for write()
# -----------------------------------------------------------------------------
+
class LRTable(object):
def __init__(self, grammar):
self.grammar = grammar
# Internal attributes
- self.lr_action = {} # Action table
- self.lr_goto = {} # Goto table
- self.lr_productions = grammar.Productions # Copy of grammar Production array
- self.lr_goto_cache = {} # Cache of computed gotos
- self.lr0_cidhash = {} # Cache of closures
- self._add_count = 0 # Internal counter used to detect cycles
+ self.lr_action = {} # Action table
+ self.lr_goto = {} # Goto table
+ self.lr_productions = grammar.Productions # Copy of grammar Production array
+ self.lr_goto_cache = {} # Cache of computed gotos
+ self.lr0_cidhash = {} # Cache of closures
+ self._add_count = 0 # Internal counter used to detect cycles
# Diagonistic information filled in by the table generator
self.state_descriptions = OrderedDict()
- self.sr_conflict = 0
- self.rr_conflict = 0
- self.conflicts = [] # List of conflicts
+ self.sr_conflict = 0
+ self.rr_conflict = 0
+ self.conflicts = [] # List of conflicts
- self.sr_conflicts = []
- self.rr_conflicts = []
+ self.sr_conflicts = []
+ self.rr_conflicts = []
# Build the tables
self.grammar.build_lritems()
@@ -964,7 +1015,7 @@ class LRTable(object):
didadd = False
for j in J:
for x in j.lr_after:
- if getattr(x, 'lr0_added', 0) == self._add_count:
+ if getattr(x, "lr0_added", 0) == self._add_count:
continue
# Add B --> .G to J
J.append(x.lr_next)
@@ -1004,13 +1055,13 @@ class LRTable(object):
s[id(n)] = s1
gs.append(n)
s = s1
- g = s.get('$end')
+ g = s.get("$end")
if not g:
if gs:
g = self.lr0_closure(gs)
- s['$end'] = g
+ s["$end"] = g
else:
- s['$end'] = gs
+ s["$end"] = gs
self.lr_goto_cache[(id(I), x)] = g
return g
@@ -1105,7 +1156,7 @@ class LRTable(object):
for stateno, state in enumerate(C):
for p in state:
if p.lr_index < p.len - 1:
- t = (stateno, p.prod[p.lr_index+1])
+ t = (stateno, p.prod[p.lr_index + 1])
if t[1] in self.grammar.Nonterminals:
if t not in trans:
trans.append(t)
@@ -1128,14 +1179,14 @@ class LRTable(object):
g = self.lr0_goto(C[state], N)
for p in g:
if p.lr_index < p.len - 1:
- a = p.prod[p.lr_index+1]
+ a = p.prod[p.lr_index + 1]
if a in self.grammar.Terminals:
if a not in terms:
terms.append(a)
# This extra bit is to handle the start state
if state == 0 and N == self.grammar.Productions[0].prod[0]:
- terms.append('$end')
+ terms.append("$end")
return terms
@@ -1189,8 +1240,8 @@ class LRTable(object):
# -----------------------------------------------------------------------------
def compute_lookback_includes(self, C, trans, nullable):
- lookdict = {} # Dictionary of lookback relations
- includedict = {} # Dictionary of include relations
+ lookdict = {} # Dictionary of lookback relations
+ includedict = {} # Dictionary of include relations
# Make a dictionary of non-terminal transitions
dtrans = {}
@@ -1223,7 +1274,7 @@ class LRTable(object):
li = lr_index + 1
while li < p.len:
if p.prod[li] in self.grammar.Terminals:
- break # No forget it
+ break # No forget it
if p.prod[li] not in nullable:
break
li = li + 1
@@ -1231,8 +1282,8 @@ class LRTable(object):
# Appears to be a relation between (j,t) and (state,N)
includes.append((j, t))
- g = self.lr0_goto(C[j], t) # Go to next set
- j = self.lr0_cidhash.get(id(g), -1) # Go to next state
+ g = self.lr0_goto(C[j], t) # Go to next set
+ j = self.lr0_cidhash.get(id(g), -1) # Go to next state
# When we get here, j is the final state, now we have to locate the production
for r in C[j]:
@@ -1243,7 +1294,7 @@ class LRTable(object):
i = 0
# This look is comparing a production ". A B C" with "A B C ."
while i < r.lr_index:
- if r.prod[i] != p.prod[i+1]:
+ if r.prod[i] != p.prod[i + 1]:
break
i = i + 1
else:
@@ -1270,7 +1321,7 @@ class LRTable(object):
def compute_read_sets(self, C, ntrans, nullable):
FP = lambda x: self.dr_relation(C, x, nullable)
- R = lambda x: self.reads_relation(C, x, nullable)
+ R = lambda x: self.reads_relation(C, x, nullable)
F = digraph(ntrans, R, FP)
return F
@@ -1292,7 +1343,7 @@ class LRTable(object):
def compute_follow_sets(self, ntrans, readsets, inclsets):
FP = lambda x: readsets[x]
- R = lambda x: inclsets.get(x, [])
+ R = lambda x: inclsets.get(x, [])
F = digraph(ntrans, R, FP)
return F
@@ -1352,11 +1403,11 @@ class LRTable(object):
# -----------------------------------------------------------------------------
def lr_parse_table(self):
Productions = self.grammar.Productions
- Precedence = self.grammar.Precedence
- goto = self.lr_goto # Goto array
- action = self.lr_action # Action array
+ Precedence = self.grammar.Precedence
+ goto = self.lr_goto # Goto array
+ action = self.lr_action # Action array
- actionp = {} # Action production array (temporary)
+ actionp = {} # Action production array (temporary)
# Step 1: Construct C = { I0, I1, ... IN}, collection of LR(0) items
# This determines the number of states
@@ -1368,129 +1419,149 @@ class LRTable(object):
for st, I in enumerate(C):
descrip = []
# Loop over each production in I
- actlist = [] # List of actions
- st_action = {}
+ actlist = [] # List of actions
+ st_action = {}
st_actionp = {}
- st_goto = {}
+ st_goto = {}
- descrip.append(f'\nstate {st}\n')
+ descrip.append(f"\nstate {st}\n")
for p in I:
- descrip.append(f' ({p.number}) {p}')
+ descrip.append(f" ({p.number}) {p}")
for p in I:
- if p.len == p.lr_index + 1:
- if p.name == "S'":
- # Start symbol. Accept!
- st_action['$end'] = 0
- st_actionp['$end'] = p
- else:
- # We are at the end of a production. Reduce!
- laheads = p.lookaheads[st]
- for a in laheads:
- actlist.append((a, p, f'reduce using rule {p.number} ({p})'))
- r = st_action.get(a)
- if r is not None:
- # Have a shift/reduce or reduce/reduce conflict
- if r > 0:
- # Need to decide on shift or reduce here
- # By default we favor shifting. Need to add
- # some precedence rules here.
-
- # Shift precedence comes from the token
- sprec, slevel = Precedence.get(a, ('right', 0))
-
- # Reduce precedence comes from rule being reduced (p)
- rprec, rlevel = Productions[p.number].prec
-
- if (slevel < rlevel) or ((slevel == rlevel) and (rprec == 'left')):
- # We really need to reduce here.
- st_action[a] = -p.number
- st_actionp[a] = p
- if not slevel and not rlevel:
- descrip.append(f' ! shift/reduce conflict for {a} resolved as reduce')
- self.sr_conflicts.append((st, a, 'reduce'))
- Productions[p.number].reduced += 1
- elif (slevel == rlevel) and (rprec == 'nonassoc'):
- st_action[a] = None
- else:
- # Hmmm. Guess we'll keep the shift
- if not rlevel:
- descrip.append(f' ! shift/reduce conflict for {a} resolved as shift')
- self.sr_conflicts.append((st, a, 'shift'))
- elif r <= 0:
- # Reduce/reduce conflict. In this case, we favor the rule
- # that was defined first in the grammar file
- oldp = Productions[-r]
- pp = Productions[p.number]
- if oldp.line > pp.line:
- st_action[a] = -p.number
- st_actionp[a] = p
- chosenp, rejectp = pp, oldp
- Productions[p.number].reduced += 1
- Productions[oldp.number].reduced -= 1
- else:
- chosenp, rejectp = oldp, pp
- self.rr_conflicts.append((st, chosenp, rejectp))
- descrip.append(' ! reduce/reduce conflict for %s resolved using rule %d (%s)' %
- (a, st_actionp[a].number, st_actionp[a]))
+ if p.len == p.lr_index + 1:
+ if p.name == "S'":
+ # Start symbol. Accept!
+ st_action["$end"] = 0
+ st_actionp["$end"] = p
+ else:
+ # We are at the end of a production. Reduce!
+ laheads = p.lookaheads[st]
+ for a in laheads:
+ actlist.append(
+ (a, p, f"reduce using rule {p.number} ({p})")
+ )
+ r = st_action.get(a)
+ if r is not None:
+ # Have a shift/reduce or reduce/reduce conflict
+ if r > 0:
+ # Need to decide on shift or reduce here
+ # By default we favor shifting. Need to add
+ # some precedence rules here.
+
+ # Shift precedence comes from the token
+ sprec, slevel = Precedence.get(a, ("right", 0))
+
+ # Reduce precedence comes from rule being reduced (p)
+ rprec, rlevel = Productions[p.number].prec
+
+ if (slevel < rlevel) or (
+ (slevel == rlevel) and (rprec == "left")
+ ):
+ # We really need to reduce here.
+ st_action[a] = -p.number
+ st_actionp[a] = p
+ if not slevel and not rlevel:
+ descrip.append(
+ f" ! shift/reduce conflict for {a} resolved as reduce"
+ )
+ self.sr_conflicts.append((st, a, "reduce"))
+ Productions[p.number].reduced += 1
+ elif (slevel == rlevel) and (rprec == "nonassoc"):
+ st_action[a] = None
+ else:
+ # Hmmm. Guess we'll keep the shift
+ if not rlevel:
+ descrip.append(
+ f" ! shift/reduce conflict for {a} resolved as shift"
+ )
+ self.sr_conflicts.append((st, a, "shift"))
+ elif r <= 0:
+ # Reduce/reduce conflict. In this case, we favor the rule
+ # that was defined first in the grammar file
+ oldp = Productions[-r]
+ pp = Productions[p.number]
+ if oldp.line > pp.line:
+ st_action[a] = -p.number
+ st_actionp[a] = p
+ chosenp, rejectp = pp, oldp
+ Productions[p.number].reduced += 1
+ Productions[oldp.number].reduced -= 1
else:
- raise LALRError(f'Unknown conflict in state {st}')
+ chosenp, rejectp = oldp, pp
+ self.rr_conflicts.append((st, chosenp, rejectp))
+ descrip.append(
+ " ! reduce/reduce conflict for %s resolved using rule %d (%s)"
+ % (a, st_actionp[a].number, st_actionp[a])
+ )
else:
- st_action[a] = -p.number
- st_actionp[a] = p
- Productions[p.number].reduced += 1
- else:
- i = p.lr_index
- a = p.prod[i+1] # Get symbol right after the "."
- if a in self.grammar.Terminals:
- g = self.lr0_goto(I, a)
- j = self.lr0_cidhash.get(id(g), -1)
- if j >= 0:
- # We are in a shift state
- actlist.append((a, p, f'shift and go to state {j}'))
- r = st_action.get(a)
- if r is not None:
- # Whoa have a shift/reduce or shift/shift conflict
- if r > 0:
- if r != j:
- raise LALRError(f'Shift/shift conflict in state {st}')
- elif r <= 0:
- # Do a precedence check.
- # - if precedence of reduce rule is higher, we reduce.
- # - if precedence of reduce is same and left assoc, we reduce.
- # - otherwise we shift
- rprec, rlevel = Productions[st_actionp[a].number].prec
- sprec, slevel = Precedence.get(a, ('right', 0))
- if (slevel > rlevel) or ((slevel == rlevel) and (rprec == 'right')):
- # We decide to shift here... highest precedence to shift
- Productions[st_actionp[a].number].reduced -= 1
- st_action[a] = j
- st_actionp[a] = p
- if not rlevel:
- descrip.append(f' ! shift/reduce conflict for {a} resolved as shift')
- self.sr_conflicts.append((st, a, 'shift'))
- elif (slevel == rlevel) and (rprec == 'nonassoc'):
- st_action[a] = None
- else:
- # Hmmm. Guess we'll keep the reduce
- if not slevel and not rlevel:
- descrip.append(f' ! shift/reduce conflict for {a} resolved as reduce')
- self.sr_conflicts.append((st, a, 'reduce'))
-
+ raise LALRError(f"Unknown conflict in state {st}")
+ else:
+ st_action[a] = -p.number
+ st_actionp[a] = p
+ Productions[p.number].reduced += 1
+ else:
+ i = p.lr_index
+ a = p.prod[i + 1] # Get symbol right after the "."
+ if a in self.grammar.Terminals:
+ g = self.lr0_goto(I, a)
+ j = self.lr0_cidhash.get(id(g), -1)
+ if j >= 0:
+ # We are in a shift state
+ actlist.append((a, p, f"shift and go to state {j}"))
+ r = st_action.get(a)
+ if r is not None:
+ # Whoa have a shift/reduce or shift/shift conflict
+ if r > 0:
+ if r != j:
+ raise LALRError(
+ f"Shift/shift conflict in state {st}"
+ )
+ elif r <= 0:
+ # Do a precedence check.
+ # - if precedence of reduce rule is higher, we reduce.
+ # - if precedence of reduce is same and left assoc, we reduce.
+ # - otherwise we shift
+ rprec, rlevel = Productions[
+ st_actionp[a].number
+ ].prec
+ sprec, slevel = Precedence.get(a, ("right", 0))
+ if (slevel > rlevel) or (
+ (slevel == rlevel) and (rprec == "right")
+ ):
+ # We decide to shift here... highest precedence to shift
+ Productions[st_actionp[a].number].reduced -= 1
+ st_action[a] = j
+ st_actionp[a] = p
+ if not rlevel:
+ descrip.append(
+ f" ! shift/reduce conflict for {a} resolved as shift"
+ )
+ self.sr_conflicts.append((st, a, "shift"))
+ elif (slevel == rlevel) and (rprec == "nonassoc"):
+ st_action[a] = None
else:
- raise LALRError(f'Unknown conflict in state {st}')
+ # Hmmm. Guess we'll keep the reduce
+ if not slevel and not rlevel:
+ descrip.append(
+ f" ! shift/reduce conflict for {a} resolved as reduce"
+ )
+ self.sr_conflicts.append((st, a, "reduce"))
+
else:
- st_action[a] = j
- st_actionp[a] = p
+ raise LALRError(f"Unknown conflict in state {st}")
+ else:
+ st_action[a] = j
+ st_actionp[a] = p
# Print the actions associated with each terminal
_actprint = {}
for a, p, m in actlist:
if a in st_action:
if p is st_actionp[a]:
- descrip.append(f' {a:<15s} {m}')
+ descrip.append(f" {a:<15s} {m}")
_actprint[(a, m)] = 1
- descrip.append('')
+ descrip.append("")
# Construct the goto table for this state
nkeys = {}
@@ -1503,12 +1574,12 @@ class LRTable(object):
j = self.lr0_cidhash.get(id(g), -1)
if j >= 0:
st_goto[n] = j
- descrip.append(f' {n:<30s} shift and go to state {j}')
+ descrip.append(f" {n:<30s} shift and go to state {j}")
action[st] = st_action
actionp[st] = st_actionp
goto[st] = st_goto
- self.state_descriptions[st] = '\n'.join(descrip)
+ self.state_descriptions[st] = "\n".join(descrip)
# ----------------------------------------------------------------------
# Debugging output. Printing the LRTable object will produce a listing
@@ -1518,28 +1589,33 @@ class LRTable(object):
out = []
for descrip in self.state_descriptions.values():
out.append(descrip)
-
+
if self.sr_conflicts or self.rr_conflicts:
- out.append('\nConflicts:\n')
+ out.append("\nConflicts:\n")
for state, tok, resolution in self.sr_conflicts:
- out.append(f'shift/reduce conflict for {tok} in state {state} resolved as {resolution}')
+ out.append(
+ f"shift/reduce conflict for {tok} in state {state} resolved as {resolution}"
+ )
already_reported = set()
for state, rule, rejected in self.rr_conflicts:
if (state, id(rule), id(rejected)) in already_reported:
continue
- out.append(f'reduce/reduce conflict in state {state} resolved using rule {rule}')
- out.append(f'rejected rule ({rejected}) in state {state}')
+ out.append(
+ f"reduce/reduce conflict in state {state} resolved using rule {rule}"
+ )
+ out.append(f"rejected rule ({rejected}) in state {state}")
already_reported.add((state, id(rule), id(rejected)))
warned_never = set()
for state, rule, rejected in self.rr_conflicts:
if not rejected.reduced and (rejected not in warned_never):
- out.append(f'Rule ({rejected}) is never reduced')
+ out.append(f"Rule ({rejected}) is never reduced")
warned_never.add(rejected)
- return '\n'.join(out)
+ return "\n".join(out)
+
# Collect grammar rules from a function
def _collect_grammar_rules(func):
@@ -1549,70 +1625,80 @@ def _collect_grammar_rules(func):
unwrapped = inspect.unwrap(func)
filename = unwrapped.__code__.co_filename
lineno = unwrapped.__code__.co_firstlineno
- for rule, lineno in zip(func.rules, range(lineno+len(func.rules)-1, 0, -1)):
+ for rule, lineno in zip(func.rules, range(lineno + len(func.rules) - 1, 0, -1)):
syms = rule.split()
- if syms[1:2] == [':'] or syms[1:2] == ['::=']:
+ if syms[1:2] == [":"] or syms[1:2] == ["::="]:
grammar.append((func, filename, lineno, syms[0], syms[2:]))
else:
grammar.append((func, filename, lineno, prodname, syms))
- func = getattr(func, 'next_func', None)
+ func = getattr(func, "next_func", None)
return grammar
+
class ParserMetaDict(dict):
- '''
+ """
Dictionary that allows decorated grammar rule functions to be overloaded
- '''
+ """
+
def __setitem__(self, key, value):
- if key in self and callable(value) and hasattr(value, 'rules'):
+ if key in self and callable(value) and hasattr(value, "rules"):
value.next_func = self[key]
- if not hasattr(value.next_func, 'rules'):
- raise GrammarError(f'Redefinition of {key}. Perhaps an earlier {key} is missing @_')
+ if not hasattr(value.next_func, "rules"):
+ raise GrammarError(
+ f"Redefinition of {key}. Perhaps an earlier {key} is missing @_"
+ )
super().__setitem__(key, value)
-
+
def __getitem__(self, key):
- if key not in self and key.isupper() and key[:1] != '_':
+ if key not in self and key.isupper() and key[:1] != "_":
return key.upper()
else:
return super().__getitem__(key)
+
class ParserMeta(type):
@classmethod
def __prepare__(meta, *args, **kwargs):
d = ParserMetaDict()
+
def _(rule, *extra):
rules = [rule, *extra]
+
def decorate(func):
- func.rules = [ *getattr(func, 'rules', []), *rules[::-1] ]
+ func.rules = [*getattr(func, "rules", []), *rules[::-1]]
return func
+
return decorate
- d['_'] = _
+
+ d["_"] = _
return d
def __new__(meta, clsname, bases, attributes):
- del attributes['_']
+ del attributes["_"]
cls = super().__new__(meta, clsname, bases, attributes)
cls._build(list(attributes.items()))
return cls
+
class Parser(metaclass=ParserMeta):
# Logging object where debugging/diagnostic messages are sent
- log = SlyLogger(sys.stderr)
+ log = SlyLogger(sys.stderr)
# Debugging filename where parsetab.out data can be written
debugfile = None
@classmethod
def __validate_tokens(cls):
- if not hasattr(cls, 'tokens'):
- cls.log.error('No token list is defined')
+ if not hasattr(cls, "tokens"):
+ cls.log.error("No token list is defined")
return False
if not cls.tokens:
- cls.log.error('tokens is empty')
+ cls.log.error("tokens is empty")
return False
- if 'error' in cls.tokens:
+ if "error" in cls.tokens:
cls.log.error("Illegal token name 'error'. Is a reserved word")
return False
@@ -1620,28 +1706,32 @@ class Parser(metaclass=ParserMeta):
@classmethod
def __validate_precedence(cls):
- if not hasattr(cls, 'precedence'):
+ if not hasattr(cls, "precedence"):
cls.__preclist = []
return True
preclist = []
if not isinstance(cls.precedence, (list, tuple)):
- cls.log.error('precedence must be a list or tuple')
+ cls.log.error("precedence must be a list or tuple")
return False
for level, p in enumerate(cls.precedence, start=1):
if not isinstance(p, (list, tuple)):
- cls.log.error(f'Bad precedence table entry {p!r}. Must be a list or tuple')
+ cls.log.error(
+ f"Bad precedence table entry {p!r}. Must be a list or tuple"
+ )
return False
if len(p) < 2:
- cls.log.error(f'Malformed precedence entry {p!r}. Must be (assoc, term, ..., term)')
+ cls.log.error(
+ f"Malformed precedence entry {p!r}. Must be (assoc, term, ..., term)"
+ )
return False
if not all(isinstance(term, str) for term in p):
- cls.log.error('precedence items must be strings')
+ cls.log.error("precedence items must be strings")
return False
-
+
assoc = p[0]
preclist.extend((term, assoc, level) for term in p[1:])
@@ -1650,9 +1740,9 @@ class Parser(metaclass=ParserMeta):
@classmethod
def __validate_specification(cls):
- '''
+ """
Validate various parts of the grammar specification
- '''
+ """
if not cls.__validate_tokens():
return False
if not cls.__validate_precedence():
@@ -1661,14 +1751,14 @@ class Parser(metaclass=ParserMeta):
@classmethod
def __build_grammar(cls, rules):
- '''
+ """
Build the grammar from the grammar rules
- '''
+ """
grammar_rules = []
- errors = ''
+ errors = ""
# Check for non-empty symbols
if not rules:
- raise YaccError('No grammar rules are defined')
+ raise YaccError("No grammar rules are defined")
grammar = Grammar(cls.tokens)
@@ -1677,95 +1767,110 @@ class Parser(metaclass=ParserMeta):
try:
grammar.set_precedence(term, assoc, level)
except GrammarError as e:
- errors += f'{e}\n'
+ errors += f"{e}\n"
for name, func in rules:
try:
parsed_rule = _collect_grammar_rules(func)
for pfunc, rulefile, ruleline, prodname, syms in parsed_rule:
try:
- grammar.add_production(prodname, syms, pfunc, rulefile, ruleline)
+ grammar.add_production(
+ prodname, syms, pfunc, rulefile, ruleline
+ )
except GrammarError as e:
- errors += f'{e}\n'
+ errors += f"{e}\n"
except SyntaxError as e:
- errors += f'{e}\n'
+ errors += f"{e}\n"
try:
- grammar.set_start(getattr(cls, 'start', None))
+ grammar.set_start(getattr(cls, "start", None))
except GrammarError as e:
- errors += f'{e}\n'
+ errors += f"{e}\n"
undefined_symbols = grammar.undefined_symbols()
for sym, prod in undefined_symbols:
- errors += '%s:%d: Symbol %r used, but not defined as a token or a rule\n' % (prod.file, prod.line, sym)
+ errors += (
+ "%s:%d: Symbol %r used, but not defined as a token or a rule\n"
+ % (prod.file, prod.line, sym)
+ )
unused_terminals = grammar.unused_terminals()
if unused_terminals:
- unused_str = '{' + ','.join(unused_terminals) + '}'
- cls.log.warning(f'Token{"(s)" if len(unused_terminals) >1 else ""} {unused_str} defined, but not used')
+ unused_str = "{" + ",".join(unused_terminals) + "}"
+ cls.log.warning(
+ f'Token{"(s)" if len(unused_terminals) >1 else ""} {unused_str} defined, but not used'
+ )
unused_rules = grammar.unused_rules()
for prod in unused_rules:
- cls.log.warning('%s:%d: Rule %r defined, but not used', prod.file, prod.line, prod.name)
+ cls.log.warning(
+ "%s:%d: Rule %r defined, but not used", prod.file, prod.line, prod.name
+ )
if len(unused_terminals) == 1:
- cls.log.warning('There is 1 unused token')
+ cls.log.warning("There is 1 unused token")
if len(unused_terminals) > 1:
- cls.log.warning('There are %d unused tokens', len(unused_terminals))
+ cls.log.warning("There are %d unused tokens", len(unused_terminals))
if len(unused_rules) == 1:
- cls.log.warning('There is 1 unused rule')
+ cls.log.warning("There is 1 unused rule")
if len(unused_rules) > 1:
- cls.log.warning('There are %d unused rules', len(unused_rules))
+ cls.log.warning("There are %d unused rules", len(unused_rules))
unreachable = grammar.find_unreachable()
for u in unreachable:
- cls.log.warning('Symbol %r is unreachable', u)
+ cls.log.warning("Symbol %r is unreachable", u)
if len(undefined_symbols) == 0:
infinite = grammar.infinite_cycles()
for inf in infinite:
- errors += 'Infinite recursion detected for symbol %r\n' % inf
+ errors += "Infinite recursion detected for symbol %r\n" % inf
unused_prec = grammar.unused_precedence()
for term, assoc in unused_prec:
- errors += 'Precedence rule %r defined for unknown symbol %r\n' % (assoc, term)
+ errors += "Precedence rule %r defined for unknown symbol %r\n" % (
+ assoc,
+ term,
+ )
cls._grammar = grammar
if errors:
- raise YaccError('Unable to build grammar.\n'+errors)
+ raise YaccError("Unable to build grammar.\n" + errors)
@classmethod
def __build_lrtables(cls):
- '''
+ """
Build the LR Parsing tables from the grammar
- '''
+ """
lrtable = LRTable(cls._grammar)
num_sr = len(lrtable.sr_conflicts)
# Report shift/reduce and reduce/reduce conflicts
- if num_sr != getattr(cls, 'expected_shift_reduce', None):
+ if num_sr != getattr(cls, "expected_shift_reduce", None):
if num_sr == 1:
- cls.log.warning('1 shift/reduce conflict')
+ cls.log.warning("1 shift/reduce conflict")
elif num_sr > 1:
- cls.log.warning('%d shift/reduce conflicts', num_sr)
+ cls.log.warning("%d shift/reduce conflicts", num_sr)
num_rr = len(lrtable.rr_conflicts)
- if num_rr != getattr(cls, 'expected_reduce_reduce', None):
+ if num_rr != getattr(cls, "expected_reduce_reduce", None):
if num_rr == 1:
- cls.log.warning('1 reduce/reduce conflict')
+ cls.log.warning("1 reduce/reduce conflict")
elif num_rr > 1:
- cls.log.warning('%d reduce/reduce conflicts', num_rr)
+ cls.log.warning("%d reduce/reduce conflicts", num_rr)
cls._lrtable = lrtable
return True
@classmethod
def __collect_rules(cls, definitions):
- '''
+ """
Collect all of the tagged grammar rules
- '''
- rules = [ (name, value) for name, value in definitions
- if callable(value) and hasattr(value, 'rules') ]
+ """
+ rules = [
+ (name, value)
+ for name, value in definitions
+ if callable(value) and hasattr(value, "rules")
+ ]
return rules
# ----------------------------------------------------------------------
@@ -1775,7 +1880,7 @@ class Parser(metaclass=ParserMeta):
# ----------------------------------------------------------------------
@classmethod
def _build(cls, definitions):
- if vars(cls).get('_build', False):
+ if vars(cls).get("_build", False):
return
# Collect all of the grammar rules from the class definition
@@ -1783,77 +1888,89 @@ class Parser(metaclass=ParserMeta):
# Validate other parts of the grammar specification
if not cls.__validate_specification():
- raise YaccError('Invalid parser specification')
+ raise YaccError("Invalid parser specification")
# Build the underlying grammar object
cls.__build_grammar(rules)
# Build the LR tables
if not cls.__build_lrtables():
- raise YaccError('Can\'t build parsing tables')
+ raise YaccError("Can't build parsing tables")
if cls.debugfile:
- with open(cls.debugfile, 'w') as f:
+ with open(cls.debugfile, "w") as f:
f.write(str(cls._grammar))
- f.write('\n')
+ f.write("\n")
f.write(str(cls._lrtable))
- cls.log.info('Parser debugging for %s written to %s', cls.__qualname__, cls.debugfile)
+ cls.log.info(
+ "Parser debugging for %s written to %s", cls.__qualname__, cls.debugfile
+ )
# ----------------------------------------------------------------------
# Parsing Support. This is the parsing runtime that users use to
# ----------------------------------------------------------------------
def error(self, token):
- '''
+ """
Default error handling function. This may be subclassed.
- '''
+ """
if token:
- lineno = getattr(token, 'lineno', 0)
+ lineno = getattr(token, "lineno", 0)
if lineno:
- sys.stderr.write(f'sly: Syntax error at line {lineno}, token={token.type}\n')
+ sys.stderr.write(
+ f"sly: Syntax error at line {lineno}, token={token.type}\n"
+ )
else:
- sys.stderr.write(f'sly: Syntax error, token={token.type}')
+ sys.stderr.write(f"sly: Syntax error, token={token.type}")
else:
- sys.stderr.write('sly: Parse error in input. EOF\n')
-
+ sys.stderr.write("sly: Parse error in input. EOF\n")
+
def errok(self):
- '''
+ """
Clear the error status
- '''
+ """
self.errorok = True
def restart(self):
- '''
+ """
Force the parser to restart from a fresh state. Clears the statestack
- '''
+ """
del self.statestack[:]
del self.symstack[:]
sym = YaccSymbol()
- sym.type = '$end'
+ sym.type = "$end"
self.symstack.append(sym)
self.statestack.append(0)
self.state = 0
def parse(self, tokens):
- '''
+ """
Parse the given input tokens.
- '''
- lookahead = None # Current lookahead symbol
- lookaheadstack = [] # Stack of lookahead symbols
- actions = self._lrtable.lr_action # Local reference to action table (to avoid lookup on self.)
- goto = self._lrtable.lr_goto # Local reference to goto table (to avoid lookup on self.)
- prod = self._grammar.Productions # Local reference to production list (to avoid lookup on self.)
- defaulted_states = self._lrtable.defaulted_states # Local reference to defaulted states
- pslice = YaccProduction(None) # Production object passed to grammar rules
- errorcount = 0 # Used during error recovery
+ """
+ lookahead = None # Current lookahead symbol
+ lookaheadstack = [] # Stack of lookahead symbols
+ actions = (
+ self._lrtable.lr_action
+ ) # Local reference to action table (to avoid lookup on self.)
+ goto = (
+ self._lrtable.lr_goto
+ ) # Local reference to goto table (to avoid lookup on self.)
+ prod = (
+ self._grammar.Productions
+ ) # Local reference to production list (to avoid lookup on self.)
+ defaulted_states = (
+ self._lrtable.defaulted_states
+ ) # Local reference to defaulted states
+ pslice = YaccProduction(None) # Production object passed to grammar rules
+ errorcount = 0 # Used during error recovery
# Set up the state and symbol stacks
self.tokens = tokens
- self.statestack = statestack = [] # Stack of parsing states
- self.symstack = symstack = [] # Stack of grammar symbols
- pslice._stack = symstack # Associate the stack with the production
+ self.statestack = statestack = [] # Stack of parsing states
+ self.symstack = symstack = [] # Stack of grammar symbols
+ pslice._stack = symstack # Associate the stack with the production
self.restart()
- errtoken = None # Err token
+ errtoken = None # Err token
while True:
# Get the next symbol on the input. If a lookahead symbol
# is already set, we just use that. Otherwise, we'll pull
@@ -1866,7 +1983,7 @@ class Parser(metaclass=ParserMeta):
lookahead = lookaheadstack.pop()
if not lookahead:
lookahead = YaccSymbol()
- lookahead.type = '$end'
+ lookahead.type = "$end"
# Check the action table
ltype = lookahead.type
@@ -1892,14 +2009,14 @@ class Parser(metaclass=ParserMeta):
# reduce a symbol on the stack, emit a production
self.production = p = prod[-t]
pname = p.name
- plen = p.len
+ plen = p.len
pslice._namemap = p.namemap
# Call the production function
pslice._slice = symstack[-plen:] if plen else []
sym = YaccSymbol()
- sym.type = pname
+ sym.type = pname
value = p.func(self, pslice)
if value is pslice:
value = (pname, *(s.value for s in pslice._slice))
@@ -1915,7 +2032,7 @@ class Parser(metaclass=ParserMeta):
if t == 0:
n = symstack[-1]
- result = getattr(n, 'value', None)
+ result = getattr(n, "value", None)
return result
if t is None:
@@ -1932,8 +2049,8 @@ class Parser(metaclass=ParserMeta):
if errorcount == 0 or self.errorok:
errorcount = ERROR_COUNT
self.errorok = False
- if lookahead.type == '$end':
- errtoken = None # End of file!
+ if lookahead.type == "$end":
+ errtoken = None # End of file!
else:
errtoken = lookahead
@@ -1957,7 +2074,7 @@ class Parser(metaclass=ParserMeta):
# entire parse has been rolled back and we're completely hosed. The token is
# discarded and we just keep going.
- if len(statestack) <= 1 and lookahead.type != '$end':
+ if len(statestack) <= 1 and lookahead.type != "$end":
lookahead = None
self.state = 0
# Nuke the lookahead stack
@@ -1968,13 +2085,13 @@ class Parser(metaclass=ParserMeta):
# at the end of the file. nuke the top entry and generate an error token
# Start nuking entries on the stack
- if lookahead.type == '$end':
+ if lookahead.type == "$end":
# Whoa. We're really hosed here. Bail out
return
- if lookahead.type != 'error':
+ if lookahead.type != "error":
sym = symstack[-1]
- if sym.type == 'error':
+ if sym.type == "error":
# Hmmm. Error is on top of stack, we'll just nuke input
# symbol and continue
lookahead = None
@@ -1982,11 +2099,11 @@ class Parser(metaclass=ParserMeta):
# Create the error symbol for the first time and make it the new lookahead symbol
t = YaccSymbol()
- t.type = 'error'
+ t.type = "error"
- if hasattr(lookahead, 'lineno'):
+ if hasattr(lookahead, "lineno"):
t.lineno = lookahead.lineno
- if hasattr(lookahead, 'index'):
+ if hasattr(lookahead, "index"):
t.index = lookahead.index
t.value = lookahead
lookaheadstack.append(lookahead)
@@ -1998,4 +2115,4 @@ class Parser(metaclass=ParserMeta):
continue
# Call an error function here
- raise RuntimeError('sly: internal parser error!!!\n')
+ raise RuntimeError("sly: internal parser error!!!\n")
diff --git a/lib/utils.py b/lib/utils.py
index 26a591e..91dded0 100644
--- a/lib/utils.py
+++ b/lib/utils.py
@@ -21,15 +21,25 @@ def running_mean(x: np.ndarray, N: int) -> np.ndarray:
:param x: 1-Dimensional NumPy array
:param N: how many items to average
"""
+ # FIXME np.insert(x, 0, [x[0] for i in range(N/2)])
+ # FIXME np.insert(x, -1, [x[-1] for i in range(N/2)])
+ # (dabei ungerade N beachten)
cumsum = np.cumsum(np.insert(x, 0, 0))
return (cumsum[N:] - cumsum[:-N]) / N
def human_readable(value, unit):
- for prefix, factor in (('p', 1e-12), ('n', 1e-9), (u'µ', 1e-6), ('m', 1e-3), ('', 1), ('k', 1e3)):
+ for prefix, factor in (
+ ("p", 1e-12),
+ ("n", 1e-9),
+ (u"µ", 1e-6),
+ ("m", 1e-3),
+ ("", 1),
+ ("k", 1e3),
+ ):
if value < 1e3 * factor:
- return '{:.2f} {}{}'.format(value * (1 / factor), prefix, unit)
- return '{:.2f} {}'.format(value, unit)
+ return "{:.2f} {}{}".format(value * (1 / factor), prefix, unit)
+ return "{:.2f} {}".format(value, unit)
def is_numeric(n):
@@ -65,7 +75,7 @@ def soft_cast_int(n):
If `n` is empty, returns None.
If `n` is not numeric, it is left unchanged.
"""
- if n is None or n == '':
+ if n is None or n == "":
return None
try:
return int(n)
@@ -80,7 +90,7 @@ def soft_cast_float(n):
If `n` is empty, returns None.
If `n` is not numeric, it is left unchanged.
"""
- if n is None or n == '':
+ if n is None or n == "":
return None
try:
return float(n)
@@ -104,8 +114,8 @@ def parse_conf_str(conf_str):
Values are casted to float if possible and kept as-is otherwise.
"""
conf_dict = dict()
- for option in conf_str.split(','):
- key, value = option.split('=')
+ for option in conf_str.split(","):
+ key, value = option.split("=")
conf_dict[key] = soft_cast_float(value)
return conf_dict
@@ -118,7 +128,7 @@ def remove_index_from_tuple(parameters, index):
:param index: index of element which is to be removed
:returns: parameters tuple without the element at index
"""
- return (*parameters[:index], *parameters[index + 1:])
+ return (*parameters[:index], *parameters[index + 1 :])
def param_slice_eq(a, b, index):
@@ -137,7 +147,9 @@ def param_slice_eq(a, b, index):
('foo', [1, 4]), ('foo', [2, 4]), 1 -> False
"""
- if (*a[1][:index], *a[1][index + 1:]) == (*b[1][:index], *b[1][index + 1:]) and a[0] == b[0]:
+ if (*a[1][:index], *a[1][index + 1 :]) == (*b[1][:index], *b[1][index + 1 :]) and a[
+ 0
+ ] == b[0]:
return True
return False
@@ -164,20 +176,20 @@ def by_name_to_by_param(by_name: dict):
"""
by_param = dict()
for name in by_name.keys():
- for i, parameters in enumerate(by_name[name]['param']):
+ for i, parameters in enumerate(by_name[name]["param"]):
param_key = (name, tuple(parameters))
if param_key not in by_param:
by_param[param_key] = dict()
for key in by_name[name].keys():
by_param[param_key][key] = list()
- by_param[param_key]['attributes'] = by_name[name]['attributes']
+ by_param[param_key]["attributes"] = by_name[name]["attributes"]
# special case for PTA models
- if 'isa' in by_name[name]:
- by_param[param_key]['isa'] = by_name[name]['isa']
- for attribute in by_name[name]['attributes']:
+ if "isa" in by_name[name]:
+ by_param[param_key]["isa"] = by_name[name]["isa"]
+ for attribute in by_name[name]["attributes"]:
by_param[param_key][attribute].append(by_name[name][attribute][i])
# Required for match_parameter_valuse in _try_fits
- by_param[param_key]['param'].append(by_name[name]['param'][i])
+ by_param[param_key]["param"].append(by_name[name]["param"][i])
return by_param
@@ -197,14 +209,26 @@ def filter_aggregate_by_param(aggregate, parameters, parameter_filter):
param_value = soft_cast_int(param_name_and_value[1])
names_to_remove = set()
for name in aggregate.keys():
- indices_to_keep = list(map(lambda x: x[param_index] == param_value, aggregate[name]['param']))
- aggregate[name]['param'] = list(map(lambda iv: iv[1], filter(lambda iv: indices_to_keep[iv[0]], enumerate(aggregate[name]['param']))))
+ indices_to_keep = list(
+ map(lambda x: x[param_index] == param_value, aggregate[name]["param"])
+ )
+ aggregate[name]["param"] = list(
+ map(
+ lambda iv: iv[1],
+ filter(
+ lambda iv: indices_to_keep[iv[0]],
+ enumerate(aggregate[name]["param"]),
+ ),
+ )
+ )
if len(indices_to_keep) == 0:
- print('??? {}->{}'.format(parameter_filter, name))
+ print("??? {}->{}".format(parameter_filter, name))
names_to_remove.add(name)
else:
- for attribute in aggregate[name]['attributes']:
- aggregate[name][attribute] = aggregate[name][attribute][indices_to_keep]
+ for attribute in aggregate[name]["attributes"]:
+ aggregate[name][attribute] = aggregate[name][attribute][
+ indices_to_keep
+ ]
if len(aggregate[name][attribute]) == 0:
names_to_remove.add(name)
for name in names_to_remove:
@@ -218,25 +242,25 @@ class OptionalTimingAnalysis:
self.index = 1
def get_header(self):
- ret = ''
+ ret = ""
if self.enabled:
- ret += '#define TIMEIT(index, functioncall) '
- ret += 'counter.start(); '
- ret += 'functioncall; '
- ret += 'counter.stop();'
+ ret += "#define TIMEIT(index, functioncall) "
+ ret += "counter.start(); "
+ ret += "functioncall; "
+ ret += "counter.stop();"
ret += 'kout << endl << index << " :: " << counter.value << "/" << counter.overflow << endl;\n'
return ret
def wrap_codeblock(self, codeblock):
if not self.enabled:
return codeblock
- lines = codeblock.split('\n')
+ lines = codeblock.split("\n")
ret = list()
for line in lines:
- if re.fullmatch('.+;', line):
- ret.append('TIMEIT( {:d}, {} )'.format(self.index, line))
+ if re.fullmatch(".+;", line):
+ ret.append("TIMEIT( {:d}, {} )".format(self.index, line))
self.wrapped_lines.append(line)
self.index += 1
else:
ret.append(line)
- return '\n'.join(ret)
+ return "\n".join(ret)