summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2021-11-22 15:31:40 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2021-11-22 15:31:40 +0100
commit5b5380430eb3b701c1c43e18524a6b4759f46e27 (patch)
treea75cd1304c374354522a547e8fc667d3e094fc1a
parent5f79f5d018cad72b1c821dd962894d3483e29553 (diff)
handle fit errors in cross validation
-rw-r--r--lib/functions.py18
-rw-r--r--lib/paramfit.py12
2 files changed, 26 insertions, 4 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 698d68c..b1477da 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -75,7 +75,7 @@ class ParamFunction:
error measure.
"""
- def __init__(self, param_function, validation_function, num_vars):
+ def __init__(self, param_function, validation_function, num_vars, repr_str=None):
"""
Create function object suitable for regression analysis.
@@ -97,6 +97,12 @@ class ParamFunction:
self._param_function = param_function
self._validation_function = validation_function
self._num_variables = num_vars
+ self.repr_str = repr_str
+
+ def __repr__(self) -> str:
+ if self.repr_str:
+ return f"ParamFunction<{self.repr_str}>"
+ return f"ParamFunction<{self._param_function}, {self.validation_function}, {self._num_variables}>"
def is_valid(self, arg: float) -> bool:
"""
@@ -691,24 +697,28 @@ class analytic:
+ reg_param[1] * model_param,
lambda model_param: True,
2,
+ repr_str="β₀ + β₁ * x",
),
"logarithmic": ParamFunction(
lambda reg_param, model_param: reg_param[0]
+ reg_param[1] * np.log(model_param),
lambda model_param: model_param > 0,
2,
+ repr_str="β₀ + β₁ * np.log(x)",
),
"logarithmic1": ParamFunction(
lambda reg_param, model_param: reg_param[0]
+ reg_param[1] * np.log(model_param + 1),
lambda model_param: model_param > -1,
2,
+ repr_str="β₀ + β₁ * np.log(x+1)",
),
"exponential": ParamFunction(
lambda reg_param, model_param: reg_param[0]
+ reg_param[1] * np.exp(model_param),
lambda model_param: model_param <= 64,
2,
+ repr_str="β₀ + β₁ * np.exp(x)",
),
#'polynomial' : lambda reg_param, model_param: reg_param[0] + reg_param[1] * model_param + reg_param[2] * model_param ** 2,
"square": ParamFunction(
@@ -716,18 +726,21 @@ class analytic:
+ reg_param[1] * model_param ** 2,
lambda model_param: True,
2,
+ repr_str="β₀ + β₁ * x²",
),
"inverse": ParamFunction(
lambda reg_param, model_param: reg_param[0]
+ reg_param[1] / model_param,
lambda model_param: model_param != 0,
2,
+ repr_str="β₀ + β₁ * 1/x",
),
"sqrt": ParamFunction(
lambda reg_param, model_param: reg_param[0]
+ reg_param[1] * np.sqrt(model_param),
lambda model_param: model_param >= 0,
2,
+ repr_str="β₀ + β₁ * np.sqrt(x)",
),
# "num0_8": ParamFunction(
# lambda reg_param, model_param: reg_param[0]
@@ -759,6 +772,7 @@ class analytic:
+ reg_param[1] * analytic._safe_log(model_param),
lambda model_param: True,
2,
+ repr_str="β₀ + β₁ * safe_log(x)",
)
functions.pop("inverse")
functions["safe_inv"] = ParamFunction(
@@ -766,6 +780,7 @@ class analytic:
+ reg_param[1] * analytic._safe_inv(model_param),
lambda model_param: True,
2,
+ repr_str="β₀ + β₁ * safe(1/x)",
)
functions.pop("sqrt")
functions["safe_sqrt"] = ParamFunction(
@@ -773,6 +788,7 @@ class analytic:
+ reg_param[1] * analytic._safe_sqrt(model_param),
lambda model_param: True,
2,
+ repr_str="β₀ + β₁ * safe_sqrt(x)",
)
if bool(int(os.getenv("DFATOOL_FIT_LINEAR_ONLY", "0"))):
diff --git a/lib/paramfit.py b/lib/paramfit.py
index 8cfd486..9923e1a 100644
--- a/lib/paramfit.py
+++ b/lib/paramfit.py
@@ -202,9 +202,15 @@ def _try_fits(
if function_name not in raw_results:
raw_results[function_name] = dict()
error_function = param_function.error_function
- res = optimize.least_squares(
- error_function, [0, 1], args=(X, Y), xtol=2e-15
- )
+ try:
+ res = optimize.least_squares(
+ error_function, [0, 1], args=(X, Y), xtol=2e-15
+ )
+ except FloatingPointError as e:
+ logger.warning(
+ f"optimize.least_squares threw '{e}' when fitting {param_function} on {X}, {Y}"
+ )
+ continue
measures = regression_measures(param_function.eval(res.x, X), Y)
raw_results_by_param[other_parameters][function_name] = measures
for measure, error_rate in measures.items():