summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-07-19 15:59:13 +0200
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-07-19 15:59:13 +0200
commit26d1aca811534e5c1673e01cc5eee47696bf7426 (patch)
tree5a1da5bcb3162d841b6926c53f19b46716e058e0 /lib
parent359a262e784219b3e4e4b63ec861f91104c09feb (diff)
specify appropriate bounds for roofline functions.
Now they are working as intended. Hurray!
Diffstat (limited to 'lib')
-rw-r--r--lib/functions.py40
-rw-r--r--lib/paramfit.py11
2 files changed, 46 insertions, 5 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 6bbc268..0a00f83 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -91,7 +91,13 @@ class ParamFunction:
"""
def __init__(
- self, param_function, validation_function, num_vars, repr_str=None, ini=None
+ self,
+ param_function,
+ validation_function,
+ num_vars,
+ repr_str=None,
+ ini=None,
+ bounds=((-np.inf, -np.inf), (np.inf, np.inf)),
):
"""
Create function object suitable for regression analysis.
@@ -116,6 +122,7 @@ class ParamFunction:
self._num_variables = num_vars
self.repr_str = repr_str
self.ini = ini
+ self.bounds = bounds
def __repr__(self) -> str:
if self.repr_str:
@@ -1796,6 +1803,7 @@ class AnalyticFunction(ModelFunction):
both for function usage and least squares optimization.
If unset, defaults to [1, 1, 1, ...]
"""
+ bounds = kwargs.pop("bounds", dict())
super().__init__(value, **kwargs)
self._parameter_names = parameters
self._num_args = num_args
@@ -1804,6 +1812,7 @@ class AnalyticFunction(ModelFunction):
self._dependson = [False] * (len(parameters) + num_args)
self.fit_success = False
self.fit_by_param = fit_by_param
+ self.bounds = bounds
if type(function_str) == str:
num_vars_re = re.compile(r"regression_arg\(([0-9]+)\)")
@@ -1909,10 +1918,25 @@ class AnalyticFunction(ModelFunction):
"""
X, Y, num_valid, num_total = self.get_fit_data(by_param)
if num_valid > 2:
+ lower_bounds = list()
+ upper_bounds = list()
+ for i in range(len(self.model_args)):
+ if i in self.bounds and self.bounds[i][0] == "range":
+ param_index = self._parameter_names.index(self.bounds[i][1])
+ lower_bounds.append(np.min(X[param_index]))
+ upper_bounds.append(np.max(X[param_index]))
+ self.model_args[i] = np.mean(X[param_index])
+ else:
+ lower_bounds.append(-np.inf)
+ upper_bounds.append(np.inf)
error_function = lambda P, X, y: self._function(P, X) - y
try:
res = optimize.least_squares(
- error_function, self.model_args, args=(X, Y), xtol=2e-15
+ error_function,
+ self.model_args,
+ args=(X, Y),
+ xtol=2e-15,
+ bounds=(lower_bounds, upper_bounds),
)
except ValueError as err:
logger.warning(f"Fit failed: {err} (function: {self.model_function})")
@@ -2120,7 +2144,7 @@ class analytic:
lambda model_param: True,
3,
repr_str="β₀ + β₁ * roofline(x, β₂)",
- ini=[0, 1, 50],
+ bounds=((-np.inf, -np.inf, -np.inf), (np.inf, np.inf, np.inf)),
),
# "num0_8": ParamFunction(
# lambda reg_param, model_param: reg_param[0]
@@ -2222,6 +2246,7 @@ class analytic:
"""
buf = "0"
arg_idx = 0
+ bounds = dict()
for combination in powerset(fit_results.items()):
buf += " + regression_arg({:d})".format(arg_idx)
arg_idx += 1
@@ -2236,7 +2261,14 @@ class analytic:
)
)
if function_item[1]["best"] == "roofline":
+ bounds[arg_idx] = ("range", function_item[0])
arg_idx += 1
return AnalyticFunction(
- None, buf, parameter_names, num_args, fit_by_param=fit_results, **kwargs
+ None,
+ buf,
+ parameter_names,
+ num_args,
+ fit_by_param=fit_results,
+ bounds=bounds,
+ **kwargs,
)
diff --git a/lib/paramfit.py b/lib/paramfit.py
index 779127e..44ff0a6 100644
--- a/lib/paramfit.py
+++ b/lib/paramfit.py
@@ -209,9 +209,18 @@ def _try_fits(
ini = param_function.ini
else:
ini = [0] + [1 for i in range(1, param_function._num_variables)]
+ if function_name == "roofline":
+ param_function.bounds = (
+ (-np.inf, -np.inf, np.min(X)),
+ (np.inf, np.inf, np.max(X)),
+ )
try:
res = optimize.least_squares(
- error_function, ini, args=(X, Y), xtol=2e-15
+ error_function,
+ ini,
+ args=(X, Y),
+ xtol=2e-15,
+ bounds=param_function.bounds,
)
except FloatingPointError as e:
logger.warning(