summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2020-07-03 11:05:38 +0200
committerDaniel Friesel <daniel.friesel@uos.de>2020-07-03 11:05:38 +0200
commitb911860adb05e9712d16d335c9d1d9785733eea0 (patch)
treea5a53c6160b58a5444d7e4d38f27a27d65066325
parent230d52e784392fd7053099e345943f58e3a5e32e (diff)
move get_fit_result to ParamFit class
-rw-r--r--lib/dfatool.py112
-rwxr-xr-x[-rw-r--r--]test/test_parameters.py2
2 files changed, 55 insertions, 59 deletions
diff --git a/lib/dfatool.py b/lib/dfatool.py
index 1e38907..0596ad8 100644
--- a/lib/dfatool.py
+++ b/lib/dfatool.py
@@ -1484,6 +1484,57 @@ class ParallelParamFit:
with Pool() as pool:
self.results = pool.map(_try_fits_parallel, self.fit_queue)
+ def get_result(self, name, attribute, param_filter: dict = None):
+ """
+ Parse and sanitize fit results for state/transition/... 'name' and model attribute 'attribute'.
+
+ Filters out results where the best function is worse (or not much better than) static mean/median estimates.
+
+ :param name: state/transition/... name, e.g. 'TX'
+ :param attribute: model attribute, e.g. 'duration'
+ :param param_filter:
+ :returns: dict with fit result (see `_try_fits`) for each successfully fitted parameter. E.g. {'param 1': {'best' : 'function name', ...} }
+ """
+ fit_result = dict()
+ for result in self.results:
+ if (
+ result["key"][0] == name
+ and result["key"][1] == attribute
+ and result["key"][3] == param_filter
+ and result["result"]["best"] is not None
+ ): # dürfte an ['best'] != None liegen-> Fit für gefilterten Kram schlägt fehl?
+ this_result = result["result"]
+ if this_result["best_rmsd"] >= min(
+ this_result["mean_rmsd"], this_result["median_rmsd"]
+ ):
+ logger.debug(
+ "Not modeling {} {} as function of {}: best ({:.0f}) is worse than ref ({:.0f}, {:.0f})".format(
+ name,
+ attribute,
+ result["key"][2],
+ this_result["best_rmsd"],
+ this_result["mean_rmsd"],
+ this_result["median_rmsd"],
+ )
+ )
+ # See notes on depends_on_param
+ elif this_result["best_rmsd"] >= 0.8 * min(
+ this_result["mean_rmsd"], this_result["median_rmsd"]
+ ):
+ logger.debug(
+ "Not modeling {} {} as function of {}: best ({:.0f}) is not much better than ref ({:.0f}, {:.0f})".format(
+ name,
+ attribute,
+ result["key"][2],
+ this_result["best_rmsd"],
+ this_result["mean_rmsd"],
+ this_result["median_rmsd"],
+ )
+ )
+ else:
+ fit_result[result["key"][2]] = this_result
+ return fit_result
+
def _try_fits_parallel(arg):
"""
@@ -1668,59 +1719,6 @@ def _num_args_from_by_name(by_name):
return num_args
-def get_fit_result(results, name, attribute, param_filter: dict = None):
- """
- Parse and sanitize fit results for state/transition/... 'name' and model attribute 'attribute'.
-
- Filters out results where the best function is worse (or not much better than) static mean/median estimates.
-
- :param results: fit results as returned by `paramfit.results`
- :param name: state/transition/... name, e.g. 'TX'
- :param attribute: model attribute, e.g. 'duration'
- :param param_filter:
- :returns: dict with fit result (see `_try_fits`) for each successfully fitted parameter. E.g. {'param 1': {'best' : 'function name', ...} }
- """
- fit_result = dict()
- for result in results:
- if (
- result["key"][0] == name
- and result["key"][1] == attribute
- and result["key"][3] == param_filter
- and result["result"]["best"] is not None
- ): # dürfte an ['best'] != None liegen-> Fit für gefilterten Kram schlägt fehl?
- this_result = result["result"]
- if this_result["best_rmsd"] >= min(
- this_result["mean_rmsd"], this_result["median_rmsd"]
- ):
- logger.debug(
- "Not modeling {} {} as function of {}: best ({:.0f}) is worse than ref ({:.0f}, {:.0f})".format(
- name,
- attribute,
- result["key"][2],
- this_result["best_rmsd"],
- this_result["mean_rmsd"],
- this_result["median_rmsd"],
- )
- )
- # See notes on depends_on_param
- elif this_result["best_rmsd"] >= 0.8 * min(
- this_result["mean_rmsd"], this_result["median_rmsd"]
- ):
- logger.debug(
- "Not modeling {} {} as function of {}: best ({:.0f}) is not much better than ref ({:.0f}, {:.0f})".format(
- name,
- attribute,
- result["key"][2],
- this_result["best_rmsd"],
- this_result["mean_rmsd"],
- this_result["median_rmsd"],
- )
- )
- else:
- fit_result[result["key"][2]] = this_result
- return fit_result
-
-
class AnalyticModel:
u"""
Parameter-aware analytic energy/data size/... model.
@@ -1929,7 +1927,7 @@ class AnalyticModel:
if name in self._num_args:
num_args = self._num_args[name]
for attribute in self.by_name[name]["attributes"]:
- fit_result = get_fit_result(paramfit.results, name, attribute)
+ fit_result = paramfit.get_result(name, attribute)
if (name, attribute) in self.function_override:
function_str = self.function_override[(name, attribute)]
@@ -2355,9 +2353,7 @@ class PTAModel:
):
num_args = self._num_args[state_or_tran]
for model_attribute in self.by_name[state_or_tran]["attributes"]:
- fit_results = get_fit_result(
- paramfit.results, state_or_tran, model_attribute
- )
+ fit_results = paramfit.get_result(state_or_tran, model_attribute)
for parameter_name in self._parameter_names:
if self.depends_on_param(
@@ -2369,7 +2365,7 @@ class PTAModel:
state_or_tran, model_attribute, parameter_name
):
pass
- # FIXME get_fit_result hat ja gar keinen Parameter als Argument...
+ # FIXME paramfit.get_result hat ja gar keinen Parameter als Argument...
if (state_or_tran, model_attribute) in self.function_override:
function_str = self.function_override[
diff --git a/test/test_parameters.py b/test/test_parameters.py
index a466d91..57ab166 100644..100755
--- a/test/test_parameters.py
+++ b/test/test_parameters.py
@@ -44,7 +44,7 @@ class TestModels(unittest.TestCase):
paramfit.enqueue("TX", "power", 1, "p_linear")
paramfit.fit()
- fit_result = dt.get_fit_result(paramfit.results, "TX", "power")
+ fit_result = paramfit.get_result("TX", "power")
self.assertEqual(fit_result["p_linear"]["best"], "linear")