diff options
Diffstat (limited to 'lib/model.py')
-rw-r--r-- | lib/model.py | 68 |
1 files changed, 27 insertions, 41 deletions
diff --git a/lib/model.py b/lib/model.py index 192cea3..1190fb0 100644 --- a/lib/model.py +++ b/lib/model.py @@ -114,17 +114,17 @@ class ParallelParamFit: This causes fit() to compute the best-fitting function for this model part. """ + # Transform by_param[(state_or_tran, param_value)][attribute] = ... + # into n_by_param[param_value] = ... + # (param_value is dynamic, the rest is fixed) + n_by_param = dict() + for k, v in self.by_param.items(): + if k[0] == state_or_tran: + n_by_param[k[1]] = v[attribute] self.fit_queue.append( { "key": [state_or_tran, attribute, param_name, param_filter], - "args": [ - self.by_param, - state_or_tran, - attribute, - param_index, - safe_functions_enabled, - param_filter, - ], + "args": [n_by_param, param_index, safe_functions_enabled, param_filter], } ) @@ -201,20 +201,15 @@ def _try_fits_parallel(arg): def _try_fits( - by_param, - state_or_tran, - model_attribute, - param_index, - safe_functions_enabled=False, - param_filter: dict = None, + n_by_param, param_index, safe_functions_enabled=False, param_filter: dict = None ): """ - Determine goodness-of-fit for prediction of `by_param[(state_or_tran, *)][model_attribute]` dependence on `param_index` using various functions. + Determine goodness-of-fit for prediction of `n_by_param[(param1_value, param2_value, ...)]` dependence on `param_index` using various functions. This is done by varying `param_index` while keeping all other parameters constant and doing one least squares optimization for each function and for each combination of the remaining parameters. The value of the parameter corresponding to `param_index` (e.g. txpower or packet length) is the sole input to the model function. Only numeric parameter values (as determined by `utils.is_numeric`) are used for fitting, non-numeric values such as None or enum strings are ignored. - Fitting is only performed if at least three distinct parameter values exist in `by_param[(state_or_tran, *)]`. + Fitting is only performed if at least three distinct parameter values exist in `by_param[*]`. :returns: a dictionary with the following elements: best -- name of the best-fitting function (see `analytic.functions`). `None` in case of insufficient data. @@ -223,14 +218,8 @@ def _try_fits( median_rmsd -- mean Root Mean Square Deviation of a reference model using the median of its respective input data as model value results -- mean goodness-of-fit measures for the individual functions. See `analytic.functions` for keys and `aggregate_measures` for values - :param by_param: measurements partitioned by state/transition/... name and parameter values. - Example: `{('foo', (0, 2)): {'bar': [2]}, ('foo', (0, 4)): {'bar': [4]}, ('foo', (0, 6)): {'bar': [6]}}` - - :param state_or_tran: state/transition/... name for which goodness-of-fit will be calculated (first element of by_param key tuple). - Example: `'foo'` - - :param model_attribute: attribute for which goodness-of-fit will be calculated. - Example: `'bar'` + :param n_by_param: measurements of a specific model attribute partitioned by parameter values. + Example: `{(0, 2): [2], (0, 4): [4], (0, 6): [6]}` :param param_index: index of the parameter used as model input :param safe_functions_enabled: Include "safe" variants of functions with limited argument range. @@ -239,15 +228,15 @@ def _try_fits( functions = analytic.functions(safe_functions_enabled=safe_functions_enabled) - for param_key in filter(lambda x: x[0] == state_or_tran, by_param.keys()): + for param_key in n_by_param.keys(): # We might remove elements from 'functions' while iterating over # its keys. A generator will not allow this, so we need to # convert to a list. function_names = list(functions.keys()) for function_name in function_names: function_object = functions[function_name] - if is_numeric(param_key[1][param_index]) and not function_object.is_valid( - param_key[1][param_index] + if is_numeric(param_key[param_index]) and not function_object.is_valid( + param_key[param_index] ): functions.pop(function_name, None) @@ -261,12 +250,11 @@ def _try_fits( # for each parameter combination: for param_key in filter( - lambda x: x[0] == state_or_tran - and remove_index_from_tuple(x[1], param_index) + lambda x: remove_index_from_tuple(x, param_index) not in seen_parameter_combinations - and len(by_param[x]["param"]) - and match_parameter_values(by_param[x]["param"][0], param_filter), - by_param.keys(), + and len(n_by_param[x]) + and match_parameter_values(n_by_param[x][0], param_filter), + n_by_param.keys(), ): X = [] Y = [] @@ -275,24 +263,22 @@ def _try_fits( # Ensure that each parameter combination is only optimized once. Otherwise, with parameters (1, 2, 5), (1, 3, 5), (1, 4, 5) and param_index == 1, # the parameter combination (1, *, 5) would be optimized three times, both wasting time and biasing results towards more frequently occuring combinations of non-param_index parameters - seen_parameter_combinations.add( - remove_index_from_tuple(param_key[1], param_index) - ) + seen_parameter_combinations.add(remove_index_from_tuple(param_key, param_index)) # for each value of the parameter denoted by param_index (all other parameters remain the same): for k, v in filter( - lambda kv: param_slice_eq(kv[0], param_key, param_index), by_param.items() + lambda kv: param_slice_eq(kv[0], param_key, param_index), n_by_param.items() ): num_total += 1 - if is_numeric(k[1][param_index]): + if is_numeric(k[param_index]): num_valid += 1 - X.extend([float(k[1][param_index])] * len(v[model_attribute])) - Y.extend(v[model_attribute]) + X.extend([float(k[param_index])] * len(v)) + Y.extend(v) if num_valid > 2: X = np.array(X) Y = np.array(Y) - other_parameters = remove_index_from_tuple(k[1], param_index) + other_parameters = remove_index_from_tuple(k, param_index) raw_results_by_param[other_parameters] = dict() results_by_param[other_parameters] = dict() for function_name, param_function in functions.items(): @@ -318,7 +304,7 @@ def _try_fits( if not len(ref_results["mean"]): # Insufficient data for fitting - # print('[W] Insufficient data for fitting {}/{}/{}'.format(state_or_tran, model_attribute, param_index)) + # print('[W] Insufficient data for fitting {}'.format(param_index)) return {"best": None, "best_rmsd": np.inf, "results": results} for ( |