diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-01-17 14:38:03 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-01-17 14:38:03 +0100 |
commit | 0015665bda8511db30c254adb351af94494deb7f (patch) | |
tree | 19f998517bace3d542d05d82b9dd26c9b1d2058a /lib | |
parent | 1b4d6d743cf128e5c278559de7d75dde71b1f577 (diff) |
paramfit: use sum(ssr) rather than mean(rmsd) for best-fit selection
Diffstat (limited to 'lib')
-rw-r--r-- | lib/paramfit.py | 30 |
1 files changed, 21 insertions, 9 deletions
diff --git a/lib/paramfit.py b/lib/paramfit.py index eb0e141..e6539a4 100644 --- a/lib/paramfit.py +++ b/lib/paramfit.py @@ -16,7 +16,7 @@ from .utils import ( ) logger = logging.getLogger(__name__) -best_fit_metric = os.getenv("DFATOOL_ULS_ERROR_METRIC", "rmsd") +best_fit_metric = os.getenv("DFATOOL_ULS_ERROR_METRIC", "ssr") class ParamFit: @@ -255,16 +255,28 @@ def _try_fits( if len(result) > 0: results[function_name] = {} for measure in result.keys(): - results[function_name][measure] = np.mean(result[measure]) + if measure == "ssr": + results[function_name][measure] = np.sum(result[measure]) + else: + results[function_name][measure] = np.mean(result[measure]) err = results[function_name][best_fit_metric] if err < best_fit_val: best_fit_val = err best_fit_name = function_name - return { - "best": best_fit_name, - "best_err": best_fit_val, - "mean_err": np.mean(ref_results["mean"]), - "median_err": np.mean(ref_results["median"]), - "results": results, - } + if best_fit_metric == "ssr": + return { + "best": best_fit_name, + "best_err": best_fit_val, + "mean_err": np.sum(ref_results["mean"]), + "median_err": np.sum(ref_results["median"]), + "results": results, + } + else: + return { + "best": best_fit_name, + "best_err": best_fit_val, + "mean_err": np.mean(ref_results["mean"]), + "median_err": np.mean(ref_results["median"]), + "results": results, + } |