diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-07-19 14:47:12 +0200 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-07-19 14:47:12 +0200 |
commit | 359a262e784219b3e4e4b63ec861f91104c09feb (patch) | |
tree | b6704d320bd4b0765b4092f55fb03e8caa24035e | |
parent | dcf3ca444a3f820f5f3ae3a55f975e40f1c41526 (diff) |
ULS: add support for roofline functions
they're not working as intended yet
-rw-r--r-- | lib/functions.py | 36 | ||||
-rw-r--r-- | lib/paramfit.py | 6 |
2 files changed, 30 insertions, 12 deletions
diff --git a/lib/functions.py b/lib/functions.py index d231c9c..6bbc268 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -90,7 +90,9 @@ class ParamFunction: error measure. """ - def __init__(self, param_function, validation_function, num_vars, repr_str=None): + def __init__( + self, param_function, validation_function, num_vars, repr_str=None, ini=None + ): """ Create function object suitable for regression analysis. @@ -113,6 +115,7 @@ class ParamFunction: self._validation_function = validation_function self._num_variables = num_vars self.repr_str = repr_str + self.ini = ini def __repr__(self) -> str: if self.repr_str: @@ -2025,6 +2028,7 @@ class analytic: _safe_log = np.vectorize(lambda x: np.log(np.abs(x)) if np.abs(x) > 0.001 else 1.0) _safe_inv = np.vectorize(lambda x: 1 / x if np.abs(x) > 0.001 else 1.0) _safe_sqrt = np.vectorize(lambda x: np.sqrt(np.abs(x))) + _roofline = np.vectorize(lambda x, y: x if x < y else y) _function_map = { "linear": lambda x: x, @@ -2040,6 +2044,7 @@ class analytic: "safe_log": lambda x: np.log(np.abs(x)) if np.abs(x) > 0.001 else 1.0, "safe_inv": lambda x: 1 / x if np.abs(x) > 0.001 else 1.0, "safe_sqrt": lambda x: np.sqrt(np.abs(x)), + "roofline": lambda x, y: x if x < y else y, } @staticmethod @@ -2109,6 +2114,14 @@ class analytic: 2, repr_str="β₀ + β₁ * np.sqrt(x)", ), + "roofline": ParamFunction( + lambda reg_param, model_param: reg_param[0] + + reg_param[1] * analytic._roofline(model_param, reg_param[2]), + lambda model_param: True, + 3, + repr_str="β₀ + β₁ * roofline(x, β₂)", + ini=[0, 1, 50], + ), # "num0_8": ParamFunction( # lambda reg_param, model_param: reg_param[0] # + reg_param[1] * analytic._num0_8(model_param), @@ -2164,7 +2177,7 @@ class analytic: return functions @staticmethod - def _fmap(reference_type, reference_name, function_type): + def _fmap(reference_type, reference_name, function_type, arg_idx=None): """Map arg/parameter name and best-fit function name to function text suitable for AnalyticFunction.""" ref_str = "{}({})".format(reference_type, reference_name) if function_type == "linear": @@ -2183,6 +2196,8 @@ class analytic: return "1/({})".format(ref_str) if function_type == "sqrt": return "np.sqrt({})".format(ref_str) + if function_type == "roofline": + return "analytic._roofline({}, regression_arg({}))".format(ref_str, arg_idx) return "analytic._{}({})".format(function_type, ref_str) @staticmethod @@ -2212,17 +2227,16 @@ class analytic: arg_idx += 1 for function_item in combination: if is_numeric(function_item[0]): - buf += " * {}".format( - analytic._fmap( - "function_arg", function_item[0], function_item[1]["best"] - ) - ) + mapkey = "function_arg" else: - buf += " * {}".format( - analytic._fmap( - "parameter", function_item[0], function_item[1]["best"] - ) + mapkey = "parameter" + buf += " * {}".format( + analytic._fmap( + mapkey, function_item[0], function_item[1]["best"], arg_idx ) + ) + if function_item[1]["best"] == "roofline": + arg_idx += 1 return AnalyticFunction( None, buf, parameter_names, num_args, fit_by_param=fit_results, **kwargs ) diff --git a/lib/paramfit.py b/lib/paramfit.py index e6539a4..779127e 100644 --- a/lib/paramfit.py +++ b/lib/paramfit.py @@ -205,9 +205,13 @@ def _try_fits( if function_name not in raw_results: raw_results[function_name] = dict() error_function = param_function.error_function + if param_function.ini: + ini = param_function.ini + else: + ini = [0] + [1 for i in range(1, param_function._num_variables)] try: res = optimize.least_squares( - error_function, [0, 1], args=(X, Y), xtol=2e-15 + error_function, ini, args=(X, Y), xtol=2e-15 ) except FloatingPointError as e: logger.warning( |