summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rwxr-xr-xlib/dfatool.py42
-rw-r--r--lib/utils.py3
2 files changed, 20 insertions, 25 deletions
diff --git a/lib/dfatool.py b/lib/dfatool.py
index ecd3051..853eb13 100755
--- a/lib/dfatool.py
+++ b/lib/dfatool.py
@@ -15,7 +15,7 @@ from multiprocessing import Pool
from automata import PTA
from functions import analytic
from functions import AnalyticFunction
-from utils import vprint, is_numeric, soft_cast_int, param_slice_eq, compute_param_statistics
+from utils import vprint, is_numeric, soft_cast_int, param_slice_eq, compute_param_statistics, remove_index_from_tuple
arg_support_enabled = True
@@ -927,27 +927,18 @@ def _try_fits(by_param, state_or_tran, model_attribute, param_index, safe_functi
median_rmsd -- mean Root Mean Square Deviation of a reference model using the median of its respective input data as model value
results -- mean goodness-of-fit measures for the individual functions. See `analytic.functions` for keys and `aggregate_measures` for values
- arguments
- ---
-
- by_param: measurements partitioned by state/transition/... name and parameter values.
- Example: `{('foo', (0, 2)): {'bar': [2]}, ('foo', (0, 4)): {'bar': [4]}, ('foo', (0, 6)): {'bar': [6]}}`
-
- state_or_tran: state/transition/... name for which goodness-of-fit will be calculated (first element of by_param key tuple).
- Example: `'foo'`
-
- model_attribute: attribute for which goodness-of-fit will be calculated.
- Example: `'bar'`
-
- param_index -- index of the parameter used as model input
- safe_functions_enabled -- Include "safe" variants of functions with limited argument range.
+ :param by_param: measurements partitioned by state/transition/... name and parameter values.
+ Example: `{('foo', (0, 2)): {'bar': [2]}, ('foo', (0, 4)): {'bar': [4]}, ('foo', (0, 6)): {'bar': [6]}}`
+ :param state_or_tran: state/transition/... name for which goodness-of-fit will be calculated (first element of by_param key tuple).
+ Example: `'foo'`
+ :param model_attribute: attribute for which goodness-of-fit will be calculated.
+ Example: `'bar'`
+ :param param_index: index of the parameter used as model input
+ :param safe_functions_enabled: Include "safe" variants of functions with limited argument range.
"""
functions = analytic.functions(safe_functions_enabled = safe_functions_enabled)
- #print('_try_fits(..., {}, {}, {})'.format(state_or_tran, model_attribute, param_index))
-
-
for param_key in filter(lambda x: x[0] == state_or_tran, by_param.keys()):
# We might remove elements from 'functions' while iterating over
# its keys. A generator will not allow this, so we need to
@@ -966,18 +957,19 @@ def _try_fits(by_param, state_or_tran, model_attribute, param_index, safe_functi
results = {}
results_by_param = {}
- # TODO diese Funktion ist unfair, wenn ein Parameter in einer Variante deutlich mehr unterschiedliche Werte
- # aufweist als bei der Kombination mit anderen Parametern. Gibt es z.B. die Parameterkombinationen
- # (0,2), (0, 4), (0,6), (0,8), (0, 10), 0,12), (2, 2), (2, 4), (2, 6) und wird der Parameter mit Index 1 bestimmt,
- # so haben die Messwerte für Parameter-Index 0 == 0 mehr Gewicht als die für Parameter-Index 0 == 2.
- # Bei klassischen AEMR-generierten Benchmarks macht das nichts, weil für alle Kombinationen die gleichen Parameterwerte
- # genutzt werden, das kann sich aber noch ändern...
+ seen_parameter_combinations = set()
+
# for each parameter combination:
- for param_key in filter(lambda x: x[0] == state_or_tran, by_param.keys()):
+ for param_key in filter(lambda x: x[0] == state_or_tran and remove_index_from_tuple(x[1], param_index) not in seen_parameter_combinations, by_param.keys()):
X = []
Y = []
num_valid = 0
num_total = 0
+
+ # Ensure that each parameter combination is only optimized once. Otherwise, with parameters (1, 2, 5), (1, 3, 5), (1, 4, 5) and param_index == 1,
+ # the parameter combination (1, *, 5) would be optimized three times
+ seen_parameter_combinations.add(remove_index_from_tuple(param_key[1], param_index))
+
# for each value of the parameter denoted by param_index (all other parameters remain the same):
for k, v in filter(lambda kv: param_slice_eq(kv[0], param_key, param_index), by_param.items()):
num_total += 1
diff --git a/lib/utils.py b/lib/utils.py
index f31aa8e..3ac4792 100644
--- a/lib/utils.py
+++ b/lib/utils.py
@@ -75,6 +75,9 @@ def parse_conf_str(conf_str):
conf_dict[key] = soft_cast_float(value)
return conf_dict
+def remove_index_from_tuple(parameters, index):
+ return (*parameters[:index], *parameters[index+1:])
+
def param_slice_eq(a, b, index):
"""
Check if by_param keys a and b are identical, ignoring the parameter at index.