summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Friesel <derf@finalrewind.org>2017-05-30 15:58:32 +0200
committerDaniel Friesel <derf@finalrewind.org>2017-05-30 15:58:32 +0200
commit6b45e9720286d21aee1de6e381d9200002812491 (patch)
treeb6293ee750b6198c593d33dffee2df7823947f89
parentd2057ddd183163973613a6ce0be01fb656cc799d (diff)
calculate Spearman rank-order correlation coefficient
-rwxr-xr-xbin/merge.py20
-rw-r--r--lib/Kratos/DFADriver.pm5
-rwxr-xr-xlib/dfatool.py8
3 files changed, 31 insertions, 2 deletions
diff --git a/bin/merge.py b/bin/merge.py
index e5365b2..551bc9e 100755
--- a/bin/merge.py
+++ b/bin/merge.py
@@ -9,9 +9,10 @@ import sys
import plotter
from copy import deepcopy
from dfatool import aggregate_measures, regression_measures, is_numeric, powerset
-from dfatool import append_if_set, mean_or_none
+from dfatool import append_if_set, mean_or_none, float_or_nan
from matplotlib.patches import Polygon
-from scipy import optimize
+from scipy import optimize, stats
+#import pickle
opts = {}
@@ -389,6 +390,12 @@ def mean_std_by_param(data, keys, name, what, index):
partitions.append(partition)
return np.mean([np.std(partition) for partition in partitions])
+def spearmanr_by_param(name, what, index):
+ sr = stats.spearmanr(by_name[name][what], list(map(lambda x : float_or_nan(x[index]), by_name[name]['param'])))[0]
+ if sr == np.nan:
+ return None
+ return sr
+
# returns the mean standard deviation of all measurements of 'what'
# (e.g. energy or duration) for transition 'name' where
# the 'index'th argumetn is dynamic and all other arguments are fixed.
@@ -584,6 +591,7 @@ def keydata(name, val, argdata, paramdata, tracedata, key):
'std_param' : np.mean([np.std(paramdata[x][key]) for x in paramdata.keys() if x[0] == name]),
'std_trace' : np.mean([np.std(tracedata[x][key]) for x in tracedata.keys() if x[0] == name]),
'std_by_param' : {},
+ 'spearmanr_by_param' : {},
'fit_guess' : {},
'function' : {},
}
@@ -844,6 +852,7 @@ def crossvalidate(by_name, by_param, by_trace, model, parameters):
def analyze_by_param(aggval, by_param, allvalues, name, key1, key2, param, param_idx):
aggval[key1]['std_by_param'][param] = mean_std_by_param(
by_param, allvalues, name, key2, param_idx)
+ aggval[key1]['spearmanr_by_param'][param] = spearmanr_by_param(name, key2, param_idx)
if aggval[key1]['std_by_param'][param] > 0 and aggval[key1]['std_param'] / aggval[key1]['std_by_param'][param] < 0.6:
aggval[key1]['fit_guess'][param] = try_fits(name, key2, param_idx, by_param)
@@ -1000,6 +1009,13 @@ for arg in args:
if elem['name'] != 'UNINITIALIZED':
load_run_elem(i, elem, run['trace'], by_name, by_arg, by_param, by_trace)
+#with open('/tmp/by_name.pickle', 'wb') as f:
+# pickle.dump(by_name, f, pickle.HIGHEST_PROTOCOL)
+#with open('/tmp/by_arg.pickle', 'wb') as f:
+# pickle.dump(by_arg, f, pickle.HIGHEST_PROTOCOL)
+#with open('/tmp/by_param.pickle', 'wb') as f:
+# pickle.dump(by_param, f, pickle.HIGHEST_PROTOCOL)
+
if 'states' in opts:
if 'params' in opts:
plotter.plot_states_param(data['model'], by_param)
diff --git a/lib/Kratos/DFADriver.pm b/lib/Kratos/DFADriver.pm
index f238ba4..ade6067 100644
--- a/lib/Kratos/DFADriver.pm
+++ b/lib/Kratos/DFADriver.pm
@@ -342,6 +342,7 @@ sub printf_parameterized {
my $std_by_arg = $hash->{std_by_arg} // {};
my $std_by_param = $hash->{std_by_param};
my $std_by_trace = $hash->{std_by_trace} // {};
+ my $r_by_param = $hash->{spearmanr_by_param} // {};
my $arg_ratio;
my $param_ratio;
my $trace_ratio;
@@ -423,6 +424,10 @@ sub printf_parameterized {
$key, $status, $param, $std_ind_param, $std_this, $ratio,
$fline );
}
+ if (exists $r_by_param->{$param}) {
+ printf(" %s: spearman_r for global %s is %.3f (p = %.3f)\n",
+ $key, $param, $r_by_param->{$param}, -1);
+ }
}
for my $arg ( sort keys %{$std_by_arg} ) {
diff --git a/lib/dfatool.py b/lib/dfatool.py
index 66be4fd..fcbac94 100755
--- a/lib/dfatool.py
+++ b/lib/dfatool.py
@@ -23,6 +23,14 @@ def is_numeric(n):
except ValueError:
return False
+def float_or_nan(n):
+ if n == None:
+ return np.nan
+ try:
+ return float(n)
+ except ValueError:
+ return np.nan
+
def append_if_set(aggregate, data, key):
if key in data:
aggregate.append(data[key])