diff options
-rw-r--r-- | lib/paramfit.py | 30 |
1 files changed, 21 insertions, 9 deletions
diff --git a/lib/paramfit.py b/lib/paramfit.py index eb0e141..e6539a4 100644 --- a/lib/paramfit.py +++ b/lib/paramfit.py @@ -16,7 +16,7 @@ from .utils import ( ) logger = logging.getLogger(__name__) -best_fit_metric = os.getenv("DFATOOL_ULS_ERROR_METRIC", "rmsd") +best_fit_metric = os.getenv("DFATOOL_ULS_ERROR_METRIC", "ssr") class ParamFit: @@ -255,16 +255,28 @@ def _try_fits( if len(result) > 0: results[function_name] = {} for measure in result.keys(): - results[function_name][measure] = np.mean(result[measure]) + if measure == "ssr": + results[function_name][measure] = np.sum(result[measure]) + else: + results[function_name][measure] = np.mean(result[measure]) err = results[function_name][best_fit_metric] if err < best_fit_val: best_fit_val = err best_fit_name = function_name - return { - "best": best_fit_name, - "best_err": best_fit_val, - "mean_err": np.mean(ref_results["mean"]), - "median_err": np.mean(ref_results["median"]), - "results": results, - } + if best_fit_metric == "ssr": + return { + "best": best_fit_name, + "best_err": best_fit_val, + "mean_err": np.sum(ref_results["mean"]), + "median_err": np.sum(ref_results["median"]), + "results": results, + } + else: + return { + "best": best_fit_name, + "best_err": best_fit_val, + "mean_err": np.mean(ref_results["mean"]), + "median_err": np.mean(ref_results["median"]), + "results": results, + } |