diff options
| author | Daniel Friesel <derf@finalrewind.org> | 2017-05-30 15:58:32 +0200 | 
|---|---|---|
| committer | Daniel Friesel <derf@finalrewind.org> | 2017-05-30 15:58:32 +0200 | 
| commit | 6b45e9720286d21aee1de6e381d9200002812491 (patch) | |
| tree | b6293ee750b6198c593d33dffee2df7823947f89 /bin | |
| parent | d2057ddd183163973613a6ce0be01fb656cc799d (diff) | |
calculate Spearman rank-order correlation coefficient
Diffstat (limited to 'bin')
| -rwxr-xr-x | bin/merge.py | 20 | 
1 files changed, 18 insertions, 2 deletions
| diff --git a/bin/merge.py b/bin/merge.py index e5365b2..551bc9e 100755 --- a/bin/merge.py +++ b/bin/merge.py @@ -9,9 +9,10 @@ import sys  import plotter  from copy import deepcopy  from dfatool import aggregate_measures, regression_measures, is_numeric, powerset -from dfatool import append_if_set, mean_or_none +from dfatool import append_if_set, mean_or_none, float_or_nan  from matplotlib.patches import Polygon -from scipy import optimize +from scipy import optimize, stats +#import pickle  opts = {} @@ -389,6 +390,12 @@ def mean_std_by_param(data, keys, name, what, index):          partitions.append(partition)      return np.mean([np.std(partition) for partition in partitions]) +def spearmanr_by_param(name, what, index): +    sr = stats.spearmanr(by_name[name][what], list(map(lambda x : float_or_nan(x[index]), by_name[name]['param'])))[0] +    if sr == np.nan: +        return None +    return sr +  # returns the mean standard deviation of all measurements of 'what'  # (e.g. energy or duration) for transition 'name' where  # the 'index'th argumetn is dynamic and all other arguments are fixed. @@ -584,6 +591,7 @@ def keydata(name, val, argdata, paramdata, tracedata, key):          'std_param' : np.mean([np.std(paramdata[x][key]) for x in paramdata.keys() if x[0] == name]),          'std_trace' : np.mean([np.std(tracedata[x][key]) for x in tracedata.keys() if x[0] == name]),          'std_by_param' : {}, +        'spearmanr_by_param' : {},          'fit_guess' : {},          'function' : {},      } @@ -844,6 +852,7 @@ def crossvalidate(by_name, by_param, by_trace, model, parameters):  def analyze_by_param(aggval, by_param, allvalues, name, key1, key2, param, param_idx):      aggval[key1]['std_by_param'][param] = mean_std_by_param(          by_param, allvalues, name, key2, param_idx) +    aggval[key1]['spearmanr_by_param'][param] = spearmanr_by_param(name, key2, param_idx)      if aggval[key1]['std_by_param'][param] > 0 and aggval[key1]['std_param'] / aggval[key1]['std_by_param'][param] < 0.6:          aggval[key1]['fit_guess'][param] = try_fits(name, key2, param_idx, by_param) @@ -1000,6 +1009,13 @@ for arg in args:                  if elem['name'] != 'UNINITIALIZED':                      load_run_elem(i, elem, run['trace'], by_name, by_arg, by_param, by_trace) +#with open('/tmp/by_name.pickle', 'wb') as f: +#    pickle.dump(by_name, f, pickle.HIGHEST_PROTOCOL) +#with open('/tmp/by_arg.pickle', 'wb') as f: +#    pickle.dump(by_arg, f, pickle.HIGHEST_PROTOCOL) +#with open('/tmp/by_param.pickle', 'wb') as f: +#    pickle.dump(by_param, f, pickle.HIGHEST_PROTOCOL) +  if 'states' in opts:      if 'params' in opts:          plotter.plot_states_param(data['model'], by_param) | 
