diff options
Diffstat (limited to 'lib/dfatool.py')
-rwxr-xr-x | lib/dfatool.py | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/lib/dfatool.py b/lib/dfatool.py index 38e140d..a089c1d 100755 --- a/lib/dfatool.py +++ b/lib/dfatool.py @@ -19,7 +19,7 @@ from utils import * arg_support_enabled = True -def running_mean(x, N): +def running_mean(x: np.ndarray, N: int) -> np.ndarray: """ Compute running average. @@ -44,7 +44,7 @@ def soft_cast_int(n): except ValueError: return n -def vprint(verbose, string): +def vprint(verbose: bool, string: str): """ Print string if verbose. @@ -69,7 +69,7 @@ def vprint(verbose, string): return x / y return 1. -def gplearn_to_function(function_str): +def gplearn_to_function(function_str: str): """ Convert gplearn-style function string to Python function. @@ -109,7 +109,7 @@ def gplearn_to_function(function_str): print(eval_str) return eval(eval_str, eval_globals) -def _elem_param_and_arg_list(elem): +def _elem_param_and_arg_list(elem: dict): param_dict = elem['parameter'] paramkeys = sorted(param_dict.keys()) paramvalue = [soft_cast_int(param_dict[x]) for x in paramkeys] @@ -117,10 +117,10 @@ def _elem_param_and_arg_list(elem): paramvalue.extend(map(soft_cast_int, elem['args'])) return paramvalue -def _arg_name(arg_index): +def _arg_name(arg_index: int) -> str: return '~arg{:02}'.format(arg_index) -def append_if_set(aggregate, data, key): +def append_if_set(aggregate: dict, data: dict, key: str): """Append data[key] to aggregate if key in data.""" if key in data: aggregate.append(data[key]) @@ -131,7 +131,7 @@ def mean_or_none(arr): return np.mean(arr) return -1 -def aggregate_measures(aggregate, actual): +def aggregate_measures(aggregate: float, actual: list) -> dict: """ Calculate error measures for model value on data list. @@ -145,7 +145,7 @@ def aggregate_measures(aggregate, actual): aggregate_array = np.array([aggregate] * len(actual)) return regression_measures(aggregate_array, np.array(actual)) -def regression_measures(predicted, actual): +def regression_measures(predicted: np.ndarray, actual: np.ndarray): """ Calculate error measures by comparing model values to reference values. @@ -204,7 +204,7 @@ class KeysightCSV: """Create a new KeysightCSV object.""" pass - def load_data(self, filename): + def load_data(self, filename: str): """ Load log data from filename, return timestamps and currents. @@ -225,7 +225,7 @@ class KeysightCSV: currents[i] = float(row[2]) * -1 return timestamps, currents -def by_name_to_by_param(by_name): +def by_name_to_by_param(by_name: dict): """ Convert aggregation by name to aggregation by name and parameter values. """ |