summaryrefslogtreecommitdiff
path: root/bin
diff options
context:
space:
mode:
Diffstat (limited to 'bin')
-rwxr-xr-xbin/merge.py20
1 files changed, 18 insertions, 2 deletions
diff --git a/bin/merge.py b/bin/merge.py
index e5365b2..551bc9e 100755
--- a/bin/merge.py
+++ b/bin/merge.py
@@ -9,9 +9,10 @@ import sys
import plotter
from copy import deepcopy
from dfatool import aggregate_measures, regression_measures, is_numeric, powerset
-from dfatool import append_if_set, mean_or_none
+from dfatool import append_if_set, mean_or_none, float_or_nan
from matplotlib.patches import Polygon
-from scipy import optimize
+from scipy import optimize, stats
+#import pickle
opts = {}
@@ -389,6 +390,12 @@ def mean_std_by_param(data, keys, name, what, index):
partitions.append(partition)
return np.mean([np.std(partition) for partition in partitions])
+def spearmanr_by_param(name, what, index):
+ sr = stats.spearmanr(by_name[name][what], list(map(lambda x : float_or_nan(x[index]), by_name[name]['param'])))[0]
+ if sr == np.nan:
+ return None
+ return sr
+
# returns the mean standard deviation of all measurements of 'what'
# (e.g. energy or duration) for transition 'name' where
# the 'index'th argumetn is dynamic and all other arguments are fixed.
@@ -584,6 +591,7 @@ def keydata(name, val, argdata, paramdata, tracedata, key):
'std_param' : np.mean([np.std(paramdata[x][key]) for x in paramdata.keys() if x[0] == name]),
'std_trace' : np.mean([np.std(tracedata[x][key]) for x in tracedata.keys() if x[0] == name]),
'std_by_param' : {},
+ 'spearmanr_by_param' : {},
'fit_guess' : {},
'function' : {},
}
@@ -844,6 +852,7 @@ def crossvalidate(by_name, by_param, by_trace, model, parameters):
def analyze_by_param(aggval, by_param, allvalues, name, key1, key2, param, param_idx):
aggval[key1]['std_by_param'][param] = mean_std_by_param(
by_param, allvalues, name, key2, param_idx)
+ aggval[key1]['spearmanr_by_param'][param] = spearmanr_by_param(name, key2, param_idx)
if aggval[key1]['std_by_param'][param] > 0 and aggval[key1]['std_param'] / aggval[key1]['std_by_param'][param] < 0.6:
aggval[key1]['fit_guess'][param] = try_fits(name, key2, param_idx, by_param)
@@ -1000,6 +1009,13 @@ for arg in args:
if elem['name'] != 'UNINITIALIZED':
load_run_elem(i, elem, run['trace'], by_name, by_arg, by_param, by_trace)
+#with open('/tmp/by_name.pickle', 'wb') as f:
+# pickle.dump(by_name, f, pickle.HIGHEST_PROTOCOL)
+#with open('/tmp/by_arg.pickle', 'wb') as f:
+# pickle.dump(by_arg, f, pickle.HIGHEST_PROTOCOL)
+#with open('/tmp/by_param.pickle', 'wb') as f:
+# pickle.dump(by_param, f, pickle.HIGHEST_PROTOCOL)
+
if 'states' in opts:
if 'params' in opts:
plotter.plot_states_param(data['model'], by_param)