diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-07-19 15:59:13 +0200 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-07-19 15:59:13 +0200 |
commit | 26d1aca811534e5c1673e01cc5eee47696bf7426 (patch) | |
tree | 5a1da5bcb3162d841b6926c53f19b46716e058e0 /lib | |
parent | 359a262e784219b3e4e4b63ec861f91104c09feb (diff) |
specify appropriate bounds for roofline functions.
Now they are working as intended. Hurray!
Diffstat (limited to 'lib')
-rw-r--r-- | lib/functions.py | 40 | ||||
-rw-r--r-- | lib/paramfit.py | 11 |
2 files changed, 46 insertions, 5 deletions
diff --git a/lib/functions.py b/lib/functions.py index 6bbc268..0a00f83 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -91,7 +91,13 @@ class ParamFunction: """ def __init__( - self, param_function, validation_function, num_vars, repr_str=None, ini=None + self, + param_function, + validation_function, + num_vars, + repr_str=None, + ini=None, + bounds=((-np.inf, -np.inf), (np.inf, np.inf)), ): """ Create function object suitable for regression analysis. @@ -116,6 +122,7 @@ class ParamFunction: self._num_variables = num_vars self.repr_str = repr_str self.ini = ini + self.bounds = bounds def __repr__(self) -> str: if self.repr_str: @@ -1796,6 +1803,7 @@ class AnalyticFunction(ModelFunction): both for function usage and least squares optimization. If unset, defaults to [1, 1, 1, ...] """ + bounds = kwargs.pop("bounds", dict()) super().__init__(value, **kwargs) self._parameter_names = parameters self._num_args = num_args @@ -1804,6 +1812,7 @@ class AnalyticFunction(ModelFunction): self._dependson = [False] * (len(parameters) + num_args) self.fit_success = False self.fit_by_param = fit_by_param + self.bounds = bounds if type(function_str) == str: num_vars_re = re.compile(r"regression_arg\(([0-9]+)\)") @@ -1909,10 +1918,25 @@ class AnalyticFunction(ModelFunction): """ X, Y, num_valid, num_total = self.get_fit_data(by_param) if num_valid > 2: + lower_bounds = list() + upper_bounds = list() + for i in range(len(self.model_args)): + if i in self.bounds and self.bounds[i][0] == "range": + param_index = self._parameter_names.index(self.bounds[i][1]) + lower_bounds.append(np.min(X[param_index])) + upper_bounds.append(np.max(X[param_index])) + self.model_args[i] = np.mean(X[param_index]) + else: + lower_bounds.append(-np.inf) + upper_bounds.append(np.inf) error_function = lambda P, X, y: self._function(P, X) - y try: res = optimize.least_squares( - error_function, self.model_args, args=(X, Y), xtol=2e-15 + error_function, + self.model_args, + args=(X, Y), + xtol=2e-15, + bounds=(lower_bounds, upper_bounds), ) except ValueError as err: logger.warning(f"Fit failed: {err} (function: {self.model_function})") @@ -2120,7 +2144,7 @@ class analytic: lambda model_param: True, 3, repr_str="β₀ + β₁ * roofline(x, β₂)", - ini=[0, 1, 50], + bounds=((-np.inf, -np.inf, -np.inf), (np.inf, np.inf, np.inf)), ), # "num0_8": ParamFunction( # lambda reg_param, model_param: reg_param[0] @@ -2222,6 +2246,7 @@ class analytic: """ buf = "0" arg_idx = 0 + bounds = dict() for combination in powerset(fit_results.items()): buf += " + regression_arg({:d})".format(arg_idx) arg_idx += 1 @@ -2236,7 +2261,14 @@ class analytic: ) ) if function_item[1]["best"] == "roofline": + bounds[arg_idx] = ("range", function_item[0]) arg_idx += 1 return AnalyticFunction( - None, buf, parameter_names, num_args, fit_by_param=fit_results, **kwargs + None, + buf, + parameter_names, + num_args, + fit_by_param=fit_results, + bounds=bounds, + **kwargs, ) diff --git a/lib/paramfit.py b/lib/paramfit.py index 779127e..44ff0a6 100644 --- a/lib/paramfit.py +++ b/lib/paramfit.py @@ -209,9 +209,18 @@ def _try_fits( ini = param_function.ini else: ini = [0] + [1 for i in range(1, param_function._num_variables)] + if function_name == "roofline": + param_function.bounds = ( + (-np.inf, -np.inf, np.min(X)), + (np.inf, np.inf, np.max(X)), + ) try: res = optimize.least_squares( - error_function, ini, args=(X, Y), xtol=2e-15 + error_function, + ini, + args=(X, Y), + xtol=2e-15, + bounds=param_function.bounds, ) except FloatingPointError as e: logger.warning( |