summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-01-17 14:38:03 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-01-17 14:38:03 +0100
commit0015665bda8511db30c254adb351af94494deb7f (patch)
tree19f998517bace3d542d05d82b9dd26c9b1d2058a /lib
parent1b4d6d743cf128e5c278559de7d75dde71b1f577 (diff)
paramfit: use sum(ssr) rather than mean(rmsd) for best-fit selection
Diffstat (limited to 'lib')
-rw-r--r--lib/paramfit.py30
1 files changed, 21 insertions, 9 deletions
diff --git a/lib/paramfit.py b/lib/paramfit.py
index eb0e141..e6539a4 100644
--- a/lib/paramfit.py
+++ b/lib/paramfit.py
@@ -16,7 +16,7 @@ from .utils import (
)
logger = logging.getLogger(__name__)
-best_fit_metric = os.getenv("DFATOOL_ULS_ERROR_METRIC", "rmsd")
+best_fit_metric = os.getenv("DFATOOL_ULS_ERROR_METRIC", "ssr")
class ParamFit:
@@ -255,16 +255,28 @@ def _try_fits(
if len(result) > 0:
results[function_name] = {}
for measure in result.keys():
- results[function_name][measure] = np.mean(result[measure])
+ if measure == "ssr":
+ results[function_name][measure] = np.sum(result[measure])
+ else:
+ results[function_name][measure] = np.mean(result[measure])
err = results[function_name][best_fit_metric]
if err < best_fit_val:
best_fit_val = err
best_fit_name = function_name
- return {
- "best": best_fit_name,
- "best_err": best_fit_val,
- "mean_err": np.mean(ref_results["mean"]),
- "median_err": np.mean(ref_results["median"]),
- "results": results,
- }
+ if best_fit_metric == "ssr":
+ return {
+ "best": best_fit_name,
+ "best_err": best_fit_val,
+ "mean_err": np.sum(ref_results["mean"]),
+ "median_err": np.sum(ref_results["median"]),
+ "results": results,
+ }
+ else:
+ return {
+ "best": best_fit_name,
+ "best_err": best_fit_val,
+ "mean_err": np.mean(ref_results["mean"]),
+ "median_err": np.mean(ref_results["median"]),
+ "results": results,
+ }