diff options
-rw-r--r-- | lib/dfatool.py | 112 | ||||
-rwxr-xr-x[-rw-r--r--] | test/test_parameters.py | 2 |
2 files changed, 55 insertions, 59 deletions
diff --git a/lib/dfatool.py b/lib/dfatool.py index 1e38907..0596ad8 100644 --- a/lib/dfatool.py +++ b/lib/dfatool.py @@ -1484,6 +1484,57 @@ class ParallelParamFit: with Pool() as pool: self.results = pool.map(_try_fits_parallel, self.fit_queue) + def get_result(self, name, attribute, param_filter: dict = None): + """ + Parse and sanitize fit results for state/transition/... 'name' and model attribute 'attribute'. + + Filters out results where the best function is worse (or not much better than) static mean/median estimates. + + :param name: state/transition/... name, e.g. 'TX' + :param attribute: model attribute, e.g. 'duration' + :param param_filter: + :returns: dict with fit result (see `_try_fits`) for each successfully fitted parameter. E.g. {'param 1': {'best' : 'function name', ...} } + """ + fit_result = dict() + for result in self.results: + if ( + result["key"][0] == name + and result["key"][1] == attribute + and result["key"][3] == param_filter + and result["result"]["best"] is not None + ): # dürfte an ['best'] != None liegen-> Fit für gefilterten Kram schlägt fehl? + this_result = result["result"] + if this_result["best_rmsd"] >= min( + this_result["mean_rmsd"], this_result["median_rmsd"] + ): + logger.debug( + "Not modeling {} {} as function of {}: best ({:.0f}) is worse than ref ({:.0f}, {:.0f})".format( + name, + attribute, + result["key"][2], + this_result["best_rmsd"], + this_result["mean_rmsd"], + this_result["median_rmsd"], + ) + ) + # See notes on depends_on_param + elif this_result["best_rmsd"] >= 0.8 * min( + this_result["mean_rmsd"], this_result["median_rmsd"] + ): + logger.debug( + "Not modeling {} {} as function of {}: best ({:.0f}) is not much better than ref ({:.0f}, {:.0f})".format( + name, + attribute, + result["key"][2], + this_result["best_rmsd"], + this_result["mean_rmsd"], + this_result["median_rmsd"], + ) + ) + else: + fit_result[result["key"][2]] = this_result + return fit_result + def _try_fits_parallel(arg): """ @@ -1668,59 +1719,6 @@ def _num_args_from_by_name(by_name): return num_args -def get_fit_result(results, name, attribute, param_filter: dict = None): - """ - Parse and sanitize fit results for state/transition/... 'name' and model attribute 'attribute'. - - Filters out results where the best function is worse (or not much better than) static mean/median estimates. - - :param results: fit results as returned by `paramfit.results` - :param name: state/transition/... name, e.g. 'TX' - :param attribute: model attribute, e.g. 'duration' - :param param_filter: - :returns: dict with fit result (see `_try_fits`) for each successfully fitted parameter. E.g. {'param 1': {'best' : 'function name', ...} } - """ - fit_result = dict() - for result in results: - if ( - result["key"][0] == name - and result["key"][1] == attribute - and result["key"][3] == param_filter - and result["result"]["best"] is not None - ): # dürfte an ['best'] != None liegen-> Fit für gefilterten Kram schlägt fehl? - this_result = result["result"] - if this_result["best_rmsd"] >= min( - this_result["mean_rmsd"], this_result["median_rmsd"] - ): - logger.debug( - "Not modeling {} {} as function of {}: best ({:.0f}) is worse than ref ({:.0f}, {:.0f})".format( - name, - attribute, - result["key"][2], - this_result["best_rmsd"], - this_result["mean_rmsd"], - this_result["median_rmsd"], - ) - ) - # See notes on depends_on_param - elif this_result["best_rmsd"] >= 0.8 * min( - this_result["mean_rmsd"], this_result["median_rmsd"] - ): - logger.debug( - "Not modeling {} {} as function of {}: best ({:.0f}) is not much better than ref ({:.0f}, {:.0f})".format( - name, - attribute, - result["key"][2], - this_result["best_rmsd"], - this_result["mean_rmsd"], - this_result["median_rmsd"], - ) - ) - else: - fit_result[result["key"][2]] = this_result - return fit_result - - class AnalyticModel: u""" Parameter-aware analytic energy/data size/... model. @@ -1929,7 +1927,7 @@ class AnalyticModel: if name in self._num_args: num_args = self._num_args[name] for attribute in self.by_name[name]["attributes"]: - fit_result = get_fit_result(paramfit.results, name, attribute) + fit_result = paramfit.get_result(name, attribute) if (name, attribute) in self.function_override: function_str = self.function_override[(name, attribute)] @@ -2355,9 +2353,7 @@ class PTAModel: ): num_args = self._num_args[state_or_tran] for model_attribute in self.by_name[state_or_tran]["attributes"]: - fit_results = get_fit_result( - paramfit.results, state_or_tran, model_attribute - ) + fit_results = paramfit.get_result(state_or_tran, model_attribute) for parameter_name in self._parameter_names: if self.depends_on_param( @@ -2369,7 +2365,7 @@ class PTAModel: state_or_tran, model_attribute, parameter_name ): pass - # FIXME get_fit_result hat ja gar keinen Parameter als Argument... + # FIXME paramfit.get_result hat ja gar keinen Parameter als Argument... if (state_or_tran, model_attribute) in self.function_override: function_str = self.function_override[ diff --git a/test/test_parameters.py b/test/test_parameters.py index a466d91..57ab166 100644..100755 --- a/test/test_parameters.py +++ b/test/test_parameters.py @@ -44,7 +44,7 @@ class TestModels(unittest.TestCase): paramfit.enqueue("TX", "power", 1, "p_linear") paramfit.fit() - fit_result = dt.get_fit_result(paramfit.results, "TX", "power") + fit_result = paramfit.get_result("TX", "power") self.assertEqual(fit_result["p_linear"]["best"], "linear") |