diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2021-11-22 15:31:40 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2021-11-22 15:31:40 +0100 |
commit | 5b5380430eb3b701c1c43e18524a6b4759f46e27 (patch) | |
tree | a75cd1304c374354522a547e8fc667d3e094fc1a | |
parent | 5f79f5d018cad72b1c821dd962894d3483e29553 (diff) |
handle fit errors in cross validation
-rw-r--r-- | lib/functions.py | 18 | ||||
-rw-r--r-- | lib/paramfit.py | 12 |
2 files changed, 26 insertions, 4 deletions
diff --git a/lib/functions.py b/lib/functions.py index 698d68c..b1477da 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -75,7 +75,7 @@ class ParamFunction: error measure. """ - def __init__(self, param_function, validation_function, num_vars): + def __init__(self, param_function, validation_function, num_vars, repr_str=None): """ Create function object suitable for regression analysis. @@ -97,6 +97,12 @@ class ParamFunction: self._param_function = param_function self._validation_function = validation_function self._num_variables = num_vars + self.repr_str = repr_str + + def __repr__(self) -> str: + if self.repr_str: + return f"ParamFunction<{self.repr_str}>" + return f"ParamFunction<{self._param_function}, {self.validation_function}, {self._num_variables}>" def is_valid(self, arg: float) -> bool: """ @@ -691,24 +697,28 @@ class analytic: + reg_param[1] * model_param, lambda model_param: True, 2, + repr_str="β₀ + β₁ * x", ), "logarithmic": ParamFunction( lambda reg_param, model_param: reg_param[0] + reg_param[1] * np.log(model_param), lambda model_param: model_param > 0, 2, + repr_str="β₀ + β₁ * np.log(x)", ), "logarithmic1": ParamFunction( lambda reg_param, model_param: reg_param[0] + reg_param[1] * np.log(model_param + 1), lambda model_param: model_param > -1, 2, + repr_str="β₀ + β₁ * np.log(x+1)", ), "exponential": ParamFunction( lambda reg_param, model_param: reg_param[0] + reg_param[1] * np.exp(model_param), lambda model_param: model_param <= 64, 2, + repr_str="β₀ + β₁ * np.exp(x)", ), #'polynomial' : lambda reg_param, model_param: reg_param[0] + reg_param[1] * model_param + reg_param[2] * model_param ** 2, "square": ParamFunction( @@ -716,18 +726,21 @@ class analytic: + reg_param[1] * model_param ** 2, lambda model_param: True, 2, + repr_str="β₀ + β₁ * x²", ), "inverse": ParamFunction( lambda reg_param, model_param: reg_param[0] + reg_param[1] / model_param, lambda model_param: model_param != 0, 2, + repr_str="β₀ + β₁ * 1/x", ), "sqrt": ParamFunction( lambda reg_param, model_param: reg_param[0] + reg_param[1] * np.sqrt(model_param), lambda model_param: model_param >= 0, 2, + repr_str="β₀ + β₁ * np.sqrt(x)", ), # "num0_8": ParamFunction( # lambda reg_param, model_param: reg_param[0] @@ -759,6 +772,7 @@ class analytic: + reg_param[1] * analytic._safe_log(model_param), lambda model_param: True, 2, + repr_str="β₀ + β₁ * safe_log(x)", ) functions.pop("inverse") functions["safe_inv"] = ParamFunction( @@ -766,6 +780,7 @@ class analytic: + reg_param[1] * analytic._safe_inv(model_param), lambda model_param: True, 2, + repr_str="β₀ + β₁ * safe(1/x)", ) functions.pop("sqrt") functions["safe_sqrt"] = ParamFunction( @@ -773,6 +788,7 @@ class analytic: + reg_param[1] * analytic._safe_sqrt(model_param), lambda model_param: True, 2, + repr_str="β₀ + β₁ * safe_sqrt(x)", ) if bool(int(os.getenv("DFATOOL_FIT_LINEAR_ONLY", "0"))): diff --git a/lib/paramfit.py b/lib/paramfit.py index 8cfd486..9923e1a 100644 --- a/lib/paramfit.py +++ b/lib/paramfit.py @@ -202,9 +202,15 @@ def _try_fits( if function_name not in raw_results: raw_results[function_name] = dict() error_function = param_function.error_function - res = optimize.least_squares( - error_function, [0, 1], args=(X, Y), xtol=2e-15 - ) + try: + res = optimize.least_squares( + error_function, [0, 1], args=(X, Y), xtol=2e-15 + ) + except FloatingPointError as e: + logger.warning( + f"optimize.least_squares threw '{e}' when fitting {param_function} on {X}, {Y}" + ) + continue measures = regression_measures(param_function.eval(res.x, X), Y) raw_results_by_param[other_parameters][function_name] = measures for measure, error_rate in measures.items(): |