diff options
Diffstat (limited to 'lib/dfatool.py')
-rw-r--r-- | lib/dfatool.py | 40 |
1 files changed, 25 insertions, 15 deletions
diff --git a/lib/dfatool.py b/lib/dfatool.py index 478f800..a3d5c0f 100644 --- a/lib/dfatool.py +++ b/lib/dfatool.py @@ -18,7 +18,7 @@ from functions import analytic from functions import AnalyticFunction from parameters import ParamStats from utils import vprint, is_numeric, soft_cast_int, param_slice_eq, remove_index_from_tuple -from utils import by_name_to_by_param +from utils import by_name_to_by_param, match_parameter_values arg_support_enabled = True @@ -505,12 +505,10 @@ class RawData: self.cache_file = '{}/{}.json'.format(self.cache_dir, cache_key) def load_cache(self): - print('checking {}...'.format(self.cache_file)) if os.path.exists(self.cache_file): with open(self.cache_file, 'r') as f: self.traces = json.load(f) self.preprocessed = True - print('loaded cache') def save_cache(self): try: @@ -902,15 +900,15 @@ class ParallelParamFit: self.fit_queue = [] self.by_param = by_param - def enqueue(self, state_or_tran, attribute, param_index, param_name, safe_functions_enabled = False): + def enqueue(self, state_or_tran, attribute, param_index, param_name, safe_functions_enabled = False, param_filter = None): """ Add state_or_tran/attribute/param_name to fit queue. This causes fit() to compute the best-fitting function for this model part. """ self.fit_queue.append({ - 'key' : [state_or_tran, attribute, param_name], - 'args' : [self.by_param, state_or_tran, attribute, param_index, safe_functions_enabled] + 'key' : [state_or_tran, attribute, param_name, param_filter], + 'args' : [self.by_param, state_or_tran, attribute, param_index, safe_functions_enabled, param_filter] }) def fit(self): @@ -935,16 +933,17 @@ def _try_fits_parallel(arg): 'result' : _try_fits(*arg['args']) } - -def _try_fits(by_param, state_or_tran, model_attribute, param_index, safe_functions_enabled = False): +def _try_fits(by_param, state_or_tran, model_attribute, param_index, safe_functions_enabled = False, param_filter: dict = None): """ Determine goodness-of-fit for prediction of `by_param[(state_or_tran, *)][model_attribute]` dependence on `param_index` using various functions. This is done by varying `param_index` while keeping all other parameters constant and doing one least squares optimization for each function and for each combination of the remaining parameters. The value of the parameter corresponding to `param_index` (e.g. txpower or packet length) is the sole input to the model function. + Only numeric parameter values (as determined by `utils.is_numeric`) are used for fitting, non-numeric values such as None or enum strings are ignored. + Fitting is only performed if at least three distinct parameter values exist in `by_param[(state_or_tran, *)]`. - :return: a dictionary with the following elements: - best -- name of the best-fitting function (see `analytic.functions`) + :returns: a dictionary with the following elements: + best -- name of the best-fitting function (see `analytic.functions`). `None` in case of insufficient data. best_rmsd -- mean Root Mean Square Deviation of best-fitting function over all combinations of the remaining parameters mean_rmsd -- mean Root Mean Square Deviation of a reference model using the mean of its respective input data as model value median_rmsd -- mean Root Mean Square Deviation of a reference model using the median of its respective input data as model value @@ -961,6 +960,7 @@ def _try_fits(by_param, state_or_tran, model_attribute, param_index, safe_functi :param param_index: index of the parameter used as model input :param safe_functions_enabled: Include "safe" variants of functions with limited argument range. + :param param_filter: Only use measurements whose parameters match param_filter for fitting. """ functions = analytic.functions(safe_functions_enabled = safe_functions_enabled) @@ -987,7 +987,7 @@ def _try_fits(by_param, state_or_tran, model_attribute, param_index, safe_functi seen_parameter_combinations = set() # for each parameter combination: - for param_key in filter(lambda x: x[0] == state_or_tran and remove_index_from_tuple(x[1], param_index) not in seen_parameter_combinations, by_param.keys()): + for param_key in filter(lambda x: x[0] == state_or_tran and remove_index_from_tuple(x[1], param_index) not in seen_parameter_combinations and len(by_param[x]['param']) and match_parameter_values(by_param[x]['param'][0], param_filter), by_param.keys()): X = [] Y = [] num_valid = 0 @@ -1087,7 +1087,7 @@ def _num_args_from_by_name(by_name): num_args[key] = len(value['args'][0]) return num_args -def get_fit_result(results, name, attribute, verbose = False): +def get_fit_result(results, name, attribute, verbose = False, param_filter: dict = None): """ Parse and sanitize fit results for state/transition/... 'name' and model attribute 'attribute'. @@ -1097,10 +1097,12 @@ def get_fit_result(results, name, attribute, verbose = False): :param name: state/transition/... name, e.g. 'TX' :param attribute: model attribute, e.g. 'duration' :param verbose: print debug message to stdout when deliberately not using a determined fit function + :param param_filter: + :returns: dict with fit result (see `_try_fits`) for each successfully fitted parameter. E.g. {'param 1': {'best' : 'function name', ...} } """ fit_result = dict() for result in results: - if result['key'][0] == name and result['key'][1] == attribute and result['result']['best'] != None: + if result['key'][0] == name and result['key'][1] == attribute and result['key'][3] == param_filter and result['result']['best'] != None: # dürfte an ['best'] != None liegen-> Fit für gefilterten Kram schlägt fehl? this_result = result['result'] if this_result['best_rmsd'] >= min(this_result['mean_rmsd'], this_result['median_rmsd']): vprint(verbose, '[I] Not modeling {} {} as function of {}: best ({:.0f}) is worse than ref ({:.0f}, {:.0f})'.format( @@ -1583,7 +1585,7 @@ class PTAModel: """ Get static model function: name, attribute -> model value. - Uses the median of by_name for modeling. + Uses the median of by_name for modeling, unless `use_mean` is set. """ getter_function = np.median @@ -1633,7 +1635,7 @@ class PTAModel: def get_fitted(self, safe_functions_enabled = False): """ - Get paramete-aware model function and model information function. + Get parameter-aware model function and model information function. Returns two functions: model_function(name, attribute, param=parameter values) -> model value. @@ -1651,6 +1653,8 @@ class PTAModel: for parameter_index, parameter_name in enumerate(self._parameter_names): if self.depends_on_param(state_or_tran, model_attribute, parameter_name): paramfit.enqueue(state_or_tran, model_attribute, parameter_index, parameter_name, safe_functions_enabled) + for codependent_param_dict in self.stats.codependent_parameter_value_dicts(state_or_tran, model_attribute, parameter_name): + paramfit.enqueue(state_or_tran, model_attribute, parameter_index, parameter_name, safe_functions_enabled, codependent_param_dict) if arg_support_enabled and self.by_name[state_or_tran]['isa'] == 'transition': for arg_index in range(self._num_args[state_or_tran]): if self.depends_on_arg(state_or_tran, model_attribute, arg_index): @@ -1664,6 +1668,12 @@ class PTAModel: for model_attribute in self.by_name[state_or_tran]['attributes']: fit_results = get_fit_result(paramfit.results, state_or_tran, model_attribute, self.verbose) + for parameter_name in self._parameter_names: + if self.depends_on_param(state_or_tran, model_attribute, parameter_name): + for codependent_param_dict in self.stats.codependent_parameter_value_dicts(state_or_tran, model_attribute, parameter_name): + pass + # FIXME get_fit_result hat ja gar keinen Parameter als Argument... + if (state_or_tran, model_attribute) in self.function_override: function_str = self.function_override[(state_or_tran, model_attribute)] x = AnalyticFunction(function_str, self._parameter_names, num_args) |