diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2019-11-21 08:17:04 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2019-11-21 08:17:04 +0100 |
commit | 0d4616db3975919c46b73cfbd4c6054b94e55aa6 (patch) | |
tree | ce3b4eaadb62583d155b804113008794ab2535fd | |
parent | 68c54f846d941ab2bb7367c5ac5dad091b5fde9f (diff) |
flake8
-rwxr-xr-x | bin/analyze-archive.py | 45 | ||||
-rwxr-xr-x | bin/analyze-timing.py | 30 | ||||
-rwxr-xr-x | bin/analyze.py | 40 | ||||
-rwxr-xr-x | lib/automata.py | 164 |
4 files changed, 136 insertions, 143 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py index 787510d..8470ab6 100755 --- a/bin/analyze-archive.py +++ b/bin/analyze-archive.py @@ -83,13 +83,14 @@ import plotter import re import sys from dfatool import PTAModel, RawData, pta_trace_to_aggregate -from dfatool import soft_cast_int, is_numeric, gplearn_to_function +from dfatool import gplearn_to_function from dfatool import CrossValidator from utils import filter_aggregate_by_param from automata import PTA opts = {} + def print_model_quality(results): for state_or_tran in results.keys(): print() @@ -101,12 +102,14 @@ def print_model_quality(results): print('{:20s} {:15s} {:.0f}'.format( state_or_tran, key, result['mae'])) + def format_quality_measures(result): if 'smape' in result: return '{:6.2f}% / {:9.0f}'.format(result['smape'], result['mae']) else: return '{:6} {:9.0f}'.format('', result['mae']) + def model_quality_table(result_lists, info_list): for state_or_tran in result_lists[0]['by_name'].keys(): for key in result_lists[0]['by_name'][state_or_tran].keys(): @@ -114,13 +117,14 @@ def model_quality_table(result_lists, info_list): for i, results in enumerate(result_lists): info = info_list[i] buf += ' ||| ' - if info == None or info(state_or_tran, key): + if info is None or info(state_or_tran, key): result = results['by_name'][state_or_tran][key] buf += format_quality_measures(result) else: buf += '{:6}----{:9}'.format('', '') print(buf) + def model_summary_table(result_list): buf = 'transition duration' for results in result_list: @@ -171,6 +175,7 @@ def print_text_model_data(model, pm, pq, lm, lq, am, ai, aq): for arg_index in range(model._num_args[state_or_tran]): print('{} {} {:d} {:.8f}'.format(state_or_tran, attribute, arg_index, model.stats.arg_dependence_ratio(state_or_tran, attribute, arg_index))) + def print_html_model_data(model, pm, pq, lm, lq, am, ai, aq): state_attributes = model.attributes(model.states()[0]) @@ -204,6 +209,7 @@ def print_html_model_data(model, pm, pq, lm, lq, am, ai, aq): print('</tr>') print('</table>') + if __name__ == '__main__': ignored_trace_indexes = [] @@ -282,10 +288,10 @@ if __name__ == '__main__': filter_aggregate_by_param(by_name, parameters, opts['filter-param']) model = PTAModel(by_name, parameters, arg_count, - traces = preprocessed_data, - discard_outliers = discard_outliers, - function_override = function_override, - pta = pta) + traces=preprocessed_data, + discard_outliers=discard_outliers, + function_override=function_override, + pta=pta) if xv_method: xv = CrossValidator(PTAModel, by_name, parameters, arg_count) @@ -299,8 +305,8 @@ if __name__ == '__main__': if 'plot-unparam' in opts: for kv in opts['plot-unparam'].split(';'): state_or_trans, attribute, ylabel = kv.split(':') - fname = 'param_y_{}_{}.pdf'.format(state_or_trans,attribute) - plotter.plot_y(model.by_name[state_or_trans][attribute], xlabel = 'measurement #', ylabel = ylabel, output = fname) + fname = 'param_y_{}_{}.pdf'.format(state_or_trans, attribute) + plotter.plot_y(model.by_name[state_or_trans][attribute], xlabel='measurement #', ylabel=ylabel, output=fname) if len(show_models): print('--- simple static model ---') @@ -361,7 +367,7 @@ if __name__ == '__main__': if len(show_models): print('--- param model ---') - param_model, param_info = model.get_fitted(safe_functions_enabled = safe_functions_enabled) + param_model, param_info = model.get_fitted(safe_functions_enabled=safe_functions_enabled) if 'paramdetection' in show_models or 'all' in show_models: for state in model.states_and_transitions(): @@ -377,7 +383,7 @@ if __name__ == '__main__': print('{:10s} {:10s} {:10s} stddev {:f}'.format( state, attribute, param, model.stats.stats[state][attribute]['std_by_param'][param] )) - if info != None: + if info is not None: for param_name in sorted(info['fit_result'].keys(), key=str): param_fit = info['fit_result'][param_name]['results'] for function_type in sorted(param_fit.keys()): @@ -413,10 +419,20 @@ if __name__ == '__main__': if 'table' in show_quality or 'all' in show_quality: model_quality_table([static_quality, analytic_quality, lut_quality], [None, param_info, None]) + if 'overall' in show_quality or 'all' in show_quality: - print('overall MAE of static model: {} µW'.format(model.assess_states(static_model))) - print('overall MAE of param model: {} µW'.format(model.assess_states(param_model))) - print('overall MAE of LUT model: {} µW'.format(model.assess_states(lut_model))) + print('overall static/param/lut MAE assuming equal state distribution:') + print(' {:6.1f} / {:6.1f} / {:6.1f} µW'.format( + model.assess_states(static_model), + model.assess_states(param_model), + model.assess_states(lut_model))) + print('overall static/param/lut MAE assuming 95% STANDBY1:') + distrib = {'STANDBY1': 0.95, 'POWERDOWN': 0.03, 'TX': 0.01, 'RX': 0.01} + print(' {:6.1f} / {:6.1f} / {:6.1f} µW'.format( + model.assess_states(static_model, distribution=distrib), + model.assess_states(param_model, distribution=distrib), + model.assess_states(lut_model, distribution=distrib))) + if 'summary' in show_quality or 'all' in show_quality: model_summary_table([model.assess_on_traces(static_model), model.assess_on_traces(param_model), model.assess_on_traces(lut_model)]) @@ -435,7 +451,6 @@ if __name__ == '__main__': sys.exit(1) json_model = model.to_json() with open(opts['export-energymodel'], 'w') as f: - json.dump(json_model, f, indent = 2, sort_keys = True) - + json.dump(json_model, f, indent=2, sort_keys=True) sys.exit(0) diff --git a/bin/analyze-timing.py b/bin/analyze-timing.py index 6c84a67..9a3aa41 100755 --- a/bin/analyze-timing.py +++ b/bin/analyze-timing.py @@ -79,14 +79,14 @@ import plotter import re import sys from dfatool import AnalyticModel, TimingData, pta_trace_to_aggregate -from dfatool import soft_cast_int, is_numeric, gplearn_to_function +from dfatool import gplearn_to_function from dfatool import CrossValidator from utils import filter_aggregate_by_param from parameters import prune_dependent_parameters -import utils opts = {} + def print_model_quality(results): for state_or_tran in results.keys(): print() @@ -98,12 +98,14 @@ def print_model_quality(results): print('{:20s} {:15s} {:.0f}'.format( state_or_tran, key, result['mae'])) + def format_quality_measures(result): if 'smape' in result: return '{:6.2f}% / {:9.0f}'.format(result['smape'], result['mae']) else: return '{:6} {:9.0f}'.format('', result['mae']) + def model_quality_table(result_lists, info_list): for state_or_tran in result_lists[0]['by_name'].keys(): for key in result_lists[0]['by_name'][state_or_tran].keys(): @@ -111,7 +113,7 @@ def model_quality_table(result_lists, info_list): for i, results in enumerate(result_lists): info = info_list[i] buf += ' ||| ' - if info == None or info(state_or_tran, key): + if info is None or info(state_or_tran, key): result = results['by_name'][state_or_tran][key] buf += format_quality_measures(result) else: @@ -136,6 +138,7 @@ def print_text_model_data(model, pm, pq, lm, lq, am, ai, aq): for arg_index in range(model._num_args[state_or_tran]): print('{} {} {:d} {:.8f}'.format(state_or_tran, attribute, arg_index, model.stats.arg_dependence_ratio(state_or_tran, attribute, arg_index))) + if __name__ == '__main__': ignored_trace_indexes = [] @@ -215,7 +218,7 @@ if __name__ == '__main__': filter_aggregate_by_param(by_name, parameters, opts['filter-param']) - model = AnalyticModel(by_name, parameters, arg_count, use_corrcoef = opts['corrcoef'], function_override = function_override) + model = AnalyticModel(by_name, parameters, arg_count, use_corrcoef=opts['corrcoef'], function_override=function_override) if xv_method: xv = CrossValidator(AnalyticModel, by_name, parameters, arg_count) @@ -229,8 +232,8 @@ if __name__ == '__main__': if 'plot-unparam' in opts: for kv in opts['plot-unparam'].split(';'): state_or_trans, attribute, ylabel = kv.split(':') - fname = 'param_y_{}_{}.pdf'.format(state_or_trans,attribute) - plotter.plot_y(model.by_name[state_or_trans][attribute], xlabel = 'measurement #', ylabel = ylabel) + fname = 'param_y_{}_{}.pdf'.format(state_or_trans, attribute) + plotter.plot_y(model.by_name[state_or_trans][attribute], xlabel='measurement #', ylabel=ylabel) if len(show_models): print('--- simple static model ---') @@ -247,6 +250,15 @@ if __name__ == '__main__': print('{:24s} co-dependencies: {:s}'.format('', ', '.join(model.stats.codependent_parameters(trans, 'duration', param)))) for param_dict in model.stats.codependent_parameter_value_dicts(trans, 'duration', param): print('{:24s} parameter-aware for {}'.format('', param_dict)) + # import numpy as np + # safe_div = np.vectorize(lambda x,y: 0. if x == 0 else 1 - x/y) + # ratio_by_value = safe_div(model.stats.stats['write']['duration']['lut_by_param_values']['max_retry_count'], model.stats.stats['write']['duration']['std_by_param_values']['max_retry_count']) + # err_mode = np.seterr('warn') + # dep_by_value = ratio_by_value > 0.5 + # np.seterr(**err_mode) + # Eigentlich sollte hier ein paar mal True stehen, ist aber nicht so... + # und warum ist da eine non-power-of-two Zahl von True-Einträgen in der Matrix? 3 stück ist komisch... + # print(dep_by_value) if xv_method == 'montecarlo': static_quality = xv.montecarlo(lambda m: m.get_static(), xv_count) @@ -265,7 +277,7 @@ if __name__ == '__main__': if len(show_models): print('--- param model ---') - param_model, param_info = model.get_fitted(safe_functions_enabled = safe_functions_enabled) + param_model, param_info = model.get_fitted(safe_functions_enabled=safe_functions_enabled) if 'paramdetection' in show_models or 'all' in show_models: for transition in model.names: @@ -289,7 +301,7 @@ if __name__ == '__main__': )) print('{:10s} {:10s} dependence on arg{:d}: {:.2f}'.format( transition, attribute, i, model.stats.arg_dependence_ratio(transition, attribute, i))) - if info != None: + if info is not None: for param_name in sorted(info['fit_result'].keys(), key=str): param_fit = info['fit_result'][param_name]['results'] for function_type in sorted(param_fit.keys()): @@ -325,6 +337,4 @@ if __name__ == '__main__': function = None plotter.plot_param(model, state_or_trans, attribute, model.param_index(param_name), extra_function=function) - - sys.exit(0) diff --git a/bin/analyze.py b/bin/analyze.py deleted file mode 100755 index 57803fe..0000000 --- a/bin/analyze.py +++ /dev/null @@ -1,40 +0,0 @@ -#!/usr/bin/env python3 - -import json -import numpy as np -import os -from scipy.cluster.vq import kmeans2 -import struct -import sys -import tarfile -from dfatool import running_mean, MIMOSA - -voltage = float(sys.argv[1]) -shunt = float(sys.argv[2]) -filename = sys.argv[3] - -mim = MIMOSA(voltage, shunt) - -charges, triggers = mim.load_data(filename) -trigidx = mim.trigger_edges(triggers) -triggers = [] -cal_edges = mim.calibration_edges(running_mean(mim.currents_nocal(charges[0:trigidx[0]]), 10)) -calfunc, caldata = mim.calibration_function(charges, cal_edges) -vcalfunc = np.vectorize(calfunc, otypes=[np.float64]) - -json_out = { - 'triggers' : len(trigidx), - 'first_trig' : trigidx[0] * 10, - 'calibration' : caldata, - 'trace' : mim.analyze_states(charges, trigidx, vcalfunc) -} - -basename, _ = os.path.splitext(filename) - -# TODO also look for interesting gradients inside each state - -with open(basename + ".json", "w") as f: - json.dump(json_out, f) - f.close() - -#print(kmeans2(charges[:firstidx], np.array([130 * ua_step, 3.6 / 987 * 1000000, 3.6 / 99300 * 1000000]))) diff --git a/lib/automata.py b/lib/automata.py index 73c0225..4aa9a97 100755 --- a/lib/automata.py +++ b/lib/automata.py @@ -4,27 +4,30 @@ from functions import AnalyticFunction, NormalizationFunction from utils import is_numeric import itertools import numpy as np -import json, yaml +import json +import yaml + def _dict_to_list(input_dict: dict) -> list: return [input_dict[x] for x in sorted(input_dict.keys())] + def _attribute_to_json(static_value: float, param_function: AnalyticFunction) -> dict: ret = { - 'static' : static_value + 'static': static_value } if param_function: ret['function'] = { - 'raw' : param_function._model_str, - 'regression_args' : list(param_function._regression_args) + 'raw': param_function._model_str, + 'regression_args': list(param_function._regression_args) } return ret + class State: """A single PTA state.""" - def __init__(self, name: str, power: float = 0, - power_function: AnalyticFunction = None): + def __init__(self, name: str, power: float = 0, power_function: AnalyticFunction = None): u""" Create a new PTA state. @@ -53,7 +56,7 @@ class State: return self.power_function.eval(_dict_to_list(param_dict)) * duration return self.power * duration - def set_random_energy_model(self, static_model = True): + def set_random_energy_model(self, static_model=True): """Set a random static energy value.""" self.power = int(np.random.sample() * 50000) @@ -83,10 +86,10 @@ class State: :returns: Transition object """ interrupts = filter(lambda x: x.is_interrupt, self.outgoing_transitions.values()) - interrupts = sorted(interrupts, key = lambda x: x.get_timeout(parameters)) + interrupts = sorted(interrupts, key=lambda x: x.get_timeout(parameters)) return interrupts[0] - def dfs(self, depth: int, with_arguments: bool = False, trace_filter = None, sleep: int = 0): + def dfs(self, depth: int, with_arguments=False, trace_filter=None, sleep=0): """ Return a generator object for depth-first search over all outgoing transitions. @@ -140,7 +143,7 @@ class State: new_trace_filter = None else: new_trace_filter = None - for suffix in trans.destination.dfs(depth - 1, with_arguments = with_arguments, trace_filter = new_trace_filter, sleep = sleep): + for suffix in trans.destination.dfs(depth - 1, with_arguments=with_arguments, trace_filter=new_trace_filter, sleep=sleep): if with_arguments: if trans.argument_combination == 'cartesian': for args in itertools.product(*trans.argument_values): @@ -173,11 +176,12 @@ class State: def to_json(self) -> dict: """Return JSON encoding of this state object.""" ret = { - 'name' : self.name, - 'power' : _attribute_to_json(self.power, self.power_function) + 'name': self.name, + 'power': _attribute_to_json(self.power, self.power_function) } return ret + class Transition: u""" A single PTA transition with one origin and one destination state. @@ -203,18 +207,18 @@ class Transition: """ def __init__(self, orig_state: State, dest_state: State, name: str, - energy: float = 0, energy_function: AnalyticFunction = None, - duration: float = 0, duration_function: AnalyticFunction = None, - timeout: float = 0, timeout_function: AnalyticFunction = None, - is_interrupt: bool = False, - arguments: list = [], - argument_values: list = [], - argument_combination: str = 'cartesian', # or 'zip' - param_update_function = None, - arg_to_param_map: dict = None, - set_param = None, - return_value_handlers: list = [], - codegen = dict()): + energy: float = 0, energy_function: AnalyticFunction = None, + duration: float = 0, duration_function: AnalyticFunction = None, + timeout: float = 0, timeout_function: AnalyticFunction = None, + is_interrupt: bool = False, + arguments: list = [], + argument_values: list = [], + argument_combination: str = 'cartesian', # or 'zip' + param_update_function=None, + arg_to_param_map: dict = None, + set_param=None, + return_value_handlers: list = [], + codegen=dict()): """ Create a new transition between two PTA states. @@ -245,7 +249,6 @@ class Transition: if 'formula' in handler: handler['formula'] = NormalizationFunction(handler['formula']) - def get_duration(self, param_dict: dict = {}, args: list = []) -> float: u""" Return transition duration in µs. @@ -270,7 +273,7 @@ class Transition: return self.energy_function.eval(_dict_to_list(param_dict), args) return self.energy - def set_random_energy_model(self, static_model = True): + def set_random_energy_model(self, static_model=True): self.energy = int(np.random.sample() * 50000) def get_timeout(self, param_dict: dict = {}) -> float: @@ -309,32 +312,35 @@ class Transition: def to_json(self) -> dict: """Return JSON encoding of this transition object.""" ret = { - 'name' : self.name, - 'origin' : self.origin.name, - 'destination' : self.destination.name, - 'is_interrupt' : self.is_interrupt, - 'arguments' : self.arguments, - 'argument_values' : self.argument_values, - 'argument_combination' : self.argument_combination, - 'arg_to_param_map' : self.arg_to_param_map, - 'set_param' : self.set_param, - 'duration' : _attribute_to_json(self.duration, self.duration_function), - 'energy' : _attribute_to_json(self.energy, self.energy_function), - 'timeout' : _attribute_to_json(self.timeout, self.timeout_function), + 'name': self.name, + 'origin': self.origin.name, + 'destination': self.destination.name, + 'is_interrupt': self.is_interrupt, + 'arguments': self.arguments, + 'argument_values': self.argument_values, + 'argument_combination': self.argument_combination, + 'arg_to_param_map': self.arg_to_param_map, + 'set_param': self.set_param, + 'duration': _attribute_to_json(self.duration, self.duration_function), + 'energy': _attribute_to_json(self.energy, self.energy_function), + 'timeout': _attribute_to_json(self.timeout, self.timeout_function), } return ret + def _json_function_to_analytic_function(base, attribute: str, parameters: list): if attribute in base and 'function' in base[attribute]: base = base[attribute]['function'] - return AnalyticFunction(base['raw'], parameters, 0, regression_args = base['regression_args']) + return AnalyticFunction(base['raw'], parameters, 0, regression_args=base['regression_args']) return None + def _json_get_static(base, attribute: str): if attribute in base: return base[attribute]['static'] return 0 + class PTA: """ A parameterized priced timed automaton. All states are accepting. @@ -354,9 +360,9 @@ class PTA: """ def __init__(self, state_names: list = [], - accepting_states: list = None, - parameters: list = [], initial_param_values: list = None, - codegen: dict = {}, parameter_normalization: dict = None): + accepting_states: list = None, + parameters: list = [], initial_param_values: list = None, + codegen: dict = {}, parameter_normalization: dict = None): """ Return a new PTA object. @@ -383,7 +389,7 @@ class PTA: self.initial_param_values = [None for x in self.parameters] self.transitions = [] - if not 'UNINITIALIZED' in state_names: + if 'UNINITIALIZED' not in state_names: self.state['UNINITIALIZED'] = State('UNINITIALIZED') if self.parameter_normalization: @@ -451,7 +457,7 @@ class PTA: pta = cls(**kwargs) for name, state in json_input['state'].items(): power_function = _json_function_to_analytic_function(state, 'power', pta.parameters) - pta.add_state(name, power = _json_get_static(state, 'power'), power_function = power_function) + pta.add_state(name, power=_json_get_static(state, 'power'), power_function=power_function) for transition in json_input['transitions']: duration_function = _json_function_to_analytic_function(transition, 'duration', pta.parameters) energy_function = _json_function_to_analytic_function(transition, 'energy', pta.parameters) @@ -465,15 +471,14 @@ class PTA: origins = [origins] for origin in origins: pta.add_transition(origin, transition['destination'], - transition['name'], - duration = _json_get_static(transition, 'duration'), - duration_function = duration_function, - energy = _json_get_static(transition, 'energy'), - energy_function = energy_function, - timeout = _json_get_static(transition, 'timeout'), - timeout_function = timeout_function, - **kwargs - ) + transition['name'], + duration=_json_get_static(transition, 'duration'), + duration_function=duration_function, + energy=_json_get_static(transition, 'energy'), + energy_function=energy_function, + timeout=_json_get_static(transition, 'timeout'), + timeout_function=timeout_function, + **kwargs) return pta @@ -485,7 +490,7 @@ class PTA: Compatible with the legacy dfatool/perl format. """ kwargs = { - 'parameters' : list(), + 'parameters': list(), 'initial_param_values': list(), } @@ -496,7 +501,7 @@ class PTA: pta = cls(**kwargs) for name, state in json_input['state'].items(): - pta.add_state(name, power = float(state['power']['static'])) + pta.add_state(name, power=float(state['power']['static'])) for trans_name in sorted(json_input['transition'].keys()): transition = json_input['transition'][trans_name] @@ -513,8 +518,9 @@ class PTA: argument_values.append(arg['values']) for origin in transition['origins']: pta.add_transition(origin, destination, trans_name, - arguments = arguments, argument_values = argument_values, - is_interrupt = is_interrupt) + arguments=arguments, + argument_values=argument_values, + is_interrupt=is_interrupt) return pta @@ -565,19 +571,21 @@ class PTA: if 'loop' in transition: for state_name in transition['loop']: pta.add_transition(state_name, state_name, trans_name, - arguments = arguments, argument_values = argument_values, - arg_to_param_map = arg_to_param_map, - **kwargs) + arguments=arguments, + argument_values=argument_values, + arg_to_param_map=arg_to_param_map, + **kwargs) else: - if not 'src' in transition: + if 'src' not in transition: transition['src'] = ['UNINITIALIZED'] - if not 'dst' in transition: + if 'dst' not in transition: transition['dst'] = 'UNINITIALIZED' for origin in transition['src']: pta.add_transition(origin, transition['dst'], trans_name, - arguments = arguments, argument_values = argument_values, - arg_to_param_map = arg_to_param_map, - **kwargs) + arguments=arguments, + argument_values=argument_values, + arg_to_param_map=arg_to_param_map, + **kwargs) return pta @@ -588,11 +596,11 @@ class PTA: Compatible with the from_json method. """ ret = { - 'parameters' : self.parameters, - 'initial_param_values' : self.initial_param_values, - 'state' : dict([[state.name, state.to_json()] for state in self.state.values()]), - 'transitions' : [trans.to_json() for trans in self.transitions], - 'accepting_states' : self.accepting_states, + 'parameters': self.parameters, + 'initial_param_values': self.initial_param_values, + 'state': dict([[state.name, state.to_json()] for state in self.state.values()]), + 'transitions': [trans.to_json() for trans in self.transitions], + 'accepting_states': self.accepting_states, } return ret @@ -602,9 +610,9 @@ class PTA: See the State() documentation for acceptable arguments. """ - if 'power_function' in kwargs and type(kwargs['power_function']) != AnalyticFunction and kwargs['power_function'] != None: + if 'power_function' in kwargs and type(kwargs['power_function']) != AnalyticFunction and kwargs['power_function'] is not None: kwargs['power_function'] = AnalyticFunction(kwargs['power_function'], - self.parameters, 0) + self.parameters, 0) self.state[state_name] = State(state_name, **kwargs) def add_transition(self, orig_state: str, dest_state: str, function_name: str, **kwargs): @@ -620,7 +628,7 @@ class PTA: orig_state = self.state[orig_state] dest_state = self.state[dest_state] for key in ('duration_function', 'energy_function', 'timeout_function'): - if key in kwargs and kwargs[key] != None and type(kwargs[key]) != AnalyticFunction: + if key in kwargs and kwargs[key] is not None and type(kwargs[key]) != AnalyticFunction: kwargs[key] = AnalyticFunction(kwargs[key], self.parameters, 0) new_transition = Transition(orig_state, dest_state, function_name, **kwargs) @@ -669,7 +677,7 @@ class PTA: def get_initial_param_dict(self): return dict([[self.parameters[i], self.initial_param_values[i]] for i in range(len(self.parameters))]) - def set_random_energy_model(self, static_model = True): + def set_random_energy_model(self, static_model=True): for state in self.state.values(): state.set_random_energy_model(static_model) for transition in self.transitions: @@ -740,12 +748,12 @@ class PTA: if with_parameters and not param_dict: param_dict = self.get_initial_param_dict() - if with_parameters and not 'with_arguments' in kwargs: + if with_parameters and 'with_arguments' not in kwargs: raise ValueError("with_parameters = True requires with_arguments = True") if self.accepting_states: generator = filter(lambda x: x[-1][0].destination.name in self.accepting_states, - self.state[orig_state].dfs(depth, **kwargs)) + self.state[orig_state].dfs(depth, **kwargs)) else: generator = self.state[orig_state].dfs(depth, **kwargs) @@ -754,7 +762,7 @@ class PTA: else: return generator - def simulate(self, trace: list, orig_state: str = 'UNINITIALIZED', accounting = None): + def simulate(self, trace: list, orig_state: str = 'UNINITIALIZED', accounting=None): u""" Simulate a single run through the PTA and return total energy, duration, final state, and resulting parameters. @@ -775,7 +783,7 @@ class PTA: function_args = function[1] else: function_name = function[0] - function_args = function[1 : ] + function_args = function[1:] if function_name is None: duration = function_args[0] total_energy += state.get_energy(duration, param_dict) |