summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xbin/analyze-archive.py45
-rwxr-xr-xbin/analyze-timing.py30
-rwxr-xr-xbin/analyze.py40
-rwxr-xr-xlib/automata.py164
4 files changed, 136 insertions, 143 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py
index 787510d..8470ab6 100755
--- a/bin/analyze-archive.py
+++ b/bin/analyze-archive.py
@@ -83,13 +83,14 @@ import plotter
import re
import sys
from dfatool import PTAModel, RawData, pta_trace_to_aggregate
-from dfatool import soft_cast_int, is_numeric, gplearn_to_function
+from dfatool import gplearn_to_function
from dfatool import CrossValidator
from utils import filter_aggregate_by_param
from automata import PTA
opts = {}
+
def print_model_quality(results):
for state_or_tran in results.keys():
print()
@@ -101,12 +102,14 @@ def print_model_quality(results):
print('{:20s} {:15s} {:.0f}'.format(
state_or_tran, key, result['mae']))
+
def format_quality_measures(result):
if 'smape' in result:
return '{:6.2f}% / {:9.0f}'.format(result['smape'], result['mae'])
else:
return '{:6} {:9.0f}'.format('', result['mae'])
+
def model_quality_table(result_lists, info_list):
for state_or_tran in result_lists[0]['by_name'].keys():
for key in result_lists[0]['by_name'][state_or_tran].keys():
@@ -114,13 +117,14 @@ def model_quality_table(result_lists, info_list):
for i, results in enumerate(result_lists):
info = info_list[i]
buf += ' ||| '
- if info == None or info(state_or_tran, key):
+ if info is None or info(state_or_tran, key):
result = results['by_name'][state_or_tran][key]
buf += format_quality_measures(result)
else:
buf += '{:6}----{:9}'.format('', '')
print(buf)
+
def model_summary_table(result_list):
buf = 'transition duration'
for results in result_list:
@@ -171,6 +175,7 @@ def print_text_model_data(model, pm, pq, lm, lq, am, ai, aq):
for arg_index in range(model._num_args[state_or_tran]):
print('{} {} {:d} {:.8f}'.format(state_or_tran, attribute, arg_index, model.stats.arg_dependence_ratio(state_or_tran, attribute, arg_index)))
+
def print_html_model_data(model, pm, pq, lm, lq, am, ai, aq):
state_attributes = model.attributes(model.states()[0])
@@ -204,6 +209,7 @@ def print_html_model_data(model, pm, pq, lm, lq, am, ai, aq):
print('</tr>')
print('</table>')
+
if __name__ == '__main__':
ignored_trace_indexes = []
@@ -282,10 +288,10 @@ if __name__ == '__main__':
filter_aggregate_by_param(by_name, parameters, opts['filter-param'])
model = PTAModel(by_name, parameters, arg_count,
- traces = preprocessed_data,
- discard_outliers = discard_outliers,
- function_override = function_override,
- pta = pta)
+ traces=preprocessed_data,
+ discard_outliers=discard_outliers,
+ function_override=function_override,
+ pta=pta)
if xv_method:
xv = CrossValidator(PTAModel, by_name, parameters, arg_count)
@@ -299,8 +305,8 @@ if __name__ == '__main__':
if 'plot-unparam' in opts:
for kv in opts['plot-unparam'].split(';'):
state_or_trans, attribute, ylabel = kv.split(':')
- fname = 'param_y_{}_{}.pdf'.format(state_or_trans,attribute)
- plotter.plot_y(model.by_name[state_or_trans][attribute], xlabel = 'measurement #', ylabel = ylabel, output = fname)
+ fname = 'param_y_{}_{}.pdf'.format(state_or_trans, attribute)
+ plotter.plot_y(model.by_name[state_or_trans][attribute], xlabel='measurement #', ylabel=ylabel, output=fname)
if len(show_models):
print('--- simple static model ---')
@@ -361,7 +367,7 @@ if __name__ == '__main__':
if len(show_models):
print('--- param model ---')
- param_model, param_info = model.get_fitted(safe_functions_enabled = safe_functions_enabled)
+ param_model, param_info = model.get_fitted(safe_functions_enabled=safe_functions_enabled)
if 'paramdetection' in show_models or 'all' in show_models:
for state in model.states_and_transitions():
@@ -377,7 +383,7 @@ if __name__ == '__main__':
print('{:10s} {:10s} {:10s} stddev {:f}'.format(
state, attribute, param, model.stats.stats[state][attribute]['std_by_param'][param]
))
- if info != None:
+ if info is not None:
for param_name in sorted(info['fit_result'].keys(), key=str):
param_fit = info['fit_result'][param_name]['results']
for function_type in sorted(param_fit.keys()):
@@ -413,10 +419,20 @@ if __name__ == '__main__':
if 'table' in show_quality or 'all' in show_quality:
model_quality_table([static_quality, analytic_quality, lut_quality], [None, param_info, None])
+
if 'overall' in show_quality or 'all' in show_quality:
- print('overall MAE of static model: {} µW'.format(model.assess_states(static_model)))
- print('overall MAE of param model: {} µW'.format(model.assess_states(param_model)))
- print('overall MAE of LUT model: {} µW'.format(model.assess_states(lut_model)))
+ print('overall static/param/lut MAE assuming equal state distribution:')
+ print(' {:6.1f} / {:6.1f} / {:6.1f} µW'.format(
+ model.assess_states(static_model),
+ model.assess_states(param_model),
+ model.assess_states(lut_model)))
+ print('overall static/param/lut MAE assuming 95% STANDBY1:')
+ distrib = {'STANDBY1': 0.95, 'POWERDOWN': 0.03, 'TX': 0.01, 'RX': 0.01}
+ print(' {:6.1f} / {:6.1f} / {:6.1f} µW'.format(
+ model.assess_states(static_model, distribution=distrib),
+ model.assess_states(param_model, distribution=distrib),
+ model.assess_states(lut_model, distribution=distrib)))
+
if 'summary' in show_quality or 'all' in show_quality:
model_summary_table([model.assess_on_traces(static_model), model.assess_on_traces(param_model), model.assess_on_traces(lut_model)])
@@ -435,7 +451,6 @@ if __name__ == '__main__':
sys.exit(1)
json_model = model.to_json()
with open(opts['export-energymodel'], 'w') as f:
- json.dump(json_model, f, indent = 2, sort_keys = True)
-
+ json.dump(json_model, f, indent=2, sort_keys=True)
sys.exit(0)
diff --git a/bin/analyze-timing.py b/bin/analyze-timing.py
index 6c84a67..9a3aa41 100755
--- a/bin/analyze-timing.py
+++ b/bin/analyze-timing.py
@@ -79,14 +79,14 @@ import plotter
import re
import sys
from dfatool import AnalyticModel, TimingData, pta_trace_to_aggregate
-from dfatool import soft_cast_int, is_numeric, gplearn_to_function
+from dfatool import gplearn_to_function
from dfatool import CrossValidator
from utils import filter_aggregate_by_param
from parameters import prune_dependent_parameters
-import utils
opts = {}
+
def print_model_quality(results):
for state_or_tran in results.keys():
print()
@@ -98,12 +98,14 @@ def print_model_quality(results):
print('{:20s} {:15s} {:.0f}'.format(
state_or_tran, key, result['mae']))
+
def format_quality_measures(result):
if 'smape' in result:
return '{:6.2f}% / {:9.0f}'.format(result['smape'], result['mae'])
else:
return '{:6} {:9.0f}'.format('', result['mae'])
+
def model_quality_table(result_lists, info_list):
for state_or_tran in result_lists[0]['by_name'].keys():
for key in result_lists[0]['by_name'][state_or_tran].keys():
@@ -111,7 +113,7 @@ def model_quality_table(result_lists, info_list):
for i, results in enumerate(result_lists):
info = info_list[i]
buf += ' ||| '
- if info == None or info(state_or_tran, key):
+ if info is None or info(state_or_tran, key):
result = results['by_name'][state_or_tran][key]
buf += format_quality_measures(result)
else:
@@ -136,6 +138,7 @@ def print_text_model_data(model, pm, pq, lm, lq, am, ai, aq):
for arg_index in range(model._num_args[state_or_tran]):
print('{} {} {:d} {:.8f}'.format(state_or_tran, attribute, arg_index, model.stats.arg_dependence_ratio(state_or_tran, attribute, arg_index)))
+
if __name__ == '__main__':
ignored_trace_indexes = []
@@ -215,7 +218,7 @@ if __name__ == '__main__':
filter_aggregate_by_param(by_name, parameters, opts['filter-param'])
- model = AnalyticModel(by_name, parameters, arg_count, use_corrcoef = opts['corrcoef'], function_override = function_override)
+ model = AnalyticModel(by_name, parameters, arg_count, use_corrcoef=opts['corrcoef'], function_override=function_override)
if xv_method:
xv = CrossValidator(AnalyticModel, by_name, parameters, arg_count)
@@ -229,8 +232,8 @@ if __name__ == '__main__':
if 'plot-unparam' in opts:
for kv in opts['plot-unparam'].split(';'):
state_or_trans, attribute, ylabel = kv.split(':')
- fname = 'param_y_{}_{}.pdf'.format(state_or_trans,attribute)
- plotter.plot_y(model.by_name[state_or_trans][attribute], xlabel = 'measurement #', ylabel = ylabel)
+ fname = 'param_y_{}_{}.pdf'.format(state_or_trans, attribute)
+ plotter.plot_y(model.by_name[state_or_trans][attribute], xlabel='measurement #', ylabel=ylabel)
if len(show_models):
print('--- simple static model ---')
@@ -247,6 +250,15 @@ if __name__ == '__main__':
print('{:24s} co-dependencies: {:s}'.format('', ', '.join(model.stats.codependent_parameters(trans, 'duration', param))))
for param_dict in model.stats.codependent_parameter_value_dicts(trans, 'duration', param):
print('{:24s} parameter-aware for {}'.format('', param_dict))
+ # import numpy as np
+ # safe_div = np.vectorize(lambda x,y: 0. if x == 0 else 1 - x/y)
+ # ratio_by_value = safe_div(model.stats.stats['write']['duration']['lut_by_param_values']['max_retry_count'], model.stats.stats['write']['duration']['std_by_param_values']['max_retry_count'])
+ # err_mode = np.seterr('warn')
+ # dep_by_value = ratio_by_value > 0.5
+ # np.seterr(**err_mode)
+ # Eigentlich sollte hier ein paar mal True stehen, ist aber nicht so...
+ # und warum ist da eine non-power-of-two Zahl von True-Einträgen in der Matrix? 3 stück ist komisch...
+ # print(dep_by_value)
if xv_method == 'montecarlo':
static_quality = xv.montecarlo(lambda m: m.get_static(), xv_count)
@@ -265,7 +277,7 @@ if __name__ == '__main__':
if len(show_models):
print('--- param model ---')
- param_model, param_info = model.get_fitted(safe_functions_enabled = safe_functions_enabled)
+ param_model, param_info = model.get_fitted(safe_functions_enabled=safe_functions_enabled)
if 'paramdetection' in show_models or 'all' in show_models:
for transition in model.names:
@@ -289,7 +301,7 @@ if __name__ == '__main__':
))
print('{:10s} {:10s} dependence on arg{:d}: {:.2f}'.format(
transition, attribute, i, model.stats.arg_dependence_ratio(transition, attribute, i)))
- if info != None:
+ if info is not None:
for param_name in sorted(info['fit_result'].keys(), key=str):
param_fit = info['fit_result'][param_name]['results']
for function_type in sorted(param_fit.keys()):
@@ -325,6 +337,4 @@ if __name__ == '__main__':
function = None
plotter.plot_param(model, state_or_trans, attribute, model.param_index(param_name), extra_function=function)
-
-
sys.exit(0)
diff --git a/bin/analyze.py b/bin/analyze.py
deleted file mode 100755
index 57803fe..0000000
--- a/bin/analyze.py
+++ /dev/null
@@ -1,40 +0,0 @@
-#!/usr/bin/env python3
-
-import json
-import numpy as np
-import os
-from scipy.cluster.vq import kmeans2
-import struct
-import sys
-import tarfile
-from dfatool import running_mean, MIMOSA
-
-voltage = float(sys.argv[1])
-shunt = float(sys.argv[2])
-filename = sys.argv[3]
-
-mim = MIMOSA(voltage, shunt)
-
-charges, triggers = mim.load_data(filename)
-trigidx = mim.trigger_edges(triggers)
-triggers = []
-cal_edges = mim.calibration_edges(running_mean(mim.currents_nocal(charges[0:trigidx[0]]), 10))
-calfunc, caldata = mim.calibration_function(charges, cal_edges)
-vcalfunc = np.vectorize(calfunc, otypes=[np.float64])
-
-json_out = {
- 'triggers' : len(trigidx),
- 'first_trig' : trigidx[0] * 10,
- 'calibration' : caldata,
- 'trace' : mim.analyze_states(charges, trigidx, vcalfunc)
-}
-
-basename, _ = os.path.splitext(filename)
-
-# TODO also look for interesting gradients inside each state
-
-with open(basename + ".json", "w") as f:
- json.dump(json_out, f)
- f.close()
-
-#print(kmeans2(charges[:firstidx], np.array([130 * ua_step, 3.6 / 987 * 1000000, 3.6 / 99300 * 1000000])))
diff --git a/lib/automata.py b/lib/automata.py
index 73c0225..4aa9a97 100755
--- a/lib/automata.py
+++ b/lib/automata.py
@@ -4,27 +4,30 @@ from functions import AnalyticFunction, NormalizationFunction
from utils import is_numeric
import itertools
import numpy as np
-import json, yaml
+import json
+import yaml
+
def _dict_to_list(input_dict: dict) -> list:
return [input_dict[x] for x in sorted(input_dict.keys())]
+
def _attribute_to_json(static_value: float, param_function: AnalyticFunction) -> dict:
ret = {
- 'static' : static_value
+ 'static': static_value
}
if param_function:
ret['function'] = {
- 'raw' : param_function._model_str,
- 'regression_args' : list(param_function._regression_args)
+ 'raw': param_function._model_str,
+ 'regression_args': list(param_function._regression_args)
}
return ret
+
class State:
"""A single PTA state."""
- def __init__(self, name: str, power: float = 0,
- power_function: AnalyticFunction = None):
+ def __init__(self, name: str, power: float = 0, power_function: AnalyticFunction = None):
u"""
Create a new PTA state.
@@ -53,7 +56,7 @@ class State:
return self.power_function.eval(_dict_to_list(param_dict)) * duration
return self.power * duration
- def set_random_energy_model(self, static_model = True):
+ def set_random_energy_model(self, static_model=True):
"""Set a random static energy value."""
self.power = int(np.random.sample() * 50000)
@@ -83,10 +86,10 @@ class State:
:returns: Transition object
"""
interrupts = filter(lambda x: x.is_interrupt, self.outgoing_transitions.values())
- interrupts = sorted(interrupts, key = lambda x: x.get_timeout(parameters))
+ interrupts = sorted(interrupts, key=lambda x: x.get_timeout(parameters))
return interrupts[0]
- def dfs(self, depth: int, with_arguments: bool = False, trace_filter = None, sleep: int = 0):
+ def dfs(self, depth: int, with_arguments=False, trace_filter=None, sleep=0):
"""
Return a generator object for depth-first search over all outgoing transitions.
@@ -140,7 +143,7 @@ class State:
new_trace_filter = None
else:
new_trace_filter = None
- for suffix in trans.destination.dfs(depth - 1, with_arguments = with_arguments, trace_filter = new_trace_filter, sleep = sleep):
+ for suffix in trans.destination.dfs(depth - 1, with_arguments=with_arguments, trace_filter=new_trace_filter, sleep=sleep):
if with_arguments:
if trans.argument_combination == 'cartesian':
for args in itertools.product(*trans.argument_values):
@@ -173,11 +176,12 @@ class State:
def to_json(self) -> dict:
"""Return JSON encoding of this state object."""
ret = {
- 'name' : self.name,
- 'power' : _attribute_to_json(self.power, self.power_function)
+ 'name': self.name,
+ 'power': _attribute_to_json(self.power, self.power_function)
}
return ret
+
class Transition:
u"""
A single PTA transition with one origin and one destination state.
@@ -203,18 +207,18 @@ class Transition:
"""
def __init__(self, orig_state: State, dest_state: State, name: str,
- energy: float = 0, energy_function: AnalyticFunction = None,
- duration: float = 0, duration_function: AnalyticFunction = None,
- timeout: float = 0, timeout_function: AnalyticFunction = None,
- is_interrupt: bool = False,
- arguments: list = [],
- argument_values: list = [],
- argument_combination: str = 'cartesian', # or 'zip'
- param_update_function = None,
- arg_to_param_map: dict = None,
- set_param = None,
- return_value_handlers: list = [],
- codegen = dict()):
+ energy: float = 0, energy_function: AnalyticFunction = None,
+ duration: float = 0, duration_function: AnalyticFunction = None,
+ timeout: float = 0, timeout_function: AnalyticFunction = None,
+ is_interrupt: bool = False,
+ arguments: list = [],
+ argument_values: list = [],
+ argument_combination: str = 'cartesian', # or 'zip'
+ param_update_function=None,
+ arg_to_param_map: dict = None,
+ set_param=None,
+ return_value_handlers: list = [],
+ codegen=dict()):
"""
Create a new transition between two PTA states.
@@ -245,7 +249,6 @@ class Transition:
if 'formula' in handler:
handler['formula'] = NormalizationFunction(handler['formula'])
-
def get_duration(self, param_dict: dict = {}, args: list = []) -> float:
u"""
Return transition duration in µs.
@@ -270,7 +273,7 @@ class Transition:
return self.energy_function.eval(_dict_to_list(param_dict), args)
return self.energy
- def set_random_energy_model(self, static_model = True):
+ def set_random_energy_model(self, static_model=True):
self.energy = int(np.random.sample() * 50000)
def get_timeout(self, param_dict: dict = {}) -> float:
@@ -309,32 +312,35 @@ class Transition:
def to_json(self) -> dict:
"""Return JSON encoding of this transition object."""
ret = {
- 'name' : self.name,
- 'origin' : self.origin.name,
- 'destination' : self.destination.name,
- 'is_interrupt' : self.is_interrupt,
- 'arguments' : self.arguments,
- 'argument_values' : self.argument_values,
- 'argument_combination' : self.argument_combination,
- 'arg_to_param_map' : self.arg_to_param_map,
- 'set_param' : self.set_param,
- 'duration' : _attribute_to_json(self.duration, self.duration_function),
- 'energy' : _attribute_to_json(self.energy, self.energy_function),
- 'timeout' : _attribute_to_json(self.timeout, self.timeout_function),
+ 'name': self.name,
+ 'origin': self.origin.name,
+ 'destination': self.destination.name,
+ 'is_interrupt': self.is_interrupt,
+ 'arguments': self.arguments,
+ 'argument_values': self.argument_values,
+ 'argument_combination': self.argument_combination,
+ 'arg_to_param_map': self.arg_to_param_map,
+ 'set_param': self.set_param,
+ 'duration': _attribute_to_json(self.duration, self.duration_function),
+ 'energy': _attribute_to_json(self.energy, self.energy_function),
+ 'timeout': _attribute_to_json(self.timeout, self.timeout_function),
}
return ret
+
def _json_function_to_analytic_function(base, attribute: str, parameters: list):
if attribute in base and 'function' in base[attribute]:
base = base[attribute]['function']
- return AnalyticFunction(base['raw'], parameters, 0, regression_args = base['regression_args'])
+ return AnalyticFunction(base['raw'], parameters, 0, regression_args=base['regression_args'])
return None
+
def _json_get_static(base, attribute: str):
if attribute in base:
return base[attribute]['static']
return 0
+
class PTA:
"""
A parameterized priced timed automaton. All states are accepting.
@@ -354,9 +360,9 @@ class PTA:
"""
def __init__(self, state_names: list = [],
- accepting_states: list = None,
- parameters: list = [], initial_param_values: list = None,
- codegen: dict = {}, parameter_normalization: dict = None):
+ accepting_states: list = None,
+ parameters: list = [], initial_param_values: list = None,
+ codegen: dict = {}, parameter_normalization: dict = None):
"""
Return a new PTA object.
@@ -383,7 +389,7 @@ class PTA:
self.initial_param_values = [None for x in self.parameters]
self.transitions = []
- if not 'UNINITIALIZED' in state_names:
+ if 'UNINITIALIZED' not in state_names:
self.state['UNINITIALIZED'] = State('UNINITIALIZED')
if self.parameter_normalization:
@@ -451,7 +457,7 @@ class PTA:
pta = cls(**kwargs)
for name, state in json_input['state'].items():
power_function = _json_function_to_analytic_function(state, 'power', pta.parameters)
- pta.add_state(name, power = _json_get_static(state, 'power'), power_function = power_function)
+ pta.add_state(name, power=_json_get_static(state, 'power'), power_function=power_function)
for transition in json_input['transitions']:
duration_function = _json_function_to_analytic_function(transition, 'duration', pta.parameters)
energy_function = _json_function_to_analytic_function(transition, 'energy', pta.parameters)
@@ -465,15 +471,14 @@ class PTA:
origins = [origins]
for origin in origins:
pta.add_transition(origin, transition['destination'],
- transition['name'],
- duration = _json_get_static(transition, 'duration'),
- duration_function = duration_function,
- energy = _json_get_static(transition, 'energy'),
- energy_function = energy_function,
- timeout = _json_get_static(transition, 'timeout'),
- timeout_function = timeout_function,
- **kwargs
- )
+ transition['name'],
+ duration=_json_get_static(transition, 'duration'),
+ duration_function=duration_function,
+ energy=_json_get_static(transition, 'energy'),
+ energy_function=energy_function,
+ timeout=_json_get_static(transition, 'timeout'),
+ timeout_function=timeout_function,
+ **kwargs)
return pta
@@ -485,7 +490,7 @@ class PTA:
Compatible with the legacy dfatool/perl format.
"""
kwargs = {
- 'parameters' : list(),
+ 'parameters': list(),
'initial_param_values': list(),
}
@@ -496,7 +501,7 @@ class PTA:
pta = cls(**kwargs)
for name, state in json_input['state'].items():
- pta.add_state(name, power = float(state['power']['static']))
+ pta.add_state(name, power=float(state['power']['static']))
for trans_name in sorted(json_input['transition'].keys()):
transition = json_input['transition'][trans_name]
@@ -513,8 +518,9 @@ class PTA:
argument_values.append(arg['values'])
for origin in transition['origins']:
pta.add_transition(origin, destination, trans_name,
- arguments = arguments, argument_values = argument_values,
- is_interrupt = is_interrupt)
+ arguments=arguments,
+ argument_values=argument_values,
+ is_interrupt=is_interrupt)
return pta
@@ -565,19 +571,21 @@ class PTA:
if 'loop' in transition:
for state_name in transition['loop']:
pta.add_transition(state_name, state_name, trans_name,
- arguments = arguments, argument_values = argument_values,
- arg_to_param_map = arg_to_param_map,
- **kwargs)
+ arguments=arguments,
+ argument_values=argument_values,
+ arg_to_param_map=arg_to_param_map,
+ **kwargs)
else:
- if not 'src' in transition:
+ if 'src' not in transition:
transition['src'] = ['UNINITIALIZED']
- if not 'dst' in transition:
+ if 'dst' not in transition:
transition['dst'] = 'UNINITIALIZED'
for origin in transition['src']:
pta.add_transition(origin, transition['dst'], trans_name,
- arguments = arguments, argument_values = argument_values,
- arg_to_param_map = arg_to_param_map,
- **kwargs)
+ arguments=arguments,
+ argument_values=argument_values,
+ arg_to_param_map=arg_to_param_map,
+ **kwargs)
return pta
@@ -588,11 +596,11 @@ class PTA:
Compatible with the from_json method.
"""
ret = {
- 'parameters' : self.parameters,
- 'initial_param_values' : self.initial_param_values,
- 'state' : dict([[state.name, state.to_json()] for state in self.state.values()]),
- 'transitions' : [trans.to_json() for trans in self.transitions],
- 'accepting_states' : self.accepting_states,
+ 'parameters': self.parameters,
+ 'initial_param_values': self.initial_param_values,
+ 'state': dict([[state.name, state.to_json()] for state in self.state.values()]),
+ 'transitions': [trans.to_json() for trans in self.transitions],
+ 'accepting_states': self.accepting_states,
}
return ret
@@ -602,9 +610,9 @@ class PTA:
See the State() documentation for acceptable arguments.
"""
- if 'power_function' in kwargs and type(kwargs['power_function']) != AnalyticFunction and kwargs['power_function'] != None:
+ if 'power_function' in kwargs and type(kwargs['power_function']) != AnalyticFunction and kwargs['power_function'] is not None:
kwargs['power_function'] = AnalyticFunction(kwargs['power_function'],
- self.parameters, 0)
+ self.parameters, 0)
self.state[state_name] = State(state_name, **kwargs)
def add_transition(self, orig_state: str, dest_state: str, function_name: str, **kwargs):
@@ -620,7 +628,7 @@ class PTA:
orig_state = self.state[orig_state]
dest_state = self.state[dest_state]
for key in ('duration_function', 'energy_function', 'timeout_function'):
- if key in kwargs and kwargs[key] != None and type(kwargs[key]) != AnalyticFunction:
+ if key in kwargs and kwargs[key] is not None and type(kwargs[key]) != AnalyticFunction:
kwargs[key] = AnalyticFunction(kwargs[key], self.parameters, 0)
new_transition = Transition(orig_state, dest_state, function_name, **kwargs)
@@ -669,7 +677,7 @@ class PTA:
def get_initial_param_dict(self):
return dict([[self.parameters[i], self.initial_param_values[i]] for i in range(len(self.parameters))])
- def set_random_energy_model(self, static_model = True):
+ def set_random_energy_model(self, static_model=True):
for state in self.state.values():
state.set_random_energy_model(static_model)
for transition in self.transitions:
@@ -740,12 +748,12 @@ class PTA:
if with_parameters and not param_dict:
param_dict = self.get_initial_param_dict()
- if with_parameters and not 'with_arguments' in kwargs:
+ if with_parameters and 'with_arguments' not in kwargs:
raise ValueError("with_parameters = True requires with_arguments = True")
if self.accepting_states:
generator = filter(lambda x: x[-1][0].destination.name in self.accepting_states,
- self.state[orig_state].dfs(depth, **kwargs))
+ self.state[orig_state].dfs(depth, **kwargs))
else:
generator = self.state[orig_state].dfs(depth, **kwargs)
@@ -754,7 +762,7 @@ class PTA:
else:
return generator
- def simulate(self, trace: list, orig_state: str = 'UNINITIALIZED', accounting = None):
+ def simulate(self, trace: list, orig_state: str = 'UNINITIALIZED', accounting=None):
u"""
Simulate a single run through the PTA and return total energy, duration, final state, and resulting parameters.
@@ -775,7 +783,7 @@ class PTA:
function_args = function[1]
else:
function_name = function[0]
- function_args = function[1 : ]
+ function_args = function[1:]
if function_name is None:
duration = function_args[0]
total_energy += state.get_energy(duration, param_dict)