summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/utils.py11
1 files changed, 10 insertions, 1 deletions
diff --git a/lib/utils.py b/lib/utils.py
index f6a34e3..4208c3a 100644
--- a/lib/utils.py
+++ b/lib/utils.py
@@ -45,6 +45,11 @@ def compute_param_statistics(by_name, by_param, parameter_names, arg_count, stat
"""
Compute standard deviation and correlation coefficient for various data partitions.
+ It is strongly recommended to vary all parameter values evenly across partitions.
+ For instance, given two parameters, providing only the combinations
+ (1, 1), (5, 1), (7, 1,) (10, 1), (1, 2), (1, 6) will lead to bogus results.
+ It is better to provide (1, 1), (5, 1), (1, 2), (5, 2), ... (i.e. a cross product of all individual parameter values)
+
arguments:
by_name -- ground truth partitioned by state/transition name.
by_name[state_or_trans][attribute] must be a list or 1-D numpy array.
@@ -83,6 +88,8 @@ def compute_param_statistics(by_name, by_param, parameter_names, arg_count, stat
'corr_by_arg' : [],
}
+ np.seterr('raise')
+
for param_idx, param in enumerate(parameter_names):
ret['std_by_param'][param] = _mean_std_by_param(by_param, state_or_trans, attribute, param_idx)
ret['corr_by_param'][param] = _corr_by_param(by_name, state_or_trans, attribute, param_idx)
@@ -117,8 +124,10 @@ def _mean_std_by_param(by_param, state_or_tran, attribute, param_index):
for k, v in by_param.items():
if param_slice_eq(k, param_value, param_index):
param_partition.extend(v[attribute])
- if len(param_partition):
+ if len(param_partition) > 1:
partitions.append(param_partition)
+ elif len(param_partition) == 1:
+ print('[W] parameter value partition for {} contains only one element -- skipping'.format(param_value))
else:
print('[W] parameter value partition for {} is empty'.format(param_value))
return np.mean([np.std(partition) for partition in partitions])