summaryrefslogtreecommitdiff
path: root/lib/functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/functions.py')
-rw-r--r--lib/functions.py18
1 files changed, 17 insertions, 1 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 698d68c..b1477da 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -75,7 +75,7 @@ class ParamFunction:
error measure.
"""
- def __init__(self, param_function, validation_function, num_vars):
+ def __init__(self, param_function, validation_function, num_vars, repr_str=None):
"""
Create function object suitable for regression analysis.
@@ -97,6 +97,12 @@ class ParamFunction:
self._param_function = param_function
self._validation_function = validation_function
self._num_variables = num_vars
+ self.repr_str = repr_str
+
+ def __repr__(self) -> str:
+ if self.repr_str:
+ return f"ParamFunction<{self.repr_str}>"
+ return f"ParamFunction<{self._param_function}, {self.validation_function}, {self._num_variables}>"
def is_valid(self, arg: float) -> bool:
"""
@@ -691,24 +697,28 @@ class analytic:
+ reg_param[1] * model_param,
lambda model_param: True,
2,
+ repr_str="β₀ + β₁ * x",
),
"logarithmic": ParamFunction(
lambda reg_param, model_param: reg_param[0]
+ reg_param[1] * np.log(model_param),
lambda model_param: model_param > 0,
2,
+ repr_str="β₀ + β₁ * np.log(x)",
),
"logarithmic1": ParamFunction(
lambda reg_param, model_param: reg_param[0]
+ reg_param[1] * np.log(model_param + 1),
lambda model_param: model_param > -1,
2,
+ repr_str="β₀ + β₁ * np.log(x+1)",
),
"exponential": ParamFunction(
lambda reg_param, model_param: reg_param[0]
+ reg_param[1] * np.exp(model_param),
lambda model_param: model_param <= 64,
2,
+ repr_str="β₀ + β₁ * np.exp(x)",
),
#'polynomial' : lambda reg_param, model_param: reg_param[0] + reg_param[1] * model_param + reg_param[2] * model_param ** 2,
"square": ParamFunction(
@@ -716,18 +726,21 @@ class analytic:
+ reg_param[1] * model_param ** 2,
lambda model_param: True,
2,
+ repr_str="β₀ + β₁ * x²",
),
"inverse": ParamFunction(
lambda reg_param, model_param: reg_param[0]
+ reg_param[1] / model_param,
lambda model_param: model_param != 0,
2,
+ repr_str="β₀ + β₁ * 1/x",
),
"sqrt": ParamFunction(
lambda reg_param, model_param: reg_param[0]
+ reg_param[1] * np.sqrt(model_param),
lambda model_param: model_param >= 0,
2,
+ repr_str="β₀ + β₁ * np.sqrt(x)",
),
# "num0_8": ParamFunction(
# lambda reg_param, model_param: reg_param[0]
@@ -759,6 +772,7 @@ class analytic:
+ reg_param[1] * analytic._safe_log(model_param),
lambda model_param: True,
2,
+ repr_str="β₀ + β₁ * safe_log(x)",
)
functions.pop("inverse")
functions["safe_inv"] = ParamFunction(
@@ -766,6 +780,7 @@ class analytic:
+ reg_param[1] * analytic._safe_inv(model_param),
lambda model_param: True,
2,
+ repr_str="β₀ + β₁ * safe(1/x)",
)
functions.pop("sqrt")
functions["safe_sqrt"] = ParamFunction(
@@ -773,6 +788,7 @@ class analytic:
+ reg_param[1] * analytic._safe_sqrt(model_param),
lambda model_param: True,
2,
+ repr_str="β₀ + β₁ * safe_sqrt(x)",
)
if bool(int(os.getenv("DFATOOL_FIT_LINEAR_ONLY", "0"))):