summaryrefslogtreecommitdiff
path: root/lib/utils.py
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2020-05-28 12:04:37 +0200
committerDaniel Friesel <daniel.friesel@uos.de>2020-05-28 12:04:37 +0200
commitc69331e4d925658b2bf26dcb387981f6530d7b9e (patch)
treed19c7f9b0bf51f68c104057e013630e009835268 /lib/utils.py
parent23927051ac3e64cabbaa6c30e8356dfe90ebfa6c (diff)
use black(1) for uniform code formatting
Diffstat (limited to 'lib/utils.py')
-rw-r--r--lib/utils.py82
1 files changed, 53 insertions, 29 deletions
diff --git a/lib/utils.py b/lib/utils.py
index 26a591e..91dded0 100644
--- a/lib/utils.py
+++ b/lib/utils.py
@@ -21,15 +21,25 @@ def running_mean(x: np.ndarray, N: int) -> np.ndarray:
:param x: 1-Dimensional NumPy array
:param N: how many items to average
"""
+ # FIXME np.insert(x, 0, [x[0] for i in range(N/2)])
+ # FIXME np.insert(x, -1, [x[-1] for i in range(N/2)])
+ # (dabei ungerade N beachten)
cumsum = np.cumsum(np.insert(x, 0, 0))
return (cumsum[N:] - cumsum[:-N]) / N
def human_readable(value, unit):
- for prefix, factor in (('p', 1e-12), ('n', 1e-9), (u'µ', 1e-6), ('m', 1e-3), ('', 1), ('k', 1e3)):
+ for prefix, factor in (
+ ("p", 1e-12),
+ ("n", 1e-9),
+ (u"µ", 1e-6),
+ ("m", 1e-3),
+ ("", 1),
+ ("k", 1e3),
+ ):
if value < 1e3 * factor:
- return '{:.2f} {}{}'.format(value * (1 / factor), prefix, unit)
- return '{:.2f} {}'.format(value, unit)
+ return "{:.2f} {}{}".format(value * (1 / factor), prefix, unit)
+ return "{:.2f} {}".format(value, unit)
def is_numeric(n):
@@ -65,7 +75,7 @@ def soft_cast_int(n):
If `n` is empty, returns None.
If `n` is not numeric, it is left unchanged.
"""
- if n is None or n == '':
+ if n is None or n == "":
return None
try:
return int(n)
@@ -80,7 +90,7 @@ def soft_cast_float(n):
If `n` is empty, returns None.
If `n` is not numeric, it is left unchanged.
"""
- if n is None or n == '':
+ if n is None or n == "":
return None
try:
return float(n)
@@ -104,8 +114,8 @@ def parse_conf_str(conf_str):
Values are casted to float if possible and kept as-is otherwise.
"""
conf_dict = dict()
- for option in conf_str.split(','):
- key, value = option.split('=')
+ for option in conf_str.split(","):
+ key, value = option.split("=")
conf_dict[key] = soft_cast_float(value)
return conf_dict
@@ -118,7 +128,7 @@ def remove_index_from_tuple(parameters, index):
:param index: index of element which is to be removed
:returns: parameters tuple without the element at index
"""
- return (*parameters[:index], *parameters[index + 1:])
+ return (*parameters[:index], *parameters[index + 1 :])
def param_slice_eq(a, b, index):
@@ -137,7 +147,9 @@ def param_slice_eq(a, b, index):
('foo', [1, 4]), ('foo', [2, 4]), 1 -> False
"""
- if (*a[1][:index], *a[1][index + 1:]) == (*b[1][:index], *b[1][index + 1:]) and a[0] == b[0]:
+ if (*a[1][:index], *a[1][index + 1 :]) == (*b[1][:index], *b[1][index + 1 :]) and a[
+ 0
+ ] == b[0]:
return True
return False
@@ -164,20 +176,20 @@ def by_name_to_by_param(by_name: dict):
"""
by_param = dict()
for name in by_name.keys():
- for i, parameters in enumerate(by_name[name]['param']):
+ for i, parameters in enumerate(by_name[name]["param"]):
param_key = (name, tuple(parameters))
if param_key not in by_param:
by_param[param_key] = dict()
for key in by_name[name].keys():
by_param[param_key][key] = list()
- by_param[param_key]['attributes'] = by_name[name]['attributes']
+ by_param[param_key]["attributes"] = by_name[name]["attributes"]
# special case for PTA models
- if 'isa' in by_name[name]:
- by_param[param_key]['isa'] = by_name[name]['isa']
- for attribute in by_name[name]['attributes']:
+ if "isa" in by_name[name]:
+ by_param[param_key]["isa"] = by_name[name]["isa"]
+ for attribute in by_name[name]["attributes"]:
by_param[param_key][attribute].append(by_name[name][attribute][i])
# Required for match_parameter_valuse in _try_fits
- by_param[param_key]['param'].append(by_name[name]['param'][i])
+ by_param[param_key]["param"].append(by_name[name]["param"][i])
return by_param
@@ -197,14 +209,26 @@ def filter_aggregate_by_param(aggregate, parameters, parameter_filter):
param_value = soft_cast_int(param_name_and_value[1])
names_to_remove = set()
for name in aggregate.keys():
- indices_to_keep = list(map(lambda x: x[param_index] == param_value, aggregate[name]['param']))
- aggregate[name]['param'] = list(map(lambda iv: iv[1], filter(lambda iv: indices_to_keep[iv[0]], enumerate(aggregate[name]['param']))))
+ indices_to_keep = list(
+ map(lambda x: x[param_index] == param_value, aggregate[name]["param"])
+ )
+ aggregate[name]["param"] = list(
+ map(
+ lambda iv: iv[1],
+ filter(
+ lambda iv: indices_to_keep[iv[0]],
+ enumerate(aggregate[name]["param"]),
+ ),
+ )
+ )
if len(indices_to_keep) == 0:
- print('??? {}->{}'.format(parameter_filter, name))
+ print("??? {}->{}".format(parameter_filter, name))
names_to_remove.add(name)
else:
- for attribute in aggregate[name]['attributes']:
- aggregate[name][attribute] = aggregate[name][attribute][indices_to_keep]
+ for attribute in aggregate[name]["attributes"]:
+ aggregate[name][attribute] = aggregate[name][attribute][
+ indices_to_keep
+ ]
if len(aggregate[name][attribute]) == 0:
names_to_remove.add(name)
for name in names_to_remove:
@@ -218,25 +242,25 @@ class OptionalTimingAnalysis:
self.index = 1
def get_header(self):
- ret = ''
+ ret = ""
if self.enabled:
- ret += '#define TIMEIT(index, functioncall) '
- ret += 'counter.start(); '
- ret += 'functioncall; '
- ret += 'counter.stop();'
+ ret += "#define TIMEIT(index, functioncall) "
+ ret += "counter.start(); "
+ ret += "functioncall; "
+ ret += "counter.stop();"
ret += 'kout << endl << index << " :: " << counter.value << "/" << counter.overflow << endl;\n'
return ret
def wrap_codeblock(self, codeblock):
if not self.enabled:
return codeblock
- lines = codeblock.split('\n')
+ lines = codeblock.split("\n")
ret = list()
for line in lines:
- if re.fullmatch('.+;', line):
- ret.append('TIMEIT( {:d}, {} )'.format(self.index, line))
+ if re.fullmatch(".+;", line):
+ ret.append("TIMEIT( {:d}, {} )".format(self.index, line))
self.wrapped_lines.append(line)
self.index += 1
else:
ret.append(line)
- return '\n'.join(ret)
+ return "\n".join(ret)