summaryrefslogtreecommitdiff
path: root/lib/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/model.py')
-rw-r--r--lib/model.py68
1 files changed, 27 insertions, 41 deletions
diff --git a/lib/model.py b/lib/model.py
index 192cea3..1190fb0 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -114,17 +114,17 @@ class ParallelParamFit:
This causes fit() to compute the best-fitting function for this model part.
"""
+ # Transform by_param[(state_or_tran, param_value)][attribute] = ...
+ # into n_by_param[param_value] = ...
+ # (param_value is dynamic, the rest is fixed)
+ n_by_param = dict()
+ for k, v in self.by_param.items():
+ if k[0] == state_or_tran:
+ n_by_param[k[1]] = v[attribute]
self.fit_queue.append(
{
"key": [state_or_tran, attribute, param_name, param_filter],
- "args": [
- self.by_param,
- state_or_tran,
- attribute,
- param_index,
- safe_functions_enabled,
- param_filter,
- ],
+ "args": [n_by_param, param_index, safe_functions_enabled, param_filter],
}
)
@@ -201,20 +201,15 @@ def _try_fits_parallel(arg):
def _try_fits(
- by_param,
- state_or_tran,
- model_attribute,
- param_index,
- safe_functions_enabled=False,
- param_filter: dict = None,
+ n_by_param, param_index, safe_functions_enabled=False, param_filter: dict = None
):
"""
- Determine goodness-of-fit for prediction of `by_param[(state_or_tran, *)][model_attribute]` dependence on `param_index` using various functions.
+ Determine goodness-of-fit for prediction of `n_by_param[(param1_value, param2_value, ...)]` dependence on `param_index` using various functions.
This is done by varying `param_index` while keeping all other parameters constant and doing one least squares optimization for each function and for each combination of the remaining parameters.
The value of the parameter corresponding to `param_index` (e.g. txpower or packet length) is the sole input to the model function.
Only numeric parameter values (as determined by `utils.is_numeric`) are used for fitting, non-numeric values such as None or enum strings are ignored.
- Fitting is only performed if at least three distinct parameter values exist in `by_param[(state_or_tran, *)]`.
+ Fitting is only performed if at least three distinct parameter values exist in `by_param[*]`.
:returns: a dictionary with the following elements:
best -- name of the best-fitting function (see `analytic.functions`). `None` in case of insufficient data.
@@ -223,14 +218,8 @@ def _try_fits(
median_rmsd -- mean Root Mean Square Deviation of a reference model using the median of its respective input data as model value
results -- mean goodness-of-fit measures for the individual functions. See `analytic.functions` for keys and `aggregate_measures` for values
- :param by_param: measurements partitioned by state/transition/... name and parameter values.
- Example: `{('foo', (0, 2)): {'bar': [2]}, ('foo', (0, 4)): {'bar': [4]}, ('foo', (0, 6)): {'bar': [6]}}`
-
- :param state_or_tran: state/transition/... name for which goodness-of-fit will be calculated (first element of by_param key tuple).
- Example: `'foo'`
-
- :param model_attribute: attribute for which goodness-of-fit will be calculated.
- Example: `'bar'`
+ :param n_by_param: measurements of a specific model attribute partitioned by parameter values.
+ Example: `{(0, 2): [2], (0, 4): [4], (0, 6): [6]}`
:param param_index: index of the parameter used as model input
:param safe_functions_enabled: Include "safe" variants of functions with limited argument range.
@@ -239,15 +228,15 @@ def _try_fits(
functions = analytic.functions(safe_functions_enabled=safe_functions_enabled)
- for param_key in filter(lambda x: x[0] == state_or_tran, by_param.keys()):
+ for param_key in n_by_param.keys():
# We might remove elements from 'functions' while iterating over
# its keys. A generator will not allow this, so we need to
# convert to a list.
function_names = list(functions.keys())
for function_name in function_names:
function_object = functions[function_name]
- if is_numeric(param_key[1][param_index]) and not function_object.is_valid(
- param_key[1][param_index]
+ if is_numeric(param_key[param_index]) and not function_object.is_valid(
+ param_key[param_index]
):
functions.pop(function_name, None)
@@ -261,12 +250,11 @@ def _try_fits(
# for each parameter combination:
for param_key in filter(
- lambda x: x[0] == state_or_tran
- and remove_index_from_tuple(x[1], param_index)
+ lambda x: remove_index_from_tuple(x, param_index)
not in seen_parameter_combinations
- and len(by_param[x]["param"])
- and match_parameter_values(by_param[x]["param"][0], param_filter),
- by_param.keys(),
+ and len(n_by_param[x])
+ and match_parameter_values(n_by_param[x][0], param_filter),
+ n_by_param.keys(),
):
X = []
Y = []
@@ -275,24 +263,22 @@ def _try_fits(
# Ensure that each parameter combination is only optimized once. Otherwise, with parameters (1, 2, 5), (1, 3, 5), (1, 4, 5) and param_index == 1,
# the parameter combination (1, *, 5) would be optimized three times, both wasting time and biasing results towards more frequently occuring combinations of non-param_index parameters
- seen_parameter_combinations.add(
- remove_index_from_tuple(param_key[1], param_index)
- )
+ seen_parameter_combinations.add(remove_index_from_tuple(param_key, param_index))
# for each value of the parameter denoted by param_index (all other parameters remain the same):
for k, v in filter(
- lambda kv: param_slice_eq(kv[0], param_key, param_index), by_param.items()
+ lambda kv: param_slice_eq(kv[0], param_key, param_index), n_by_param.items()
):
num_total += 1
- if is_numeric(k[1][param_index]):
+ if is_numeric(k[param_index]):
num_valid += 1
- X.extend([float(k[1][param_index])] * len(v[model_attribute]))
- Y.extend(v[model_attribute])
+ X.extend([float(k[param_index])] * len(v))
+ Y.extend(v)
if num_valid > 2:
X = np.array(X)
Y = np.array(Y)
- other_parameters = remove_index_from_tuple(k[1], param_index)
+ other_parameters = remove_index_from_tuple(k, param_index)
raw_results_by_param[other_parameters] = dict()
results_by_param[other_parameters] = dict()
for function_name, param_function in functions.items():
@@ -318,7 +304,7 @@ def _try_fits(
if not len(ref_results["mean"]):
# Insufficient data for fitting
- # print('[W] Insufficient data for fitting {}/{}/{}'.format(state_or_tran, model_attribute, param_index))
+ # print('[W] Insufficient data for fitting {}'.format(param_index))
return {"best": None, "best_rmsd": np.inf, "results": results}
for (