diff options
-rwxr-xr-x | lib/dfatool.py | 42 | ||||
-rw-r--r-- | lib/utils.py | 3 |
2 files changed, 20 insertions, 25 deletions
diff --git a/lib/dfatool.py b/lib/dfatool.py index ecd3051..853eb13 100755 --- a/lib/dfatool.py +++ b/lib/dfatool.py @@ -15,7 +15,7 @@ from multiprocessing import Pool from automata import PTA from functions import analytic from functions import AnalyticFunction -from utils import vprint, is_numeric, soft_cast_int, param_slice_eq, compute_param_statistics +from utils import vprint, is_numeric, soft_cast_int, param_slice_eq, compute_param_statistics, remove_index_from_tuple arg_support_enabled = True @@ -927,27 +927,18 @@ def _try_fits(by_param, state_or_tran, model_attribute, param_index, safe_functi median_rmsd -- mean Root Mean Square Deviation of a reference model using the median of its respective input data as model value results -- mean goodness-of-fit measures for the individual functions. See `analytic.functions` for keys and `aggregate_measures` for values - arguments - --- - - by_param: measurements partitioned by state/transition/... name and parameter values. - Example: `{('foo', (0, 2)): {'bar': [2]}, ('foo', (0, 4)): {'bar': [4]}, ('foo', (0, 6)): {'bar': [6]}}` - - state_or_tran: state/transition/... name for which goodness-of-fit will be calculated (first element of by_param key tuple). - Example: `'foo'` - - model_attribute: attribute for which goodness-of-fit will be calculated. - Example: `'bar'` - - param_index -- index of the parameter used as model input - safe_functions_enabled -- Include "safe" variants of functions with limited argument range. + :param by_param: measurements partitioned by state/transition/... name and parameter values. + Example: `{('foo', (0, 2)): {'bar': [2]}, ('foo', (0, 4)): {'bar': [4]}, ('foo', (0, 6)): {'bar': [6]}}` + :param state_or_tran: state/transition/... name for which goodness-of-fit will be calculated (first element of by_param key tuple). + Example: `'foo'` + :param model_attribute: attribute for which goodness-of-fit will be calculated. + Example: `'bar'` + :param param_index: index of the parameter used as model input + :param safe_functions_enabled: Include "safe" variants of functions with limited argument range. """ functions = analytic.functions(safe_functions_enabled = safe_functions_enabled) - #print('_try_fits(..., {}, {}, {})'.format(state_or_tran, model_attribute, param_index)) - - for param_key in filter(lambda x: x[0] == state_or_tran, by_param.keys()): # We might remove elements from 'functions' while iterating over # its keys. A generator will not allow this, so we need to @@ -966,18 +957,19 @@ def _try_fits(by_param, state_or_tran, model_attribute, param_index, safe_functi results = {} results_by_param = {} - # TODO diese Funktion ist unfair, wenn ein Parameter in einer Variante deutlich mehr unterschiedliche Werte - # aufweist als bei der Kombination mit anderen Parametern. Gibt es z.B. die Parameterkombinationen - # (0,2), (0, 4), (0,6), (0,8), (0, 10), 0,12), (2, 2), (2, 4), (2, 6) und wird der Parameter mit Index 1 bestimmt, - # so haben die Messwerte für Parameter-Index 0 == 0 mehr Gewicht als die für Parameter-Index 0 == 2. - # Bei klassischen AEMR-generierten Benchmarks macht das nichts, weil für alle Kombinationen die gleichen Parameterwerte - # genutzt werden, das kann sich aber noch ändern... + seen_parameter_combinations = set() + # for each parameter combination: - for param_key in filter(lambda x: x[0] == state_or_tran, by_param.keys()): + for param_key in filter(lambda x: x[0] == state_or_tran and remove_index_from_tuple(x[1], param_index) not in seen_parameter_combinations, by_param.keys()): X = [] Y = [] num_valid = 0 num_total = 0 + + # Ensure that each parameter combination is only optimized once. Otherwise, with parameters (1, 2, 5), (1, 3, 5), (1, 4, 5) and param_index == 1, + # the parameter combination (1, *, 5) would be optimized three times + seen_parameter_combinations.add(remove_index_from_tuple(param_key[1], param_index)) + # for each value of the parameter denoted by param_index (all other parameters remain the same): for k, v in filter(lambda kv: param_slice_eq(kv[0], param_key, param_index), by_param.items()): num_total += 1 diff --git a/lib/utils.py b/lib/utils.py index f31aa8e..3ac4792 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -75,6 +75,9 @@ def parse_conf_str(conf_str): conf_dict[key] = soft_cast_float(value) return conf_dict +def remove_index_from_tuple(parameters, index): + return (*parameters[:index], *parameters[index+1:]) + def param_slice_eq(a, b, index): """ Check if by_param keys a and b are identical, ignoring the parameter at index. |