summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/utils.py20
1 files changed, 11 insertions, 9 deletions
diff --git a/lib/utils.py b/lib/utils.py
index 8d1b817..abea67e 100644
--- a/lib/utils.py
+++ b/lib/utils.py
@@ -134,15 +134,17 @@ def prune_dependent_parameters(by_name, parameter_names):
parameter_values[0].append(value_1)
parameter_values[1].append(value_2)
if len(parameter_values[0]):
- correlation = np.corrcoef(parameter_values)[0][1]
- if correlation != np.nan and np.abs(correlation) > 0.5:
- print('[!] Parameters {} <-> {} are correlated with coefficcient {}'.format(parameter_names[index_1], parameter_names[index_2], correlation))
- if len(parameter_values_1) < len(parameter_values_2):
- index_to_remove = index_1
- else:
- index_to_remove = index_2
- print(' Removing parameter {}'.format(parameter_names[index_to_remove]))
- parameter_indices_to_remove.append(index_to_remove)
+ # Calculating the correlation coefficient only makes sense when neither value is constant
+ if np.std(parameter_values_1) != 0 and np.std(parameter_values_2) != 0:
+ correlation = np.corrcoef(parameter_values)[0][1]
+ if correlation != np.nan and np.abs(correlation) > 0.5:
+ print('[!] Parameters {} <-> {} are correlated with coefficcient {}'.format(parameter_names[index_1], parameter_names[index_2], correlation))
+ if len(parameter_values_1) < len(parameter_values_2):
+ index_to_remove = index_1
+ else:
+ index_to_remove = index_2
+ print(' Removing parameter {}'.format(parameter_names[index_to_remove]))
+ parameter_indices_to_remove.append(index_to_remove)
remove_parameters_by_indices(by_name, parameter_names, parameter_indices_to_remove)
def remove_parameters_by_indices(by_name, parameter_names, parameter_indices_to_remove):