diff options
Diffstat (limited to 'bin')
-rwxr-xr-x | bin/merge.py | 20 |
1 files changed, 18 insertions, 2 deletions
diff --git a/bin/merge.py b/bin/merge.py index e5365b2..551bc9e 100755 --- a/bin/merge.py +++ b/bin/merge.py @@ -9,9 +9,10 @@ import sys import plotter from copy import deepcopy from dfatool import aggregate_measures, regression_measures, is_numeric, powerset -from dfatool import append_if_set, mean_or_none +from dfatool import append_if_set, mean_or_none, float_or_nan from matplotlib.patches import Polygon -from scipy import optimize +from scipy import optimize, stats +#import pickle opts = {} @@ -389,6 +390,12 @@ def mean_std_by_param(data, keys, name, what, index): partitions.append(partition) return np.mean([np.std(partition) for partition in partitions]) +def spearmanr_by_param(name, what, index): + sr = stats.spearmanr(by_name[name][what], list(map(lambda x : float_or_nan(x[index]), by_name[name]['param'])))[0] + if sr == np.nan: + return None + return sr + # returns the mean standard deviation of all measurements of 'what' # (e.g. energy or duration) for transition 'name' where # the 'index'th argumetn is dynamic and all other arguments are fixed. @@ -584,6 +591,7 @@ def keydata(name, val, argdata, paramdata, tracedata, key): 'std_param' : np.mean([np.std(paramdata[x][key]) for x in paramdata.keys() if x[0] == name]), 'std_trace' : np.mean([np.std(tracedata[x][key]) for x in tracedata.keys() if x[0] == name]), 'std_by_param' : {}, + 'spearmanr_by_param' : {}, 'fit_guess' : {}, 'function' : {}, } @@ -844,6 +852,7 @@ def crossvalidate(by_name, by_param, by_trace, model, parameters): def analyze_by_param(aggval, by_param, allvalues, name, key1, key2, param, param_idx): aggval[key1]['std_by_param'][param] = mean_std_by_param( by_param, allvalues, name, key2, param_idx) + aggval[key1]['spearmanr_by_param'][param] = spearmanr_by_param(name, key2, param_idx) if aggval[key1]['std_by_param'][param] > 0 and aggval[key1]['std_param'] / aggval[key1]['std_by_param'][param] < 0.6: aggval[key1]['fit_guess'][param] = try_fits(name, key2, param_idx, by_param) @@ -1000,6 +1009,13 @@ for arg in args: if elem['name'] != 'UNINITIALIZED': load_run_elem(i, elem, run['trace'], by_name, by_arg, by_param, by_trace) +#with open('/tmp/by_name.pickle', 'wb') as f: +# pickle.dump(by_name, f, pickle.HIGHEST_PROTOCOL) +#with open('/tmp/by_arg.pickle', 'wb') as f: +# pickle.dump(by_arg, f, pickle.HIGHEST_PROTOCOL) +#with open('/tmp/by_param.pickle', 'wb') as f: +# pickle.dump(by_param, f, pickle.HIGHEST_PROTOCOL) + if 'states' in opts: if 'params' in opts: plotter.plot_states_param(data['model'], by_param) |