summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/utils.py51
1 files changed, 48 insertions, 3 deletions
diff --git a/lib/utils.py b/lib/utils.py
index d8df9cd..e3eda16 100644
--- a/lib/utils.py
+++ b/lib/utils.py
@@ -83,10 +83,11 @@ def parse_conf_str(conf_str):
def remove_index_from_tuple(parameters, index):
"""
- Remove the element at `index` from tuple `parameters` (edited in-place).
+ Remove the element at `index` from tuple `parameters`.
- :param parameters: tuple (edited in-place)
+ :param parameters: tuple
:param index: index of element which is to be removed
+ :returns: parameters tuple without the element at index
"""
return (*parameters[:index], *parameters[index+1:])
@@ -202,7 +203,7 @@ def compute_param_statistics(by_name, by_param, parameter_names, arg_count, stat
:param attribute: model attribute, e.g. 'power' or 'duration'
:param verbose: print warning if some parameter partitions are too small for fitting
- :return: a dict with the following content:
+ :returns: a dict with the following content:
std_static -- static parameter-unaware model error: stddev of by_name[state_or_trans][attribute]
std_param_lut -- static parameter-aware model error: mean stddev of by_param[(state_or_trans, *)][attribute]
std_by_param -- static parameter-aware model error ignoring a single parameter.
@@ -237,6 +238,26 @@ def compute_param_statistics(by_name, by_param, parameter_names, arg_count, stat
return ret
+def _param_values(by_param, state_or_tran):
+ """
+ Return the distinct values of each parameter in by_param.
+
+ E.g. if by_param.keys() contains the distinct parameter values (1, 1), (1, 2), (1, 3), (0, 3),
+ this function returns [[1, 0], [1, 2, 3]].
+ Note that the order is not deterministic at the moment.
+ """
+ param_tuples = list(map(lambda x: x[1], filter(lambda x: x[0] == state_or_tran, by_param.keys())))
+ distinct_values = [set() for i in range(len(param_tuples[0]))]
+ for param_tuple in param_tuples:
+ for i in range(len(param_tuple)):
+ distinct_values[i].add(param_tuple[i])
+
+ # TODO returned values must have a deterministic order
+
+ # Convert sets to lists
+ distinct_values = list(map(list, distinct_values))
+ return distinct_values
+
def _mean_std_by_param(by_param, state_or_tran, attribute, param_index, verbose = False):
u"""
Calculate the mean standard deviation for a static model where all parameters but param_index are constant.
@@ -254,8 +275,32 @@ def _mean_std_by_param(by_param, state_or_tran, attribute, param_index, verbose
I.e., if parameters are a, b, c ∈ {1,2,3} and 'index' corresponds to b, then
this function returns the mean of the standard deviations of (a=1, b=*, c=1),
(a=1, b=*, c=2), and so on.
+ Also returns an (n-1)-dimensional array (where n is the number of parameters)
+ giving the standard deviation of each individual partition. E.g. for
+ param_index == 2 and 4 parameters, array[a][b][d] is the
+ stddev of measurements with param0 == a, param1 == b, param2 variable,
+ and param3 == d.
"""
partitions = []
+
+ # TODO precalculate or cache info_shape (it only depends on state_or_tran)
+ param_values = list(remove_index_from_tuple(_param_values(by_param, state_or_tran), param_index))
+ info_shape = tuple(map(len, param_values))
+ stddev_matrix = np.full(info_shape, np.nan)
+
+ for param_value in itertools.product(*param_values):
+ param_partition = list()
+ for k, v in by_param.items():
+ if k[0] == state_or_tran and (*k[1][:param_index], *k[1][param_index+1:]) == param_value:
+ param_partition.extend(v[attribute])
+
+ if len(param_partition) > 1:
+ matrix_index = list(range(len(param_value)))
+ for i in range(len(param_value)):
+ matrix_index[i] = param_values[i].index(param_value[i])
+ matrix_index = tuple(matrix_index)
+ stddev_matrix[matrix_index] = np.std(param_partition)
+
for param_value in filter(lambda x: x[0] == state_or_tran, by_param.keys()):
param_partition = []
for k, v in by_param.items():