diff options
Diffstat (limited to 'lib/functions.py')
-rw-r--r-- | lib/functions.py | 250 |
1 files changed, 155 insertions, 95 deletions
diff --git a/lib/functions.py b/lib/functions.py index 2451ef6..6d8daa4 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -12,6 +12,7 @@ from .utils import is_numeric, vprint arg_support_enabled = True + def powerset(iterable): """ Return powerset of `iterable` elements. @@ -19,7 +20,8 @@ def powerset(iterable): Example: `powerset([1, 2])` -> `[(), (1), (2), (1, 2)]` """ s = list(iterable) - return chain.from_iterable(combinations(s, r) for r in range(len(s)+1)) + return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) + class ParamFunction: """ @@ -82,6 +84,7 @@ class ParamFunction: """ return self._param_function(P, X) - y + class NormalizationFunction: """ Wrapper for parameter normalization functions used in YAML PTA/DFA models. @@ -95,7 +98,7 @@ class NormalizationFunction: `param` and return a float. """ self._function_str = function_str - self._function = eval('lambda param: ' + function_str) + self._function = eval("lambda param: " + function_str) def eval(self, param_value: float) -> float: """ @@ -105,6 +108,7 @@ class NormalizationFunction: """ return self._function(param_value) + class AnalyticFunction: """ A multi-dimensional model function, generated from a string, which can be optimized using regression. @@ -114,7 +118,9 @@ class AnalyticFunction: packet length. """ - def __init__(self, function_str, parameters, num_args, verbose = True, regression_args = None): + def __init__( + self, function_str, parameters, num_args, verbose=True, regression_args=None + ): """ Create a new AnalyticFunction object from a function string. @@ -143,22 +149,30 @@ class AnalyticFunction: self.verbose = verbose if type(function_str) == str: - num_vars_re = re.compile(r'regression_arg\(([0-9]+)\)') + num_vars_re = re.compile(r"regression_arg\(([0-9]+)\)") num_vars = max(map(int, num_vars_re.findall(function_str))) + 1 for i in range(len(parameters)): - if rawfunction.find('parameter({})'.format(parameters[i])) >= 0: + if rawfunction.find("parameter({})".format(parameters[i])) >= 0: self._dependson[i] = True - rawfunction = rawfunction.replace('parameter({})'.format(parameters[i]), 'model_param[{:d}]'.format(i)) + rawfunction = rawfunction.replace( + "parameter({})".format(parameters[i]), + "model_param[{:d}]".format(i), + ) for i in range(0, num_args): - if rawfunction.find('function_arg({:d})'.format(i)) >= 0: + if rawfunction.find("function_arg({:d})".format(i)) >= 0: self._dependson[len(parameters) + i] = True - rawfunction = rawfunction.replace('function_arg({:d})'.format(i), 'model_param[{:d}]'.format(len(parameters) + i)) + rawfunction = rawfunction.replace( + "function_arg({:d})".format(i), + "model_param[{:d}]".format(len(parameters) + i), + ) for i in range(num_vars): - rawfunction = rawfunction.replace('regression_arg({:d})'.format(i), 'reg_param[{:d}]'.format(i)) + rawfunction = rawfunction.replace( + "regression_arg({:d})".format(i), "reg_param[{:d}]".format(i) + ) self._function_str = rawfunction - self._function = eval('lambda reg_param, model_param: ' + rawfunction) + self._function = eval("lambda reg_param, model_param: " + rawfunction) else: - self._function_str = 'raise ValueError' + self._function_str = "raise ValueError" self._function = function_str if regression_args: @@ -217,7 +231,12 @@ class AnalyticFunction: else: X[i].extend([np.nan] * len(val[model_attribute])) elif key[0] == state_or_tran and len(key[1]) != dimension: - vprint(self.verbose, '[W] Invalid parameter key length while gathering fit data for {}/{}. is {}, want {}.'.format(state_or_tran, model_attribute, len(key[1]), dimension)) + vprint( + self.verbose, + "[W] Invalid parameter key length while gathering fit data for {}/{}. is {}, want {}.".format( + state_or_tran, model_attribute, len(key[1]), dimension + ), + ) X = np.array(X) Y = np.array(Y) @@ -237,21 +256,40 @@ class AnalyticFunction: argument values are present, they must come after parameter values in the order of their appearance in the function signature. """ - X, Y, num_valid, num_total = self.get_fit_data(by_param, state_or_tran, model_attribute) + X, Y, num_valid, num_total = self.get_fit_data( + by_param, state_or_tran, model_attribute + ) if num_valid > 2: error_function = lambda P, X, y: self._function(P, X) - y try: - res = optimize.least_squares(error_function, self._regression_args, args=(X, Y), xtol=2e-15) + res = optimize.least_squares( + error_function, self._regression_args, args=(X, Y), xtol=2e-15 + ) except ValueError as err: - vprint(self.verbose, '[W] Fit failed for {}/{}: {} (function: {})'.format(state_or_tran, model_attribute, err, self._model_str)) + vprint( + self.verbose, + "[W] Fit failed for {}/{}: {} (function: {})".format( + state_or_tran, model_attribute, err, self._model_str + ), + ) return if res.status > 0: self._regression_args = res.x self.fit_success = True else: - vprint(self.verbose, '[W] Fit failed for {}/{}: {} (function: {})'.format(state_or_tran, model_attribute, res.message, self._model_str)) + vprint( + self.verbose, + "[W] Fit failed for {}/{}: {} (function: {})".format( + state_or_tran, model_attribute, res.message, self._model_str + ), + ) else: - vprint(self.verbose, '[W] Insufficient amount of valid parameter keys, cannot fit {}/{}'.format(state_or_tran, model_attribute)) + vprint( + self.verbose, + "[W] Insufficient amount of valid parameter keys, cannot fit {}/{}".format( + state_or_tran, model_attribute + ), + ) def is_predictable(self, param_list): """ @@ -268,7 +306,7 @@ class AnalyticFunction: return False return True - def eval(self, param_list, arg_list = []): + def eval(self, param_list, arg_list=[]): """ Evaluate model function with specified param/arg values. @@ -280,6 +318,7 @@ class AnalyticFunction: return self._function(param_list, arg_list) return self._function(self._regression_args, param_list) + class analytic: """ Utilities for analytic description of parameter-dependent model attributes and regression analysis. @@ -292,28 +331,28 @@ class analytic: _num0_8 = np.vectorize(lambda x: 8 - bin(int(x)).count("1")) _num0_16 = np.vectorize(lambda x: 16 - bin(int(x)).count("1")) _num1 = np.vectorize(lambda x: bin(int(x)).count("1")) - _safe_log = np.vectorize(lambda x: np.log(np.abs(x)) if np.abs(x) > 0.001 else 1.) - _safe_inv = np.vectorize(lambda x: 1 / x if np.abs(x) > 0.001 else 1.) + _safe_log = np.vectorize(lambda x: np.log(np.abs(x)) if np.abs(x) > 0.001 else 1.0) + _safe_inv = np.vectorize(lambda x: 1 / x if np.abs(x) > 0.001 else 1.0) _safe_sqrt = np.vectorize(lambda x: np.sqrt(np.abs(x))) _function_map = { - 'linear' : lambda x: x, - 'logarithmic' : np.log, - 'logarithmic1' : lambda x: np.log(x + 1), - 'exponential' : np.exp, - 'square' : lambda x : x ** 2, - 'inverse' : lambda x : 1 / x, - 'sqrt' : lambda x: np.sqrt(np.abs(x)), - 'num0_8' : _num0_8, - 'num0_16' : _num0_16, - 'num1' : _num1, - 'safe_log' : lambda x: np.log(np.abs(x)) if np.abs(x) > 0.001 else 1., - 'safe_inv' : lambda x: 1 / x if np.abs(x) > 0.001 else 1., - 'safe_sqrt': lambda x: np.sqrt(np.abs(x)), + "linear": lambda x: x, + "logarithmic": np.log, + "logarithmic1": lambda x: np.log(x + 1), + "exponential": np.exp, + "square": lambda x: x ** 2, + "inverse": lambda x: 1 / x, + "sqrt": lambda x: np.sqrt(np.abs(x)), + "num0_8": _num0_8, + "num0_16": _num0_16, + "num1": _num1, + "safe_log": lambda x: np.log(np.abs(x)) if np.abs(x) > 0.001 else 1.0, + "safe_inv": lambda x: 1 / x if np.abs(x) > 0.001 else 1.0, + "safe_sqrt": lambda x: np.sqrt(np.abs(x)), } @staticmethod - def functions(safe_functions_enabled = False): + def functions(safe_functions_enabled=False): """ Retrieve pre-defined set of regression function candidates. @@ -329,74 +368,87 @@ class analytic: variables are expected. """ functions = { - 'linear' : ParamFunction( - lambda reg_param, model_param: reg_param[0] + reg_param[1] * model_param, + "linear": ParamFunction( + lambda reg_param, model_param: reg_param[0] + + reg_param[1] * model_param, lambda model_param: True, - 2 + 2, ), - 'logarithmic' : ParamFunction( - lambda reg_param, model_param: reg_param[0] + reg_param[1] * np.log(model_param), + "logarithmic": ParamFunction( + lambda reg_param, model_param: reg_param[0] + + reg_param[1] * np.log(model_param), lambda model_param: model_param > 0, - 2 + 2, ), - 'logarithmic1' : ParamFunction( - lambda reg_param, model_param: reg_param[0] + reg_param[1] * np.log(model_param + 1), + "logarithmic1": ParamFunction( + lambda reg_param, model_param: reg_param[0] + + reg_param[1] * np.log(model_param + 1), lambda model_param: model_param > -1, - 2 + 2, ), - 'exponential' : ParamFunction( - lambda reg_param, model_param: reg_param[0] + reg_param[1] * np.exp(model_param), + "exponential": ParamFunction( + lambda reg_param, model_param: reg_param[0] + + reg_param[1] * np.exp(model_param), lambda model_param: model_param <= 64, - 2 + 2, ), #'polynomial' : lambda reg_param, model_param: reg_param[0] + reg_param[1] * model_param + reg_param[2] * model_param ** 2, - 'square' : ParamFunction( - lambda reg_param, model_param: reg_param[0] + reg_param[1] * model_param ** 2, + "square": ParamFunction( + lambda reg_param, model_param: reg_param[0] + + reg_param[1] * model_param ** 2, lambda model_param: True, - 2 + 2, ), - 'inverse' : ParamFunction( - lambda reg_param, model_param: reg_param[0] + reg_param[1] / model_param, + "inverse": ParamFunction( + lambda reg_param, model_param: reg_param[0] + + reg_param[1] / model_param, lambda model_param: model_param != 0, - 2 + 2, ), - 'sqrt' : ParamFunction( - lambda reg_param, model_param: reg_param[0] + reg_param[1] * np.sqrt(model_param), + "sqrt": ParamFunction( + lambda reg_param, model_param: reg_param[0] + + reg_param[1] * np.sqrt(model_param), lambda model_param: model_param >= 0, - 2 + 2, ), - 'num0_8' : ParamFunction( - lambda reg_param, model_param: reg_param[0] + reg_param[1] * analytic._num0_8(model_param), + "num0_8": ParamFunction( + lambda reg_param, model_param: reg_param[0] + + reg_param[1] * analytic._num0_8(model_param), lambda model_param: True, - 2 + 2, ), - 'num0_16' : ParamFunction( - lambda reg_param, model_param: reg_param[0] + reg_param[1] * analytic._num0_16(model_param), + "num0_16": ParamFunction( + lambda reg_param, model_param: reg_param[0] + + reg_param[1] * analytic._num0_16(model_param), lambda model_param: True, - 2 + 2, ), - 'num1' : ParamFunction( - lambda reg_param, model_param: reg_param[0] + reg_param[1] * analytic._num1(model_param), + "num1": ParamFunction( + lambda reg_param, model_param: reg_param[0] + + reg_param[1] * analytic._num1(model_param), lambda model_param: True, - 2 + 2, ), } if safe_functions_enabled: - functions['safe_log'] = ParamFunction( - lambda reg_param, model_param: reg_param[0] + reg_param[1] * analytic._safe_log(model_param), + functions["safe_log"] = ParamFunction( + lambda reg_param, model_param: reg_param[0] + + reg_param[1] * analytic._safe_log(model_param), lambda model_param: True, - 2 + 2, ) - functions['safe_inv'] = ParamFunction( - lambda reg_param, model_param: reg_param[0] + reg_param[1] * analytic._safe_inv(model_param), + functions["safe_inv"] = ParamFunction( + lambda reg_param, model_param: reg_param[0] + + reg_param[1] * analytic._safe_inv(model_param), lambda model_param: True, - 2 + 2, ) - functions['safe_sqrt'] = ParamFunction( - lambda reg_param, model_param: reg_param[0] + reg_param[1] * analytic._safe_sqrt(model_param), + functions["safe_sqrt"] = ParamFunction( + lambda reg_param, model_param: reg_param[0] + + reg_param[1] * analytic._safe_sqrt(model_param), lambda model_param: True, - 2 + 2, ) return functions @@ -404,27 +456,27 @@ class analytic: @staticmethod def _fmap(reference_type, reference_name, function_type): """Map arg/parameter name and best-fit function name to function text suitable for AnalyticFunction.""" - ref_str = '{}({})'.format(reference_type,reference_name) - if function_type == 'linear': + ref_str = "{}({})".format(reference_type, reference_name) + if function_type == "linear": return ref_str - if function_type == 'logarithmic': - return 'np.log({})'.format(ref_str) - if function_type == 'logarithmic1': - return 'np.log({} + 1)'.format(ref_str) - if function_type == 'exponential': - return 'np.exp({})'.format(ref_str) - if function_type == 'exponential': - return 'np.exp({})'.format(ref_str) - if function_type == 'square': - return '({})**2'.format(ref_str) - if function_type == 'inverse': - return '1/({})'.format(ref_str) - if function_type == 'sqrt': - return 'np.sqrt({})'.format(ref_str) - return 'analytic._{}({})'.format(function_type, ref_str) + if function_type == "logarithmic": + return "np.log({})".format(ref_str) + if function_type == "logarithmic1": + return "np.log({} + 1)".format(ref_str) + if function_type == "exponential": + return "np.exp({})".format(ref_str) + if function_type == "exponential": + return "np.exp({})".format(ref_str) + if function_type == "square": + return "({})**2".format(ref_str) + if function_type == "inverse": + return "1/({})".format(ref_str) + if function_type == "sqrt": + return "np.sqrt({})".format(ref_str) + return "analytic._{}({})".format(function_type, ref_str) @staticmethod - def function_powerset(fit_results, parameter_names, num_args = 0): + def function_powerset(fit_results, parameter_names, num_args=0): """ Combine per-parameter regression results into a single multi-dimensional function. @@ -443,14 +495,22 @@ class analytic: Returns an AnalyticFunction instantce corresponding to the combined function. """ - buf = '0' + buf = "0" arg_idx = 0 for combination in powerset(fit_results.items()): - buf += ' + regression_arg({:d})'.format(arg_idx) + buf += " + regression_arg({:d})".format(arg_idx) arg_idx += 1 for function_item in combination: if arg_support_enabled and is_numeric(function_item[0]): - buf += ' * {}'.format(analytic._fmap('function_arg', function_item[0], function_item[1]['best'])) + buf += " * {}".format( + analytic._fmap( + "function_arg", function_item[0], function_item[1]["best"] + ) + ) else: - buf += ' * {}'.format(analytic._fmap('parameter', function_item[0], function_item[1]['best'])) + buf += " * {}".format( + analytic._fmap( + "parameter", function_item[0], function_item[1]["best"] + ) + ) return AnalyticFunction(buf, parameter_names, num_args) |