summaryrefslogtreecommitdiff
path: root/lib/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/utils.py')
-rw-r--r--lib/utils.py37
1 files changed, 25 insertions, 12 deletions
diff --git a/lib/utils.py b/lib/utils.py
index abea67e..f54fb99 100644
--- a/lib/utils.py
+++ b/lib/utils.py
@@ -6,7 +6,7 @@ arg_support_enabled = True
def vprint(verbose, string):
"""
- Print string if verbose.
+ Print `string` if `verbose`.
Prints string if verbose is a True value
"""
@@ -14,7 +14,7 @@ def vprint(verbose, string):
print(string)
def is_numeric(n):
- """Check if n is numeric (i.e., can be converted to int)."""
+ """Check if `n` is numeric (i.e., it can be converted to int)."""
if n == None:
return False
try:
@@ -24,7 +24,7 @@ def is_numeric(n):
return False
def float_or_nan(n):
- """Convert to float (if numeric) or NaN."""
+ """Convert `n` to float (if numeric) or NaN."""
if n == None:
return np.nan
try:
@@ -34,10 +34,10 @@ def float_or_nan(n):
def soft_cast_int(n):
"""
- Convert to int, if possible.
+ Convert `n` to int, if possible.
- If it is empty, returns None.
- If it is not numeric, it is left unchanged.
+ If `n` is empty, returns None.
+ If `n` is not numeric, it is left unchanged.
"""
if n == None or n == '':
return None
@@ -48,10 +48,10 @@ def soft_cast_int(n):
def soft_cast_float(n):
"""
- Convert to float, if possible.
+ Convert `n` to float, if possible.
- If it is empty, returns None.
- If it is not numeric, it is left unchanged.
+ If `n` is empty, returns None.
+ If `n` is not numeric, it is left unchanged.
"""
if n == None or n == '':
return None
@@ -70,6 +70,11 @@ def flatten(somelist):
return [item for sublist in somelist for item in sublist]
def parse_conf_str(conf_str):
+ """
+ Parse a configuration string `k1=v1,k2=v2`... and return a dict `{'k1': v1, 'k2': v2}`...
+
+ Values are casted to float if possible and kept as-is otherwise.
+ """
conf_dict = dict()
for option in conf_str.split(','):
key, value = option.split('=')
@@ -77,6 +82,12 @@ def parse_conf_str(conf_str):
return conf_dict
def remove_index_from_tuple(parameters, index):
+ """
+ Remove the element at `index` from tuple `parameters` (edited in-place).
+
+ :param parameters: tuple (edited in-place)
+ :param index: index of element which is to be removed
+ """
return (*parameters[:index], *parameters[index+1:])
def param_slice_eq(a, b, index):
@@ -99,7 +110,7 @@ def param_slice_eq(a, b, index):
return True
return False
-def prune_dependent_parameters(by_name, parameter_names):
+def prune_dependent_parameters(by_name, parameter_names, correlation_threshold = 0.5):
"""
Remove dependent parameters from aggregate.
@@ -108,8 +119,9 @@ def prune_dependent_parameters(by_name, parameter_names):
by_name[stanamete_or_trans]['param'] must be a list of parameter values.
Other dict members are left as-is
:param parameter_names: List of parameter names in the order they are used in by_name[name]['param'], edited in-place.
+ :param correlation_threshold: Remove parameter if absolute correlation exceeds this threshold (default: 0.5)
- Model generation (and its components, such as relevant parameter detection and least squares optimization) only work if input variables (i.e., parameters)
+ Model generation (and its components, such as relevant parameter detection and least squares optimization) only works if input variables (i.e., parameters)
are independent of each other. This function computes the correlation coefficient for each pair of parameters and removes those which depend on each other.
For each pair of dependent parameters, the lexically greater one is removed (e.g. "a" and "b" -> "b" is removed).
"""
@@ -137,7 +149,7 @@ def prune_dependent_parameters(by_name, parameter_names):
# Calculating the correlation coefficient only makes sense when neither value is constant
if np.std(parameter_values_1) != 0 and np.std(parameter_values_2) != 0:
correlation = np.corrcoef(parameter_values)[0][1]
- if correlation != np.nan and np.abs(correlation) > 0.5:
+ if correlation != np.nan and np.abs(correlation) > correlation_threshold:
print('[!] Parameters {} <-> {} are correlated with coefficcient {}'.format(parameter_names[index_1], parameter_names[index_2], correlation))
if len(parameter_values_1) < len(parameter_values_2):
index_to_remove = index_1
@@ -278,6 +290,7 @@ def _corr_by_param(by_name, state_or_trans, attribute, param_index):
return 0.
def _all_params_are_numeric(data, param_idx):
+ """Check if all `data['param'][*][param_idx]` elements are numeric, as reported by `utils.is_numeric`."""
param_values = list(map(lambda x: x[param_idx], data['param']))
if len(list(filter(is_numeric, param_values))) == len(param_values):
return True