summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-07-19 14:47:12 +0200
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-07-19 14:47:12 +0200
commit359a262e784219b3e4e4b63ec861f91104c09feb (patch)
treeb6704d320bd4b0765b4092f55fb03e8caa24035e
parentdcf3ca444a3f820f5f3ae3a55f975e40f1c41526 (diff)
ULS: add support for roofline functions
they're not working as intended yet
-rw-r--r--lib/functions.py36
-rw-r--r--lib/paramfit.py6
2 files changed, 30 insertions, 12 deletions
diff --git a/lib/functions.py b/lib/functions.py
index d231c9c..6bbc268 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -90,7 +90,9 @@ class ParamFunction:
error measure.
"""
- def __init__(self, param_function, validation_function, num_vars, repr_str=None):
+ def __init__(
+ self, param_function, validation_function, num_vars, repr_str=None, ini=None
+ ):
"""
Create function object suitable for regression analysis.
@@ -113,6 +115,7 @@ class ParamFunction:
self._validation_function = validation_function
self._num_variables = num_vars
self.repr_str = repr_str
+ self.ini = ini
def __repr__(self) -> str:
if self.repr_str:
@@ -2025,6 +2028,7 @@ class analytic:
_safe_log = np.vectorize(lambda x: np.log(np.abs(x)) if np.abs(x) > 0.001 else 1.0)
_safe_inv = np.vectorize(lambda x: 1 / x if np.abs(x) > 0.001 else 1.0)
_safe_sqrt = np.vectorize(lambda x: np.sqrt(np.abs(x)))
+ _roofline = np.vectorize(lambda x, y: x if x < y else y)
_function_map = {
"linear": lambda x: x,
@@ -2040,6 +2044,7 @@ class analytic:
"safe_log": lambda x: np.log(np.abs(x)) if np.abs(x) > 0.001 else 1.0,
"safe_inv": lambda x: 1 / x if np.abs(x) > 0.001 else 1.0,
"safe_sqrt": lambda x: np.sqrt(np.abs(x)),
+ "roofline": lambda x, y: x if x < y else y,
}
@staticmethod
@@ -2109,6 +2114,14 @@ class analytic:
2,
repr_str="β₀ + β₁ * np.sqrt(x)",
),
+ "roofline": ParamFunction(
+ lambda reg_param, model_param: reg_param[0]
+ + reg_param[1] * analytic._roofline(model_param, reg_param[2]),
+ lambda model_param: True,
+ 3,
+ repr_str="β₀ + β₁ * roofline(x, β₂)",
+ ini=[0, 1, 50],
+ ),
# "num0_8": ParamFunction(
# lambda reg_param, model_param: reg_param[0]
# + reg_param[1] * analytic._num0_8(model_param),
@@ -2164,7 +2177,7 @@ class analytic:
return functions
@staticmethod
- def _fmap(reference_type, reference_name, function_type):
+ def _fmap(reference_type, reference_name, function_type, arg_idx=None):
"""Map arg/parameter name and best-fit function name to function text suitable for AnalyticFunction."""
ref_str = "{}({})".format(reference_type, reference_name)
if function_type == "linear":
@@ -2183,6 +2196,8 @@ class analytic:
return "1/({})".format(ref_str)
if function_type == "sqrt":
return "np.sqrt({})".format(ref_str)
+ if function_type == "roofline":
+ return "analytic._roofline({}, regression_arg({}))".format(ref_str, arg_idx)
return "analytic._{}({})".format(function_type, ref_str)
@staticmethod
@@ -2212,17 +2227,16 @@ class analytic:
arg_idx += 1
for function_item in combination:
if is_numeric(function_item[0]):
- buf += " * {}".format(
- analytic._fmap(
- "function_arg", function_item[0], function_item[1]["best"]
- )
- )
+ mapkey = "function_arg"
else:
- buf += " * {}".format(
- analytic._fmap(
- "parameter", function_item[0], function_item[1]["best"]
- )
+ mapkey = "parameter"
+ buf += " * {}".format(
+ analytic._fmap(
+ mapkey, function_item[0], function_item[1]["best"], arg_idx
)
+ )
+ if function_item[1]["best"] == "roofline":
+ arg_idx += 1
return AnalyticFunction(
None, buf, parameter_names, num_args, fit_by_param=fit_results, **kwargs
)
diff --git a/lib/paramfit.py b/lib/paramfit.py
index e6539a4..779127e 100644
--- a/lib/paramfit.py
+++ b/lib/paramfit.py
@@ -205,9 +205,13 @@ def _try_fits(
if function_name not in raw_results:
raw_results[function_name] = dict()
error_function = param_function.error_function
+ if param_function.ini:
+ ini = param_function.ini
+ else:
+ ini = [0] + [1 for i in range(1, param_function._num_variables)]
try:
res = optimize.least_squares(
- error_function, [0, 1], args=(X, Y), xtol=2e-15
+ error_function, ini, args=(X, Y), xtol=2e-15
)
except FloatingPointError as e:
logger.warning(