diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2020-05-28 12:04:37 +0200 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2020-05-28 12:04:37 +0200 |
commit | c69331e4d925658b2bf26dcb387981f6530d7b9e (patch) | |
tree | d19c7f9b0bf51f68c104057e013630e009835268 /lib | |
parent | 23927051ac3e64cabbaa6c30e8356dfe90ebfa6c (diff) |
use black(1) for uniform code formatting
Diffstat (limited to 'lib')
-rw-r--r-- | lib/aspectc.py | 35 | ||||
-rwxr-xr-x | lib/automata.py | 705 | ||||
-rw-r--r-- | lib/codegen.py | 389 | ||||
-rw-r--r-- | lib/cycles_to_energy.py | 139 | ||||
-rw-r--r-- | lib/data_parameters.py | 337 | ||||
-rw-r--r-- | lib/dfatool.py | 1797 | ||||
-rw-r--r-- | lib/functions.py | 250 | ||||
-rw-r--r-- | lib/harness.py | 489 | ||||
-rwxr-xr-x | lib/ipython_energymodel_prelude.py | 8 | ||||
-rwxr-xr-x | lib/keysightdlog.py | 136 | ||||
-rw-r--r-- | lib/lex.py | 104 | ||||
-rw-r--r-- | lib/modular_arithmetic.py | 57 | ||||
-rw-r--r-- | lib/parameters.py | 450 | ||||
-rwxr-xr-x | lib/plotter.py | 261 | ||||
-rwxr-xr-x | lib/protocol_benchmarks.py | 1654 | ||||
-rw-r--r-- | lib/pubcode/__init__.py | 2 | ||||
-rw-r--r-- | lib/pubcode/code128.py | 283 | ||||
-rw-r--r-- | lib/runner.py | 156 | ||||
-rw-r--r-- | lib/size_to_radio_energy.py | 315 | ||||
-rw-r--r-- | lib/sly/__init__.py | 3 | ||||
-rw-r--r-- | lib/sly/ast.py | 9 | ||||
-rw-r--r-- | lib/sly/docparse.py | 15 | ||||
-rw-r--r-- | lib/sly/lex.py | 178 | ||||
-rw-r--r-- | lib/sly/yacc.py | 867 | ||||
-rw-r--r-- | lib/utils.py | 82 |
25 files changed, 5558 insertions, 3163 deletions
diff --git a/lib/aspectc.py b/lib/aspectc.py index 3229057..f4a102c 100644 --- a/lib/aspectc.py +++ b/lib/aspectc.py @@ -1,5 +1,6 @@ import xml.etree.ElementTree as ET + class AspectCClass: """ C++ class information provided by the AspectC++ repo.acp @@ -19,6 +20,7 @@ class AspectCClass: for function in self.functions: self.function[function.name] = function + class AspectCFunction: """ C++ function informationed provided by the AspectC++ repo.acp @@ -53,16 +55,23 @@ class AspectCFunction: :param function_node: `xml.etree.ElementTree.Element` node """ - name = function_node.get('name') - kind = function_node.get('kind') - function_id = function_node.get('id') + name = function_node.get("name") + kind = function_node.get("kind") + function_id = function_node.get("id") return_type = None argument_types = list() - for type_node in function_node.findall('result_type/Type'): - return_type = type_node.get('signature') - for type_node in function_node.findall('arg_types/Type'): - argument_types.append(type_node.get('signature')) - return cls(name = name, kind = kind, function_id = function_id, argument_types = argument_types, return_type = return_type) + for type_node in function_node.findall("result_type/Type"): + return_type = type_node.get("signature") + for type_node in function_node.findall("arg_types/Type"): + argument_types.append(type_node.get("signature")) + return cls( + name=name, + kind=kind, + function_id=function_id, + argument_types=argument_types, + return_type=return_type, + ) + class Repo: """ @@ -85,11 +94,13 @@ class Repo: def _load_classes(self): self.classes = list() - for class_node in self.root.findall('root/Namespace[@name="::"]/children/Class'): - name = class_node.get('name') - class_id = class_node.get('id') + for class_node in self.root.findall( + 'root/Namespace[@name="::"]/children/Class' + ): + name = class_node.get("name") + class_id = class_node.get("id") functions = list() - for function_node in class_node.findall('children/Function'): + for function_node in class_node.findall("children/Function"): function = AspectCFunction.from_function_node(function_node) functions.append(function) self.classes.append(AspectCClass(name, class_id, functions)) diff --git a/lib/automata.py b/lib/automata.py index b7668c5..b3318e0 100755 --- a/lib/automata.py +++ b/lib/automata.py @@ -28,7 +28,15 @@ class SimulationResult: :param mean_power: mean power during run in W """ - def __init__(self, duration: float, energy: float, end_state, parameters, duration_mae: float = None, energy_mae: float = None): + def __init__( + self, + duration: float, + energy: float, + end_state, + parameters, + duration_mae: float = None, + energy_mae: float = None, + ): u""" Create a new SimulationResult. @@ -77,7 +85,13 @@ class PTAAttribute: :param function_error: mean absolute error of function (optional) """ - def __init__(self, value: float = 0, function: AnalyticFunction = None, value_error=None, function_error=None): + def __init__( + self, + value: float = 0, + function: AnalyticFunction = None, + value_error=None, + function_error=None, + ): self.value = value self.function = function self.value_error = value_error @@ -85,8 +99,10 @@ class PTAAttribute: def __repr__(self): if self.function is not None: - return 'PTAATtribute<{:.0f}, {}>'.format(self.value, self.function._model_str) - return 'PTAATtribute<{:.0f}, None>'.format(self.value) + return "PTAATtribute<{:.0f}, {}>".format( + self.value, self.function._model_str + ) + return "PTAATtribute<{:.0f}, None>".format(self.value) def eval(self, param_dict=dict(), args=list()): """ @@ -108,33 +124,38 @@ class PTAAttribute: """ param_list = _dict_to_list(param_dict) if self.function and self.function.is_predictable(param_list): - return self.function_error['mae'] - return self.value_error['mae'] + return self.function_error["mae"] + return self.value_error["mae"] def to_json(self): ret = { - 'static': self.value, - 'static_error': self.value_error, + "static": self.value, + "static_error": self.value_error, } if self.function: - ret['function'] = { - 'raw': self.function._model_str, - 'regression_args': list(self.function._regression_args) + ret["function"] = { + "raw": self.function._model_str, + "regression_args": list(self.function._regression_args), } - ret['function_error'] = self.function_error + ret["function_error"] = self.function_error return ret @classmethod def from_json(cls, json_input: dict, parameters: dict): ret = cls() - if 'static' in json_input: - ret.value = json_input['static'] - if 'static_error' in json_input: - ret.value_error = json_input['static_error'] - if 'function' in json_input: - ret.function = AnalyticFunction(json_input['function']['raw'], parameters, 0, regression_args=json_input['function']['regression_args']) - if 'function_error' in json_input: - ret.function_error = json_input['function_error'] + if "static" in json_input: + ret.value = json_input["static"] + if "static_error" in json_input: + ret.value_error = json_input["static_error"] + if "function" in json_input: + ret.function = AnalyticFunction( + json_input["function"]["raw"], + parameters, + 0, + regression_args=json_input["function"]["regression_args"], + ) + if "function_error" in json_input: + ret.function_error = json_input["function_error"] return ret @classmethod @@ -145,16 +166,23 @@ class PTAAttribute: def _json_function_to_analytic_function(base, attribute: str, parameters: list): - if attribute in base and 'function' in base[attribute]: - base = base[attribute]['function'] - return AnalyticFunction(base['raw'], parameters, 0, regression_args=base['regression_args']) + if attribute in base and "function" in base[attribute]: + base = base[attribute]["function"] + return AnalyticFunction( + base["raw"], parameters, 0, regression_args=base["regression_args"] + ) return None class State: """A single PTA state.""" - def __init__(self, name: str, power: PTAAttribute = PTAAttribute(), power_function: AnalyticFunction = None): + def __init__( + self, + name: str, + power: PTAAttribute = PTAAttribute(), + power_function: AnalyticFunction = None, + ): u""" Create a new PTA state. @@ -173,10 +201,10 @@ class State: if type(power_function) is AnalyticFunction: self.power.function = power_function else: - raise ValueError('power_function must be None or AnalyticFunction') + raise ValueError("power_function must be None or AnalyticFunction") def __repr__(self): - return 'State<{:s}, {}>'.format(self.name, self.power) + return "State<{:s}, {}>".format(self.name, self.power) def add_outgoing_transition(self, new_transition: object): """Add a new outgoing transition.""" @@ -206,7 +234,11 @@ class State: try: return self.outgoing_transitions[transition_name] except KeyError: - raise ValueError('State {} has no outgoing transition called {}'.format(self.name, transition_name)) from None + raise ValueError( + "State {} has no outgoing transition called {}".format( + self.name, transition_name + ) + ) from None def has_interrupt_transitions(self) -> bool: """Return whether this state has any outgoing interrupt transitions.""" @@ -224,7 +256,9 @@ class State: :param parameters: current parameter values :returns: Transition object """ - interrupts = filter(lambda x: x.is_interrupt, self.outgoing_transitions.values()) + interrupts = filter( + lambda x: x.is_interrupt, self.outgoing_transitions.values() + ) interrupts = sorted(interrupts, key=lambda x: x.get_timeout(parameters)) return interrupts[0] @@ -246,15 +280,32 @@ class State: # TODO parametergewahrer Trace-Filter, z.B. "setHeaterDuration nur wenn bme680 power mode => FORCED und GAS_ENABLED" # A '$' entry in trace_filter indicates that the trace should (successfully) terminate here regardless of `depth`. - if trace_filter is not None and next(filter(lambda x: x == '$', map(lambda x: x[0], trace_filter)), None) is not None: + if ( + trace_filter is not None + and next( + filter(lambda x: x == "$", map(lambda x: x[0], trace_filter)), None + ) + is not None + ): yield [] # there may be other entries in trace_filter that still yield results. if depth == 0: for trans in self.outgoing_transitions.values(): - if trace_filter is not None and len(list(filter(lambda x: x == trans.name, map(lambda x: x[0], trace_filter)))) == 0: + if ( + trace_filter is not None + and len( + list( + filter( + lambda x: x == trans.name, + map(lambda x: x[0], trace_filter), + ) + ) + ) + == 0 + ): continue if with_arguments: - if trans.argument_combination == 'cartesian': + if trans.argument_combination == "cartesian": for args in itertools.product(*trans.argument_values): if sleep: yield [(None, sleep), (trans, args)] @@ -273,18 +324,35 @@ class State: yield [(trans,)] else: for trans in self.outgoing_transitions.values(): - if trace_filter is not None and next(filter(lambda x: x == trans.name, map(lambda x: x[0], trace_filter)), None) is None: + if ( + trace_filter is not None + and next( + filter( + lambda x: x == trans.name, map(lambda x: x[0], trace_filter) + ), + None, + ) + is None + ): continue if trace_filter is not None: - new_trace_filter = map(lambda x: x[1:], filter(lambda x: x[0] == trans.name, trace_filter)) + new_trace_filter = map( + lambda x: x[1:], + filter(lambda x: x[0] == trans.name, trace_filter), + ) new_trace_filter = list(filter(len, new_trace_filter)) if len(new_trace_filter) == 0: new_trace_filter = None else: new_trace_filter = None - for suffix in trans.destination.dfs(depth - 1, with_arguments=with_arguments, trace_filter=new_trace_filter, sleep=sleep): + for suffix in trans.destination.dfs( + depth - 1, + with_arguments=with_arguments, + trace_filter=new_trace_filter, + sleep=sleep, + ): if with_arguments: - if trans.argument_combination == 'cartesian': + if trans.argument_combination == "cartesian": for args in itertools.product(*trans.argument_values): if sleep: new_suffix = [(None, sleep), (trans, args)] @@ -314,10 +382,7 @@ class State: def to_json(self) -> dict: """Return JSON encoding of this state object.""" - ret = { - 'name': self.name, - 'power': self.power.to_json() - } + ret = {"name": self.name, "power": self.power.to_json()} return ret @@ -345,19 +410,27 @@ class Transition: :param codegen: todo """ - def __init__(self, orig_state: State, dest_state: State, name: str, - energy: PTAAttribute = PTAAttribute(), energy_function: AnalyticFunction = None, - duration: PTAAttribute = PTAAttribute(), duration_function: AnalyticFunction = None, - timeout: PTAAttribute = PTAAttribute(), timeout_function: AnalyticFunction = None, - is_interrupt: bool = False, - arguments: list = [], - argument_values: list = [], - argument_combination: str = 'cartesian', # or 'zip' - param_update_function=None, - arg_to_param_map: dict = None, - set_param=None, - return_value_handlers: list = [], - codegen=dict()): + def __init__( + self, + orig_state: State, + dest_state: State, + name: str, + energy: PTAAttribute = PTAAttribute(), + energy_function: AnalyticFunction = None, + duration: PTAAttribute = PTAAttribute(), + duration_function: AnalyticFunction = None, + timeout: PTAAttribute = PTAAttribute(), + timeout_function: AnalyticFunction = None, + is_interrupt: bool = False, + arguments: list = [], + argument_values: list = [], + argument_combination: str = "cartesian", # or 'zip' + param_update_function=None, + arg_to_param_map: dict = None, + set_param=None, + return_value_handlers: list = [], + codegen=dict(), + ): """ Create a new transition between two PTA states. @@ -400,8 +473,8 @@ class Transition: self.timeout.function = timeout_function for handler in self.return_value_handlers: - if 'formula' in handler: - handler['formula'] = NormalizationFunction(handler['formula']) + if "formula" in handler: + handler["formula"] = NormalizationFunction(handler["formula"]) def get_duration(self, param_dict: dict = {}, args: list = []) -> float: u""" @@ -465,25 +538,25 @@ class Transition: def to_json(self) -> dict: """Return JSON encoding of this transition object.""" ret = { - 'name': self.name, - 'origin': self.origin.name, - 'destination': self.destination.name, - 'is_interrupt': self.is_interrupt, - 'arguments': self.arguments, - 'argument_values': self.argument_values, - 'argument_combination': self.argument_combination, - 'arg_to_param_map': self.arg_to_param_map, - 'set_param': self.set_param, - 'duration': self.duration.to_json(), - 'energy': self.energy.to_json(), - 'timeout': self.timeout.to_json() + "name": self.name, + "origin": self.origin.name, + "destination": self.destination.name, + "is_interrupt": self.is_interrupt, + "arguments": self.arguments, + "argument_values": self.argument_values, + "argument_combination": self.argument_combination, + "arg_to_param_map": self.arg_to_param_map, + "set_param": self.set_param, + "duration": self.duration.to_json(), + "energy": self.energy.to_json(), + "timeout": self.timeout.to_json(), } return ret def _json_get_static(base, attribute: str): if attribute in base: - return base[attribute]['static'] + return base[attribute]["static"] return 0 @@ -505,10 +578,15 @@ class PTA: :param transitions: list of `Transition` objects """ - def __init__(self, state_names: list = [], - accepting_states: list = None, - parameters: list = [], initial_param_values: list = None, - codegen: dict = {}, parameter_normalization: dict = None): + def __init__( + self, + state_names: list = [], + accepting_states: list = None, + parameters: list = [], + initial_param_values: list = None, + codegen: dict = {}, + parameter_normalization: dict = None, + ): """ Return a new PTA object. @@ -524,7 +602,9 @@ class PTA: `enum`: maps enum descriptors (keys) to parameter values. Note that the mapping is not required to correspond to the driver API. `formula`: maps an argument or return value (passed as `param`) to a parameter value. Must be a string describing a valid python lambda function. NumPy is available as `np`. """ - self.state = dict([[state_name, State(state_name)] for state_name in state_names]) + self.state = dict( + [[state_name, State(state_name)] for state_name in state_names] + ) self.accepting_states = accepting_states.copy() if accepting_states else None self.parameters = parameters.copy() self.parameter_normalization = parameter_normalization @@ -535,13 +615,15 @@ class PTA: self.initial_param_values = [None for x in self.parameters] self.transitions = [] - if 'UNINITIALIZED' not in state_names: - self.state['UNINITIALIZED'] = State('UNINITIALIZED') + if "UNINITIALIZED" not in state_names: + self.state["UNINITIALIZED"] = State("UNINITIALIZED") if self.parameter_normalization: for normalization_spec in self.parameter_normalization.values(): - if 'formula' in normalization_spec: - normalization_spec['formula'] = NormalizationFunction(normalization_spec['formula']) + if "formula" in normalization_spec: + normalization_spec["formula"] = NormalizationFunction( + normalization_spec["formula"] + ) def normalize_parameter(self, parameter_name: str, parameter_value) -> float: """ @@ -553,11 +635,23 @@ class PTA: :param parameter_name: parameter name. :param parameter_value: parameter value. """ - if parameter_value is not None and self.parameter_normalization is not None and parameter_name in self.parameter_normalization: - if 'enum' in self.parameter_normalization[parameter_name] and parameter_value in self.parameter_normalization[parameter_name]['enum']: - return self.parameter_normalization[parameter_name]['enum'][parameter_value] - if 'formula' in self.parameter_normalization[parameter_name]: - normalization_formula = self.parameter_normalization[parameter_name]['formula'] + if ( + parameter_value is not None + and self.parameter_normalization is not None + and parameter_name in self.parameter_normalization + ): + if ( + "enum" in self.parameter_normalization[parameter_name] + and parameter_value + in self.parameter_normalization[parameter_name]["enum"] + ): + return self.parameter_normalization[parameter_name]["enum"][ + parameter_value + ] + if "formula" in self.parameter_normalization[parameter_name]: + normalization_formula = self.parameter_normalization[parameter_name][ + "formula" + ] return normalization_formula.eval(parameter_value) return parameter_value @@ -580,8 +674,8 @@ class PTA: @classmethod def from_file(cls, model_file: str): """Return PTA loaded from the provided JSON or YAML file.""" - with open(model_file, 'r') as f: - if '.json' in model_file: + with open(model_file, "r") as f: + if ".json" in model_file: return cls.from_json(json.load(f)) else: return cls.from_yaml(yaml.safe_load(f)) @@ -593,36 +687,58 @@ class PTA: Compatible with the to_json method. """ - if 'transition' in json_input: + if "transition" in json_input: return cls.from_legacy_json(json_input) kwargs = dict() - for key in ('state_names', 'parameters', 'initial_param_values', 'accepting_states'): + for key in ( + "state_names", + "parameters", + "initial_param_values", + "accepting_states", + ): if key in json_input: kwargs[key] = json_input[key] pta = cls(**kwargs) - for name, state in json_input['state'].items(): - pta.add_state(name, power=PTAAttribute.from_json_maybe(state, 'power', pta.parameters)) - for transition in json_input['transitions']: + for name, state in json_input["state"].items(): + pta.add_state( + name, power=PTAAttribute.from_json_maybe(state, "power", pta.parameters) + ) + for transition in json_input["transitions"]: kwargs = dict() - for key in ['arguments', 'argument_values', 'argument_combination', 'is_interrupt', 'set_param']: + for key in [ + "arguments", + "argument_values", + "argument_combination", + "is_interrupt", + "set_param", + ]: if key in transition: kwargs[key] = transition[key] # arg_to_param_map uses integer indices. This is not supported by JSON - if 'arg_to_param_map' in transition: - kwargs['arg_to_param_map'] = dict() - for arg_index, param_name in transition['arg_to_param_map'].items(): - kwargs['arg_to_param_map'][int(arg_index)] = param_name - origins = transition['origin'] + if "arg_to_param_map" in transition: + kwargs["arg_to_param_map"] = dict() + for arg_index, param_name in transition["arg_to_param_map"].items(): + kwargs["arg_to_param_map"][int(arg_index)] = param_name + origins = transition["origin"] if type(origins) != list: origins = [origins] for origin in origins: - pta.add_transition(origin, transition['destination'], - transition['name'], - duration=PTAAttribute.from_json_maybe(transition, 'duration', pta.parameters), - energy=PTAAttribute.from_json_maybe(transition, 'energy', pta.parameters), - timeout=PTAAttribute.from_json_maybe(transition, 'timeout', pta.parameters), - **kwargs) + pta.add_transition( + origin, + transition["destination"], + transition["name"], + duration=PTAAttribute.from_json_maybe( + transition, "duration", pta.parameters + ), + energy=PTAAttribute.from_json_maybe( + transition, "energy", pta.parameters + ), + timeout=PTAAttribute.from_json_maybe( + transition, "timeout", pta.parameters + ), + **kwargs + ) return pta @@ -634,37 +750,45 @@ class PTA: Compatible with the legacy dfatool/perl format. """ kwargs = { - 'parameters': list(), - 'initial_param_values': list(), + "parameters": list(), + "initial_param_values": list(), } - for param in sorted(json_input['parameter'].keys()): - kwargs['parameters'].append(param) - kwargs['initial_param_values'].append(json_input['parameter'][param]['default']) + for param in sorted(json_input["parameter"].keys()): + kwargs["parameters"].append(param) + kwargs["initial_param_values"].append( + json_input["parameter"][param]["default"] + ) pta = cls(**kwargs) - for name, state in json_input['state'].items(): - pta.add_state(name, power=PTAAttribute(value=float(state['power']['static']))) + for name, state in json_input["state"].items(): + pta.add_state( + name, power=PTAAttribute(value=float(state["power"]["static"])) + ) - for trans_name in sorted(json_input['transition'].keys()): - transition = json_input['transition'][trans_name] - destination = transition['destination'] + for trans_name in sorted(json_input["transition"].keys()): + transition = json_input["transition"][trans_name] + destination = transition["destination"] arguments = list() argument_values = list() is_interrupt = False - if transition['level'] == 'epilogue': + if transition["level"] == "epilogue": is_interrupt = True if type(destination) == list: destination = destination[0] - for arg in transition['parameters']: - arguments.append(arg['name']) - argument_values.append(arg['values']) - for origin in transition['origins']: - pta.add_transition(origin, destination, trans_name, - arguments=arguments, - argument_values=argument_values, - is_interrupt=is_interrupt) + for arg in transition["parameters"]: + arguments.append(arg["name"]) + argument_values.append(arg["values"]) + for origin in transition["origins"]: + pta.add_transition( + origin, + destination, + trans_name, + arguments=arguments, + argument_values=argument_values, + is_interrupt=is_interrupt, + ) return pta @@ -674,68 +798,79 @@ class PTA: kwargs = dict() - if 'parameters' in yaml_input: - kwargs['parameters'] = yaml_input['parameters'] + if "parameters" in yaml_input: + kwargs["parameters"] = yaml_input["parameters"] - if 'initial_param_values' in yaml_input: - kwargs['initial_param_values'] = yaml_input['initial_param_values'] + if "initial_param_values" in yaml_input: + kwargs["initial_param_values"] = yaml_input["initial_param_values"] - if 'states' in yaml_input: - kwargs['state_names'] = yaml_input['states'] + if "states" in yaml_input: + kwargs["state_names"] = yaml_input["states"] # else: set to UNINITIALIZED by class constructor - if 'codegen' in yaml_input: - kwargs['codegen'] = yaml_input['codegen'] + if "codegen" in yaml_input: + kwargs["codegen"] = yaml_input["codegen"] - if 'parameter_normalization' in yaml_input: - kwargs['parameter_normalization'] = yaml_input['parameter_normalization'] + if "parameter_normalization" in yaml_input: + kwargs["parameter_normalization"] = yaml_input["parameter_normalization"] pta = cls(**kwargs) - if 'state' in yaml_input: - for state_name, state in yaml_input['state'].items(): - pta.add_state(state_name, power=PTAAttribute.from_json_maybe(state, 'power', pta.parameters)) + if "state" in yaml_input: + for state_name, state in yaml_input["state"].items(): + pta.add_state( + state_name, + power=PTAAttribute.from_json_maybe(state, "power", pta.parameters), + ) - for trans_name in sorted(yaml_input['transition'].keys()): + for trans_name in sorted(yaml_input["transition"].keys()): kwargs = dict() - transition = yaml_input['transition'][trans_name] + transition = yaml_input["transition"][trans_name] arguments = list() argument_values = list() arg_to_param_map = dict() - if 'arguments' in transition: - for i, argument in enumerate(transition['arguments']): - arguments.append(argument['name']) - argument_values.append(argument['values']) - if 'parameter' in argument: - arg_to_param_map[i] = argument['parameter'] - if 'argument_combination' in transition: - kwargs['argument_combination'] = transition['argument_combination'] - if 'set_param' in transition: - kwargs['set_param'] = transition['set_param'] - if 'is_interrupt' in transition: - kwargs['is_interrupt'] = transition['is_interrupt'] - if 'return_value' in transition: - kwargs['return_value_handlers'] = transition['return_value'] - if 'codegen' in transition: - kwargs['codegen'] = transition['codegen'] - if 'loop' in transition: - for state_name in transition['loop']: - pta.add_transition(state_name, state_name, trans_name, - arguments=arguments, - argument_values=argument_values, - arg_to_param_map=arg_to_param_map, - **kwargs) + if "arguments" in transition: + for i, argument in enumerate(transition["arguments"]): + arguments.append(argument["name"]) + argument_values.append(argument["values"]) + if "parameter" in argument: + arg_to_param_map[i] = argument["parameter"] + if "argument_combination" in transition: + kwargs["argument_combination"] = transition["argument_combination"] + if "set_param" in transition: + kwargs["set_param"] = transition["set_param"] + if "is_interrupt" in transition: + kwargs["is_interrupt"] = transition["is_interrupt"] + if "return_value" in transition: + kwargs["return_value_handlers"] = transition["return_value"] + if "codegen" in transition: + kwargs["codegen"] = transition["codegen"] + if "loop" in transition: + for state_name in transition["loop"]: + pta.add_transition( + state_name, + state_name, + trans_name, + arguments=arguments, + argument_values=argument_values, + arg_to_param_map=arg_to_param_map, + **kwargs + ) else: - if 'src' not in transition: - transition['src'] = ['UNINITIALIZED'] - if 'dst' not in transition: - transition['dst'] = 'UNINITIALIZED' - for origin in transition['src']: - pta.add_transition(origin, transition['dst'], trans_name, - arguments=arguments, - argument_values=argument_values, - arg_to_param_map=arg_to_param_map, - **kwargs) + if "src" not in transition: + transition["src"] = ["UNINITIALIZED"] + if "dst" not in transition: + transition["dst"] = "UNINITIALIZED" + for origin in transition["src"]: + pta.add_transition( + origin, + transition["dst"], + trans_name, + arguments=arguments, + argument_values=argument_values, + arg_to_param_map=arg_to_param_map, + **kwargs + ) return pta @@ -746,11 +881,13 @@ class PTA: Compatible with the from_json method. """ ret = { - 'parameters': self.parameters, - 'initial_param_values': self.initial_param_values, - 'state': dict([[state.name, state.to_json()] for state in self.state.values()]), - 'transitions': [trans.to_json() for trans in self.transitions], - 'accepting_states': self.accepting_states, + "parameters": self.parameters, + "initial_param_values": self.initial_param_values, + "state": dict( + [[state.name, state.to_json()] for state in self.state.values()] + ), + "transitions": [trans.to_json() for trans in self.transitions], + "accepting_states": self.accepting_states, } return ret @@ -760,12 +897,19 @@ class PTA: See the State() documentation for acceptable arguments. """ - if 'power_function' in kwargs and type(kwargs['power_function']) != AnalyticFunction and kwargs['power_function'] is not None: - kwargs['power_function'] = AnalyticFunction(kwargs['power_function'], - self.parameters, 0) + if ( + "power_function" in kwargs + and type(kwargs["power_function"]) != AnalyticFunction + and kwargs["power_function"] is not None + ): + kwargs["power_function"] = AnalyticFunction( + kwargs["power_function"], self.parameters, 0 + ) self.state[state_name] = State(state_name, **kwargs) - def add_transition(self, orig_state: str, dest_state: str, function_name: str, **kwargs): + def add_transition( + self, orig_state: str, dest_state: str, function_name: str, **kwargs + ): """ Add function_name as new transition from orig_state to dest_state. @@ -776,8 +920,12 @@ class PTA: """ orig_state = self.state[orig_state] dest_state = self.state[dest_state] - for key in ('duration_function', 'energy_function', 'timeout_function'): - if key in kwargs and kwargs[key] is not None and type(kwargs[key]) != AnalyticFunction: + for key in ("duration_function", "energy_function", "timeout_function"): + if ( + key in kwargs + and kwargs[key] is not None + and type(kwargs[key]) != AnalyticFunction + ): kwargs[key] = AnalyticFunction(kwargs[key], self.parameters, 0) new_transition = Transition(orig_state, dest_state, function_name, **kwargs) @@ -824,7 +972,12 @@ class PTA: return self.get_unique_transitions().index(transition) def get_initial_param_dict(self): - return dict([[self.parameters[i], self.initial_param_values[i]] for i in range(len(self.parameters))]) + return dict( + [ + [self.parameters[i], self.initial_param_values[i]] + for i in range(len(self.parameters)) + ] + ) def set_random_energy_model(self, static_model=True): u""" @@ -841,18 +994,24 @@ class PTA: def get_most_expensive_state(self): max_state = None for state in self.state.values(): - if state.name != 'UNINITIALIZED' and (max_state is None or state.power.value > max_state.power.value): + if state.name != "UNINITIALIZED" and ( + max_state is None or state.power.value > max_state.power.value + ): max_state = state return max_state def get_least_expensive_state(self): min_state = None for state in self.state.values(): - if state.name != 'UNINITIALIZED' and (min_state is None or state.power.value < min_state.power.value): + if state.name != "UNINITIALIZED" and ( + min_state is None or state.power.value < min_state.power.value + ): min_state = state return min_state - def min_duration_until_energy_overflow(self, energy_granularity=1e-12, max_energy_value=2 ** 32 - 1): + def min_duration_until_energy_overflow( + self, energy_granularity=1e-12, max_energy_value=2 ** 32 - 1 + ): """ Return minimum duration (in s) until energy counter overflow during online accounting. @@ -862,7 +1021,9 @@ class PTA: max_power_state = self.get_most_expensive_state() if max_power_state.has_interrupt_transitions(): - raise RuntimeWarning('state with maximum power consumption has outgoing interrupt transitions, results will be inaccurate') + raise RuntimeWarning( + "state with maximum power consumption has outgoing interrupt transitions, results will be inaccurate" + ) # convert from µW to W max_power = max_power_state.power.value * 1e-6 @@ -870,7 +1031,9 @@ class PTA: min_duration = max_energy_value * energy_granularity / max_power return min_duration - def max_duration_until_energy_overflow(self, energy_granularity=1e-12, max_energy_value=2 ** 32 - 1): + def max_duration_until_energy_overflow( + self, energy_granularity=1e-12, max_energy_value=2 ** 32 - 1 + ): """ Return maximum duration (in s) until energy counter overflow during online accounting. @@ -880,7 +1043,9 @@ class PTA: min_power_state = self.get_least_expensive_state() if min_power_state.has_interrupt_transitions(): - raise RuntimeWarning('state with maximum power consumption has outgoing interrupt transitions, results will be inaccurate') + raise RuntimeWarning( + "state with maximum power consumption has outgoing interrupt transitions, results will be inaccurate" + ) # convert from µW to W min_power = min_power_state.power.value * 1e-6 @@ -904,14 +1069,19 @@ class PTA: for i, argument in enumerate(transition.arguments): if len(transition.argument_values[i]) <= 2: continue - if transition.argument_combination == 'zip': + if transition.argument_combination == "zip": continue values_are_numeric = True for value in transition.argument_values[i]: - if not is_numeric(self.normalize_parameter(transition.arg_to_param_map[i], value)): + if not is_numeric( + self.normalize_parameter(transition.arg_to_param_map[i], value) + ): values_are_numeric = False if values_are_numeric and len(transition.argument_values[i]) > 2: - transition.argument_values[i] = [transition.argument_values[i][0], transition.argument_values[i][-1]] + transition.argument_values[i] = [ + transition.argument_values[i][0], + transition.argument_values[i][-1], + ] def _dfs_with_param(self, generator, param_dict): for trace in generator: @@ -921,13 +1091,23 @@ class PTA: transition, arguments = elem if transition is not None: param = transition.get_params_after_transition(param, arguments) - ret.append((transition, arguments, self.normalize_parameters(param))) + ret.append( + (transition, arguments, self.normalize_parameters(param)) + ) else: # parameters have already been normalized ret.append((transition, arguments, param)) yield ret - def bfs(self, depth: int = 10, orig_state: str = 'UNINITIALIZED', param_dict: dict = None, with_parameters: bool = False, transition_filter=None, state_filter=None): + def bfs( + self, + depth: int = 10, + orig_state: str = "UNINITIALIZED", + param_dict: dict = None, + with_parameters: bool = False, + transition_filter=None, + state_filter=None, + ): """ Return a generator object for breadth-first search of traces starting at orig_state. @@ -968,7 +1148,14 @@ class PTA: yield new_trace state_queue.put((new_trace, transition.destination)) - def dfs(self, depth: int = 10, orig_state: str = 'UNINITIALIZED', param_dict: dict = None, with_parameters: bool = False, **kwargs): + def dfs( + self, + depth: int = 10, + orig_state: str = "UNINITIALIZED", + param_dict: dict = None, + with_parameters: bool = False, + **kwargs + ): """ Return a generator object for depth-first search starting at orig_state. @@ -994,12 +1181,14 @@ class PTA: if with_parameters and not param_dict: param_dict = self.get_initial_param_dict() - if with_parameters and 'with_arguments' not in kwargs: + if with_parameters and "with_arguments" not in kwargs: raise ValueError("with_parameters = True requires with_arguments = True") if self.accepting_states: - generator = filter(lambda x: x[-1][0].destination.name in self.accepting_states, - self.state[orig_state].dfs(depth, **kwargs)) + generator = filter( + lambda x: x[-1][0].destination.name in self.accepting_states, + self.state[orig_state].dfs(depth, **kwargs), + ) else: generator = self.state[orig_state].dfs(depth, **kwargs) @@ -1008,7 +1197,13 @@ class PTA: else: return generator - def simulate(self, trace: list, orig_state: str = 'UNINITIALIZED', orig_param=None, accounting=None): + def simulate( + self, + trace: list, + orig_state: str = "UNINITIALIZED", + orig_param=None, + accounting=None, + ): u""" Simulate a single run through the PTA and return total energy, duration, final state, and resulting parameters. @@ -1021,10 +1216,10 @@ class PTA: :returns: SimulationResult with duration in s, total energy in J, end state, and final parameters """ - total_duration = 0. - total_duration_mae = 0. - total_energy = 0. - total_energy_error = 0. + total_duration = 0.0 + total_duration_mae = 0.0 + total_energy = 0.0 + total_energy_error = 0.0 if type(orig_state) is State: state = orig_state else: @@ -1032,7 +1227,12 @@ class PTA: if orig_param: param_dict = orig_param.copy() else: - param_dict = dict([[self.parameters[i], self.initial_param_values[i]] for i in range(len(self.parameters))]) + param_dict = dict( + [ + [self.parameters[i], self.initial_param_values[i]] + for i in range(len(self.parameters)) + ] + ) for function in trace: if isinstance(function[0], Transition): function_name = function[0].name @@ -1040,11 +1240,13 @@ class PTA: else: function_name = function[0] function_args = function[1:] - if function_name is None or function_name == '_': + if function_name is None or function_name == "_": duration = function_args[0] total_energy += state.get_energy(duration, param_dict) if state.power.value_error is not None: - total_energy_error += (duration * state.power.eval_mae(param_dict, function_args))**2 + total_energy_error += ( + duration * state.power.eval_mae(param_dict, function_args) + ) ** 2 total_duration += duration # assumption: sleep is near-exact and does not contribute to the duration error if accounting is not None: @@ -1053,15 +1255,21 @@ class PTA: transition = state.get_transition(function_name) total_duration += transition.duration.eval(param_dict, function_args) if transition.duration.value_error is not None: - total_duration_mae += transition.duration.eval_mae(param_dict, function_args)**2 + total_duration_mae += ( + transition.duration.eval_mae(param_dict, function_args) ** 2 + ) total_energy += transition.get_energy(param_dict, function_args) if transition.energy.value_error is not None: - total_energy_error += transition.energy.eval_mae(param_dict, function_args)**2 - param_dict = transition.get_params_after_transition(param_dict, function_args) + total_energy_error += ( + transition.energy.eval_mae(param_dict, function_args) ** 2 + ) + param_dict = transition.get_params_after_transition( + param_dict, function_args + ) state = transition.destination if accounting is not None: accounting.pass_transition(transition) - while (state.has_interrupt_transitions()): + while state.has_interrupt_transitions(): transition = state.get_next_interrupt(param_dict) duration = transition.get_timeout(param_dict) total_duration += duration @@ -1072,45 +1280,82 @@ class PTA: param_dict = transition.get_params_after_transition(param_dict) state = transition.destination - return SimulationResult(total_duration, total_energy, state, param_dict, duration_mae=np.sqrt(total_duration_mae), energy_mae=np.sqrt(total_energy_error)) + return SimulationResult( + total_duration, + total_energy, + state, + param_dict, + duration_mae=np.sqrt(total_duration_mae), + energy_mae=np.sqrt(total_energy_error), + ) def update(self, static_model, param_model, static_error=None, analytic_error=None): for state in self.state.values(): - if state.name != 'UNINITIALIZED': + if state.name != "UNINITIALIZED": try: - state.power.value = static_model(state.name, 'power') + state.power.value = static_model(state.name, "power") if static_error is not None: - state.power.value_error = static_error[state.name]['power'] - if param_model(state.name, 'power'): - state.power.function = param_model(state.name, 'power')['function'] + state.power.value_error = static_error[state.name]["power"] + if param_model(state.name, "power"): + state.power.function = param_model(state.name, "power")[ + "function" + ] if analytic_error is not None: - state.power.function_error = analytic_error[state.name]['power'] + state.power.function_error = analytic_error[state.name][ + "power" + ] except KeyError: - print('[W] skipping model update of state {} due to missing data'.format(state.name)) + print( + "[W] skipping model update of state {} due to missing data".format( + state.name + ) + ) pass for transition in self.transitions: try: - transition.duration.value = static_model(transition.name, 'duration') - if param_model(transition.name, 'duration'): - transition.duration.function = param_model(transition.name, 'duration')['function'] + transition.duration.value = static_model(transition.name, "duration") + if param_model(transition.name, "duration"): + transition.duration.function = param_model( + transition.name, "duration" + )["function"] if analytic_error is not None: - transition.duration.function_error = analytic_error[transition.name]['duration'] - transition.energy.value = static_model(transition.name, 'energy') - if param_model(transition.name, 'energy'): - transition.energy.function = param_model(transition.name, 'energy')['function'] + transition.duration.function_error = analytic_error[ + transition.name + ]["duration"] + transition.energy.value = static_model(transition.name, "energy") + if param_model(transition.name, "energy"): + transition.energy.function = param_model(transition.name, "energy")[ + "function" + ] if analytic_error is not None: - transition.energy.function_error = analytic_error[transition.name]['energy'] + transition.energy.function_error = analytic_error[ + transition.name + ]["energy"] if transition.is_interrupt: - transition.timeout.value = static_model(transition.name, 'timeout') - if param_model(transition.name, 'timeout'): - transition.timeout.function = param_model(transition.name, 'timeout')['function'] + transition.timeout.value = static_model(transition.name, "timeout") + if param_model(transition.name, "timeout"): + transition.timeout.function = param_model( + transition.name, "timeout" + )["function"] if analytic_error is not None: - transition.timeout.function_error = analytic_error[transition.name]['timeout'] + transition.timeout.function_error = analytic_error[ + transition.name + ]["timeout"] if static_error is not None: - transition.duration.value_error = static_error[transition.name]['duration'] - transition.energy.value_error = static_error[transition.name]['energy'] - transition.timeout.value_error = static_error[transition.name]['timeout'] + transition.duration.value_error = static_error[transition.name][ + "duration" + ] + transition.energy.value_error = static_error[transition.name][ + "energy" + ] + transition.timeout.value_error = static_error[transition.name][ + "timeout" + ] except KeyError: - print('[W] skipping model update of transition {} due to missing data'.format(transition.name)) + print( + "[W] skipping model update of transition {} due to missing data".format( + transition.name + ) + ) pass diff --git a/lib/codegen.py b/lib/codegen.py index e0bf45f..62776fd 100644 --- a/lib/codegen.py +++ b/lib/codegen.py @@ -60,38 +60,46 @@ class ClassFunction: self.body = body def get_definition(self): - return '{} {}({});'.format(self.return_type, self.name, ', '.join(self.arguments)) + return "{} {}({});".format( + self.return_type, self.name, ", ".join(self.arguments) + ) def get_implementation(self): if self.body is None: - return '' - return '{} {}::{}({}) {{\n{}}}\n'.format(self.return_type, self.class_name, self.name, ', '.join(self.arguments), self.body) + return "" + return "{} {}::{}({}) {{\n{}}}\n".format( + self.return_type, + self.class_name, + self.name, + ", ".join(self.arguments), + self.body, + ) def get_accountingmethod(method): """Return AccountingMethod class for method.""" - if method == 'static_state_immediate': + if method == "static_state_immediate": return StaticStateOnlyAccountingImmediateCalculation - if method == 'static_state': + if method == "static_state": return StaticStateOnlyAccounting - if method == 'static_statetransition_immediate': + if method == "static_statetransition_immediate": return StaticAccountingImmediateCalculation - if method == 'static_statetransition': + if method == "static_statetransition": return StaticAccounting - raise ValueError('Unknown accounting method: {}'.format(method)) + raise ValueError("Unknown accounting method: {}".format(method)) def get_simulated_accountingmethod(method): """Return SimulatedAccountingMethod class for method.""" - if method == 'static_state_immediate': + if method == "static_state_immediate": return SimulatedStaticStateOnlyAccountingImmediateCalculation - if method == 'static_statetransition_immediate': + if method == "static_statetransition_immediate": return SimulatedStaticAccountingImmediateCalculation - if method == 'static_state': + if method == "static_state": return SimulatedStaticStateOnlyAccounting - if method == 'static_statetransition': + if method == "static_statetransition": return SimulatedStaticAccounting - raise ValueError('Unknown accounting method: {}'.format(method)) + raise ValueError("Unknown accounting method: {}".format(method)) class SimulatedAccountingMethod: @@ -104,7 +112,18 @@ class SimulatedAccountingMethod: * variable size for accounting of durations, power and energy values """ - def __init__(self, pta: PTA, timer_freq_hz, timer_type, ts_type, power_type, energy_type, ts_granularity=1e-6, power_granularity=1e-6, energy_granularity=1e-12): + def __init__( + self, + pta: PTA, + timer_freq_hz, + timer_type, + ts_type, + power_type, + energy_type, + ts_granularity=1e-6, + power_granularity=1e-6, + energy_granularity=1e-12, + ): """ Simulate Online Accounting for a given PTA. @@ -121,7 +140,7 @@ class SimulatedAccountingMethod: self.ts_class = simulate_int_type(ts_type) self.power_class = simulate_int_type(power_type) self.energy_class = simulate_int_type(energy_type) - self.current_state = pta.state['UNINITIALIZED'] + self.current_state = pta.state["UNINITIALIZED"] self.ts_granularity = ts_granularity self.power_granularity = power_granularity @@ -137,7 +156,13 @@ class SimulatedAccountingMethod: Does not use Module types and therefore does not consider overflows or data-type limitations""" if self.energy_granularity == self.power_granularity * self.ts_granularity: return power * time - return int(power * self.power_granularity * time * self.ts_granularity / self.energy_granularity) + return int( + power + * self.power_granularity + * time + * self.ts_granularity + / self.energy_granularity + ) def _sleep_duration(self, duration_us): u""" @@ -202,11 +227,11 @@ class SimulatedStaticAccountingImmediateCalculation(SimulatedAccountingMethod): def sleep(self, duration_us): time = self._sleep_duration(duration_us) - print('sleep duration is {}'.format(time)) + print("sleep duration is {}".format(time)) power = int(self.current_state.power.value) - print('power is {}'.format(power)) + print("power is {}".format(power)) energy = self._energy_from_power_and_time(time, power) - print('energy is {}'.format(energy)) + print("energy is {}".format(energy)) self.energy += energy def pass_transition(self, transition: Transition): @@ -232,7 +257,7 @@ class SimulatedStaticAccounting(SimulatedAccountingMethod): self.time_in_state[state_name] = self.ts_class(0) self.transition_count = list() for transition in pta.transitions: - self.transition_count.append(simulate_int_type('uint16_t')(0)) + self.transition_count.append(simulate_int_type("uint16_t")(0)) def sleep(self, duration_us): self.time_in_state[self.current_state.name] += self._sleep_duration(duration_us) @@ -245,7 +270,9 @@ class SimulatedStaticAccounting(SimulatedAccountingMethod): pta = self.pta energy = self.energy_class(0) for state in pta.state.values(): - energy += self._energy_from_power_and_time(self.time_in_state[state.name], int(state.power.value)) + energy += self._energy_from_power_and_time( + self.time_in_state[state.name], int(state.power.value) + ) for i, transition in enumerate(pta.transitions): energy += self.transition_count[i] * int(transition.energy.value) return energy.val @@ -275,7 +302,9 @@ class SimulatedStaticStateOnlyAccounting(SimulatedAccountingMethod): pta = self.pta energy = self.energy_class(0) for state in pta.state.values(): - energy += self._energy_from_power_and_time(self.time_in_state[state.name], int(state.power.value)) + energy += self._energy_from_power_and_time( + self.time_in_state[state.name], int(state.power.value) + ) return energy.val @@ -290,32 +319,50 @@ class AccountingMethod: self.public_functions = list() def pre_transition_hook(self, transition): - return '' + return "" def init_code(self): - return '' + return "" def get_includes(self): return map(lambda x: '#include "{}"'.format(x), self.include_paths) class StaticStateOnlyAccountingImmediateCalculation(AccountingMethod): - def __init__(self, class_name: str, pta: PTA, ts_type='unsigned int', power_type='unsigned int', energy_type='unsigned long'): + def __init__( + self, + class_name: str, + pta: PTA, + ts_type="unsigned int", + power_type="unsigned int", + energy_type="unsigned long", + ): super().__init__(class_name, pta) self.ts_type = ts_type - self.include_paths.append('driver/uptime.h') - self.private_variables.append('unsigned char lastState;') - self.private_variables.append('{} lastStateChange;'.format(ts_type)) - self.private_variables.append('{} totalEnergy;'.format(energy_type)) - self.private_variables.append(array_template.format( - type=power_type, - name='state_power', - length=len(pta.state), - elements=', '.join(map(lambda state_name: '{:.0f}'.format(pta.state[state_name].power), pta.get_state_names())) - )) + self.include_paths.append("driver/uptime.h") + self.private_variables.append("unsigned char lastState;") + self.private_variables.append("{} lastStateChange;".format(ts_type)) + self.private_variables.append("{} totalEnergy;".format(energy_type)) + self.private_variables.append( + array_template.format( + type=power_type, + name="state_power", + length=len(pta.state), + elements=", ".join( + map( + lambda state_name: "{:.0f}".format(pta.state[state_name].power), + pta.get_state_names(), + ) + ), + ) + ) get_energy_function = """return totalEnergy;""" - self.public_functions.append(ClassFunction(class_name, energy_type, 'getEnergy', list(), get_energy_function)) + self.public_functions.append( + ClassFunction( + class_name, energy_type, "getEnergy", list(), get_energy_function + ) + ) def pre_transition_hook(self, transition): return """ @@ -323,30 +370,50 @@ class StaticStateOnlyAccountingImmediateCalculation(AccountingMethod): totalEnergy += (now - lastStateChange) * state_power[lastState]; lastStateChange = now; lastState = {}; - """.format(self.pta.get_state_id(transition.destination)) + """.format( + self.pta.get_state_id(transition.destination) + ) def init_code(self): return """ totalEnergy = 0; lastStateChange = 0; lastState = 0; - """.format(num_states=len(self.pta.state)) + """.format( + num_states=len(self.pta.state) + ) class StaticStateOnlyAccounting(AccountingMethod): - def __init__(self, class_name: str, pta: PTA, ts_type='unsigned int', power_type='unsigned int', energy_type='unsigned long'): + def __init__( + self, + class_name: str, + pta: PTA, + ts_type="unsigned int", + power_type="unsigned int", + energy_type="unsigned long", + ): super().__init__(class_name, pta) self.ts_type = ts_type - self.include_paths.append('driver/uptime.h') - self.private_variables.append('unsigned char lastState;') - self.private_variables.append('{} lastStateChange;'.format(ts_type)) - self.private_variables.append(array_template.format( - type=power_type, - name='state_power', - length=len(pta.state), - elements=', '.join(map(lambda state_name: '{:.0f}'.format(pta.state[state_name].power), pta.get_state_names())) - )) - self.private_variables.append('{} timeInState[{}];'.format(ts_type, len(pta.state))) + self.include_paths.append("driver/uptime.h") + self.private_variables.append("unsigned char lastState;") + self.private_variables.append("{} lastStateChange;".format(ts_type)) + self.private_variables.append( + array_template.format( + type=power_type, + name="state_power", + length=len(pta.state), + elements=", ".join( + map( + lambda state_name: "{:.0f}".format(pta.state[state_name].power), + pta.get_state_names(), + ) + ), + ) + ) + self.private_variables.append( + "{} timeInState[{}];".format(ts_type, len(pta.state)) + ) get_energy_function = """ {energy_type} total_energy = 0; @@ -354,8 +421,14 @@ class StaticStateOnlyAccounting(AccountingMethod): total_energy += timeInState[i] * state_power[i]; }} return total_energy; - """.format(energy_type=energy_type, num_states=len(pta.state)) - self.public_functions.append(ClassFunction(class_name, energy_type, 'getEnergy', list(), get_energy_function)) + """.format( + energy_type=energy_type, num_states=len(pta.state) + ) + self.public_functions.append( + ClassFunction( + class_name, energy_type, "getEnergy", list(), get_energy_function + ) + ) def pre_transition_hook(self, transition): return """ @@ -363,7 +436,9 @@ class StaticStateOnlyAccounting(AccountingMethod): timeInState[lastState] += now - lastStateChange; lastStateChange = now; lastState = {}; - """.format(self.pta.get_state_id(transition.destination)) + """.format( + self.pta.get_state_id(transition.destination) + ) def init_code(self): return """ @@ -372,30 +447,59 @@ class StaticStateOnlyAccounting(AccountingMethod): }} lastState = 0; lastStateChange = 0; - """.format(num_states=len(self.pta.state)) + """.format( + num_states=len(self.pta.state) + ) class StaticAccounting(AccountingMethod): - def __init__(self, class_name: str, pta: PTA, ts_type='unsigned int', power_type='unsigned int', energy_type='unsigned long'): + def __init__( + self, + class_name: str, + pta: PTA, + ts_type="unsigned int", + power_type="unsigned int", + energy_type="unsigned long", + ): super().__init__(class_name, pta) self.ts_type = ts_type - self.include_paths.append('driver/uptime.h') - self.private_variables.append('unsigned char lastState;') - self.private_variables.append('{} lastStateChange;'.format(ts_type)) - self.private_variables.append(array_template.format( - type=power_type, - name='state_power', - length=len(pta.state), - elements=', '.join(map(lambda state_name: '{:.0f}'.format(pta.state[state_name].power), pta.get_state_names())) - )) - self.private_variables.append(array_template.format( - type=energy_type, - name='transition_energy', - length=len(pta.get_unique_transitions()), - elements=', '.join(map(lambda transition: '{:.0f}'.format(transition.energy), pta.get_unique_transitions())) - )) - self.private_variables.append('{} timeInState[{}];'.format(ts_type, len(pta.state))) - self.private_variables.append('{} transitionCount[{}];'.format('unsigned int', len(pta.get_unique_transitions()))) + self.include_paths.append("driver/uptime.h") + self.private_variables.append("unsigned char lastState;") + self.private_variables.append("{} lastStateChange;".format(ts_type)) + self.private_variables.append( + array_template.format( + type=power_type, + name="state_power", + length=len(pta.state), + elements=", ".join( + map( + lambda state_name: "{:.0f}".format(pta.state[state_name].power), + pta.get_state_names(), + ) + ), + ) + ) + self.private_variables.append( + array_template.format( + type=energy_type, + name="transition_energy", + length=len(pta.get_unique_transitions()), + elements=", ".join( + map( + lambda transition: "{:.0f}".format(transition.energy), + pta.get_unique_transitions(), + ) + ), + ) + ) + self.private_variables.append( + "{} timeInState[{}];".format(ts_type, len(pta.state)) + ) + self.private_variables.append( + "{} transitionCount[{}];".format( + "unsigned int", len(pta.get_unique_transitions()) + ) + ) get_energy_function = """ {energy_type} total_energy = 0; @@ -406,8 +510,16 @@ class StaticAccounting(AccountingMethod): total_energy += transitionCount[i] * transition_energy[i]; }} return total_energy; - """.format(energy_type=energy_type, num_states=len(pta.state), num_transitions=len(pta.get_unique_transitions())) - self.public_functions.append(ClassFunction(class_name, energy_type, 'getEnergy', list(), get_energy_function)) + """.format( + energy_type=energy_type, + num_states=len(pta.state), + num_transitions=len(pta.get_unique_transitions()), + ) + self.public_functions.append( + ClassFunction( + class_name, energy_type, "getEnergy", list(), get_energy_function + ) + ) def pre_transition_hook(self, transition): return """ @@ -416,7 +528,10 @@ class StaticAccounting(AccountingMethod): transitionCount[{}]++; lastStateChange = now; lastState = {}; - """.format(self.pta.get_unique_transition_id(transition), self.pta.get_state_id(transition.destination)) + """.format( + self.pta.get_unique_transition_id(transition), + self.pta.get_state_id(transition.destination), + ) def init_code(self): return """ @@ -428,28 +543,53 @@ class StaticAccounting(AccountingMethod): }} lastState = 0; lastStateChange = 0; - """.format(num_states=len(self.pta.state), num_transitions=len(self.pta.get_unique_transitions())) + """.format( + num_states=len(self.pta.state), + num_transitions=len(self.pta.get_unique_transitions()), + ) class StaticAccountingImmediateCalculation(AccountingMethod): - def __init__(self, class_name: str, pta: PTA, ts_type='unsigned int', power_type='unsigned int', energy_type='unsigned long'): + def __init__( + self, + class_name: str, + pta: PTA, + ts_type="unsigned int", + power_type="unsigned int", + energy_type="unsigned long", + ): super().__init__(class_name, pta) self.ts_type = ts_type - self.include_paths.append('driver/uptime.h') - self.private_variables.append('unsigned char lastState;') - self.private_variables.append('{} lastStateChange;'.format(ts_type)) - self.private_variables.append('{} totalEnergy;'.format(energy_type)) - self.private_variables.append(array_template.format( - type=power_type, - name='state_power', - length=len(pta.state), - elements=', '.join(map(lambda state_name: '{:.0f}'.format(pta.state[state_name].power), pta.get_state_names())) - )) + self.include_paths.append("driver/uptime.h") + self.private_variables.append("unsigned char lastState;") + self.private_variables.append("{} lastStateChange;".format(ts_type)) + self.private_variables.append("{} totalEnergy;".format(energy_type)) + self.private_variables.append( + array_template.format( + type=power_type, + name="state_power", + length=len(pta.state), + elements=", ".join( + map( + lambda state_name: "{:.0f}".format(pta.state[state_name].power), + pta.get_state_names(), + ) + ), + ) + ) get_energy_function = """ return totalEnergy; - """.format(energy_type=energy_type, num_states=len(pta.state), num_transitions=len(pta.get_unique_transitions())) - self.public_functions.append(ClassFunction(class_name, energy_type, 'getEnergy', list(), get_energy_function)) + """.format( + energy_type=energy_type, + num_states=len(pta.state), + num_transitions=len(pta.get_unique_transitions()), + ) + self.public_functions.append( + ClassFunction( + class_name, energy_type, "getEnergy", list(), get_energy_function + ) + ) def pre_transition_hook(self, transition): return """ @@ -458,21 +598,26 @@ class StaticAccountingImmediateCalculation(AccountingMethod): totalEnergy += {}; lastStateChange = now; lastState = {}; - """.format(transition.energy, self.pta.get_state_id(transition.destination)) + """.format( + transition.energy, self.pta.get_state_id(transition.destination) + ) def init_code(self): return """ lastState = 0; lastStateChange = 0; - """.format(num_states=len(self.pta.state), num_transitions=len(self.pta.get_unique_transitions())) + """.format( + num_states=len(self.pta.state), + num_transitions=len(self.pta.get_unique_transitions()), + ) class MultipassDriver: """Generate C++ header and no-op implementation for a multipass driver based on a DFA model.""" def __init__(self, name, pta, class_info, enum=dict(), accounting=AccountingMethod): - self.impl = '' - self.header = '' + self.impl = "" + self.header = "" self.name = name self.pta = pta self.class_info = class_info @@ -484,35 +629,53 @@ class MultipassDriver: private_variables = list() public_variables = list() - public_functions.append(ClassFunction(self.name, '', self.name, list(), accounting.init_code())) + public_functions.append( + ClassFunction(self.name, "", self.name, list(), accounting.init_code()) + ) for transition in self.pta.get_unique_transitions(): - if transition.name == 'getEnergy': + if transition.name == "getEnergy": continue # XXX right now we only verify whether both functions have the # same number of arguments. This breaks in many overloading cases. function_info = self.class_info.function[transition.name] for function_candidate in self.class_info.functions: - if function_candidate.name == transition.name and len(function_candidate.argument_types) == len(transition.arguments): + if function_candidate.name == transition.name and len( + function_candidate.argument_types + ) == len(transition.arguments): function_info = function_candidate function_arguments = list() for i in range(len(transition.arguments)): - function_arguments.append('{} {}'.format(function_info.argument_types[i], transition.arguments[i])) + function_arguments.append( + "{} {}".format( + function_info.argument_types[i], transition.arguments[i] + ) + ) function_body = accounting.pre_transition_hook(transition) - if function_info.return_type != 'void': - function_body += 'return 0;\n' + if function_info.return_type != "void": + function_body += "return 0;\n" - public_functions.append(ClassFunction(self.name, function_info.return_type, transition.name, function_arguments, function_body)) + public_functions.append( + ClassFunction( + self.name, + function_info.return_type, + transition.name, + function_arguments, + function_body, + ) + ) enums = list() for enum_name in self.enum.keys(): - enums.append('enum {} {{ {} }};'.format(enum_name, ', '.join(self.enum[enum_name]))) + enums.append( + "enum {} {{ {} }};".format(enum_name, ", ".join(self.enum[enum_name])) + ) if accounting: includes.extend(accounting.get_includes()) @@ -522,11 +685,21 @@ class MultipassDriver: public_variables.extend(accounting.public_variables) self.header = header_template.format( - name=self.name, name_lower=self.name.lower(), - includes='\n'.join(includes), - private_variables='\n'.join(private_variables), - public_variables='\n'.join(public_variables), - public_functions='\n'.join(map(lambda x: x.get_definition(), public_functions)), - private_functions='', - enums='\n'.join(enums)) - self.impl = implementation_template.format(name=self.name, name_lower=self.name.lower(), functions='\n\n'.join(map(lambda x: x.get_implementation(), public_functions))) + name=self.name, + name_lower=self.name.lower(), + includes="\n".join(includes), + private_variables="\n".join(private_variables), + public_variables="\n".join(public_variables), + public_functions="\n".join( + map(lambda x: x.get_definition(), public_functions) + ), + private_functions="", + enums="\n".join(enums), + ) + self.impl = implementation_template.format( + name=self.name, + name_lower=self.name.lower(), + functions="\n\n".join( + map(lambda x: x.get_implementation(), public_functions) + ), + ) diff --git a/lib/cycles_to_energy.py b/lib/cycles_to_energy.py index 35f9199..a8e88b8 100644 --- a/lib/cycles_to_energy.py +++ b/lib/cycles_to_energy.py @@ -5,86 +5,80 @@ Contains classes for some embedded CPUs/MCUs. Given a configuration, each class can convert a cycle count to an energy consumption. """ + def get_class(cpu_name): """Return model class for cpu_name.""" - if cpu_name == 'MSP430': + if cpu_name == "MSP430": return MSP430 - if cpu_name == 'ATMega168': + if cpu_name == "ATMega168": return ATMega168 - if cpu_name == 'ATMega328': + if cpu_name == "ATMega328": return ATMega328 - if cpu_name == 'ATTiny88': + if cpu_name == "ATTiny88": return ATTiny88 - if cpu_name == 'esp8266': + if cpu_name == "esp8266": return ESP8266 + def _param_list_to_dict(device, param_list): param_dict = dict() for i, parameter in enumerate(sorted(device.parameters.keys())): param_dict[parameter] = param_list[i] return param_dict + class MSP430: - name = 'MSP430' + name = "MSP430" parameters = { - 'cpu_freq': [1e6, 4e6, 8e6, 12e6, 16e6], - 'memory' : ['unified', 'fram0', 'fram50', 'fram66', 'fram75', 'fram100', 'ram'], - 'voltage': [2.2, 3.0], - } - default_params = { - 'cpu_freq': 4e6, - 'memory' : 'unified', - 'voltage': 3 + "cpu_freq": [1e6, 4e6, 8e6, 12e6, 16e6], + "memory": ["unified", "fram0", "fram50", "fram66", "fram75", "fram100", "ram"], + "voltage": [2.2, 3.0], } + default_params = {"cpu_freq": 4e6, "memory": "unified", "voltage": 3} current_by_mem = { - 'unified' : [210, 640, 1220, 1475, 1845], - 'fram0' : [370, 1280, 2510, 2080, 2650], - 'fram50' : [240, 745, 1440, 1575, 1990], - 'fram66' : [200, 560, 1070, 1300, 1620], - 'fram75' : [170, 480, 890, 1155, 1420], - 'fram100' : [110, 235, 420, 640, 730], - 'ram' : [130, 320, 585, 890, 1070], + "unified": [210, 640, 1220, 1475, 1845], + "fram0": [370, 1280, 2510, 2080, 2650], + "fram50": [240, 745, 1440, 1575, 1990], + "fram66": [200, 560, 1070, 1300, 1620], + "fram75": [170, 480, 890, 1155, 1420], + "fram100": [110, 235, 420, 640, 730], + "ram": [130, 320, 585, 890, 1070], } def get_current(params): if type(params) != dict: return MSP430.get_current(_param_list_to_dict(MSP430, params)) - cpu_freq_index = MSP430.parameters['cpu_freq'].index(params['cpu_freq']) + cpu_freq_index = MSP430.parameters["cpu_freq"].index(params["cpu_freq"]) - return MSP430.current_by_mem[params['memory']][cpu_freq_index] * 1e-6 + return MSP430.current_by_mem[params["memory"]][cpu_freq_index] * 1e-6 def get_power(params): if type(params) != dict: return MSP430.get_energy(_param_list_to_dict(MSP430, params)) - return MSP430.get_current(params) * params['voltage'] + return MSP430.get_current(params) * params["voltage"] def get_energy_per_cycle(params): if type(params) != dict: return MSP430.get_energy_per_cycle(_param_list_to_dict(MSP430, params)) - return MSP430.get_power(params) / params['cpu_freq'] + return MSP430.get_power(params) / params["cpu_freq"] + class ATMega168: - name = 'ATMega168' - parameters = { - 'cpu_freq': [1e6, 4e6, 8e6], - 'voltage': [2, 3, 5] - } - default_params = { - 'cpu_freq': 4e6, - 'voltage': 3 - } + name = "ATMega168" + parameters = {"cpu_freq": [1e6, 4e6, 8e6], "voltage": [2, 3, 5]} + default_params = {"cpu_freq": 4e6, "voltage": 3} def get_current(params): if type(params) != dict: return ATMega168.get_current(_param_list_to_dict(ATMega168, params)) - if params['cpu_freq'] == 1e6 and params['voltage'] <= 2: + if params["cpu_freq"] == 1e6 and params["voltage"] <= 2: return 0.5e-3 - if params['cpu_freq'] == 4e6 and params['voltage'] <= 3: + if params["cpu_freq"] == 4e6 and params["voltage"] <= 3: return 3.5e-3 - if params['cpu_freq'] == 8e6 and params['voltage'] <= 5: + if params["cpu_freq"] == 8e6 and params["voltage"] <= 5: return 12e-3 return None @@ -92,37 +86,34 @@ class ATMega168: if type(params) != dict: return ATMega168.get_energy(_param_list_to_dict(ATMega168, params)) - return ATMega168.get_current(params) * params['voltage'] + return ATMega168.get_current(params) * params["voltage"] def get_energy_per_cycle(params): if type(params) != dict: - return ATMega168.get_energy_per_cycle(_param_list_to_dict(ATMega168, params)) + return ATMega168.get_energy_per_cycle( + _param_list_to_dict(ATMega168, params) + ) + + return ATMega168.get_power(params) / params["cpu_freq"] - return ATMega168.get_power(params) / params['cpu_freq'] class ATMega328: - name = 'ATMega328' - parameters = { - 'cpu_freq': [1e6, 4e6, 8e6], - 'voltage': [2, 3, 5] - } - default_params = { - 'cpu_freq': 4e6, - 'voltage': 3 - } + name = "ATMega328" + parameters = {"cpu_freq": [1e6, 4e6, 8e6], "voltage": [2, 3, 5]} + default_params = {"cpu_freq": 4e6, "voltage": 3} # Source: ATMega328P Datasheet p.316 / table 28.2.4 def get_current(params): if type(params) != dict: return ATMega328.get_current(_param_list_to_dict(ATMega328, params)) # specified for 2V - if params['cpu_freq'] == 1e6 and params['voltage'] >= 1.8: + if params["cpu_freq"] == 1e6 and params["voltage"] >= 1.8: return 0.3e-3 # specified for 3V - if params['cpu_freq'] == 4e6 and params['voltage'] >= 1.8: + if params["cpu_freq"] == 4e6 and params["voltage"] >= 1.8: return 1.7e-3 # specified for 5V - if params['cpu_freq'] == 8e6 and params['voltage'] >= 2.5: + if params["cpu_freq"] == 8e6 and params["voltage"] >= 2.5: return 5.2e-3 return None @@ -130,26 +121,26 @@ class ATMega328: if type(params) != dict: return ATMega328.get_energy(_param_list_to_dict(ATMega328, params)) - return ATMega328.get_current(params) * params['voltage'] + return ATMega328.get_current(params) * params["voltage"] def get_energy_per_cycle(params): if type(params) != dict: - return ATMega328.get_energy_per_cycle(_param_list_to_dict(ATMega328, params)) + return ATMega328.get_energy_per_cycle( + _param_list_to_dict(ATMega328, params) + ) + + return ATMega328.get_power(params) / params["cpu_freq"] - return ATMega328.get_power(params) / params['cpu_freq'] class ESP8266: # Source: ESP8266EX Datasheet, table 5-2 (v2017.11) / table 3-4 (v2018.11) # Taken at 3.0V - name = 'ESP8266' + name = "ESP8266" parameters = { - 'cpu_freq': [80e6], - 'voltage': [2.5, 3.0, 3.3, 3.6] # min / ... / typ / max - } - default_params = { - 'cpu_freq': 80e6, - 'voltage': 3.0 + "cpu_freq": [80e6], + "voltage": [2.5, 3.0, 3.3, 3.6], # min / ... / typ / max } + default_params = {"cpu_freq": 80e6, "voltage": 3.0} def get_current(params): if type(params) != dict: @@ -161,33 +152,27 @@ class ESP8266: if type(params) != dict: return ESP8266.get_power(_param_list_to_dict(ESP8266, params)) - return ESP8266.get_current(params) * params['voltage'] + return ESP8266.get_current(params) * params["voltage"] def get_energy_per_cycle(params): if type(params) != dict: return ESP8266.get_energy_per_cycle(_param_list_to_dict(ESP8266, params)) - return ESP8266.get_power(params) / params['cpu_freq'] + return ESP8266.get_power(params) / params["cpu_freq"] + class ATTiny88: - name = 'ATTiny88' - parameters = { - 'cpu_freq': [1e6, 4e6, 8e6], - 'voltage': [2, 3, 5] - } - default_params = { - 'cpu_freq' : 4e6, - 'voltage' : 3 - } + name = "ATTiny88" + parameters = {"cpu_freq": [1e6, 4e6, 8e6], "voltage": [2, 3, 5]} + default_params = {"cpu_freq": 4e6, "voltage": 3} def get_current(params): if type(params) != dict: return ATTiny88.get_current(_param_list_to_dict(ATTiny88, params)) - if params['cpu_freq'] == 1e6 and params['voltage'] <= 2: + if params["cpu_freq"] == 1e6 and params["voltage"] <= 2: return 0.2e-3 - if params['cpu_freq'] == 4e6 and params['voltage'] <= 3: + if params["cpu_freq"] == 4e6 and params["voltage"] <= 3: return 1.4e-3 - if params['cpu_freq'] == 8e6 and params['voltage'] <= 5: + if params["cpu_freq"] == 8e6 and params["voltage"] <= 5: return 4.5e-3 return None - diff --git a/lib/data_parameters.py b/lib/data_parameters.py index 3b7a148..1150b71 100644 --- a/lib/data_parameters.py +++ b/lib/data_parameters.py @@ -10,6 +10,7 @@ from . import cycles_to_energy, size_to_radio_energy, utils import numpy as np import ubjson + def _string_value_length(json): if type(json) == str: return len(json) @@ -22,6 +23,7 @@ def _string_value_length(json): return 0 + # TODO distinguish between int and uint, which is not visible from the # data value alone def _int_value_length(json): @@ -40,18 +42,21 @@ def _int_value_length(json): return 0 + def _string_key_length(json): if type(json) == dict: return sum(map(len, json.keys())) + sum(map(_string_key_length, json.values())) return 0 + def _num_keys(json): if type(json) == dict: return len(json.keys()) + sum(map(_num_keys, json.values())) return 0 + def _num_of_type(json, wanted_type): ret = 0 if type(json) == wanted_type: @@ -65,16 +70,17 @@ def _num_of_type(json, wanted_type): return ret + def json_to_param(json): """Return numeric parameters describing the structure of JSON data.""" ret = dict() - ret['strlen_keys'] = _string_key_length(json) - ret['strlen_values'] = _string_value_length(json) - ret['bytelen_int'] = _int_value_length(json) - ret['num_int'] = _num_of_type(json, int) - ret['num_float'] = _num_of_type(json, float) - ret['num_str'] = _num_of_type(json, str) + ret["strlen_keys"] = _string_key_length(json) + ret["strlen_values"] = _string_value_length(json) + ret["bytelen_int"] = _int_value_length(json) + ret["num_int"] = _num_of_type(json, int) + ret["num_float"] = _num_of_type(json, float) + ret["num_str"] = _num_of_type(json, str) return ret @@ -127,16 +133,16 @@ class Protolog: # bogus data if val > 10_000_000: return np.nan - for val in data['nop']: + for val in data["nop"]: # bogus data if val > 10_000_000: return np.nan # All measurements in data[key] cover the same instructions, so they # should be identical -> it's safe to take the median. # However, we leave out the first measurement as it is often bogus. - if key == 'nop': - return np.median(data['nop'][1:]) - return max(0, int(np.median(data[key][1:]) - np.median(data['nop'][1:]))) + if key == "nop": + return np.median(data["nop"][1:]) + return max(0, int(np.median(data[key][1:]) - np.median(data["nop"][1:]))) def _median_callcycles(data): ret = dict() @@ -146,37 +152,44 @@ class Protolog: idem = lambda x: x datamap = [ - ['bss_nop', 'bss_size_nop', idem], - ['bss_ser', 'bss_size_ser', idem], - ['bss_serdes', 'bss_size_serdes', idem], - ['callcycles_raw', 'callcycles', idem], - ['callcycles_median', 'callcycles', _median_callcycles], + ["bss_nop", "bss_size_nop", idem], + ["bss_ser", "bss_size_ser", idem], + ["bss_serdes", "bss_size_serdes", idem], + ["callcycles_raw", "callcycles", idem], + ["callcycles_median", "callcycles", _median_callcycles], # Used to remove nop cycles from callcycles_median - ['cycles_nop', 'cycles', lambda x: Protolog._median_cycles(x, 'nop')], - ['cycles_ser', 'cycles', lambda x: Protolog._median_cycles(x, 'ser')], - ['cycles_des', 'cycles', lambda x: Protolog._median_cycles(x, 'des')], - ['cycles_enc', 'cycles', lambda x: Protolog._median_cycles(x, 'enc')], - ['cycles_dec', 'cycles', lambda x: Protolog._median_cycles(x, 'dec')], - #['cycles_ser_arr', 'cycles', lambda x: np.array(x['ser'][1:]) - np.mean(x['nop'][1:])], - #['cycles_des_arr', 'cycles', lambda x: np.array(x['des'][1:]) - np.mean(x['nop'][1:])], - #['cycles_enc_arr', 'cycles', lambda x: np.array(x['enc'][1:]) - np.mean(x['nop'][1:])], - #['cycles_dec_arr', 'cycles', lambda x: np.array(x['dec'][1:]) - np.mean(x['nop'][1:])], - ['data_nop', 'data_size_nop', idem], - ['data_ser', 'data_size_ser', idem], - ['data_serdes', 'data_size_serdes', idem], - ['heap_ser', 'heap_usage_ser', idem], - ['heap_des', 'heap_usage_des', idem], - ['serialized_size', 'serialized_size', idem], - ['stack_alloc_ser', 'stack_online_ser', lambda x: x['allocated']], - ['stack_set_ser', 'stack_online_ser', lambda x: x['used']], - ['stack_alloc_des', 'stack_online_des', lambda x: x['allocated']], - ['stack_set_des', 'stack_online_des', lambda x: x['used']], - ['text_nop', 'text_size_nop', idem], - ['text_ser', 'text_size_ser', idem], - ['text_serdes', 'text_size_serdes', idem], + ["cycles_nop", "cycles", lambda x: Protolog._median_cycles(x, "nop")], + ["cycles_ser", "cycles", lambda x: Protolog._median_cycles(x, "ser")], + ["cycles_des", "cycles", lambda x: Protolog._median_cycles(x, "des")], + ["cycles_enc", "cycles", lambda x: Protolog._median_cycles(x, "enc")], + ["cycles_dec", "cycles", lambda x: Protolog._median_cycles(x, "dec")], + # ['cycles_ser_arr', 'cycles', lambda x: np.array(x['ser'][1:]) - np.mean(x['nop'][1:])], + # ['cycles_des_arr', 'cycles', lambda x: np.array(x['des'][1:]) - np.mean(x['nop'][1:])], + # ['cycles_enc_arr', 'cycles', lambda x: np.array(x['enc'][1:]) - np.mean(x['nop'][1:])], + # ['cycles_dec_arr', 'cycles', lambda x: np.array(x['dec'][1:]) - np.mean(x['nop'][1:])], + ["data_nop", "data_size_nop", idem], + ["data_ser", "data_size_ser", idem], + ["data_serdes", "data_size_serdes", idem], + ["heap_ser", "heap_usage_ser", idem], + ["heap_des", "heap_usage_des", idem], + ["serialized_size", "serialized_size", idem], + ["stack_alloc_ser", "stack_online_ser", lambda x: x["allocated"]], + ["stack_set_ser", "stack_online_ser", lambda x: x["used"]], + ["stack_alloc_des", "stack_online_des", lambda x: x["allocated"]], + ["stack_set_des", "stack_online_des", lambda x: x["used"]], + ["text_nop", "text_size_nop", idem], + ["text_ser", "text_size_ser", idem], + ["text_serdes", "text_size_serdes", idem], ] - def __init__(self, logfile, cpu_conf = None, cpu_conf_str = None, radio_conf = None, radio_conf_str = None): + def __init__( + self, + logfile, + cpu_conf=None, + cpu_conf_str=None, + radio_conf=None, + radio_conf_str=None, + ): """ Load and enrich raw protobench log data. @@ -185,116 +198,177 @@ class Protolog: """ self.cpu = None self.radio = None - with open(logfile, 'rb') as f: + with open(logfile, "rb") as f: self.data = ubjson.load(f) self.libraries = set() self.architectures = set() self.aggregate = dict() for arch_lib in self.data.keys(): - arch, lib, libopts = arch_lib.split(':') - library = lib + ':' + libopts + arch, lib, libopts = arch_lib.split(":") + library = lib + ":" + libopts for benchmark in self.data[arch_lib].keys(): for benchmark_item in self.data[arch_lib][benchmark].keys(): subv = self.data[arch_lib][benchmark][benchmark_item] for aggregate_label, data_label, getter in Protolog.datamap: try: - self.add_datapoint(arch, library, (benchmark, benchmark_item), subv, aggregate_label, data_label, getter) + self.add_datapoint( + arch, + library, + (benchmark, benchmark_item), + subv, + aggregate_label, + data_label, + getter, + ) except KeyError: pass except TypeError as e: - print('TypeError in {} {} {} {}: {} -> {}'.format( - arch_lib, benchmark, benchmark_item, aggregate_label, - subv[data_label]['v'], str(e))) + print( + "TypeError in {} {} {} {}: {} -> {}".format( + arch_lib, + benchmark, + benchmark_item, + aggregate_label, + subv[data_label]["v"], + str(e), + ) + ) pass try: - codegen = codegen_for_lib(lib, libopts.split(','), subv['data']) + codegen = codegen_for_lib(lib, libopts.split(","), subv["data"]) if codegen.max_serialized_bytes != None: - self.add_datapoint(arch, library, (benchmark, benchmark_item), subv, 'buffer_size', data_label, lambda x: codegen.max_serialized_bytes) + self.add_datapoint( + arch, + library, + (benchmark, benchmark_item), + subv, + "buffer_size", + data_label, + lambda x: codegen.max_serialized_bytes, + ) else: - self.add_datapoint(arch, library, (benchmark, benchmark_item), subv, 'buffer_size', data_label, lambda x: 0) + self.add_datapoint( + arch, + library, + (benchmark, benchmark_item), + subv, + "buffer_size", + data_label, + lambda x: 0, + ) except: # avro's codegen will raise RuntimeError("Unsupported Schema") on unsupported data. Other libraries may just silently ignore it. - self.add_datapoint(arch, library, (benchmark, benchmark_item), subv, 'buffer_size', data_label, lambda x: 0) - #self.aggregate[(benchmark, benchmark_item)][arch][lib][aggregate_label] = getter(value[data_label]['v']) - + self.add_datapoint( + arch, + library, + (benchmark, benchmark_item), + subv, + "buffer_size", + data_label, + lambda x: 0, + ) + # self.aggregate[(benchmark, benchmark_item)][arch][lib][aggregate_label] = getter(value[data_label]['v']) for key in self.aggregate.keys(): for arch in self.aggregate[key].keys(): for lib, val in self.aggregate[key][arch].items(): try: - val['cycles_encser'] = val['cycles_enc'] + val['cycles_ser'] + val["cycles_encser"] = val["cycles_enc"] + val["cycles_ser"] except KeyError: pass try: - val['cycles_desdec'] = val['cycles_des'] + val['cycles_dec'] + val["cycles_desdec"] = val["cycles_des"] + val["cycles_dec"] except KeyError: pass try: - for line in val['callcycles_median'].keys(): - val['callcycles_median'][line] -= val['cycles_nop'] + for line in val["callcycles_median"].keys(): + val["callcycles_median"][line] -= val["cycles_nop"] except KeyError: pass try: - val['data_serdes_delta'] = val['data_serdes'] - val['data_nop'] + val["data_serdes_delta"] = val["data_serdes"] - val["data_nop"] except KeyError: pass try: - val['data_serdes_delta_nobuf'] = val['data_serdes'] - val['data_nop'] - val['buffer_size'] + val["data_serdes_delta_nobuf"] = ( + val["data_serdes"] - val["data_nop"] - val["buffer_size"] + ) except KeyError: pass try: - val['bss_serdes_delta'] = val['bss_serdes'] - val['bss_nop'] + val["bss_serdes_delta"] = val["bss_serdes"] - val["bss_nop"] except KeyError: pass try: - val['bss_serdes_delta_nobuf'] = val['bss_serdes'] - val['bss_nop'] - val['buffer_size'] + val["bss_serdes_delta_nobuf"] = ( + val["bss_serdes"] - val["bss_nop"] - val["buffer_size"] + ) except KeyError: pass try: - val['text_serdes_delta'] = val['text_serdes'] - val['text_nop'] + val["text_serdes_delta"] = val["text_serdes"] - val["text_nop"] except KeyError: pass try: - val['total_dmem_ser'] = val['stack_alloc_ser'] - val['written_dmem_ser'] = val['stack_set_ser'] - val['total_dmem_ser'] += val['heap_ser'] - val['written_dmem_ser'] += val['heap_ser'] + val["total_dmem_ser"] = val["stack_alloc_ser"] + val["written_dmem_ser"] = val["stack_set_ser"] + val["total_dmem_ser"] += val["heap_ser"] + val["written_dmem_ser"] += val["heap_ser"] except KeyError: pass try: - val['total_dmem_des'] = val['stack_alloc_des'] - val['written_dmem_des'] = val['stack_set_des'] - val['total_dmem_des'] += val['heap_des'] - val['written_dmem_des'] += val['heap_des'] + val["total_dmem_des"] = val["stack_alloc_des"] + val["written_dmem_des"] = val["stack_set_des"] + val["total_dmem_des"] += val["heap_des"] + val["written_dmem_des"] += val["heap_des"] except KeyError: pass try: - val['total_dmem_serdes'] = max(val['total_dmem_ser'], val['total_dmem_des']) + val["total_dmem_serdes"] = max( + val["total_dmem_ser"], val["total_dmem_des"] + ) except KeyError: pass try: - val['text_ser_delta'] = val['text_ser'] - val['text_nop'] - val['text_serdes_delta'] = val['text_serdes'] - val['text_nop'] + val["text_ser_delta"] = val["text_ser"] - val["text_nop"] + val["text_serdes_delta"] = val["text_serdes"] - val["text_nop"] except KeyError: pass try: - val['bss_ser_delta'] = val['bss_ser'] - val['bss_nop'] - val['bss_serdes_delta'] = val['bss_serdes'] - val['bss_nop'] + val["bss_ser_delta"] = val["bss_ser"] - val["bss_nop"] + val["bss_serdes_delta"] = val["bss_serdes"] - val["bss_nop"] except KeyError: pass try: - val['data_ser_delta'] = val['data_ser'] - val['data_nop'] - val['data_serdes_delta'] = val['data_serdes'] - val['data_nop'] + val["data_ser_delta"] = val["data_ser"] - val["data_nop"] + val["data_serdes_delta"] = val["data_serdes"] - val["data_nop"] except KeyError: pass try: - val['allmem_ser'] = val['text_ser'] + val['data_ser'] + val['bss_ser'] + val['total_dmem_ser'] - val['buffer_size'] - val['allmem_serdes'] = val['text_serdes'] + val['data_serdes'] + val['bss_serdes'] + val['total_dmem_serdes'] - val['buffer_size'] + val["allmem_ser"] = ( + val["text_ser"] + + val["data_ser"] + + val["bss_ser"] + + val["total_dmem_ser"] + - val["buffer_size"] + ) + val["allmem_serdes"] = ( + val["text_serdes"] + + val["data_serdes"] + + val["bss_serdes"] + + val["total_dmem_serdes"] + - val["buffer_size"] + ) except KeyError: pass try: - val['smem_serdes'] = val['text_serdes'] + val['data_serdes'] + val['bss_serdes'] - val['buffer_size'] + val["smem_serdes"] = ( + val["text_serdes"] + + val["data_serdes"] + + val["bss_serdes"] + - val["buffer_size"] + ) except KeyError: pass @@ -303,7 +377,7 @@ class Protolog: if cpu_conf: self.cpu_conf = cpu_conf - cpu = self.cpu = cycles_to_energy.get_class(cpu_conf['model']) + cpu = self.cpu = cycles_to_energy.get_class(cpu_conf["model"]) for key, value in cpu.default_params.items(): if not key in cpu_conf: cpu_conf[key] = value @@ -312,48 +386,102 @@ class Protolog: for lib, val in self.aggregate[key][arch].items(): # All energy data is stored in nanojoules (nJ) try: - val['energy_enc'] = int(val['cycles_enc'] * cpu.get_power(cpu_conf) / cpu_conf['cpu_freq'] * 1e9) + val["energy_enc"] = int( + val["cycles_enc"] + * cpu.get_power(cpu_conf) + / cpu_conf["cpu_freq"] + * 1e9 + ) except KeyError: pass except ValueError: - print('cycles_enc is NaN for {} -> {} -> {}'.format(arch, lib, key)) + print( + "cycles_enc is NaN for {} -> {} -> {}".format( + arch, lib, key + ) + ) try: - val['energy_ser'] = int(val['cycles_ser'] * cpu.get_power(cpu_conf) / cpu_conf['cpu_freq'] * 1e9) + val["energy_ser"] = int( + val["cycles_ser"] + * cpu.get_power(cpu_conf) + / cpu_conf["cpu_freq"] + * 1e9 + ) except KeyError: pass except ValueError: - print('cycles_ser is NaN for {} -> {} -> {}'.format(arch, lib, key)) + print( + "cycles_ser is NaN for {} -> {} -> {}".format( + arch, lib, key + ) + ) try: - val['energy_encser'] = int(val['cycles_encser'] * cpu.get_power(cpu_conf) / cpu_conf['cpu_freq'] * 1e9) + val["energy_encser"] = int( + val["cycles_encser"] + * cpu.get_power(cpu_conf) + / cpu_conf["cpu_freq"] + * 1e9 + ) except KeyError: pass except ValueError: - print('cycles_encser is NaN for {} -> {} -> {}'.format(arch, lib, key)) + print( + "cycles_encser is NaN for {} -> {} -> {}".format( + arch, lib, key + ) + ) try: - val['energy_des'] = int(val['cycles_des'] * cpu.get_power(cpu_conf) / cpu_conf['cpu_freq'] * 1e9) + val["energy_des"] = int( + val["cycles_des"] + * cpu.get_power(cpu_conf) + / cpu_conf["cpu_freq"] + * 1e9 + ) except KeyError: pass except ValueError: - print('cycles_des is NaN for {} -> {} -> {}'.format(arch, lib, key)) + print( + "cycles_des is NaN for {} -> {} -> {}".format( + arch, lib, key + ) + ) try: - val['energy_dec'] = int(val['cycles_dec'] * cpu.get_power(cpu_conf) / cpu_conf['cpu_freq'] * 1e9) + val["energy_dec"] = int( + val["cycles_dec"] + * cpu.get_power(cpu_conf) + / cpu_conf["cpu_freq"] + * 1e9 + ) except KeyError: pass except ValueError: - print('cycles_dec is NaN for {} -> {} -> {}'.format(arch, lib, key)) + print( + "cycles_dec is NaN for {} -> {} -> {}".format( + arch, lib, key + ) + ) try: - val['energy_desdec'] = int(val['cycles_desdec'] * cpu.get_power(cpu_conf) / cpu_conf['cpu_freq'] * 1e9) + val["energy_desdec"] = int( + val["cycles_desdec"] + * cpu.get_power(cpu_conf) + / cpu_conf["cpu_freq"] + * 1e9 + ) except KeyError: pass except ValueError: - print('cycles_desdec is NaN for {} -> {} -> {}'.format(arch, lib, key)) + print( + "cycles_desdec is NaN for {} -> {} -> {}".format( + arch, lib, key + ) + ) if radio_conf_str: radio_conf = utils.parse_conf_str(radio_conf_str) if radio_conf: self.radio_conf = radio_conf - radio = self.radio = size_to_radio_energy.get_class(radio_conf['model']) + radio = self.radio = size_to_radio_energy.get_class(radio_conf["model"]) for key, value in radio.default_params.items(): if not key in radio_conf: radio_conf[key] = value @@ -361,17 +489,22 @@ class Protolog: for arch in self.aggregate[key].keys(): for lib, val in self.aggregate[key][arch].items(): try: - radio_conf['txbytes'] = val['serialized_size'] - if radio_conf['txbytes'] > 0: - val['energy_tx'] = int(radio.get_energy(radio_conf) * 1e9) + radio_conf["txbytes"] = val["serialized_size"] + if radio_conf["txbytes"] > 0: + val["energy_tx"] = int( + radio.get_energy(radio_conf) * 1e9 + ) else: - val['energy_tx'] = 0 - val['energy_encsertx'] = val['energy_encser'] + val['energy_tx'] - val['energy_desdecrx'] = val['energy_desdec'] + val['energy_tx'] + val["energy_tx"] = 0 + val["energy_encsertx"] = ( + val["energy_encser"] + val["energy_tx"] + ) + val["energy_desdecrx"] = ( + val["energy_desdec"] + val["energy_tx"] + ) except KeyError: pass - def add_datapoint(self, arch, lib, key, value, aggregate_label, data_label, getter): """ Set self.aggregate[key][arch][lib][aggregate_Label] = getter(value[data_label]['v']). @@ -379,7 +512,7 @@ class Protolog: Additionally, add lib to self.libraries and arch to self.architectures key usually is ('benchmark name', 'sub-benchmark index'). """ - if data_label in value and 'v' in value[data_label]: + if data_label in value and "v" in value[data_label]: self.architectures.add(arch) self.libraries.add(lib) if not key in self.aggregate: @@ -388,4 +521,6 @@ class Protolog: self.aggregate[key][arch] = dict() if not lib in self.aggregate[key][arch]: self.aggregate[key][arch][lib] = dict() - self.aggregate[key][arch][lib][aggregate_label] = getter(value[data_label]['v']) + self.aggregate[key][arch][lib][aggregate_label] = getter( + value[data_label]["v"] + ) diff --git a/lib/dfatool.py b/lib/dfatool.py index 8fb41a5..56f0f2d 100644 --- a/lib/dfatool.py +++ b/lib/dfatool.py @@ -15,12 +15,19 @@ from multiprocessing import Pool from .functions import analytic from .functions import AnalyticFunction from .parameters import ParamStats -from .utils import vprint, is_numeric, soft_cast_int, param_slice_eq, remove_index_from_tuple +from .utils import ( + vprint, + is_numeric, + soft_cast_int, + param_slice_eq, + remove_index_from_tuple, +) from .utils import by_name_to_by_param, match_parameter_values, running_mean try: from .pubcode import Code128 import zbar + zbar_available = True except ImportError: zbar_available = False @@ -47,25 +54,25 @@ def gplearn_to_function(function_str: str): inv -- 1 / x if |x| > 0.001, otherwise 0 """ eval_globals = { - 'add': lambda x, y: x + y, - 'sub': lambda x, y: x - y, - 'mul': lambda x, y: x * y, - 'div': lambda x, y: np.divide(x, y) if np.abs(y) > 0.001 else 1., - 'sqrt': lambda x: np.sqrt(np.abs(x)), - 'log': lambda x: np.log(np.abs(x)) if np.abs(x) > 0.001 else 0., - 'inv': lambda x: 1. / x if np.abs(x) > 0.001 else 0., + "add": lambda x, y: x + y, + "sub": lambda x, y: x - y, + "mul": lambda x, y: x * y, + "div": lambda x, y: np.divide(x, y) if np.abs(y) > 0.001 else 1.0, + "sqrt": lambda x: np.sqrt(np.abs(x)), + "log": lambda x: np.log(np.abs(x)) if np.abs(x) > 0.001 else 0.0, + "inv": lambda x: 1.0 / x if np.abs(x) > 0.001 else 0.0, } last_arg_index = 0 for i in range(0, 100): - if function_str.find('X{:d}'.format(i)) >= 0: + if function_str.find("X{:d}".format(i)) >= 0: last_arg_index = i arg_list = [] for i in range(0, last_arg_index + 1): - arg_list.append('X{:d}'.format(i)) + arg_list.append("X{:d}".format(i)) - eval_str = 'lambda {}, *whatever: {}'.format(','.join(arg_list), function_str) + eval_str = "lambda {}, *whatever: {}".format(",".join(arg_list), function_str) print(eval_str) return eval(eval_str, eval_globals) @@ -123,32 +130,35 @@ def regression_measures(predicted: np.ndarray, actual: np.ndarray): count -- Number of values """ if type(predicted) != np.ndarray: - raise ValueError('first arg must be ndarray, is {}'.format(type(predicted))) + raise ValueError("first arg must be ndarray, is {}".format(type(predicted))) if type(actual) != np.ndarray: - raise ValueError('second arg must be ndarray, is {}'.format(type(actual))) + raise ValueError("second arg must be ndarray, is {}".format(type(actual))) deviations = predicted - actual # mean = np.mean(actual) if len(deviations) == 0: return {} measures = { - 'mae': np.mean(np.abs(deviations), dtype=np.float64), - 'msd': np.mean(deviations**2, dtype=np.float64), - 'rmsd': np.sqrt(np.mean(deviations**2), dtype=np.float64), - 'ssr': np.sum(deviations**2, dtype=np.float64), - 'rsq': r2_score(actual, predicted), - 'count': len(actual), + "mae": np.mean(np.abs(deviations), dtype=np.float64), + "msd": np.mean(deviations ** 2, dtype=np.float64), + "rmsd": np.sqrt(np.mean(deviations ** 2), dtype=np.float64), + "ssr": np.sum(deviations ** 2, dtype=np.float64), + "rsq": r2_score(actual, predicted), + "count": len(actual), } # rsq_quotient = np.sum((actual - mean)**2, dtype=np.float64) * np.sum((predicted - mean)**2, dtype=np.float64) if np.all(actual != 0): - measures['mape'] = np.mean(np.abs(deviations / actual)) * 100 # bad measure + measures["mape"] = np.mean(np.abs(deviations / actual)) * 100 # bad measure else: - measures['mape'] = np.nan + measures["mape"] = np.nan if np.all(np.abs(predicted) + np.abs(actual) != 0): - measures['smape'] = np.mean(np.abs(deviations) / ((np.abs(predicted) + np.abs(actual)) / 2)) * 100 + measures["smape"] = ( + np.mean(np.abs(deviations) / ((np.abs(predicted) + np.abs(actual)) / 2)) + * 100 + ) else: - measures['smape'] = np.nan + measures["smape"] = np.nan # if np.all(rsq_quotient != 0): # measures['rsq'] = (np.sum((actual - mean) * (predicted - mean), dtype=np.float64)**2) / rsq_quotient @@ -177,7 +187,7 @@ class KeysightCSV: with open(filename) as f: for _ in range(4): next(f) - reader = csv.reader(f, delimiter=',') + reader = csv.reader(f, delimiter=",") for i, row in enumerate(reader): timestamps[i] = float(row[0]) currents[i] = float(row[2]) * -1 @@ -266,29 +276,35 @@ class CrossValidator: } } """ - ret = { - 'by_name': dict() - } + ret = {"by_name": dict()} for name in self.names: - ret['by_name'][name] = dict() - for attribute in self.by_name[name]['attributes']: - ret['by_name'][name][attribute] = { - 'mae_list': list(), - 'smape_list': list() + ret["by_name"][name] = dict() + for attribute in self.by_name[name]["attributes"]: + ret["by_name"][name][attribute] = { + "mae_list": list(), + "smape_list": list(), } for _ in range(count): res = self._single_montecarlo(model_getter) for name in self.names: - for attribute in self.by_name[name]['attributes']: - ret['by_name'][name][attribute]['mae_list'].append(res['by_name'][name][attribute]['mae']) - ret['by_name'][name][attribute]['smape_list'].append(res['by_name'][name][attribute]['smape']) + for attribute in self.by_name[name]["attributes"]: + ret["by_name"][name][attribute]["mae_list"].append( + res["by_name"][name][attribute]["mae"] + ) + ret["by_name"][name][attribute]["smape_list"].append( + res["by_name"][name][attribute]["smape"] + ) for name in self.names: - for attribute in self.by_name[name]['attributes']: - ret['by_name'][name][attribute]['mae'] = np.mean(ret['by_name'][name][attribute]['mae_list']) - ret['by_name'][name][attribute]['smape'] = np.mean(ret['by_name'][name][attribute]['smape_list']) + for attribute in self.by_name[name]["attributes"]: + ret["by_name"][name][attribute]["mae"] = np.mean( + ret["by_name"][name][attribute]["mae_list"] + ) + ret["by_name"][name][attribute]["smape"] = np.mean( + ret["by_name"][name][attribute]["smape_list"] + ) return ret @@ -296,77 +312,87 @@ class CrossValidator: training = dict() validation = dict() for name in self.names: - training[name] = { - 'attributes': self.by_name[name]['attributes'] - } - validation[name] = { - 'attributes': self.by_name[name]['attributes'] - } + training[name] = {"attributes": self.by_name[name]["attributes"]} + validation[name] = {"attributes": self.by_name[name]["attributes"]} - if 'isa' in self.by_name[name]: - training[name]['isa'] = self.by_name[name]['isa'] - validation[name]['isa'] = self.by_name[name]['isa'] + if "isa" in self.by_name[name]: + training[name]["isa"] = self.by_name[name]["isa"] + validation[name]["isa"] = self.by_name[name]["isa"] - data_count = len(self.by_name[name]['param']) + data_count = len(self.by_name[name]["param"]) training_subset, validation_subset = _xv_partition_montecarlo(data_count) - for attribute in self.by_name[name]['attributes']: + for attribute in self.by_name[name]["attributes"]: self.by_name[name][attribute] = np.array(self.by_name[name][attribute]) - training[name][attribute] = self.by_name[name][attribute][training_subset] - validation[name][attribute] = self.by_name[name][attribute][validation_subset] + training[name][attribute] = self.by_name[name][attribute][ + training_subset + ] + validation[name][attribute] = self.by_name[name][attribute][ + validation_subset + ] # We can't use slice syntax for 'param', which may contain strings and other odd values - training[name]['param'] = list() - validation[name]['param'] = list() + training[name]["param"] = list() + validation[name]["param"] = list() for idx in training_subset: - training[name]['param'].append(self.by_name[name]['param'][idx]) + training[name]["param"].append(self.by_name[name]["param"][idx]) for idx in validation_subset: - validation[name]['param'].append(self.by_name[name]['param'][idx]) + validation[name]["param"].append(self.by_name[name]["param"][idx]) - training_data = self.model_class(training, self.parameters, self.arg_count, verbose=False) + training_data = self.model_class( + training, self.parameters, self.arg_count, verbose=False + ) training_model = model_getter(training_data) - validation_data = self.model_class(validation, self.parameters, self.arg_count, verbose=False) + validation_data = self.model_class( + validation, self.parameters, self.arg_count, verbose=False + ) return validation_data.assess(training_model) def _preprocess_mimosa(measurement): - setup = measurement['setup'] - mim = MIMOSA(float(setup['mimosa_voltage']), int(setup['mimosa_shunt']), with_traces=measurement['with_traces']) + setup = measurement["setup"] + mim = MIMOSA( + float(setup["mimosa_voltage"]), + int(setup["mimosa_shunt"]), + with_traces=measurement["with_traces"], + ) try: - charges, triggers = mim.load_data(measurement['content']) + charges, triggers = mim.load_data(measurement["content"]) trigidx = mim.trigger_edges(triggers) except EOFError as e: - mim.errors.append('MIMOSA logfile error: {}'.format(e)) + mim.errors.append("MIMOSA logfile error: {}".format(e)) trigidx = list() if len(trigidx) == 0: - mim.errors.append('MIMOSA log has no triggers') + mim.errors.append("MIMOSA log has no triggers") return { - 'fileno': measurement['fileno'], - 'info': measurement['info'], - 'has_datasource_error': len(mim.errors) > 0, - 'datasource_errors': mim.errors, - 'expected_trace': measurement['expected_trace'], - 'repeat_id': measurement['repeat_id'], + "fileno": measurement["fileno"], + "info": measurement["info"], + "has_datasource_error": len(mim.errors) > 0, + "datasource_errors": mim.errors, + "expected_trace": measurement["expected_trace"], + "repeat_id": measurement["repeat_id"], } - cal_edges = mim.calibration_edges(running_mean(mim.currents_nocal(charges[0:trigidx[0]]), 10)) + cal_edges = mim.calibration_edges( + running_mean(mim.currents_nocal(charges[0 : trigidx[0]]), 10) + ) calfunc, caldata = mim.calibration_function(charges, cal_edges) vcalfunc = np.vectorize(calfunc, otypes=[np.float64]) processed_data = { - 'fileno': measurement['fileno'], - 'info': measurement['info'], - 'triggers': len(trigidx), - 'first_trig': trigidx[0] * 10, - 'calibration': caldata, - 'energy_trace': mim.analyze_states(charges, trigidx, vcalfunc), - 'has_datasource_error': len(mim.errors) > 0, - 'datasource_errors': mim.errors, + "fileno": measurement["fileno"], + "info": measurement["info"], + "triggers": len(trigidx), + "first_trig": trigidx[0] * 10, + "calibration": caldata, + "energy_trace": mim.analyze_states(charges, trigidx, vcalfunc), + "has_datasource_error": len(mim.errors) > 0, + "datasource_errors": mim.errors, } - for key in ['expected_trace', 'repeat_id']: + for key in ["expected_trace", "repeat_id"]: if key in measurement: processed_data[key] = measurement[key] @@ -374,22 +400,28 @@ def _preprocess_mimosa(measurement): def _preprocess_etlog(measurement): - setup = measurement['setup'] - etlog = EnergyTraceLog(float(setup['voltage']), int(setup['state_duration']), measurement['transition_names']) + setup = measurement["setup"] + etlog = EnergyTraceLog( + float(setup["voltage"]), + int(setup["state_duration"]), + measurement["transition_names"], + ) try: - etlog.load_data(measurement['content']) - states_and_transitions = etlog.analyze_states(measurement['expected_trace'], measurement['repeat_id']) + etlog.load_data(measurement["content"]) + states_and_transitions = etlog.analyze_states( + measurement["expected_trace"], measurement["repeat_id"] + ) except EOFError as e: - etlog.errors.append('EnergyTrace logfile error: {}'.format(e)) + etlog.errors.append("EnergyTrace logfile error: {}".format(e)) processed_data = { - 'fileno': measurement['fileno'], - 'repeat_id': measurement['repeat_id'], - 'info': measurement['info'], - 'expected_trace': measurement['expected_trace'], - 'energy_trace': states_and_transitions, - 'has_datasource_error': len(etlog.errors) > 0, - 'datasource_errors': etlog.errors, + "fileno": measurement["fileno"], + "repeat_id": measurement["repeat_id"], + "info": measurement["info"], + "expected_trace": measurement["expected_trace"], + "energy_trace": states_and_transitions, + "has_datasource_error": len(etlog.errors) > 0, + "datasource_errors": etlog.errors, } return processed_data @@ -421,35 +453,40 @@ class TimingData: for trace_group in self.traces_by_fileno: for trace in trace_group: # TimingHarness logs states, but does not aggregate any data for them at the moment -> throw all states away - transitions = list(filter(lambda x: x['isa'] == 'transition', trace['trace'])) - self.traces.append({ - 'id': trace['id'], - 'trace': transitions, - }) + transitions = list( + filter(lambda x: x["isa"] == "transition", trace["trace"]) + ) + self.traces.append( + {"id": trace["id"], "trace": transitions,} + ) for i, trace in enumerate(self.traces): - trace['orig_id'] = trace['id'] - trace['id'] = i - for log_entry in trace['trace']: - paramkeys = sorted(log_entry['parameter'].keys()) - if 'param' not in log_entry['offline_aggregates']: - log_entry['offline_aggregates']['param'] = list() - if 'duration' in log_entry['offline_aggregates']: - for i in range(len(log_entry['offline_aggregates']['duration'])): + trace["orig_id"] = trace["id"] + trace["id"] = i + for log_entry in trace["trace"]: + paramkeys = sorted(log_entry["parameter"].keys()) + if "param" not in log_entry["offline_aggregates"]: + log_entry["offline_aggregates"]["param"] = list() + if "duration" in log_entry["offline_aggregates"]: + for i in range(len(log_entry["offline_aggregates"]["duration"])): paramvalues = list() for paramkey in paramkeys: - if type(log_entry['parameter'][paramkey]) is list: - paramvalues.append(soft_cast_int(log_entry['parameter'][paramkey][i])) + if type(log_entry["parameter"][paramkey]) is list: + paramvalues.append( + soft_cast_int(log_entry["parameter"][paramkey][i]) + ) else: - paramvalues.append(soft_cast_int(log_entry['parameter'][paramkey])) - if arg_support_enabled and 'args' in log_entry: - paramvalues.extend(map(soft_cast_int, log_entry['args'])) - log_entry['offline_aggregates']['param'].append(paramvalues) + paramvalues.append( + soft_cast_int(log_entry["parameter"][paramkey]) + ) + if arg_support_enabled and "args" in log_entry: + paramvalues.extend(map(soft_cast_int, log_entry["args"])) + log_entry["offline_aggregates"]["param"].append(paramvalues) def _preprocess_0(self): for filename in self.filenames: - with open(filename, 'r') as f: + with open(filename, "r") as f: log_data = json.load(f) - self.traces_by_fileno.extend(log_data['traces']) + self.traces_by_fileno.extend(log_data["traces"]) self._concatenate_analyzed_traces() def get_preprocessed_data(self, verbose=True): @@ -470,17 +507,25 @@ class TimingData: def sanity_check_aggregate(aggregate): for key in aggregate: - if 'param' not in aggregate[key]: - raise RuntimeError('aggregate[{}][param] does not exist'.format(key)) - if 'attributes' not in aggregate[key]: - raise RuntimeError('aggregate[{}][attributes] does not exist'.format(key)) - for attribute in aggregate[key]['attributes']: + if "param" not in aggregate[key]: + raise RuntimeError("aggregate[{}][param] does not exist".format(key)) + if "attributes" not in aggregate[key]: + raise RuntimeError("aggregate[{}][attributes] does not exist".format(key)) + for attribute in aggregate[key]["attributes"]: if attribute not in aggregate[key]: - raise RuntimeError('aggregate[{}][{}] does not exist, even though it is contained in aggregate[{}][attributes]'.format(key, attribute, key)) - param_len = len(aggregate[key]['param']) + raise RuntimeError( + "aggregate[{}][{}] does not exist, even though it is contained in aggregate[{}][attributes]".format( + key, attribute, key + ) + ) + param_len = len(aggregate[key]["param"]) attr_len = len(aggregate[key][attribute]) if param_len != attr_len: - raise RuntimeError('parameter mismatch: len(aggregate[{}][param]) == {} != len(aggregate[{}][{}]) == {}'.format(key, param_len, key, attribute, attr_len)) + raise RuntimeError( + "parameter mismatch: len(aggregate[{}][param]) == {} != len(aggregate[{}][{}]) == {}".format( + key, param_len, key, attribute, attr_len + ) + ) class RawData: @@ -559,11 +604,11 @@ class RawData: with tarfile.open(filenames[0]) as tf: for member in tf.getmembers(): - if member.name == 'ptalog.json' and self.version == 0: + if member.name == "ptalog.json" and self.version == 0: self.version = 1 # might also be version 2 # depends on whether *.etlog exists or not - elif '.etlog' in member.name: + elif ".etlog" in member.name: self.version = 2 break @@ -572,18 +617,18 @@ class RawData: self.load_cache() def set_cache_file(self): - cache_key = hashlib.sha256('!'.join(self.filenames).encode()).hexdigest() - self.cache_dir = os.path.dirname(self.filenames[0]) + '/cache' - self.cache_file = '{}/{}.json'.format(self.cache_dir, cache_key) + cache_key = hashlib.sha256("!".join(self.filenames).encode()).hexdigest() + self.cache_dir = os.path.dirname(self.filenames[0]) + "/cache" + self.cache_file = "{}/{}.json".format(self.cache_dir, cache_key) def load_cache(self): if os.path.exists(self.cache_file): - with open(self.cache_file, 'r') as f: + with open(self.cache_file, "r") as f: cache_data = json.load(f) - self.traces = cache_data['traces'] - self.preprocessing_stats = cache_data['preprocessing_stats'] - if 'pta' in cache_data: - self.pta = cache_data['pta'] + self.traces = cache_data["traces"] + self.preprocessing_stats = cache_data["preprocessing_stats"] + if "pta" in cache_data: + self.pta = cache_data["pta"] self.preprocessed = True def save_cache(self): @@ -593,30 +638,30 @@ class RawData: os.mkdir(self.cache_dir) except FileExistsError: pass - with open(self.cache_file, 'w') as f: + with open(self.cache_file, "w") as f: cache_data = { - 'traces': self.traces, - 'preprocessing_stats': self.preprocessing_stats, - 'pta': self.pta, + "traces": self.traces, + "preprocessing_stats": self.preprocessing_stats, + "pta": self.pta, } json.dump(cache_data, f) def _state_is_too_short(self, online, offline, state_duration, next_transition): # We cannot control when an interrupt causes a state to be left - if next_transition['plan']['level'] == 'epilogue': + if next_transition["plan"]["level"] == "epilogue": return False # Note: state_duration is stored as ms, not us - return offline['us'] < state_duration * 500 + return offline["us"] < state_duration * 500 def _state_is_too_long(self, online, offline, state_duration, prev_transition): # If the previous state was left by an interrupt, we may have some # waiting time left over. So it's okay if the current state is longer # than expected. - if prev_transition['plan']['level'] == 'epilogue': + if prev_transition["plan"]["level"] == "epilogue": return False # state_duration is stored as ms, not us - return offline['us'] > state_duration * 1500 + return offline["us"] > state_duration * 1500 def _measurement_is_valid_2(self, processed_data): """ @@ -642,8 +687,8 @@ class RawData: """ # Check for low-level parser errors - if processed_data['has_datasource_error']: - processed_data['error'] = '; '.join(processed_data['datasource_errors']) + if processed_data["has_datasource_error"]: + processed_data["error"] = "; ".join(processed_data["datasource_errors"]) return False # Note that the low-level parser (EnergyTraceLog) already checks @@ -680,26 +725,27 @@ class RawData: - uW_mean_delta_prev: Differenz zwischen uW_mean und uW_mean des vorherigen Zustands - uW_mean_delta_next: Differenz zwischen uW_mean und uW_mean des Folgezustands """ - setup = self.setup_by_fileno[processed_data['fileno']] - if 'expected_trace' in processed_data: - traces = processed_data['expected_trace'] + setup = self.setup_by_fileno[processed_data["fileno"]] + if "expected_trace" in processed_data: + traces = processed_data["expected_trace"] else: - traces = self.traces_by_fileno[processed_data['fileno']] - state_duration = setup['state_duration'] + traces = self.traces_by_fileno[processed_data["fileno"]] + state_duration = setup["state_duration"] # Check MIMOSA error - if processed_data['has_datasource_error']: - processed_data['error'] = '; '.join(processed_data['datasource_errors']) + if processed_data["has_datasource_error"]: + processed_data["error"] = "; ".join(processed_data["datasource_errors"]) return False # Check trigger count sched_trigger_count = 0 for run in traces: - sched_trigger_count += len(run['trace']) - if sched_trigger_count != processed_data['triggers']: - processed_data['error'] = 'got {got:d} trigger edges, expected {exp:d}'.format( - got=processed_data['triggers'], - exp=sched_trigger_count + sched_trigger_count += len(run["trace"]) + if sched_trigger_count != processed_data["triggers"]: + processed_data[ + "error" + ] = "got {got:d} trigger edges, expected {exp:d}".format( + got=processed_data["triggers"], exp=sched_trigger_count ) return False # Check state durations. Very short or long states can indicate a @@ -707,62 +753,102 @@ class RawData: # triggers elsewhere online_datapoints = [] for run_idx, run in enumerate(traces): - for trace_part_idx in range(len(run['trace'])): + for trace_part_idx in range(len(run["trace"])): online_datapoints.append((run_idx, trace_part_idx)) for offline_idx, online_ref in enumerate(online_datapoints): online_run_idx, online_trace_part_idx = online_ref - offline_trace_part = processed_data['energy_trace'][offline_idx] - online_trace_part = traces[online_run_idx]['trace'][online_trace_part_idx] + offline_trace_part = processed_data["energy_trace"][offline_idx] + online_trace_part = traces[online_run_idx]["trace"][online_trace_part_idx] if self._parameter_names is None: - self._parameter_names = sorted(online_trace_part['parameter'].keys()) - - if sorted(online_trace_part['parameter'].keys()) != self._parameter_names: - processed_data['error'] = 'Offline #{off_idx:d} (online {on_name:s} @ {on_idx:d}/{on_sub:d}) has inconsistent parameter set: should be {param_want:s}, is {param_is:s}'.format( - off_idx=offline_idx, on_idx=online_run_idx, + self._parameter_names = sorted(online_trace_part["parameter"].keys()) + + if sorted(online_trace_part["parameter"].keys()) != self._parameter_names: + processed_data[ + "error" + ] = "Offline #{off_idx:d} (online {on_name:s} @ {on_idx:d}/{on_sub:d}) has inconsistent parameter set: should be {param_want:s}, is {param_is:s}".format( + off_idx=offline_idx, + on_idx=online_run_idx, on_sub=online_trace_part_idx, - on_name=online_trace_part['name'], + on_name=online_trace_part["name"], param_want=self._parameter_names, - param_is=sorted(online_trace_part['parameter'].keys()) + param_is=sorted(online_trace_part["parameter"].keys()), ) - if online_trace_part['isa'] != offline_trace_part['isa']: - processed_data['error'] = 'Offline #{off_idx:d} (online {on_name:s} @ {on_idx:d}/{on_sub:d}) claims to be {off_isa:s}, but should be {on_isa:s}'.format( - off_idx=offline_idx, on_idx=online_run_idx, + if online_trace_part["isa"] != offline_trace_part["isa"]: + processed_data[ + "error" + ] = "Offline #{off_idx:d} (online {on_name:s} @ {on_idx:d}/{on_sub:d}) claims to be {off_isa:s}, but should be {on_isa:s}".format( + off_idx=offline_idx, + on_idx=online_run_idx, on_sub=online_trace_part_idx, - on_name=online_trace_part['name'], - off_isa=offline_trace_part['isa'], - on_isa=online_trace_part['isa']) + on_name=online_trace_part["name"], + off_isa=offline_trace_part["isa"], + on_isa=online_trace_part["isa"], + ) return False # Clipping in UNINITIALIZED (offline_idx == 0) can happen during # calibration and is handled by MIMOSA - if offline_idx != 0 and offline_trace_part['clip_rate'] != 0 and not self.ignore_clipping: - processed_data['error'] = 'Offline #{off_idx:d} (online {on_name:s} @ {on_idx:d}/{on_sub:d}) was clipping {clip:f}% of the time'.format( - off_idx=offline_idx, on_idx=online_run_idx, + if ( + offline_idx != 0 + and offline_trace_part["clip_rate"] != 0 + and not self.ignore_clipping + ): + processed_data[ + "error" + ] = "Offline #{off_idx:d} (online {on_name:s} @ {on_idx:d}/{on_sub:d}) was clipping {clip:f}% of the time".format( + off_idx=offline_idx, + on_idx=online_run_idx, on_sub=online_trace_part_idx, - on_name=online_trace_part['name'], - clip=offline_trace_part['clip_rate'] * 100, + on_name=online_trace_part["name"], + clip=offline_trace_part["clip_rate"] * 100, ) return False - if online_trace_part['isa'] == 'state' and online_trace_part['name'] != 'UNINITIALIZED' and len(traces[online_run_idx]['trace']) > online_trace_part_idx + 1: - online_prev_transition = traces[online_run_idx]['trace'][online_trace_part_idx - 1] - online_next_transition = traces[online_run_idx]['trace'][online_trace_part_idx + 1] + if ( + online_trace_part["isa"] == "state" + and online_trace_part["name"] != "UNINITIALIZED" + and len(traces[online_run_idx]["trace"]) > online_trace_part_idx + 1 + ): + online_prev_transition = traces[online_run_idx]["trace"][ + online_trace_part_idx - 1 + ] + online_next_transition = traces[online_run_idx]["trace"][ + online_trace_part_idx + 1 + ] try: - if self._state_is_too_short(online_trace_part, offline_trace_part, state_duration, online_next_transition): - processed_data['error'] = 'Offline #{off_idx:d} (online {on_name:s} @ {on_idx:d}/{on_sub:d}) is too short (duration = {dur:d} us)'.format( - off_idx=offline_idx, on_idx=online_run_idx, + if self._state_is_too_short( + online_trace_part, + offline_trace_part, + state_duration, + online_next_transition, + ): + processed_data[ + "error" + ] = "Offline #{off_idx:d} (online {on_name:s} @ {on_idx:d}/{on_sub:d}) is too short (duration = {dur:d} us)".format( + off_idx=offline_idx, + on_idx=online_run_idx, on_sub=online_trace_part_idx, - on_name=online_trace_part['name'], - dur=offline_trace_part['us']) + on_name=online_trace_part["name"], + dur=offline_trace_part["us"], + ) return False - if self._state_is_too_long(online_trace_part, offline_trace_part, state_duration, online_prev_transition): - processed_data['error'] = 'Offline #{off_idx:d} (online {on_name:s} @ {on_idx:d}/{on_sub:d}) is too long (duration = {dur:d} us)'.format( - off_idx=offline_idx, on_idx=online_run_idx, + if self._state_is_too_long( + online_trace_part, + offline_trace_part, + state_duration, + online_prev_transition, + ): + processed_data[ + "error" + ] = "Offline #{off_idx:d} (online {on_name:s} @ {on_idx:d}/{on_sub:d}) is too long (duration = {dur:d} us)".format( + off_idx=offline_idx, + on_idx=online_run_idx, on_sub=online_trace_part_idx, - on_name=online_trace_part['name'], - dur=offline_trace_part['us']) + on_name=online_trace_part["name"], + dur=offline_trace_part["us"], + ) return False except KeyError: pass @@ -775,136 +861,169 @@ class RawData: # (appends data from measurement['energy_trace']) # If measurement['expected_trace'] exists, it is edited in place instead online_datapoints = [] - if 'expected_trace' in measurement: - traces = measurement['expected_trace'] - traces = self.traces_by_fileno[measurement['fileno']] + if "expected_trace" in measurement: + traces = measurement["expected_trace"] + traces = self.traces_by_fileno[measurement["fileno"]] else: - traces = self.traces_by_fileno[measurement['fileno']] + traces = self.traces_by_fileno[measurement["fileno"]] for run_idx, run in enumerate(traces): - for trace_part_idx in range(len(run['trace'])): + for trace_part_idx in range(len(run["trace"])): online_datapoints.append((run_idx, trace_part_idx)) for offline_idx, online_ref in enumerate(online_datapoints): online_run_idx, online_trace_part_idx = online_ref - offline_trace_part = measurement['energy_trace'][offline_idx] - online_trace_part = traces[online_run_idx]['trace'][online_trace_part_idx] + offline_trace_part = measurement["energy_trace"][offline_idx] + online_trace_part = traces[online_run_idx]["trace"][online_trace_part_idx] - if 'offline' not in online_trace_part: - online_trace_part['offline'] = [offline_trace_part] + if "offline" not in online_trace_part: + online_trace_part["offline"] = [offline_trace_part] else: - online_trace_part['offline'].append(offline_trace_part) + online_trace_part["offline"].append(offline_trace_part) - paramkeys = sorted(online_trace_part['parameter'].keys()) + paramkeys = sorted(online_trace_part["parameter"].keys()) paramvalues = list() for paramkey in paramkeys: - if type(online_trace_part['parameter'][paramkey]) is list: - paramvalues.append(soft_cast_int(online_trace_part['parameter'][paramkey][measurement['repeat_id']])) + if type(online_trace_part["parameter"][paramkey]) is list: + paramvalues.append( + soft_cast_int( + online_trace_part["parameter"][paramkey][ + measurement["repeat_id"] + ] + ) + ) else: - paramvalues.append(soft_cast_int(online_trace_part['parameter'][paramkey])) + paramvalues.append( + soft_cast_int(online_trace_part["parameter"][paramkey]) + ) # NB: Unscheduled transitions do not have an 'args' field set. # However, they should only be caused by interrupts, and # interrupts don't have args anyways. - if arg_support_enabled and 'args' in online_trace_part: - paramvalues.extend(map(soft_cast_int, online_trace_part['args'])) - - if 'offline_aggregates' not in online_trace_part: - online_trace_part['offline_attributes'] = ['power', 'duration', 'energy'] - online_trace_part['offline_aggregates'] = { - 'power': [], - 'duration': [], - 'power_std': [], - 'energy': [], - 'paramkeys': [], - 'param': [], + if arg_support_enabled and "args" in online_trace_part: + paramvalues.extend(map(soft_cast_int, online_trace_part["args"])) + + if "offline_aggregates" not in online_trace_part: + online_trace_part["offline_attributes"] = [ + "power", + "duration", + "energy", + ] + online_trace_part["offline_aggregates"] = { + "power": [], + "duration": [], + "power_std": [], + "energy": [], + "paramkeys": [], + "param": [], } - if online_trace_part['isa'] == 'transition': - online_trace_part['offline_attributes'].extend(['rel_energy_prev', 'rel_energy_next', 'timeout']) - online_trace_part['offline_aggregates']['rel_energy_prev'] = [] - online_trace_part['offline_aggregates']['rel_energy_next'] = [] - online_trace_part['offline_aggregates']['timeout'] = [] + if online_trace_part["isa"] == "transition": + online_trace_part["offline_attributes"].extend( + ["rel_energy_prev", "rel_energy_next", "timeout"] + ) + online_trace_part["offline_aggregates"]["rel_energy_prev"] = [] + online_trace_part["offline_aggregates"]["rel_energy_next"] = [] + online_trace_part["offline_aggregates"]["timeout"] = [] # Note: All state/transitions are 20us "too long" due to injected # active wait states. These are needed to work around MIMOSA's # relatively low sample rate of 100 kHz (10us) and removed here. - online_trace_part['offline_aggregates']['power'].append( - offline_trace_part['uW_mean']) - online_trace_part['offline_aggregates']['duration'].append( - offline_trace_part['us'] - 20) - online_trace_part['offline_aggregates']['power_std'].append( - offline_trace_part['uW_std']) - online_trace_part['offline_aggregates']['energy'].append( - offline_trace_part['uW_mean'] * (offline_trace_part['us'] - 20)) - online_trace_part['offline_aggregates']['paramkeys'].append(paramkeys) - online_trace_part['offline_aggregates']['param'].append(paramvalues) - if online_trace_part['isa'] == 'transition': - online_trace_part['offline_aggregates']['rel_energy_prev'].append( - offline_trace_part['uW_mean_delta_prev'] * (offline_trace_part['us'] - 20)) - online_trace_part['offline_aggregates']['rel_energy_next'].append( - offline_trace_part['uW_mean_delta_next'] * (offline_trace_part['us'] - 20)) - online_trace_part['offline_aggregates']['timeout'].append( - offline_trace_part['timeout']) + online_trace_part["offline_aggregates"]["power"].append( + offline_trace_part["uW_mean"] + ) + online_trace_part["offline_aggregates"]["duration"].append( + offline_trace_part["us"] - 20 + ) + online_trace_part["offline_aggregates"]["power_std"].append( + offline_trace_part["uW_std"] + ) + online_trace_part["offline_aggregates"]["energy"].append( + offline_trace_part["uW_mean"] * (offline_trace_part["us"] - 20) + ) + online_trace_part["offline_aggregates"]["paramkeys"].append(paramkeys) + online_trace_part["offline_aggregates"]["param"].append(paramvalues) + if online_trace_part["isa"] == "transition": + online_trace_part["offline_aggregates"]["rel_energy_prev"].append( + offline_trace_part["uW_mean_delta_prev"] + * (offline_trace_part["us"] - 20) + ) + online_trace_part["offline_aggregates"]["rel_energy_next"].append( + offline_trace_part["uW_mean_delta_next"] + * (offline_trace_part["us"] - 20) + ) + online_trace_part["offline_aggregates"]["timeout"].append( + offline_trace_part["timeout"] + ) def _merge_online_and_etlog(self, measurement): # Edits self.traces_by_fileno[measurement['fileno']][*]['trace'][*]['offline'] # and self.traces_by_fileno[measurement['fileno']][*]['trace'][*]['offline_aggregates'] in place # (appends data from measurement['energy_trace']) online_datapoints = [] - traces = self.traces_by_fileno[measurement['fileno']] + traces = self.traces_by_fileno[measurement["fileno"]] for run_idx, run in enumerate(traces): - for trace_part_idx in range(len(run['trace'])): + for trace_part_idx in range(len(run["trace"])): online_datapoints.append((run_idx, trace_part_idx)) for offline_idx, online_ref in enumerate(online_datapoints): online_run_idx, online_trace_part_idx = online_ref - offline_trace_part = measurement['energy_trace'][offline_idx] - online_trace_part = traces[online_run_idx]['trace'][online_trace_part_idx] + offline_trace_part = measurement["energy_trace"][offline_idx] + online_trace_part = traces[online_run_idx]["trace"][online_trace_part_idx] - if 'offline' not in online_trace_part: - online_trace_part['offline'] = [offline_trace_part] + if "offline" not in online_trace_part: + online_trace_part["offline"] = [offline_trace_part] else: - online_trace_part['offline'].append(offline_trace_part) + online_trace_part["offline"].append(offline_trace_part) - paramkeys = sorted(online_trace_part['parameter'].keys()) + paramkeys = sorted(online_trace_part["parameter"].keys()) paramvalues = list() for paramkey in paramkeys: - if type(online_trace_part['parameter'][paramkey]) is list: - paramvalues.append(soft_cast_int(online_trace_part['parameter'][paramkey][measurement['repeat_id']])) + if type(online_trace_part["parameter"][paramkey]) is list: + paramvalues.append( + soft_cast_int( + online_trace_part["parameter"][paramkey][ + measurement["repeat_id"] + ] + ) + ) else: - paramvalues.append(soft_cast_int(online_trace_part['parameter'][paramkey])) + paramvalues.append( + soft_cast_int(online_trace_part["parameter"][paramkey]) + ) # NB: Unscheduled transitions do not have an 'args' field set. # However, they should only be caused by interrupts, and # interrupts don't have args anyways. - if arg_support_enabled and 'args' in online_trace_part: - paramvalues.extend(map(soft_cast_int, online_trace_part['args'])) - - if 'offline_aggregates' not in online_trace_part: - online_trace_part['offline_aggregates'] = { - 'offline_attributes': ['power', 'duration', 'energy'], - 'duration': list(), - 'power': list(), - 'power_std': list(), - 'energy': list(), - 'paramkeys': list(), - 'param': list() + if arg_support_enabled and "args" in online_trace_part: + paramvalues.extend(map(soft_cast_int, online_trace_part["args"])) + + if "offline_aggregates" not in online_trace_part: + online_trace_part["offline_aggregates"] = { + "offline_attributes": ["power", "duration", "energy"], + "duration": list(), + "power": list(), + "power_std": list(), + "energy": list(), + "paramkeys": list(), + "param": list(), } - offline_aggregates = online_trace_part['offline_aggregates'] + offline_aggregates = online_trace_part["offline_aggregates"] # if online_trace_part['isa'] == 'transitions': # online_trace_part['offline_attributes'].extend(['rel_energy_prev', 'rel_energy_next']) # offline_aggregates['rel_energy_prev'] = list() # offline_aggregates['rel_energy_next'] = list() - offline_aggregates['duration'].append(offline_trace_part['s'] * 1e6) - offline_aggregates['power'].append(offline_trace_part['W_mean'] * 1e6) - offline_aggregates['power_std'].append(offline_trace_part['W_std'] * 1e6) - offline_aggregates['energy'].append(offline_trace_part['W_mean'] * offline_trace_part['s'] * 1e12) - offline_aggregates['paramkeys'].append(paramkeys) - offline_aggregates['param'].append(paramvalues) + offline_aggregates["duration"].append(offline_trace_part["s"] * 1e6) + offline_aggregates["power"].append(offline_trace_part["W_mean"] * 1e6) + offline_aggregates["power_std"].append(offline_trace_part["W_std"] * 1e6) + offline_aggregates["energy"].append( + offline_trace_part["W_mean"] * offline_trace_part["s"] * 1e12 + ) + offline_aggregates["paramkeys"].append(paramkeys) + offline_aggregates["param"].append(paramvalues) # if online_trace_part['isa'] == 'transition': # offline_aggregates['rel_energy_prev'].append(offline_trace_part['W_mean_delta_prev'] * offline_trace_part['s'] * 1e12) @@ -922,8 +1041,8 @@ class RawData: for trace in list_of_traces: trace_output.extend(trace.copy()) for i, trace in enumerate(trace_output): - trace['orig_id'] = trace['id'] - trace['id'] = i + trace["orig_id"] = trace["id"] + trace["id"] = i return trace_output def get_preprocessed_data(self, verbose=True): @@ -1000,25 +1119,29 @@ class RawData: if version == 0: with tarfile.open(filename) as tf: - self.setup_by_fileno.append(json.load(tf.extractfile('setup.json'))) - self.traces_by_fileno.append(json.load(tf.extractfile('src/apps/DriverEval/DriverLog.json'))) + self.setup_by_fileno.append(json.load(tf.extractfile("setup.json"))) + self.traces_by_fileno.append( + json.load(tf.extractfile("src/apps/DriverEval/DriverLog.json")) + ) for member in tf.getmembers(): _, extension = os.path.splitext(member.name) - if extension == '.mim': - offline_data.append({ - 'content': tf.extractfile(member).read(), - 'fileno': i, - 'info': member, - 'setup': self.setup_by_fileno[i], - 'with_traces': self.with_traces, - }) + if extension == ".mim": + offline_data.append( + { + "content": tf.extractfile(member).read(), + "fileno": i, + "info": member, + "setup": self.setup_by_fileno[i], + "with_traces": self.with_traces, + } + ) elif version == 1: new_filenames = list() with tarfile.open(filename) as tf: - ptalog = json.load(tf.extractfile(tf.getmember('ptalog.json'))) - self.pta = ptalog['pta'] + ptalog = json.load(tf.extractfile(tf.getmember("ptalog.json"))) + self.pta = ptalog["pta"] # Benchmark code may be too large to be executed in a single # run, so benchmarks (a benchmark is basically a list of DFA runs) @@ -1043,33 +1166,37 @@ class RawData: # ptalog['files'][0][0] is its first iteration/repetition, # ptalog['files'][0][1] the second, etc. - for j, traces in enumerate(ptalog['traces']): - new_filenames.append('{}#{}'.format(filename, j)) + for j, traces in enumerate(ptalog["traces"]): + new_filenames.append("{}#{}".format(filename, j)) self.traces_by_fileno.append(traces) - self.setup_by_fileno.append({ - 'mimosa_voltage': ptalog['configs'][j]['voltage'], - 'mimosa_shunt': ptalog['configs'][j]['shunt'], - 'state_duration': ptalog['opt']['sleep'], - }) - for repeat_id, mim_file in enumerate(ptalog['files'][j]): + self.setup_by_fileno.append( + { + "mimosa_voltage": ptalog["configs"][j]["voltage"], + "mimosa_shunt": ptalog["configs"][j]["shunt"], + "state_duration": ptalog["opt"]["sleep"], + } + ) + for repeat_id, mim_file in enumerate(ptalog["files"][j]): member = tf.getmember(mim_file) - offline_data.append({ - 'content': tf.extractfile(member).read(), - 'fileno': j, - 'info': member, - 'setup': self.setup_by_fileno[j], - 'repeat_id': repeat_id, - 'expected_trace': ptalog['traces'][j], - 'with_traces': self.with_traces, - }) + offline_data.append( + { + "content": tf.extractfile(member).read(), + "fileno": j, + "info": member, + "setup": self.setup_by_fileno[j], + "repeat_id": repeat_id, + "expected_trace": ptalog["traces"][j], + "with_traces": self.with_traces, + } + ) self.filenames = new_filenames elif version == 2: new_filenames = list() with tarfile.open(filename) as tf: - ptalog = json.load(tf.extractfile(tf.getmember('ptalog.json'))) - self.pta = ptalog['pta'] + ptalog = json.load(tf.extractfile(tf.getmember("ptalog.json"))) + self.pta = ptalog["pta"] # Benchmark code may be too large to be executed in a single # run, so benchmarks (a benchmark is basically a list of DFA runs) @@ -1103,32 +1230,45 @@ class RawData: # to an invalid measurement and thus power[b] corresponding # to duration[C]. At the moment, this is harmless, but in the # future it might not be. - if 'offline_aggregates' in ptalog['traces'][0][0]['trace'][0]: - for trace_group in ptalog['traces']: + if "offline_aggregates" in ptalog["traces"][0][0]["trace"][0]: + for trace_group in ptalog["traces"]: for trace in trace_group: - for state_or_transition in trace['trace']: - offline_aggregates = state_or_transition.pop('offline_aggregates', None) + for state_or_transition in trace["trace"]: + offline_aggregates = state_or_transition.pop( + "offline_aggregates", None + ) if offline_aggregates: - state_or_transition['online_aggregates'] = offline_aggregates + state_or_transition[ + "online_aggregates" + ] = offline_aggregates - for j, traces in enumerate(ptalog['traces']): - new_filenames.append('{}#{}'.format(filename, j)) + for j, traces in enumerate(ptalog["traces"]): + new_filenames.append("{}#{}".format(filename, j)) self.traces_by_fileno.append(traces) - self.setup_by_fileno.append({ - 'voltage': ptalog['configs'][j]['voltage'], - 'state_duration': ptalog['opt']['sleep'], - }) - for repeat_id, etlog_file in enumerate(ptalog['files'][j]): + self.setup_by_fileno.append( + { + "voltage": ptalog["configs"][j]["voltage"], + "state_duration": ptalog["opt"]["sleep"], + } + ) + for repeat_id, etlog_file in enumerate(ptalog["files"][j]): member = tf.getmember(etlog_file) - offline_data.append({ - 'content': tf.extractfile(member).read(), - 'fileno': j, - 'info': member, - 'setup': self.setup_by_fileno[j], - 'repeat_id': repeat_id, - 'expected_trace': ptalog['traces'][j], - 'transition_names': list(map(lambda x: x['name'], ptalog['pta']['transitions'])) - }) + offline_data.append( + { + "content": tf.extractfile(member).read(), + "fileno": j, + "info": member, + "setup": self.setup_by_fileno[j], + "repeat_id": repeat_id, + "expected_trace": ptalog["traces"][j], + "transition_names": list( + map( + lambda x: x["name"], + ptalog["pta"]["transitions"], + ) + ), + } + ) self.filenames = new_filenames # TODO remove 'offline_aggregates' from pre-parse data and place # it under 'online_aggregates' or similar instead. This way, if @@ -1145,52 +1285,69 @@ class RawData: num_valid = 0 for measurement in measurements: - if 'energy_trace' not in measurement: - vprint(self.verbose, '[W] Skipping {ar:s}/{m:s}: {e:s}'.format( - ar=self.filenames[measurement['fileno']], - m=measurement['info'].name, - e='; '.join(measurement['datasource_errors']))) + if "energy_trace" not in measurement: + vprint( + self.verbose, + "[W] Skipping {ar:s}/{m:s}: {e:s}".format( + ar=self.filenames[measurement["fileno"]], + m=measurement["info"].name, + e="; ".join(measurement["datasource_errors"]), + ), + ) continue if version == 0: # Strip the last state (it is not part of the scheduled measurement) - measurement['energy_trace'].pop() + measurement["energy_trace"].pop() elif version == 1: # The first online measurement is the UNINITIALIZED state. In v1, # it is not part of the expected PTA trace -> remove it. - measurement['energy_trace'].pop(0) + measurement["energy_trace"].pop(0) if version == 0 or version == 1: if self._measurement_is_valid_01(measurement): self._merge_online_and_offline(measurement) num_valid += 1 else: - vprint(self.verbose, '[W] Skipping {ar:s}/{m:s}: {e:s}'.format( - ar=self.filenames[measurement['fileno']], - m=measurement['info'].name, - e=measurement['error'])) + vprint( + self.verbose, + "[W] Skipping {ar:s}/{m:s}: {e:s}".format( + ar=self.filenames[measurement["fileno"]], + m=measurement["info"].name, + e=measurement["error"], + ), + ) elif version == 2: if self._measurement_is_valid_2(measurement): self._merge_online_and_etlog(measurement) num_valid += 1 else: - vprint(self.verbose, '[W] Skipping {ar:s}/{m:s}: {e:s}'.format( - ar=self.filenames[measurement['fileno']], - m=measurement['info'].name, - e=measurement['error'])) - vprint(self.verbose, '[I] {num_valid:d}/{num_total:d} measurements are valid'.format( - num_valid=num_valid, - num_total=len(measurements))) + vprint( + self.verbose, + "[W] Skipping {ar:s}/{m:s}: {e:s}".format( + ar=self.filenames[measurement["fileno"]], + m=measurement["info"].name, + e=measurement["error"], + ), + ) + vprint( + self.verbose, + "[I] {num_valid:d}/{num_total:d} measurements are valid".format( + num_valid=num_valid, num_total=len(measurements) + ), + ) if version == 0: self.traces = self._concatenate_traces(self.traces_by_fileno) elif version == 1: - self.traces = self._concatenate_traces(map(lambda x: x['expected_trace'], measurements)) + self.traces = self._concatenate_traces( + map(lambda x: x["expected_trace"], measurements) + ) self.traces = self._concatenate_traces(self.traces_by_fileno) elif version == 2: self.traces = self._concatenate_traces(self.traces_by_fileno) self.preprocessing_stats = { - 'num_runs': len(measurements), - 'num_valid': num_valid + "num_runs": len(measurements), + "num_valid": num_valid, } @@ -1207,16 +1364,33 @@ class ParallelParamFit: self.fit_queue = [] self.by_param = by_param - def enqueue(self, state_or_tran, attribute, param_index, param_name, safe_functions_enabled=False, param_filter=None): + def enqueue( + self, + state_or_tran, + attribute, + param_index, + param_name, + safe_functions_enabled=False, + param_filter=None, + ): """ Add state_or_tran/attribute/param_name to fit queue. This causes fit() to compute the best-fitting function for this model part. """ - self.fit_queue.append({ - 'key': [state_or_tran, attribute, param_name, param_filter], - 'args': [self.by_param, state_or_tran, attribute, param_index, safe_functions_enabled, param_filter] - }) + self.fit_queue.append( + { + "key": [state_or_tran, attribute, param_name, param_filter], + "args": [ + self.by_param, + state_or_tran, + attribute, + param_index, + safe_functions_enabled, + param_filter, + ], + } + ) def fit(self): """ @@ -1236,13 +1410,17 @@ def _try_fits_parallel(arg): Must be a global function as it is called from a multiprocessing Pool. """ - return { - 'key': arg['key'], - 'result': _try_fits(*arg['args']) - } + return {"key": arg["key"], "result": _try_fits(*arg["args"])} -def _try_fits(by_param, state_or_tran, model_attribute, param_index, safe_functions_enabled=False, param_filter: dict = None): +def _try_fits( + by_param, + state_or_tran, + model_attribute, + param_index, + safe_functions_enabled=False, + param_filter: dict = None, +): """ Determine goodness-of-fit for prediction of `by_param[(state_or_tran, *)][model_attribute]` dependence on `param_index` using various functions. @@ -1281,22 +1459,28 @@ def _try_fits(by_param, state_or_tran, model_attribute, param_index, safe_functi function_names = list(functions.keys()) for function_name in function_names: function_object = functions[function_name] - if is_numeric(param_key[1][param_index]) and not function_object.is_valid(param_key[1][param_index]): + if is_numeric(param_key[1][param_index]) and not function_object.is_valid( + param_key[1][param_index] + ): functions.pop(function_name, None) raw_results = dict() raw_results_by_param = dict() - ref_results = { - 'mean': list(), - 'median': list() - } + ref_results = {"mean": list(), "median": list()} results = dict() results_by_param = dict() seen_parameter_combinations = set() # for each parameter combination: - for param_key in filter(lambda x: x[0] == state_or_tran and remove_index_from_tuple(x[1], param_index) not in seen_parameter_combinations and len(by_param[x]['param']) and match_parameter_values(by_param[x]['param'][0], param_filter), by_param.keys()): + for param_key in filter( + lambda x: x[0] == state_or_tran + and remove_index_from_tuple(x[1], param_index) + not in seen_parameter_combinations + and len(by_param[x]["param"]) + and match_parameter_values(by_param[x]["param"][0], param_filter), + by_param.keys(), + ): X = [] Y = [] num_valid = 0 @@ -1304,10 +1488,14 @@ def _try_fits(by_param, state_or_tran, model_attribute, param_index, safe_functi # Ensure that each parameter combination is only optimized once. Otherwise, with parameters (1, 2, 5), (1, 3, 5), (1, 4, 5) and param_index == 1, # the parameter combination (1, *, 5) would be optimized three times, both wasting time and biasing results towards more frequently occuring combinations of non-param_index parameters - seen_parameter_combinations.add(remove_index_from_tuple(param_key[1], param_index)) + seen_parameter_combinations.add( + remove_index_from_tuple(param_key[1], param_index) + ) # for each value of the parameter denoted by param_index (all other parameters remain the same): - for k, v in filter(lambda kv: param_slice_eq(kv[0], param_key, param_index), by_param.items()): + for k, v in filter( + lambda kv: param_slice_eq(kv[0], param_key, param_index), by_param.items() + ): num_total += 1 if is_numeric(k[1][param_index]): num_valid += 1 @@ -1324,7 +1512,9 @@ def _try_fits(by_param, state_or_tran, model_attribute, param_index, safe_functi if function_name not in raw_results: raw_results[function_name] = dict() error_function = param_function.error_function - res = optimize.least_squares(error_function, [0, 1], args=(X, Y), xtol=2e-15) + res = optimize.least_squares( + error_function, [0, 1], args=(X, Y), xtol=2e-15 + ) measures = regression_measures(param_function.eval(res.x, X), Y) raw_results_by_param[other_parameters][function_name] = measures for measure, error_rate in measures.items(): @@ -1333,38 +1523,37 @@ def _try_fits(by_param, state_or_tran, model_attribute, param_index, safe_functi raw_results[function_name][measure].append(error_rate) # print(function_name, res, measures) mean_measures = aggregate_measures(np.mean(Y), Y) - ref_results['mean'].append(mean_measures['rmsd']) - raw_results_by_param[other_parameters]['mean'] = mean_measures + ref_results["mean"].append(mean_measures["rmsd"]) + raw_results_by_param[other_parameters]["mean"] = mean_measures median_measures = aggregate_measures(np.median(Y), Y) - ref_results['median'].append(median_measures['rmsd']) - raw_results_by_param[other_parameters]['median'] = median_measures + ref_results["median"].append(median_measures["rmsd"]) + raw_results_by_param[other_parameters]["median"] = median_measures - if not len(ref_results['mean']): + if not len(ref_results["mean"]): # Insufficient data for fitting # print('[W] Insufficient data for fitting {}/{}/{}'.format(state_or_tran, model_attribute, param_index)) - return { - 'best': None, - 'best_rmsd': np.inf, - 'results': results - } + return {"best": None, "best_rmsd": np.inf, "results": results} - for other_parameter_combination, other_parameter_results in raw_results_by_param.items(): + for ( + other_parameter_combination, + other_parameter_results, + ) in raw_results_by_param.items(): best_fit_val = np.inf best_fit_name = None results = dict() for function_name, result in other_parameter_results.items(): if len(result) > 0: results[function_name] = result - rmsd = result['rmsd'] + rmsd = result["rmsd"] if rmsd < best_fit_val: best_fit_val = rmsd best_fit_name = function_name results_by_param[other_parameter_combination] = { - 'best': best_fit_name, - 'best_rmsd': best_fit_val, - 'mean_rmsd': results['mean']['rmsd'], - 'median_rmsd': results['median']['rmsd'], - 'results': results + "best": best_fit_name, + "best_rmsd": best_fit_val, + "mean_rmsd": results["mean"]["rmsd"], + "median_rmsd": results["median"]["rmsd"], + "results": results, } best_fit_val = np.inf @@ -1375,26 +1564,26 @@ def _try_fits(by_param, state_or_tran, model_attribute, param_index, safe_functi results[function_name] = {} for measure in result.keys(): results[function_name][measure] = np.mean(result[measure]) - rmsd = results[function_name]['rmsd'] + rmsd = results[function_name]["rmsd"] if rmsd < best_fit_val: best_fit_val = rmsd best_fit_name = function_name return { - 'best': best_fit_name, - 'best_rmsd': best_fit_val, - 'mean_rmsd': np.mean(ref_results['mean']), - 'median_rmsd': np.mean(ref_results['median']), - 'results': results, - 'results_by_other_param': results_by_param + "best": best_fit_name, + "best_rmsd": best_fit_val, + "mean_rmsd": np.mean(ref_results["mean"]), + "median_rmsd": np.mean(ref_results["median"]), + "results": results, + "results_by_other_param": results_by_param, } def _num_args_from_by_name(by_name): num_args = dict() for key, value in by_name.items(): - if 'args' in value: - num_args[key] = len(value['args'][0]) + if "args" in value: + num_args[key] = len(value["args"][0]) return num_args @@ -1413,19 +1602,44 @@ def get_fit_result(results, name, attribute, verbose=False, param_filter: dict = """ fit_result = dict() for result in results: - if result['key'][0] == name and result['key'][1] == attribute and result['key'][3] == param_filter and result['result']['best'] is not None: # dürfte an ['best'] != None liegen-> Fit für gefilterten Kram schlägt fehl? - this_result = result['result'] - if this_result['best_rmsd'] >= min(this_result['mean_rmsd'], this_result['median_rmsd']): - vprint(verbose, '[I] Not modeling {} {} as function of {}: best ({:.0f}) is worse than ref ({:.0f}, {:.0f})'.format( - name, attribute, result['key'][2], this_result['best_rmsd'], - this_result['mean_rmsd'], this_result['median_rmsd'])) + if ( + result["key"][0] == name + and result["key"][1] == attribute + and result["key"][3] == param_filter + and result["result"]["best"] is not None + ): # dürfte an ['best'] != None liegen-> Fit für gefilterten Kram schlägt fehl? + this_result = result["result"] + if this_result["best_rmsd"] >= min( + this_result["mean_rmsd"], this_result["median_rmsd"] + ): + vprint( + verbose, + "[I] Not modeling {} {} as function of {}: best ({:.0f}) is worse than ref ({:.0f}, {:.0f})".format( + name, + attribute, + result["key"][2], + this_result["best_rmsd"], + this_result["mean_rmsd"], + this_result["median_rmsd"], + ), + ) # See notes on depends_on_param - elif this_result['best_rmsd'] >= 0.8 * min(this_result['mean_rmsd'], this_result['median_rmsd']): - vprint(verbose, '[I] Not modeling {} {} as function of {}: best ({:.0f}) is not much better than ref ({:.0f}, {:.0f})'.format( - name, attribute, result['key'][2], this_result['best_rmsd'], - this_result['mean_rmsd'], this_result['median_rmsd'])) + elif this_result["best_rmsd"] >= 0.8 * min( + this_result["mean_rmsd"], this_result["median_rmsd"] + ): + vprint( + verbose, + "[I] Not modeling {} {} as function of {}: best ({:.0f}) is not much better than ref ({:.0f}, {:.0f})".format( + name, + attribute, + result["key"][2], + this_result["best_rmsd"], + this_result["mean_rmsd"], + this_result["median_rmsd"], + ), + ) else: - fit_result[result['key'][2]] = this_result + fit_result[result["key"][2]] = this_result return fit_result @@ -1471,7 +1685,15 @@ class AnalyticModel: assess -- calculate model quality """ - def __init__(self, by_name, parameters, arg_count=None, function_override=dict(), verbose=True, use_corrcoef=False): + def __init__( + self, + by_name, + parameters, + arg_count=None, + function_override=dict(), + verbose=True, + use_corrcoef=False, + ): """ Create a new AnalyticModel and compute parameter statistics. @@ -1521,19 +1743,29 @@ class AnalyticModel: if self._num_args is None: self._num_args = _num_args_from_by_name(by_name) - self.stats = ParamStats(self.by_name, self.by_param, self.parameters, self._num_args, verbose=verbose, use_corrcoef=use_corrcoef) + self.stats = ParamStats( + self.by_name, + self.by_param, + self.parameters, + self._num_args, + verbose=verbose, + use_corrcoef=use_corrcoef, + ) def _get_model_from_dict(self, model_dict, model_function): model = {} for name, elem in model_dict.items(): model[name] = {} - for key in elem['attributes']: + for key in elem["attributes"]: try: model[name][key] = model_function(elem[key]) except RuntimeWarning: - vprint(self.verbose, '[W] Got no data for {} {}'.format(name, key)) + vprint(self.verbose, "[W] Got no data for {} {}".format(name, key)) except FloatingPointError as fpe: - vprint(self.verbose, '[W] Got no data for {} {}: {}'.format(name, key, fpe)) + vprint( + self.verbose, + "[W] Got no data for {} {}: {}".format(name, key, fpe), + ) return model def param_index(self, param_name): @@ -1596,22 +1828,28 @@ class AnalyticModel: model_function(name, attribute, param=parameter values) -> model value. model_info(name, attribute) -> {'fit_result' : ..., 'function' : ... } or None """ - if 'fitted_model_getter' in self.cache and 'fitted_info_getter' in self.cache: - return self.cache['fitted_model_getter'], self.cache['fitted_info_getter'] + if "fitted_model_getter" in self.cache and "fitted_info_getter" in self.cache: + return self.cache["fitted_model_getter"], self.cache["fitted_info_getter"] static_model = self._get_model_from_dict(self.by_name, np.median) param_model = dict([[name, {}] for name in self.by_name.keys()]) paramfit = ParallelParamFit(self.by_param) for name in self.by_name.keys(): - for attribute in self.by_name[name]['attributes']: + for attribute in self.by_name[name]["attributes"]: for param_index, param in enumerate(self.parameters): if self.stats.depends_on_param(name, attribute, param): paramfit.enqueue(name, attribute, param_index, param, False) if arg_support_enabled and name in self._num_args: for arg_index in range(self._num_args[name]): if self.stats.depends_on_arg(name, attribute, arg_index): - paramfit.enqueue(name, attribute, len(self.parameters) + arg_index, arg_index, False) + paramfit.enqueue( + name, + attribute, + len(self.parameters) + arg_index, + arg_index, + False, + ) paramfit.fit() @@ -1619,8 +1857,10 @@ class AnalyticModel: num_args = 0 if name in self._num_args: num_args = self._num_args[name] - for attribute in self.by_name[name]['attributes']: - fit_result = get_fit_result(paramfit.results, name, attribute, self.verbose) + for attribute in self.by_name[name]["attributes"]: + fit_result = get_fit_result( + paramfit.results, name, attribute, self.verbose + ) if (name, attribute) in self.function_override: function_str = self.function_override[(name, attribute)] @@ -1628,25 +1868,27 @@ class AnalyticModel: x.fit(self.by_param, name, attribute) if x.fit_success: param_model[name][attribute] = { - 'fit_result': fit_result, - 'function': x + "fit_result": fit_result, + "function": x, } elif len(fit_result.keys()): - x = analytic.function_powerset(fit_result, self.parameters, num_args) + x = analytic.function_powerset( + fit_result, self.parameters, num_args + ) x.fit(self.by_param, name, attribute) if x.fit_success: param_model[name][attribute] = { - 'fit_result': fit_result, - 'function': x + "fit_result": fit_result, + "function": x, } def model_getter(name, key, **kwargs): - if 'arg' in kwargs and 'param' in kwargs: - kwargs['param'].extend(map(soft_cast_int, kwargs['arg'])) + if "arg" in kwargs and "param" in kwargs: + kwargs["param"].extend(map(soft_cast_int, kwargs["arg"])) if key in param_model[name]: - param_list = kwargs['param'] - param_function = param_model[name][key]['function'] + param_list = kwargs["param"] + param_function = param_model[name][key]["function"] if param_function.is_predictable(param_list): return param_function.eval(param_list) return static_model[name][key] @@ -1656,8 +1898,8 @@ class AnalyticModel: return param_model[name][key] return None - self.cache['fitted_model_getter'] = model_getter - self.cache['fitted_info_getter'] = info_getter + self.cache["fitted_model_getter"] = model_getter + self.cache["fitted_info_getter"] = info_getter return model_getter, info_getter @@ -1677,13 +1919,22 @@ class AnalyticModel: detailed_results = {} for name, elem in sorted(self.by_name.items()): detailed_results[name] = {} - for attribute in elem['attributes']: - predicted_data = np.array(list(map(lambda i: model_function(name, attribute, param=elem['param'][i]), range(len(elem[attribute]))))) + for attribute in elem["attributes"]: + predicted_data = np.array( + list( + map( + lambda i: model_function( + name, attribute, param=elem["param"][i] + ), + range(len(elem[attribute])), + ) + ) + ) measures = regression_measures(predicted_data, elem[attribute]) detailed_results[name][attribute] = measures return { - 'by_name': detailed_results, + "by_name": detailed_results, } def to_json(self): @@ -1695,25 +1946,28 @@ def _add_trace_data_to_aggregate(aggregate, key, element): # Only cares about element['isa'], element['offline_aggregates'], and # element['plan']['level'] if key not in aggregate: - aggregate[key] = { - 'isa': element['isa'] - } - for datakey in element['offline_aggregates'].keys(): + aggregate[key] = {"isa": element["isa"]} + for datakey in element["offline_aggregates"].keys(): aggregate[key][datakey] = [] - if element['isa'] == 'state': - aggregate[key]['attributes'] = ['power'] + if element["isa"] == "state": + aggregate[key]["attributes"] = ["power"] else: # TODO do not hardcode values - aggregate[key]['attributes'] = ['duration', 'energy', 'rel_energy_prev', 'rel_energy_next'] + aggregate[key]["attributes"] = [ + "duration", + "energy", + "rel_energy_prev", + "rel_energy_next", + ] # Uncomment this line if you also want to analyze mean transition power # aggrgate[key]['attributes'].append('power') - if 'plan' in element and element['plan']['level'] == 'epilogue': - aggregate[key]['attributes'].insert(0, 'timeout') - attributes = aggregate[key]['attributes'].copy() + if "plan" in element and element["plan"]["level"] == "epilogue": + aggregate[key]["attributes"].insert(0, "timeout") + attributes = aggregate[key]["attributes"].copy() for attribute in attributes: - if attribute not in element['offline_aggregates']: - aggregate[key]['attributes'].remove(attribute) - for datakey, dataval in element['offline_aggregates'].items(): + if attribute not in element["offline_aggregates"]: + aggregate[key]["attributes"].remove(attribute) + for datakey, dataval in element["offline_aggregates"].items(): aggregate[key][datakey].extend(dataval) @@ -1771,16 +2025,20 @@ def pta_trace_to_aggregate(traces, ignore_trace_indexes=[]): """ arg_count = dict() by_name = dict() - parameter_names = sorted(traces[0]['trace'][0]['parameter'].keys()) + parameter_names = sorted(traces[0]["trace"][0]["parameter"].keys()) for run in traces: - if run['id'] not in ignore_trace_indexes: - for elem in run['trace']: - if elem['isa'] == 'transition' and not elem['name'] in arg_count and 'args' in elem: - arg_count[elem['name']] = len(elem['args']) - if elem['name'] != 'UNINITIALIZED': - _add_trace_data_to_aggregate(by_name, elem['name'], elem) + if run["id"] not in ignore_trace_indexes: + for elem in run["trace"]: + if ( + elem["isa"] == "transition" + and not elem["name"] in arg_count + and "args" in elem + ): + arg_count[elem["name"]] = len(elem["args"]) + if elem["name"] != "UNINITIALIZED": + _add_trace_data_to_aggregate(by_name, elem["name"], elem) for elem in by_name.values(): - for key in elem['attributes']: + for key in elem["attributes"]: elem[key] = np.array(elem[key]) return by_name, parameter_names, arg_count @@ -1817,7 +2075,19 @@ class PTAModel: - rel_energy_next: transition energy relative to next state mean power in pJ """ - def __init__(self, by_name, parameters, arg_count, traces=[], ignore_trace_indexes=[], discard_outliers=None, function_override={}, verbose=True, use_corrcoef=False, pta=None): + def __init__( + self, + by_name, + parameters, + arg_count, + traces=[], + ignore_trace_indexes=[], + discard_outliers=None, + function_override={}, + verbose=True, + use_corrcoef=False, + pta=None, + ): """ Prepare a new PTA energy model. @@ -1854,9 +2124,16 @@ class PTAModel: self._num_args = arg_count self._use_corrcoef = use_corrcoef self.traces = traces - self.stats = ParamStats(self.by_name, self.by_param, self._parameter_names, self._num_args, self._use_corrcoef, verbose=verbose) + self.stats = ParamStats( + self.by_name, + self.by_param, + self._parameter_names, + self._num_args, + self._use_corrcoef, + verbose=verbose, + ) self.cache = {} - np.seterr('raise') + np.seterr("raise") self._outlier_threshold = discard_outliers self.function_override = function_override.copy() self.verbose = verbose @@ -1866,7 +2143,7 @@ class PTAModel: def _aggregate_to_ndarray(self, aggregate): for elem in aggregate.values(): - for key in elem['attributes']: + for key in elem["attributes"]: elem[key] = np.array(elem[key]) # This heuristic is very similar to the "function is not much better than @@ -1884,13 +2161,16 @@ class PTAModel: model = {} for name, elem in model_dict.items(): model[name] = {} - for key in elem['attributes']: + for key in elem["attributes"]: try: model[name][key] = model_function(elem[key]) except RuntimeWarning: - vprint(self.verbose, '[W] Got no data for {} {}'.format(name, key)) + vprint(self.verbose, "[W] Got no data for {} {}".format(name, key)) except FloatingPointError as fpe: - vprint(self.verbose, '[W] Got no data for {} {}: {}'.format(name, key, fpe)) + vprint( + self.verbose, + "[W] Got no data for {} {}: {}".format(name, key, fpe), + ) return model def get_static(self, use_mean=False): @@ -1953,63 +2233,110 @@ class PTAModel: model_function(name, attribute, param=parameter values) -> model value. model_info(name, attribute) -> {'fit_result' : ..., 'function' : ... } or None """ - if 'fitted_model_getter' in self.cache and 'fitted_info_getter' in self.cache: - return self.cache['fitted_model_getter'], self.cache['fitted_info_getter'] + if "fitted_model_getter" in self.cache and "fitted_info_getter" in self.cache: + return self.cache["fitted_model_getter"], self.cache["fitted_info_getter"] static_model = self._get_model_from_dict(self.by_name, np.median) - param_model = dict([[state_or_tran, {}] for state_or_tran in self.by_name.keys()]) + param_model = dict( + [[state_or_tran, {}] for state_or_tran in self.by_name.keys()] + ) paramfit = ParallelParamFit(self.by_param) for state_or_tran in self.by_name.keys(): - for model_attribute in self.by_name[state_or_tran]['attributes']: + for model_attribute in self.by_name[state_or_tran]["attributes"]: fit_results = {} for parameter_index, parameter_name in enumerate(self._parameter_names): - if self.depends_on_param(state_or_tran, model_attribute, parameter_name): - paramfit.enqueue(state_or_tran, model_attribute, parameter_index, parameter_name, safe_functions_enabled) - for codependent_param_dict in self.stats.codependent_parameter_value_dicts(state_or_tran, model_attribute, parameter_name): - paramfit.enqueue(state_or_tran, model_attribute, parameter_index, parameter_name, safe_functions_enabled, codependent_param_dict) - if arg_support_enabled and self.by_name[state_or_tran]['isa'] == 'transition': + if self.depends_on_param( + state_or_tran, model_attribute, parameter_name + ): + paramfit.enqueue( + state_or_tran, + model_attribute, + parameter_index, + parameter_name, + safe_functions_enabled, + ) + for ( + codependent_param_dict + ) in self.stats.codependent_parameter_value_dicts( + state_or_tran, model_attribute, parameter_name + ): + paramfit.enqueue( + state_or_tran, + model_attribute, + parameter_index, + parameter_name, + safe_functions_enabled, + codependent_param_dict, + ) + if ( + arg_support_enabled + and self.by_name[state_or_tran]["isa"] == "transition" + ): for arg_index in range(self._num_args[state_or_tran]): - if self.depends_on_arg(state_or_tran, model_attribute, arg_index): - paramfit.enqueue(state_or_tran, model_attribute, len(self._parameter_names) + arg_index, arg_index, safe_functions_enabled) + if self.depends_on_arg( + state_or_tran, model_attribute, arg_index + ): + paramfit.enqueue( + state_or_tran, + model_attribute, + len(self._parameter_names) + arg_index, + arg_index, + safe_functions_enabled, + ) paramfit.fit() for state_or_tran in self.by_name.keys(): num_args = 0 - if arg_support_enabled and self.by_name[state_or_tran]['isa'] == 'transition': + if ( + arg_support_enabled + and self.by_name[state_or_tran]["isa"] == "transition" + ): num_args = self._num_args[state_or_tran] - for model_attribute in self.by_name[state_or_tran]['attributes']: - fit_results = get_fit_result(paramfit.results, state_or_tran, model_attribute, self.verbose) + for model_attribute in self.by_name[state_or_tran]["attributes"]: + fit_results = get_fit_result( + paramfit.results, state_or_tran, model_attribute, self.verbose + ) for parameter_name in self._parameter_names: - if self.depends_on_param(state_or_tran, model_attribute, parameter_name): - for codependent_param_dict in self.stats.codependent_parameter_value_dicts(state_or_tran, model_attribute, parameter_name): + if self.depends_on_param( + state_or_tran, model_attribute, parameter_name + ): + for ( + codependent_param_dict + ) in self.stats.codependent_parameter_value_dicts( + state_or_tran, model_attribute, parameter_name + ): pass # FIXME get_fit_result hat ja gar keinen Parameter als Argument... if (state_or_tran, model_attribute) in self.function_override: - function_str = self.function_override[(state_or_tran, model_attribute)] + function_str = self.function_override[ + (state_or_tran, model_attribute) + ] x = AnalyticFunction(function_str, self._parameter_names, num_args) x.fit(self.by_param, state_or_tran, model_attribute) if x.fit_success: param_model[state_or_tran][model_attribute] = { - 'fit_result': fit_results, - 'function': x + "fit_result": fit_results, + "function": x, } elif len(fit_results.keys()): - x = analytic.function_powerset(fit_results, self._parameter_names, num_args) + x = analytic.function_powerset( + fit_results, self._parameter_names, num_args + ) x.fit(self.by_param, state_or_tran, model_attribute) if x.fit_success: param_model[state_or_tran][model_attribute] = { - 'fit_result': fit_results, - 'function': x + "fit_result": fit_results, + "function": x, } def model_getter(name, key, **kwargs): - if 'arg' in kwargs and 'param' in kwargs: - kwargs['param'].extend(map(soft_cast_int, kwargs['arg'])) + if "arg" in kwargs and "param" in kwargs: + kwargs["param"].extend(map(soft_cast_int, kwargs["arg"])) if key in param_model[name]: - param_list = kwargs['param'] - param_function = param_model[name][key]['function'] + param_list = kwargs["param"] + param_function = param_model[name][key]["function"] if param_function.is_predictable(param_list): return param_function.eval(param_list) return static_model[name][key] @@ -2019,8 +2346,8 @@ class PTAModel: return param_model[name][key] return None - self.cache['fitted_model_getter'] = model_getter - self.cache['fitted_info_getter'] = info_getter + self.cache["fitted_model_getter"] = model_getter + self.cache["fitted_info_getter"] = info_getter return model_getter, info_getter @@ -2029,16 +2356,32 @@ class PTAModel: static_quality = self.assess(static_model) param_model, param_info = self.get_fitted() analytic_quality = self.assess(param_model) - self.pta.update(static_model, param_info, static_error=static_quality['by_name'], analytic_error=analytic_quality['by_name']) + self.pta.update( + static_model, + param_info, + static_error=static_quality["by_name"], + analytic_error=analytic_quality["by_name"], + ) return self.pta.to_json() def states(self): """Return sorted list of state names.""" - return sorted(list(filter(lambda k: self.by_name[k]['isa'] == 'state', self.by_name.keys()))) + return sorted( + list( + filter(lambda k: self.by_name[k]["isa"] == "state", self.by_name.keys()) + ) + ) def transitions(self): """Return sorted list of transition names.""" - return sorted(list(filter(lambda k: self.by_name[k]['isa'] == 'transition', self.by_name.keys()))) + return sorted( + list( + filter( + lambda k: self.by_name[k]["isa"] == "transition", + self.by_name.keys(), + ) + ) + ) def states_and_transitions(self): """Return list of states and transition names.""" @@ -2050,7 +2393,7 @@ class PTAModel: return self._parameter_names def attributes(self, state_or_trans): - return self.by_name[state_or_trans]['attributes'] + return self.by_name[state_or_trans]["attributes"] def assess(self, model_function): """ @@ -2068,16 +2411,23 @@ class PTAModel: detailed_results = {} for name, elem in sorted(self.by_name.items()): detailed_results[name] = {} - for key in elem['attributes']: - predicted_data = np.array(list(map(lambda i: model_function(name, key, param=elem['param'][i]), range(len(elem[key]))))) + for key in elem["attributes"]: + predicted_data = np.array( + list( + map( + lambda i: model_function(name, key, param=elem["param"][i]), + range(len(elem[key])), + ) + ) + ) measures = regression_measures(predicted_data, elem[key]) detailed_results[name][key] = measures - return { - 'by_name': detailed_results - } + return {"by_name": detailed_results} - def assess_states(self, model_function, model_attribute='power', distribution: dict = None): + def assess_states( + self, model_function, model_attribute="power", distribution: dict = None + ): """ Calculate overall model error assuming equal distribution of states """ @@ -2089,7 +2439,9 @@ class PTAModel: distribution = dict(map(lambda x: [x, 1 / num_states], self.states())) if not np.isclose(sum(distribution.values()), 1): - raise ValueError('distribution must be a probability distribution with sum 1') + raise ValueError( + "distribution must be a probability distribution with sum 1" + ) # total_value = None # try: @@ -2097,7 +2449,17 @@ class PTAModel: # except KeyError: # pass - total_error = np.sqrt(sum(map(lambda x: np.square(model_quality['by_name'][x][model_attribute]['mae'] * distribution[x]), self.states()))) + total_error = np.sqrt( + sum( + map( + lambda x: np.square( + model_quality["by_name"][x][model_attribute]["mae"] + * distribution[x] + ), + self.states(), + ) + ) + ) return total_error def assess_on_traces(self, model_function): @@ -2118,44 +2480,72 @@ class PTAModel: real_timeout_list = [] for trace in self.traces: - if trace['id'] not in self.ignore_trace_indexes: - for rep_id in range(len(trace['trace'][0]['offline'])): - model_energy = 0. - real_energy = 0. - model_rel_energy = 0. - model_state_energy = 0. - model_duration = 0. - real_duration = 0. - model_timeout = 0. - real_timeout = 0. - for i, trace_part in enumerate(trace['trace']): - name = trace_part['name'] - prev_name = trace['trace'][i - 1]['name'] - isa = trace_part['isa'] - if name != 'UNINITIALIZED': + if trace["id"] not in self.ignore_trace_indexes: + for rep_id in range(len(trace["trace"][0]["offline"])): + model_energy = 0.0 + real_energy = 0.0 + model_rel_energy = 0.0 + model_state_energy = 0.0 + model_duration = 0.0 + real_duration = 0.0 + model_timeout = 0.0 + real_timeout = 0.0 + for i, trace_part in enumerate(trace["trace"]): + name = trace_part["name"] + prev_name = trace["trace"][i - 1]["name"] + isa = trace_part["isa"] + if name != "UNINITIALIZED": try: - param = trace_part['offline_aggregates']['param'][rep_id] - prev_param = trace['trace'][i - 1]['offline_aggregates']['param'][rep_id] - power = trace_part['offline'][rep_id]['uW_mean'] - duration = trace_part['offline'][rep_id]['us'] - prev_duration = trace['trace'][i - 1]['offline'][rep_id]['us'] + param = trace_part["offline_aggregates"]["param"][ + rep_id + ] + prev_param = trace["trace"][i - 1][ + "offline_aggregates" + ]["param"][rep_id] + power = trace_part["offline"][rep_id]["uW_mean"] + duration = trace_part["offline"][rep_id]["us"] + prev_duration = trace["trace"][i - 1]["offline"][ + rep_id + ]["us"] real_energy += power * duration - if isa == 'state': - model_energy += model_function(name, 'power', param=param) * duration + if isa == "state": + model_energy += ( + model_function(name, "power", param=param) + * duration + ) else: - model_energy += model_function(name, 'energy', param=param) + model_energy += model_function( + name, "energy", param=param + ) # If i == 1, the previous state was UNINITIALIZED, for which we do not have model data if i == 1: - model_rel_energy += model_function(name, 'energy', param=param) + model_rel_energy += model_function( + name, "energy", param=param + ) else: - model_rel_energy += model_function(prev_name, 'power', param=prev_param) * (prev_duration + duration) - model_state_energy += model_function(prev_name, 'power', param=prev_param) * (prev_duration + duration) - model_rel_energy += model_function(name, 'rel_energy_prev', param=param) + model_rel_energy += model_function( + prev_name, "power", param=prev_param + ) * (prev_duration + duration) + model_state_energy += model_function( + prev_name, "power", param=prev_param + ) * (prev_duration + duration) + model_rel_energy += model_function( + name, "rel_energy_prev", param=param + ) real_duration += duration - model_duration += model_function(name, 'duration', param=param) - if 'plan' in trace_part and trace_part['plan']['level'] == 'epilogue': - real_timeout += trace_part['offline'][rep_id]['timeout'] - model_timeout += model_function(name, 'timeout', param=param) + model_duration += model_function( + name, "duration", param=param + ) + if ( + "plan" in trace_part + and trace_part["plan"]["level"] == "epilogue" + ): + real_timeout += trace_part["offline"][rep_id][ + "timeout" + ] + model_timeout += model_function( + name, "timeout", param=param + ) except KeyError: # if states/transitions have been removed via --filter-param, this is harmless pass @@ -2169,11 +2559,21 @@ class PTAModel: model_timeout_list.append(model_timeout) return { - 'duration_by_trace': regression_measures(np.array(model_duration_list), np.array(real_duration_list)), - 'energy_by_trace': regression_measures(np.array(model_energy_list), np.array(real_energy_list)), - 'timeout_by_trace': regression_measures(np.array(model_timeout_list), np.array(real_timeout_list)), - 'rel_energy_by_trace': regression_measures(np.array(model_rel_energy_list), np.array(real_energy_list)), - 'state_energy_by_trace': regression_measures(np.array(model_state_energy_list), np.array(real_energy_list)), + "duration_by_trace": regression_measures( + np.array(model_duration_list), np.array(real_duration_list) + ), + "energy_by_trace": regression_measures( + np.array(model_energy_list), np.array(real_energy_list) + ), + "timeout_by_trace": regression_measures( + np.array(model_timeout_list), np.array(real_timeout_list) + ), + "rel_energy_by_trace": regression_measures( + np.array(model_rel_energy_list), np.array(real_energy_list) + ), + "state_energy_by_trace": regression_measures( + np.array(model_state_energy_list), np.array(real_energy_list) + ), } @@ -2230,17 +2630,19 @@ class EnergyTraceLog: """ if not zbar_available: - self.errors.append('zbar module is not available. Try "apt install python3-zbar"') + self.errors.append( + 'zbar module is not available. Try "apt install python3-zbar"' + ) return list() - lines = log_data.decode('ascii').split('\n') - data_count = sum(map(lambda x: len(x) > 0 and x[0] != '#', lines)) - data_lines = filter(lambda x: len(x) > 0 and x[0] != '#', lines) + lines = log_data.decode("ascii").split("\n") + data_count = sum(map(lambda x: len(x) > 0 and x[0] != "#", lines)) + data_lines = filter(lambda x: len(x) > 0 and x[0] != "#", lines) data = np.empty((data_count, 4)) for i, line in enumerate(data_lines): - fields = line.split(' ') + fields = line.split(" ") if len(fields) == 4: timestamp, current, voltage, total_energy = map(int, fields) elif len(fields) == 5: @@ -2252,15 +2654,26 @@ class EnergyTraceLog: self.interval_start_timestamp = data[:-1, 0] * 1e-6 self.interval_duration = (data[1:, 0] - data[:-1, 0]) * 1e-6 - self.interval_power = ((data[1:, 3] - data[:-1, 3]) * 1e-9) / ((data[1:, 0] - data[:-1, 0]) * 1e-6) + self.interval_power = ((data[1:, 3] - data[:-1, 3]) * 1e-9) / ( + (data[1:, 0] - data[:-1, 0]) * 1e-6 + ) m_duration_us = data[-1, 0] - data[0, 0] self.sample_rate = data_count / (m_duration_us * 1e-6) - vprint(self.verbose, 'got {} samples with {} seconds of log data ({} Hz)'.format(data_count, m_duration_us * 1e-6, self.sample_rate)) + vprint( + self.verbose, + "got {} samples with {} seconds of log data ({} Hz)".format( + data_count, m_duration_us * 1e-6, self.sample_rate + ), + ) - return self.interval_start_timestamp, self.interval_duration, self.interval_power + return ( + self.interval_start_timestamp, + self.interval_duration, + self.interval_power, + ) def ts_to_index(self, timestamp): """ @@ -2279,7 +2692,12 @@ class EnergyTraceLog: mid_index = left_index + (right_index - left_index) // 2 # I'm feeling lucky - if timestamp > self.interval_start_timestamp[mid_index] and timestamp <= self.interval_start_timestamp[mid_index] + self.interval_duration[mid_index]: + if ( + timestamp > self.interval_start_timestamp[mid_index] + and timestamp + <= self.interval_start_timestamp[mid_index] + + self.interval_duration[mid_index] + ): return mid_index if timestamp <= self.interval_start_timestamp[mid_index]: @@ -2322,16 +2740,29 @@ class EnergyTraceLog: expected_transitions = list() for trace_number, trace in enumerate(traces): - for state_or_transition_number, state_or_transition in enumerate(trace['trace']): - if state_or_transition['isa'] == 'transition': + for state_or_transition_number, state_or_transition in enumerate( + trace["trace"] + ): + if state_or_transition["isa"] == "transition": try: - expected_transitions.append(( - state_or_transition['name'], - state_or_transition['online_aggregates']['duration'][offline_index] * 1e-6 - )) + expected_transitions.append( + ( + state_or_transition["name"], + state_or_transition["online_aggregates"]["duration"][ + offline_index + ] + * 1e-6, + ) + ) except IndexError: - self.errors.append('Entry #{} ("{}") in trace #{} has no duration entry for offline_index/repeat_id {}'.format( - state_or_transition_number, state_or_transition['name'], trace_number, offline_index)) + self.errors.append( + 'Entry #{} ("{}") in trace #{} has no duration entry for offline_index/repeat_id {}'.format( + state_or_transition_number, + state_or_transition["name"], + trace_number, + offline_index, + ) + ) return energy_trace next_barcode = first_sync @@ -2342,51 +2773,101 @@ class EnergyTraceLog: print('[!!!] did not find transition "{}"'.format(name)) break next_barcode = end + self.state_duration + duration - vprint(self.verbose, '{} barcode "{}" area: {:0.2f} .. {:0.2f} / {:0.2f} seconds'.format(offline_index, bc, start, stop, end)) + vprint( + self.verbose, + '{} barcode "{}" area: {:0.2f} .. {:0.2f} / {:0.2f} seconds'.format( + offline_index, bc, start, stop, end + ), + ) if bc != name: - vprint(self.verbose, '[!!!] mismatch: expected "{}", got "{}"'.format(name, bc)) - vprint(self.verbose, '{} estimated transition area: {:0.3f} .. {:0.3f} seconds'.format(offline_index, end, end + duration)) + vprint( + self.verbose, + '[!!!] mismatch: expected "{}", got "{}"'.format(name, bc), + ) + vprint( + self.verbose, + "{} estimated transition area: {:0.3f} .. {:0.3f} seconds".format( + offline_index, end, end + duration + ), + ) transition_start_index = self.ts_to_index(end) transition_done_index = self.ts_to_index(end + duration) + 1 state_start_index = transition_done_index - state_done_index = self.ts_to_index(end + duration + self.state_duration) + 1 - - vprint(self.verbose, '{} estimated transitionindex: {:0.3f} .. {:0.3f} seconds'.format(offline_index, transition_start_index / self.sample_rate, transition_done_index / self.sample_rate)) + state_done_index = ( + self.ts_to_index(end + duration + self.state_duration) + 1 + ) - energy_trace.append({ - 'isa': 'transition', - 'W_mean': np.mean(self.interval_power[transition_start_index: transition_done_index]), - 'W_std': np.std(self.interval_power[transition_start_index: transition_done_index]), - 's': duration, - 's_coarse': self.interval_start_timestamp[transition_done_index] - self.interval_start_timestamp[transition_start_index] + vprint( + self.verbose, + "{} estimated transitionindex: {:0.3f} .. {:0.3f} seconds".format( + offline_index, + transition_start_index / self.sample_rate, + transition_done_index / self.sample_rate, + ), + ) - }) + energy_trace.append( + { + "isa": "transition", + "W_mean": np.mean( + self.interval_power[ + transition_start_index:transition_done_index + ] + ), + "W_std": np.std( + self.interval_power[ + transition_start_index:transition_done_index + ] + ), + "s": duration, + "s_coarse": self.interval_start_timestamp[transition_done_index] + - self.interval_start_timestamp[transition_start_index], + } + ) if len(energy_trace) > 1: - energy_trace[-1]['W_mean_delta_prev'] = energy_trace[-1]['W_mean'] - energy_trace[-2]['W_mean'] + energy_trace[-1]["W_mean_delta_prev"] = ( + energy_trace[-1]["W_mean"] - energy_trace[-2]["W_mean"] + ) - energy_trace.append({ - 'isa': 'state', - 'W_mean': np.mean(self.interval_power[state_start_index: state_done_index]), - 'W_std': np.std(self.interval_power[state_start_index: state_done_index]), - 's': self.state_duration, - 's_coarse': self.interval_start_timestamp[state_done_index] - self.interval_start_timestamp[state_start_index] - }) + energy_trace.append( + { + "isa": "state", + "W_mean": np.mean( + self.interval_power[state_start_index:state_done_index] + ), + "W_std": np.std( + self.interval_power[state_start_index:state_done_index] + ), + "s": self.state_duration, + "s_coarse": self.interval_start_timestamp[state_done_index] + - self.interval_start_timestamp[state_start_index], + } + ) - energy_trace[-2]['W_mean_delta_next'] = energy_trace[-2]['W_mean'] - energy_trace[-1]['W_mean'] + energy_trace[-2]["W_mean_delta_next"] = ( + energy_trace[-2]["W_mean"] - energy_trace[-1]["W_mean"] + ) expected_transition_count = len(expected_transitions) recovered_transition_ount = len(energy_trace) // 2 if expected_transition_count != recovered_transition_ount: - self.errors.append('Expected {:d} transitions, got {:d}'.format(expected_transition_count, recovered_transition_ount)) + self.errors.append( + "Expected {:d} transitions, got {:d}".format( + expected_transition_count, recovered_transition_ount + ) + ) return energy_trace def find_first_sync(self): # LED Power is approx. self.led_power W, use self.led_power/2 W above surrounding median as threshold - sync_threshold_power = np.median(self.interval_power[: int(3 * self.sample_rate)]) + self.led_power / 3 + sync_threshold_power = ( + np.median(self.interval_power[: int(3 * self.sample_rate)]) + + self.led_power / 3 + ) for i, ts in enumerate(self.interval_start_timestamp): if ts > 2 and self.interval_power[i] > sync_threshold_power: return self.interval_start_timestamp[i - 300] @@ -2410,26 +2891,56 @@ class EnergyTraceLog: lookaround = int(0.1 * self.sample_rate) # LED Power is approx. self.led_power W, use self.led_power/2 W above surrounding median as threshold - sync_threshold_power = np.median(self.interval_power[start_position - lookaround: start_position + lookaround]) + self.led_power / 3 + sync_threshold_power = ( + np.median( + self.interval_power[ + start_position - lookaround : start_position + lookaround + ] + ) + + self.led_power / 3 + ) - vprint(self.verbose, 'looking for barcode starting at {:0.2f} s, threshold is {:0.1f} mW'.format(start_ts, sync_threshold_power * 1e3)) + vprint( + self.verbose, + "looking for barcode starting at {:0.2f} s, threshold is {:0.1f} mW".format( + start_ts, sync_threshold_power * 1e3 + ), + ) sync_area_start = None sync_start_ts = None sync_area_end = None sync_end_ts = None for i, ts in enumerate(self.interval_start_timestamp): - if sync_area_start is None and ts >= start_ts and self.interval_power[i] > sync_threshold_power: + if ( + sync_area_start is None + and ts >= start_ts + and self.interval_power[i] > sync_threshold_power + ): sync_area_start = i - 300 sync_start_ts = ts - if sync_area_start is not None and sync_area_end is None and ts > sync_start_ts + self.min_barcode_duration and (ts > sync_start_ts + self.max_barcode_duration or abs(sync_threshold_power - self.interval_power[i]) > self.led_power): + if ( + sync_area_start is not None + and sync_area_end is None + and ts > sync_start_ts + self.min_barcode_duration + and ( + ts > sync_start_ts + self.max_barcode_duration + or abs(sync_threshold_power - self.interval_power[i]) + > self.led_power + ) + ): sync_area_end = i sync_end_ts = ts break - barcode_data = self.interval_power[sync_area_start: sync_area_end] + barcode_data = self.interval_power[sync_area_start:sync_area_end] - vprint(self.verbose, 'barcode search area: {:0.2f} .. {:0.2f} seconds ({} samples)'.format(sync_start_ts, sync_end_ts, len(barcode_data))) + vprint( + self.verbose, + "barcode search area: {:0.2f} .. {:0.2f} seconds ({} samples)".format( + sync_start_ts, sync_end_ts, len(barcode_data) + ), + ) bc, start, stop, padding_bits = self.find_barcode_in_power_data(barcode_data) @@ -2439,7 +2950,9 @@ class EnergyTraceLog: start_ts = self.interval_start_timestamp[sync_area_start + start] stop_ts = self.interval_start_timestamp[sync_area_start + stop] - end_ts = stop_ts + self.module_duration * padding_bits + self.quiet_zone_duration + end_ts = ( + stop_ts + self.module_duration * padding_bits + self.quiet_zone_duration + ) # barcode content, barcode start timestamp, barcode stop timestamp, barcode end (stop + padding) timestamp return bc, start_ts, stop_ts, end_ts @@ -2455,7 +2968,9 @@ class EnergyTraceLog: # -> Create a black and white (not grayscale) image to avoid this. # Unfortunately, this decreases resilience against background noise # (e.g. a not-exactly-idle peripheral device or CPU interrupts). - image_data = np.around(1 - ((barcode_data - min_power) / (max_power - min_power))) + image_data = np.around( + 1 - ((barcode_data - min_power) / (max_power - min_power)) + ) image_data *= 255 # zbar only returns the complete barcode position if it is at least @@ -2469,12 +2984,12 @@ class EnergyTraceLog: # img = Image.frombytes('L', (width, height), image_data).resize((width, 100)) # img.save('/tmp/test-{}.png'.format(os.getpid())) - zbimg = zbar.Image(width, height, 'Y800', image_data) + zbimg = zbar.Image(width, height, "Y800", image_data) scanner = zbar.ImageScanner() - scanner.parse_config('enable') + scanner.parse_config("enable") if scanner.scan(zbimg): - sym, = zbimg.symbols + (sym,) = zbimg.symbols content = sym.data try: sym_start = sym.location[1][0] @@ -2482,7 +2997,7 @@ class EnergyTraceLog: sym_start = 0 sym_end = sym.location[0][0] - match = re.fullmatch(r'T(\d+)', content) + match = re.fullmatch(r"T(\d+)", content) if match: content = self.transition_names[int(match.group(1))] @@ -2490,7 +3005,7 @@ class EnergyTraceLog: # additional non-barcode padding (encoded as LED off / image white). # Calculate the amount of extra bits to determine the offset until # the transition starts. - padding_bits = len(Code128(sym.data, charset='B').modules) % 8 + padding_bits = len(Code128(sym.data, charset="B").modules) % 8 # sym_start leaves out the first two bars, but we don't do anything about that here # sym_end leaves out the last three bars, each of which is one padding bit long. @@ -2499,7 +3014,7 @@ class EnergyTraceLog: return content, sym_start, sym_end, padding_bits else: - vprint(self.verbose, 'unable to find barcode') + vprint(self.verbose, "unable to find barcode") return None, None, None, None @@ -2555,15 +3070,15 @@ class MIMOSA: :returns: (numpy array of charges (pJ per 10µs), numpy array of triggers (0/1 int, per 10µs)) """ - num_bytes = tf.getmember('/tmp/mimosa//mimosa_scale_1.tmp').size + num_bytes = tf.getmember("/tmp/mimosa//mimosa_scale_1.tmp").size charges = np.ndarray(shape=(int(num_bytes / 4)), dtype=np.int32) triggers = np.ndarray(shape=(int(num_bytes / 4)), dtype=np.int8) - with tf.extractfile('/tmp/mimosa//mimosa_scale_1.tmp') as f: + with tf.extractfile("/tmp/mimosa//mimosa_scale_1.tmp") as f: content = f.read() - iterator = struct.iter_unpack('<I', content) + iterator = struct.iter_unpack("<I", content) i = 0 for word in iterator: - charges[i] = (word[0] >> 4) + charges[i] = word[0] >> 4 triggers[i] = (word[0] & 0x08) >> 3 i += 1 return charges, triggers @@ -2616,7 +3131,7 @@ class MIMOSA: trigidx = [] if len(triggers) < 1000000: - self.errors.append('MIMOSA log is too short') + self.errors.append("MIMOSA log is too short") return trigidx prevtrig = triggers[999999] @@ -2625,13 +3140,17 @@ class MIMOSA: # something went wrong and are unable to determine when the first # transition starts. if prevtrig != 0: - self.errors.append('Unable to find start of first transition (log starts with trigger == {} != 0)'.format(prevtrig)) + self.errors.append( + "Unable to find start of first transition (log starts with trigger == {} != 0)".format( + prevtrig + ) + ) # if the last trigger is high (i.e., trigger/buzzer pin is active when the benchmark ends), # it terminated in the middle of a transition -- meaning that it was not # measured in its entirety. if triggers[-1] != 0: - self.errors.append('Log ends during a transition'.format(prevtrig)) + self.errors.append("Log ends during a transition".format(prevtrig)) # the device is reset for MIMOSA calibration in the first 10s and may # send bogus interrupts -> bogus triggers @@ -2663,11 +3182,23 @@ class MIMOSA: for i in range(100000, len(currents)): if r1idx == 0 and currents[i] > ua_r1 * 0.6: r1idx = i - elif r1idx != 0 and r2idx == 0 and i > (r1idx + 180000) and currents[i] < ua_r1 * 0.4: + elif ( + r1idx != 0 + and r2idx == 0 + and i > (r1idx + 180000) + and currents[i] < ua_r1 * 0.4 + ): r2idx = i # 2s disconnected, 2s r1, 2s r2 with r1 < r2 -> ua_r1 > ua_r2 # allow 5ms buffer in both directions to account for bouncing relais contacts - return r1idx - 180500, r1idx - 500, r1idx + 500, r2idx - 500, r2idx + 500, r2idx + 180500 + return ( + r1idx - 180500, + r1idx - 500, + r1idx + 500, + r2idx - 500, + r2idx + 500, + r2idx + 180500, + ) def calibration_function(self, charges, cal_edges): u""" @@ -2711,7 +3242,7 @@ class MIMOSA: if cal_r2_mean > cal_0_mean: b_lower = (ua_r2 - 0) / (cal_r2_mean - cal_0_mean) else: - vprint(self.verbose, '[W] 0 uA == %.f uA during calibration' % (ua_r2)) + vprint(self.verbose, "[W] 0 uA == %.f uA during calibration" % (ua_r2)) b_lower = 0 b_upper = (ua_r1 - ua_r2) / (cal_r1_mean - cal_r2_mean) @@ -2726,7 +3257,9 @@ class MIMOSA: return 0 else: return charge * b_lower + a_lower + else: + def calfunc(charge): if charge < cal_0_mean: return 0 @@ -2736,19 +3269,19 @@ class MIMOSA: return charge * b_upper + a_upper + ua_r2 caldata = { - 'edges': [x * 10 for x in cal_edges], - 'offset': cal_0_mean, - 'offset2': cal_r2_mean, - 'slope_low': b_lower, - 'slope_high': b_upper, - 'add_low': a_lower, - 'add_high': a_upper, - 'r0_err_uW': np.mean(self.currents_nocal(chg_r0)) * self.voltage, - 'r0_std_uW': np.std(self.currents_nocal(chg_r0)) * self.voltage, - 'r1_err_uW': (np.mean(self.currents_nocal(chg_r1)) - ua_r1) * self.voltage, - 'r1_std_uW': np.std(self.currents_nocal(chg_r1)) * self.voltage, - 'r2_err_uW': (np.mean(self.currents_nocal(chg_r2)) - ua_r2) * self.voltage, - 'r2_std_uW': np.std(self.currents_nocal(chg_r2)) * self.voltage, + "edges": [x * 10 for x in cal_edges], + "offset": cal_0_mean, + "offset2": cal_r2_mean, + "slope_low": b_lower, + "slope_high": b_upper, + "add_low": a_lower, + "add_high": a_upper, + "r0_err_uW": np.mean(self.currents_nocal(chg_r0)) * self.voltage, + "r0_std_uW": np.std(self.currents_nocal(chg_r0)) * self.voltage, + "r1_err_uW": (np.mean(self.currents_nocal(chg_r1)) - ua_r1) * self.voltage, + "r1_std_uW": np.std(self.currents_nocal(chg_r1)) * self.voltage, + "r2_err_uW": (np.mean(self.currents_nocal(chg_r2)) - ua_r2) * self.voltage, + "r2_std_uW": np.std(self.currents_nocal(chg_r2)) * self.voltage, } # print("if charge < %f : return 0" % cal_0_mean) @@ -2843,51 +3376,59 @@ class MIMOSA: statelist = [] prevsubidx = 0 for subidx in subst: - statelist.append({ - 'duration': (subidx - prevsubidx) * 10, - 'uW_mean': np.mean(range_ua[prevsubidx: subidx] * self.voltage), - 'uW_std': np.std(range_ua[prevsubidx: subidx] * self.voltage), - }) + statelist.append( + { + "duration": (subidx - prevsubidx) * 10, + "uW_mean": np.mean( + range_ua[prevsubidx:subidx] * self.voltage + ), + "uW_std": np.std( + range_ua[prevsubidx:subidx] * self.voltage + ), + } + ) prevsubidx = subidx substates = { - 'threshold': thr, - 'states': statelist, + "threshold": thr, + "states": statelist, } - isa = 'state' + isa = "state" if not is_state: - isa = 'transition' + isa = "transition" data = { - 'isa': isa, - 'clip_rate': np.mean(range_raw == 65535), - 'raw_mean': np.mean(range_raw), - 'raw_std': np.std(range_raw), - 'uW_mean': np.mean(range_ua * self.voltage), - 'uW_std': np.std(range_ua * self.voltage), - 'us': (idx - previdx) * 10, + "isa": isa, + "clip_rate": np.mean(range_raw == 65535), + "raw_mean": np.mean(range_raw), + "raw_std": np.std(range_raw), + "uW_mean": np.mean(range_ua * self.voltage), + "uW_std": np.std(range_ua * self.voltage), + "us": (idx - previdx) * 10, } if self.with_traces: - data['uW'] = range_ua * self.voltage + data["uW"] = range_ua * self.voltage - if 'states' in substates: - data['substates'] = substates - ssum = np.sum(list(map(lambda x: x['duration'], substates['states']))) - if ssum != data['us']: - vprint(self.verbose, "ERR: duration %d vs %d" % (data['us'], ssum)) + if "states" in substates: + data["substates"] = substates + ssum = np.sum(list(map(lambda x: x["duration"], substates["states"]))) + if ssum != data["us"]: + vprint(self.verbose, "ERR: duration %d vs %d" % (data["us"], ssum)) - if isa == 'transition': + if isa == "transition": # subtract average power of previous state # (that is, the state from which this transition originates) - data['uW_mean_delta_prev'] = data['uW_mean'] - iterdata[-1]['uW_mean'] + data["uW_mean_delta_prev"] = data["uW_mean"] - iterdata[-1]["uW_mean"] # placeholder to avoid extra cases in the analysis - data['uW_mean_delta_next'] = data['uW_mean'] - data['timeout'] = iterdata[-1]['us'] + data["uW_mean_delta_next"] = data["uW_mean"] + data["timeout"] = iterdata[-1]["us"] elif len(iterdata) > 0: # subtract average power of next state # (the state into which this transition leads) - iterdata[-1]['uW_mean_delta_next'] = iterdata[-1]['uW_mean'] - data['uW_mean'] + iterdata[-1]["uW_mean_delta_next"] = ( + iterdata[-1]["uW_mean"] - data["uW_mean"] + ) iterdata.append(data) diff --git a/lib/functions.py b/lib/functions.py index 2451ef6..6d8daa4 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -12,6 +12,7 @@ from .utils import is_numeric, vprint arg_support_enabled = True + def powerset(iterable): """ Return powerset of `iterable` elements. @@ -19,7 +20,8 @@ def powerset(iterable): Example: `powerset([1, 2])` -> `[(), (1), (2), (1, 2)]` """ s = list(iterable) - return chain.from_iterable(combinations(s, r) for r in range(len(s)+1)) + return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) + class ParamFunction: """ @@ -82,6 +84,7 @@ class ParamFunction: """ return self._param_function(P, X) - y + class NormalizationFunction: """ Wrapper for parameter normalization functions used in YAML PTA/DFA models. @@ -95,7 +98,7 @@ class NormalizationFunction: `param` and return a float. """ self._function_str = function_str - self._function = eval('lambda param: ' + function_str) + self._function = eval("lambda param: " + function_str) def eval(self, param_value: float) -> float: """ @@ -105,6 +108,7 @@ class NormalizationFunction: """ return self._function(param_value) + class AnalyticFunction: """ A multi-dimensional model function, generated from a string, which can be optimized using regression. @@ -114,7 +118,9 @@ class AnalyticFunction: packet length. """ - def __init__(self, function_str, parameters, num_args, verbose = True, regression_args = None): + def __init__( + self, function_str, parameters, num_args, verbose=True, regression_args=None + ): """ Create a new AnalyticFunction object from a function string. @@ -143,22 +149,30 @@ class AnalyticFunction: self.verbose = verbose if type(function_str) == str: - num_vars_re = re.compile(r'regression_arg\(([0-9]+)\)') + num_vars_re = re.compile(r"regression_arg\(([0-9]+)\)") num_vars = max(map(int, num_vars_re.findall(function_str))) + 1 for i in range(len(parameters)): - if rawfunction.find('parameter({})'.format(parameters[i])) >= 0: + if rawfunction.find("parameter({})".format(parameters[i])) >= 0: self._dependson[i] = True - rawfunction = rawfunction.replace('parameter({})'.format(parameters[i]), 'model_param[{:d}]'.format(i)) + rawfunction = rawfunction.replace( + "parameter({})".format(parameters[i]), + "model_param[{:d}]".format(i), + ) for i in range(0, num_args): - if rawfunction.find('function_arg({:d})'.format(i)) >= 0: + if rawfunction.find("function_arg({:d})".format(i)) >= 0: self._dependson[len(parameters) + i] = True - rawfunction = rawfunction.replace('function_arg({:d})'.format(i), 'model_param[{:d}]'.format(len(parameters) + i)) + rawfunction = rawfunction.replace( + "function_arg({:d})".format(i), + "model_param[{:d}]".format(len(parameters) + i), + ) for i in range(num_vars): - rawfunction = rawfunction.replace('regression_arg({:d})'.format(i), 'reg_param[{:d}]'.format(i)) + rawfunction = rawfunction.replace( + "regression_arg({:d})".format(i), "reg_param[{:d}]".format(i) + ) self._function_str = rawfunction - self._function = eval('lambda reg_param, model_param: ' + rawfunction) + self._function = eval("lambda reg_param, model_param: " + rawfunction) else: - self._function_str = 'raise ValueError' + self._function_str = "raise ValueError" self._function = function_str if regression_args: @@ -217,7 +231,12 @@ class AnalyticFunction: else: X[i].extend([np.nan] * len(val[model_attribute])) elif key[0] == state_or_tran and len(key[1]) != dimension: - vprint(self.verbose, '[W] Invalid parameter key length while gathering fit data for {}/{}. is {}, want {}.'.format(state_or_tran, model_attribute, len(key[1]), dimension)) + vprint( + self.verbose, + "[W] Invalid parameter key length while gathering fit data for {}/{}. is {}, want {}.".format( + state_or_tran, model_attribute, len(key[1]), dimension + ), + ) X = np.array(X) Y = np.array(Y) @@ -237,21 +256,40 @@ class AnalyticFunction: argument values are present, they must come after parameter values in the order of their appearance in the function signature. """ - X, Y, num_valid, num_total = self.get_fit_data(by_param, state_or_tran, model_attribute) + X, Y, num_valid, num_total = self.get_fit_data( + by_param, state_or_tran, model_attribute + ) if num_valid > 2: error_function = lambda P, X, y: self._function(P, X) - y try: - res = optimize.least_squares(error_function, self._regression_args, args=(X, Y), xtol=2e-15) + res = optimize.least_squares( + error_function, self._regression_args, args=(X, Y), xtol=2e-15 + ) except ValueError as err: - vprint(self.verbose, '[W] Fit failed for {}/{}: {} (function: {})'.format(state_or_tran, model_attribute, err, self._model_str)) + vprint( + self.verbose, + "[W] Fit failed for {}/{}: {} (function: {})".format( + state_or_tran, model_attribute, err, self._model_str + ), + ) return if res.status > 0: self._regression_args = res.x self.fit_success = True else: - vprint(self.verbose, '[W] Fit failed for {}/{}: {} (function: {})'.format(state_or_tran, model_attribute, res.message, self._model_str)) + vprint( + self.verbose, + "[W] Fit failed for {}/{}: {} (function: {})".format( + state_or_tran, model_attribute, res.message, self._model_str + ), + ) else: - vprint(self.verbose, '[W] Insufficient amount of valid parameter keys, cannot fit {}/{}'.format(state_or_tran, model_attribute)) + vprint( + self.verbose, + "[W] Insufficient amount of valid parameter keys, cannot fit {}/{}".format( + state_or_tran, model_attribute + ), + ) def is_predictable(self, param_list): """ @@ -268,7 +306,7 @@ class AnalyticFunction: return False return True - def eval(self, param_list, arg_list = []): + def eval(self, param_list, arg_list=[]): """ Evaluate model function with specified param/arg values. @@ -280,6 +318,7 @@ class AnalyticFunction: return self._function(param_list, arg_list) return self._function(self._regression_args, param_list) + class analytic: """ Utilities for analytic description of parameter-dependent model attributes and regression analysis. @@ -292,28 +331,28 @@ class analytic: _num0_8 = np.vectorize(lambda x: 8 - bin(int(x)).count("1")) _num0_16 = np.vectorize(lambda x: 16 - bin(int(x)).count("1")) _num1 = np.vectorize(lambda x: bin(int(x)).count("1")) - _safe_log = np.vectorize(lambda x: np.log(np.abs(x)) if np.abs(x) > 0.001 else 1.) - _safe_inv = np.vectorize(lambda x: 1 / x if np.abs(x) > 0.001 else 1.) + _safe_log = np.vectorize(lambda x: np.log(np.abs(x)) if np.abs(x) > 0.001 else 1.0) + _safe_inv = np.vectorize(lambda x: 1 / x if np.abs(x) > 0.001 else 1.0) _safe_sqrt = np.vectorize(lambda x: np.sqrt(np.abs(x))) _function_map = { - 'linear' : lambda x: x, - 'logarithmic' : np.log, - 'logarithmic1' : lambda x: np.log(x + 1), - 'exponential' : np.exp, - 'square' : lambda x : x ** 2, - 'inverse' : lambda x : 1 / x, - 'sqrt' : lambda x: np.sqrt(np.abs(x)), - 'num0_8' : _num0_8, - 'num0_16' : _num0_16, - 'num1' : _num1, - 'safe_log' : lambda x: np.log(np.abs(x)) if np.abs(x) > 0.001 else 1., - 'safe_inv' : lambda x: 1 / x if np.abs(x) > 0.001 else 1., - 'safe_sqrt': lambda x: np.sqrt(np.abs(x)), + "linear": lambda x: x, + "logarithmic": np.log, + "logarithmic1": lambda x: np.log(x + 1), + "exponential": np.exp, + "square": lambda x: x ** 2, + "inverse": lambda x: 1 / x, + "sqrt": lambda x: np.sqrt(np.abs(x)), + "num0_8": _num0_8, + "num0_16": _num0_16, + "num1": _num1, + "safe_log": lambda x: np.log(np.abs(x)) if np.abs(x) > 0.001 else 1.0, + "safe_inv": lambda x: 1 / x if np.abs(x) > 0.001 else 1.0, + "safe_sqrt": lambda x: np.sqrt(np.abs(x)), } @staticmethod - def functions(safe_functions_enabled = False): + def functions(safe_functions_enabled=False): """ Retrieve pre-defined set of regression function candidates. @@ -329,74 +368,87 @@ class analytic: variables are expected. """ functions = { - 'linear' : ParamFunction( - lambda reg_param, model_param: reg_param[0] + reg_param[1] * model_param, + "linear": ParamFunction( + lambda reg_param, model_param: reg_param[0] + + reg_param[1] * model_param, lambda model_param: True, - 2 + 2, ), - 'logarithmic' : ParamFunction( - lambda reg_param, model_param: reg_param[0] + reg_param[1] * np.log(model_param), + "logarithmic": ParamFunction( + lambda reg_param, model_param: reg_param[0] + + reg_param[1] * np.log(model_param), lambda model_param: model_param > 0, - 2 + 2, ), - 'logarithmic1' : ParamFunction( - lambda reg_param, model_param: reg_param[0] + reg_param[1] * np.log(model_param + 1), + "logarithmic1": ParamFunction( + lambda reg_param, model_param: reg_param[0] + + reg_param[1] * np.log(model_param + 1), lambda model_param: model_param > -1, - 2 + 2, ), - 'exponential' : ParamFunction( - lambda reg_param, model_param: reg_param[0] + reg_param[1] * np.exp(model_param), + "exponential": ParamFunction( + lambda reg_param, model_param: reg_param[0] + + reg_param[1] * np.exp(model_param), lambda model_param: model_param <= 64, - 2 + 2, ), #'polynomial' : lambda reg_param, model_param: reg_param[0] + reg_param[1] * model_param + reg_param[2] * model_param ** 2, - 'square' : ParamFunction( - lambda reg_param, model_param: reg_param[0] + reg_param[1] * model_param ** 2, + "square": ParamFunction( + lambda reg_param, model_param: reg_param[0] + + reg_param[1] * model_param ** 2, lambda model_param: True, - 2 + 2, ), - 'inverse' : ParamFunction( - lambda reg_param, model_param: reg_param[0] + reg_param[1] / model_param, + "inverse": ParamFunction( + lambda reg_param, model_param: reg_param[0] + + reg_param[1] / model_param, lambda model_param: model_param != 0, - 2 + 2, ), - 'sqrt' : ParamFunction( - lambda reg_param, model_param: reg_param[0] + reg_param[1] * np.sqrt(model_param), + "sqrt": ParamFunction( + lambda reg_param, model_param: reg_param[0] + + reg_param[1] * np.sqrt(model_param), lambda model_param: model_param >= 0, - 2 + 2, ), - 'num0_8' : ParamFunction( - lambda reg_param, model_param: reg_param[0] + reg_param[1] * analytic._num0_8(model_param), + "num0_8": ParamFunction( + lambda reg_param, model_param: reg_param[0] + + reg_param[1] * analytic._num0_8(model_param), lambda model_param: True, - 2 + 2, ), - 'num0_16' : ParamFunction( - lambda reg_param, model_param: reg_param[0] + reg_param[1] * analytic._num0_16(model_param), + "num0_16": ParamFunction( + lambda reg_param, model_param: reg_param[0] + + reg_param[1] * analytic._num0_16(model_param), lambda model_param: True, - 2 + 2, ), - 'num1' : ParamFunction( - lambda reg_param, model_param: reg_param[0] + reg_param[1] * analytic._num1(model_param), + "num1": ParamFunction( + lambda reg_param, model_param: reg_param[0] + + reg_param[1] * analytic._num1(model_param), lambda model_param: True, - 2 + 2, ), } if safe_functions_enabled: - functions['safe_log'] = ParamFunction( - lambda reg_param, model_param: reg_param[0] + reg_param[1] * analytic._safe_log(model_param), + functions["safe_log"] = ParamFunction( + lambda reg_param, model_param: reg_param[0] + + reg_param[1] * analytic._safe_log(model_param), lambda model_param: True, - 2 + 2, ) - functions['safe_inv'] = ParamFunction( - lambda reg_param, model_param: reg_param[0] + reg_param[1] * analytic._safe_inv(model_param), + functions["safe_inv"] = ParamFunction( + lambda reg_param, model_param: reg_param[0] + + reg_param[1] * analytic._safe_inv(model_param), lambda model_param: True, - 2 + 2, ) - functions['safe_sqrt'] = ParamFunction( - lambda reg_param, model_param: reg_param[0] + reg_param[1] * analytic._safe_sqrt(model_param), + functions["safe_sqrt"] = ParamFunction( + lambda reg_param, model_param: reg_param[0] + + reg_param[1] * analytic._safe_sqrt(model_param), lambda model_param: True, - 2 + 2, ) return functions @@ -404,27 +456,27 @@ class analytic: @staticmethod def _fmap(reference_type, reference_name, function_type): """Map arg/parameter name and best-fit function name to function text suitable for AnalyticFunction.""" - ref_str = '{}({})'.format(reference_type,reference_name) - if function_type == 'linear': + ref_str = "{}({})".format(reference_type, reference_name) + if function_type == "linear": return ref_str - if function_type == 'logarithmic': - return 'np.log({})'.format(ref_str) - if function_type == 'logarithmic1': - return 'np.log({} + 1)'.format(ref_str) - if function_type == 'exponential': - return 'np.exp({})'.format(ref_str) - if function_type == 'exponential': - return 'np.exp({})'.format(ref_str) - if function_type == 'square': - return '({})**2'.format(ref_str) - if function_type == 'inverse': - return '1/({})'.format(ref_str) - if function_type == 'sqrt': - return 'np.sqrt({})'.format(ref_str) - return 'analytic._{}({})'.format(function_type, ref_str) + if function_type == "logarithmic": + return "np.log({})".format(ref_str) + if function_type == "logarithmic1": + return "np.log({} + 1)".format(ref_str) + if function_type == "exponential": + return "np.exp({})".format(ref_str) + if function_type == "exponential": + return "np.exp({})".format(ref_str) + if function_type == "square": + return "({})**2".format(ref_str) + if function_type == "inverse": + return "1/({})".format(ref_str) + if function_type == "sqrt": + return "np.sqrt({})".format(ref_str) + return "analytic._{}({})".format(function_type, ref_str) @staticmethod - def function_powerset(fit_results, parameter_names, num_args = 0): + def function_powerset(fit_results, parameter_names, num_args=0): """ Combine per-parameter regression results into a single multi-dimensional function. @@ -443,14 +495,22 @@ class analytic: Returns an AnalyticFunction instantce corresponding to the combined function. """ - buf = '0' + buf = "0" arg_idx = 0 for combination in powerset(fit_results.items()): - buf += ' + regression_arg({:d})'.format(arg_idx) + buf += " + regression_arg({:d})".format(arg_idx) arg_idx += 1 for function_item in combination: if arg_support_enabled and is_numeric(function_item[0]): - buf += ' * {}'.format(analytic._fmap('function_arg', function_item[0], function_item[1]['best'])) + buf += " * {}".format( + analytic._fmap( + "function_arg", function_item[0], function_item[1]["best"] + ) + ) else: - buf += ' * {}'.format(analytic._fmap('parameter', function_item[0], function_item[1]['best'])) + buf += " * {}".format( + analytic._fmap( + "parameter", function_item[0], function_item[1]["best"] + ) + ) return AnalyticFunction(buf, parameter_names, num_args) diff --git a/lib/harness.py b/lib/harness.py index 54518e3..3b279c0 100644 --- a/lib/harness.py +++ b/lib/harness.py @@ -24,7 +24,16 @@ class TransitionHarness: primitive values (-> set by the return value of the current run, not necessarily constan) * `args`: function arguments, if isa == 'transition' """ - def __init__(self, gpio_pin = None, gpio_mode = 'around', pta = None, log_return_values = False, repeat = 0, post_transition_delay_us = 0): + + def __init__( + self, + gpio_pin=None, + gpio_mode="around", + pta=None, + log_return_values=False, + repeat=0, + post_transition_delay_us=0, + ): """ Create a new TransitionHarness @@ -47,7 +56,14 @@ class TransitionHarness: self.reset() def copy(self): - new_object = __class__(gpio_pin = self.gpio_pin, gpio_mode = self.gpio_mode, pta = self.pta, log_return_values = self.log_return_values, repeat = self.repeat, post_transition_delay_us = self.post_transition_delay_us) + new_object = __class__( + gpio_pin=self.gpio_pin, + gpio_mode=self.gpio_mode, + pta=self.pta, + log_return_values=self.log_return_values, + repeat=self.repeat, + post_transition_delay_us=self.post_transition_delay_us, + ) new_object.traces = self.traces.copy() new_object.trace_id = self.trace_id return new_object @@ -62,12 +78,16 @@ class TransitionHarness: of the current benchmark iteration. Resets `done` and `synced`, """ for trace in self.traces: - for state_or_transition in trace['trace']: - if 'return_values' in state_or_transition: - state_or_transition['return_values'] = state_or_transition['return_values'][:undo_from] - for param_name in state_or_transition['parameter'].keys(): - if type(state_or_transition['parameter'][param_name]) is list: - state_or_transition['parameter'][param_name] = state_or_transition['parameter'][param_name][:undo_from] + for state_or_transition in trace["trace"]: + if "return_values" in state_or_transition: + state_or_transition["return_values"] = state_or_transition[ + "return_values" + ][:undo_from] + for param_name in state_or_transition["parameter"].keys(): + if type(state_or_transition["parameter"][param_name]) is list: + state_or_transition["parameter"][ + param_name + ] = state_or_transition["parameter"][param_name][:undo_from] def reset(self): """ @@ -95,33 +115,32 @@ class TransitionHarness: def global_code(self): """Return global (pre-`main()`) C++ code needed for tracing.""" - ret = '' + ret = "" if self.gpio_pin != None: - ret += '#define PTALOG_GPIO {}\n'.format(self.gpio_pin) - if self.gpio_mode == 'before': - ret += '#define PTALOG_GPIO_BEFORE\n' - elif self.gpio_mode == 'bar': - ret += '#define PTALOG_GPIO_BAR\n' + ret += "#define PTALOG_GPIO {}\n".format(self.gpio_pin) + if self.gpio_mode == "before": + ret += "#define PTALOG_GPIO_BEFORE\n" + elif self.gpio_mode == "bar": + ret += "#define PTALOG_GPIO_BAR\n" if self.log_return_values: - ret += '#define PTALOG_WITH_RETURNVALUES\n' - ret += 'uint16_t transition_return_value;\n' + ret += "#define PTALOG_WITH_RETURNVALUES\n" + ret += "uint16_t transition_return_value;\n" ret += '#include "object/ptalog.h"\n' if self.gpio_pin != None: - ret += 'PTALog ptalog({});\n'.format(self.gpio_pin) + ret += "PTALog ptalog({});\n".format(self.gpio_pin) else: - ret += 'PTALog ptalog;\n' + ret += "PTALog ptalog;\n" return ret - def start_benchmark(self, benchmark_id = 0): + def start_benchmark(self, benchmark_id=0): """Return C++ code to signal benchmark start to harness.""" - return 'ptalog.startBenchmark({:d});\n'.format(benchmark_id) + return "ptalog.startBenchmark({:d});\n".format(benchmark_id) def start_trace(self): """Prepare a new trace/run in the internal `.traces` structure.""" - self.traces.append({ - 'id' : self.trace_id, - 'trace' : list(), - }) + self.traces.append( + {"id": self.trace_id, "trace": list(),} + ) self.trace_id += 1 def append_state(self, state_name, param): @@ -131,13 +150,11 @@ class TransitionHarness: :param state_name: state name :param param: parameter dict """ - self.traces[-1]['trace'].append({ - 'name': state_name, - 'isa': 'state', - 'parameter': param, - }) + self.traces[-1]["trace"].append( + {"name": state_name, "isa": "state", "parameter": param,} + ) - def append_transition(self, transition_name, param, args = []): + def append_transition(self, transition_name, param, args=[]): """ Append a transition to the current run in the internal `.traces` structure. @@ -145,122 +162,188 @@ class TransitionHarness: :param param: parameter dict :param args: function arguments (optional) """ - self.traces[-1]['trace'].append({ - 'name': transition_name, - 'isa': 'transition', - 'parameter': param, - 'args' : args, - }) + self.traces[-1]["trace"].append( + { + "name": transition_name, + "isa": "transition", + "parameter": param, + "args": args, + } + ) def start_run(self): """Return C++ code used to start a new run/trace.""" - return 'ptalog.reset();\n' + return "ptalog.reset();\n" def _pass_transition_call(self, transition_id): - if self.gpio_mode == 'bar': - barcode_bits = Code128('T{}'.format(transition_id), charset='B').modules + if self.gpio_mode == "bar": + barcode_bits = Code128("T{}".format(transition_id), charset="B").modules if len(barcode_bits) % 8 != 0: barcode_bits.extend([1] * (8 - (len(barcode_bits) % 8))) - barcode_bytes = [255 - int("".join(map(str, reversed(barcode_bits[i:i+8]))), 2) for i in range(0, len(barcode_bits), 8)] - inline_array = "".join(map(lambda s: '\\x{:02x}'.format(s), barcode_bytes)) - return 'ptalog.startTransition("{}", {});\n'.format(inline_array, len(barcode_bytes)) + barcode_bytes = [ + 255 - int("".join(map(str, reversed(barcode_bits[i : i + 8]))), 2) + for i in range(0, len(barcode_bits), 8) + ] + inline_array = "".join(map(lambda s: "\\x{:02x}".format(s), barcode_bytes)) + return 'ptalog.startTransition("{}", {});\n'.format( + inline_array, len(barcode_bytes) + ) else: - return 'ptalog.startTransition();\n' + return "ptalog.startTransition();\n" - def pass_transition(self, transition_id, transition_code, transition: object = None): + def pass_transition( + self, transition_id, transition_code, transition: object = None + ): """ Return C++ code used to pass a transition, including the corresponding function call. Tracks which transition has been executed and optionally its return value. May also inject a delay, if `post_transition_delay_us` is set. """ - ret = 'ptalog.passTransition({:d});\n'.format(transition_id) + ret = "ptalog.passTransition({:d});\n".format(transition_id) ret += self._pass_transition_call(transition_id) - if self.log_return_values and transition and len(transition.return_value_handlers): - ret += 'transition_return_value = {}\n'.format(transition_code) - ret += 'ptalog.logReturn(transition_return_value);\n' + if ( + self.log_return_values + and transition + and len(transition.return_value_handlers) + ): + ret += "transition_return_value = {}\n".format(transition_code) + ret += "ptalog.logReturn(transition_return_value);\n" else: - ret += '{}\n'.format(transition_code) + ret += "{}\n".format(transition_code) if self.post_transition_delay_us: - ret += 'arch.delay_us({});\n'.format(self.post_transition_delay_us) - ret += 'ptalog.stopTransition();\n' + ret += "arch.delay_us({});\n".format(self.post_transition_delay_us) + ret += "ptalog.stopTransition();\n" return ret - def stop_run(self, num_traces = 0): - return 'ptalog.dump({:d});\n'.format(num_traces) + def stop_run(self, num_traces=0): + return "ptalog.dump({:d});\n".format(num_traces) def stop_benchmark(self): - return 'ptalog.stopBenchmark();\n' + return "ptalog.stopBenchmark();\n" - def _append_nondeterministic_parameter_value(self, log_data_target, parameter_name, parameter_value): - if log_data_target['parameter'][parameter_name] is None: - log_data_target['parameter'][parameter_name] = list() - log_data_target['parameter'][parameter_name].append(parameter_value) + def _append_nondeterministic_parameter_value( + self, log_data_target, parameter_name, parameter_value + ): + if log_data_target["parameter"][parameter_name] is None: + log_data_target["parameter"][parameter_name] = list() + log_data_target["parameter"][parameter_name].append(parameter_value) def parser_cb(self, line): - #print('[HARNESS] got line {}'.format(line)) - if re.match(r'\[PTA\] benchmark stop', line): + # print('[HARNESS] got line {}'.format(line)) + if re.match(r"\[PTA\] benchmark stop", line): self.repetitions += 1 self.synced = False if self.repeat > 0 and self.repetitions == self.repeat: self.done = True - print('[HARNESS] done') + print("[HARNESS] done") return - if re.match(r'\[PTA\] benchmark start, id=(\S+)', line): + if re.match(r"\[PTA\] benchmark start, id=(\S+)", line): self.synced = True - print('[HARNESS] synced, {}/{}'.format(self.repetitions + 1, self.repeat)) + print("[HARNESS] synced, {}/{}".format(self.repetitions + 1, self.repeat)) if self.synced: - res = re.match(r'\[PTA\] trace=(\S+) count=(\S+)', line) + res = re.match(r"\[PTA\] trace=(\S+) count=(\S+)", line) if res: self.trace_id = int(res.group(1)) self.trace_length = int(res.group(2)) self.current_transition_in_trace = 0 if self.log_return_values: - res = re.match(r'\[PTA\] transition=(\S+) return=(\S+)', line) + res = re.match(r"\[PTA\] transition=(\S+) return=(\S+)", line) else: - res = re.match(r'\[PTA\] transition=(\S+)', line) + res = re.match(r"\[PTA\] transition=(\S+)", line) if res: transition_id = int(res.group(1)) # self.traces contains transitions and states, UART output only contains transitions -> use index * 2 try: - log_data_target = self.traces[self.trace_id]['trace'][self.current_transition_in_trace * 2] + log_data_target = self.traces[self.trace_id]["trace"][ + self.current_transition_in_trace * 2 + ] except IndexError: transition_name = None if self.pta: transition_name = self.pta.transitions[transition_id].name - print('[HARNESS] benchmark id={:d} trace={:d}: transition #{:d} (ID {:d}, name {}) is out of bounds'.format(0, self.trace_id, self.current_transition_in_trace, transition_id, transition_name)) - print(' Offending line: {}'.format(line)) + print( + "[HARNESS] benchmark id={:d} trace={:d}: transition #{:d} (ID {:d}, name {}) is out of bounds".format( + 0, + self.trace_id, + self.current_transition_in_trace, + transition_id, + transition_name, + ) + ) + print(" Offending line: {}".format(line)) return - if log_data_target['isa'] != 'transition': + if log_data_target["isa"] != "transition": self.abort = True - raise RuntimeError('Log mismatch: Expected transition, got {:s}'.format(log_data_target['isa'])) + raise RuntimeError( + "Log mismatch: Expected transition, got {:s}".format( + log_data_target["isa"] + ) + ) if self.pta: transition = self.pta.transitions[transition_id] - if transition.name != log_data_target['name']: + if transition.name != log_data_target["name"]: self.abort = True - raise RuntimeError('Log mismatch: Expected transition {:s}, got transition {:s} -- may have been caused by preceding malformed UART output'.format(log_data_target['name'], transition.name)) + raise RuntimeError( + "Log mismatch: Expected transition {:s}, got transition {:s} -- may have been caused by preceding malformed UART output".format( + log_data_target["name"], transition.name + ) + ) if self.log_return_values and len(transition.return_value_handlers): for handler in transition.return_value_handlers: - if 'parameter' in handler: + if "parameter" in handler: parameter_value = return_value = int(res.group(2)) - if 'return_values' not in log_data_target: - log_data_target['return_values'] = list() - log_data_target['return_values'].append(return_value) - - if 'formula' in handler: - parameter_value = handler['formula'].eval(return_value) - - self._append_nondeterministic_parameter_value(log_data_target, handler['parameter'], parameter_value) - for following_log_data_target in self.traces[self.trace_id]['trace'][(self.current_transition_in_trace * 2 + 1) :]: - self._append_nondeterministic_parameter_value(following_log_data_target, handler['parameter'], parameter_value) - if 'apply_from' in handler and any(map(lambda x: x['name'] == handler['apply_from'], self.traces[self.trace_id]['trace'][: (self.current_transition_in_trace * 2 + 1)])): - for preceding_log_data_target in reversed(self.traces[self.trace_id]['trace'][: (self.current_transition_in_trace * 2)]): - self._append_nondeterministic_parameter_value(preceding_log_data_target, handler['parameter'], parameter_value) - if preceding_log_data_target['name'] == handler['apply_from']: + if "return_values" not in log_data_target: + log_data_target["return_values"] = list() + log_data_target["return_values"].append(return_value) + + if "formula" in handler: + parameter_value = handler["formula"].eval( + return_value + ) + + self._append_nondeterministic_parameter_value( + log_data_target, + handler["parameter"], + parameter_value, + ) + for following_log_data_target in self.traces[ + self.trace_id + ]["trace"][ + (self.current_transition_in_trace * 2 + 1) : + ]: + self._append_nondeterministic_parameter_value( + following_log_data_target, + handler["parameter"], + parameter_value, + ) + if "apply_from" in handler and any( + map( + lambda x: x["name"] == handler["apply_from"], + self.traces[self.trace_id]["trace"][ + : (self.current_transition_in_trace * 2 + 1) + ], + ) + ): + for preceding_log_data_target in reversed( + self.traces[self.trace_id]["trace"][ + : (self.current_transition_in_trace * 2) + ] + ): + self._append_nondeterministic_parameter_value( + preceding_log_data_target, + handler["parameter"], + parameter_value, + ) + if ( + preceding_log_data_target["name"] + == handler["apply_from"] + ): break self.current_transition_in_trace += 1 + class OnboardTimerHarness(TransitionHarness): """TODO @@ -271,13 +354,25 @@ class OnboardTimerHarness(TransitionHarness): benchmark iteration. I.e. `.traces[*]['trace'][*]['offline_aggregates']['duration'] = [..., ...]` """ + def __init__(self, counter_limits, **kwargs): super().__init__(**kwargs) self.trace_length = 0 - self.one_cycle_in_us, self.one_overflow_in_us, self.counter_max_overflow = counter_limits + ( + self.one_cycle_in_us, + self.one_overflow_in_us, + self.counter_max_overflow, + ) = counter_limits def copy(self): - new_harness = __class__((self.one_cycle_in_us, self.one_overflow_in_us, self.counter_max_overflow), gpio_pin = self.gpio_pin, gpio_mode = self.gpio_mode, pta = self.pta, log_return_values = self.log_return_values, repeat = self.repeat) + new_harness = __class__( + (self.one_cycle_in_us, self.one_overflow_in_us, self.counter_max_overflow), + gpio_pin=self.gpio_pin, + gpio_mode=self.gpio_mode, + pta=self.pta, + log_return_values=self.log_return_values, + repeat=self.repeat, + ) new_harness.traces = self.traces.copy() new_harness.trace_id = self.trace_id return new_harness @@ -293,123 +388,215 @@ class OnboardTimerHarness(TransitionHarness): """ super().undo(undo_from) for trace in self.traces: - for state_or_transition in trace['trace']: - if 'offline_aggregates' in state_or_transition: - state_or_transition['offline_aggregates']['duration'] = state_or_transition['offline_aggregates']['duration'][:undo_from] + for state_or_transition in trace["trace"]: + if "offline_aggregates" in state_or_transition: + state_or_transition["offline_aggregates"][ + "duration" + ] = state_or_transition["offline_aggregates"]["duration"][ + :undo_from + ] def global_code(self): ret = '#include "driver/counter.h"\n' - ret += '#define PTALOG_TIMING\n' + ret += "#define PTALOG_TIMING\n" ret += super().global_code() return ret - def start_benchmark(self, benchmark_id = 0): - ret = 'counter.start();\n' - ret += 'counter.stop();\n' - ret += 'ptalog.passNop(counter);\n' + def start_benchmark(self, benchmark_id=0): + ret = "counter.start();\n" + ret += "counter.stop();\n" + ret += "ptalog.passNop(counter);\n" ret += super().start_benchmark(benchmark_id) return ret - def pass_transition(self, transition_id, transition_code, transition: object = None): - ret = 'ptalog.passTransition({:d});\n'.format(transition_id) + def pass_transition( + self, transition_id, transition_code, transition: object = None + ): + ret = "ptalog.passTransition({:d});\n".format(transition_id) ret += self._pass_transition_call(transition_id) - ret += 'counter.start();\n' - if self.log_return_values and transition and len(transition.return_value_handlers): - ret += 'transition_return_value = {}\n'.format(transition_code) + ret += "counter.start();\n" + if ( + self.log_return_values + and transition + and len(transition.return_value_handlers) + ): + ret += "transition_return_value = {}\n".format(transition_code) else: - ret += '{}\n'.format(transition_code) - ret += 'counter.stop();\n' - if self.log_return_values and transition and len(transition.return_value_handlers): - ret += 'ptalog.logReturn(transition_return_value);\n' - ret += 'ptalog.stopTransition(counter);\n' + ret += "{}\n".format(transition_code) + ret += "counter.stop();\n" + if ( + self.log_return_values + and transition + and len(transition.return_value_handlers) + ): + ret += "ptalog.logReturn(transition_return_value);\n" + ret += "ptalog.stopTransition(counter);\n" return ret - def _append_nondeterministic_parameter_value(self, log_data_target, parameter_name, parameter_value): - if log_data_target['parameter'][parameter_name] is None: - log_data_target['parameter'][parameter_name] = list() - log_data_target['parameter'][parameter_name].append(parameter_value) + def _append_nondeterministic_parameter_value( + self, log_data_target, parameter_name, parameter_value + ): + if log_data_target["parameter"][parameter_name] is None: + log_data_target["parameter"][parameter_name] = list() + log_data_target["parameter"][parameter_name].append(parameter_value) def parser_cb(self, line): # print('[HARNESS] got line {}'.format(line)) - res = re.match(r'\[PTA\] nop=(\S+)/(\S+)', line) + res = re.match(r"\[PTA\] nop=(\S+)/(\S+)", line) if res: self.nop_cycles = int(res.group(1)) if int(res.group(2)): - raise RuntimeError('Counter overflow ({:d}/{:d}) during NOP test, wtf?!'.format(res.group(1), res.group(2))) - if re.match(r'\[PTA\] benchmark stop', line): + raise RuntimeError( + "Counter overflow ({:d}/{:d}) during NOP test, wtf?!".format( + res.group(1), res.group(2) + ) + ) + if re.match(r"\[PTA\] benchmark stop", line): self.repetitions += 1 self.synced = False if self.repeat > 0 and self.repetitions == self.repeat: self.done = True - print('[HARNESS] done') + print("[HARNESS] done") return # May be repeated, e.g. if the device is reset shortly after start by # EnergyTrace. - if re.match(r'\[PTA\] benchmark start, id=(\S+)', line): + if re.match(r"\[PTA\] benchmark start, id=(\S+)", line): self.synced = True - print('[HARNESS] synced, {}/{}'.format(self.repetitions + 1, self.repeat)) + print("[HARNESS] synced, {}/{}".format(self.repetitions + 1, self.repeat)) if self.synced: - res = re.match(r'\[PTA\] trace=(\S+) count=(\S+)', line) + res = re.match(r"\[PTA\] trace=(\S+) count=(\S+)", line) if res: self.trace_id = int(res.group(1)) self.trace_length = int(res.group(2)) self.current_transition_in_trace = 0 if self.log_return_values: - res = re.match(r'\[PTA\] transition=(\S+) cycles=(\S+)/(\S+) return=(\S+)', line) + res = re.match( + r"\[PTA\] transition=(\S+) cycles=(\S+)/(\S+) return=(\S+)", line + ) else: - res = re.match(r'\[PTA\] transition=(\S+) cycles=(\S+)/(\S+)', line) + res = re.match(r"\[PTA\] transition=(\S+) cycles=(\S+)/(\S+)", line) if res: transition_id = int(res.group(1)) cycles = int(res.group(2)) overflow = int(res.group(3)) if overflow >= self.counter_max_overflow: self.abort = True - raise RuntimeError('Counter overflow ({:d}/{:d}) in benchmark id={:d} trace={:d}: transition #{:d} (ID {:d})'.format(cycles, overflow, 0, self.trace_id, self.current_transition_in_trace, transition_id)) - duration_us = cycles * self.one_cycle_in_us + overflow * self.one_overflow_in_us - self.nop_cycles * self.one_cycle_in_us + raise RuntimeError( + "Counter overflow ({:d}/{:d}) in benchmark id={:d} trace={:d}: transition #{:d} (ID {:d})".format( + cycles, + overflow, + 0, + self.trace_id, + self.current_transition_in_trace, + transition_id, + ) + ) + duration_us = ( + cycles * self.one_cycle_in_us + + overflow * self.one_overflow_in_us + - self.nop_cycles * self.one_cycle_in_us + ) if duration_us < 0: duration_us = 0 # self.traces contains transitions and states, UART output only contains transitions -> use index * 2 try: - log_data_target = self.traces[self.trace_id]['trace'][self.current_transition_in_trace * 2] + log_data_target = self.traces[self.trace_id]["trace"][ + self.current_transition_in_trace * 2 + ] except IndexError: transition_name = None if self.pta: transition_name = self.pta.transitions[transition_id].name - print('[HARNESS] benchmark id={:d} trace={:d}: transition #{:d} (ID {:d}, name {}) is out of bounds'.format(0, self.trace_id, self.current_transition_in_trace, transition_id, transition_name)) - print(' Offending line: {}'.format(line)) + print( + "[HARNESS] benchmark id={:d} trace={:d}: transition #{:d} (ID {:d}, name {}) is out of bounds".format( + 0, + self.trace_id, + self.current_transition_in_trace, + transition_id, + transition_name, + ) + ) + print(" Offending line: {}".format(line)) return - if log_data_target['isa'] != 'transition': + if log_data_target["isa"] != "transition": self.abort = True - raise RuntimeError('Log mismatch in benchmark id={:d} trace={:d}: transition #{:d} (ID {:d}): Expected transition, got {:s}'.format(0, - self.trace_id, self.current_transition_in_trace, transition_id, log_data_target['isa'])) + raise RuntimeError( + "Log mismatch in benchmark id={:d} trace={:d}: transition #{:d} (ID {:d}): Expected transition, got {:s}".format( + 0, + self.trace_id, + self.current_transition_in_trace, + transition_id, + log_data_target["isa"], + ) + ) if self.pta: transition = self.pta.transitions[transition_id] - if transition.name != log_data_target['name']: + if transition.name != log_data_target["name"]: self.abort = True - raise RuntimeError('Log mismatch in benchmark id={:d} trace={:d}: transition #{:d} (ID {:d}): Expected transition {:s}, got transition {:s} -- may have been caused by preceding maformed UART output'.format(0, self.trace_id, self.current_transition_in_trace, transition_id, log_data_target['name'], transition.name, line)) + raise RuntimeError( + "Log mismatch in benchmark id={:d} trace={:d}: transition #{:d} (ID {:d}): Expected transition {:s}, got transition {:s} -- may have been caused by preceding maformed UART output".format( + 0, + self.trace_id, + self.current_transition_in_trace, + transition_id, + log_data_target["name"], + transition.name, + line, + ) + ) if self.log_return_values and len(transition.return_value_handlers): for handler in transition.return_value_handlers: - if 'parameter' in handler: + if "parameter" in handler: parameter_value = return_value = int(res.group(4)) - if 'return_values' not in log_data_target: - log_data_target['return_values'] = list() - log_data_target['return_values'].append(return_value) - - if 'formula' in handler: - parameter_value = handler['formula'].eval(return_value) - - self._append_nondeterministic_parameter_value(log_data_target, handler['parameter'], parameter_value) - for following_log_data_target in self.traces[self.trace_id]['trace'][(self.current_transition_in_trace * 2 + 1) :]: - self._append_nondeterministic_parameter_value(following_log_data_target, handler['parameter'], parameter_value) - if 'apply_from' in handler and any(map(lambda x: x['name'] == handler['apply_from'], self.traces[self.trace_id]['trace'][: (self.current_transition_in_trace * 2 + 1)])): - for preceding_log_data_target in reversed(self.traces[self.trace_id]['trace'][: (self.current_transition_in_trace * 2)]): - self._append_nondeterministic_parameter_value(preceding_log_data_target, handler['parameter'], parameter_value) - if preceding_log_data_target['name'] == handler['apply_from']: + if "return_values" not in log_data_target: + log_data_target["return_values"] = list() + log_data_target["return_values"].append(return_value) + + if "formula" in handler: + parameter_value = handler["formula"].eval( + return_value + ) + + self._append_nondeterministic_parameter_value( + log_data_target, + handler["parameter"], + parameter_value, + ) + for following_log_data_target in self.traces[ + self.trace_id + ]["trace"][ + (self.current_transition_in_trace * 2 + 1) : + ]: + self._append_nondeterministic_parameter_value( + following_log_data_target, + handler["parameter"], + parameter_value, + ) + if "apply_from" in handler and any( + map( + lambda x: x["name"] == handler["apply_from"], + self.traces[self.trace_id]["trace"][ + : (self.current_transition_in_trace * 2 + 1) + ], + ) + ): + for preceding_log_data_target in reversed( + self.traces[self.trace_id]["trace"][ + : (self.current_transition_in_trace * 2) + ] + ): + self._append_nondeterministic_parameter_value( + preceding_log_data_target, + handler["parameter"], + parameter_value, + ) + if ( + preceding_log_data_target["name"] + == handler["apply_from"] + ): break - if 'offline_aggregates' not in log_data_target: - log_data_target['offline_aggregates'] = { - 'duration' : list() - } - log_data_target['offline_aggregates']['duration'].append(duration_us) + if "offline_aggregates" not in log_data_target: + log_data_target["offline_aggregates"] = {"duration": list()} + log_data_target["offline_aggregates"]["duration"].append(duration_us) self.current_transition_in_trace += 1 diff --git a/lib/ipython_energymodel_prelude.py b/lib/ipython_energymodel_prelude.py index 6777b17..0457838 100755 --- a/lib/ipython_energymodel_prelude.py +++ b/lib/ipython_energymodel_prelude.py @@ -5,9 +5,11 @@ from dfatool import PTAModel, RawData, soft_cast_int ignored_trace_indexes = None -files = '../data/20170125_125433_cc1200.tar ../data/20170125_142420_cc1200.tar ../data/20170125_144957_cc1200.tar ../data/20170125_151149_cc1200.tar ../data/20170125_151824_cc1200.tar ../data/20170125_154019_cc1200.tar'.split(' ') -#files = '../data/20170116_124500_LM75x.tar ../data/20170116_131306_LM75x.tar'.split(' ') +files = "../data/20170125_125433_cc1200.tar ../data/20170125_142420_cc1200.tar ../data/20170125_144957_cc1200.tar ../data/20170125_151149_cc1200.tar ../data/20170125_151824_cc1200.tar ../data/20170125_154019_cc1200.tar".split( + " " +) +# files = '../data/20170116_124500_LM75x.tar ../data/20170116_131306_LM75x.tar'.split(' ') raw_data = RawData(files) preprocessed_data = raw_data.get_preprocessed_data() -model = PTAModel(preprocessed_data, ignore_trace_indexes = ignored_trace_indexes) +model = PTAModel(preprocessed_data, ignore_trace_indexes=ignored_trace_indexes) diff --git a/lib/keysightdlog.py b/lib/keysightdlog.py index 0cf8da1..89264b9 100755 --- a/lib/keysightdlog.py +++ b/lib/keysightdlog.py @@ -8,69 +8,74 @@ import struct import sys import xml.etree.ElementTree as ET + def plot_y(Y, **kwargs): plot_xy(np.arange(len(Y)), Y, **kwargs) -def plot_xy(X, Y, xlabel = None, ylabel = None, title = None, output = None): - fig, ax1 = plt.subplots(figsize=(10,6)) + +def plot_xy(X, Y, xlabel=None, ylabel=None, title=None, output=None): + fig, ax1 = plt.subplots(figsize=(10, 6)) if title != None: fig.canvas.set_window_title(title) if xlabel != None: ax1.set_xlabel(xlabel) if ylabel != None: ax1.set_ylabel(ylabel) - plt.subplots_adjust(left = 0.1, bottom = 0.1, right = 0.99, top = 0.99) + plt.subplots_adjust(left=0.1, bottom=0.1, right=0.99, top=0.99) plt.plot(X, Y, "bo", markersize=2) if output: plt.savefig(output) - with open('{}.txt'.format(output), 'w') as f: - print('X Y', file=f) + with open("{}.txt".format(output), "w") as f: + print("X Y", file=f) for i in range(len(X)): - print('{} {}'.format(X[i], Y[i]), file=f) + print("{} {}".format(X[i], Y[i]), file=f) else: plt.show() + filename = sys.argv[1] -with open(filename, 'rb') as logfile: +with open(filename, "rb") as logfile: lines = [] - line = '' + line = "" - if '.xz' in filename: + if ".xz" in filename: f = lzma.open(logfile) else: f = logfile - while line != '</dlog>\n': + while line != "</dlog>\n": line = f.readline().decode() lines.append(line) - xml_header = ''.join(lines) + xml_header = "".join(lines) raw_header = f.read(8) data_offset = f.tell() raw_data = f.read() - xml_header = xml_header.replace('1ua>', 'X1ua>') - xml_header = xml_header.replace('2ua>', 'X2ua>') + xml_header = xml_header.replace("1ua>", "X1ua>") + xml_header = xml_header.replace("2ua>", "X2ua>") dlog = ET.fromstring(xml_header) channels = [] - for channel in dlog.findall('channel'): - channel_id = int(channel.get('id')) - sense_curr = channel.find('sense_curr').text - sense_volt = channel.find('sense_volt').text - model = channel.find('ident').find('model').text - if sense_volt == '1': - channels.append((channel_id, model, 'V')) - if sense_curr == '1': - channels.append((channel_id, model, 'A')) + for channel in dlog.findall("channel"): + channel_id = int(channel.get("id")) + sense_curr = channel.find("sense_curr").text + sense_volt = channel.find("sense_volt").text + model = channel.find("ident").find("model").text + if sense_volt == "1": + channels.append((channel_id, model, "V")) + if sense_curr == "1": + channels.append((channel_id, model, "A")) num_channels = len(channels) - duration = int(dlog.find('frame').find('time').text) - interval = float(dlog.find('frame').find('tint').text) + duration = int(dlog.find("frame").find("time").text) + interval = float(dlog.find("frame").find("tint").text) real_duration = interval * int(len(raw_data) / (4 * num_channels)) - data = np.ndarray(shape=(num_channels, int(len(raw_data) / (4 * num_channels))), dtype=np.float32) + data = np.ndarray( + shape=(num_channels, int(len(raw_data) / (4 * num_channels))), dtype=np.float32 + ) - iterator = struct.iter_unpack('>f', raw_data) + iterator = struct.iter_unpack(">f", raw_data) channel_offset = 0 measurement_offset = 0 for value in iterator: @@ -82,34 +87,59 @@ with open(filename, 'rb') as logfile: channel_offset += 1 if int(real_duration) != duration: - print('Measurement duration: {:f} of {:d} seconds at {:f} µs per sample'.format( - real_duration, duration, interval * 1000000)) + print( + "Measurement duration: {:f} of {:d} seconds at {:f} µs per sample".format( + real_duration, duration, interval * 1000000 + ) + ) else: - print('Measurement duration: {:d} seconds at {:f} µs per sample'.format( - duration, interval * 1000000)) + print( + "Measurement duration: {:d} seconds at {:f} µs per sample".format( + duration, interval * 1000000 + ) + ) for i, channel in enumerate(channels): channel_id, channel_model, channel_type = channel - print('channel {:d} ({:s}): min {:f}, max {:f}, mean {:f} {:s}'.format( - channel_id, channel_model, np.min(data[i]), np.max(data[i]), np.mean(data[i]), - channel_type)) + print( + "channel {:d} ({:s}): min {:f}, max {:f}, mean {:f} {:s}".format( + channel_id, + channel_model, + np.min(data[i]), + np.max(data[i]), + np.mean(data[i]), + channel_type, + ) + ) - if i > 0 and channel_type == 'A' and channels[i-1][2] == 'V' and channel_id == channels[i-1][0]: - power = data[i-1] * data[i] + if ( + i > 0 + and channel_type == "A" + and channels[i - 1][2] == "V" + and channel_id == channels[i - 1][0] + ): + power = data[i - 1] * data[i] power = 3.6 * data[i] - print('channel {:d} ({:s}): min {:f}, max {:f}, mean {:f} W'.format( - channel_id, channel_model, np.min(power), np.max(power), np.mean(power))) + print( + "channel {:d} ({:s}): min {:f}, max {:f}, mean {:f} W".format( + channel_id, channel_model, np.min(power), np.max(power), np.mean(power) + ) + ) min_power = np.min(power) max_power = np.max(power) power_border = np.mean([min_power, max_power]) low_power = power[power < power_border] high_power = power[power >= power_border] plot_y(power) - print(' avg low / high power (delta): {:f} / {:f} ({:f}) W'.format( - np.mean(low_power), np.mean(high_power), - np.mean(high_power) - np.mean(low_power))) - #plot_y(low_power) - #plot_y(high_power) + print( + " avg low / high power (delta): {:f} / {:f} ({:f}) W".format( + np.mean(low_power), + np.mean(high_power), + np.mean(high_power) - np.mean(low_power), + ) + ) + # plot_y(low_power) + # plot_y(high_power) high_power_durations = [] current_high_power_duration = 0 for is_hpe in power >= power_border: @@ -119,12 +149,16 @@ for i, channel in enumerate(channels): if current_high_power_duration > 0: high_power_durations.append(current_high_power_duration) current_high_power_duration = 0 - print(' avg high-power duration: {:f} µs'.format(np.mean(high_power_durations) * 1000000)) - -#print(xml_header) -#print(raw_header) -#print(channels) -#print(data) -#print(np.mean(data[0])) -#print(np.mean(data[1])) -#print(np.mean(data[0] * data[1])) + print( + " avg high-power duration: {:f} µs".format( + np.mean(high_power_durations) * 1000000 + ) + ) + +# print(xml_header) +# print(raw_header) +# print(channels) +# print(data) +# print(np.mean(data[0])) +# print(np.mean(data[1])) +# print(np.mean(data[0] * data[1])) @@ -3,34 +3,44 @@ from .sly import Lexer, Parser class TimedWordLexer(Lexer): tokens = {LPAREN, RPAREN, IDENTIFIER, NUMBER, ARGSEP, FUNCTIONSEP} - ignore = ' \t' + ignore = " \t" - LPAREN = r'\(' - RPAREN = r'\)' - IDENTIFIER = r'[a-zA-Z_][a-zA-Z0-9_]*' - NUMBER = r'[0-9e.]+' - ARGSEP = r',' - FUNCTIONSEP = r';' + LPAREN = r"\(" + RPAREN = r"\)" + IDENTIFIER = r"[a-zA-Z_][a-zA-Z0-9_]*" + NUMBER = r"[0-9e.]+" + ARGSEP = r"," + FUNCTIONSEP = r";" class TimedSequenceLexer(Lexer): - tokens = {LPAREN, RPAREN, LBRACE, RBRACE, CYCLE, IDENTIFIER, NUMBER, ARGSEP, FUNCTIONSEP} - ignore = ' \t' - - LPAREN = r'\(' - RPAREN = r'\)' - LBRACE = r'\{' - RBRACE = r'\}' - CYCLE = r'cycle' - IDENTIFIER = r'[a-zA-Z_][a-zA-Z0-9_]*' - NUMBER = r'[0-9e.]+' - ARGSEP = r',' - FUNCTIONSEP = r';' + tokens = { + LPAREN, + RPAREN, + LBRACE, + RBRACE, + CYCLE, + IDENTIFIER, + NUMBER, + ARGSEP, + FUNCTIONSEP, + } + ignore = " \t" + + LPAREN = r"\(" + RPAREN = r"\)" + LBRACE = r"\{" + RBRACE = r"\}" + CYCLE = r"cycle" + IDENTIFIER = r"[a-zA-Z_][a-zA-Z0-9_]*" + NUMBER = r"[0-9e.]+" + ARGSEP = r"," + FUNCTIONSEP = r";" def error(self, t): print("Illegal character '%s'" % t.value[0]) - if t.value[0] == '{' and t.value.find('}'): - self.index += 1 + t.value.find('}') + if t.value[0] == "{" and t.value.find("}"): + self.index += 1 + t.value.find("}") else: self.index += 1 @@ -38,39 +48,39 @@ class TimedSequenceLexer(Lexer): class TimedWordParser(Parser): tokens = TimedWordLexer.tokens - @_('timedSymbol FUNCTIONSEP timedWord') + @_("timedSymbol FUNCTIONSEP timedWord") def timedWord(self, p): ret = [p.timedSymbol] ret.extend(p.timedWord) return ret - @_('timedSymbol FUNCTIONSEP', 'timedSymbol') + @_("timedSymbol FUNCTIONSEP", "timedSymbol") def timedWord(self, p): return [p.timedSymbol] - @_('IDENTIFIER', 'IDENTIFIER LPAREN RPAREN') + @_("IDENTIFIER", "IDENTIFIER LPAREN RPAREN") def timedSymbol(self, p): return (p.IDENTIFIER,) - @_('IDENTIFIER LPAREN args RPAREN') + @_("IDENTIFIER LPAREN args RPAREN") def timedSymbol(self, p): return (p.IDENTIFIER, *p.args) - @_('arg ARGSEP args') + @_("arg ARGSEP args") def args(self, p): ret = [p.arg] ret.extend(p.args) return ret - @_('arg') + @_("arg") def args(self, p): return [p.arg] - @_('NUMBER') + @_("NUMBER") def arg(self, p): return float(p.NUMBER) - @_('IDENTIFIER') + @_("IDENTIFIER") def arg(self, p): return p.IDENTIFIER @@ -78,66 +88,66 @@ class TimedWordParser(Parser): class TimedSequenceParser(Parser): tokens = TimedSequenceLexer.tokens - @_('timedSequenceL', 'timedSequenceW') + @_("timedSequenceL", "timedSequenceW") def timedSequence(self, p): return p[0] - @_('loop') + @_("loop") def timedSequenceL(self, p): return [p.loop] - @_('loop timedSequenceW') + @_("loop timedSequenceW") def timedSequenceL(self, p): ret = [p.loop] ret.extend(p.timedSequenceW) return ret - @_('timedWord') + @_("timedWord") def timedSequenceW(self, p): return [p.timedWord] - @_('timedWord timedSequenceL') + @_("timedWord timedSequenceL") def timedSequenceW(self, p): ret = [p.timedWord] ret.extend(p.timedSequenceL) return ret - @_('timedSymbol FUNCTIONSEP timedWord') + @_("timedSymbol FUNCTIONSEP timedWord") def timedWord(self, p): p.timedWord.word.insert(0, p.timedSymbol) return p.timedWord - @_('timedSymbol FUNCTIONSEP') + @_("timedSymbol FUNCTIONSEP") def timedWord(self, p): return TimedWord(word=[p.timedSymbol]) - @_('CYCLE LPAREN IDENTIFIER RPAREN LBRACE timedWord RBRACE') + @_("CYCLE LPAREN IDENTIFIER RPAREN LBRACE timedWord RBRACE") def loop(self, p): return Workload(p.IDENTIFIER, p.timedWord) - @_('IDENTIFIER', 'IDENTIFIER LPAREN RPAREN') + @_("IDENTIFIER", "IDENTIFIER LPAREN RPAREN") def timedSymbol(self, p): return (p.IDENTIFIER,) - @_('IDENTIFIER LPAREN args RPAREN') + @_("IDENTIFIER LPAREN args RPAREN") def timedSymbol(self, p): return (p.IDENTIFIER, *p.args) - @_('arg ARGSEP args') + @_("arg ARGSEP args") def args(self, p): ret = [p.arg] ret.extend(p.args) return ret - @_('arg') + @_("arg") def args(self, p): return [p.arg] - @_('NUMBER') + @_("NUMBER") def arg(self, p): return float(p.NUMBER) - @_('IDENTIFIER') + @_("IDENTIFIER") def arg(self, p): return p.IDENTIFIER @@ -165,8 +175,8 @@ class TimedWord: def __repr__(self): ret = list() for symbol in self.word: - ret.append('{}({})'.format(symbol[0], ', '.join(map(str, symbol[1:])))) - return 'TimedWord<"{}">'.format('; '.join(ret)) + ret.append("{}({})".format(symbol[0], ", ".join(map(str, symbol[1:])))) + return 'TimedWord<"{}">'.format("; ".join(ret)) class Workload: @@ -196,5 +206,5 @@ class TimedSequence: def __repr__(self): ret = list() for symbol in self.seq: - ret.append('{}'.format(symbol)) - return 'TimedSequence(seq=[{}])'.format(', '.join(ret)) + ret.append("{}".format(symbol)) + return "TimedSequence(seq=[{}])".format(", ".join(ret)) diff --git a/lib/modular_arithmetic.py b/lib/modular_arithmetic.py index 0a69b79..c5ed1aa 100644 --- a/lib/modular_arithmetic.py +++ b/lib/modular_arithmetic.py @@ -3,6 +3,7 @@ import operator import functools + @functools.total_ordering class Mod: """A class for modular arithmetic, useful to simulate behaviour of uint8 and other limited data types. @@ -14,20 +15,21 @@ class Mod: :param val: stored integer value Param mod: modulus """ - __slots__ = ['val','mod'] + + __slots__ = ["val", "mod"] def __init__(self, val, mod): if isinstance(val, Mod): val = val.val if not isinstance(val, int): - raise ValueError('Value must be integer') - if not isinstance(mod, int) or mod<=0: - raise ValueError('Modulo must be positive integer') + raise ValueError("Value must be integer") + if not isinstance(mod, int) or mod <= 0: + raise ValueError("Modulo must be positive integer") self.val = val % mod self.mod = mod def __repr__(self): - return 'Mod({}, {})'.format(self.val, self.mod) + return "Mod({}, {})".format(self.val, self.mod) def __int__(self): return self.val @@ -50,7 +52,7 @@ class Mod: def _check_operand(self, other): if not isinstance(other, (int, Mod)): - raise TypeError('Only integer and Mod operands are supported') + raise TypeError("Only integer and Mod operands are supported") def __pow__(self, other): self._check_operand(other) @@ -61,32 +63,46 @@ class Mod: return Mod(self.mod - self.val, self.mod) def __pos__(self): - return self # The unary plus operator does nothing. + return self # The unary plus operator does nothing. def __abs__(self): - return self # The value is always kept non-negative, so the abs function should do nothing. + return self # The value is always kept non-negative, so the abs function should do nothing. + # Helper functions to build common operands based on a template. # They need to be implemented as functions for the closures to work properly. def _make_op(opname): - op_fun = getattr(operator, opname) # Fetch the operator by name from the operator module + op_fun = getattr( + operator, opname + ) # Fetch the operator by name from the operator module + def op(self, other): self._check_operand(other) return Mod(op_fun(self.val, int(other)) % self.mod, self.mod) + return op + def _make_reflected_op(opname): op_fun = getattr(operator, opname) + def op(self, other): self._check_operand(other) return Mod(op_fun(int(other), self.val) % self.mod, self.mod) + return op + # Build the actual operator overload methods based on the template. -for opname, reflected_opname in [('__add__', '__radd__'), ('__sub__', '__rsub__'), ('__mul__', '__rmul__')]: +for opname, reflected_opname in [ + ("__add__", "__radd__"), + ("__sub__", "__rsub__"), + ("__mul__", "__rmul__"), +]: setattr(Mod, opname, _make_op(opname)) setattr(Mod, reflected_opname, _make_reflected_op(opname)) + class Uint8(Mod): __slots__ = [] @@ -94,7 +110,8 @@ class Uint8(Mod): super().__init__(val, 256) def __repr__(self): - return 'Uint8({})'.format(self.val) + return "Uint8({})".format(self.val) + class Uint16(Mod): __slots__ = [] @@ -103,7 +120,8 @@ class Uint16(Mod): super().__init__(val, 65536) def __repr__(self): - return 'Uint16({})'.format(self.val) + return "Uint16({})".format(self.val) + class Uint32(Mod): __slots__ = [] @@ -112,7 +130,8 @@ class Uint32(Mod): super().__init__(val, 4294967296) def __repr__(self): - return 'Uint32({})'.format(self.val) + return "Uint32({})".format(self.val) + class Uint64(Mod): __slots__ = [] @@ -121,7 +140,7 @@ class Uint64(Mod): super().__init__(val, 18446744073709551616) def __repr__(self): - return 'Uint64({})'.format(self.val) + return "Uint64({})".format(self.val) def simulate_int_type(int_type: str) -> Mod: @@ -131,12 +150,12 @@ def simulate_int_type(int_type: str) -> Mod: :param int_type: uint8_t / uint16_t / uint32_t / uint64_t :returns: `Mod` subclass, e.g. Uint8 """ - if int_type == 'uint8_t': + if int_type == "uint8_t": return Uint8 - if int_type == 'uint16_t': + if int_type == "uint16_t": return Uint16 - if int_type == 'uint32_t': + if int_type == "uint32_t": return Uint32 - if int_type == 'uint64_t': + if int_type == "uint64_t": return Uint64 - raise ValueError('unsupported integer type: {}'.format(int_type)) + raise ValueError("unsupported integer type: {}".format(int_type)) diff --git a/lib/parameters.py b/lib/parameters.py index 41e312a..8b562b6 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -21,8 +21,10 @@ def distinct_param_values(by_name, state_or_tran): write() or similar has not been called yet. Other parameters should always be initialized when leaving UNINITIALIZED. """ - distinct_values = [OrderedDict() for i in range(len(by_name[state_or_tran]['param'][0]))] - for param_tuple in by_name[state_or_tran]['param']: + distinct_values = [ + OrderedDict() for i in range(len(by_name[state_or_tran]["param"][0])) + ] + for param_tuple in by_name[state_or_tran]["param"]: for i in range(len(param_tuple)): distinct_values[i][param_tuple[i]] = True @@ -30,8 +32,9 @@ def distinct_param_values(by_name, state_or_tran): distinct_values = list(map(lambda x: list(x.keys()), distinct_values)) return distinct_values + def _depends_on_param(corr_param, std_param, std_lut): - #if self.use_corrcoef: + # if self.use_corrcoef: if False: return corr_param > 0.1 elif std_param == 0: @@ -40,6 +43,7 @@ def _depends_on_param(corr_param, std_param, std_lut): return False return std_lut / std_param < 0.5 + def _reduce_param_matrix(matrix: np.ndarray, parameter_names: list) -> list: """ :param matrix: parameter dependence matrix, M[(...)] == 1 iff (model attribute) is influenced by (parameter) for other parameter value indxe == (...) @@ -53,7 +57,7 @@ def _reduce_param_matrix(matrix: np.ndarray, parameter_names: list) -> list: # Diese Abbruchbedingung scheint noch nicht so schlau zu sein... # Mit wird zu viel rausgefiltert (z.B. auto_ack! -> max_retry_count in "bin/analyze-timing.py ../data/20190815_122531_nRF24_no-rx.json" nicht erkannt) # Ohne wird zu wenig rausgefiltert (auch ganz viele Abhängigkeiten erkannt, bei denen eine Parameter-Abhängigketi immer unabhängig vom Wert der anderen Parameter besteht) - #if not is_power_of_two(np.count_nonzero(matrix)): + # if not is_power_of_two(np.count_nonzero(matrix)): # # cannot be reliably reduced to a list of parameters # return list() @@ -65,20 +69,23 @@ def _reduce_param_matrix(matrix: np.ndarray, parameter_names: list) -> list: return influential_parameters for axis in range(matrix.ndim): - candidate = _reduce_param_matrix(np.all(matrix, axis=axis), remove_index_from_tuple(parameter_names, axis)) + candidate = _reduce_param_matrix( + np.all(matrix, axis=axis), remove_index_from_tuple(parameter_names, axis) + ) if len(candidate): return candidate return list() + def _codependent_parameters(param, lut_by_param_values, std_by_param_values): """ Return list of parameters which affect whether a parameter affects a model attribute or not. """ return list() - safe_div = np.vectorize(lambda x,y: 0. if x == 0 else 1 - x/y) + safe_div = np.vectorize(lambda x, y: 0.0 if x == 0 else 1 - x / y) ratio_by_value = safe_div(lut_by_param_values, std_by_param_values) - err_mode = np.seterr('ignore') + err_mode = np.seterr("ignore") dep_by_value = ratio_by_value > 0.5 np.seterr(**err_mode) @@ -86,7 +93,10 @@ def _codependent_parameters(param, lut_by_param_values, std_by_param_values): influencer_parameters = _reduce_param_matrix(dep_by_value, other_param_list) return influencer_parameters -def _std_by_param(by_param, all_param_values, state_or_tran, attribute, param_index, verbose = False): + +def _std_by_param( + by_param, all_param_values, state_or_tran, attribute, param_index, verbose=False +): u""" Calculate standard deviations for a static model where all parameters but `param_index` are constant. @@ -130,7 +140,10 @@ def _std_by_param(by_param, all_param_values, state_or_tran, attribute, param_in param_partition = list() std_list = list() for k, v in by_param.items(): - if k[0] == state_or_tran and (*k[1][:param_index], *k[1][param_index+1:]) == param_value: + if ( + k[0] == state_or_tran + and (*k[1][:param_index], *k[1][param_index + 1 :]) == param_value + ): param_partition.extend(v[attribute]) std_list.append(np.std(v[attribute])) @@ -143,17 +156,26 @@ def _std_by_param(by_param, all_param_values, state_or_tran, attribute, param_in lut_matrix[matrix_index] = np.mean(std_list) # This can (and will) happen in normal operation, e.g. when a transition's # arguments are combined using 'zip' rather than 'cartesian'. - #elif len(param_partition) == 1: + # elif len(param_partition) == 1: # vprint(verbose, '[W] parameter value partition for {} contains only one element -- skipping'.format(param_value)) - #else: + # else: # vprint(verbose, '[W] parameter value partition for {} is empty'.format(param_value)) if np.all(np.isnan(stddev_matrix)): - print('[W] {}/{} parameter #{} has no data partitions -- how did this even happen?'.format(state_or_tran, attribute, param_index)) - print('stddev_matrix = {}'.format(stddev_matrix)) - return stddev_matrix, 0. + print( + "[W] {}/{} parameter #{} has no data partitions -- how did this even happen?".format( + state_or_tran, attribute, param_index + ) + ) + print("stddev_matrix = {}".format(stddev_matrix)) + return stddev_matrix, 0.0 + + return ( + stddev_matrix, + np.nanmean(stddev_matrix), + lut_matrix, + ) # np.mean([np.std(partition) for partition in partitions]) - return stddev_matrix, np.nanmean(stddev_matrix), lut_matrix #np.mean([np.std(partition) for partition in partitions]) def _corr_by_param(by_name, state_or_trans, attribute, param_index): """ @@ -169,22 +191,46 @@ def _corr_by_param(by_name, state_or_trans, attribute, param_index): :param param_index: index of parameter in `by_name[state_or_trans]['param']` """ if _all_params_are_numeric(by_name[state_or_trans], param_index): - param_values = np.array(list((map(lambda x: x[param_index], by_name[state_or_trans]['param'])))) + param_values = np.array( + list((map(lambda x: x[param_index], by_name[state_or_trans]["param"]))) + ) try: return np.corrcoef(by_name[state_or_trans][attribute], param_values)[0, 1] except FloatingPointError: # Typically happens when all parameter values are identical. # Building a correlation coefficient is pointless in this case # -> assume no correlation - return 0. + return 0.0 except ValueError: - print('[!] Exception in _corr_by_param(by_name, state_or_trans={}, attribute={}, param_index={})'.format(state_or_trans, attribute, param_index)) - print('[!] while executing np.corrcoef(by_name[{}][{}]={}, {}))'.format(state_or_trans, attribute, by_name[state_or_trans][attribute], param_values)) + print( + "[!] Exception in _corr_by_param(by_name, state_or_trans={}, attribute={}, param_index={})".format( + state_or_trans, attribute, param_index + ) + ) + print( + "[!] while executing np.corrcoef(by_name[{}][{}]={}, {}))".format( + state_or_trans, + attribute, + by_name[state_or_trans][attribute], + param_values, + ) + ) raise else: - return 0. - -def _compute_param_statistics(by_name, by_param, parameter_names, arg_count, state_or_trans, attribute, distinct_values, distinct_values_by_param_index, verbose = False): + return 0.0 + + +def _compute_param_statistics( + by_name, + by_param, + parameter_names, + arg_count, + state_or_trans, + attribute, + distinct_values, + distinct_values_by_param_index, + verbose=False, +): """ Compute standard deviation and correlation coefficient for various data partitions. @@ -223,87 +269,140 @@ def _compute_param_statistics(by_name, by_param, parameter_names, arg_count, sta Only set if state_or_trans appears in arg_count, empty dict otherwise. """ ret = { - 'std_static' : np.std(by_name[state_or_trans][attribute]), - 'std_param_lut' : np.mean([np.std(by_param[x][attribute]) for x in by_param.keys() if x[0] == state_or_trans]), - 'std_by_param' : {}, - 'std_by_param_values' : {}, - 'lut_by_param_values' : {}, - 'std_by_arg' : [], - 'std_by_arg_values' : [], - 'lut_by_arg_values' : [], - 'corr_by_param' : {}, - 'corr_by_arg' : [], - 'depends_on_param' : {}, - 'depends_on_arg' : [], - 'param_data' : {}, + "std_static": np.std(by_name[state_or_trans][attribute]), + "std_param_lut": np.mean( + [ + np.std(by_param[x][attribute]) + for x in by_param.keys() + if x[0] == state_or_trans + ] + ), + "std_by_param": {}, + "std_by_param_values": {}, + "lut_by_param_values": {}, + "std_by_arg": [], + "std_by_arg_values": [], + "lut_by_arg_values": [], + "corr_by_param": {}, + "corr_by_arg": [], + "depends_on_param": {}, + "depends_on_arg": [], + "param_data": {}, } - np.seterr('raise') + np.seterr("raise") for param_idx, param in enumerate(parameter_names): - std_matrix, mean_std, lut_matrix = _std_by_param(by_param, distinct_values_by_param_index, state_or_trans, attribute, param_idx, verbose) - ret['std_by_param'][param] = mean_std - ret['std_by_param_values'][param] = std_matrix - ret['lut_by_param_values'][param] = lut_matrix - ret['corr_by_param'][param] = _corr_by_param(by_name, state_or_trans, attribute, param_idx) - - ret['depends_on_param'][param] = _depends_on_param(ret['corr_by_param'][param], ret['std_by_param'][param], ret['std_param_lut']) - - if ret['depends_on_param'][param]: - ret['param_data'][param] = { - 'codependent_parameters': _codependent_parameters(param, lut_matrix, std_matrix), - 'depends_for_codependent_value': dict() + std_matrix, mean_std, lut_matrix = _std_by_param( + by_param, + distinct_values_by_param_index, + state_or_trans, + attribute, + param_idx, + verbose, + ) + ret["std_by_param"][param] = mean_std + ret["std_by_param_values"][param] = std_matrix + ret["lut_by_param_values"][param] = lut_matrix + ret["corr_by_param"][param] = _corr_by_param( + by_name, state_or_trans, attribute, param_idx + ) + + ret["depends_on_param"][param] = _depends_on_param( + ret["corr_by_param"][param], + ret["std_by_param"][param], + ret["std_param_lut"], + ) + + if ret["depends_on_param"][param]: + ret["param_data"][param] = { + "codependent_parameters": _codependent_parameters( + param, lut_matrix, std_matrix + ), + "depends_for_codependent_value": dict(), } # calculate parameter dependence for individual values of codependent parameters codependent_param_values = list() - for codependent_param in ret['param_data'][param]['codependent_parameters']: + for codependent_param in ret["param_data"][param]["codependent_parameters"]: codependent_param_values.append(distinct_values[codependent_param]) for combi in itertools.product(*codependent_param_values): by_name_part = deepcopy(by_name) - filter_list = list(zip(ret['param_data'][param]['codependent_parameters'], combi)) + filter_list = list( + zip(ret["param_data"][param]["codependent_parameters"], combi) + ) filter_aggregate_by_param(by_name_part, parameter_names, filter_list) by_param_part = by_name_to_by_param(by_name_part) # there may be no data for this specific parameter value combination if state_or_trans in by_name_part: - part_corr = _corr_by_param(by_name_part, state_or_trans, attribute, param_idx) - part_std_lut = np.mean([np.std(by_param_part[x][attribute]) for x in by_param_part.keys() if x[0] == state_or_trans]) - _, part_std_param, _ = _std_by_param(by_param_part, distinct_values_by_param_index, state_or_trans, attribute, param_idx, verbose) - ret['param_data'][param]['depends_for_codependent_value'][combi] = _depends_on_param(part_corr, part_std_param, part_std_lut) + part_corr = _corr_by_param( + by_name_part, state_or_trans, attribute, param_idx + ) + part_std_lut = np.mean( + [ + np.std(by_param_part[x][attribute]) + for x in by_param_part.keys() + if x[0] == state_or_trans + ] + ) + _, part_std_param, _ = _std_by_param( + by_param_part, + distinct_values_by_param_index, + state_or_trans, + attribute, + param_idx, + verbose, + ) + ret["param_data"][param]["depends_for_codependent_value"][ + combi + ] = _depends_on_param(part_corr, part_std_param, part_std_lut) if state_or_trans in arg_count: for arg_index in range(arg_count[state_or_trans]): - std_matrix, mean_std, lut_matrix = _std_by_param(by_param, distinct_values_by_param_index, state_or_trans, attribute, len(parameter_names) + arg_index, verbose) - ret['std_by_arg'].append(mean_std) - ret['std_by_arg_values'].append(std_matrix) - ret['lut_by_arg_values'].append(lut_matrix) - ret['corr_by_arg'].append(_corr_by_param(by_name, state_or_trans, attribute, len(parameter_names) + arg_index)) + std_matrix, mean_std, lut_matrix = _std_by_param( + by_param, + distinct_values_by_param_index, + state_or_trans, + attribute, + len(parameter_names) + arg_index, + verbose, + ) + ret["std_by_arg"].append(mean_std) + ret["std_by_arg_values"].append(std_matrix) + ret["lut_by_arg_values"].append(lut_matrix) + ret["corr_by_arg"].append( + _corr_by_param( + by_name, state_or_trans, attribute, len(parameter_names) + arg_index + ) + ) if False: - ret['depends_on_arg'].append(ret['corr_by_arg'][arg_index] > 0.1) - elif ret['std_by_arg'][arg_index] == 0: + ret["depends_on_arg"].append(ret["corr_by_arg"][arg_index] > 0.1) + elif ret["std_by_arg"][arg_index] == 0: # In general, std_param_lut < std_by_arg. So, if std_by_arg == 0, std_param_lut == 0 follows. # This means that the variation of arg does not affect the model quality -> no influence - ret['depends_on_arg'].append(False) + ret["depends_on_arg"].append(False) else: - ret['depends_on_arg'].append(ret['std_param_lut'] / ret['std_by_arg'][arg_index] < 0.5) + ret["depends_on_arg"].append( + ret["std_param_lut"] / ret["std_by_arg"][arg_index] < 0.5 + ) return ret + def _compute_param_statistics_parallel(arg): - return { - 'key' : arg['key'], - 'result': _compute_param_statistics(*arg['args']) - } + return {"key": arg["key"], "result": _compute_param_statistics(*arg["args"])} + def _all_params_are_numeric(data, param_idx): """Check if all `data['param'][*][param_idx]` elements are numeric, as reported by `utils.is_numeric`.""" - param_values = list(map(lambda x: x[param_idx], data['param'])) + param_values = list(map(lambda x: x[param_idx], data["param"])) if len(list(filter(is_numeric, param_values))) == len(param_values): return True return False -def prune_dependent_parameters(by_name, parameter_names, correlation_threshold = 0.5): + +def prune_dependent_parameters(by_name, parameter_names, correlation_threshold=0.5): """ Remove dependent parameters from aggregate. @@ -320,15 +419,17 @@ def prune_dependent_parameters(by_name, parameter_names, correlation_threshold = """ parameter_indices_to_remove = list() - for parameter_combination in itertools.product(range(len(parameter_names)), range(len(parameter_names))): + for parameter_combination in itertools.product( + range(len(parameter_names)), range(len(parameter_names)) + ): index_1, index_2 = parameter_combination if index_1 >= index_2: continue - parameter_values = [list(), list()] # both parameters have a value - parameter_values_1 = list() # parameter 1 has a value - parameter_values_2 = list() # parameter 2 has a value + parameter_values = [list(), list()] # both parameters have a value + parameter_values_1 = list() # parameter 1 has a value + parameter_values_2 = list() # parameter 2 has a value for name in by_name: - for measurement in by_name[name]['param']: + for measurement in by_name[name]["param"]: value_1 = measurement[index_1] value_2 = measurement[index_2] if is_numeric(value_1): @@ -342,16 +443,30 @@ def prune_dependent_parameters(by_name, parameter_names, correlation_threshold = # Calculating the correlation coefficient only makes sense when neither value is constant if np.std(parameter_values_1) != 0 and np.std(parameter_values_2) != 0: correlation = np.corrcoef(parameter_values)[0][1] - if correlation != np.nan and np.abs(correlation) > correlation_threshold: - print('[!] Parameters {} <-> {} are correlated with coefficcient {}'.format(parameter_names[index_1], parameter_names[index_2], correlation)) + if ( + correlation != np.nan + and np.abs(correlation) > correlation_threshold + ): + print( + "[!] Parameters {} <-> {} are correlated with coefficcient {}".format( + parameter_names[index_1], + parameter_names[index_2], + correlation, + ) + ) if len(parameter_values_1) < len(parameter_values_2): index_to_remove = index_1 else: index_to_remove = index_2 - print(' Removing parameter {}'.format(parameter_names[index_to_remove])) + print( + " Removing parameter {}".format( + parameter_names[index_to_remove] + ) + ) parameter_indices_to_remove.append(index_to_remove) remove_parameters_by_indices(by_name, parameter_names, parameter_indices_to_remove) + def remove_parameters_by_indices(by_name, parameter_names, parameter_indices_to_remove): """ Remove parameters listed in `parameter_indices` from aggregate `by_name` and `parameter_names`. @@ -365,12 +480,13 @@ def remove_parameters_by_indices(by_name, parameter_names, parameter_indices_to_ """ # Start removal from the end of the list to avoid renumbering of list elemenets - for parameter_index in sorted(parameter_indices_to_remove, reverse = True): + for parameter_index in sorted(parameter_indices_to_remove, reverse=True): for name in by_name: - for measurement in by_name[name]['param']: + for measurement in by_name[name]["param"]: measurement.pop(parameter_index) parameter_names.pop(parameter_index) + class ParamStats: """ :param stats: `stats[state_or_tran][attribute]` = std_static, std_param_lut, ... (see `compute_param_statistics`) @@ -378,7 +494,15 @@ class ParamStats: :param distinct_values_by_param_index: `distinct_values[state_or_tran][i]` = [distinct values in aggregate] """ - def __init__(self, by_name, by_param, parameter_names, arg_count, use_corrcoef = False, verbose = False): + def __init__( + self, + by_name, + by_param, + parameter_names, + arg_count, + use_corrcoef=False, + verbose=False, + ): """ Compute standard deviation and correlation coefficient on parameterized data partitions. @@ -411,24 +535,40 @@ class ParamStats: for state_or_tran in by_name.keys(): self.stats[state_or_tran] = dict() - self.distinct_values_by_param_index[state_or_tran] = distinct_param_values(by_name, state_or_tran) + self.distinct_values_by_param_index[state_or_tran] = distinct_param_values( + by_name, state_or_tran + ) self.distinct_values[state_or_tran] = dict() for i, param in enumerate(parameter_names): - self.distinct_values[state_or_tran][param] = self.distinct_values_by_param_index[state_or_tran][i] - for attribute in by_name[state_or_tran]['attributes']: - stats_queue.append({ - 'key': [state_or_tran, attribute], - 'args': [by_name, by_param, parameter_names, arg_count, state_or_tran, attribute, self.distinct_values[state_or_tran], self.distinct_values_by_param_index[state_or_tran], verbose], - }) + self.distinct_values[state_or_tran][ + param + ] = self.distinct_values_by_param_index[state_or_tran][i] + for attribute in by_name[state_or_tran]["attributes"]: + stats_queue.append( + { + "key": [state_or_tran, attribute], + "args": [ + by_name, + by_param, + parameter_names, + arg_count, + state_or_tran, + attribute, + self.distinct_values[state_or_tran], + self.distinct_values_by_param_index[state_or_tran], + verbose, + ], + } + ) with Pool() as pool: stats_results = pool.map(_compute_param_statistics_parallel, stats_queue) for stats in stats_results: - state_or_tran, attribute = stats['key'] - self.stats[state_or_tran][attribute] = stats['result'] + state_or_tran, attribute = stats["key"] + self.stats[state_or_tran][attribute] = stats["result"] - def can_be_fitted(self, state_or_tran = None) -> bool: + def can_be_fitted(self, state_or_tran=None) -> bool: """ Return whether a sufficient amount of distinct numeric parameter values is available, allowing a parameter-aware model to be generated. @@ -441,8 +581,27 @@ class ParamStats: for key in keys: for param in self._parameter_names: - if len(list(filter(lambda n: is_numeric(n), self.distinct_values[key][param]))) > 2: - print(key, param, list(filter(lambda n: is_numeric(n), self.distinct_values[key][param]))) + if ( + len( + list( + filter( + lambda n: is_numeric(n), + self.distinct_values[key][param], + ) + ) + ) + > 2 + ): + print( + key, + param, + list( + filter( + lambda n: is_numeric(n), + self.distinct_values[key][param], + ) + ), + ) return True return False @@ -456,7 +615,9 @@ class ParamStats: # TODO pass - def has_codependent_parameters(self, state_or_tran: str, attribute: str, param: str) -> bool: + def has_codependent_parameters( + self, state_or_tran: str, attribute: str, param: str + ) -> bool: """ Return whether there are parameters which determine whether `param` influences `state_or_tran` `attribute` or not. @@ -468,7 +629,9 @@ class ParamStats: return True return False - def codependent_parameters(self, state_or_tran: str, attribute: str, param: str) -> list: + def codependent_parameters( + self, state_or_tran: str, attribute: str, param: str + ) -> list: """ Return list of parameters which determine whether `param` influences `state_or_tran` `attribute` or not. @@ -476,12 +639,15 @@ class ParamStats: :param attribute: model attribute :param param: parameter name """ - if self.stats[state_or_tran][attribute]['depends_on_param'][param]: - return self.stats[state_or_tran][attribute]['param_data'][param]['codependent_parameters'] + if self.stats[state_or_tran][attribute]["depends_on_param"][param]: + return self.stats[state_or_tran][attribute]["param_data"][param][ + "codependent_parameters" + ] return list() - - def has_codependent_parameters_union(self, state_or_tran: str, attribute: str) -> bool: + def has_codependent_parameters_union( + self, state_or_tran: str, attribute: str + ) -> bool: """ Return whether there is a subset of parameters which decides whether `state_or_tran` `attribute` is static or parameter-dependent @@ -490,11 +656,14 @@ class ParamStats: """ depends_on_a_parameter = False for param in self._parameter_names: - if self.stats[state_or_tran][attribute]['depends_on_param'][param]: - print('{}/{} depends on {}'.format(state_or_tran, attribute, param)) + if self.stats[state_or_tran][attribute]["depends_on_param"][param]: + print("{}/{} depends on {}".format(state_or_tran, attribute, param)) depends_on_a_parameter = True - if len(self.codependent_parameters(state_or_tran, attribute, param)) == 0: - print('has no codependent parameters') + if ( + len(self.codependent_parameters(state_or_tran, attribute, param)) + == 0 + ): + print("has no codependent parameters") # Always depends on this parameter, regardless of other parameters' values return False return depends_on_a_parameter @@ -508,14 +677,21 @@ class ParamStats: """ codependent_parameters = set() for param in self._parameter_names: - if self.stats[state_or_tran][attribute]['depends_on_param'][param]: - if len(self.codependent_parameters(state_or_tran, attribute, param)) == 0: + if self.stats[state_or_tran][attribute]["depends_on_param"][param]: + if ( + len(self.codependent_parameters(state_or_tran, attribute, param)) + == 0 + ): return list(self._parameter_names) - for codependent_param in self.codependent_parameters(state_or_tran, attribute, param): + for codependent_param in self.codependent_parameters( + state_or_tran, attribute, param + ): codependent_parameters.add(codependent_param) return sorted(codependent_parameters) - def codependence_by_codependent_param_values(self, state_or_tran: str, attribute: str, param: str) -> dict: + def codependence_by_codependent_param_values( + self, state_or_tran: str, attribute: str, param: str + ) -> dict: """ Return dict mapping codependent parameter values to a boolean indicating whether `param` influences `state_or_tran` `attribute`. @@ -525,11 +701,15 @@ class ParamStats: :param attribute: model attribute :param param: parameter name """ - if self.stats[state_or_tran][attribute]['depends_on_param'][param]: - return self.stats[state_or_tran][attribute]['param_data'][param]['depends_for_codependent_value'] + if self.stats[state_or_tran][attribute]["depends_on_param"][param]: + return self.stats[state_or_tran][attribute]["param_data"][param][ + "depends_for_codependent_value" + ] return dict() - def codependent_parameter_value_dicts(self, state_or_tran: str, attribute: str, param: str, kind='dynamic'): + def codependent_parameter_value_dicts( + self, state_or_tran: str, attribute: str, param: str, kind="dynamic" + ): """ Return dicts of codependent parameter key-value mappings for which `param` influences (or does not influence) `state_or_tran` `attribute`. @@ -538,16 +718,21 @@ class ParamStats: :param param: parameter name: :param kind: 'static' or 'dynamic'. If 'dynamic' (the default), returns codependent parameter values for which `param` influences `attribute`. If 'static', returns codependent parameter values for which `param` does not influence `attribute` """ - codependent_parameters = self.stats[state_or_tran][attribute]['param_data'][param]['codependent_parameters'] - codependence_info = self.stats[state_or_tran][attribute]['param_data'][param]['depends_for_codependent_value'] + codependent_parameters = self.stats[state_or_tran][attribute]["param_data"][ + param + ]["codependent_parameters"] + codependence_info = self.stats[state_or_tran][attribute]["param_data"][param][ + "depends_for_codependent_value" + ] if len(codependent_parameters) == 0: return else: for param_values, is_dynamic in codependence_info.items(): - if (is_dynamic and kind == 'dynamic') or (not is_dynamic and kind == 'static'): + if (is_dynamic and kind == "dynamic") or ( + not is_dynamic and kind == "static" + ): yield dict(zip(codependent_parameters, param_values)) - def _generic_param_independence_ratio(self, state_or_trans, attribute): """ Return the heuristic ratio of parameter independence for state_or_trans and attribute. @@ -559,9 +744,9 @@ class ParamStats: if self.use_corrcoef: # not supported raise ValueError - if statistics['std_static'] == 0: + if statistics["std_static"] == 0: return 0 - return statistics['std_param_lut'] / statistics['std_static'] + return statistics["std_param_lut"] / statistics["std_static"] def generic_param_dependence_ratio(self, state_or_trans, attribute): """ @@ -572,7 +757,9 @@ class ParamStats: """ return 1 - self._generic_param_independence_ratio(state_or_trans, attribute) - def _param_independence_ratio(self, state_or_trans: str, attribute: str, param: str) -> float: + def _param_independence_ratio( + self, state_or_trans: str, attribute: str, param: str + ) -> float: """ Return the heuristic ratio of parameter independence for state_or_trans, attribute, and param. @@ -580,17 +767,19 @@ class ParamStats: """ statistics = self.stats[state_or_trans][attribute] if self.use_corrcoef: - return 1 - np.abs(statistics['corr_by_param'][param]) - if statistics['std_by_param'][param] == 0: - if statistics['std_param_lut'] != 0: + return 1 - np.abs(statistics["corr_by_param"][param]) + if statistics["std_by_param"][param] == 0: + if statistics["std_param_lut"] != 0: raise RuntimeError("wat") # In general, std_param_lut < std_by_param. So, if std_by_param == 0, std_param_lut == 0 follows. # This means that the variation of param does not affect the model quality -> no influence, return 1 - return 1. + return 1.0 - return statistics['std_param_lut'] / statistics['std_by_param'][param] + return statistics["std_param_lut"] / statistics["std_by_param"][param] - def param_dependence_ratio(self, state_or_trans: str, attribute: str, param: str) -> float: + def param_dependence_ratio( + self, state_or_trans: str, attribute: str, param: str + ) -> float: """ Return the heuristic ratio of parameter dependence for state_or_trans, attribute, and param. @@ -607,16 +796,18 @@ class ParamStats: def _arg_independence_ratio(self, state_or_trans, attribute, arg_index): statistics = self.stats[state_or_trans][attribute] if self.use_corrcoef: - return 1 - np.abs(statistics['corr_by_arg'][arg_index]) - if statistics['std_by_arg'][arg_index] == 0: - if statistics['std_param_lut'] != 0: + return 1 - np.abs(statistics["corr_by_arg"][arg_index]) + if statistics["std_by_arg"][arg_index] == 0: + if statistics["std_param_lut"] != 0: raise RuntimeError("wat") # In general, std_param_lut < std_by_arg. So, if std_by_arg == 0, std_param_lut == 0 follows. # This means that the variation of arg does not affect the model quality -> no influence, return 1 return 1 - return statistics['std_param_lut'] / statistics['std_by_arg'][arg_index] + return statistics["std_param_lut"] / statistics["std_by_arg"][arg_index] - def arg_dependence_ratio(self, state_or_trans: str, attribute: str, arg_index: int) -> float: + def arg_dependence_ratio( + self, state_or_trans: str, attribute: str, arg_index: int + ) -> float: return 1 - self._arg_independence_ratio(state_or_trans, attribute, arg_index) # This heuristic is very similar to the "function is not much better than @@ -625,10 +816,9 @@ class ParamStats: # --df, 2018-04-18 def depends_on_param(self, state_or_trans, attribute, param): """Return whether attribute of state_or_trans depens on param.""" - return self.stats[state_or_trans][attribute]['depends_on_param'][param] + return self.stats[state_or_trans][attribute]["depends_on_param"][param] # See notes on depends_on_param def depends_on_arg(self, state_or_trans, attribute, arg_index): """Return whether attribute of state_or_trans depens on arg_index.""" - return self.stats[state_or_trans][attribute]['depends_on_arg'][arg_index] - + return self.stats[state_or_trans][attribute]["depends_on_arg"][arg_index] diff --git a/lib/plotter.py b/lib/plotter.py index deed93a..16c0145 100755 --- a/lib/plotter.py +++ b/lib/plotter.py @@ -8,75 +8,89 @@ import re def is_state(aggregate, name): """Return true if name is a state and not UNINITIALIZED.""" - return aggregate[name]['isa'] == 'state' and name != 'UNINITIALIZED' + return aggregate[name]["isa"] == "state" and name != "UNINITIALIZED" def plot_states(model, aggregate): keys = [key for key in sorted(aggregate.keys()) if is_state(aggregate, key)] - data = [aggregate[key]['means'] for key in keys] - mdata = [int(model['state'][key]['power']['static']) for key in keys] - boxplot(keys, data, 'Zustand', 'µW', modeldata=mdata) + data = [aggregate[key]["means"] for key in keys] + mdata = [int(model["state"][key]["power"]["static"]) for key in keys] + boxplot(keys, data, "Zustand", "µW", modeldata=mdata) def plot_transitions(model, aggregate): - keys = [key for key in sorted(aggregate.keys()) if aggregate[key]['isa'] == 'transition'] - data = [aggregate[key]['rel_energies'] for key in keys] - mdata = [int(model['transition'][key]['rel_energy']['static']) for key in keys] - boxplot(keys, data, 'Transition', 'pJ (rel)', modeldata=mdata) - data = [aggregate[key]['energies'] for key in keys] - mdata = [int(model['transition'][key]['energy']['static']) for key in keys] - boxplot(keys, data, 'Transition', 'pJ', modeldata=mdata) + keys = [ + key for key in sorted(aggregate.keys()) if aggregate[key]["isa"] == "transition" + ] + data = [aggregate[key]["rel_energies"] for key in keys] + mdata = [int(model["transition"][key]["rel_energy"]["static"]) for key in keys] + boxplot(keys, data, "Transition", "pJ (rel)", modeldata=mdata) + data = [aggregate[key]["energies"] for key in keys] + mdata = [int(model["transition"][key]["energy"]["static"]) for key in keys] + boxplot(keys, data, "Transition", "pJ", modeldata=mdata) def plot_states_duration(model, aggregate): keys = [key for key in sorted(aggregate.keys()) if is_state(aggregate, key)] - data = [aggregate[key]['durations'] for key in keys] - boxplot(keys, data, 'Zustand', 'µs') + data = [aggregate[key]["durations"] for key in keys] + boxplot(keys, data, "Zustand", "µs") def plot_transitions_duration(model, aggregate): - keys = [key for key in sorted(aggregate.keys()) if aggregate[key]['isa'] == 'transition'] - data = [aggregate[key]['durations'] for key in keys] - boxplot(keys, data, 'Transition', 'µs') + keys = [ + key for key in sorted(aggregate.keys()) if aggregate[key]["isa"] == "transition" + ] + data = [aggregate[key]["durations"] for key in keys] + boxplot(keys, data, "Transition", "µs") def plot_transitions_timeout(model, aggregate): - keys = [key for key in sorted(aggregate.keys()) if aggregate[key]['isa'] == 'transition'] - data = [aggregate[key]['timeouts'] for key in keys] - boxplot(keys, data, 'Timeout', 'µs') + keys = [ + key for key in sorted(aggregate.keys()) if aggregate[key]["isa"] == "transition" + ] + data = [aggregate[key]["timeouts"] for key in keys] + boxplot(keys, data, "Timeout", "µs") def plot_states_clips(model, aggregate): keys = [key for key in sorted(aggregate.keys()) if is_state(aggregate, key)] - data = [np.array([100]) * aggregate[key]['clip_rate'] for key in keys] - boxplot(keys, data, 'Zustand', '% Clipping') + data = [np.array([100]) * aggregate[key]["clip_rate"] for key in keys] + boxplot(keys, data, "Zustand", "% Clipping") def plot_transitions_clips(model, aggregate): - keys = [key for key in sorted(aggregate.keys()) if aggregate[key]['isa'] == 'transition'] - data = [np.array([100]) * aggregate[key]['clip_rate'] for key in keys] - boxplot(keys, data, 'Transition', '% Clipping') + keys = [ + key for key in sorted(aggregate.keys()) if aggregate[key]["isa"] == "transition" + ] + data = [np.array([100]) * aggregate[key]["clip_rate"] for key in keys] + boxplot(keys, data, "Transition", "% Clipping") def plot_substate_thresholds(model, aggregate): keys = [key for key in sorted(aggregate.keys()) if is_state(aggregate, key)] - data = [aggregate[key]['sub_thresholds'] for key in keys] - boxplot(keys, data, 'Zustand', 'substate threshold (mW/dmW)') + data = [aggregate[key]["sub_thresholds"] for key in keys] + boxplot(keys, data, "Zustand", "substate threshold (mW/dmW)") def plot_histogram(data): - n, bins, patches = plt.hist(data, 1000, normed=1, facecolor='green', alpha=0.75) + n, bins, patches = plt.hist(data, 1000, normed=1, facecolor="green", alpha=0.75) plt.show() def plot_states_param(model, aggregate): - keys = [key for key in sorted(aggregate.keys()) if aggregate[key]['isa'] == 'state' and key[0] != 'UNINITIALIZED'] - data = [aggregate[key]['means'] for key in keys] - mdata = [int(model['state'][key[0]]['power']['static']) for key in keys] - boxplot(keys, data, 'Transition', 'µW', modeldata=mdata) - - -def plot_attribute(aggregate, attribute, attribute_unit='', key_filter=lambda x: True, **kwargs): + keys = [ + key + for key in sorted(aggregate.keys()) + if aggregate[key]["isa"] == "state" and key[0] != "UNINITIALIZED" + ] + data = [aggregate[key]["means"] for key in keys] + mdata = [int(model["state"][key[0]]["power"]["static"]) for key in keys] + boxplot(keys, data, "Transition", "µW", modeldata=mdata) + + +def plot_attribute( + aggregate, attribute, attribute_unit="", key_filter=lambda x: True, **kwargs +): """ Boxplot measurements of a single attribute according to the partitioning provided by aggregate. @@ -94,13 +108,17 @@ def plot_attribute(aggregate, attribute, attribute_unit='', key_filter=lambda x: def plot_substate_thresholds_p(model, aggregate): - keys = [key for key in sorted(aggregate.keys()) if aggregate[key]['isa'] == 'state' and key[0] != 'UNINITIALIZED'] - data = [aggregate[key]['sub_thresholds'] for key in keys] - boxplot(keys, data, 'Zustand', '% Clipping') + keys = [ + key + for key in sorted(aggregate.keys()) + if aggregate[key]["isa"] == "state" and key[0] != "UNINITIALIZED" + ] + data = [aggregate[key]["sub_thresholds"] for key in keys] + boxplot(keys, data, "Zustand", "% Clipping") def plot_y(Y, **kwargs): - if 'family' in kwargs and kwargs['family']: + if "family" in kwargs and kwargs["family"]: plot_xy(None, Y, **kwargs) else: plot_xy(np.arange(len(Y)), Y, **kwargs) @@ -116,26 +134,39 @@ def plot_xy(X, Y, xlabel=None, ylabel=None, title=None, output=None, family=Fals ax1.set_ylabel(ylabel) plt.subplots_adjust(left=0.1, bottom=0.1, right=0.99, top=0.99) if family: - cm = plt.get_cmap('brg', len(Y)) + cm = plt.get_cmap("brg", len(Y)) for i, YY in enumerate(Y): plt.plot(np.arange(len(YY)), YY, "-", markersize=2, color=cm(i)) else: plt.plot(X, Y, "bo", markersize=2) if output: plt.savefig(output) - with open('{}.txt'.format(output), 'w') as f: - print('X Y', file=f) + with open("{}.txt".format(output), "w") as f: + print("X Y", file=f) for i in range(len(X)): - print('{} {}'.format(X[i], Y[i]), file=f) + print("{} {}".format(X[i], Y[i]), file=f) else: plt.show() def _param_slice_eq(a, b, index): - return (*a[1][:index], *a[1][index + 1:]) == (*b[1][:index], *b[1][index + 1:]) and a[0] == b[0] - - -def plot_param(model, state_or_trans, attribute, param_idx, xlabel=None, ylabel=None, title=None, extra_function=None, output=None): + return (*a[1][:index], *a[1][index + 1 :]) == ( + *b[1][:index], + *b[1][index + 1 :], + ) and a[0] == b[0] + + +def plot_param( + model, + state_or_trans, + attribute, + param_idx, + xlabel=None, + ylabel=None, + title=None, + extra_function=None, + output=None, +): fig, ax1 = plt.subplots(figsize=(10, 6)) if title is not None: fig.canvas.set_window_title(title) @@ -147,8 +178,12 @@ def plot_param(model, state_or_trans, attribute, param_idx, xlabel=None, ylabel= param_name = model.param_name(param_idx) - function_filename = 'plot_param_{}_{}_{}.txt'.format(state_or_trans, attribute, param_name) - data_filename_base = 'measurements_{}_{}_{}'.format(state_or_trans, attribute, param_name) + function_filename = "plot_param_{}_{}_{}.txt".format( + state_or_trans, attribute, param_name + ) + data_filename_base = "measurements_{}_{}_{}".format( + state_or_trans, attribute, param_name + ) param_model, param_info = model.get_fitted() @@ -156,16 +191,18 @@ def plot_param(model, state_or_trans, attribute, param_idx, xlabel=None, ylabel= XX = [] - legend_sanitizer = re.compile(r'[^0-9a-zA-Z]+') + legend_sanitizer = re.compile(r"[^0-9a-zA-Z]+") for k, v in model.by_param.items(): if k[0] == state_or_trans: - other_param_key = (*k[1][:param_idx], *k[1][param_idx + 1:]) + other_param_key = (*k[1][:param_idx], *k[1][param_idx + 1 :]) if other_param_key not in by_other_param: - by_other_param[other_param_key] = {'X': [], 'Y': []} - by_other_param[other_param_key]['X'].extend([float(k[1][param_idx])] * len(v[attribute])) - by_other_param[other_param_key]['Y'].extend(v[attribute]) - XX.extend(by_other_param[other_param_key]['X']) + by_other_param[other_param_key] = {"X": [], "Y": []} + by_other_param[other_param_key]["X"].extend( + [float(k[1][param_idx])] * len(v[attribute]) + ) + by_other_param[other_param_key]["Y"].extend(v[attribute]) + XX.extend(by_other_param[other_param_key]["X"]) XX = np.array(XX) x_range = int((XX.max() - XX.min()) * 10) @@ -175,22 +212,22 @@ def plot_param(model, state_or_trans, attribute, param_idx, xlabel=None, ylabel= YY2 = [] YY2_legend = [] - cm = plt.get_cmap('brg', len(by_other_param)) + cm = plt.get_cmap("brg", len(by_other_param)) for i, k in sorted(enumerate(by_other_param), key=lambda x: x[1]): v = by_other_param[k] - v['X'] = np.array(v['X']) - v['Y'] = np.array(v['Y']) - plt.plot(v['X'], v['Y'], "ro", color=cm(i), markersize=3) - YY2_legend.append(legend_sanitizer.sub('_', 'X_{}'.format(k))) - YY2.append(v['X']) - YY2_legend.append(legend_sanitizer.sub('_', 'Y_{}'.format(k))) - YY2.append(v['Y']) - - sanitized_k = legend_sanitizer.sub('_', str(k)) - with open('{}_{}.txt'.format(data_filename_base, sanitized_k), 'w') as f: - print('X Y', file=f) - for i in range(len(v['X'])): - print('{} {}'.format(v['X'][i], v['Y'][i]), file=f) + v["X"] = np.array(v["X"]) + v["Y"] = np.array(v["Y"]) + plt.plot(v["X"], v["Y"], "ro", color=cm(i), markersize=3) + YY2_legend.append(legend_sanitizer.sub("_", "X_{}".format(k))) + YY2.append(v["X"]) + YY2_legend.append(legend_sanitizer.sub("_", "Y_{}".format(k))) + YY2.append(v["Y"]) + + sanitized_k = legend_sanitizer.sub("_", str(k)) + with open("{}_{}.txt".format(data_filename_base, sanitized_k), "w") as f: + print("X Y", file=f) + for i in range(len(v["X"])): + print("{} {}".format(v["X"][i], v["Y"][i]), file=f) # x_range = int((v['X'].max() - v['X'].min()) * 10) # xsp = np.linspace(v['X'].min(), v['X'].max(), x_range) @@ -201,21 +238,21 @@ def plot_param(model, state_or_trans, attribute, param_idx, xlabel=None, ylabel= ysp.append(param_model(state_or_trans, attribute, param=xarg)) plt.plot(xsp, ysp, "r-", color=cm(i), linewidth=0.5) YY.append(ysp) - YY_legend.append(legend_sanitizer.sub('_', 'regr_{}'.format(k))) + YY_legend.append(legend_sanitizer.sub("_", "regr_{}".format(k))) if extra_function is not None: ysp = [] - with np.errstate(divide='ignore', invalid='ignore'): + with np.errstate(divide="ignore", invalid="ignore"): for x in xsp: xarg = [*k[:param_idx], x, *k[param_idx:]] ysp.append(extra_function(*xarg)) plt.plot(xsp, ysp, "r--", color=cm(i), linewidth=1, dashes=(3, 3)) YY.append(ysp) - YY_legend.append(legend_sanitizer.sub('_', 'symb_{}'.format(k))) + YY_legend.append(legend_sanitizer.sub("_", "symb_{}".format(k))) - with open(function_filename, 'w') as f: - print(' '.join(YY_legend), file=f) + with open(function_filename, "w") as f: + print(" ".join(YY_legend), file=f) for elem in np.array(YY).T: - print(' '.join(map(str, elem)), file=f) + print(" ".join(map(str, elem)), file=f) print(data_filename_base, function_filename) if output: @@ -224,7 +261,19 @@ def plot_param(model, state_or_trans, attribute, param_idx, xlabel=None, ylabel= plt.show() -def plot_param_fit(function, name, fitfunc, funp, parameters, datatype, index, X, Y, xaxis=None, yaxis=None): +def plot_param_fit( + function, + name, + fitfunc, + funp, + parameters, + datatype, + index, + X, + Y, + xaxis=None, + yaxis=None, +): fig, ax1 = plt.subplots(figsize=(10, 6)) fig.canvas.set_window_title("fit %s" % (function)) plt.subplots_adjust(left=0.14, right=0.99, top=0.99, bottom=0.14) @@ -244,10 +293,10 @@ def plot_param_fit(function, name, fitfunc, funp, parameters, datatype, index, X if yaxis is not None: ax1.set_ylabel(yaxis) else: - ax1.set_ylabel('%s %s' % (name, datatype)) + ax1.set_ylabel("%s %s" % (name, datatype)) - otherparams = list(set(itertools.product(*X[:index], *X[index + 1:]))) - cm = plt.get_cmap('brg', len(otherparams)) + otherparams = list(set(itertools.product(*X[:index], *X[index + 1 :]))) + cm = plt.get_cmap("brg", len(otherparams)) for i in range(len(otherparams)): elem = otherparams[i] color = cm(i) @@ -268,18 +317,17 @@ def plot_param_fit(function, name, fitfunc, funp, parameters, datatype, index, X plt.show() -def boxplot(ticks, measurements, xlabel='', ylabel='', modeldata=None, output=None): +def boxplot(ticks, measurements, xlabel="", ylabel="", modeldata=None, output=None): fig, ax1 = plt.subplots(figsize=(10, 6)) - fig.canvas.set_window_title('DriverEval') + fig.canvas.set_window_title("DriverEval") plt.subplots_adjust(left=0.1, right=0.95, top=0.95, bottom=0.1) - bp = plt.boxplot(measurements, notch=0, sym='+', vert=1, whis=1.5) - plt.setp(bp['boxes'], color='black') - plt.setp(bp['whiskers'], color='black') - plt.setp(bp['fliers'], color='red', marker='+') + bp = plt.boxplot(measurements, notch=0, sym="+", vert=1, whis=1.5) + plt.setp(bp["boxes"], color="black") + plt.setp(bp["whiskers"], color="black") + plt.setp(bp["fliers"], color="red", marker="+") - ax1.yaxis.grid(True, linestyle='-', which='major', color='lightgrey', - alpha=0.5) + ax1.yaxis.grid(True, linestyle="-", which="major", color="lightgrey", alpha=0.5) ax1.set_axisbelow(True) # ax1.set_title('DriverEval') @@ -294,7 +342,7 @@ def boxplot(ticks, measurements, xlabel='', ylabel='', modeldata=None, output=No # boxColors = ['darkkhaki', 'royalblue'] medians = list(range(numBoxes)) for i in range(numBoxes): - box = bp['boxes'][i] + box = bp["boxes"][i] boxX = [] boxY = [] for j in range(5): @@ -306,21 +354,31 @@ def boxplot(ticks, measurements, xlabel='', ylabel='', modeldata=None, output=No # boxPolygon = Polygon(boxCoords, facecolor=boxColors[k]) # ax1.add_patch(boxPolygon) # Now draw the median lines back over what we just filled in - med = bp['medians'][i] + med = bp["medians"][i] medianX = [] medianY = [] for j in range(2): medianX.append(med.get_xdata()[j]) medianY.append(med.get_ydata()[j]) - plt.plot(medianX, medianY, 'k') + plt.plot(medianX, medianY, "k") medians[i] = medianY[0] # Finally, overplot the sample averages, with horizontal alignment # in the center of each box - plt.plot([np.average(med.get_xdata())], [np.average(measurements[i])], - color='w', marker='*', markeredgecolor='k') + plt.plot( + [np.average(med.get_xdata())], + [np.average(measurements[i])], + color="w", + marker="*", + markeredgecolor="k", + ) if modeldata: - plt.plot([np.average(med.get_xdata())], [modeldata[i]], - color='w', marker='o', markeredgecolor='k') + plt.plot( + [np.average(med.get_xdata())], + [modeldata[i]], + color="w", + marker="o", + markeredgecolor="k", + ) pos = np.arange(numBoxes) + 1 upperLabels = [str(np.round(s, 2)) for s in medians] @@ -330,16 +388,21 @@ def boxplot(ticks, measurements, xlabel='', ylabel='', modeldata=None, output=No y0, y1 = ax1.get_ylim() textpos = y0 + (y1 - y0) * 0.97 # ypos = ax1.get_ylim()[0] - ax1.text(pos[tick], textpos, upperLabels[tick], - horizontalalignment='center', size='small', - color='royalblue') + ax1.text( + pos[tick], + textpos, + upperLabels[tick], + horizontalalignment="center", + size="small", + color="royalblue", + ) if output: plt.savefig(output) - with open('{}.txt'.format(output), 'w') as f: - print('X Y', file=f) + with open("{}.txt".format(output), "w") as f: + print("X Y", file=f) for i, data in enumerate(measurements): for value in data: - print('{} {}'.format(ticks[i], value), file=f) + print("{} {}".format(ticks[i], value), file=f) else: plt.show() diff --git a/lib/protocol_benchmarks.py b/lib/protocol_benchmarks.py index e82af67..b42e821 100755 --- a/lib/protocol_benchmarks.py +++ b/lib/protocol_benchmarks.py @@ -18,40 +18,41 @@ import re import time from filelock import FileLock + class DummyProtocol: def __init__(self): self.max_serialized_bytes = None - self.enc_buf = '' - self.dec_buf = '' - self.dec_buf0 = '' - self.dec_buf1 = '' - self.dec_buf2 = '' + self.enc_buf = "" + self.dec_buf = "" + self.dec_buf0 = "" + self.dec_buf1 = "" + self.dec_buf2 = "" self.dec_index = 0 self.transition_map = dict() - def assign_and_kout(self, signature, assignment, transition_args = None): + def assign_and_kout(self, signature, assignment, transition_args=None): self.new_var(signature) - self.assign_var(assignment, transition_args = transition_args) + self.assign_var(assignment, transition_args=transition_args) self.kout_var() def new_var(self, signature): self.dec_index += 1 - self.dec_buf0 += '{} dec_{:d};\n'.format(signature, self.dec_index) + self.dec_buf0 += "{} dec_{:d};\n".format(signature, self.dec_index) - def assign_var(self, assignment, transition_args = None): - snippet = 'dec_{:d} = {};\n'.format(self.dec_index, assignment) + def assign_var(self, assignment, transition_args=None): + snippet = "dec_{:d} = {};\n".format(self.dec_index, assignment) self.dec_buf1 += snippet if transition_args: self.add_transition(snippet, transition_args) def get_var(self): - return 'dec_{:d}'.format(self.dec_index) + return "dec_{:d}".format(self.dec_index) def kout_var(self): - self.dec_buf2 += 'kout << dec_{:d};\n'.format(self.dec_index) + self.dec_buf2 += "kout << dec_{:d};\n".format(self.dec_index) def note_unsupported(self, value): - note = '// value {} has unsupported type {}\n'.format(value, type(value)) + note = "// value {} has unsupported type {}\n".format(value, type(value)) self.enc_buf += note self.dec_buf += note self.dec_buf1 += note @@ -60,31 +61,31 @@ class DummyProtocol: return True def get_encode(self): - return '' + return "" def get_buffer_declaration(self): - return '' + return "" def get_buffer_name(self): return '"none"' def get_serialize(self): - return '' + return "" def get_deserialize(self): - return '' + return "" def get_decode_and_output(self): - return '' + return "" def get_decode_vars(self): - return '' + return "" def get_decode(self): - return '' + return "" def get_decode_output(self): - return '' + return "" def get_extra_files(self): return dict() @@ -101,30 +102,34 @@ class DummyProtocol: return self.transition_map[code_snippet] return list() -class Avro(DummyProtocol): - def __init__(self, data, strip_schema = False): +class Avro(DummyProtocol): + def __init__(self, data, strip_schema=False): super().__init__() self.data = data self.strip_schema = strip_schema self.schema = { - 'namespace' : 'benchmark.avro', - 'type' : 'record', - 'name' : 'Benchmark', - 'fields' : [] + "namespace": "benchmark.avro", + "type": "record", + "name": "Benchmark", + "fields": [], } for key, value in data.items(): - self.add_to_dict(self.schema['fields'], key, value) + self.add_to_dict(self.schema["fields"], key, value) buf = io.BytesIO() try: - writer = avro.datafile.DataFileWriter(buf, avro.io.DatumWriter(), avro.schema.Parse(json.dumps(self.schema))) + writer = avro.datafile.DataFileWriter( + buf, avro.io.DatumWriter(), avro.schema.Parse(json.dumps(self.schema)) + ) writer.append(data) writer.flush() except avro.schema.SchemaParseException: - raise RuntimeError('Unsupported schema') from None + raise RuntimeError("Unsupported schema") from None self.serialized_data = buf.getvalue() if strip_schema: - self.serialized_data = self.serialized_data[self.serialized_data.find(b'}\x00')+2 : ] + self.serialized_data = self.serialized_data[ + self.serialized_data.find(b"}\x00") + 2 : + ] # strip leading 16-byte sync marker self.serialized_data = self.serialized_data[16:] # strip trailing 16-byte sync marker @@ -138,29 +143,30 @@ class Avro(DummyProtocol): def type_to_type_name(self, type_type): if type_type == int: - return 'int' + return "int" if type_type == float: - return 'float' + return "float" if type_type == str: - return 'string' + return "string" if type_type == list: - return 'array' + return "array" if type_type == dict: - return 'record' + return "record" def add_to_dict(self, fields, key, value): - new_field = { - 'name' : key, - 'type' : self.type_to_type_name(type(value)) - } - if new_field['type'] == 'array': - new_field['type'] = {'type' : 'array', 'items' : self.type_to_type_name(type(value[0]))} - if new_field['type'] == 'record': - new_field['type'] = {'type' : 'record', 'name': key, 'fields' : []} + new_field = {"name": key, "type": self.type_to_type_name(type(value))} + if new_field["type"] == "array": + new_field["type"] = { + "type": "array", + "items": self.type_to_type_name(type(value[0])), + } + if new_field["type"] == "record": + new_field["type"] = {"type": "record", "name": key, "fields": []} for key, value in value.items(): - self.add_to_dict(new_field['type']['fields'], key, value) + self.add_to_dict(new_field["type"]["fields"], key, value) fields.append(new_field) + class Thrift(DummyProtocol): class_index = 1 @@ -169,10 +175,10 @@ class Thrift(DummyProtocol): super().__init__() self.data = data self._field_id = 1 - self.proto_buf = '' + self.proto_buf = "" self.proto_from_json(data) - with open('/tmp/test.thrift', 'w') as f: + with open("/tmp/test.thrift", "w") as f: f.write(self.proto_buf) membuf = TCyMemoryBuffer() @@ -180,7 +186,9 @@ class Thrift(DummyProtocol): # TODO irgendwo bleibt state übrig -> nur das bei allerersten # Aufruf geladene Protokoll wird berücksichtigt, dazu nicht passende # Daten werden nicht serialisiert - test_thrift = thriftpy.load('/tmp/test.thrift', module_name='test{:d}_thrift'.format(Thrift.class_index)) + test_thrift = thriftpy.load( + "/tmp/test.thrift", module_name="test{:d}_thrift".format(Thrift.class_index) + ) Thrift.class_index += 1 benchmark = test_thrift.Benchmark() @@ -190,7 +198,7 @@ class Thrift(DummyProtocol): try: proto.write_struct(benchmark) except thriftpy.thrift.TDecodeException: - raise RuntimeError('Unsupported data layout') from None + raise RuntimeError("Unsupported data layout") from None membuf.flush() self.serialized_data = membuf.getvalue() @@ -203,31 +211,31 @@ class Thrift(DummyProtocol): def type_to_type_name(self, value): type_type = type(value) if type_type == int: - return 'i32' + return "i32" if type_type == float: - return 'double' + return "double" if type_type == str: - return 'string' + return "string" if type_type == list: - return 'list<{}>'.format(self.type_to_type_name(value[0])) + return "list<{}>".format(self.type_to_type_name(value[0])) if type_type == dict: sub_value = list(value.values())[0] - return 'map<{},{}>'.format('string', self.type_to_type_name(sub_value)) + return "map<{},{}>".format("string", self.type_to_type_name(sub_value)) def add_to_dict(self, key, value): key_type = self.type_to_type_name(value) - self.proto_buf += '{:d}: {} {};\n'.format(self._field_id, key_type, key) + self.proto_buf += "{:d}: {} {};\n".format(self._field_id, key_type, key) self._field_id += 1 def proto_from_json(self, data): - self.proto_buf += 'struct Benchmark {\n' + self.proto_buf += "struct Benchmark {\n" for key, value in data.items(): self.add_to_dict(key, value) - self.proto_buf += '}\n' + self.proto_buf += "}\n" -class ArduinoJSON(DummyProtocol): - def __init__(self, data, bufsize = 255, int_type = 'uint16_t', float_type = 'float'): +class ArduinoJSON(DummyProtocol): + def __init__(self, data, bufsize=255, int_type="uint16_t", float_type="float"): super().__init__() self.data = data self.max_serialized_bytes = self.get_serialized_length() + 2 @@ -235,9 +243,12 @@ class ArduinoJSON(DummyProtocol): self.bufsize = bufsize self.int_type = int_type self.float_type = float_type - self.enc_buf += self.add_transition('ArduinoJson::StaticJsonBuffer<{:d}> jsonBuffer;\n'.format(bufsize), [bufsize]) - self.enc_buf += 'ArduinoJson::JsonObject& root = jsonBuffer.createObject();\n' - self.from_json(data, 'root') + self.enc_buf += self.add_transition( + "ArduinoJson::StaticJsonBuffer<{:d}> jsonBuffer;\n".format(bufsize), + [bufsize], + ) + self.enc_buf += "ArduinoJson::JsonObject& root = jsonBuffer.createObject();\n" + self.from_json(data, "root") def get_serialized_length(self): return len(json.dumps(self.data)) @@ -249,24 +260,33 @@ class ArduinoJSON(DummyProtocol): return self.enc_buf def get_buffer_declaration(self): - return 'char buf[{:d}];\n'.format(self.max_serialized_bytes) + return "char buf[{:d}];\n".format(self.max_serialized_bytes) def get_buffer_name(self): - return 'buf' + return "buf" def get_length_var(self): - return 'serialized_size' + return "serialized_size" def get_serialize(self): - return self.add_transition('uint16_t serialized_size = root.printTo(buf);\n', [self.max_serialized_bytes]) + return self.add_transition( + "uint16_t serialized_size = root.printTo(buf);\n", + [self.max_serialized_bytes], + ) def get_deserialize(self): - ret = self.add_transition('ArduinoJson::StaticJsonBuffer<{:d}> jsonBuffer;\n'.format(self.bufsize), [self.bufsize]) - ret += self.add_transition('ArduinoJson::JsonObject& root = jsonBuffer.parseObject(buf);\n', [self.max_serialized_bytes]) + ret = self.add_transition( + "ArduinoJson::StaticJsonBuffer<{:d}> jsonBuffer;\n".format(self.bufsize), + [self.bufsize], + ) + ret += self.add_transition( + "ArduinoJson::JsonObject& root = jsonBuffer.parseObject(buf);\n", + [self.max_serialized_bytes], + ) return ret def get_decode_and_output(self): - return 'kout << dec << "dec:";\n' + self.dec_buf + 'kout << endl;\n'; + return 'kout << dec << "dec:";\n' + self.dec_buf + "kout << endl;\n" def get_decode_vars(self): return self.dec_buf0 @@ -275,93 +295,158 @@ class ArduinoJSON(DummyProtocol): return self.dec_buf1 def get_decode_output(self): - return 'kout << dec << "dec:";\n' + self.dec_buf2 + 'kout << endl;\n'; + return 'kout << dec << "dec:";\n' + self.dec_buf2 + "kout << endl;\n" def add_to_list(self, enc_node, dec_node, offset, value): if type(value) == str: - if len(value) and value[0] == '$': - self.enc_buf += '{}.add({});\n'.format(enc_node, value[1:]) - self.dec_buf += 'kout << {}[{:d}].as<{}>();\n'.format(dec_node, offset, self.int_type) - self.assign_and_kout(self.int_type, '{}[{:d}].as<{}>()'.format(dec_node, offset, self.int_type)) + if len(value) and value[0] == "$": + self.enc_buf += "{}.add({});\n".format(enc_node, value[1:]) + self.dec_buf += "kout << {}[{:d}].as<{}>();\n".format( + dec_node, offset, self.int_type + ) + self.assign_and_kout( + self.int_type, + "{}[{:d}].as<{}>()".format(dec_node, offset, self.int_type), + ) else: - self.enc_buf += self.add_transition('{}.add("{}");\n'.format(enc_node, value), [len(value)]) - self.dec_buf += 'kout << {}[{:d}].as<const char *>();\n'.format(dec_node, offset) - self.assign_and_kout('char const*', '{}[{:d}].as<char const *>()'.format(dec_node, offset), transition_args = [len(value)]) + self.enc_buf += self.add_transition( + '{}.add("{}");\n'.format(enc_node, value), [len(value)] + ) + self.dec_buf += "kout << {}[{:d}].as<const char *>();\n".format( + dec_node, offset + ) + self.assign_and_kout( + "char const*", + "{}[{:d}].as<char const *>()".format(dec_node, offset), + transition_args=[len(value)], + ) elif type(value) == list: - child = enc_node + 'l' + child = enc_node + "l" while child in self.children: - child += '_' - self.enc_buf += 'ArduinoJson::JsonArray& {} = {}.createNestedArray();\n'.format( - child, enc_node) + child += "_" + self.enc_buf += "ArduinoJson::JsonArray& {} = {}.createNestedArray();\n".format( + child, enc_node + ) self.children.add(child) self.from_json(value, child) elif type(value) == dict: - child = enc_node + 'o' + child = enc_node + "o" while child in self.children: - child += '_' - self.enc_buf += 'ArduinoJson::JsonObject& {} = {}.createNestedObject();\n'.format( - child, enc_node) + child += "_" + self.enc_buf += "ArduinoJson::JsonObject& {} = {}.createNestedObject();\n".format( + child, enc_node + ) self.children.add(child) self.from_json(value, child) elif type(value) == float: - self.enc_buf += '{}.add({});\n'.format(enc_node, value) - self.dec_buf += 'kout << {}[{:d}].as<{}>();\n'.format(dec_node, offset, self.float_type) - self.assign_and_kout(self.float_type, '{}[{:d}].as<{}>()'.format(dec_node, offset, self.float_type)) + self.enc_buf += "{}.add({});\n".format(enc_node, value) + self.dec_buf += "kout << {}[{:d}].as<{}>();\n".format( + dec_node, offset, self.float_type + ) + self.assign_and_kout( + self.float_type, + "{}[{:d}].as<{}>()".format(dec_node, offset, self.float_type), + ) elif type(value) == int: - self.enc_buf += '{}.add({});\n'.format(enc_node, value) - self.dec_buf += 'kout << {}[{:d}].as<{}>();\n'.format(dec_node, offset, self.int_type) - self.assign_and_kout(self.int_type, '{}[{:d}].as<{}>()'.format(dec_node, offset, self.int_type)) + self.enc_buf += "{}.add({});\n".format(enc_node, value) + self.dec_buf += "kout << {}[{:d}].as<{}>();\n".format( + dec_node, offset, self.int_type + ) + self.assign_and_kout( + self.int_type, + "{}[{:d}].as<{}>()".format(dec_node, offset, self.int_type), + ) else: self.note_unsupported(value) def add_to_dict(self, enc_node, dec_node, key, value): if type(value) == str: - if len(value) and value[0] == '$': - self.enc_buf += self.add_transition('{}["{}"] = {};\n'.format(enc_node, key, value[1:]), [len(key)]) - self.dec_buf += 'kout << {}["{}"].as<{}>();\n'.format(dec_node, key, self.int_type) - self.assign_and_kout(self.int_type, '{}["{}"].as<{}>()'.format(dec_node, key, self.int_type)) + if len(value) and value[0] == "$": + self.enc_buf += self.add_transition( + '{}["{}"] = {};\n'.format(enc_node, key, value[1:]), [len(key)] + ) + self.dec_buf += 'kout << {}["{}"].as<{}>();\n'.format( + dec_node, key, self.int_type + ) + self.assign_and_kout( + self.int_type, + '{}["{}"].as<{}>()'.format(dec_node, key, self.int_type), + ) else: - self.enc_buf += self.add_transition('{}["{}"] = "{}";\n'.format(enc_node, key, value), [len(key), len(value)]) - self.dec_buf += 'kout << {}["{}"].as<const char *>();\n'.format(dec_node, key) - self.assign_and_kout('char const*', '{}["{}"].as<const char *>()'.format(dec_node, key), transition_args = [len(key), len(value)]) + self.enc_buf += self.add_transition( + '{}["{}"] = "{}";\n'.format(enc_node, key, value), + [len(key), len(value)], + ) + self.dec_buf += 'kout << {}["{}"].as<const char *>();\n'.format( + dec_node, key + ) + self.assign_and_kout( + "char const*", + '{}["{}"].as<const char *>()'.format(dec_node, key), + transition_args=[len(key), len(value)], + ) elif type(value) == list: - child = enc_node + 'l' + child = enc_node + "l" while child in self.children: - child += '_' - self.enc_buf += self.add_transition('ArduinoJson::JsonArray& {} = {}.createNestedArray("{}");\n'.format( - child, enc_node, key), [len(key)]) + child += "_" + self.enc_buf += self.add_transition( + 'ArduinoJson::JsonArray& {} = {}.createNestedArray("{}");\n'.format( + child, enc_node, key + ), + [len(key)], + ) self.children.add(child) self.from_json(value, child, '{}["{}"]'.format(dec_node, key)) elif type(value) == dict: - child = enc_node + 'o' + child = enc_node + "o" while child in self.children: - child += '_' - self.enc_buf += self.add_transition('ArduinoJson::JsonObject& {} = {}.createNestedObject("{}");\n'.format( - child, enc_node, key), [len(key)]) + child += "_" + self.enc_buf += self.add_transition( + 'ArduinoJson::JsonObject& {} = {}.createNestedObject("{}");\n'.format( + child, enc_node, key + ), + [len(key)], + ) self.children.add(child) self.from_json(value, child, '{}["{}"]'.format(dec_node, key)) elif type(value) == float: - self.enc_buf += self.add_transition('{}["{}"] = {};\n'.format(enc_node, key, value), [len(key)]) - self.dec_buf += 'kout << {}["{}"].as<{}>();\n'.format(dec_node, key, self.float_type) - self.assign_and_kout(self.float_type, '{}["{}"].as<{}>()'.format(dec_node, key, self.float_type), transition_args = [len(key)]) + self.enc_buf += self.add_transition( + '{}["{}"] = {};\n'.format(enc_node, key, value), [len(key)] + ) + self.dec_buf += 'kout << {}["{}"].as<{}>();\n'.format( + dec_node, key, self.float_type + ) + self.assign_and_kout( + self.float_type, + '{}["{}"].as<{}>()'.format(dec_node, key, self.float_type), + transition_args=[len(key)], + ) elif type(value) == int: - self.enc_buf += self.add_transition('{}["{}"] = {};\n'.format(enc_node, key, value), [len(key)]) - self.dec_buf += 'kout << {}["{}"].as<{}>();\n'.format(dec_node, key, self.int_type) - self.assign_and_kout(self.int_type, '{}["{}"].as<{}>()'.format(dec_node, key, self.int_type), transition_args = [len(key)]) + self.enc_buf += self.add_transition( + '{}["{}"] = {};\n'.format(enc_node, key, value), [len(key)] + ) + self.dec_buf += 'kout << {}["{}"].as<{}>();\n'.format( + dec_node, key, self.int_type + ) + self.assign_and_kout( + self.int_type, + '{}["{}"].as<{}>()'.format(dec_node, key, self.int_type), + transition_args=[len(key)], + ) else: self.note_unsupported(value) - def from_json(self, data, enc_node = 'root', dec_node = 'root'): + def from_json(self, data, enc_node="root", dec_node="root"): if type(data) == dict: for key in sorted(data.keys()): self.add_to_dict(enc_node, dec_node, key, data[key]) @@ -371,8 +456,16 @@ class ArduinoJSON(DummyProtocol): class CapnProtoC(DummyProtocol): - - def __init__(self, data, max_serialized_bytes = 128, packed = False, trail = ['benchmark'], int_type = 'uint16_t', float_type = 'float', dec_index = 0): + def __init__( + self, + data, + max_serialized_bytes=128, + packed=False, + trail=["benchmark"], + int_type="uint16_t", + float_type="float", + dec_index=0, + ): super().__init__() self.data = data self.max_serialized_bytes = max_serialized_bytes @@ -384,169 +477,190 @@ class CapnProtoC(DummyProtocol): self.float_type = float_type self.proto_float_type = self.float_type_to_proto_type(float_type) self.dec_index = dec_index - self.trail_name = '_'.join(map(lambda x: x.capitalize(), trail)) - self.proto_buf = '' - self.enc_buf += 'struct {} {};\n'.format(self.trail_name, self.name) - self.cc_tail = '' + self.trail_name = "_".join(map(lambda x: x.capitalize(), trail)) + self.proto_buf = "" + self.enc_buf += "struct {} {};\n".format(self.trail_name, self.name) + self.cc_tail = "" self.key_counter = 0 self.from_json(data) def int_type_to_proto_type(self, int_type): - sign = '' - if int_type[0] == 'u': - sign = 'U' - if '8' in int_type: + sign = "" + if int_type[0] == "u": + sign = "U" + if "8" in int_type: self.int_bits = 8 - return sign + 'Int8' - if '16' in int_type: + return sign + "Int8" + if "16" in int_type: self.int_bits = 16 - return sign + 'Int16' - if '32' in int_type: + return sign + "Int16" + if "32" in int_type: self.int_bits = 32 - return sign + 'Int32' + return sign + "Int32" self.int_bits = 64 - return sign + 'Int64' + return sign + "Int64" def float_type_to_proto_type(self, float_type): - if float_type == 'float': + if float_type == "float": self.float_bits = 32 - return 'Float32' + return "Float32" self.float_bits = 64 - return 'Float64' + return "Float64" def is_ascii(self): return False def get_proto(self): - return '@0xad5b236043de2389;\n\n' + self.proto_buf + return "@0xad5b236043de2389;\n\n" + self.proto_buf def get_extra_files(self): - return { - 'capnp_c_bench.capnp' : self.get_proto() - } + return {"capnp_c_bench.capnp": self.get_proto()} def get_buffer_declaration(self): - ret = 'uint8_t buf[{:d}];\n'.format(self.max_serialized_bytes) - ret += 'uint16_t serialized_size;\n' + ret = "uint8_t buf[{:d}];\n".format(self.max_serialized_bytes) + ret += "uint16_t serialized_size;\n" return ret def get_buffer_name(self): - return 'buf' + return "buf" def get_encode(self): - ret = 'struct capn c;\n' - ret += 'capn_init_malloc(&c);\n' - ret += 'capn_ptr cr = capn_root(&c);\n' - ret += 'struct capn_segment *cs = cr.seg;\n\n' - ret += '{}_ptr {}_ptr = new_{}(cs);\n'.format( - self.trail_name, self.name, self.trail_name) + ret = "struct capn c;\n" + ret += "capn_init_malloc(&c);\n" + ret += "capn_ptr cr = capn_root(&c);\n" + ret += "struct capn_segment *cs = cr.seg;\n\n" + ret += "{}_ptr {}_ptr = new_{}(cs);\n".format( + self.trail_name, self.name, self.trail_name + ) - tail = 'write_{}(&{}, {}_ptr);\n'.format( - self.trail_name, self.name, self.name) - tail += 'capn_setp(cr, 0, {}_ptr.p);\n'.format(self.name) + tail = "write_{}(&{}, {}_ptr);\n".format(self.trail_name, self.name, self.name) + tail += "capn_setp(cr, 0, {}_ptr.p);\n".format(self.name) return ret + self.enc_buf + self.cc_tail + tail def get_serialize(self): - ret = 'serialized_size = capn_write_mem(&c, buf, sizeof(buf), {:d});\n'.format(self.packed) - ret += 'capn_free(&c);\n' + ret = "serialized_size = capn_write_mem(&c, buf, sizeof(buf), {:d});\n".format( + self.packed + ) + ret += "capn_free(&c);\n" return ret def get_deserialize(self): - ret = 'struct capn c;\n' - ret += 'capn_init_mem(&c, buf, serialized_size, 0);\n' + ret = "struct capn c;\n" + ret += "capn_init_mem(&c, buf, serialized_size, 0);\n" return ret def get_decode_and_output(self): - ret = '{}_ptr {}_ptr;\n'.format(self.trail_name, self.name) - ret += '{}_ptr.p = capn_getp(capn_root(&c), 0, 1);\n'.format(self.name) - ret += 'struct {} {};\n'.format(self.trail_name, self.name) + ret = "{}_ptr {}_ptr;\n".format(self.trail_name, self.name) + ret += "{}_ptr.p = capn_getp(capn_root(&c), 0, 1);\n".format(self.name) + ret += "struct {} {};\n".format(self.trail_name, self.name) ret += 'kout << dec << "dec:";\n' ret += self.dec_buf - ret += 'kout << endl;\n' - ret += 'capn_free(&c);\n' + ret += "kout << endl;\n" + ret += "capn_free(&c);\n" return ret def get_decode_vars(self): return self.dec_buf0 def get_decode(self): - ret = '{}_ptr {}_ptr;\n'.format(self.trail_name, self.name) - ret += '{}_ptr.p = capn_getp(capn_root(&c), 0, 1);\n'.format(self.name) - ret += 'struct {} {};\n'.format(self.trail_name, self.name) + ret = "{}_ptr {}_ptr;\n".format(self.trail_name, self.name) + ret += "{}_ptr.p = capn_getp(capn_root(&c), 0, 1);\n".format(self.name) + ret += "struct {} {};\n".format(self.trail_name, self.name) ret += self.dec_buf1 - ret += 'capn_free(&c);\n' + ret += "capn_free(&c);\n" return ret def get_decode_output(self): - return 'kout << dec << "dec:";\n' + self.dec_buf2 + 'kout << endl;\n'; + return 'kout << dec << "dec:";\n' + self.dec_buf2 + "kout << endl;\n" def get_length_var(self): - return 'serialized_size' + return "serialized_size" def add_field(self, fieldtype, key, value): - extra = '' + extra = "" texttype = self.proto_int_type if fieldtype == str: - texttype = 'Text' + texttype = "Text" elif fieldtype == float: texttype = self.proto_float_type elif fieldtype == dict: texttype = key.capitalize() if type(value) == list: - texttype = 'List({})'.format(texttype) + texttype = "List({})".format(texttype) - self.proto_buf += '{} @{:d} :{};\n'.format( - key, self.key_counter, texttype) + self.proto_buf += "{} @{:d} :{};\n".format(key, self.key_counter, texttype) self.key_counter += 1 if fieldtype == str: - self.enc_buf += 'capn_text {}_text;\n'.format(key) - self.enc_buf += '{}_text.len = {:d};\n'.format(key, len(value)) + self.enc_buf += "capn_text {}_text;\n".format(key) + self.enc_buf += "{}_text.len = {:d};\n".format(key, len(value)) self.enc_buf += '{}_text.str = "{}";\n'.format(key, value) - self.enc_buf += '{}_text.seg = NULL;\n'.format(key) - self.enc_buf += '{}.{} = {}_text;\n\n'.format(self.name, key, key) - self.dec_buf += 'kout << {}.{}.str;\n'.format(self.name, key) - self.assign_and_kout('char const *', '{}.{}.str'.format(self.name, key)) + self.enc_buf += "{}_text.seg = NULL;\n".format(key) + self.enc_buf += "{}.{} = {}_text;\n\n".format(self.name, key, key) + self.dec_buf += "kout << {}.{}.str;\n".format(self.name, key) + self.assign_and_kout("char const *", "{}.{}.str".format(self.name, key)) elif fieldtype == dict: - pass # content is handled recursively in add_to_dict + pass # content is handled recursively in add_to_dict elif type(value) == list: if type(value[0]) == float: - self.enc_buf += self.add_transition('{}.{} = capn_new_list{:d}(cs, {:d});\n'.format( - self.name, key, self.float_bits, len(value)), [len(value)]) + self.enc_buf += self.add_transition( + "{}.{} = capn_new_list{:d}(cs, {:d});\n".format( + self.name, key, self.float_bits, len(value) + ), + [len(value)], + ) for i, elem in enumerate(value): - self.enc_buf += 'capn_set{:d}({}.{}, {:d}, capn_from_f{:d}({:f}));\n'.format( - self.float_bits, self.name, key, i, self.float_bits, elem) - self.dec_buf += 'kout << capn_to_f{:d}(capn_get{:d}({}.{}, {:d}));\n'.format( - self.float_bits, self.float_bits, self.name, key, i) - self.assign_and_kout(self.float_type, 'capn_to_f{:d}(capn_get{:d}({}.{}, {:d}))'.format(self.float_bits, self.float_bits, self.name, key, i)) + self.enc_buf += "capn_set{:d}({}.{}, {:d}, capn_from_f{:d}({:f}));\n".format( + self.float_bits, self.name, key, i, self.float_bits, elem + ) + self.dec_buf += "kout << capn_to_f{:d}(capn_get{:d}({}.{}, {:d}));\n".format( + self.float_bits, self.float_bits, self.name, key, i + ) + self.assign_and_kout( + self.float_type, + "capn_to_f{:d}(capn_get{:d}({}.{}, {:d}))".format( + self.float_bits, self.float_bits, self.name, key, i + ), + ) else: - self.enc_buf += self.add_transition('{}.{} = capn_new_list{:d}(cs, {:d});\n'.format( - self.name, key, self.int_bits, len(value)), [len(value)]) + self.enc_buf += self.add_transition( + "{}.{} = capn_new_list{:d}(cs, {:d});\n".format( + self.name, key, self.int_bits, len(value) + ), + [len(value)], + ) for i, elem in enumerate(value): - self.enc_buf += 'capn_set{:d}({}.{}, {:d}, {:d});\n'.format( - self.int_bits, self.name, key, i, elem) - self.dec_buf += 'kout << capn_get{:d}({}.{}, {:d});\n'.format( - self.int_bits, self.name, key, i) - self.assign_and_kout(self.int_type, 'capn_get{:d}({}.{}, {:d})'.format(self.int_bits, self.name, key, i)) + self.enc_buf += "capn_set{:d}({}.{}, {:d}, {:d});\n".format( + self.int_bits, self.name, key, i, elem + ) + self.dec_buf += "kout << capn_get{:d}({}.{}, {:d});\n".format( + self.int_bits, self.name, key, i + ) + self.assign_and_kout( + self.int_type, + "capn_get{:d}({}.{}, {:d})".format( + self.int_bits, self.name, key, i + ), + ) elif fieldtype == float: - self.enc_buf += '{}.{} = {};\n\n'.format(self.name, key, value) - self.dec_buf += 'kout << {}.{};\n'.format(self.name, key) - self.assign_and_kout(self.float_type, '{}.{}'.format(self.name, key)) + self.enc_buf += "{}.{} = {};\n\n".format(self.name, key, value) + self.dec_buf += "kout << {}.{};\n".format(self.name, key) + self.assign_and_kout(self.float_type, "{}.{}".format(self.name, key)) elif fieldtype == int: - self.enc_buf += '{}.{} = {};\n\n'.format(self.name, key, value) - self.dec_buf += 'kout << {}.{};\n'.format(self.name, key) - self.assign_and_kout(self.int_type, '{}.{}'.format(self.name, key)) + self.enc_buf += "{}.{} = {};\n\n".format(self.name, key, value) + self.dec_buf += "kout << {}.{};\n".format(self.name, key) + self.assign_and_kout(self.int_type, "{}.{}".format(self.name, key)) else: self.note_unsupported(value) def add_to_dict(self, key, value): if type(value) == str: - if len(value) and value[0] == '$': + if len(value) and value[0] == "$": self.add_field(int, key, value[1:]) else: self.add_field(str, key, value) @@ -555,19 +669,35 @@ class CapnProtoC(DummyProtocol): elif type(value) == dict: trail = list(self.trail) trail.append(key) - nested = CapnProtoC(value, trail = trail, int_type = self.int_type, float_type = self.float_type, dec_index = self.dec_index) + nested = CapnProtoC( + value, + trail=trail, + int_type=self.int_type, + float_type=self.float_type, + dec_index=self.dec_index, + ) self.add_field(dict, key, value) - self.enc_buf += '{}.{} = new_{}_{}(cs);\n'.format( - self.name, key, self.trail_name, key.capitalize()) + self.enc_buf += "{}.{} = new_{}_{}(cs);\n".format( + self.name, key, self.trail_name, key.capitalize() + ) self.enc_buf += nested.enc_buf - self.enc_buf += 'write_{}_{}(&{}, {}.{});\n'.format( - self.trail_name, key.capitalize(), key, self.name, key) - self.dec_buf += 'struct {}_{} {};\n'.format(self.trail_name, key.capitalize(), key) - self.dec_buf += 'read_{}_{}(&{}, {}.{});\n'.format(self.trail_name, key.capitalize(), key, self.name, key) + self.enc_buf += "write_{}_{}(&{}, {}.{});\n".format( + self.trail_name, key.capitalize(), key, self.name, key + ) + self.dec_buf += "struct {}_{} {};\n".format( + self.trail_name, key.capitalize(), key + ) + self.dec_buf += "read_{}_{}(&{}, {}.{});\n".format( + self.trail_name, key.capitalize(), key, self.name, key + ) self.dec_buf += nested.dec_buf self.dec_buf0 += nested.dec_buf0 - self.dec_buf1 += 'struct {}_{} {};\n'.format(self.trail_name, key.capitalize(), key) - self.dec_buf1 += 'read_{}_{}(&{}, {}.{});\n'.format(self.trail_name, key.capitalize(), key, self.name, key) + self.dec_buf1 += "struct {}_{} {};\n".format( + self.trail_name, key.capitalize(), key + ) + self.dec_buf1 += "read_{}_{}(&{}, {}.{});\n".format( + self.trail_name, key.capitalize(), key, self.name, key + ) self.dec_buf1 += nested.dec_buf1 self.dec_buf2 += nested.dec_buf2 self.dec_index = nested.dec_index @@ -576,19 +706,19 @@ class CapnProtoC(DummyProtocol): self.add_field(type(value), key, value) def from_json(self, data): - self.proto_buf += 'struct {} {{\n'.format(self.name.capitalize()) + self.proto_buf += "struct {} {{\n".format(self.name.capitalize()) if type(data) == dict: for key in sorted(data.keys()): self.add_to_dict(key, data[key]) - self.proto_buf += '}\n' + self.proto_buf += "}\n" -class ManualJSON(DummyProtocol): +class ManualJSON(DummyProtocol): def __init__(self, data): super().__init__() self.data = data self.max_serialized_bytes = self.get_serialized_length() + 2 - self.buf = 'BufferOutput<> bout(buf);\n' + self.buf = "BufferOutput<> bout(buf);\n" self.buf += 'bout << "{";\n' self.from_json(data) self.buf += 'bout << "}";\n' @@ -603,23 +733,25 @@ class ManualJSON(DummyProtocol): return True def get_buffer_declaration(self): - return 'char buf[{:d}];\n'.format(self.max_serialized_bytes); + return "char buf[{:d}];\n".format(self.max_serialized_bytes) def get_buffer_name(self): - return 'buf' + return "buf" def get_encode(self): return self.buf def get_length_var(self): - return 'bout.size()' + return "bout.size()" def add_to_list(self, value, is_last): if type(value) == str: - if len(value) and value[0] == '$': - self.buf += 'bout << dec << {}'.format(value[1:]) + if len(value) and value[0] == "$": + self.buf += "bout << dec << {}".format(value[1:]) else: - self.buf += self.add_transition('bout << "\\"{}\\""'.format(value), [len(value)]) + self.buf += self.add_transition( + 'bout << "\\"{}\\""'.format(value), [len(value)] + ) elif type(value) == list: self.buf += 'bout << "[";\n' @@ -632,36 +764,48 @@ class ManualJSON(DummyProtocol): self.buf += 'bout << "}"' else: - self.buf += 'bout << {}'.format(value) + self.buf += "bout << {}".format(value) if is_last: - self.buf += ';\n'; + self.buf += ";\n" else: self.buf += ' << ",";\n' def add_to_dict(self, key, value, is_last): if type(value) == str: - if len(value) and value[0] == '$': - self.buf += self.add_transition('bout << "\\"{}\\":" << dec << {}'.format(key, value[1:]), [len(key)]) + if len(value) and value[0] == "$": + self.buf += self.add_transition( + 'bout << "\\"{}\\":" << dec << {}'.format(key, value[1:]), + [len(key)], + ) else: - self.buf += self.add_transition('bout << "\\"{}\\":\\"{}\\""'.format(key, value), [len(key), len(value)]) + self.buf += self.add_transition( + 'bout << "\\"{}\\":\\"{}\\""'.format(key, value), + [len(key), len(value)], + ) elif type(value) == list: - self.buf += self.add_transition('bout << "\\"{}\\":[";\n'.format(key), [len(key)]) + self.buf += self.add_transition( + 'bout << "\\"{}\\":[";\n'.format(key), [len(key)] + ) self.from_json(value) self.buf += 'bout << "]"' elif type(value) == dict: # '{{' is an escaped '{' character - self.buf += self.add_transition('bout << "\\"{}\\":{{";\n'.format(key), [len(key)]) + self.buf += self.add_transition( + 'bout << "\\"{}\\":{{";\n'.format(key), [len(key)] + ) self.from_json(value) self.buf += 'bout << "}"' else: - self.buf += self.add_transition('bout << "\\"{}\\":" << {}'.format(key, value), [len(key)]) + self.buf += self.add_transition( + 'bout << "\\"{}\\":" << {}'.format(key, value), [len(key)] + ) if is_last: - self.buf += ';\n' + self.buf += ";\n" else: self.buf += ' << ",";\n' @@ -674,74 +818,74 @@ class ManualJSON(DummyProtocol): for i, elem in enumerate(data): self.add_to_list(elem, i == len(data) - 1) -class ModernJSON(DummyProtocol): - def __init__(self, data, output_format = 'json'): +class ModernJSON(DummyProtocol): + def __init__(self, data, output_format="json"): super().__init__() self.data = data self.output_format = output_format - self.buf = 'nlohmann::json js;\n' + self.buf = "nlohmann::json js;\n" self.from_json(data) def is_ascii(self): - if self.output_format == 'json': + if self.output_format == "json": return True return False def get_buffer_name(self): - return 'out' + return "out" def get_encode(self): return self.buf def get_serialize(self): - if self.output_format == 'json': - return 'std::string out = js.dump();\n' - elif self.output_format == 'bson': - return 'std::vector<std::uint8_t> out = nlohmann::json::to_bson(js);\n' - elif self.output_format == 'cbor': - return 'std::vector<std::uint8_t> out = nlohmann::json::to_cbor(js);\n' - elif self.output_format == 'msgpack': - return 'std::vector<std::uint8_t> out = nlohmann::json::to_msgpack(js);\n' - elif self.output_format == 'ubjson': - return 'std::vector<std::uint8_t> out = nlohmann::json::to_ubjson(js);\n' + if self.output_format == "json": + return "std::string out = js.dump();\n" + elif self.output_format == "bson": + return "std::vector<std::uint8_t> out = nlohmann::json::to_bson(js);\n" + elif self.output_format == "cbor": + return "std::vector<std::uint8_t> out = nlohmann::json::to_cbor(js);\n" + elif self.output_format == "msgpack": + return "std::vector<std::uint8_t> out = nlohmann::json::to_msgpack(js);\n" + elif self.output_format == "ubjson": + return "std::vector<std::uint8_t> out = nlohmann::json::to_ubjson(js);\n" else: - raise ValueError('invalid output format {}'.format(self.output_format)) + raise ValueError("invalid output format {}".format(self.output_format)) def get_serialized_length(self): - if self.output_format == 'json': + if self.output_format == "json": return len(json.dumps(self.data)) - elif self.output_format == 'bson': + elif self.output_format == "bson": return len(bson.BSON.encode(self.data)) - elif self.output_format == 'cbor': + elif self.output_format == "cbor": return len(cbor.dumps(self.data)) - elif self.output_format == 'msgpack': + elif self.output_format == "msgpack": return len(msgpack.dumps(self.data)) - elif self.output_format == 'ubjson': + elif self.output_format == "ubjson": return len(ubjson.dumpb(self.data)) else: - raise ValueError('invalid output format {}'.format(self.output_format)) + raise ValueError("invalid output format {}".format(self.output_format)) def can_get_serialized_length(self): return True def get_length_var(self): - return 'out.size()' + return "out.size()" def add_to_list(self, prefix, index, value): if type(value) == str: - if len(value) and value[0] == '$': + if len(value) and value[0] == "$": self.buf += value[1:] - self.buf += '{}[{:d}] = {};\n'.format(prefix, index, value[1:]) + self.buf += "{}[{:d}] = {};\n".format(prefix, index, value[1:]) else: self.buf += '{}[{:d}] = "{}";\n'.format(prefix, index, value) else: - self.buf += '{}[{:d}] = {};\n'.format(prefix, index, value) + self.buf += "{}[{:d}] = {};\n".format(prefix, index, value) def add_to_dict(self, prefix, key, value): if type(value) == str: - if len(value) and value[0] == '$': + if len(value) and value[0] == "$": self.buf += '{}["{}"] = {};\n'.format(prefix, key, value[1:]) else: self.buf += '{}["{}"] = "{}";\n'.format(prefix, key, value) @@ -755,7 +899,7 @@ class ModernJSON(DummyProtocol): else: self.buf += '{}["{}"] = {};\n'.format(prefix, key, value) - def from_json(self, data, prefix = 'js'): + def from_json(self, data, prefix="js"): if type(data) == dict: for key in sorted(data.keys()): self.add_to_dict(prefix, key, data[key]) @@ -763,17 +907,20 @@ class ModernJSON(DummyProtocol): for i, elem in enumerate(data): self.add_to_list(prefix, i, elem) -class MPack(DummyProtocol): - def __init__(self, data, int_type = 'uint16_t', float_type = 'float'): +class MPack(DummyProtocol): + def __init__(self, data, int_type="uint16_t", float_type="float"): super().__init__() self.data = data self.max_serialized_bytes = self.get_serialized_length() + 2 self.int_type = int_type self.float_type = float_type - self.enc_buf += 'mpack_writer_t writer;\n' - self.enc_buf += self.add_transition('mpack_writer_init(&writer, buf, sizeof(buf));\n', [self.max_serialized_bytes]) - self.dec_buf0 += 'char strbuf[16];\n' + self.enc_buf += "mpack_writer_t writer;\n" + self.enc_buf += self.add_transition( + "mpack_writer_init(&writer, buf, sizeof(buf));\n", + [self.max_serialized_bytes], + ) + self.dec_buf0 += "char strbuf[16];\n" self.from_json(data) def get_serialized_length(self): @@ -786,34 +933,37 @@ class MPack(DummyProtocol): return False def get_buffer_declaration(self): - ret = 'char buf[{:d}];\n'.format(self.max_serialized_bytes) - ret += 'uint16_t serialized_size;\n' + ret = "char buf[{:d}];\n".format(self.max_serialized_bytes) + ret += "uint16_t serialized_size;\n" return ret def get_buffer_name(self): - return 'buf' + return "buf" def get_encode(self): return self.enc_buf def get_serialize(self): - ret = 'serialized_size = mpack_writer_buffer_used(&writer);\n' + ret = "serialized_size = mpack_writer_buffer_used(&writer);\n" # OptionalTimingAnalysis and other wrappers only wrap lines ending with a ; # We therefore deliberately do not use proper if { ... } syntax here # to make sure that these two statements are timed as well. - ret += 'if (mpack_writer_destroy(&writer) != mpack_ok) ' + ret += "if (mpack_writer_destroy(&writer) != mpack_ok) " ret += 'kout << "Encoding failed" << endl;\n' return ret def get_deserialize(self): - ret = 'mpack_reader_t reader;\n' - ret += self.add_transition('mpack_reader_init_data(&reader, buf, serialized_size);\n', [self.max_serialized_bytes]) + ret = "mpack_reader_t reader;\n" + ret += self.add_transition( + "mpack_reader_init_data(&reader, buf, serialized_size);\n", + [self.max_serialized_bytes], + ) return ret def get_decode_and_output(self): ret = 'kout << dec << "dec:";\n' - ret += 'char strbuf[16];\n' - return ret + self.dec_buf + 'kout << endl;\n' + ret += "char strbuf[16];\n" + return ret + self.dec_buf + "kout << endl;\n" def get_decode_vars(self): return self.dec_buf0 @@ -822,65 +972,99 @@ class MPack(DummyProtocol): return self.dec_buf1 def get_decode_output(self): - return 'kout << dec << "dec:";\n' + self.dec_buf2 + 'kout << endl;\n'; + return 'kout << dec << "dec:";\n' + self.dec_buf2 + "kout << endl;\n" def get_length_var(self): - return 'serialized_size' + return "serialized_size" def add_value(self, value): if type(value) == str: - if len(value) and value[0] == '$': - self.enc_buf += 'mpack_write(&writer, {});\n'.format(value[1:]) - self.dec_buf += 'kout << mpack_expect_uint(&reader);\n' - self.assign_and_kout(self.int_type, 'mpack_expect_uint(&reader)') + if len(value) and value[0] == "$": + self.enc_buf += "mpack_write(&writer, {});\n".format(value[1:]) + self.dec_buf += "kout << mpack_expect_uint(&reader);\n" + self.assign_and_kout(self.int_type, "mpack_expect_uint(&reader)") else: - self.enc_buf += self.add_transition('mpack_write_cstr_or_nil(&writer, "{}");\n'.format(value), [len(value)]) - self.dec_buf += 'mpack_expect_cstr(&reader, strbuf, sizeof(strbuf));\n' - self.dec_buf += 'kout << strbuf;\n' - self.dec_buf1 += self.add_transition('mpack_expect_cstr(&reader, strbuf, sizeof(strbuf));\n', [len(value)]) - self.dec_buf2 += 'kout << strbuf;\n' + self.enc_buf += self.add_transition( + 'mpack_write_cstr_or_nil(&writer, "{}");\n'.format(value), + [len(value)], + ) + self.dec_buf += "mpack_expect_cstr(&reader, strbuf, sizeof(strbuf));\n" + self.dec_buf += "kout << strbuf;\n" + self.dec_buf1 += self.add_transition( + "mpack_expect_cstr(&reader, strbuf, sizeof(strbuf));\n", + [len(value)], + ) + self.dec_buf2 += "kout << strbuf;\n" elif type(value) == list: self.from_json(value) elif type(value) == dict: self.from_json(value) elif type(value) == int: - self.enc_buf += 'mpack_write(&writer, ({}){:d});\n'.format(self.int_type, value) - self.dec_buf += 'kout << mpack_expect_uint(&reader);\n' - self.assign_and_kout(self.int_type, 'mpack_expect_uint(&reader)') + self.enc_buf += "mpack_write(&writer, ({}){:d});\n".format( + self.int_type, value + ) + self.dec_buf += "kout << mpack_expect_uint(&reader);\n" + self.assign_and_kout(self.int_type, "mpack_expect_uint(&reader)") elif type(value) == float: - self.enc_buf += 'mpack_write(&writer, ({}){:f});\n'.format(self.float_type, value) - self.dec_buf += 'kout << mpack_expect_float(&reader);\n' - self.assign_and_kout(self.float_type, 'mpack_expect_float(&reader)') + self.enc_buf += "mpack_write(&writer, ({}){:f});\n".format( + self.float_type, value + ) + self.dec_buf += "kout << mpack_expect_float(&reader);\n" + self.assign_and_kout(self.float_type, "mpack_expect_float(&reader)") else: self.note_unsupported(value) def from_json(self, data): if type(data) == dict: - self.enc_buf += self.add_transition('mpack_start_map(&writer, {:d});\n'.format(len(data)), [len(data)]) - self.dec_buf += 'mpack_expect_map_max(&reader, {:d});\n'.format(len(data)) - self.dec_buf1 += self.add_transition('mpack_expect_map_max(&reader, {:d});\n'.format(len(data)), [len(data)]) + self.enc_buf += self.add_transition( + "mpack_start_map(&writer, {:d});\n".format(len(data)), [len(data)] + ) + self.dec_buf += "mpack_expect_map_max(&reader, {:d});\n".format(len(data)) + self.dec_buf1 += self.add_transition( + "mpack_expect_map_max(&reader, {:d});\n".format(len(data)), [len(data)] + ) for key in sorted(data.keys()): - self.enc_buf += self.add_transition('mpack_write_cstr(&writer, "{}");\n'.format(key), [len(key)]) + self.enc_buf += self.add_transition( + 'mpack_write_cstr(&writer, "{}");\n'.format(key), [len(key)] + ) self.dec_buf += 'mpack_expect_cstr_match(&reader, "{}");\n'.format(key) - self.dec_buf1 += self.add_transition('mpack_expect_cstr_match(&reader, "{}");\n'.format(key), [len(key)]) + self.dec_buf1 += self.add_transition( + 'mpack_expect_cstr_match(&reader, "{}");\n'.format(key), [len(key)] + ) self.add_value(data[key]) - self.enc_buf += 'mpack_finish_map(&writer);\n' - self.dec_buf += 'mpack_done_map(&reader);\n' - self.dec_buf1 += 'mpack_done_map(&reader);\n' + self.enc_buf += "mpack_finish_map(&writer);\n" + self.dec_buf += "mpack_done_map(&reader);\n" + self.dec_buf1 += "mpack_done_map(&reader);\n" if type(data) == list: - self.enc_buf += self.add_transition('mpack_start_array(&writer, {:d});\n'.format(len(data)), [len(data)]) - self.dec_buf += 'mpack_expect_array_max(&reader, {:d});\n'.format(len(data)) - self.dec_buf1 += self.add_transition('mpack_expect_array_max(&reader, {:d});\n'.format(len(data)), [len(data)]) + self.enc_buf += self.add_transition( + "mpack_start_array(&writer, {:d});\n".format(len(data)), [len(data)] + ) + self.dec_buf += "mpack_expect_array_max(&reader, {:d});\n".format(len(data)) + self.dec_buf1 += self.add_transition( + "mpack_expect_array_max(&reader, {:d});\n".format(len(data)), + [len(data)], + ) for elem in data: - self.add_value(elem); - self.enc_buf += 'mpack_finish_array(&writer);\n' - self.dec_buf += 'mpack_done_array(&reader);\n' - self.dec_buf1 += 'mpack_done_array(&reader);\n' + self.add_value(elem) + self.enc_buf += "mpack_finish_array(&writer);\n" + self.dec_buf += "mpack_done_array(&reader);\n" + self.dec_buf1 += "mpack_done_array(&reader);\n" -class NanoPB(DummyProtocol): - def __init__(self, data, max_serialized_bytes = 256, cardinality = 'required', use_maps = False, max_string_length = None, cc_prefix = '', name = 'Benchmark', - int_type = 'uint16_t', float_type = 'float', dec_index = 0): +class NanoPB(DummyProtocol): + def __init__( + self, + data, + max_serialized_bytes=256, + cardinality="required", + use_maps=False, + max_string_length=None, + cc_prefix="", + name="Benchmark", + int_type="uint16_t", + float_type="float", + dec_index=0, + ): super().__init__() self.data = data self.max_serialized_bytes = max_serialized_bytes @@ -895,60 +1079,62 @@ class NanoPB(DummyProtocol): self.proto_float_type = self.float_type_to_proto_type(float_type) self.dec_index = dec_index self.fieldnum = 1 - self.proto_head = 'syntax = "proto2";\nimport "src/app/prototest/nanopb.proto";\n\n' - self.proto_fields = '' - self.proto_options = '' + self.proto_head = ( + 'syntax = "proto2";\nimport "src/app/prototest/nanopb.proto";\n\n' + ) + self.proto_fields = "" + self.proto_options = "" self.sub_protos = [] - self.cc_encoders = '' + self.cc_encoders = "" self.from_json(data) def is_ascii(self): return False def int_type_to_proto_type(self, int_type): - sign = 'u' - if int_type[0] != 'u': - sign = '' - if '64' in int_type: + sign = "u" + if int_type[0] != "u": + sign = "" + if "64" in int_type: self.int_bits = 64 - return sign + 'int64' + return sign + "int64" # Protocol Buffers only have 32 and 64 bit integers, so we default to 32 self.int_bits = 32 - return sign + 'int32' + return sign + "int32" def float_type_to_proto_type(self, float_type): - if float_type == 'float': + if float_type == "float": self.float_bits = 32 else: self.float_bits = 64 return float_type def get_buffer_declaration(self): - ret = 'uint8_t buf[{:d}];\n'.format(self.max_serialized_bytes) - ret += 'uint16_t serialized_size;\n' + ret = "uint8_t buf[{:d}];\n".format(self.max_serialized_bytes) + ret += "uint16_t serialized_size;\n" return ret + self.get_cc_functions() def get_buffer_name(self): - return 'buf' + return "buf" def get_serialize(self): - ret = 'pb_ostream_t stream = pb_ostream_from_buffer(buf, sizeof(buf));\n' - ret += 'pb_encode(&stream, Benchmark_fields, &msg);\n' - ret += 'serialized_size = stream.bytes_written;\n' + ret = "pb_ostream_t stream = pb_ostream_from_buffer(buf, sizeof(buf));\n" + ret += "pb_encode(&stream, Benchmark_fields, &msg);\n" + ret += "serialized_size = stream.bytes_written;\n" return ret def get_deserialize(self): - ret = 'Benchmark msg = Benchmark_init_zero;\n' - ret += 'pb_istream_t stream = pb_istream_from_buffer(buf, serialized_size);\n' + ret = "Benchmark msg = Benchmark_init_zero;\n" + ret += "pb_istream_t stream = pb_istream_from_buffer(buf, serialized_size);\n" # OptionalTimingAnalysis and other wrappers only wrap lines ending with a ; # We therefore deliberately do not use proper if { ... } syntax here # to make sure that these two statements are timed as well. - ret += 'if (pb_decode(&stream, Benchmark_fields, &msg) == false) ' + ret += "if (pb_decode(&stream, Benchmark_fields, &msg) == false) " ret += 'kout << "deserialized failed" << endl;\n' return ret def get_decode_and_output(self): - return 'kout << dec << "dec:";\n' + self.dec_buf + 'kout << endl;\n' + return 'kout << dec << "dec:";\n' + self.dec_buf + "kout << endl;\n" def get_decode_vars(self): return self.dec_buf0 @@ -957,100 +1143,133 @@ class NanoPB(DummyProtocol): return self.dec_buf1 def get_decode_output(self): - return 'kout << dec << "dec:";\n' + self.dec_buf2 + 'kout << endl;\n'; + return 'kout << dec << "dec:";\n' + self.dec_buf2 + "kout << endl;\n" def get_length_var(self): - return 'serialized_size' + return "serialized_size" def add_field(self, cardinality, fieldtype, key, value): - extra = '' + extra = "" texttype = self.proto_int_type dectype = self.int_type if fieldtype == str: - texttype = 'string' + texttype = "string" elif fieldtype == float: texttype = self.proto_float_type dectype = self.float_type elif fieldtype == dict: texttype = key.capitalize() if type(value) == list: - extra = '[(nanopb).max_count = {:d}]'.format(len(value)) - self.enc_buf += 'msg.{}{}_count = {:d};\n'.format(self.cc_prefix, key, len(value)) - self.proto_fields += '{} {} {} = {:d} {};\n'.format( - cardinality, texttype, key, self.fieldnum, extra) + extra = "[(nanopb).max_count = {:d}]".format(len(value)) + self.enc_buf += "msg.{}{}_count = {:d};\n".format( + self.cc_prefix, key, len(value) + ) + self.proto_fields += "{} {} {} = {:d} {};\n".format( + cardinality, texttype, key, self.fieldnum, extra + ) self.fieldnum += 1 if fieldtype == str: - if cardinality == 'optional': - self.enc_buf += 'msg.{}has_{} = true;\n'.format(self.cc_prefix, key) + if cardinality == "optional": + self.enc_buf += "msg.{}has_{} = true;\n".format(self.cc_prefix, key) if self.max_strlen: - self.proto_options += '{}.{} max_size:{:d}\n'.format(self.name, key, self.max_strlen) + self.proto_options += "{}.{} max_size:{:d}\n".format( + self.name, key, self.max_strlen + ) i = -1 for i, character in enumerate(value): - self.enc_buf += '''msg.{}{}[{:d}] = '{}';\n'''.format(self.cc_prefix, key, i, character) - self.enc_buf += 'msg.{}{}[{:d}] = 0;\n'.format(self.cc_prefix, key, i+1) - self.dec_buf += 'kout << msg.{}{};\n'.format(self.cc_prefix, key) - self.assign_and_kout('char *', 'msg.{}{}'.format(self.cc_prefix, key)) + self.enc_buf += """msg.{}{}[{:d}] = '{}';\n""".format( + self.cc_prefix, key, i, character + ) + self.enc_buf += "msg.{}{}[{:d}] = 0;\n".format( + self.cc_prefix, key, i + 1 + ) + self.dec_buf += "kout << msg.{}{};\n".format(self.cc_prefix, key) + self.assign_and_kout("char *", "msg.{}{}".format(self.cc_prefix, key)) else: - self.cc_encoders += 'bool encode_{}(pb_ostream_t *stream, const pb_field_t *field, void * const *arg)\n'.format(key) - self.cc_encoders += '{\n' - self.cc_encoders += 'if (!pb_encode_tag_for_field(stream, field)) return false;\n' - self.cc_encoders += 'return pb_encode_string(stream, (uint8_t*)"{}", {:d});\n'.format(value, len(value)) - self.cc_encoders += '}\n' - self.enc_buf += 'msg.{}{}.funcs.encode = encode_{};\n'.format(self.cc_prefix, key, key) - self.dec_buf += '// TODO decode string {}{} via callback\n'.format(self.cc_prefix, key) - self.dec_buf1 += '// TODO decode string {}{} via callback\n'.format(self.cc_prefix, key) + self.cc_encoders += "bool encode_{}(pb_ostream_t *stream, const pb_field_t *field, void * const *arg)\n".format( + key + ) + self.cc_encoders += "{\n" + self.cc_encoders += ( + "if (!pb_encode_tag_for_field(stream, field)) return false;\n" + ) + self.cc_encoders += 'return pb_encode_string(stream, (uint8_t*)"{}", {:d});\n'.format( + value, len(value) + ) + self.cc_encoders += "}\n" + self.enc_buf += "msg.{}{}.funcs.encode = encode_{};\n".format( + self.cc_prefix, key, key + ) + self.dec_buf += "// TODO decode string {}{} via callback\n".format( + self.cc_prefix, key + ) + self.dec_buf1 += "// TODO decode string {}{} via callback\n".format( + self.cc_prefix, key + ) elif fieldtype == dict: - if cardinality == 'optional': - self.enc_buf += 'msg.{}has_{} = true;\n'.format(self.cc_prefix, key) + if cardinality == "optional": + self.enc_buf += "msg.{}has_{} = true;\n".format(self.cc_prefix, key) # The rest is handled recursively in add_to_dict elif type(value) == list: for i, elem in enumerate(value): - self.enc_buf += 'msg.{}{}[{:d}] = {};\n'.format(self.cc_prefix, key, i, elem) - self.dec_buf += 'kout << msg.{}{}[{:d}];\n'.format(self.cc_prefix, key, i) + self.enc_buf += "msg.{}{}[{:d}] = {};\n".format( + self.cc_prefix, key, i, elem + ) + self.dec_buf += "kout << msg.{}{}[{:d}];\n".format( + self.cc_prefix, key, i + ) if fieldtype == float: - self.assign_and_kout(self.float_type, 'msg.{}{}[{:d}]'.format(self.cc_prefix, key, i)) + self.assign_and_kout( + self.float_type, "msg.{}{}[{:d}]".format(self.cc_prefix, key, i) + ) elif fieldtype == int: - self.assign_and_kout(self.int_type, 'msg.{}{}[{:d}]'.format(self.cc_prefix, key, i)) + self.assign_and_kout( + self.int_type, "msg.{}{}[{:d}]".format(self.cc_prefix, key, i) + ) elif fieldtype == int: - if cardinality == 'optional': - self.enc_buf += 'msg.{}has_{} = true;\n'.format(self.cc_prefix, key) - self.enc_buf += 'msg.{}{} = {};\n'.format(self.cc_prefix, key, value) - self.dec_buf += 'kout << msg.{}{};\n'.format(self.cc_prefix, key) - self.assign_and_kout(self.int_type, 'msg.{}{}'.format(self.cc_prefix, key)) + if cardinality == "optional": + self.enc_buf += "msg.{}has_{} = true;\n".format(self.cc_prefix, key) + self.enc_buf += "msg.{}{} = {};\n".format(self.cc_prefix, key, value) + self.dec_buf += "kout << msg.{}{};\n".format(self.cc_prefix, key) + self.assign_and_kout(self.int_type, "msg.{}{}".format(self.cc_prefix, key)) elif fieldtype == float: - if cardinality == 'optional': - self.enc_buf += 'msg.{}has_{} = true;\n'.format(self.cc_prefix, key) - self.enc_buf += 'msg.{}{} = {};\n'.format(self.cc_prefix, key, value) - self.dec_buf += 'kout << msg.{}{};\n'.format(self.cc_prefix, key) - self.assign_and_kout(self.float_type, 'msg.{}{}'.format(self.cc_prefix, key)) + if cardinality == "optional": + self.enc_buf += "msg.{}has_{} = true;\n".format(self.cc_prefix, key) + self.enc_buf += "msg.{}{} = {};\n".format(self.cc_prefix, key, value) + self.dec_buf += "kout << msg.{}{};\n".format(self.cc_prefix, key) + self.assign_and_kout( + self.float_type, "msg.{}{}".format(self.cc_prefix, key) + ) elif fieldtype == dict: - if cardinality == 'optional': - self.enc_buf += 'msg.{}has_{} = true;\n'.format(self.cc_prefix, key) - self.enc_buf += 'msg.{}{} = {};\n'.format(self.cc_prefix, key, value) - self.dec_buf += 'kout << msg.{}{};\n'.format(self.cc_prefix, key) - self.assign_and_kout(self.float_type, 'msg.{}{}'.format(self.cc_prefix, key)) + if cardinality == "optional": + self.enc_buf += "msg.{}has_{} = true;\n".format(self.cc_prefix, key) + self.enc_buf += "msg.{}{} = {};\n".format(self.cc_prefix, key, value) + self.dec_buf += "kout << msg.{}{};\n".format(self.cc_prefix, key) + self.assign_and_kout( + self.float_type, "msg.{}{}".format(self.cc_prefix, key) + ) else: self.note_unsupported(value) def get_proto(self): - return self.proto_head + '\n\n'.join(self.get_message_definitions('Benchmark')) + return self.proto_head + "\n\n".join(self.get_message_definitions("Benchmark")) def get_proto_options(self): return self.proto_options def get_extra_files(self): return { - 'nanopbbench.proto' : self.get_proto(), - 'nanopbbench.options' : self.get_proto_options() + "nanopbbench.proto": self.get_proto(), + "nanopbbench.options": self.get_proto_options(), } def get_message_definitions(self, msgname): ret = list(self.sub_protos) - ret.append('message {} {{\n'.format(msgname) + self.proto_fields + '}\n') + ret.append("message {} {{\n".format(msgname) + self.proto_fields + "}\n") return ret def get_encode(self): - ret = 'Benchmark msg = Benchmark_init_zero;\n' + ret = "Benchmark msg = Benchmark_init_zero;\n" return ret + self.enc_buf def get_cc_functions(self): @@ -1058,19 +1277,27 @@ class NanoPB(DummyProtocol): def add_to_dict(self, key, value): if type(value) == str: - if len(value) and value[0] == '$': + if len(value) and value[0] == "$": self.add_field(self.cardinality, int, key, value[1:]) else: self.add_field(self.cardinality, str, key, value) elif type(value) == list: - self.add_field('repeated', type(value[0]), key, value) + self.add_field("repeated", type(value[0]), key, value) elif type(value) == dict: nested_proto = NanoPB( - value, max_string_length = self.max_strlen, cardinality = self.cardinality, use_maps = self.use_maps, cc_prefix = - '{}{}.'.format(self.cc_prefix, key), name = key.capitalize(), - int_type = self.int_type, float_type = self.float_type, - dec_index = self.dec_index) - self.sub_protos.extend(nested_proto.get_message_definitions(key.capitalize())) + value, + max_string_length=self.max_strlen, + cardinality=self.cardinality, + use_maps=self.use_maps, + cc_prefix="{}{}.".format(self.cc_prefix, key), + name=key.capitalize(), + int_type=self.int_type, + float_type=self.float_type, + dec_index=self.dec_index, + ) + self.sub_protos.extend( + nested_proto.get_message_definitions(key.capitalize()) + ) self.proto_options += nested_proto.proto_options self.cc_encoders += nested_proto.cc_encoders self.add_field(self.cardinality, dict, key, value) @@ -1089,19 +1316,21 @@ class NanoPB(DummyProtocol): self.add_to_dict(key, data[key]) - class UBJ(DummyProtocol): - - def __init__(self, data, max_serialized_bytes = 255, int_type = 'uint16_t', float_type = 'float'): + def __init__( + self, data, max_serialized_bytes=255, int_type="uint16_t", float_type="float" + ): super().__init__() self.data = data self.max_serialized_bytes = self.get_serialized_length() + 2 self.int_type = int_type self.float_type = self.parse_float_type(float_type) - self.enc_buf += 'ubjw_context_t* ctx = ubjw_open_memory(buf, buf + sizeof(buf));\n' - self.enc_buf += 'ubjw_begin_object(ctx, UBJ_MIXED, 0);\n' - self.from_json('root', data) - self.enc_buf += 'ubjw_end(ctx);\n' + self.enc_buf += ( + "ubjw_context_t* ctx = ubjw_open_memory(buf, buf + sizeof(buf));\n" + ) + self.enc_buf += "ubjw_begin_object(ctx, UBJ_MIXED, 0);\n" + self.from_json("root", data) + self.enc_buf += "ubjw_end(ctx);\n" def get_serialized_length(self): return len(ubjson.dumpb(self.data)) @@ -1113,41 +1342,41 @@ class UBJ(DummyProtocol): return False def parse_float_type(self, float_type): - if float_type == 'float': + if float_type == "float": self.float_bits = 32 else: self.float_bits = 64 return float_type def get_buffer_declaration(self): - ret = 'uint8_t buf[{:d}];\n'.format(self.max_serialized_bytes) - ret += 'uint16_t serialized_size;\n' + ret = "uint8_t buf[{:d}];\n".format(self.max_serialized_bytes) + ret += "uint16_t serialized_size;\n" return ret def get_buffer_name(self): - return 'buf' + return "buf" def get_length_var(self): - return 'serialized_size' + return "serialized_size" def get_serialize(self): - return 'serialized_size = ubjw_close_context(ctx);\n' + return "serialized_size = ubjw_close_context(ctx);\n" def get_encode(self): return self.enc_buf def get_deserialize(self): - ret = 'ubjr_context_t* ctx = ubjr_open_memory(buf, buf + serialized_size);\n' - ret += 'ubjr_dynamic_t dynamic_root = ubjr_read_dynamic(ctx);\n' - ret += 'ubjr_dynamic_t* root_values = (ubjr_dynamic_t*)dynamic_root.container_object.values;\n' + ret = "ubjr_context_t* ctx = ubjr_open_memory(buf, buf + serialized_size);\n" + ret += "ubjr_dynamic_t dynamic_root = ubjr_read_dynamic(ctx);\n" + ret += "ubjr_dynamic_t* root_values = (ubjr_dynamic_t*)dynamic_root.container_object.values;\n" return ret def get_decode_and_output(self): ret = 'kout << dec << "dec:";\n' ret += self.dec_buf - ret += 'kout << endl;\n' - ret += 'ubjr_cleanup_dynamic(&dynamic_root);\n' # This causes the data (including all strings) to be free'd - ret += 'ubjr_close_context(ctx);\n' + ret += "kout << endl;\n" + ret += "ubjr_cleanup_dynamic(&dynamic_root);\n" # This causes the data (including all strings) to be free'd + ret += "ubjr_close_context(ctx);\n" return ret def get_decode_vars(self): @@ -1157,90 +1386,144 @@ class UBJ(DummyProtocol): return self.dec_buf1 def get_decode_output(self): - ret = 'kout << dec << "dec:";\n' + self.dec_buf2 + 'kout << endl;\n' - ret += 'ubjr_cleanup_dynamic(&dynamic_root);\n' - ret += 'ubjr_close_context(ctx);\n' + ret = 'kout << dec << "dec:";\n' + self.dec_buf2 + "kout << endl;\n" + ret += "ubjr_cleanup_dynamic(&dynamic_root);\n" + ret += "ubjr_close_context(ctx);\n" return ret def add_to_list(self, root, index, value): if type(value) == str: - if len(value) and value[0] == '$': - self.enc_buf += 'ubjw_write_integer(ctx, {});\n'.format(value[1:]) - self.dec_buf += 'kout << {}_values[{:d}].integer;\n'.format(root, index) - self.assign_and_kout(self.int_type, '{}_values[{:d}].integer'.format(root, index)) + if len(value) and value[0] == "$": + self.enc_buf += "ubjw_write_integer(ctx, {});\n".format(value[1:]) + self.dec_buf += "kout << {}_values[{:d}].integer;\n".format(root, index) + self.assign_and_kout( + self.int_type, "{}_values[{:d}].integer".format(root, index) + ) else: - self.enc_buf += self.add_transition('ubjw_write_string(ctx, "{}");\n'.format(value), [len(value)]) - self.dec_buf += 'kout << {}_values[{:d}].string;\n'.format(root, index) - self.assign_and_kout('char *', '{}_values[{:d}].string'.format(root, index)) + self.enc_buf += self.add_transition( + 'ubjw_write_string(ctx, "{}");\n'.format(value), [len(value)] + ) + self.dec_buf += "kout << {}_values[{:d}].string;\n".format(root, index) + self.assign_and_kout( + "char *", "{}_values[{:d}].string".format(root, index) + ) elif type(value) == list: - self.enc_buf += 'ubjw_begin_array(ctx, UBJ_MIXED, 0);\n' - self.dec_buf += '// decoding nested lists is not supported\n' - self.dec_buf1 += '// decoding nested lists is not supported\n' + self.enc_buf += "ubjw_begin_array(ctx, UBJ_MIXED, 0);\n" + self.dec_buf += "// decoding nested lists is not supported\n" + self.dec_buf1 += "// decoding nested lists is not supported\n" self.from_json(root, value) - self.enc_buf += 'ubjw_end(ctx);\n' + self.enc_buf += "ubjw_end(ctx);\n" elif type(value) == dict: - self.enc_buf += 'ubjw_begin_object(ctx, UBJ_MIXED, 0);\n' - self.dec_buf += '// decoding objects in lists is not supported\n' - self.dec_buf1 += '// decoding objects in lists is not supported\n' + self.enc_buf += "ubjw_begin_object(ctx, UBJ_MIXED, 0);\n" + self.dec_buf += "// decoding objects in lists is not supported\n" + self.dec_buf1 += "// decoding objects in lists is not supported\n" self.from_json(root, value) - self.enc_buf += 'ubjw_end(ctx);\n' + self.enc_buf += "ubjw_end(ctx);\n" elif type(value) == float: - self.enc_buf += 'ubjw_write_float{:d}(ctx, {});\n'.format(self.float_bits, value) - self.dec_buf += 'kout << {}_values[{:d}].real;\n'.format(root, index) - self.assign_and_kout(self.float_type, '{}_values[{:d}].real'.format(root, index)) + self.enc_buf += "ubjw_write_float{:d}(ctx, {});\n".format( + self.float_bits, value + ) + self.dec_buf += "kout << {}_values[{:d}].real;\n".format(root, index) + self.assign_and_kout( + self.float_type, "{}_values[{:d}].real".format(root, index) + ) elif type(value) == int: - self.enc_buf += 'ubjw_write_integer(ctx, {});\n'.format(value) - self.dec_buf += 'kout << {}_values[{:d}].integer;\n'.format(root, index) - self.assign_and_kout(self.int_type, '{}_values[{:d}].integer'.format(root, index)) + self.enc_buf += "ubjw_write_integer(ctx, {});\n".format(value) + self.dec_buf += "kout << {}_values[{:d}].integer;\n".format(root, index) + self.assign_and_kout( + self.int_type, "{}_values[{:d}].integer".format(root, index) + ) else: - raise TypeError('Cannot handle {} of type {}'.format(value, type(value))) + raise TypeError("Cannot handle {} of type {}".format(value, type(value))) def add_to_dict(self, root, index, key, value): if type(value) == str: - if len(value) and value[0] == '$': - self.enc_buf += self.add_transition('ubjw_write_key(ctx, "{}"); ubjw_write_integer(ctx, {});\n'.format(key, value[1:]), [len(key)]) - self.dec_buf += 'kout << {}_values[{:d}].integer;\n'.format(root, index) - self.assign_and_kout(self.int_type, '{}_values[{:d}].integer'.format(root, index)) + if len(value) and value[0] == "$": + self.enc_buf += self.add_transition( + 'ubjw_write_key(ctx, "{}"); ubjw_write_integer(ctx, {});\n'.format( + key, value[1:] + ), + [len(key)], + ) + self.dec_buf += "kout << {}_values[{:d}].integer;\n".format(root, index) + self.assign_and_kout( + self.int_type, "{}_values[{:d}].integer".format(root, index) + ) else: - self.enc_buf += self.add_transition('ubjw_write_key(ctx, "{}"); ubjw_write_string(ctx, "{}");\n'.format(key, value), [len(key), len(value)]) - self.dec_buf += 'kout << {}_values[{:d}].string;\n'.format(root, index) - self.assign_and_kout('char *', '{}_values[{:d}].string'.format(root, index)) + self.enc_buf += self.add_transition( + 'ubjw_write_key(ctx, "{}"); ubjw_write_string(ctx, "{}");\n'.format( + key, value + ), + [len(key), len(value)], + ) + self.dec_buf += "kout << {}_values[{:d}].string;\n".format(root, index) + self.assign_and_kout( + "char *", "{}_values[{:d}].string".format(root, index) + ) elif type(value) == list: - self.enc_buf += self.add_transition('ubjw_write_key(ctx, "{}"); ubjw_begin_array(ctx, UBJ_MIXED, 0);\n'.format(key), [len(key)]) - self.dec_buf += 'ubjr_dynamic_t *{}_values = (ubjr_dynamic_t*){}_values[{:d}].container_array.values;\n'.format( - key, root, index) - self.dec_buf1 += 'ubjr_dynamic_t *{}_values = (ubjr_dynamic_t*){}_values[{:d}].container_array.values;\n'.format( - key, root, index) + self.enc_buf += self.add_transition( + 'ubjw_write_key(ctx, "{}"); ubjw_begin_array(ctx, UBJ_MIXED, 0);\n'.format( + key + ), + [len(key)], + ) + self.dec_buf += "ubjr_dynamic_t *{}_values = (ubjr_dynamic_t*){}_values[{:d}].container_array.values;\n".format( + key, root, index + ) + self.dec_buf1 += "ubjr_dynamic_t *{}_values = (ubjr_dynamic_t*){}_values[{:d}].container_array.values;\n".format( + key, root, index + ) self.from_json(key, value) - self.enc_buf += 'ubjw_end(ctx);\n' + self.enc_buf += "ubjw_end(ctx);\n" elif type(value) == dict: - self.enc_buf += self.add_transition('ubjw_write_key(ctx, "{}"); ubjw_begin_object(ctx, UBJ_MIXED, 0);\n'.format(key), [len(key)]) - self.dec_buf += 'ubjr_dynamic_t *{}_values = (ubjr_dynamic_t*){}_values[{:d}].container_object.values;\n'.format( - key, root, index) - self.dec_buf1 += 'ubjr_dynamic_t *{}_values = (ubjr_dynamic_t*){}_values[{:d}].container_object.values;\n'.format( - key, root, index) + self.enc_buf += self.add_transition( + 'ubjw_write_key(ctx, "{}"); ubjw_begin_object(ctx, UBJ_MIXED, 0);\n'.format( + key + ), + [len(key)], + ) + self.dec_buf += "ubjr_dynamic_t *{}_values = (ubjr_dynamic_t*){}_values[{:d}].container_object.values;\n".format( + key, root, index + ) + self.dec_buf1 += "ubjr_dynamic_t *{}_values = (ubjr_dynamic_t*){}_values[{:d}].container_object.values;\n".format( + key, root, index + ) self.from_json(key, value) - self.enc_buf += 'ubjw_end(ctx);\n' + self.enc_buf += "ubjw_end(ctx);\n" elif type(value) == float: - self.enc_buf += self.add_transition('ubjw_write_key(ctx, "{}"); ubjw_write_float{:d}(ctx, {});\n'.format(key, self.float_bits, value), [len(key)]) - self.dec_buf += 'kout << {}_values[{:d}].real;\n'.format(root, index) - self.assign_and_kout(self.float_type, '{}_values[{:d}].real'.format(root, index)) + self.enc_buf += self.add_transition( + 'ubjw_write_key(ctx, "{}"); ubjw_write_float{:d}(ctx, {});\n'.format( + key, self.float_bits, value + ), + [len(key)], + ) + self.dec_buf += "kout << {}_values[{:d}].real;\n".format(root, index) + self.assign_and_kout( + self.float_type, "{}_values[{:d}].real".format(root, index) + ) elif type(value) == int: - self.enc_buf += self.add_transition('ubjw_write_key(ctx, "{}"); ubjw_write_integer(ctx, {});\n'.format(key, value), [len(key)]) - self.dec_buf += 'kout << {}_values[{:d}].integer;\n'.format(root, index) - self.assign_and_kout(self.int_type, '{}_values[{:d}].integer'.format(root, index)) + self.enc_buf += self.add_transition( + 'ubjw_write_key(ctx, "{}"); ubjw_write_integer(ctx, {});\n'.format( + key, value + ), + [len(key)], + ) + self.dec_buf += "kout << {}_values[{:d}].integer;\n".format(root, index) + self.assign_and_kout( + self.int_type, "{}_values[{:d}].integer".format(root, index) + ) else: - raise TypeError('Cannot handle {} of type {}'.format(value, type(value))) + raise TypeError("Cannot handle {} of type {}".format(value, type(value))) def from_json(self, root, data): if type(data) == dict: @@ -1253,63 +1536,64 @@ class UBJ(DummyProtocol): class XDR(DummyProtocol): - - def __init__(self, data, max_serialized_bytes = 256, int_type = 'uint16_t', float_type = 'float'): + def __init__( + self, data, max_serialized_bytes=256, int_type="uint16_t", float_type="float" + ): super().__init__() self.data = data self.max_serialized_bytes = 256 self.enc_int_type = int_type self.dec_int_type = self.parse_int_type(int_type) self.float_type = self.parse_float_type(float_type) - self.enc_buf += 'XDRWriter xdrstream(buf);\n' - self.dec_buf += 'XDRReader xdrinput(buf);\n' - self.dec_buf0 += 'XDRReader xdrinput(buf);\n' + self.enc_buf += "XDRWriter xdrstream(buf);\n" + self.dec_buf += "XDRReader xdrinput(buf);\n" + self.dec_buf0 += "XDRReader xdrinput(buf);\n" # By default, XDR does not even include a version / protocol specifier. # This seems rather impractical -> emulate that here. - #self.enc_buf += 'xdrstream << (uint32_t)22075;\n' - self.dec_buf += 'char strbuf[16];\n' - #self.dec_buf += 'xdrinput.get_uint32();\n' - self.dec_buf0 += 'char strbuf[16];\n' - #self.dec_buf0 += 'xdrinput.get_uint32();\n' + # self.enc_buf += 'xdrstream << (uint32_t)22075;\n' + self.dec_buf += "char strbuf[16];\n" + # self.dec_buf += 'xdrinput.get_uint32();\n' + self.dec_buf0 += "char strbuf[16];\n" + # self.dec_buf0 += 'xdrinput.get_uint32();\n' self.from_json(data) def is_ascii(self): return False def parse_int_type(self, int_type): - sign = '' - if int_type[0] == 'u': - sign = 'u' - if '64' in int_type: + sign = "" + if int_type[0] == "u": + sign = "u" + if "64" in int_type: self.int_bits = 64 - return sign + 'int64' + return sign + "int64" else: self.int_bits = 32 - return sign + 'int32' + return sign + "int32" def parse_float_type(self, float_type): - if float_type == 'float': + if float_type == "float": self.float_bits = 32 else: self.float_bits = 64 return float_type def get_buffer_declaration(self): - ret = 'uint16_t serialized_size;\n' - ret += 'char buf[{:d}];\n'.format(self.max_serialized_bytes) + ret = "uint16_t serialized_size;\n" + ret += "char buf[{:d}];\n".format(self.max_serialized_bytes) return ret def get_buffer_name(self): - return 'buf' + return "buf" def get_length_var(self): - return 'xdrstream.size()' + return "xdrstream.size()" def get_encode(self): return self.enc_buf def get_decode_and_output(self): - return 'kout << dec << "dec:";\n' + self.dec_buf + 'kout << endl;\n' + return 'kout << dec << "dec:";\n' + self.dec_buf + "kout << endl;\n" def get_decode_vars(self): return self.dec_buf0 @@ -1318,115 +1602,129 @@ class XDR(DummyProtocol): return self.dec_buf1 def get_decode_output(self): - return 'kout << dec << "dec:";\n' + self.dec_buf2 + 'kout << endl;\n' + return 'kout << dec << "dec:";\n' + self.dec_buf2 + "kout << endl;\n" def from_json(self, data): if type(data) == dict: for key in sorted(data.keys()): self.from_json(data[key]) elif type(data) == list: - self.enc_buf += 'xdrstream.setNextArrayLen({});\n'.format(len(data)) - self.enc_buf += 'xdrstream.setVariableLength();\n' - self.enc_buf += 'xdrstream.startList();\n' - self.dec_buf += 'xdrinput.get_uint32();\n' - self.dec_buf1 += 'xdrinput.get_uint32();\n' + self.enc_buf += "xdrstream.setNextArrayLen({});\n".format(len(data)) + self.enc_buf += "xdrstream.setVariableLength();\n" + self.enc_buf += "xdrstream.startList();\n" + self.dec_buf += "xdrinput.get_uint32();\n" + self.dec_buf1 += "xdrinput.get_uint32();\n" for elem in data: self.from_json(elem) elif type(data) == str: - if len(data) and data[0] == '$': - self.enc_buf += 'xdrstream.put(({}){});\n'.format(self.enc_int_type, data[1:]) - self.dec_buf += 'kout << xdrinput.get_{}();\n'.format(self.dec_int_type) - self.dec_buf0 += '{} dec_{};\n'.format(self.enc_int_type, self.dec_index) - self.dec_buf1 += 'dec_{} = xdrinput.get_{}();;\n'.format(self.dec_index, self.dec_int_type) - self.dec_buf2 += 'kout << dec_{};\n'.format(self.dec_index) + if len(data) and data[0] == "$": + self.enc_buf += "xdrstream.put(({}){});\n".format( + self.enc_int_type, data[1:] + ) + self.dec_buf += "kout << xdrinput.get_{}();\n".format(self.dec_int_type) + self.dec_buf0 += "{} dec_{};\n".format( + self.enc_int_type, self.dec_index + ) + self.dec_buf1 += "dec_{} = xdrinput.get_{}();;\n".format( + self.dec_index, self.dec_int_type + ) + self.dec_buf2 += "kout << dec_{};\n".format(self.dec_index) else: # Kodierte Strings haben nicht immer ein Nullbyte am Ende - self.enc_buf += 'xdrstream.setNextArrayLen({});\n'.format(len(data)) - self.enc_buf += 'xdrstream.setVariableLength();\n' - self.enc_buf += self.add_transition('xdrstream.put("{}");\n'.format(data), [len(data)]) - self.dec_buf += 'xdrinput.get_string(strbuf);\n' - self.dec_buf += 'kout << strbuf;\n' - self.dec_buf1 += 'xdrinput.get_string(strbuf);\n' - self.dec_buf2 += 'kout << strbuf;\n' + self.enc_buf += "xdrstream.setNextArrayLen({});\n".format(len(data)) + self.enc_buf += "xdrstream.setVariableLength();\n" + self.enc_buf += self.add_transition( + 'xdrstream.put("{}");\n'.format(data), [len(data)] + ) + self.dec_buf += "xdrinput.get_string(strbuf);\n" + self.dec_buf += "kout << strbuf;\n" + self.dec_buf1 += "xdrinput.get_string(strbuf);\n" + self.dec_buf2 += "kout << strbuf;\n" elif type(data) == float: - self.enc_buf += 'xdrstream.put(({}){});\n'.format(self.float_type, data) - self.dec_buf += 'kout << xdrinput.get_{}();\n'.format(self.float_type) - self.dec_buf0 += '{} dec_{};\n'.format(self.float_type, self.dec_index) - self.dec_buf1 += 'dec_{} = xdrinput.get_{}();\n'.format(self.dec_index, self.float_type) - self.dec_buf2 += 'kout << dec_{};\n'.format(self.dec_index) + self.enc_buf += "xdrstream.put(({}){});\n".format(self.float_type, data) + self.dec_buf += "kout << xdrinput.get_{}();\n".format(self.float_type) + self.dec_buf0 += "{} dec_{};\n".format(self.float_type, self.dec_index) + self.dec_buf1 += "dec_{} = xdrinput.get_{}();\n".format( + self.dec_index, self.float_type + ) + self.dec_buf2 += "kout << dec_{};\n".format(self.dec_index) elif type(data) == int: - self.enc_buf += 'xdrstream.put(({}){});\n'.format(self.enc_int_type, data) - self.dec_buf += 'kout << xdrinput.get_{}();\n'.format(self.dec_int_type) - self.dec_buf0 += '{} dec_{};\n'.format(self.enc_int_type, self.dec_index) - self.dec_buf1 += 'dec_{} = xdrinput.get_{}();\n'.format(self.dec_index, self.dec_int_type) - self.dec_buf2 += 'kout << dec_{};\n'.format(self.dec_index) + self.enc_buf += "xdrstream.put(({}){});\n".format(self.enc_int_type, data) + self.dec_buf += "kout << xdrinput.get_{}();\n".format(self.dec_int_type) + self.dec_buf0 += "{} dec_{};\n".format(self.enc_int_type, self.dec_index) + self.dec_buf1 += "dec_{} = xdrinput.get_{}();\n".format( + self.dec_index, self.dec_int_type + ) + self.dec_buf2 += "kout << dec_{};\n".format(self.dec_index) else: - self.enc_buf += 'xdrstream.put({});\n'.format(data) - self.dec_buf += '// unsupported type {} of {}\n'.format(type(data), data) - self.dec_buf1 += '// unsupported type {} of {}\n'.format(type(data), data) + self.enc_buf += "xdrstream.put({});\n".format(data) + self.dec_buf += "// unsupported type {} of {}\n".format(type(data), data) + self.dec_buf1 += "// unsupported type {} of {}\n".format(type(data), data) self.dec_index += 1 -class XDR16(DummyProtocol): - def __init__(self, data, max_serialized_bytes = 256, int_type = 'uint16_t', float_type = 'float'): +class XDR16(DummyProtocol): + def __init__( + self, data, max_serialized_bytes=256, int_type="uint16_t", float_type="float" + ): super().__init__() self.data = data self.max_serialized_bytes = 256 self.enc_int_type = int_type self.dec_int_type = self.parse_int_type(int_type) self.float_type = self.parse_float_type(float_type) - self.enc_buf += 'XDRWriter xdrstream(buf);\n' - self.dec_buf += 'XDRReader xdrinput(buf);\n' - self.dec_buf0 += 'XDRReader xdrinput(buf);\n' + self.enc_buf += "XDRWriter xdrstream(buf);\n" + self.dec_buf += "XDRReader xdrinput(buf);\n" + self.dec_buf0 += "XDRReader xdrinput(buf);\n" # By default, XDR does not even include a version / protocol specifier. # This seems rather impractical -> emulate that here. - #self.enc_buf += 'xdrstream << (uint32_t)22075;\n' - self.dec_buf += 'char strbuf[16];\n' - #self.dec_buf += 'xdrinput.get_uint32();\n' - self.dec_buf0 += 'char strbuf[16];\n' - #self.dec_buf0 += 'xdrinput.get_uint32();\n' + # self.enc_buf += 'xdrstream << (uint32_t)22075;\n' + self.dec_buf += "char strbuf[16];\n" + # self.dec_buf += 'xdrinput.get_uint32();\n' + self.dec_buf0 += "char strbuf[16];\n" + # self.dec_buf0 += 'xdrinput.get_uint32();\n' self.from_json(data) def is_ascii(self): return False def parse_int_type(self, int_type): - sign = '' - if int_type[0] == 'u': - sign = 'u' - if '64' in int_type: + sign = "" + if int_type[0] == "u": + sign = "u" + if "64" in int_type: self.int_bits = 64 - return sign + 'int64' - if '32' in int_type: + return sign + "int64" + if "32" in int_type: self.int_bits = 32 - return sign + 'int32' + return sign + "int32" else: self.int_bits = 16 - return sign + 'int16' + return sign + "int16" def parse_float_type(self, float_type): - if float_type == 'float': + if float_type == "float": self.float_bits = 32 else: self.float_bits = 64 return float_type def get_buffer_declaration(self): - ret = 'uint16_t serialized_size;\n' - ret += 'char buf[{:d}];\n'.format(self.max_serialized_bytes) + ret = "uint16_t serialized_size;\n" + ret += "char buf[{:d}];\n".format(self.max_serialized_bytes) return ret def get_buffer_name(self): - return 'buf' + return "buf" def get_length_var(self): - return 'xdrstream.size()' + return "xdrstream.size()" def get_encode(self): return self.enc_buf def get_decode_and_output(self): - return 'kout << dec << "dec:";\n' + self.dec_buf + 'kout << endl;\n' + return 'kout << dec << "dec:";\n' + self.dec_buf + "kout << endl;\n" def get_decode_vars(self): return self.dec_buf0 @@ -1435,76 +1733,99 @@ class XDR16(DummyProtocol): return self.dec_buf1 def get_decode_output(self): - return 'kout << dec << "dec:";\n' + self.dec_buf2 + 'kout << endl;\n'; + return 'kout << dec << "dec:";\n' + self.dec_buf2 + "kout << endl;\n" def from_json(self, data): if type(data) == dict: for key in sorted(data.keys()): self.from_json(data[key]) elif type(data) == list: - self.enc_buf += 'xdrstream.setNextArrayLen({});\n'.format(len(data)) - self.enc_buf += 'xdrstream.setVariableLength();\n' - self.enc_buf += 'xdrstream.startList();\n' - self.dec_buf += 'xdrinput.get_uint16();\n' - self.dec_buf1 += 'xdrinput.get_uint16();\n' + self.enc_buf += "xdrstream.setNextArrayLen({});\n".format(len(data)) + self.enc_buf += "xdrstream.setVariableLength();\n" + self.enc_buf += "xdrstream.startList();\n" + self.dec_buf += "xdrinput.get_uint16();\n" + self.dec_buf1 += "xdrinput.get_uint16();\n" for elem in data: self.from_json(elem) elif type(data) == str: - if len(data) and data[0] == '$': - self.enc_buf += 'xdrstream.put(({}){});\n'.format(self.enc_int_type, data[1:]) - self.dec_buf += 'kout << xdrinput.get_{}();\n'.format(self.dec_int_type) - self.dec_buf0 += '{} dec_{};\n'.format(self.enc_int_type, self.dec_index) - self.dec_buf1 += 'dec_{} = xdrinput.get_{}();;\n'.format(self.dec_index, self.dec_int_type) - self.dec_buf2 += 'kout << dec_{};\n'.format(self.dec_index) + if len(data) and data[0] == "$": + self.enc_buf += "xdrstream.put(({}){});\n".format( + self.enc_int_type, data[1:] + ) + self.dec_buf += "kout << xdrinput.get_{}();\n".format(self.dec_int_type) + self.dec_buf0 += "{} dec_{};\n".format( + self.enc_int_type, self.dec_index + ) + self.dec_buf1 += "dec_{} = xdrinput.get_{}();;\n".format( + self.dec_index, self.dec_int_type + ) + self.dec_buf2 += "kout << dec_{};\n".format(self.dec_index) else: # Kodierte Strings haben nicht immer ein Nullbyte am Ende - self.enc_buf += 'xdrstream.setNextArrayLen({});\n'.format(len(data)) - self.enc_buf += 'xdrstream.setVariableLength();\n' - self.enc_buf += self.add_transition('xdrstream.put("{}");\n'.format(data), [len(data)]) - self.dec_buf += 'xdrinput.get_string(strbuf);\n' - self.dec_buf += 'kout << strbuf;\n' - self.dec_buf1 += 'xdrinput.get_string(strbuf);\n' - self.dec_buf2 += 'kout << strbuf;\n' + self.enc_buf += "xdrstream.setNextArrayLen({});\n".format(len(data)) + self.enc_buf += "xdrstream.setVariableLength();\n" + self.enc_buf += self.add_transition( + 'xdrstream.put("{}");\n'.format(data), [len(data)] + ) + self.dec_buf += "xdrinput.get_string(strbuf);\n" + self.dec_buf += "kout << strbuf;\n" + self.dec_buf1 += "xdrinput.get_string(strbuf);\n" + self.dec_buf2 += "kout << strbuf;\n" elif type(data) == float: - self.enc_buf += 'xdrstream.put(({}){});\n'.format(self.float_type, data) - self.dec_buf += 'kout << xdrinput.get_{}();\n'.format(self.float_type) - self.dec_buf0 += '{} dec_{};\n'.format(self.float_type, self.dec_index) - self.dec_buf1 += 'dec_{} = xdrinput.get_{}();\n'.format(self.dec_index, self.float_type) - self.dec_buf2 += 'kout << dec_{};\n'.format(self.dec_index) + self.enc_buf += "xdrstream.put(({}){});\n".format(self.float_type, data) + self.dec_buf += "kout << xdrinput.get_{}();\n".format(self.float_type) + self.dec_buf0 += "{} dec_{};\n".format(self.float_type, self.dec_index) + self.dec_buf1 += "dec_{} = xdrinput.get_{}();\n".format( + self.dec_index, self.float_type + ) + self.dec_buf2 += "kout << dec_{};\n".format(self.dec_index) elif type(data) == int: - self.enc_buf += 'xdrstream.put(({}){});\n'.format(self.enc_int_type, data) - self.dec_buf += 'kout << xdrinput.get_{}();\n'.format(self.dec_int_type) - self.dec_buf0 += '{} dec_{};\n'.format(self.enc_int_type, self.dec_index) - self.dec_buf1 += 'dec_{} = xdrinput.get_{}();\n'.format(self.dec_index, self.dec_int_type) - self.dec_buf2 += 'kout << dec_{};\n'.format(self.dec_index) + self.enc_buf += "xdrstream.put(({}){});\n".format(self.enc_int_type, data) + self.dec_buf += "kout << xdrinput.get_{}();\n".format(self.dec_int_type) + self.dec_buf0 += "{} dec_{};\n".format(self.enc_int_type, self.dec_index) + self.dec_buf1 += "dec_{} = xdrinput.get_{}();\n".format( + self.dec_index, self.dec_int_type + ) + self.dec_buf2 += "kout << dec_{};\n".format(self.dec_index) else: - self.enc_buf += 'xdrstream.put({});\n'.format(data) - self.dec_buf += '// unsupported type {} of {}\n'.format(type(data), data) - self.dec_buf1 += '// unsupported type {} of {}\n'.format(type(data), data) - self.dec_index += 1; + self.enc_buf += "xdrstream.put({});\n".format(data) + self.dec_buf += "// unsupported type {} of {}\n".format(type(data), data) + self.dec_buf1 += "// unsupported type {} of {}\n".format(type(data), data) + self.dec_index += 1 -class Benchmark: +class Benchmark: def __init__(self, logfile): self.atomic = True self.logfile = logfile def __enter__(self): self.atomic = False - with FileLock(self.logfile + '.lock'): + with FileLock(self.logfile + ".lock"): if os.path.exists(self.logfile): - with open(self.logfile, 'rb') as f: + with open(self.logfile, "rb") as f: self.data = ubjson.load(f) else: self.data = {} return self def __exit__(self, exc_type, exc_value, exc_traceback): - with FileLock(self.logfile + '.lock'): - with open(self.logfile, 'wb') as f: + with FileLock(self.logfile + ".lock"): + with open(self.logfile, "wb") as f: ubjson.dump(self.data, f) - def _add_log_entry(self, benchmark_data, arch, libkey, bench_name, bench_index, data, key, value, error): + def _add_log_entry( + self, + benchmark_data, + arch, + libkey, + bench_name, + bench_index, + data, + key, + value, + error, + ): if not libkey in benchmark_data: benchmark_data[libkey] = dict() if not bench_name in benchmark_data[libkey]: @@ -1514,94 +1835,127 @@ class Benchmark: this_result = benchmark_data[libkey][bench_name][bench_index] # data is unset for log(...) calls from postprocessing if data != None: - this_result['data'] = data + this_result["data"] = data if value != None: - this_result[key] = { - 'v' : value, - 'ts' : int(time.time()) - } - print('{} {} {} ({}) :: {} -> {}'.format( - libkey, bench_name, bench_index, data, key, value)) + this_result[key] = {"v": value, "ts": int(time.time())} + print( + "{} {} {} ({}) :: {} -> {}".format( + libkey, bench_name, bench_index, data, key, value + ) + ) else: - this_result[key] = { - 'e' : error, - 'ts' : int(time.time()) - } - print('{} {} {} ({}) :: {} -> [E] {}'.format( - libkey, bench_name, bench_index, data, key, error[:500])) - - def log(self, arch, library, library_options, bench_name, bench_index, data, key, value = None, error = None): + this_result[key] = {"e": error, "ts": int(time.time())} + print( + "{} {} {} ({}) :: {} -> [E] {}".format( + libkey, bench_name, bench_index, data, key, error[:500] + ) + ) + + def log( + self, + arch, + library, + library_options, + bench_name, + bench_index, + data, + key, + value=None, + error=None, + ): if not library_options: library_options = [] - libkey = '{}:{}:{}'.format(arch, library, ','.join(library_options)) + libkey = "{}:{}:{}".format(arch, library, ",".join(library_options)) # JSON does not differentiate between int and str keys -> always use # str(bench_index) bench_index = str(bench_index) if self.atomic: - with FileLock(self.logfile + '.lock'): + with FileLock(self.logfile + ".lock"): if os.path.exists(self.logfile): - with open(self.logfile, 'rb') as f: + with open(self.logfile, "rb") as f: benchmark_data = ubjson.load(f) else: benchmark_data = {} - self._add_log_entry(benchmark_data, arch, libkey, bench_name, bench_index, data, key, value, error) - with open(self.logfile, 'wb') as f: + self._add_log_entry( + benchmark_data, + arch, + libkey, + bench_name, + bench_index, + data, + key, + value, + error, + ) + with open(self.logfile, "wb") as f: ubjson.dump(benchmark_data, f) else: - self._add_log_entry(self.data, arch, libkey, bench_name, bench_index, data, key, value, error) + self._add_log_entry( + self.data, + arch, + libkey, + bench_name, + bench_index, + data, + key, + value, + error, + ) def get_snapshot(self): - with FileLock(self.logfile + '.lock'): + with FileLock(self.logfile + ".lock"): if os.path.exists(self.logfile): - with open(self.logfile, 'rb') as f: + with open(self.logfile, "rb") as f: benchmark_data = ubjson.load(f) else: benchmark_data = {} return benchmark_data + def codegen_for_lib(library, library_options, data): - if library == 'arduinojson': - return ArduinoJSON(data, bufsize = 512) + if library == "arduinojson": + return ArduinoJSON(data, bufsize=512) - if library == 'avro': + if library == "avro": strip_schema = bool(int(library_options[0])) - return Avro(data, strip_schema = strip_schema) + return Avro(data, strip_schema=strip_schema) - if library == 'capnproto_c': + if library == "capnproto_c": packed = bool(int(library_options[0])) - return CapnProtoC(data, packed = packed) + return CapnProtoC(data, packed=packed) - if library == 'manualjson': + if library == "manualjson": return ManualJSON(data) - if library == 'modernjson': - dataformat, = library_options + if library == "modernjson": + (dataformat,) = library_options return ModernJSON(data, dataformat) - if library == 'mpack': + if library == "mpack": return MPack(data) - if library == 'nanopb': + if library == "nanopb": cardinality, strbuf = library_options - if not len(strbuf) or strbuf == '0': + if not len(strbuf) or strbuf == "0": strbuf = None else: strbuf = int(strbuf) - return NanoPB(data, cardinality = cardinality, max_string_length = strbuf) + return NanoPB(data, cardinality=cardinality, max_string_length=strbuf) - if library == 'thrift': + if library == "thrift": return Thrift(data) - if library == 'ubjson': + if library == "ubjson": return UBJ(data) - if library == 'xdr': + if library == "xdr": return XDR(data) - if library == 'xdr16': + if library == "xdr16": return XDR16(data) - raise ValueError('Unsupported library: {}'.format(library)) + raise ValueError("Unsupported library: {}".format(library)) + def shorten_call(snippet, lib): """ @@ -1612,54 +1966,78 @@ def shorten_call(snippet, lib): """ # The following adjustments are protobench-specific # "xdrstream << (uint16_t)123" -> "xdrstream << (uint16_t" - if 'xdrstream << (' in snippet: - snippet = snippet.split(')')[0] + if "xdrstream << (" in snippet: + snippet = snippet.split(")")[0] # "xdrstream << variable << ..." -> "xdrstream << variable" - elif 'xdrstream << variable' in snippet: - snippet = '<<'.join(snippet.split('<<')[0:2]) - elif 'xdrstream.setNextArrayLen(' in snippet: - snippet = 'xdrstream.setNextArrayLen' - elif 'ubjw' in snippet: - snippet = re.sub('ubjw_write_key\(ctx, [^)]+\)', 'ubjw_write_key(ctx, ?)', snippet) - snippet = re.sub('ubjw_write_([^(]+)\(ctx, [^)]+\)', 'ubjw_write_\\1(ctx, ?)', snippet) + elif "xdrstream << variable" in snippet: + snippet = "<<".join(snippet.split("<<")[0:2]) + elif "xdrstream.setNextArrayLen(" in snippet: + snippet = "xdrstream.setNextArrayLen" + elif "ubjw" in snippet: + snippet = re.sub( + "ubjw_write_key\(ctx, [^)]+\)", "ubjw_write_key(ctx, ?)", snippet + ) + snippet = re.sub( + "ubjw_write_([^(]+)\(ctx, [^)]+\)", "ubjw_write_\\1(ctx, ?)", snippet + ) # mpack_write(&writer, (type)value) -> mpack_write(&writer, (type - elif 'mpack_write(' in snippet: - snippet = snippet.split(')')[0] + elif "mpack_write(" in snippet: + snippet = snippet.split(")")[0] # mpack_write_cstr(&writer, "foo") -> mpack_write_cstr(&writer, # same for mpack_write_cstr_or_nil - elif 'mpack_write_cstr' in snippet: + elif "mpack_write_cstr" in snippet: snippet = snippet.split('"')[0] # mpack_start_map(&writer, x) -> mpack_start_map(&writer # mpack_start_array(&writer, x) -> mpack_start_array(&writer - elif 'mpack_start_' in snippet: - snippet = snippet.split(',')[0] - #elif 'bout <<' in snippet: + elif "mpack_start_" in snippet: + snippet = snippet.split(",")[0] + # elif 'bout <<' in snippet: # if '\\":\\"' in snippet: # snippet = 'bout << key:str' # elif 'bout << "\\"' in snippet: # snippet = 'bout << key' # else: # snippet = 'bout << other' - elif 'msg.' in snippet: - snippet = re.sub('msg.(?:[^[]+)(?:\[.*?\])? = .*', 'msg.? = ?', snippet) - elif lib == 'arduinojson:': - snippet = re.sub('ArduinoJson::JsonObject& [^ ]+ = [^.]+.createNestedObject\([^")]*\);', 'ArduinoJson::JsonObject& ? = ?.createNestedObject();', snippet) - snippet = re.sub('ArduinoJson::JsonObject& [^ ]+ = [^.]+.createNestedObject\("[^")]*"\);', 'ArduinoJson::JsonObject& ? = ?.createNestedObject(?);', snippet) - snippet = re.sub('ArduinoJson::JsonArray& [^ ]+ = [^.]+.createNestedArray\([^")]*\);', 'ArduinoJson::JsonArray& ? = ?.createNestedArray();', snippet) - snippet = re.sub('ArduinoJson::JsonArray& [^ ]+ = [^.]+.createNestedArray\("[^")]*"\);', 'ArduinoJson::JsonArray& ? = ?.createNestedArray(?);', snippet) + elif "msg." in snippet: + snippet = re.sub("msg.(?:[^[]+)(?:\[.*?\])? = .*", "msg.? = ?", snippet) + elif lib == "arduinojson:": + snippet = re.sub( + 'ArduinoJson::JsonObject& [^ ]+ = [^.]+.createNestedObject\([^")]*\);', + "ArduinoJson::JsonObject& ? = ?.createNestedObject();", + snippet, + ) + snippet = re.sub( + 'ArduinoJson::JsonObject& [^ ]+ = [^.]+.createNestedObject\("[^")]*"\);', + "ArduinoJson::JsonObject& ? = ?.createNestedObject(?);", + snippet, + ) + snippet = re.sub( + 'ArduinoJson::JsonArray& [^ ]+ = [^.]+.createNestedArray\([^")]*\);', + "ArduinoJson::JsonArray& ? = ?.createNestedArray();", + snippet, + ) + snippet = re.sub( + 'ArduinoJson::JsonArray& [^ ]+ = [^.]+.createNestedArray\("[^")]*"\);', + "ArduinoJson::JsonArray& ? = ?.createNestedArray(?);", + snippet, + ) snippet = re.sub('root[^[]*\["[^"]*"\] = [^";]+', 'root?["?"] = ?', snippet) snippet = re.sub('root[^[]*\["[^"]*"\] = "[^"]+"', 'root?["?"] = "?"', snippet) - snippet = re.sub('rootl.add\([^)]*\)', 'rootl.add(?)', snippet) + snippet = re.sub("rootl.add\([^)]*\)", "rootl.add(?)", snippet) - snippet = re.sub('^dec_[^ ]*', 'dec_?', snippet) - if lib == 'arduinojson:': - snippet = re.sub('root[^[]*\[[^]"]+\]\.as', 'root[?].as', snippet) + snippet = re.sub("^dec_[^ ]*", "dec_?", snippet) + if lib == "arduinojson:": + snippet = re.sub('root[^[]*\[[^]"]+\]\.as', "root[?].as", snippet) snippet = re.sub('root[^[]*\["[^]]+"\]\.as', 'root["?"].as', snippet) - elif 'nanopb:' in lib: - snippet = re.sub('= msg\.[^;]+;', '= msg.?;', snippet) - elif lib == 'mpack:': - snippet = re.sub('mpack_expect_([^_]+)_max\(&reader, [^)]+\)', 'mpack_expect_\\1_max(&reader, ?)', snippet) - elif lib == 'ubjson:': - snippet = re.sub('[^_ ]+_values[^.]+\.', '?_values[?].', snippet) + elif "nanopb:" in lib: + snippet = re.sub("= msg\.[^;]+;", "= msg.?;", snippet) + elif lib == "mpack:": + snippet = re.sub( + "mpack_expect_([^_]+)_max\(&reader, [^)]+\)", + "mpack_expect_\\1_max(&reader, ?)", + snippet, + ) + elif lib == "ubjson:": + snippet = re.sub("[^_ ]+_values[^.]+\.", "?_values[?].", snippet) return snippet diff --git a/lib/pubcode/__init__.py b/lib/pubcode/__init__.py index 30c8490..3007bdf 100644 --- a/lib/pubcode/__init__.py +++ b/lib/pubcode/__init__.py @@ -1,6 +1,6 @@ """A simple module for creating barcodes. """ -__version__ = '1.1.0' +__version__ = "1.1.0" from .code128 import Code128 diff --git a/lib/pubcode/code128.py b/lib/pubcode/code128.py index 1c37f37..4fd7aed 100644 --- a/lib/pubcode/code128.py +++ b/lib/pubcode/code128.py @@ -5,6 +5,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera from builtins import * # Use Python3-like builtins for Python2. import base64 import io + try: from PIL import Image except ImportError: @@ -34,74 +35,189 @@ class Code128(object): # List of bar and space weights, indexed by symbol character values (0-105), and the STOP character (106). # The first weights is a bar and then it alternates. _val2bars = [ - '212222', '222122', '222221', '121223', '121322', '131222', '122213', '122312', '132212', '221213', - '221312', '231212', '112232', '122132', '122231', '113222', '123122', '123221', '223211', '221132', - '221231', '213212', '223112', '312131', '311222', '321122', '321221', '312212', '322112', '322211', - '212123', '212321', '232121', '111323', '131123', '131321', '112313', '132113', '132311', '211313', - '231113', '231311', '112133', '112331', '132131', '113123', '113321', '133121', '313121', '211331', - '231131', '213113', '213311', '213131', '311123', '311321', '331121', '312113', '312311', '332111', - '314111', '221411', '431111', '111224', '111422', '121124', '121421', '141122', '141221', '112214', - '112412', '122114', '122411', '142112', '142211', '241211', '221114', '413111', '241112', '134111', - '111242', '121142', '121241', '114212', '124112', '124211', '411212', '421112', '421211', '212141', - '214121', '412121', '111143', '111341', '131141', '114113', '114311', '411113', '411311', '113141', - '114131', '311141', '411131', '211412', '211214', '211232', '2331112' + "212222", + "222122", + "222221", + "121223", + "121322", + "131222", + "122213", + "122312", + "132212", + "221213", + "221312", + "231212", + "112232", + "122132", + "122231", + "113222", + "123122", + "123221", + "223211", + "221132", + "221231", + "213212", + "223112", + "312131", + "311222", + "321122", + "321221", + "312212", + "322112", + "322211", + "212123", + "212321", + "232121", + "111323", + "131123", + "131321", + "112313", + "132113", + "132311", + "211313", + "231113", + "231311", + "112133", + "112331", + "132131", + "113123", + "113321", + "133121", + "313121", + "211331", + "231131", + "213113", + "213311", + "213131", + "311123", + "311321", + "331121", + "312113", + "312311", + "332111", + "314111", + "221411", + "431111", + "111224", + "111422", + "121124", + "121421", + "141122", + "141221", + "112214", + "112412", + "122114", + "122411", + "142112", + "142211", + "241211", + "221114", + "413111", + "241112", + "134111", + "111242", + "121142", + "121241", + "114212", + "124112", + "124211", + "411212", + "421112", + "421211", + "212141", + "214121", + "412121", + "111143", + "111341", + "131141", + "114113", + "114311", + "411113", + "411311", + "113141", + "114131", + "311141", + "411131", + "211412", + "211214", + "211232", + "2331112", ] class Special(object): """These are special characters used by the Code128 encoding.""" - START_A = '[Start Code A]' - START_B = '[Start Code B]' - START_C = '[Start Code C]' - CODE_A = '[Code A]' - CODE_B = '[Code B]' - CODE_C = '[Code C]' - SHIFT_A = '[Shift A]' - SHIFT_B = '[Shift B]' - FNC_1 = '[FNC 1]' - FNC_2 = '[FNC 2]' - FNC_3 = '[FNC 3]' - FNC_4 = '[FNC 4]' - STOP = '[Stop]' - - _start_codes = {'A': Special.START_A, 'B': Special.START_B, 'C': Special.START_C} - _char_codes = {'A': Special.CODE_A, 'B': Special.CODE_B, 'C': Special.CODE_C} + + START_A = "[Start Code A]" + START_B = "[Start Code B]" + START_C = "[Start Code C]" + CODE_A = "[Code A]" + CODE_B = "[Code B]" + CODE_C = "[Code C]" + SHIFT_A = "[Shift A]" + SHIFT_B = "[Shift B]" + FNC_1 = "[FNC 1]" + FNC_2 = "[FNC 2]" + FNC_3 = "[FNC 3]" + FNC_4 = "[FNC 4]" + STOP = "[Stop]" + + _start_codes = {"A": Special.START_A, "B": Special.START_B, "C": Special.START_C} + _char_codes = {"A": Special.CODE_A, "B": Special.CODE_B, "C": Special.CODE_C} # Lists mapping symbol values to characters in each character set. This defines the alphabet and Code128._sym2val # is derived from this structure. _val2sym = { # Code Set A includes ordinals 0 through 95 and 7 special characters. The ordinals include digits, # upper case characters, punctuation and control characters. - 'A': - [chr(x) for x in range(32, 95 + 1)] + - [chr(x) for x in range(0, 31 + 1)] + - [ - Special.FNC_3, Special.FNC_2, Special.SHIFT_B, Special.CODE_C, - Special.CODE_B, Special.FNC_4, Special.FNC_1, - Special.START_A, Special.START_B, Special.START_C, Special.STOP - ], + "A": [chr(x) for x in range(32, 95 + 1)] + + [chr(x) for x in range(0, 31 + 1)] + + [ + Special.FNC_3, + Special.FNC_2, + Special.SHIFT_B, + Special.CODE_C, + Special.CODE_B, + Special.FNC_4, + Special.FNC_1, + Special.START_A, + Special.START_B, + Special.START_C, + Special.STOP, + ], # Code Set B includes ordinals 32 through 127 and 7 special characters. The ordinals include digits, # upper and lover case characters and punctuation. - 'B': - [chr(x) for x in range(32, 127 + 1)] + - [ - Special.FNC_3, Special.FNC_2, Special.SHIFT_A, Special.CODE_C, - Special.FNC_4, Special.CODE_A, Special.FNC_1, - Special.START_A, Special.START_B, Special.START_C, Special.STOP - ], + "B": [chr(x) for x in range(32, 127 + 1)] + + [ + Special.FNC_3, + Special.FNC_2, + Special.SHIFT_A, + Special.CODE_C, + Special.FNC_4, + Special.CODE_A, + Special.FNC_1, + Special.START_A, + Special.START_B, + Special.START_C, + Special.STOP, + ], # Code Set C includes all pairs of 2 digits and 3 special characters. - 'C': - ['%02d' % (x,) for x in range(0, 99 + 1)] + - [ - Special.CODE_B, Special.CODE_A, Special.FNC_1, - Special.START_A, Special.START_B, Special.START_C, Special.STOP - ], + "C": ["%02d" % (x,) for x in range(0, 99 + 1)] + + [ + Special.CODE_B, + Special.CODE_A, + Special.FNC_1, + Special.START_A, + Special.START_B, + Special.START_C, + Special.STOP, + ], } # Dicts mapping characters to symbol values in each character set. _sym2val = { - 'A': {char: val for val, char in enumerate(_val2sym['A'])}, - 'B': {char: val for val, char in enumerate(_val2sym['B'])}, - 'C': {char: val for val, char in enumerate(_val2sym['C'])}, + "A": {char: val for val, char in enumerate(_val2sym["A"])}, + "B": {char: val for val, char in enumerate(_val2sym["B"])}, + "C": {char: val for val, char in enumerate(_val2sym["C"])}, } # How large the quiet zone is on either side of the barcode, when quiet zone is used. @@ -121,13 +237,13 @@ class Code128(object): """ self._validate_charset(data, charset) - if charset in ('A', 'B'): + if charset in ("A", "B"): charset *= len(data) - elif charset in ('C',): - charset *= (len(data) // 2) + elif charset in ("C",): + charset *= len(data) // 2 if len(data) % 2 == 1: # If there are an odd number of characters for charset C, encode the last character with charset B. - charset += 'B' + charset += "B" self.data = data self.symbol_values = self._encode(data, charset) @@ -148,13 +264,13 @@ class Code128(object): if len(charset) > 1: charset_data_length = 0 for symbol_charset in charset: - if symbol_charset not in ('A', 'B', 'C'): + if symbol_charset not in ("A", "B", "C"): raise Code128.CharsetError - charset_data_length += 2 if symbol_charset is 'C' else 1 + charset_data_length += 2 if symbol_charset is "C" else 1 if charset_data_length != len(data): raise Code128.CharsetLengthError elif len(charset) == 1: - if charset not in ('A', 'B', 'C'): + if charset not in ("A", "B", "C"): raise Code128.CharsetError elif charset is not None: raise Code128.CharsetError @@ -182,10 +298,12 @@ class Code128(object): if charset is not prev_charset: # Handle a special case of there being a single A in middle of two B's or the other way around, where # using a single shift character is more efficient than using two character set switches. - next_charset = charsets[symbol_num + 1] if symbol_num + 1 < len(charsets) else None - if charset == 'A' and prev_charset == next_charset == 'B': + next_charset = ( + charsets[symbol_num + 1] if symbol_num + 1 < len(charsets) else None + ) + if charset == "A" and prev_charset == next_charset == "B": result.append(cls._sym2val[prev_charset][cls.Special.SHIFT_A]) - elif charset == 'B' and prev_charset == next_charset == 'A': + elif charset == "B" and prev_charset == next_charset == "A": result.append(cls._sym2val[prev_charset][cls.Special.SHIFT_B]) else: # This is the normal case. @@ -193,7 +311,7 @@ class Code128(object): result.append(cls._sym2val[prev_charset][charset_symbol]) prev_charset = charset - nxt = cur + (2 if charset == 'C' else 1) + nxt = cur + (2 if charset == "C" else 1) symbol = data[cur:nxt] cur = nxt result.append(cls._sym2val[charset][symbol]) @@ -206,9 +324,10 @@ class Code128(object): @property def symbols(self): """List of the coded symbols as strings, with special characters included.""" + def _iter_symbols(symbol_values): # The initial charset doesn't matter, as the start codes have the same symbol values in all charsets. - charset = 'A' + charset = "A" shift_charset = None for symbol_value in symbol_values: @@ -219,15 +338,15 @@ class Code128(object): symbol = self._val2sym[charset][symbol_value] if symbol in (self.Special.START_A, self.Special.CODE_A): - charset = 'A' + charset = "A" elif symbol in (self.Special.START_B, self.Special.CODE_B): - charset = 'B' + charset = "B" elif symbol in (self.Special.START_C, self.Special.CODE_C): - charset = 'C' + charset = "C" elif symbol in (self.Special.SHIFT_A,): - shift_charset = 'A' + shift_charset = "A" elif symbol in (self.Special.SHIFT_B,): - shift_charset = 'B' + shift_charset = "B" yield symbol @@ -243,7 +362,7 @@ class Code128(object): :rtype: string """ - return ''.join(map((lambda val: self._val2bars[val]), self.symbol_values)) + return "".join(map((lambda val: self._val2bars[val]), self.symbol_values)) @property def modules(self): @@ -255,6 +374,7 @@ class Code128(object): :rtype: list[int] """ + def _iterate_modules(bars): is_bar = True for char in map(int, bars): @@ -288,7 +408,9 @@ class Code128(object): :return: A monochromatic image containing the barcode as black bars on white background. """ if Image is None: - raise Code128.MissingDependencyError("PIL module is required to use image method.") + raise Code128.MissingDependencyError( + "PIL module is required to use image method." + ) modules = list(self.modules) if add_quiet_zone: @@ -296,7 +418,7 @@ class Code128(object): modules = [1] * self.quiet_zone + modules + [1] * self.quiet_zone width = len(modules) - img = Image.new(mode='1', size=(width, 1)) + img = Image.new(mode="1", size=(width, 1)) img.putdata(modules) if height == 1 and module_width == 1: @@ -305,7 +427,7 @@ class Code128(object): new_size = (width * module_width, height) return img.resize(new_size, resample=Image.NEAREST) - def data_url(self, image_format='png', add_quiet_zone=True): + def data_url(self, image_format="png", add_quiet_zone=True): """Get a data URL representing the barcode. >>> barcode = Code128('Hello!', charset='B') @@ -327,22 +449,21 @@ class Code128(object): # Using BMP can often result in smaller data URLs than PNG, but it isn't as widely supported by browsers as PNG. # GIFs result in data URLs 10 times bigger than PNG or BMP, possibly due to lack of support for monochrome GIFs # in Pillow, so they shouldn't be used. - if image_format == 'png': + if image_format == "png": # Unfortunately there is no way to avoid adding the zlib headers. # Using compress_level=0 sometimes results in a slightly bigger data size (by a few bytes), but there # doesn't appear to be a difference between levels 9 and 1, so let's just use 1. - pil_image.save(memory_file, format='png', compress_level=1) - elif image_format == 'bmp': - pil_image.save(memory_file, format='bmp') + pil_image.save(memory_file, format="png", compress_level=1) + elif image_format == "bmp": + pil_image.save(memory_file, format="bmp") else: - raise Code128.UnknownFormatError('Only png and bmp are supported.') + raise Code128.UnknownFormatError("Only png and bmp are supported.") # Encode the data in the BytesIO object and convert the result into unicode. - base64_image = base64.b64encode(memory_file.getvalue()).decode('ascii') + base64_image = base64.b64encode(memory_file.getvalue()).decode("ascii") - data_url = 'data:image/{format};base64,{base64_data}'.format( - format=image_format, - base64_data=base64_image + data_url = "data:image/{format};base64,{base64_data}".format( + format=image_format, base64_data=base64_image ) return data_url diff --git a/lib/runner.py b/lib/runner.py index 85df2b6..16f0a29 100644 --- a/lib/runner.py +++ b/lib/runner.py @@ -30,7 +30,7 @@ class SerialReader(serial.threaded.Protocol): def __init__(self, callback=None): """Create a new SerialReader object.""" self.callback = callback - self.recv_buf = '' + self.recv_buf = "" self.lines = [] def __call__(self): @@ -39,13 +39,13 @@ class SerialReader(serial.threaded.Protocol): def data_received(self, data): """Append newly received serial data to the line buffer.""" try: - str_data = data.decode('UTF-8') + str_data = data.decode("UTF-8") self.recv_buf += str_data # We may get anything between \r\n, \n\r and simple \n newlines. # We assume that \n is always present and use str.strip to remove leading/trailing \r symbols # Note: Do not call str.strip on lines[-1]! Otherwise, lines may be mangled - lines = self.recv_buf.split('\n') + lines = self.recv_buf.split("\n") if len(lines) > 1: self.lines.extend(map(str.strip, lines[:-1])) self.recv_buf = lines[-1] @@ -94,14 +94,16 @@ class SerialMonitor: """ self.ser = serial.serial_for_url(port, do_not_open=True) self.ser.baudrate = baud - self.ser.parity = 'N' + self.ser.parity = "N" self.ser.rtscts = False self.ser.xonxoff = False try: self.ser.open() except serial.SerialException as e: - sys.stderr.write('Could not open serial port {}: {}\n'.format(self.ser.name, e)) + sys.stderr.write( + "Could not open serial port {}: {}\n".format(self.ser.name, e) + ) sys.exit(1) self.reader = SerialReader(callback=callback) @@ -131,6 +133,7 @@ class SerialMonitor: self.worker.stop() self.ser.close() + # TODO Optionale Kalibrierung mit bekannten Widerständen an GPIOs am Anfang # TODO Sync per LED? -> Vor und ggf nach jeder Transition kurz pulsen # TODO Für Verbraucher mit wenig Energiebedarf: Versorgung direkt per GPIO @@ -143,14 +146,14 @@ class EnergyTraceMonitor(SerialMonitor): def __init__(self, port: str, baud: int, callback=None, voltage=3.3): super().__init__(port=port, baud=baud, callback=callback) self._voltage = voltage - self._output = time.strftime('%Y%m%d-%H%M%S.etlog') + self._output = time.strftime("%Y%m%d-%H%M%S.etlog") self._start_energytrace() def _start_energytrace(self): - cmd = ['msp430-etv', '--save', self._output, '0'] - self._logger = subprocess.Popen(cmd, - stdout=subprocess.PIPE, stderr=subprocess.PIPE, - universal_newlines=True) + cmd = ["msp430-etv", "--save", self._output, "0"] + self._logger = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True + ) def close(self): super().close() @@ -162,14 +165,16 @@ class EnergyTraceMonitor(SerialMonitor): def get_config(self) -> dict: return { - 'voltage': self._voltage, + "voltage": self._voltage, } class MIMOSAMonitor(SerialMonitor): """MIMOSAMonitor captures serial output and MIMOSA energy data for a specific amount of time.""" - def __init__(self, port: str, baud: int, callback=None, offset=130, shunt=330, voltage=3.3): + def __init__( + self, port: str, baud: int, callback=None, offset=130, shunt=330, voltage=3.3 + ): super().__init__(port=port, baud=baud, callback=callback) self._offset = offset self._shunt = shunt @@ -177,39 +182,41 @@ class MIMOSAMonitor(SerialMonitor): self._start_mimosa() def _mimosactl(self, subcommand): - cmd = ['mimosactl'] + cmd = ["mimosactl"] cmd.append(subcommand) res = subprocess.run(cmd) if res.returncode != 0: res = subprocess.run(cmd) if res.returncode != 0: - raise RuntimeError('{} returned {}'.format(' '.join(cmd), res.returncode)) + raise RuntimeError( + "{} returned {}".format(" ".join(cmd), res.returncode) + ) def _mimosacmd(self, opts): - cmd = ['MimosaCMD'] + cmd = ["MimosaCMD"] cmd.extend(opts) res = subprocess.run(cmd) if res.returncode != 0: - raise RuntimeError('{} returned {}'.format(' '.join(cmd), res.returncode)) + raise RuntimeError("{} returned {}".format(" ".join(cmd), res.returncode)) def _start_mimosa(self): - self._mimosactl('disconnect') - self._mimosacmd(['--start']) - self._mimosacmd(['--parameter', 'offset', str(self._offset)]) - self._mimosacmd(['--parameter', 'shunt', str(self._shunt)]) - self._mimosacmd(['--parameter', 'voltage', str(self._voltage)]) - self._mimosacmd(['--mimosa-start']) + self._mimosactl("disconnect") + self._mimosacmd(["--start"]) + self._mimosacmd(["--parameter", "offset", str(self._offset)]) + self._mimosacmd(["--parameter", "shunt", str(self._shunt)]) + self._mimosacmd(["--parameter", "voltage", str(self._voltage)]) + self._mimosacmd(["--mimosa-start"]) time.sleep(2) - self._mimosactl('1k') # 987 ohm + self._mimosactl("1k") # 987 ohm time.sleep(2) - self._mimosactl('100k') # 99.3 kohm + self._mimosactl("100k") # 99.3 kohm time.sleep(2) - self._mimosactl('connect') + self._mimosactl("connect") def _stop_mimosa(self): # Make sure the MIMOSA daemon has gathered all needed data time.sleep(2) - self._mimosacmd(['--mimosa-stop']) + self._mimosacmd(["--mimosa-stop"]) mtime_changed = True mim_file = None time.sleep(1) @@ -218,7 +225,7 @@ class MIMOSAMonitor(SerialMonitor): # files lying around in the directory will not confuse our # heuristic. for filename in sorted(os.listdir(), reverse=True): - if re.search(r'[.]mim$', filename): + if re.search(r"[.]mim$", filename): mim_file = filename break while mtime_changed: @@ -226,7 +233,7 @@ class MIMOSAMonitor(SerialMonitor): if time.time() - os.stat(mim_file).st_mtime < 3: mtime_changed = True time.sleep(1) - self._mimosacmd(['--stop']) + self._mimosacmd(["--stop"]) return mim_file def close(self): @@ -238,9 +245,9 @@ class MIMOSAMonitor(SerialMonitor): def get_config(self) -> dict: return { - 'offset': self._offset, - 'shunt': self._shunt, - 'voltage': self._voltage, + "offset": self._offset, + "shunt": self._shunt, + "voltage": self._voltage, } @@ -263,14 +270,17 @@ class ShellMonitor: stderr and return status are discarded at the moment. """ if type(timeout) != int: - raise ValueError('timeout argument must be int') - res = subprocess.run(['timeout', '{:d}s'.format(timeout), self.script], - stdout=subprocess.PIPE, stderr=subprocess.PIPE, - universal_newlines=True) + raise ValueError("timeout argument must be int") + res = subprocess.run( + ["timeout", "{:d}s".format(timeout), self.script], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + ) if self.callback: - for line in res.stdout.split('\n'): + for line in res.stdout.split("\n"): self.callback(line) - return res.stdout.split('\n') + return res.stdout.split("\n") def monitor(self): raise NotImplementedError @@ -285,28 +295,35 @@ class ShellMonitor: def build(arch, app, opts=[]): - command = ['make', 'arch={}'.format(arch), 'app={}'.format(app), 'clean'] + command = ["make", "arch={}".format(arch), "app={}".format(app), "clean"] command.extend(opts) - res = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - universal_newlines=True) + res = subprocess.run( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True + ) if res.returncode != 0: - raise RuntimeError('Build failure, executing {}:\n'.format(command) + res.stderr) - command = ['make', '-B', 'arch={}'.format(arch), 'app={}'.format(app)] + raise RuntimeError( + "Build failure, executing {}:\n".format(command) + res.stderr + ) + command = ["make", "-B", "arch={}".format(arch), "app={}".format(app)] command.extend(opts) - res = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - universal_newlines=True) + res = subprocess.run( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True + ) if res.returncode != 0: - raise RuntimeError('Build failure, executing {}:\n '.format(command) + res.stderr) + raise RuntimeError( + "Build failure, executing {}:\n ".format(command) + res.stderr + ) return command def flash(arch, app, opts=[]): - command = ['make', 'arch={}'.format(arch), 'app={}'.format(app), 'program'] + command = ["make", "arch={}".format(arch), "app={}".format(app), "program"] command.extend(opts) - res = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - universal_newlines=True) + res = subprocess.run( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True + ) if res.returncode != 0: - raise RuntimeError('Flash failure') + raise RuntimeError("Flash failure") return command @@ -316,13 +333,14 @@ def get_info(arch, opts: list = []) -> list: Returns a list. """ - command = ['make', 'arch={}'.format(arch), 'info'] + command = ["make", "arch={}".format(arch), "info"] command.extend(opts) - res = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - universal_newlines=True) + res = subprocess.run( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True + ) if res.returncode != 0: - raise RuntimeError('make info Failure') - return res.stdout.split('\n') + raise RuntimeError("make info Failure") + return res.stdout.split("\n") def get_monitor(arch: str, **kwargs) -> object: @@ -336,32 +354,32 @@ def get_monitor(arch: str, **kwargs) -> object: :param mimosa: `MIMOSAMonitor` options. Returns a MIMOSA monitor if not None. """ for line in get_info(arch): - if 'Monitor:' in line: - _, port, arg = line.split(' ') - if port == 'run': + if "Monitor:" in line: + _, port, arg = line.split(" ") + if port == "run": return ShellMonitor(arg, **kwargs) - elif 'mimosa' in kwargs and kwargs['mimosa'] is not None: - mimosa_kwargs = kwargs.pop('mimosa') + elif "mimosa" in kwargs and kwargs["mimosa"] is not None: + mimosa_kwargs = kwargs.pop("mimosa") return MIMOSAMonitor(port, arg, **mimosa_kwargs, **kwargs) - elif 'energytrace' in kwargs and kwargs['energytrace'] is not None: - energytrace_kwargs = kwargs.pop('energytrace') + elif "energytrace" in kwargs and kwargs["energytrace"] is not None: + energytrace_kwargs = kwargs.pop("energytrace") return EnergyTraceMonitor(port, arg, **energytrace_kwargs, **kwargs) else: - kwargs.pop('energytrace', None) - kwargs.pop('mimosa', None) + kwargs.pop("energytrace", None) + kwargs.pop("mimosa", None) return SerialMonitor(port, arg, **kwargs) - raise RuntimeError('Monitor failure') + raise RuntimeError("Monitor failure") def get_counter_limits(arch: str) -> tuple: """Return multipass max counter and max overflow value for arch.""" for line in get_info(arch): - match = re.match('Counter Overflow: ([^/]*)/(.*)', line) + match = re.match("Counter Overflow: ([^/]*)/(.*)", line) if match: overflow_value = int(match.group(1)) max_overflow = int(match.group(2)) return overflow_value, max_overflow - raise RuntimeError('Did not find Counter Overflow limits') + raise RuntimeError("Did not find Counter Overflow limits") def get_counter_limits_us(arch: str) -> tuple: @@ -370,13 +388,13 @@ def get_counter_limits_us(arch: str) -> tuple: overflow_value = 0 max_overflow = 0 for line in get_info(arch): - match = re.match(r'CPU\s+Freq:\s+(.*)\s+Hz', line) + match = re.match(r"CPU\s+Freq:\s+(.*)\s+Hz", line) if match: cpu_freq = int(match.group(1)) - match = re.match(r'Counter Overflow:\s+([^/]*)/(.*)', line) + match = re.match(r"Counter Overflow:\s+([^/]*)/(.*)", line) if match: overflow_value = int(match.group(1)) max_overflow = int(match.group(2)) if cpu_freq and overflow_value: return 1000000 / cpu_freq, overflow_value * 1000000 / cpu_freq, max_overflow - raise RuntimeError('Did not find Counter Overflow limits') + raise RuntimeError("Did not find Counter Overflow limits") diff --git a/lib/size_to_radio_energy.py b/lib/size_to_radio_energy.py index a3cf7c5..10de1a3 100644 --- a/lib/size_to_radio_energy.py +++ b/lib/size_to_radio_energy.py @@ -7,38 +7,42 @@ class can convert a cycle count to an energy consumption. import numpy as np + def get_class(radio_name: str): """Return model class for radio_name.""" - if radio_name == 'CC1200tx': + if radio_name == "CC1200tx": return CC1200tx - if radio_name == 'CC1200rx': + if radio_name == "CC1200rx": return CC1200rx - if radio_name == 'NRF24L01tx': + if radio_name == "NRF24L01tx": return NRF24L01tx - if radio_name == 'NRF24L01dtx': + if radio_name == "NRF24L01dtx": return NRF24L01dtx - if radio_name == 'esp8266dtx': + if radio_name == "esp8266dtx": return ESP8266dtx - if radio_name == 'esp8266drx': + if radio_name == "esp8266drx": return ESP8266drx + def _param_list_to_dict(device, param_list): param_dict = dict() for i, parameter in enumerate(sorted(device.parameters.keys())): param_dict[parameter] = param_list[i] return param_dict + class CC1200tx: """CC1200 TX energy based on aemr measurements.""" - name = 'CC1200tx' + + name = "CC1200tx" parameters = { - 'symbolrate' : [6, 12, 25, 50, 100, 200, 250], # ksps - 'txbytes' : [], - 'txpower' : [10, 20, 30, 40, 47], # dBm = f(txpower) + "symbolrate": [6, 12, 25, 50, 100, 200, 250], # ksps + "txbytes": [], + "txpower": [10, 20, 30, 40, 47], # dBm = f(txpower) } default_params = { - 'symbolrate' : 100, - 'txpower' : 47, + "symbolrate": 100, + "txpower": 47, } @staticmethod @@ -48,106 +52,129 @@ class CC1200tx: # Mittlere TX-Leistung, gefitted von AEMR # Messdaten erhoben bei 3.6V - power = 8.18053941e+04 - power -= 1.24208376e+03 * np.sqrt(params['symbolrate']) - power -= 5.73742779e+02 * np.log(params['txbytes']) - power += 1.76945886e+01 * (params['txpower'])**2 - power += 2.33469617e+02 * np.sqrt(params['symbolrate']) * np.log(params['txbytes']) - power -= 6.99137635e-01 * np.sqrt(params['symbolrate']) * (params['txpower'])**2 - power -= 3.31365158e-01 * np.log(params['txbytes']) * (params['txpower'])**2 - power += 1.32784945e-01 * np.sqrt(params['symbolrate']) * np.log(params['txbytes']) * (params['txpower'])**2 + power = 8.18053941e04 + power -= 1.24208376e03 * np.sqrt(params["symbolrate"]) + power -= 5.73742779e02 * np.log(params["txbytes"]) + power += 1.76945886e01 * (params["txpower"]) ** 2 + power += ( + 2.33469617e02 * np.sqrt(params["symbolrate"]) * np.log(params["txbytes"]) + ) + power -= ( + 6.99137635e-01 * np.sqrt(params["symbolrate"]) * (params["txpower"]) ** 2 + ) + power -= 3.31365158e-01 * np.log(params["txbytes"]) * (params["txpower"]) ** 2 + power += ( + 1.32784945e-01 + * np.sqrt(params["symbolrate"]) + * np.log(params["txbytes"]) + * (params["txpower"]) ** 2 + ) # txDone-Timeout, gefitted von AEMR - duration = 3.65513500e+02 - duration += 8.01016526e+04 * 1/(params['symbolrate']) - duration -= 7.06364515e-03 * params['txbytes'] - duration += 8.00029860e+03 * 1/(params['symbolrate']) * params['txbytes'] + duration = 3.65513500e02 + duration += 8.01016526e04 * 1 / (params["symbolrate"]) + duration -= 7.06364515e-03 * params["txbytes"] + duration += 8.00029860e03 * 1 / (params["symbolrate"]) * params["txbytes"] # TX-Energie, gefitted von AEMR # Achtung: Energy ist in µJ, nicht (wie in AEMR-Transitionsmodellen üblich) in pJ # Messdaten erhoben bei 3.6V - energy = 1.74383259e+01 - energy += 6.29922138e+03 * 1/(params['symbolrate']) - energy += 1.13307135e-02 * params['txbytes'] - energy -= 1.28121377e-04 * (params['txpower'])**2 - energy += 6.29080184e+02 * 1/(params['symbolrate']) * params['txbytes'] - energy += 1.25647926e+00 * 1/(params['symbolrate']) * (params['txpower'])**2 - energy += 1.31996202e-05 * params['txbytes'] * (params['txpower'])**2 - energy += 1.25676966e-01 * 1/(params['symbolrate']) * params['txbytes'] * (params['txpower'])**2 + energy = 1.74383259e01 + energy += 6.29922138e03 * 1 / (params["symbolrate"]) + energy += 1.13307135e-02 * params["txbytes"] + energy -= 1.28121377e-04 * (params["txpower"]) ** 2 + energy += 6.29080184e02 * 1 / (params["symbolrate"]) * params["txbytes"] + energy += 1.25647926e00 * 1 / (params["symbolrate"]) * (params["txpower"]) ** 2 + energy += 1.31996202e-05 * params["txbytes"] * (params["txpower"]) ** 2 + energy += ( + 1.25676966e-01 + * 1 + / (params["symbolrate"]) + * params["txbytes"] + * (params["txpower"]) ** 2 + ) return energy * 1e-6 @staticmethod def get_energy_per_byte(params): - A = 8.18053941e+04 - A -= 1.24208376e+03 * np.sqrt(params['symbolrate']) - A += 1.76945886e+01 * (params['txpower'])**2 - A -= 6.99137635e-01 * np.sqrt(params['symbolrate']) * (params['txpower'])**2 - B = -5.73742779e+02 - B += 2.33469617e+02 * np.sqrt(params['symbolrate']) - B -= 3.31365158e-01 * (params['txpower'])**2 - B += 1.32784945e-01 * np.sqrt(params['symbolrate']) * (params['txpower'])**2 - C = 3.65513500e+02 - C += 8.01016526e+04 * 1/(params['symbolrate']) - D = -7.06364515e-03 - D += 8.00029860e+03 * 1/(params['symbolrate']) - - x = params['txbytes'] + A = 8.18053941e04 + A -= 1.24208376e03 * np.sqrt(params["symbolrate"]) + A += 1.76945886e01 * (params["txpower"]) ** 2 + A -= 6.99137635e-01 * np.sqrt(params["symbolrate"]) * (params["txpower"]) ** 2 + B = -5.73742779e02 + B += 2.33469617e02 * np.sqrt(params["symbolrate"]) + B -= 3.31365158e-01 * (params["txpower"]) ** 2 + B += 1.32784945e-01 * np.sqrt(params["symbolrate"]) * (params["txpower"]) ** 2 + C = 3.65513500e02 + C += 8.01016526e04 * 1 / (params["symbolrate"]) + D = -7.06364515e-03 + D += 8.00029860e03 * 1 / (params["symbolrate"]) + + x = params["txbytes"] # in pJ - de_dx = A * D + B * C * 1/x + B * D * (np.log(x) + 1) + de_dx = A * D + B * C * 1 / x + B * D * (np.log(x) + 1) # in µJ - de_dx = 1.13307135e-02 - de_dx += 6.29080184e+02 * 1/(params['symbolrate']) - de_dx += 1.31996202e-05 * (params['txpower'])**2 - de_dx += 1.25676966e-01 * 1/(params['symbolrate']) * (params['txpower'])**2 + de_dx = 1.13307135e-02 + de_dx += 6.29080184e02 * 1 / (params["symbolrate"]) + de_dx += 1.31996202e-05 * (params["txpower"]) ** 2 + de_dx += 1.25676966e-01 * 1 / (params["symbolrate"]) * (params["txpower"]) ** 2 - #de_dx = (B * 1/x) * (C + D * x) + (A + B * np.log(x)) * D + # de_dx = (B * 1/x) * (C + D * x) + (A + B * np.log(x)) * D return de_dx * 1e-6 + class CC1200rx: """CC1200 RX energy based on aemr measurements.""" - name = 'CC1200rx' + + name = "CC1200rx" parameters = { - 'symbolrate' : [6, 12, 25, 50, 100, 200, 250], # ksps - 'txbytes' : [], - 'txpower' : [10, 20, 30, 40, 47], # dBm = f(txpower) + "symbolrate": [6, 12, 25, 50, 100, 200, 250], # ksps + "txbytes": [], + "txpower": [10, 20, 30, 40, 47], # dBm = f(txpower) } default_params = { - 'symbolrate' : 100, - 'txpower' : 47, + "symbolrate": 100, + "txpower": 47, } @staticmethod def get_energy(params): # TODO - return params['txbytes'] * CC1200rx.get_energy_per_byte(params) + return params["txbytes"] * CC1200rx.get_energy_per_byte(params) @staticmethod def get_energy_per_byte(params): - #RX : 0 + regression_arg(0) + regression_arg(1) * np.log(parameter(symbolrate) + 1) + # RX : 0 + regression_arg(0) + regression_arg(1) * np.log(parameter(symbolrate) + 1) # [84414.91636169 205.63323036] - de_dx = (84414.91636169 + 205.63323036 * np.log(params['symbolrate'] + 1)) * 8000 / params['symbolrate'] + de_dx = ( + (84414.91636169 + 205.63323036 * np.log(params["symbolrate"] + 1)) + * 8000 + / params["symbolrate"] + ) return de_dx * 1e-12 + class NRF24L01rx: """NRF24L01+ RX energy based on aemr measurements (using variable packet size)""" - name = 'NRF24L01' + + name = "NRF24L01" parameters = { - 'datarate' : [250, 1000, 2000], # kbps - 'txbytes' : [], - 'txpower' : [-18, -12, -6, 0], # dBm - 'voltage' : [1.9, 3.6], + "datarate": [250, 1000, 2000], # kbps + "txbytes": [], + "txpower": [-18, -12, -6, 0], # dBm + "voltage": [1.9, 3.6], } default_params = { - 'datarate' : 1000, - 'txpower' : -6, - 'voltage' : 3, + "datarate": 1000, + "txpower": -6, + "voltage": 3, } @staticmethod @@ -155,35 +182,40 @@ class NRF24L01rx: # RX : 0 + regression_arg(0) + regression_arg(1) * np.sqrt(parameter(datarate)) # [48530.73235537 117.25274402] - de_dx = (48530.73235537 + 117.25274402 * np.sqrt(params['datarate'])) * 8000 / params['datarate'] + de_dx = ( + (48530.73235537 + 117.25274402 * np.sqrt(params["datarate"])) + * 8000 + / params["datarate"] + ) return de_dx * 1e-12 + # PYTHONPATH=lib bin/analyze-archive.py --show-model=all --show-quality=table ../data/*_RF24_no_retries.tar class NRF24L01tx: """NRF24L01+ TX(*) energy based on aemr measurements (32B fixed packet size, (*)ack-await, no retries).""" - name = 'NRF24L01' + + name = "NRF24L01" parameters = { - 'datarate' : [250, 1000, 2000], # kbps - 'txbytes' : [], - 'txpower' : [-18, -12, -6, 0], # dBm - 'voltage' : [1.9, 3.6], + "datarate": [250, 1000, 2000], # kbps + "txbytes": [], + "txpower": [-18, -12, -6, 0], # dBm + "voltage": [1.9, 3.6], } default_params = { - 'datarate' : 1000, - 'txpower' : -6, - 'voltage' : 3, + "datarate": 1000, + "txpower": -6, + "voltage": 3, } -# AEMR: -# TX power / energy: -#TX : 0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate)) + regression_arg(2) * (19.47+parameter(txpower))**2 + regression_arg(3) * 1/(parameter(datarate)) * (19.47+parameter(txpower))**2 -# [6.30323056e+03 2.59889924e+06 7.82186268e+00 8.69746093e+03] -#TX : 0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate)) + regression_arg(2) * (19.47+parameter(txpower))**2 + regression_arg(3) * 1/(parameter(datarate)) * (19.47+parameter(txpower))**2 -# [7.67932887e+00 1.02969455e+04 4.55116475e-03 2.99786534e+01] -#epilogue : timeout : 0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate)) -# [ 1624.06589147 332251.93798766] - + # AEMR: + # TX power / energy: + # TX : 0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate)) + regression_arg(2) * (19.47+parameter(txpower))**2 + regression_arg(3) * 1/(parameter(datarate)) * (19.47+parameter(txpower))**2 + # [6.30323056e+03 2.59889924e+06 7.82186268e+00 8.69746093e+03] + # TX : 0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate)) + regression_arg(2) * (19.47+parameter(txpower))**2 + regression_arg(3) * 1/(parameter(datarate)) * (19.47+parameter(txpower))**2 + # [7.67932887e+00 1.02969455e+04 4.55116475e-03 2.99786534e+01] + # epilogue : timeout : 0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate)) + # [ 1624.06589147 332251.93798766] @staticmethod def get_energy(params): @@ -192,31 +224,37 @@ class NRF24L01tx: # TX-Leistung, gefitted von AEMR # Messdaten erhoben bei 3.6V - power = 6.30323056e+03 - power += 2.59889924e+06 * 1/params['datarate'] - power += 7.82186268e+00 * (19.47+params['txpower'])**2 - power += 8.69746093e+03 * 1/params['datarate'] * (19.47+params['txpower'])**2 + power = 6.30323056e03 + power += 2.59889924e06 * 1 / params["datarate"] + power += 7.82186268e00 * (19.47 + params["txpower"]) ** 2 + power += ( + 8.69746093e03 * 1 / params["datarate"] * (19.47 + params["txpower"]) ** 2 + ) # TX-Dauer, gefitted von AEMR duration = 1624.06589147 - duration += 332251.93798766 * 1/params['datarate'] + duration += 332251.93798766 * 1 / params["datarate"] # TX-Energie, gefitted von AEMR # Achtung: Energy ist in µJ, nicht (wie in AEMR-Transitionsmodellen üblich) in pJ # Messdaten erhoben bei 3.6V - energy = 7.67932887e+00 - energy += 1.02969455e+04 * 1/params['datarate'] - energy += 4.55116475e-03 * (19.47+params['txpower'])**2 - energy += 2.99786534e+01 * 1/params['datarate'] * (19.47+params['txpower'])**2 + energy = 7.67932887e00 + energy += 1.02969455e04 * 1 / params["datarate"] + energy += 4.55116475e-03 * (19.47 + params["txpower"]) ** 2 + energy += ( + 2.99786534e01 * 1 / params["datarate"] * (19.47 + params["txpower"]) ** 2 + ) - energy = power * 1e-6 * duration * 1e-6 * np.ceil(params['txbytes'] / 32) + energy = power * 1e-6 * duration * 1e-6 * np.ceil(params["txbytes"] / 32) return energy @staticmethod def get_energy_per_byte(params): if type(params) != dict: - return NRF24L01tx.get_energy_per_byte(_param_list_to_dict(NRF24L01tx, params)) + return NRF24L01tx.get_energy_per_byte( + _param_list_to_dict(NRF24L01tx, params) + ) # in µJ de_dx = 0 @@ -224,17 +262,18 @@ class NRF24L01tx: class NRF24L01dtx: """nRF24L01+ TX energy based on datasheet values (probably unerestimated)""" - name = 'NRF24L01' + + name = "NRF24L01" parameters = { - 'datarate' : [250, 1000, 2000], # kbps - 'txbytes' : [], - 'txpower' : [-18, -12, -6, 0], # dBm - 'voltage' : [1.9, 3.6], + "datarate": [250, 1000, 2000], # kbps + "txbytes": [], + "txpower": [-18, -12, -6, 0], # dBm + "voltage": [1.9, 3.6], } default_params = { - 'datarate' : 1000, - 'txpower' : -6, - 'voltage' : 3, + "datarate": 1000, + "txpower": -6, + "voltage": 3, } # 130 us RX settling: 8.9 mE @@ -248,35 +287,37 @@ class NRF24L01dtx: header_bytes = 7 # TX settling: 130 us @ 8 mA - energy = 8e-3 * params['voltage'] * 130e-6 + energy = 8e-3 * params["voltage"] * 130e-6 - if params['txpower'] == -18: + if params["txpower"] == -18: current = 7e-3 - elif params['txpower'] == -12: + elif params["txpower"] == -12: current = 7.5e-3 - elif params['txpower'] == -6: + elif params["txpower"] == -6: current = 9e-3 - elif params['txpower'] == 0: + elif params["txpower"] == 0: current = 11.3e-3 - energy += current * params['voltage'] * ((header_bytes + params['txbytes']) * 8 / (params['datarate'] * 1e3)) + energy += ( + current + * params["voltage"] + * ((header_bytes + params["txbytes"]) * 8 / (params["datarate"] * 1e3)) + ) return energy + class ESP8266dtx: """esp8266 TX energy based on (hardly documented) datasheet values""" - name = 'esp8266' + + name = "esp8266" parameters = { - 'voltage' : [2.5, 3.0, 3.3, 3.6], - 'txbytes' : [], - 'bitrate' : [65e6], - 'tx_current' : [120e-3], - } - default_params = { - 'voltage' : 3, - 'bitrate' : 65e6, - 'tx_current' : 120e-3 + "voltage": [2.5, 3.0, 3.3, 3.6], + "txbytes": [], + "bitrate": [65e6], + "tx_current": [120e-3], } + default_params = {"voltage": 3, "bitrate": 65e6, "tx_current": 120e-3} @staticmethod def get_energy(params): @@ -286,26 +327,26 @@ class ESP8266dtx: @staticmethod def get_energy_per_byte(params): if type(params) != dict: - return ESP8266dtx.get_energy_per_byte(_param_list_to_dict(ESP8266dtx, params)) + return ESP8266dtx.get_energy_per_byte( + _param_list_to_dict(ESP8266dtx, params) + ) # TX in 802.11n MCS7 -> 64QAM, 65/72.2 Mbit/s @ 20MHz channel, 135/150 Mbit/s @ 40MHz # -> Value for 65 Mbit/s @ 20MHz channel - return params['tx_current'] * params['voltage'] / params['bitrate'] + return params["tx_current"] * params["voltage"] / params["bitrate"] + class ESP8266drx: """esp8266 RX energy based on (hardly documented) datasheet values""" - name = 'esp8266' + + name = "esp8266" parameters = { - 'voltage' : [2.5, 3.0, 3.3, 3.6], - 'txbytes' : [], - 'bitrate' : [65e6], - 'rx_current' : [56e-3], - } - default_params = { - 'voltage' : 3, - 'bitrate' : 65e6, - 'rx_current' : 56e-3 + "voltage": [2.5, 3.0, 3.3, 3.6], + "txbytes": [], + "bitrate": [65e6], + "rx_current": [56e-3], } + default_params = {"voltage": 3, "bitrate": 65e6, "rx_current": 56e-3} @staticmethod def get_energy(params): @@ -315,8 +356,10 @@ class ESP8266drx: @staticmethod def get_energy_per_byte(params): if type(params) != dict: - return ESP8266drx.get_energy_per_byte(_param_list_to_dict(ESP8266drx, params)) + return ESP8266drx.get_energy_per_byte( + _param_list_to_dict(ESP8266drx, params) + ) # TX in 802.11n MCS7 -> 64QAM, 65/72.2 Mbit/s @ 20MHz channel, 135/150 Mbit/s @ 40MHz # -> Value for 65 Mbit/s @ 20MHz channel - return params['rx_current'] * params['voltage'] / params['bitrate'] + return params["rx_current"] * params["voltage"] / params["bitrate"] diff --git a/lib/sly/__init__.py b/lib/sly/__init__.py index 3c1e708..3a2d92e 100644 --- a/lib/sly/__init__.py +++ b/lib/sly/__init__.py @@ -1,6 +1,5 @@ - from .lex import * from .yacc import * __version__ = "0.4" -__all__ = [ *lex.__all__, *yacc.__all__ ] +__all__ = [*lex.__all__, *yacc.__all__] diff --git a/lib/sly/ast.py b/lib/sly/ast.py index 7b79ac5..05802bd 100644 --- a/lib/sly/ast.py +++ b/lib/sly/ast.py @@ -1,25 +1,24 @@ # sly/ast.py import sys + class AST(object): - @classmethod def __init_subclass__(cls, **kwargs): mod = sys.modules[cls.__module__] - if not hasattr(cls, '__annotations__'): + if not hasattr(cls, "__annotations__"): return hints = list(cls.__annotations__.items()) def __init__(self, *args, **kwargs): if len(hints) != len(args): - raise TypeError(f'Expected {len(hints)} arguments') + raise TypeError(f"Expected {len(hints)} arguments") for arg, (name, val) in zip(args, hints): if isinstance(val, str): val = getattr(mod, val) if not isinstance(arg, val): - raise TypeError(f'{name} argument must be {val}') + raise TypeError(f"{name} argument must be {val}") setattr(self, name, arg) cls.__init__ = __init__ - diff --git a/lib/sly/docparse.py b/lib/sly/docparse.py index d5a83ce..0f35c97 100644 --- a/lib/sly/docparse.py +++ b/lib/sly/docparse.py @@ -2,7 +2,8 @@ # # Support doc-string parsing classes -__all__ = [ 'DocParseMeta' ] +__all__ = ["DocParseMeta"] + class DocParseMeta(type): ''' @@ -44,17 +45,17 @@ class DocParseMeta(type): @staticmethod def __new__(meta, clsname, bases, clsdict): - if '__doc__' in clsdict: + if "__doc__" in clsdict: lexer = meta.lexer() parser = meta.parser() lexer.cls_name = parser.cls_name = clsname - lexer.cls_qualname = parser.cls_qualname = clsdict['__qualname__'] - lexer.cls_module = parser.cls_module = clsdict['__module__'] - parsedict = parser.parse(lexer.tokenize(clsdict['__doc__'])) - assert isinstance(parsedict, dict), 'Parser must return a dictionary' + lexer.cls_qualname = parser.cls_qualname = clsdict["__qualname__"] + lexer.cls_module = parser.cls_module = clsdict["__module__"] + parsedict = parser.parse(lexer.tokenize(clsdict["__doc__"])) + assert isinstance(parsedict, dict), "Parser must return a dictionary" clsdict.update(parsedict) return super().__new__(meta, clsname, bases, clsdict) @classmethod def __init_subclass__(cls): - assert hasattr(cls, 'parser') and hasattr(cls, 'lexer') + assert hasattr(cls, "parser") and hasattr(cls, "lexer") diff --git a/lib/sly/lex.py b/lib/sly/lex.py index 246dd9e..0ab0160 100644 --- a/lib/sly/lex.py +++ b/lib/sly/lex.py @@ -31,51 +31,63 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # ----------------------------------------------------------------------------- -__all__ = ['Lexer', 'LexerStateChange'] +__all__ = ["Lexer", "LexerStateChange"] import re import copy + class LexError(Exception): - ''' + """ Exception raised if an invalid character is encountered and no default error handler function is defined. The .text attribute of the exception contains all remaining untokenized text. The .error_index is the index location of the error. - ''' + """ + def __init__(self, message, text, error_index): self.args = (message,) self.text = text self.error_index = error_index + class PatternError(Exception): - ''' + """ Exception raised if there's some kind of problem with the specified regex patterns in the lexer. - ''' + """ + pass + class LexerBuildError(Exception): - ''' + """ Exception raised if there's some sort of problem building the lexer. - ''' + """ + pass + class LexerStateChange(Exception): - ''' + """ Exception raised to force a lexing state change - ''' + """ + def __init__(self, newstate, tok=None): self.newstate = newstate self.tok = tok + class Token(object): - ''' + """ Representation of a single token. - ''' - __slots__ = ('type', 'value', 'lineno', 'index') + """ + + __slots__ = ("type", "value", "lineno", "index") + def __repr__(self): - return f'Token(type={self.type!r}, value={self.value!r}, lineno={self.lineno}, index={self.index})' + return f"Token(type={self.type!r}, value={self.value!r}, lineno={self.lineno}, index={self.index})" + class TokenStr(str): @staticmethod @@ -95,35 +107,38 @@ class TokenStr(str): if self.remap is not None: self.remap[self.key, key] = self.key + class _Before: def __init__(self, tok, pattern): self.tok = tok self.pattern = pattern + class LexerMetaDict(dict): - ''' + """ Special dictionary that prohibits duplicate definitions in lexer specifications. - ''' + """ + def __init__(self): - self.before = { } - self.delete = [ ] - self.remap = { } + self.before = {} + self.delete = [] + self.remap = {} def __setitem__(self, key, value): if isinstance(value, str): value = TokenStr(value, key, self.remap) - + if isinstance(value, _Before): self.before[key] = value.tok value = TokenStr(value.pattern, key, self.remap) - + if key in self and not isinstance(value, property): prior = self[key] if isinstance(prior, str): if callable(value): value.pattern = prior else: - raise AttributeError(f'Name {key} redefined') + raise AttributeError(f"Name {key} redefined") super().__setitem__(key, value) @@ -135,41 +150,47 @@ class LexerMetaDict(dict): return super().__delitem__(key) def __getitem__(self, key): - if key not in self and key.split('ignore_')[-1].isupper() and key[:1] != '_': + if key not in self and key.split("ignore_")[-1].isupper() and key[:1] != "_": return TokenStr(key, key, self.remap) else: return super().__getitem__(key) + class LexerMeta(type): - ''' + """ Metaclass for collecting lexing rules - ''' + """ + @classmethod def __prepare__(meta, name, bases): d = LexerMetaDict() def _(pattern, *extra): patterns = [pattern, *extra] + def decorate(func): - pattern = '|'.join(f'({pat})' for pat in patterns ) - if hasattr(func, 'pattern'): - func.pattern = pattern + '|' + func.pattern + pattern = "|".join(f"({pat})" for pat in patterns) + if hasattr(func, "pattern"): + func.pattern = pattern + "|" + func.pattern else: func.pattern = pattern return func + return decorate - d['_'] = _ - d['before'] = _Before + d["_"] = _ + d["before"] = _Before return d def __new__(meta, clsname, bases, attributes): - del attributes['_'] - del attributes['before'] + del attributes["_"] + del attributes["before"] # Create attributes for use in the actual class body - cls_attributes = { str(key): str(val) if isinstance(val, TokenStr) else val - for key, val in attributes.items() } + cls_attributes = { + str(key): str(val) if isinstance(val, TokenStr) else val + for key, val in attributes.items() + } cls = super().__new__(meta, clsname, bases, cls_attributes) # Attach various metadata to the class @@ -180,11 +201,12 @@ class LexerMeta(type): cls._build() return cls + class Lexer(metaclass=LexerMeta): # These attributes may be defined in subclasses tokens = set() literals = set() - ignore = '' + ignore = "" reflags = 0 regex_module = re @@ -214,7 +236,7 @@ class Lexer(metaclass=LexerMeta): # Such functions can be created with the @_ decorator or by defining # function with the same name as a previously defined string. # - # This function is responsible for keeping rules in order. + # This function is responsible for keeping rules in order. # Collect all previous rules from base classes rules = [] @@ -222,15 +244,21 @@ class Lexer(metaclass=LexerMeta): for base in cls.__bases__: if isinstance(base, LexerMeta): rules.extend(base._rules) - + # Dictionary of previous rules existing = dict(rules) for key, value in cls._attributes.items(): - if (key in cls._token_names) or key.startswith('ignore_') or hasattr(value, 'pattern'): - if callable(value) and not hasattr(value, 'pattern'): - raise LexerBuildError(f"function {value} doesn't have a regex pattern") - + if ( + (key in cls._token_names) + or key.startswith("ignore_") + or hasattr(value, "pattern") + ): + if callable(value) and not hasattr(value, "pattern"): + raise LexerBuildError( + f"function {value} doesn't have a regex pattern" + ) + if key in existing: # The definition matches something that already existed in the base class. # We replace it, but keep the original ordering @@ -252,21 +280,27 @@ class Lexer(metaclass=LexerMeta): rules.append((key, value)) existing[key] = value - elif isinstance(value, str) and not key.startswith('_') and key not in {'ignore', 'literals'}: - raise LexerBuildError(f'{key} does not match a name in tokens') + elif ( + isinstance(value, str) + and not key.startswith("_") + and key not in {"ignore", "literals"} + ): + raise LexerBuildError(f"{key} does not match a name in tokens") # Apply deletion rules - rules = [ (key, value) for key, value in rules if key not in cls._delete ] + rules = [(key, value) for key, value in rules if key not in cls._delete] cls._rules = rules @classmethod def _build(cls): - ''' + """ Build the lexer object from the collected tokens and regular expressions. Validate the rules to make sure they look sane. - ''' - if 'tokens' not in vars(cls): - raise LexerBuildError(f'{cls.__qualname__} class does not define a tokens attribute') + """ + if "tokens" not in vars(cls): + raise LexerBuildError( + f"{cls.__qualname__} class does not define a tokens attribute" + ) # Pull definitions created for any parent classes cls._token_names = cls._token_names | set(cls.tokens) @@ -282,17 +316,17 @@ class Lexer(metaclass=LexerMeta): remapped_toks = set() for d in cls._remapping.values(): remapped_toks.update(d.values()) - + undefined = remapped_toks - set(cls._token_names) if undefined: - missing = ', '.join(undefined) - raise LexerBuildError(f'{missing} not included in token(s)') + missing = ", ".join(undefined) + raise LexerBuildError(f"{missing} not included in token(s)") cls._collect_rules() parts = [] for tokname, value in cls._rules: - if tokname.startswith('ignore_'): + if tokname.startswith("ignore_"): tokname = tokname[7:] cls._ignored_tokens.add(tokname) @@ -301,20 +335,20 @@ class Lexer(metaclass=LexerMeta): elif callable(value): cls._token_funcs[tokname] = value - pattern = getattr(value, 'pattern') + pattern = getattr(value, "pattern") # Form the regular expression component - part = f'(?P<{tokname}>{pattern})' + part = f"(?P<{tokname}>{pattern})" # Make sure the individual regex compiles properly try: cpat = cls.regex_module.compile(part, cls.reflags) except Exception as e: - raise PatternError(f'Invalid regex for token {tokname}') from e + raise PatternError(f"Invalid regex for token {tokname}") from e # Verify that the pattern doesn't match the empty string - if cpat.match(''): - raise PatternError(f'Regex for token {tokname} matches empty input') + if cpat.match(""): + raise PatternError(f"Regex for token {tokname} matches empty input") parts.append(part) @@ -322,43 +356,45 @@ class Lexer(metaclass=LexerMeta): return # Form the master regular expression - #previous = ('|' + cls._master_re.pattern) if cls._master_re else '' + # previous = ('|' + cls._master_re.pattern) if cls._master_re else '' # cls._master_re = cls.regex_module.compile('|'.join(parts) + previous, cls.reflags) - cls._master_re = cls.regex_module.compile('|'.join(parts), cls.reflags) + cls._master_re = cls.regex_module.compile("|".join(parts), cls.reflags) # Verify that that ignore and literals specifiers match the input type if not isinstance(cls.ignore, str): - raise LexerBuildError('ignore specifier must be a string') + raise LexerBuildError("ignore specifier must be a string") if not all(isinstance(lit, str) for lit in cls.literals): - raise LexerBuildError('literals must be specified as strings') + raise LexerBuildError("literals must be specified as strings") def begin(self, cls): - ''' + """ Begin a new lexer state - ''' + """ assert isinstance(cls, LexerMeta), "state must be a subclass of Lexer" if self.__set_state: self.__set_state(cls) self.__class__ = cls def push_state(self, cls): - ''' + """ Push a new lexer state onto the stack - ''' + """ if self.__state_stack is None: self.__state_stack = [] self.__state_stack.append(type(self)) self.begin(cls) def pop_state(self): - ''' + """ Pop a lexer state from the stack - ''' + """ self.begin(self.__state_stack.pop()) def tokenize(self, text, lineno=1, index=0): - _ignored_tokens = _master_re = _ignore = _token_funcs = _literals = _remapping = None + _ignored_tokens = ( + _master_re + ) = _ignore = _token_funcs = _literals = _remapping = None def _set_state(cls): nonlocal _ignored_tokens, _master_re, _ignore, _token_funcs, _literals, _remapping @@ -419,7 +455,7 @@ class Lexer(metaclass=LexerMeta): # A lexing error self.index = index self.lineno = lineno - tok.type = 'ERROR' + tok.type = "ERROR" tok.value = text[index:] tok = self.error(tok) if tok is not None: @@ -436,4 +472,8 @@ class Lexer(metaclass=LexerMeta): # Default implementations of the error handler. May be changed in subclasses def error(self, t): - raise LexError(f'Illegal character {t.value[0]!r} at index {self.index}', t.value, self.index) + raise LexError( + f"Illegal character {t.value[0]!r} at index {self.index}", + t.value, + self.index, + ) diff --git a/lib/sly/yacc.py b/lib/sly/yacc.py index c30f13c..167168d 100644 --- a/lib/sly/yacc.py +++ b/lib/sly/yacc.py @@ -35,22 +35,25 @@ import sys import inspect from collections import OrderedDict, defaultdict -__all__ = [ 'Parser' ] +__all__ = ["Parser"] + class YaccError(Exception): - ''' + """ Exception raised for yacc-related build errors. - ''' + """ + pass -#----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- # === User configurable parameters === # -# Change these to modify the default behavior of yacc (if you wish). +# Change these to modify the default behavior of yacc (if you wish). # Move these parameters to the Yacc class itself. -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- -ERROR_COUNT = 3 # Number of symbols that must be shifted to leave recovery mode +ERROR_COUNT = 3 # Number of symbols that must be shifted to leave recovery mode MAXINT = sys.maxsize # This object is a stand-in for a logging object created by the @@ -59,20 +62,21 @@ MAXINT = sys.maxsize # information, they can create their own logging object and pass # it into SLY. + class SlyLogger(object): def __init__(self, f): self.f = f def debug(self, msg, *args, **kwargs): - self.f.write((msg % args) + '\n') + self.f.write((msg % args) + "\n") info = debug def warning(self, msg, *args, **kwargs): - self.f.write('WARNING: ' + (msg % args) + '\n') + self.f.write("WARNING: " + (msg % args) + "\n") def error(self, msg, *args, **kwargs): - self.f.write('ERROR: ' + (msg % args) + '\n') + self.f.write("ERROR: " + (msg % args) + "\n") critical = debug @@ -86,6 +90,7 @@ class SlyLogger(object): # .index = Starting lex position # ---------------------------------------------------------------------- + class YaccSymbol: def __str__(self): return self.type @@ -93,19 +98,22 @@ class YaccSymbol: def __repr__(self): return str(self) + # ---------------------------------------------------------------------- # This class is a wrapper around the objects actually passed to each # grammar rule. Index lookup and assignment actually assign the # .value attribute of the underlying YaccSymbol object. # The lineno() method returns the line number of a given -# item (or 0 if not defined). +# item (or 0 if not defined). # ---------------------------------------------------------------------- + class YaccProduction: - __slots__ = ('_slice', '_namemap', '_stack') + __slots__ = ("_slice", "_namemap", "_stack") + def __init__(self, s, stack=None): self._slice = s - self._namemap = { } + self._namemap = {} self._stack = stack def __getitem__(self, n): @@ -128,34 +136,35 @@ class YaccProduction: for tok in self._slice: if isinstance(tok, YaccSymbol): continue - lineno = getattr(tok, 'lineno', None) + lineno = getattr(tok, "lineno", None) if lineno: return lineno - raise AttributeError('No line number found') + raise AttributeError("No line number found") @property def index(self): for tok in self._slice: if isinstance(tok, YaccSymbol): continue - index = getattr(tok, 'index', None) + index = getattr(tok, "index", None) if index is not None: return index - raise AttributeError('No index attribute found') + raise AttributeError("No index attribute found") def __getattr__(self, name): if name in self._namemap: return self._slice[self._namemap[name]].value else: - nameset = '{' + ', '.join(self._namemap) + '}' - raise AttributeError(f'No symbol {name}. Must be one of {nameset}.') + nameset = "{" + ", ".join(self._namemap) + "}" + raise AttributeError(f"No symbol {name}. Must be one of {nameset}.") def __setattr__(self, name, value): - if name[:1] == '_': + if name[:1] == "_": super().__setattr__(name, value) else: raise AttributeError(f"Can't reassign the value of attribute {name!r}") + # ----------------------------------------------------------------------------- # === Grammar Representation === # @@ -187,19 +196,23 @@ class YaccProduction: # usyms - Set of unique symbols found in the production # ----------------------------------------------------------------------------- + class Production(object): reduced = 0 - def __init__(self, number, name, prod, precedence=('right', 0), func=None, file='', line=0): - self.name = name - self.prod = tuple(prod) - self.number = number - self.func = func - self.file = file - self.line = line - self.prec = precedence - + + def __init__( + self, number, name, prod, precedence=("right", 0), func=None, file="", line=0 + ): + self.name = name + self.prod = tuple(prod) + self.number = number + self.func = func + self.file = file + self.line = line + self.prec = precedence + # Internal settings used during table construction - self.len = len(self.prod) # Length of the production + self.len = len(self.prod) # Length of the production # Create a list of unique production symbols used in the production self.usyms = [] @@ -216,33 +229,33 @@ class Production(object): m[key] = indices[0] else: for n, index in enumerate(indices): - m[key+str(n)] = index + m[key + str(n)] = index self.namemap = m - + # List of all LR items for the production self.lr_items = [] self.lr_next = None def __str__(self): if self.prod: - s = '%s -> %s' % (self.name, ' '.join(self.prod)) + s = "%s -> %s" % (self.name, " ".join(self.prod)) else: - s = f'{self.name} -> <empty>' + s = f"{self.name} -> <empty>" if self.prec[1]: - s += ' [precedence=%s, level=%d]' % self.prec + s += " [precedence=%s, level=%d]" % self.prec return s def __repr__(self): - return f'Production({self})' + return f"Production({self})" def __len__(self): return len(self.prod) def __nonzero__(self): - raise RuntimeError('Used') + raise RuntimeError("Used") return 1 def __getitem__(self, index): @@ -255,15 +268,16 @@ class Production(object): p = LRItem(self, n) # Precompute the list of productions immediately following. try: - p.lr_after = Prodnames[p.prod[n+1]] + p.lr_after = Prodnames[p.prod[n + 1]] except (IndexError, KeyError): p.lr_after = [] try: - p.lr_before = p.prod[n-1] + p.lr_before = p.prod[n - 1] except IndexError: p.lr_before = None return p + # ----------------------------------------------------------------------------- # class LRItem # @@ -288,27 +302,29 @@ class Production(object): # lr_before - Grammar symbol immediately before # ----------------------------------------------------------------------------- + class LRItem(object): def __init__(self, p, n): - self.name = p.name - self.prod = list(p.prod) - self.number = p.number - self.lr_index = n + self.name = p.name + self.prod = list(p.prod) + self.number = p.number + self.lr_index = n self.lookaheads = {} - self.prod.insert(n, '.') - self.prod = tuple(self.prod) - self.len = len(self.prod) - self.usyms = p.usyms + self.prod.insert(n, ".") + self.prod = tuple(self.prod) + self.len = len(self.prod) + self.usyms = p.usyms def __str__(self): if self.prod: - s = '%s -> %s' % (self.name, ' '.join(self.prod)) + s = "%s -> %s" % (self.name, " ".join(self.prod)) else: - s = f'{self.name} -> <empty>' + s = f"{self.name} -> <empty>" return s def __repr__(self): - return f'LRItem({self})' + return f"LRItem({self})" + # ----------------------------------------------------------------------------- # rightmost_terminal() @@ -323,6 +339,7 @@ def rightmost_terminal(symbols, terminals): i -= 1 return None + # ----------------------------------------------------------------------------- # === GRAMMAR CLASS === # @@ -331,45 +348,52 @@ def rightmost_terminal(symbols, terminals): # This data is used for critical parts of the table generation process later. # ----------------------------------------------------------------------------- + class GrammarError(YaccError): pass + class Grammar(object): def __init__(self, terminals): - self.Productions = [None] # A list of all of the productions. The first - # entry is always reserved for the purpose of - # building an augmented grammar + self.Productions = [None] # A list of all of the productions. The first + # entry is always reserved for the purpose of + # building an augmented grammar - self.Prodnames = {} # A dictionary mapping the names of nonterminals to a list of all - # productions of that nonterminal. + self.Prodnames = ( + {} + ) # A dictionary mapping the names of nonterminals to a list of all + # productions of that nonterminal. - self.Prodmap = {} # A dictionary that is only used to detect duplicate - # productions. + self.Prodmap = {} # A dictionary that is only used to detect duplicate + # productions. - self.Terminals = {} # A dictionary mapping the names of terminal symbols to a - # list of the rules where they are used. + self.Terminals = {} # A dictionary mapping the names of terminal symbols to a + # list of the rules where they are used. for term in terminals: self.Terminals[term] = [] - self.Terminals['error'] = [] - - self.Nonterminals = {} # A dictionary mapping names of nonterminals to a list - # of rule numbers where they are used. + self.Terminals["error"] = [] - self.First = {} # A dictionary of precomputed FIRST(x) symbols + self.Nonterminals = {} # A dictionary mapping names of nonterminals to a list + # of rule numbers where they are used. - self.Follow = {} # A dictionary of precomputed FOLLOW(x) symbols + self.First = {} # A dictionary of precomputed FIRST(x) symbols - self.Precedence = {} # Precedence rules for each terminal. Contains tuples of the - # form ('right',level) or ('nonassoc', level) or ('left',level) + self.Follow = {} # A dictionary of precomputed FOLLOW(x) symbols - self.UsedPrecedence = set() # Precedence rules that were actually used by the grammer. - # This is only used to provide error checking and to generate - # a warning about unused precedence rules. + self.Precedence = ( + {} + ) # Precedence rules for each terminal. Contains tuples of the + # form ('right',level) or ('nonassoc', level) or ('left',level) - self.Start = None # Starting symbol for the grammar + self.UsedPrecedence = ( + set() + ) # Precedence rules that were actually used by the grammer. + # This is only used to provide error checking and to generate + # a warning about unused precedence rules. + self.Start = None # Starting symbol for the grammar def __len__(self): return len(self.Productions) @@ -386,11 +410,15 @@ class Grammar(object): # ----------------------------------------------------------------------------- def set_precedence(self, term, assoc, level): - assert self.Productions == [None], 'Must call set_precedence() before add_production()' + assert self.Productions == [ + None + ], "Must call set_precedence() before add_production()" if term in self.Precedence: - raise GrammarError(f'Precedence already specified for terminal {term!r}') - if assoc not in ['left', 'right', 'nonassoc']: - raise GrammarError(f"Associativity of {term!r} must be one of 'left','right', or 'nonassoc'") + raise GrammarError(f"Precedence already specified for terminal {term!r}") + if assoc not in ["left", "right", "nonassoc"]: + raise GrammarError( + f"Associativity of {term!r} must be one of 'left','right', or 'nonassoc'" + ) self.Precedence[term] = (assoc, level) # ----------------------------------------------------------------------------- @@ -410,51 +438,65 @@ class Grammar(object): # are valid and that %prec is used correctly. # ----------------------------------------------------------------------------- - def add_production(self, prodname, syms, func=None, file='', line=0): + def add_production(self, prodname, syms, func=None, file="", line=0): if prodname in self.Terminals: - raise GrammarError(f'{file}:{line}: Illegal rule name {prodname!r}. Already defined as a token') - if prodname == 'error': - raise GrammarError(f'{file}:{line}: Illegal rule name {prodname!r}. error is a reserved word') + raise GrammarError( + f"{file}:{line}: Illegal rule name {prodname!r}. Already defined as a token" + ) + if prodname == "error": + raise GrammarError( + f"{file}:{line}: Illegal rule name {prodname!r}. error is a reserved word" + ) # Look for literal tokens for n, s in enumerate(syms): if s[0] in "'\"" and s[0] == s[-1]: c = s[1:-1] - if (len(c) != 1): - raise GrammarError(f'{file}:{line}: Literal token {s} in rule {prodname!r} may only be a single character') + if len(c) != 1: + raise GrammarError( + f"{file}:{line}: Literal token {s} in rule {prodname!r} may only be a single character" + ) if c not in self.Terminals: self.Terminals[c] = [] syms[n] = c continue # Determine the precedence level - if '%prec' in syms: - if syms[-1] == '%prec': - raise GrammarError(f'{file}:{line}: Syntax error. Nothing follows %%prec') - if syms[-2] != '%prec': - raise GrammarError(f'{file}:{line}: Syntax error. %prec can only appear at the end of a grammar rule') + if "%prec" in syms: + if syms[-1] == "%prec": + raise GrammarError( + f"{file}:{line}: Syntax error. Nothing follows %%prec" + ) + if syms[-2] != "%prec": + raise GrammarError( + f"{file}:{line}: Syntax error. %prec can only appear at the end of a grammar rule" + ) precname = syms[-1] prodprec = self.Precedence.get(precname) if not prodprec: - raise GrammarError(f'{file}:{line}: Nothing known about the precedence of {precname!r}') + raise GrammarError( + f"{file}:{line}: Nothing known about the precedence of {precname!r}" + ) else: self.UsedPrecedence.add(precname) - del syms[-2:] # Drop %prec from the rule + del syms[-2:] # Drop %prec from the rule else: # If no %prec, precedence is determined by the rightmost terminal symbol precname = rightmost_terminal(syms, self.Terminals) - prodprec = self.Precedence.get(precname, ('right', 0)) + prodprec = self.Precedence.get(precname, ("right", 0)) # See if the rule is already in the rulemap - map = '%s -> %s' % (prodname, syms) + map = "%s -> %s" % (prodname, syms) if map in self.Prodmap: m = self.Prodmap[map] - raise GrammarError(f'{file}:{line}: Duplicate rule {m}. ' + - f'Previous definition at {m.file}:{m.line}') + raise GrammarError( + f"{file}:{line}: Duplicate rule {m}. " + + f"Previous definition at {m.file}:{m.line}" + ) # From this point on, everything is valid. Create a new Production instance - pnumber = len(self.Productions) + pnumber = len(self.Productions) if prodname not in self.Nonterminals: self.Nonterminals[prodname] = [] @@ -493,7 +535,7 @@ class Grammar(object): start = self.Productions[1].name if start not in self.Nonterminals: - raise GrammarError(f'start symbol {start} undefined') + raise GrammarError(f"start symbol {start} undefined") self.Productions[0] = Production(0, "S'", [start]) self.Nonterminals[start].append(0) self.Start = start @@ -535,7 +577,7 @@ class Grammar(object): for t in self.Terminals: terminates[t] = True - terminates['$end'] = True + terminates["$end"] = True # Nonterminals: @@ -576,7 +618,7 @@ class Grammar(object): infinite = [] for (s, term) in terminates.items(): if not term: - if s not in self.Prodnames and s not in self.Terminals and s != 'error': + if s not in self.Prodnames and s not in self.Terminals and s != "error": # s is used-but-not-defined, and we've already warned of that, # so it would be overkill to say that it's also non-terminating. pass @@ -599,7 +641,7 @@ class Grammar(object): continue for s in p.prod: - if s not in self.Prodnames and s not in self.Terminals and s != 'error': + if s not in self.Prodnames and s not in self.Terminals and s != "error": result.append((s, p)) return result @@ -612,7 +654,7 @@ class Grammar(object): def unused_terminals(self): unused_tok = [] for s, v in self.Terminals.items(): - if s != 'error' and not v: + if s != "error" and not v: unused_tok.append(s) return unused_tok @@ -666,7 +708,7 @@ class Grammar(object): # Add all the non-<empty> symbols of First[x] to the result. for f in self.First[x]: - if f == '<empty>': + if f == "<empty>": x_produces_empty = True else: if f not in result: @@ -683,7 +725,7 @@ class Grammar(object): # There was no 'break' from the loop, # so x_produces_empty was true for all x in beta, # so beta produces empty as well. - result.append('<empty>') + result.append("<empty>") return result @@ -700,7 +742,7 @@ class Grammar(object): for t in self.Terminals: self.First[t] = [t] - self.First['$end'] = ['$end'] + self.First["$end"] = ["$end"] # Nonterminals: @@ -745,7 +787,7 @@ class Grammar(object): if not start: start = self.Productions[1].name - self.Follow[start] = ['$end'] + self.Follow[start] = ["$end"] while True: didadd = False @@ -754,15 +796,15 @@ class Grammar(object): for i, B in enumerate(p.prod): if B in self.Nonterminals: # Okay. We got a non-terminal in a production - fst = self._first(p.prod[i+1:]) + fst = self._first(p.prod[i + 1 :]) hasempty = False for f in fst: - if f != '<empty>' and f not in self.Follow[B]: + if f != "<empty>" and f not in self.Follow[B]: self.Follow[B].append(f) didadd = True - if f == '<empty>': + if f == "<empty>": hasempty = True - if hasempty or i == (len(p.prod)-1): + if hasempty or i == (len(p.prod) - 1): # Add elements of follow(a) to follow(b) for f in self.Follow[p.name]: if f not in self.Follow[B]: @@ -772,7 +814,6 @@ class Grammar(object): break return self.Follow - # ----------------------------------------------------------------------------- # build_lritems() # @@ -800,11 +841,11 @@ class Grammar(object): lri = LRItem(p, i) # Precompute the list of productions immediately following try: - lri.lr_after = self.Prodnames[lri.prod[i+1]] + lri.lr_after = self.Prodnames[lri.prod[i + 1]] except (IndexError, KeyError): lri.lr_after = [] try: - lri.lr_before = lri.prod[i-1] + lri.lr_before = lri.prod[i - 1] except IndexError: lri.lr_before = None @@ -816,33 +857,38 @@ class Grammar(object): i += 1 p.lr_items = lr_items - # ---------------------------------------------------------------------- # Debugging output. Printing the grammar will produce a detailed # description along with some diagnostics. # ---------------------------------------------------------------------- def __str__(self): out = [] - out.append('Grammar:\n') + out.append("Grammar:\n") for n, p in enumerate(self.Productions): - out.append(f'Rule {n:<5d} {p}') - + out.append(f"Rule {n:<5d} {p}") + unused_terminals = self.unused_terminals() if unused_terminals: - out.append('\nUnused terminals:\n') + out.append("\nUnused terminals:\n") for term in unused_terminals: - out.append(f' {term}') + out.append(f" {term}") - out.append('\nTerminals, with rules where they appear:\n') + out.append("\nTerminals, with rules where they appear:\n") for term in sorted(self.Terminals): - out.append('%-20s : %s' % (term, ' '.join(str(s) for s in self.Terminals[term]))) + out.append( + "%-20s : %s" % (term, " ".join(str(s) for s in self.Terminals[term])) + ) - out.append('\nNonterminals, with rules where they appear:\n') + out.append("\nNonterminals, with rules where they appear:\n") for nonterm in sorted(self.Nonterminals): - out.append('%-20s : %s' % (nonterm, ' '.join(str(s) for s in self.Nonterminals[nonterm]))) + out.append( + "%-20s : %s" + % (nonterm, " ".join(str(s) for s in self.Nonterminals[nonterm])) + ) + + out.append("") + return "\n".join(out) - out.append('') - return '\n'.join(out) # ----------------------------------------------------------------------------- # === LR Generator === @@ -868,6 +914,7 @@ class Grammar(object): # FP - Set-valued function # ------------------------------------------------------------------------------ + def digraph(X, R, FP): N = {} for x in X: @@ -879,13 +926,14 @@ def digraph(X, R, FP): traverse(x, N, stack, F, X, R, FP) return F + def traverse(x, N, stack, F, X, R, FP): stack.append(x) d = len(stack) N[x] = d - F[x] = FP(x) # F(X) <- F'(x) + F[x] = FP(x) # F(X) <- F'(x) - rel = R(x) # Get y's related to x + rel = R(x) # Get y's related to x for y in rel: if N[y] == 0: traverse(y, N, stack, F, X, R, FP) @@ -902,9 +950,11 @@ def traverse(x, N, stack, F, X, R, FP): F[stack[-1]] = F[x] element = stack.pop() + class LALRError(YaccError): pass + # ----------------------------------------------------------------------------- # == LRGeneratedTable == # @@ -912,26 +962,27 @@ class LALRError(YaccError): # public methods except for write() # ----------------------------------------------------------------------------- + class LRTable(object): def __init__(self, grammar): self.grammar = grammar # Internal attributes - self.lr_action = {} # Action table - self.lr_goto = {} # Goto table - self.lr_productions = grammar.Productions # Copy of grammar Production array - self.lr_goto_cache = {} # Cache of computed gotos - self.lr0_cidhash = {} # Cache of closures - self._add_count = 0 # Internal counter used to detect cycles + self.lr_action = {} # Action table + self.lr_goto = {} # Goto table + self.lr_productions = grammar.Productions # Copy of grammar Production array + self.lr_goto_cache = {} # Cache of computed gotos + self.lr0_cidhash = {} # Cache of closures + self._add_count = 0 # Internal counter used to detect cycles # Diagonistic information filled in by the table generator self.state_descriptions = OrderedDict() - self.sr_conflict = 0 - self.rr_conflict = 0 - self.conflicts = [] # List of conflicts + self.sr_conflict = 0 + self.rr_conflict = 0 + self.conflicts = [] # List of conflicts - self.sr_conflicts = [] - self.rr_conflicts = [] + self.sr_conflicts = [] + self.rr_conflicts = [] # Build the tables self.grammar.build_lritems() @@ -964,7 +1015,7 @@ class LRTable(object): didadd = False for j in J: for x in j.lr_after: - if getattr(x, 'lr0_added', 0) == self._add_count: + if getattr(x, "lr0_added", 0) == self._add_count: continue # Add B --> .G to J J.append(x.lr_next) @@ -1004,13 +1055,13 @@ class LRTable(object): s[id(n)] = s1 gs.append(n) s = s1 - g = s.get('$end') + g = s.get("$end") if not g: if gs: g = self.lr0_closure(gs) - s['$end'] = g + s["$end"] = g else: - s['$end'] = gs + s["$end"] = gs self.lr_goto_cache[(id(I), x)] = g return g @@ -1105,7 +1156,7 @@ class LRTable(object): for stateno, state in enumerate(C): for p in state: if p.lr_index < p.len - 1: - t = (stateno, p.prod[p.lr_index+1]) + t = (stateno, p.prod[p.lr_index + 1]) if t[1] in self.grammar.Nonterminals: if t not in trans: trans.append(t) @@ -1128,14 +1179,14 @@ class LRTable(object): g = self.lr0_goto(C[state], N) for p in g: if p.lr_index < p.len - 1: - a = p.prod[p.lr_index+1] + a = p.prod[p.lr_index + 1] if a in self.grammar.Terminals: if a not in terms: terms.append(a) # This extra bit is to handle the start state if state == 0 and N == self.grammar.Productions[0].prod[0]: - terms.append('$end') + terms.append("$end") return terms @@ -1189,8 +1240,8 @@ class LRTable(object): # ----------------------------------------------------------------------------- def compute_lookback_includes(self, C, trans, nullable): - lookdict = {} # Dictionary of lookback relations - includedict = {} # Dictionary of include relations + lookdict = {} # Dictionary of lookback relations + includedict = {} # Dictionary of include relations # Make a dictionary of non-terminal transitions dtrans = {} @@ -1223,7 +1274,7 @@ class LRTable(object): li = lr_index + 1 while li < p.len: if p.prod[li] in self.grammar.Terminals: - break # No forget it + break # No forget it if p.prod[li] not in nullable: break li = li + 1 @@ -1231,8 +1282,8 @@ class LRTable(object): # Appears to be a relation between (j,t) and (state,N) includes.append((j, t)) - g = self.lr0_goto(C[j], t) # Go to next set - j = self.lr0_cidhash.get(id(g), -1) # Go to next state + g = self.lr0_goto(C[j], t) # Go to next set + j = self.lr0_cidhash.get(id(g), -1) # Go to next state # When we get here, j is the final state, now we have to locate the production for r in C[j]: @@ -1243,7 +1294,7 @@ class LRTable(object): i = 0 # This look is comparing a production ". A B C" with "A B C ." while i < r.lr_index: - if r.prod[i] != p.prod[i+1]: + if r.prod[i] != p.prod[i + 1]: break i = i + 1 else: @@ -1270,7 +1321,7 @@ class LRTable(object): def compute_read_sets(self, C, ntrans, nullable): FP = lambda x: self.dr_relation(C, x, nullable) - R = lambda x: self.reads_relation(C, x, nullable) + R = lambda x: self.reads_relation(C, x, nullable) F = digraph(ntrans, R, FP) return F @@ -1292,7 +1343,7 @@ class LRTable(object): def compute_follow_sets(self, ntrans, readsets, inclsets): FP = lambda x: readsets[x] - R = lambda x: inclsets.get(x, []) + R = lambda x: inclsets.get(x, []) F = digraph(ntrans, R, FP) return F @@ -1352,11 +1403,11 @@ class LRTable(object): # ----------------------------------------------------------------------------- def lr_parse_table(self): Productions = self.grammar.Productions - Precedence = self.grammar.Precedence - goto = self.lr_goto # Goto array - action = self.lr_action # Action array + Precedence = self.grammar.Precedence + goto = self.lr_goto # Goto array + action = self.lr_action # Action array - actionp = {} # Action production array (temporary) + actionp = {} # Action production array (temporary) # Step 1: Construct C = { I0, I1, ... IN}, collection of LR(0) items # This determines the number of states @@ -1368,129 +1419,149 @@ class LRTable(object): for st, I in enumerate(C): descrip = [] # Loop over each production in I - actlist = [] # List of actions - st_action = {} + actlist = [] # List of actions + st_action = {} st_actionp = {} - st_goto = {} + st_goto = {} - descrip.append(f'\nstate {st}\n') + descrip.append(f"\nstate {st}\n") for p in I: - descrip.append(f' ({p.number}) {p}') + descrip.append(f" ({p.number}) {p}") for p in I: - if p.len == p.lr_index + 1: - if p.name == "S'": - # Start symbol. Accept! - st_action['$end'] = 0 - st_actionp['$end'] = p - else: - # We are at the end of a production. Reduce! - laheads = p.lookaheads[st] - for a in laheads: - actlist.append((a, p, f'reduce using rule {p.number} ({p})')) - r = st_action.get(a) - if r is not None: - # Have a shift/reduce or reduce/reduce conflict - if r > 0: - # Need to decide on shift or reduce here - # By default we favor shifting. Need to add - # some precedence rules here. - - # Shift precedence comes from the token - sprec, slevel = Precedence.get(a, ('right', 0)) - - # Reduce precedence comes from rule being reduced (p) - rprec, rlevel = Productions[p.number].prec - - if (slevel < rlevel) or ((slevel == rlevel) and (rprec == 'left')): - # We really need to reduce here. - st_action[a] = -p.number - st_actionp[a] = p - if not slevel and not rlevel: - descrip.append(f' ! shift/reduce conflict for {a} resolved as reduce') - self.sr_conflicts.append((st, a, 'reduce')) - Productions[p.number].reduced += 1 - elif (slevel == rlevel) and (rprec == 'nonassoc'): - st_action[a] = None - else: - # Hmmm. Guess we'll keep the shift - if not rlevel: - descrip.append(f' ! shift/reduce conflict for {a} resolved as shift') - self.sr_conflicts.append((st, a, 'shift')) - elif r <= 0: - # Reduce/reduce conflict. In this case, we favor the rule - # that was defined first in the grammar file - oldp = Productions[-r] - pp = Productions[p.number] - if oldp.line > pp.line: - st_action[a] = -p.number - st_actionp[a] = p - chosenp, rejectp = pp, oldp - Productions[p.number].reduced += 1 - Productions[oldp.number].reduced -= 1 - else: - chosenp, rejectp = oldp, pp - self.rr_conflicts.append((st, chosenp, rejectp)) - descrip.append(' ! reduce/reduce conflict for %s resolved using rule %d (%s)' % - (a, st_actionp[a].number, st_actionp[a])) + if p.len == p.lr_index + 1: + if p.name == "S'": + # Start symbol. Accept! + st_action["$end"] = 0 + st_actionp["$end"] = p + else: + # We are at the end of a production. Reduce! + laheads = p.lookaheads[st] + for a in laheads: + actlist.append( + (a, p, f"reduce using rule {p.number} ({p})") + ) + r = st_action.get(a) + if r is not None: + # Have a shift/reduce or reduce/reduce conflict + if r > 0: + # Need to decide on shift or reduce here + # By default we favor shifting. Need to add + # some precedence rules here. + + # Shift precedence comes from the token + sprec, slevel = Precedence.get(a, ("right", 0)) + + # Reduce precedence comes from rule being reduced (p) + rprec, rlevel = Productions[p.number].prec + + if (slevel < rlevel) or ( + (slevel == rlevel) and (rprec == "left") + ): + # We really need to reduce here. + st_action[a] = -p.number + st_actionp[a] = p + if not slevel and not rlevel: + descrip.append( + f" ! shift/reduce conflict for {a} resolved as reduce" + ) + self.sr_conflicts.append((st, a, "reduce")) + Productions[p.number].reduced += 1 + elif (slevel == rlevel) and (rprec == "nonassoc"): + st_action[a] = None + else: + # Hmmm. Guess we'll keep the shift + if not rlevel: + descrip.append( + f" ! shift/reduce conflict for {a} resolved as shift" + ) + self.sr_conflicts.append((st, a, "shift")) + elif r <= 0: + # Reduce/reduce conflict. In this case, we favor the rule + # that was defined first in the grammar file + oldp = Productions[-r] + pp = Productions[p.number] + if oldp.line > pp.line: + st_action[a] = -p.number + st_actionp[a] = p + chosenp, rejectp = pp, oldp + Productions[p.number].reduced += 1 + Productions[oldp.number].reduced -= 1 else: - raise LALRError(f'Unknown conflict in state {st}') + chosenp, rejectp = oldp, pp + self.rr_conflicts.append((st, chosenp, rejectp)) + descrip.append( + " ! reduce/reduce conflict for %s resolved using rule %d (%s)" + % (a, st_actionp[a].number, st_actionp[a]) + ) else: - st_action[a] = -p.number - st_actionp[a] = p - Productions[p.number].reduced += 1 - else: - i = p.lr_index - a = p.prod[i+1] # Get symbol right after the "." - if a in self.grammar.Terminals: - g = self.lr0_goto(I, a) - j = self.lr0_cidhash.get(id(g), -1) - if j >= 0: - # We are in a shift state - actlist.append((a, p, f'shift and go to state {j}')) - r = st_action.get(a) - if r is not None: - # Whoa have a shift/reduce or shift/shift conflict - if r > 0: - if r != j: - raise LALRError(f'Shift/shift conflict in state {st}') - elif r <= 0: - # Do a precedence check. - # - if precedence of reduce rule is higher, we reduce. - # - if precedence of reduce is same and left assoc, we reduce. - # - otherwise we shift - rprec, rlevel = Productions[st_actionp[a].number].prec - sprec, slevel = Precedence.get(a, ('right', 0)) - if (slevel > rlevel) or ((slevel == rlevel) and (rprec == 'right')): - # We decide to shift here... highest precedence to shift - Productions[st_actionp[a].number].reduced -= 1 - st_action[a] = j - st_actionp[a] = p - if not rlevel: - descrip.append(f' ! shift/reduce conflict for {a} resolved as shift') - self.sr_conflicts.append((st, a, 'shift')) - elif (slevel == rlevel) and (rprec == 'nonassoc'): - st_action[a] = None - else: - # Hmmm. Guess we'll keep the reduce - if not slevel and not rlevel: - descrip.append(f' ! shift/reduce conflict for {a} resolved as reduce') - self.sr_conflicts.append((st, a, 'reduce')) - + raise LALRError(f"Unknown conflict in state {st}") + else: + st_action[a] = -p.number + st_actionp[a] = p + Productions[p.number].reduced += 1 + else: + i = p.lr_index + a = p.prod[i + 1] # Get symbol right after the "." + if a in self.grammar.Terminals: + g = self.lr0_goto(I, a) + j = self.lr0_cidhash.get(id(g), -1) + if j >= 0: + # We are in a shift state + actlist.append((a, p, f"shift and go to state {j}")) + r = st_action.get(a) + if r is not None: + # Whoa have a shift/reduce or shift/shift conflict + if r > 0: + if r != j: + raise LALRError( + f"Shift/shift conflict in state {st}" + ) + elif r <= 0: + # Do a precedence check. + # - if precedence of reduce rule is higher, we reduce. + # - if precedence of reduce is same and left assoc, we reduce. + # - otherwise we shift + rprec, rlevel = Productions[ + st_actionp[a].number + ].prec + sprec, slevel = Precedence.get(a, ("right", 0)) + if (slevel > rlevel) or ( + (slevel == rlevel) and (rprec == "right") + ): + # We decide to shift here... highest precedence to shift + Productions[st_actionp[a].number].reduced -= 1 + st_action[a] = j + st_actionp[a] = p + if not rlevel: + descrip.append( + f" ! shift/reduce conflict for {a} resolved as shift" + ) + self.sr_conflicts.append((st, a, "shift")) + elif (slevel == rlevel) and (rprec == "nonassoc"): + st_action[a] = None else: - raise LALRError(f'Unknown conflict in state {st}') + # Hmmm. Guess we'll keep the reduce + if not slevel and not rlevel: + descrip.append( + f" ! shift/reduce conflict for {a} resolved as reduce" + ) + self.sr_conflicts.append((st, a, "reduce")) + else: - st_action[a] = j - st_actionp[a] = p + raise LALRError(f"Unknown conflict in state {st}") + else: + st_action[a] = j + st_actionp[a] = p # Print the actions associated with each terminal _actprint = {} for a, p, m in actlist: if a in st_action: if p is st_actionp[a]: - descrip.append(f' {a:<15s} {m}') + descrip.append(f" {a:<15s} {m}") _actprint[(a, m)] = 1 - descrip.append('') + descrip.append("") # Construct the goto table for this state nkeys = {} @@ -1503,12 +1574,12 @@ class LRTable(object): j = self.lr0_cidhash.get(id(g), -1) if j >= 0: st_goto[n] = j - descrip.append(f' {n:<30s} shift and go to state {j}') + descrip.append(f" {n:<30s} shift and go to state {j}") action[st] = st_action actionp[st] = st_actionp goto[st] = st_goto - self.state_descriptions[st] = '\n'.join(descrip) + self.state_descriptions[st] = "\n".join(descrip) # ---------------------------------------------------------------------- # Debugging output. Printing the LRTable object will produce a listing @@ -1518,28 +1589,33 @@ class LRTable(object): out = [] for descrip in self.state_descriptions.values(): out.append(descrip) - + if self.sr_conflicts or self.rr_conflicts: - out.append('\nConflicts:\n') + out.append("\nConflicts:\n") for state, tok, resolution in self.sr_conflicts: - out.append(f'shift/reduce conflict for {tok} in state {state} resolved as {resolution}') + out.append( + f"shift/reduce conflict for {tok} in state {state} resolved as {resolution}" + ) already_reported = set() for state, rule, rejected in self.rr_conflicts: if (state, id(rule), id(rejected)) in already_reported: continue - out.append(f'reduce/reduce conflict in state {state} resolved using rule {rule}') - out.append(f'rejected rule ({rejected}) in state {state}') + out.append( + f"reduce/reduce conflict in state {state} resolved using rule {rule}" + ) + out.append(f"rejected rule ({rejected}) in state {state}") already_reported.add((state, id(rule), id(rejected))) warned_never = set() for state, rule, rejected in self.rr_conflicts: if not rejected.reduced and (rejected not in warned_never): - out.append(f'Rule ({rejected}) is never reduced') + out.append(f"Rule ({rejected}) is never reduced") warned_never.add(rejected) - return '\n'.join(out) + return "\n".join(out) + # Collect grammar rules from a function def _collect_grammar_rules(func): @@ -1549,70 +1625,80 @@ def _collect_grammar_rules(func): unwrapped = inspect.unwrap(func) filename = unwrapped.__code__.co_filename lineno = unwrapped.__code__.co_firstlineno - for rule, lineno in zip(func.rules, range(lineno+len(func.rules)-1, 0, -1)): + for rule, lineno in zip(func.rules, range(lineno + len(func.rules) - 1, 0, -1)): syms = rule.split() - if syms[1:2] == [':'] or syms[1:2] == ['::=']: + if syms[1:2] == [":"] or syms[1:2] == ["::="]: grammar.append((func, filename, lineno, syms[0], syms[2:])) else: grammar.append((func, filename, lineno, prodname, syms)) - func = getattr(func, 'next_func', None) + func = getattr(func, "next_func", None) return grammar + class ParserMetaDict(dict): - ''' + """ Dictionary that allows decorated grammar rule functions to be overloaded - ''' + """ + def __setitem__(self, key, value): - if key in self and callable(value) and hasattr(value, 'rules'): + if key in self and callable(value) and hasattr(value, "rules"): value.next_func = self[key] - if not hasattr(value.next_func, 'rules'): - raise GrammarError(f'Redefinition of {key}. Perhaps an earlier {key} is missing @_') + if not hasattr(value.next_func, "rules"): + raise GrammarError( + f"Redefinition of {key}. Perhaps an earlier {key} is missing @_" + ) super().__setitem__(key, value) - + def __getitem__(self, key): - if key not in self and key.isupper() and key[:1] != '_': + if key not in self and key.isupper() and key[:1] != "_": return key.upper() else: return super().__getitem__(key) + class ParserMeta(type): @classmethod def __prepare__(meta, *args, **kwargs): d = ParserMetaDict() + def _(rule, *extra): rules = [rule, *extra] + def decorate(func): - func.rules = [ *getattr(func, 'rules', []), *rules[::-1] ] + func.rules = [*getattr(func, "rules", []), *rules[::-1]] return func + return decorate - d['_'] = _ + + d["_"] = _ return d def __new__(meta, clsname, bases, attributes): - del attributes['_'] + del attributes["_"] cls = super().__new__(meta, clsname, bases, attributes) cls._build(list(attributes.items())) return cls + class Parser(metaclass=ParserMeta): # Logging object where debugging/diagnostic messages are sent - log = SlyLogger(sys.stderr) + log = SlyLogger(sys.stderr) # Debugging filename where parsetab.out data can be written debugfile = None @classmethod def __validate_tokens(cls): - if not hasattr(cls, 'tokens'): - cls.log.error('No token list is defined') + if not hasattr(cls, "tokens"): + cls.log.error("No token list is defined") return False if not cls.tokens: - cls.log.error('tokens is empty') + cls.log.error("tokens is empty") return False - if 'error' in cls.tokens: + if "error" in cls.tokens: cls.log.error("Illegal token name 'error'. Is a reserved word") return False @@ -1620,28 +1706,32 @@ class Parser(metaclass=ParserMeta): @classmethod def __validate_precedence(cls): - if not hasattr(cls, 'precedence'): + if not hasattr(cls, "precedence"): cls.__preclist = [] return True preclist = [] if not isinstance(cls.precedence, (list, tuple)): - cls.log.error('precedence must be a list or tuple') + cls.log.error("precedence must be a list or tuple") return False for level, p in enumerate(cls.precedence, start=1): if not isinstance(p, (list, tuple)): - cls.log.error(f'Bad precedence table entry {p!r}. Must be a list or tuple') + cls.log.error( + f"Bad precedence table entry {p!r}. Must be a list or tuple" + ) return False if len(p) < 2: - cls.log.error(f'Malformed precedence entry {p!r}. Must be (assoc, term, ..., term)') + cls.log.error( + f"Malformed precedence entry {p!r}. Must be (assoc, term, ..., term)" + ) return False if not all(isinstance(term, str) for term in p): - cls.log.error('precedence items must be strings') + cls.log.error("precedence items must be strings") return False - + assoc = p[0] preclist.extend((term, assoc, level) for term in p[1:]) @@ -1650,9 +1740,9 @@ class Parser(metaclass=ParserMeta): @classmethod def __validate_specification(cls): - ''' + """ Validate various parts of the grammar specification - ''' + """ if not cls.__validate_tokens(): return False if not cls.__validate_precedence(): @@ -1661,14 +1751,14 @@ class Parser(metaclass=ParserMeta): @classmethod def __build_grammar(cls, rules): - ''' + """ Build the grammar from the grammar rules - ''' + """ grammar_rules = [] - errors = '' + errors = "" # Check for non-empty symbols if not rules: - raise YaccError('No grammar rules are defined') + raise YaccError("No grammar rules are defined") grammar = Grammar(cls.tokens) @@ -1677,95 +1767,110 @@ class Parser(metaclass=ParserMeta): try: grammar.set_precedence(term, assoc, level) except GrammarError as e: - errors += f'{e}\n' + errors += f"{e}\n" for name, func in rules: try: parsed_rule = _collect_grammar_rules(func) for pfunc, rulefile, ruleline, prodname, syms in parsed_rule: try: - grammar.add_production(prodname, syms, pfunc, rulefile, ruleline) + grammar.add_production( + prodname, syms, pfunc, rulefile, ruleline + ) except GrammarError as e: - errors += f'{e}\n' + errors += f"{e}\n" except SyntaxError as e: - errors += f'{e}\n' + errors += f"{e}\n" try: - grammar.set_start(getattr(cls, 'start', None)) + grammar.set_start(getattr(cls, "start", None)) except GrammarError as e: - errors += f'{e}\n' + errors += f"{e}\n" undefined_symbols = grammar.undefined_symbols() for sym, prod in undefined_symbols: - errors += '%s:%d: Symbol %r used, but not defined as a token or a rule\n' % (prod.file, prod.line, sym) + errors += ( + "%s:%d: Symbol %r used, but not defined as a token or a rule\n" + % (prod.file, prod.line, sym) + ) unused_terminals = grammar.unused_terminals() if unused_terminals: - unused_str = '{' + ','.join(unused_terminals) + '}' - cls.log.warning(f'Token{"(s)" if len(unused_terminals) >1 else ""} {unused_str} defined, but not used') + unused_str = "{" + ",".join(unused_terminals) + "}" + cls.log.warning( + f'Token{"(s)" if len(unused_terminals) >1 else ""} {unused_str} defined, but not used' + ) unused_rules = grammar.unused_rules() for prod in unused_rules: - cls.log.warning('%s:%d: Rule %r defined, but not used', prod.file, prod.line, prod.name) + cls.log.warning( + "%s:%d: Rule %r defined, but not used", prod.file, prod.line, prod.name + ) if len(unused_terminals) == 1: - cls.log.warning('There is 1 unused token') + cls.log.warning("There is 1 unused token") if len(unused_terminals) > 1: - cls.log.warning('There are %d unused tokens', len(unused_terminals)) + cls.log.warning("There are %d unused tokens", len(unused_terminals)) if len(unused_rules) == 1: - cls.log.warning('There is 1 unused rule') + cls.log.warning("There is 1 unused rule") if len(unused_rules) > 1: - cls.log.warning('There are %d unused rules', len(unused_rules)) + cls.log.warning("There are %d unused rules", len(unused_rules)) unreachable = grammar.find_unreachable() for u in unreachable: - cls.log.warning('Symbol %r is unreachable', u) + cls.log.warning("Symbol %r is unreachable", u) if len(undefined_symbols) == 0: infinite = grammar.infinite_cycles() for inf in infinite: - errors += 'Infinite recursion detected for symbol %r\n' % inf + errors += "Infinite recursion detected for symbol %r\n" % inf unused_prec = grammar.unused_precedence() for term, assoc in unused_prec: - errors += 'Precedence rule %r defined for unknown symbol %r\n' % (assoc, term) + errors += "Precedence rule %r defined for unknown symbol %r\n" % ( + assoc, + term, + ) cls._grammar = grammar if errors: - raise YaccError('Unable to build grammar.\n'+errors) + raise YaccError("Unable to build grammar.\n" + errors) @classmethod def __build_lrtables(cls): - ''' + """ Build the LR Parsing tables from the grammar - ''' + """ lrtable = LRTable(cls._grammar) num_sr = len(lrtable.sr_conflicts) # Report shift/reduce and reduce/reduce conflicts - if num_sr != getattr(cls, 'expected_shift_reduce', None): + if num_sr != getattr(cls, "expected_shift_reduce", None): if num_sr == 1: - cls.log.warning('1 shift/reduce conflict') + cls.log.warning("1 shift/reduce conflict") elif num_sr > 1: - cls.log.warning('%d shift/reduce conflicts', num_sr) + cls.log.warning("%d shift/reduce conflicts", num_sr) num_rr = len(lrtable.rr_conflicts) - if num_rr != getattr(cls, 'expected_reduce_reduce', None): + if num_rr != getattr(cls, "expected_reduce_reduce", None): if num_rr == 1: - cls.log.warning('1 reduce/reduce conflict') + cls.log.warning("1 reduce/reduce conflict") elif num_rr > 1: - cls.log.warning('%d reduce/reduce conflicts', num_rr) + cls.log.warning("%d reduce/reduce conflicts", num_rr) cls._lrtable = lrtable return True @classmethod def __collect_rules(cls, definitions): - ''' + """ Collect all of the tagged grammar rules - ''' - rules = [ (name, value) for name, value in definitions - if callable(value) and hasattr(value, 'rules') ] + """ + rules = [ + (name, value) + for name, value in definitions + if callable(value) and hasattr(value, "rules") + ] return rules # ---------------------------------------------------------------------- @@ -1775,7 +1880,7 @@ class Parser(metaclass=ParserMeta): # ---------------------------------------------------------------------- @classmethod def _build(cls, definitions): - if vars(cls).get('_build', False): + if vars(cls).get("_build", False): return # Collect all of the grammar rules from the class definition @@ -1783,77 +1888,89 @@ class Parser(metaclass=ParserMeta): # Validate other parts of the grammar specification if not cls.__validate_specification(): - raise YaccError('Invalid parser specification') + raise YaccError("Invalid parser specification") # Build the underlying grammar object cls.__build_grammar(rules) # Build the LR tables if not cls.__build_lrtables(): - raise YaccError('Can\'t build parsing tables') + raise YaccError("Can't build parsing tables") if cls.debugfile: - with open(cls.debugfile, 'w') as f: + with open(cls.debugfile, "w") as f: f.write(str(cls._grammar)) - f.write('\n') + f.write("\n") f.write(str(cls._lrtable)) - cls.log.info('Parser debugging for %s written to %s', cls.__qualname__, cls.debugfile) + cls.log.info( + "Parser debugging for %s written to %s", cls.__qualname__, cls.debugfile + ) # ---------------------------------------------------------------------- # Parsing Support. This is the parsing runtime that users use to # ---------------------------------------------------------------------- def error(self, token): - ''' + """ Default error handling function. This may be subclassed. - ''' + """ if token: - lineno = getattr(token, 'lineno', 0) + lineno = getattr(token, "lineno", 0) if lineno: - sys.stderr.write(f'sly: Syntax error at line {lineno}, token={token.type}\n') + sys.stderr.write( + f"sly: Syntax error at line {lineno}, token={token.type}\n" + ) else: - sys.stderr.write(f'sly: Syntax error, token={token.type}') + sys.stderr.write(f"sly: Syntax error, token={token.type}") else: - sys.stderr.write('sly: Parse error in input. EOF\n') - + sys.stderr.write("sly: Parse error in input. EOF\n") + def errok(self): - ''' + """ Clear the error status - ''' + """ self.errorok = True def restart(self): - ''' + """ Force the parser to restart from a fresh state. Clears the statestack - ''' + """ del self.statestack[:] del self.symstack[:] sym = YaccSymbol() - sym.type = '$end' + sym.type = "$end" self.symstack.append(sym) self.statestack.append(0) self.state = 0 def parse(self, tokens): - ''' + """ Parse the given input tokens. - ''' - lookahead = None # Current lookahead symbol - lookaheadstack = [] # Stack of lookahead symbols - actions = self._lrtable.lr_action # Local reference to action table (to avoid lookup on self.) - goto = self._lrtable.lr_goto # Local reference to goto table (to avoid lookup on self.) - prod = self._grammar.Productions # Local reference to production list (to avoid lookup on self.) - defaulted_states = self._lrtable.defaulted_states # Local reference to defaulted states - pslice = YaccProduction(None) # Production object passed to grammar rules - errorcount = 0 # Used during error recovery + """ + lookahead = None # Current lookahead symbol + lookaheadstack = [] # Stack of lookahead symbols + actions = ( + self._lrtable.lr_action + ) # Local reference to action table (to avoid lookup on self.) + goto = ( + self._lrtable.lr_goto + ) # Local reference to goto table (to avoid lookup on self.) + prod = ( + self._grammar.Productions + ) # Local reference to production list (to avoid lookup on self.) + defaulted_states = ( + self._lrtable.defaulted_states + ) # Local reference to defaulted states + pslice = YaccProduction(None) # Production object passed to grammar rules + errorcount = 0 # Used during error recovery # Set up the state and symbol stacks self.tokens = tokens - self.statestack = statestack = [] # Stack of parsing states - self.symstack = symstack = [] # Stack of grammar symbols - pslice._stack = symstack # Associate the stack with the production + self.statestack = statestack = [] # Stack of parsing states + self.symstack = symstack = [] # Stack of grammar symbols + pslice._stack = symstack # Associate the stack with the production self.restart() - errtoken = None # Err token + errtoken = None # Err token while True: # Get the next symbol on the input. If a lookahead symbol # is already set, we just use that. Otherwise, we'll pull @@ -1866,7 +1983,7 @@ class Parser(metaclass=ParserMeta): lookahead = lookaheadstack.pop() if not lookahead: lookahead = YaccSymbol() - lookahead.type = '$end' + lookahead.type = "$end" # Check the action table ltype = lookahead.type @@ -1892,14 +2009,14 @@ class Parser(metaclass=ParserMeta): # reduce a symbol on the stack, emit a production self.production = p = prod[-t] pname = p.name - plen = p.len + plen = p.len pslice._namemap = p.namemap # Call the production function pslice._slice = symstack[-plen:] if plen else [] sym = YaccSymbol() - sym.type = pname + sym.type = pname value = p.func(self, pslice) if value is pslice: value = (pname, *(s.value for s in pslice._slice)) @@ -1915,7 +2032,7 @@ class Parser(metaclass=ParserMeta): if t == 0: n = symstack[-1] - result = getattr(n, 'value', None) + result = getattr(n, "value", None) return result if t is None: @@ -1932,8 +2049,8 @@ class Parser(metaclass=ParserMeta): if errorcount == 0 or self.errorok: errorcount = ERROR_COUNT self.errorok = False - if lookahead.type == '$end': - errtoken = None # End of file! + if lookahead.type == "$end": + errtoken = None # End of file! else: errtoken = lookahead @@ -1957,7 +2074,7 @@ class Parser(metaclass=ParserMeta): # entire parse has been rolled back and we're completely hosed. The token is # discarded and we just keep going. - if len(statestack) <= 1 and lookahead.type != '$end': + if len(statestack) <= 1 and lookahead.type != "$end": lookahead = None self.state = 0 # Nuke the lookahead stack @@ -1968,13 +2085,13 @@ class Parser(metaclass=ParserMeta): # at the end of the file. nuke the top entry and generate an error token # Start nuking entries on the stack - if lookahead.type == '$end': + if lookahead.type == "$end": # Whoa. We're really hosed here. Bail out return - if lookahead.type != 'error': + if lookahead.type != "error": sym = symstack[-1] - if sym.type == 'error': + if sym.type == "error": # Hmmm. Error is on top of stack, we'll just nuke input # symbol and continue lookahead = None @@ -1982,11 +2099,11 @@ class Parser(metaclass=ParserMeta): # Create the error symbol for the first time and make it the new lookahead symbol t = YaccSymbol() - t.type = 'error' + t.type = "error" - if hasattr(lookahead, 'lineno'): + if hasattr(lookahead, "lineno"): t.lineno = lookahead.lineno - if hasattr(lookahead, 'index'): + if hasattr(lookahead, "index"): t.index = lookahead.index t.value = lookahead lookaheadstack.append(lookahead) @@ -1998,4 +2115,4 @@ class Parser(metaclass=ParserMeta): continue # Call an error function here - raise RuntimeError('sly: internal parser error!!!\n') + raise RuntimeError("sly: internal parser error!!!\n") diff --git a/lib/utils.py b/lib/utils.py index 26a591e..91dded0 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -21,15 +21,25 @@ def running_mean(x: np.ndarray, N: int) -> np.ndarray: :param x: 1-Dimensional NumPy array :param N: how many items to average """ + # FIXME np.insert(x, 0, [x[0] for i in range(N/2)]) + # FIXME np.insert(x, -1, [x[-1] for i in range(N/2)]) + # (dabei ungerade N beachten) cumsum = np.cumsum(np.insert(x, 0, 0)) return (cumsum[N:] - cumsum[:-N]) / N def human_readable(value, unit): - for prefix, factor in (('p', 1e-12), ('n', 1e-9), (u'µ', 1e-6), ('m', 1e-3), ('', 1), ('k', 1e3)): + for prefix, factor in ( + ("p", 1e-12), + ("n", 1e-9), + (u"µ", 1e-6), + ("m", 1e-3), + ("", 1), + ("k", 1e3), + ): if value < 1e3 * factor: - return '{:.2f} {}{}'.format(value * (1 / factor), prefix, unit) - return '{:.2f} {}'.format(value, unit) + return "{:.2f} {}{}".format(value * (1 / factor), prefix, unit) + return "{:.2f} {}".format(value, unit) def is_numeric(n): @@ -65,7 +75,7 @@ def soft_cast_int(n): If `n` is empty, returns None. If `n` is not numeric, it is left unchanged. """ - if n is None or n == '': + if n is None or n == "": return None try: return int(n) @@ -80,7 +90,7 @@ def soft_cast_float(n): If `n` is empty, returns None. If `n` is not numeric, it is left unchanged. """ - if n is None or n == '': + if n is None or n == "": return None try: return float(n) @@ -104,8 +114,8 @@ def parse_conf_str(conf_str): Values are casted to float if possible and kept as-is otherwise. """ conf_dict = dict() - for option in conf_str.split(','): - key, value = option.split('=') + for option in conf_str.split(","): + key, value = option.split("=") conf_dict[key] = soft_cast_float(value) return conf_dict @@ -118,7 +128,7 @@ def remove_index_from_tuple(parameters, index): :param index: index of element which is to be removed :returns: parameters tuple without the element at index """ - return (*parameters[:index], *parameters[index + 1:]) + return (*parameters[:index], *parameters[index + 1 :]) def param_slice_eq(a, b, index): @@ -137,7 +147,9 @@ def param_slice_eq(a, b, index): ('foo', [1, 4]), ('foo', [2, 4]), 1 -> False """ - if (*a[1][:index], *a[1][index + 1:]) == (*b[1][:index], *b[1][index + 1:]) and a[0] == b[0]: + if (*a[1][:index], *a[1][index + 1 :]) == (*b[1][:index], *b[1][index + 1 :]) and a[ + 0 + ] == b[0]: return True return False @@ -164,20 +176,20 @@ def by_name_to_by_param(by_name: dict): """ by_param = dict() for name in by_name.keys(): - for i, parameters in enumerate(by_name[name]['param']): + for i, parameters in enumerate(by_name[name]["param"]): param_key = (name, tuple(parameters)) if param_key not in by_param: by_param[param_key] = dict() for key in by_name[name].keys(): by_param[param_key][key] = list() - by_param[param_key]['attributes'] = by_name[name]['attributes'] + by_param[param_key]["attributes"] = by_name[name]["attributes"] # special case for PTA models - if 'isa' in by_name[name]: - by_param[param_key]['isa'] = by_name[name]['isa'] - for attribute in by_name[name]['attributes']: + if "isa" in by_name[name]: + by_param[param_key]["isa"] = by_name[name]["isa"] + for attribute in by_name[name]["attributes"]: by_param[param_key][attribute].append(by_name[name][attribute][i]) # Required for match_parameter_valuse in _try_fits - by_param[param_key]['param'].append(by_name[name]['param'][i]) + by_param[param_key]["param"].append(by_name[name]["param"][i]) return by_param @@ -197,14 +209,26 @@ def filter_aggregate_by_param(aggregate, parameters, parameter_filter): param_value = soft_cast_int(param_name_and_value[1]) names_to_remove = set() for name in aggregate.keys(): - indices_to_keep = list(map(lambda x: x[param_index] == param_value, aggregate[name]['param'])) - aggregate[name]['param'] = list(map(lambda iv: iv[1], filter(lambda iv: indices_to_keep[iv[0]], enumerate(aggregate[name]['param'])))) + indices_to_keep = list( + map(lambda x: x[param_index] == param_value, aggregate[name]["param"]) + ) + aggregate[name]["param"] = list( + map( + lambda iv: iv[1], + filter( + lambda iv: indices_to_keep[iv[0]], + enumerate(aggregate[name]["param"]), + ), + ) + ) if len(indices_to_keep) == 0: - print('??? {}->{}'.format(parameter_filter, name)) + print("??? {}->{}".format(parameter_filter, name)) names_to_remove.add(name) else: - for attribute in aggregate[name]['attributes']: - aggregate[name][attribute] = aggregate[name][attribute][indices_to_keep] + for attribute in aggregate[name]["attributes"]: + aggregate[name][attribute] = aggregate[name][attribute][ + indices_to_keep + ] if len(aggregate[name][attribute]) == 0: names_to_remove.add(name) for name in names_to_remove: @@ -218,25 +242,25 @@ class OptionalTimingAnalysis: self.index = 1 def get_header(self): - ret = '' + ret = "" if self.enabled: - ret += '#define TIMEIT(index, functioncall) ' - ret += 'counter.start(); ' - ret += 'functioncall; ' - ret += 'counter.stop();' + ret += "#define TIMEIT(index, functioncall) " + ret += "counter.start(); " + ret += "functioncall; " + ret += "counter.stop();" ret += 'kout << endl << index << " :: " << counter.value << "/" << counter.overflow << endl;\n' return ret def wrap_codeblock(self, codeblock): if not self.enabled: return codeblock - lines = codeblock.split('\n') + lines = codeblock.split("\n") ret = list() for line in lines: - if re.fullmatch('.+;', line): - ret.append('TIMEIT( {:d}, {} )'.format(self.index, line)) + if re.fullmatch(".+;", line): + ret.append("TIMEIT( {:d}, {} )".format(self.index, line)) self.wrapped_lines.append(line) self.index += 1 else: ret.append(line) - return '\n'.join(ret) + return "\n".join(ret) |