diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/utils.py | 11 |
1 files changed, 10 insertions, 1 deletions
diff --git a/lib/utils.py b/lib/utils.py index f6a34e3..4208c3a 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -45,6 +45,11 @@ def compute_param_statistics(by_name, by_param, parameter_names, arg_count, stat """ Compute standard deviation and correlation coefficient for various data partitions. + It is strongly recommended to vary all parameter values evenly across partitions. + For instance, given two parameters, providing only the combinations + (1, 1), (5, 1), (7, 1,) (10, 1), (1, 2), (1, 6) will lead to bogus results. + It is better to provide (1, 1), (5, 1), (1, 2), (5, 2), ... (i.e. a cross product of all individual parameter values) + arguments: by_name -- ground truth partitioned by state/transition name. by_name[state_or_trans][attribute] must be a list or 1-D numpy array. @@ -83,6 +88,8 @@ def compute_param_statistics(by_name, by_param, parameter_names, arg_count, stat 'corr_by_arg' : [], } + np.seterr('raise') + for param_idx, param in enumerate(parameter_names): ret['std_by_param'][param] = _mean_std_by_param(by_param, state_or_trans, attribute, param_idx) ret['corr_by_param'][param] = _corr_by_param(by_name, state_or_trans, attribute, param_idx) @@ -117,8 +124,10 @@ def _mean_std_by_param(by_param, state_or_tran, attribute, param_index): for k, v in by_param.items(): if param_slice_eq(k, param_value, param_index): param_partition.extend(v[attribute]) - if len(param_partition): + if len(param_partition) > 1: partitions.append(param_partition) + elif len(param_partition) == 1: + print('[W] parameter value partition for {} contains only one element -- skipping'.format(param_value)) else: print('[W] parameter value partition for {} is empty'.format(param_value)) return np.mean([np.std(partition) for partition in partitions]) |