From 719734d3eef6e0d0a603523055334a77ee717f9f Mon Sep 17 00:00:00 2001 From: Daniel Friesel Date: Tue, 5 Feb 2019 11:18:46 +0100 Subject: compute_param_statistics: Warn when encountering useless data --- lib/utils.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) (limited to 'lib') diff --git a/lib/utils.py b/lib/utils.py index f6a34e3..4208c3a 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -45,6 +45,11 @@ def compute_param_statistics(by_name, by_param, parameter_names, arg_count, stat """ Compute standard deviation and correlation coefficient for various data partitions. + It is strongly recommended to vary all parameter values evenly across partitions. + For instance, given two parameters, providing only the combinations + (1, 1), (5, 1), (7, 1,) (10, 1), (1, 2), (1, 6) will lead to bogus results. + It is better to provide (1, 1), (5, 1), (1, 2), (5, 2), ... (i.e. a cross product of all individual parameter values) + arguments: by_name -- ground truth partitioned by state/transition name. by_name[state_or_trans][attribute] must be a list or 1-D numpy array. @@ -83,6 +88,8 @@ def compute_param_statistics(by_name, by_param, parameter_names, arg_count, stat 'corr_by_arg' : [], } + np.seterr('raise') + for param_idx, param in enumerate(parameter_names): ret['std_by_param'][param] = _mean_std_by_param(by_param, state_or_trans, attribute, param_idx) ret['corr_by_param'][param] = _corr_by_param(by_name, state_or_trans, attribute, param_idx) @@ -117,8 +124,10 @@ def _mean_std_by_param(by_param, state_or_tran, attribute, param_index): for k, v in by_param.items(): if param_slice_eq(k, param_value, param_index): param_partition.extend(v[attribute]) - if len(param_partition): + if len(param_partition) > 1: partitions.append(param_partition) + elif len(param_partition) == 1: + print('[W] parameter value partition for {} contains only one element -- skipping'.format(param_value)) else: print('[W] parameter value partition for {} is empty'.format(param_value)) return np.mean([np.std(partition) for partition in partitions]) -- cgit v1.2.3