diff options
Diffstat (limited to 'lib/utils.py')
-rw-r--r-- | lib/utils.py | 20 |
1 files changed, 11 insertions, 9 deletions
diff --git a/lib/utils.py b/lib/utils.py index 8d1b817..abea67e 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -134,15 +134,17 @@ def prune_dependent_parameters(by_name, parameter_names): parameter_values[0].append(value_1) parameter_values[1].append(value_2) if len(parameter_values[0]): - correlation = np.corrcoef(parameter_values)[0][1] - if correlation != np.nan and np.abs(correlation) > 0.5: - print('[!] Parameters {} <-> {} are correlated with coefficcient {}'.format(parameter_names[index_1], parameter_names[index_2], correlation)) - if len(parameter_values_1) < len(parameter_values_2): - index_to_remove = index_1 - else: - index_to_remove = index_2 - print(' Removing parameter {}'.format(parameter_names[index_to_remove])) - parameter_indices_to_remove.append(index_to_remove) + # Calculating the correlation coefficient only makes sense when neither value is constant + if np.std(parameter_values_1) != 0 and np.std(parameter_values_2) != 0: + correlation = np.corrcoef(parameter_values)[0][1] + if correlation != np.nan and np.abs(correlation) > 0.5: + print('[!] Parameters {} <-> {} are correlated with coefficcient {}'.format(parameter_names[index_1], parameter_names[index_2], correlation)) + if len(parameter_values_1) < len(parameter_values_2): + index_to_remove = index_1 + else: + index_to_remove = index_2 + print(' Removing parameter {}'.format(parameter_names[index_to_remove])) + parameter_indices_to_remove.append(index_to_remove) remove_parameters_by_indices(by_name, parameter_names, parameter_indices_to_remove) def remove_parameters_by_indices(by_name, parameter_names, parameter_indices_to_remove): |