diff options
Diffstat (limited to 'lib')
-rwxr-xr-x | lib/dfatool.py | 16 | ||||
-rw-r--r-- | lib/functions.py | 16 | ||||
-rw-r--r-- | lib/utils.py | 37 |
3 files changed, 46 insertions, 23 deletions
diff --git a/lib/dfatool.py b/lib/dfatool.py index 95e76e7..abf8c10 100755 --- a/lib/dfatool.py +++ b/lib/dfatool.py @@ -21,11 +21,10 @@ arg_support_enabled = True def running_mean(x: np.ndarray, N: int) -> np.ndarray: """ - Compute running average. + Compute `N` elements wide running average over `x`. - arguments: - x -- NumPy array - N -- how many items to average + :param x: 1-Dimensional NumPy array + :param N: how many items to average """ cumsum = np.cumsum(np.insert(x, 0, 0)) return (cumsum[N:] - cumsum[:-N]) / N @@ -71,16 +70,17 @@ def gplearn_to_function(function_str: str): print(eval_str) return eval(eval_str, eval_globals) -def _arg_name(arg_index: int) -> str: - return '~arg{:02}'.format(arg_index) - def append_if_set(aggregate: dict, data: dict, key: str): """Append data[key] to aggregate if key in data.""" if key in data: aggregate.append(data[key]) def mean_or_none(arr): - """Compute mean of NumPy array arr, return -1 if empty.""" + """ + Compute mean of NumPy array `arr`, return -1 if empty. + + :param arr: 1-Dimensional NumPy array + """ if len(arr): return np.mean(arr) return -1 diff --git a/lib/functions.py b/lib/functions.py index 9c58b58..76be562 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -14,11 +14,11 @@ arg_support_enabled = True def powerset(iterable): """ - Calculate powerset of given items. + Calculate powerset of `iterable` elements. Returns an iterable containing one tuple for each powerset element. - Example: powerset([1, 2]) -> [(), (1), (2), (1, 2)] + Example: `powerset([1, 2])` -> `[(), (1), (2), (1, 2)]` """ s = list(iterable) return chain.from_iterable(combinations(s, r) for r in range(len(s)+1)) @@ -86,14 +86,24 @@ class ParamFunction: class NormalizationFunction: """ - Hi + Wrapper for parameter normalization functions used in YAML PTA/DFA models. """ def __init__(self, function_str): + """ + Create a new normalization function from `function_str`. + + :param function_str: Function string. Signature: (param) -> float + """ self._function_str = function_str self._function = eval('lambda param: ' + function_str) def eval(self, param_value): + """ + Evaluate the normalization function and return its output. + + :param param_value: Parameter value + """ return self._function(param_value) class AnalyticFunction: diff --git a/lib/utils.py b/lib/utils.py index abea67e..f54fb99 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -6,7 +6,7 @@ arg_support_enabled = True def vprint(verbose, string): """ - Print string if verbose. + Print `string` if `verbose`. Prints string if verbose is a True value """ @@ -14,7 +14,7 @@ def vprint(verbose, string): print(string) def is_numeric(n): - """Check if n is numeric (i.e., can be converted to int).""" + """Check if `n` is numeric (i.e., it can be converted to int).""" if n == None: return False try: @@ -24,7 +24,7 @@ def is_numeric(n): return False def float_or_nan(n): - """Convert to float (if numeric) or NaN.""" + """Convert `n` to float (if numeric) or NaN.""" if n == None: return np.nan try: @@ -34,10 +34,10 @@ def float_or_nan(n): def soft_cast_int(n): """ - Convert to int, if possible. + Convert `n` to int, if possible. - If it is empty, returns None. - If it is not numeric, it is left unchanged. + If `n` is empty, returns None. + If `n` is not numeric, it is left unchanged. """ if n == None or n == '': return None @@ -48,10 +48,10 @@ def soft_cast_int(n): def soft_cast_float(n): """ - Convert to float, if possible. + Convert `n` to float, if possible. - If it is empty, returns None. - If it is not numeric, it is left unchanged. + If `n` is empty, returns None. + If `n` is not numeric, it is left unchanged. """ if n == None or n == '': return None @@ -70,6 +70,11 @@ def flatten(somelist): return [item for sublist in somelist for item in sublist] def parse_conf_str(conf_str): + """ + Parse a configuration string `k1=v1,k2=v2`... and return a dict `{'k1': v1, 'k2': v2}`... + + Values are casted to float if possible and kept as-is otherwise. + """ conf_dict = dict() for option in conf_str.split(','): key, value = option.split('=') @@ -77,6 +82,12 @@ def parse_conf_str(conf_str): return conf_dict def remove_index_from_tuple(parameters, index): + """ + Remove the element at `index` from tuple `parameters` (edited in-place). + + :param parameters: tuple (edited in-place) + :param index: index of element which is to be removed + """ return (*parameters[:index], *parameters[index+1:]) def param_slice_eq(a, b, index): @@ -99,7 +110,7 @@ def param_slice_eq(a, b, index): return True return False -def prune_dependent_parameters(by_name, parameter_names): +def prune_dependent_parameters(by_name, parameter_names, correlation_threshold = 0.5): """ Remove dependent parameters from aggregate. @@ -108,8 +119,9 @@ def prune_dependent_parameters(by_name, parameter_names): by_name[stanamete_or_trans]['param'] must be a list of parameter values. Other dict members are left as-is :param parameter_names: List of parameter names in the order they are used in by_name[name]['param'], edited in-place. + :param correlation_threshold: Remove parameter if absolute correlation exceeds this threshold (default: 0.5) - Model generation (and its components, such as relevant parameter detection and least squares optimization) only work if input variables (i.e., parameters) + Model generation (and its components, such as relevant parameter detection and least squares optimization) only works if input variables (i.e., parameters) are independent of each other. This function computes the correlation coefficient for each pair of parameters and removes those which depend on each other. For each pair of dependent parameters, the lexically greater one is removed (e.g. "a" and "b" -> "b" is removed). """ @@ -137,7 +149,7 @@ def prune_dependent_parameters(by_name, parameter_names): # Calculating the correlation coefficient only makes sense when neither value is constant if np.std(parameter_values_1) != 0 and np.std(parameter_values_2) != 0: correlation = np.corrcoef(parameter_values)[0][1] - if correlation != np.nan and np.abs(correlation) > 0.5: + if correlation != np.nan and np.abs(correlation) > correlation_threshold: print('[!] Parameters {} <-> {} are correlated with coefficcient {}'.format(parameter_names[index_1], parameter_names[index_2], correlation)) if len(parameter_values_1) < len(parameter_values_2): index_to_remove = index_1 @@ -278,6 +290,7 @@ def _corr_by_param(by_name, state_or_trans, attribute, param_index): return 0. def _all_params_are_numeric(data, param_idx): + """Check if all `data['param'][*][param_idx]` elements are numeric, as reported by `utils.is_numeric`.""" param_values = list(map(lambda x: x[param_idx], data['param'])) if len(list(filter(is_numeric, param_values))) == len(param_values): return True |