diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2020-05-28 12:04:37 +0200 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2020-05-28 12:04:37 +0200 |
commit | c69331e4d925658b2bf26dcb387981f6530d7b9e (patch) | |
tree | d19c7f9b0bf51f68c104057e013630e009835268 /lib/utils.py | |
parent | 23927051ac3e64cabbaa6c30e8356dfe90ebfa6c (diff) |
use black(1) for uniform code formatting
Diffstat (limited to 'lib/utils.py')
-rw-r--r-- | lib/utils.py | 82 |
1 files changed, 53 insertions, 29 deletions
diff --git a/lib/utils.py b/lib/utils.py index 26a591e..91dded0 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -21,15 +21,25 @@ def running_mean(x: np.ndarray, N: int) -> np.ndarray: :param x: 1-Dimensional NumPy array :param N: how many items to average """ + # FIXME np.insert(x, 0, [x[0] for i in range(N/2)]) + # FIXME np.insert(x, -1, [x[-1] for i in range(N/2)]) + # (dabei ungerade N beachten) cumsum = np.cumsum(np.insert(x, 0, 0)) return (cumsum[N:] - cumsum[:-N]) / N def human_readable(value, unit): - for prefix, factor in (('p', 1e-12), ('n', 1e-9), (u'µ', 1e-6), ('m', 1e-3), ('', 1), ('k', 1e3)): + for prefix, factor in ( + ("p", 1e-12), + ("n", 1e-9), + (u"µ", 1e-6), + ("m", 1e-3), + ("", 1), + ("k", 1e3), + ): if value < 1e3 * factor: - return '{:.2f} {}{}'.format(value * (1 / factor), prefix, unit) - return '{:.2f} {}'.format(value, unit) + return "{:.2f} {}{}".format(value * (1 / factor), prefix, unit) + return "{:.2f} {}".format(value, unit) def is_numeric(n): @@ -65,7 +75,7 @@ def soft_cast_int(n): If `n` is empty, returns None. If `n` is not numeric, it is left unchanged. """ - if n is None or n == '': + if n is None or n == "": return None try: return int(n) @@ -80,7 +90,7 @@ def soft_cast_float(n): If `n` is empty, returns None. If `n` is not numeric, it is left unchanged. """ - if n is None or n == '': + if n is None or n == "": return None try: return float(n) @@ -104,8 +114,8 @@ def parse_conf_str(conf_str): Values are casted to float if possible and kept as-is otherwise. """ conf_dict = dict() - for option in conf_str.split(','): - key, value = option.split('=') + for option in conf_str.split(","): + key, value = option.split("=") conf_dict[key] = soft_cast_float(value) return conf_dict @@ -118,7 +128,7 @@ def remove_index_from_tuple(parameters, index): :param index: index of element which is to be removed :returns: parameters tuple without the element at index """ - return (*parameters[:index], *parameters[index + 1:]) + return (*parameters[:index], *parameters[index + 1 :]) def param_slice_eq(a, b, index): @@ -137,7 +147,9 @@ def param_slice_eq(a, b, index): ('foo', [1, 4]), ('foo', [2, 4]), 1 -> False """ - if (*a[1][:index], *a[1][index + 1:]) == (*b[1][:index], *b[1][index + 1:]) and a[0] == b[0]: + if (*a[1][:index], *a[1][index + 1 :]) == (*b[1][:index], *b[1][index + 1 :]) and a[ + 0 + ] == b[0]: return True return False @@ -164,20 +176,20 @@ def by_name_to_by_param(by_name: dict): """ by_param = dict() for name in by_name.keys(): - for i, parameters in enumerate(by_name[name]['param']): + for i, parameters in enumerate(by_name[name]["param"]): param_key = (name, tuple(parameters)) if param_key not in by_param: by_param[param_key] = dict() for key in by_name[name].keys(): by_param[param_key][key] = list() - by_param[param_key]['attributes'] = by_name[name]['attributes'] + by_param[param_key]["attributes"] = by_name[name]["attributes"] # special case for PTA models - if 'isa' in by_name[name]: - by_param[param_key]['isa'] = by_name[name]['isa'] - for attribute in by_name[name]['attributes']: + if "isa" in by_name[name]: + by_param[param_key]["isa"] = by_name[name]["isa"] + for attribute in by_name[name]["attributes"]: by_param[param_key][attribute].append(by_name[name][attribute][i]) # Required for match_parameter_valuse in _try_fits - by_param[param_key]['param'].append(by_name[name]['param'][i]) + by_param[param_key]["param"].append(by_name[name]["param"][i]) return by_param @@ -197,14 +209,26 @@ def filter_aggregate_by_param(aggregate, parameters, parameter_filter): param_value = soft_cast_int(param_name_and_value[1]) names_to_remove = set() for name in aggregate.keys(): - indices_to_keep = list(map(lambda x: x[param_index] == param_value, aggregate[name]['param'])) - aggregate[name]['param'] = list(map(lambda iv: iv[1], filter(lambda iv: indices_to_keep[iv[0]], enumerate(aggregate[name]['param'])))) + indices_to_keep = list( + map(lambda x: x[param_index] == param_value, aggregate[name]["param"]) + ) + aggregate[name]["param"] = list( + map( + lambda iv: iv[1], + filter( + lambda iv: indices_to_keep[iv[0]], + enumerate(aggregate[name]["param"]), + ), + ) + ) if len(indices_to_keep) == 0: - print('??? {}->{}'.format(parameter_filter, name)) + print("??? {}->{}".format(parameter_filter, name)) names_to_remove.add(name) else: - for attribute in aggregate[name]['attributes']: - aggregate[name][attribute] = aggregate[name][attribute][indices_to_keep] + for attribute in aggregate[name]["attributes"]: + aggregate[name][attribute] = aggregate[name][attribute][ + indices_to_keep + ] if len(aggregate[name][attribute]) == 0: names_to_remove.add(name) for name in names_to_remove: @@ -218,25 +242,25 @@ class OptionalTimingAnalysis: self.index = 1 def get_header(self): - ret = '' + ret = "" if self.enabled: - ret += '#define TIMEIT(index, functioncall) ' - ret += 'counter.start(); ' - ret += 'functioncall; ' - ret += 'counter.stop();' + ret += "#define TIMEIT(index, functioncall) " + ret += "counter.start(); " + ret += "functioncall; " + ret += "counter.stop();" ret += 'kout << endl << index << " :: " << counter.value << "/" << counter.overflow << endl;\n' return ret def wrap_codeblock(self, codeblock): if not self.enabled: return codeblock - lines = codeblock.split('\n') + lines = codeblock.split("\n") ret = list() for line in lines: - if re.fullmatch('.+;', line): - ret.append('TIMEIT( {:d}, {} )'.format(self.index, line)) + if re.fullmatch(".+;", line): + ret.append("TIMEIT( {:d}, {} )".format(self.index, line)) self.wrapped_lines.append(line) self.index += 1 else: ret.append(line) - return '\n'.join(ret) + return "\n".join(ret) |