From c68c4a2bc617dd1356d5d0d2c3ee0ff9754261ab Mon Sep 17 00:00:00 2001 From: Daniel Friesel Date: Thu, 18 Feb 2021 12:07:52 +0100 Subject: refactor model generation from Analytic/PTAModel into ModelAttribute class Iteration over states/transitions and model attributes is no longer hardcoded into most model generation code. This should make support for decision trees and sub-states much easier. --- lib/functions.py | 57 ++++++++++++++++++++------------------------------------ 1 file changed, 20 insertions(+), 37 deletions(-) (limited to 'lib/functions.py') diff --git a/lib/functions.py b/lib/functions.py index 9d799c7..0bdea45 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -222,21 +222,18 @@ class AnalyticFunction: else: self.model_args = [] - def get_fit_data(self, by_param, state_or_tran, model_attribute): + def get_fit_data(self, by_param): """ Return training data suitable for scipy.optimize.least_squares. - :param by_param: measurement data, partitioned by state/transition name and parameter/arg values. - This function only uses by_param[(state_or_tran, *)][model_attribute], - which must be a list or 1-D NumPy array containing the ground truth. - The parameter values in (state_or_tran, *) must be numeric for + :param by_param: measurement data, partitioned by parameter/arg values. + by_param[*] must be a list or 1-D NumPy array containing the ground truth. + The parameter values (dict keys) must be numeric for all parameters this function depends on -- otherwise, the corresponding data will be left out. Parameter values must be ordered according to the order of parameter names used in the ParamFunction constructor. Argument values (if any) always come after parameters, in the order of their index in the function signature. - :param state_or_tran: state or transition name, e.g. "TX" or "send" - :param model_attribute: model attribute name, e.g. "power" or "duration" :return: (X, Y, num_valid, num_total): X -- 2-D NumPy array of parameter combinations (model input). @@ -255,48 +252,44 @@ class AnalyticFunction: num_total = 0 for key, val in by_param.items(): - if key[0] == state_or_tran and len(key[1]) == dimension: + if len(key) == dimension: valid = True num_total += 1 for i in range(dimension): - if self._dependson[i] and not is_numeric(key[1][i]): + if self._dependson[i] and not is_numeric(key[i]): valid = False if valid: num_valid += 1 - Y.extend(val[model_attribute]) + Y.extend(val) for i in range(dimension): if self._dependson[i]: - X[i].extend([float(key[1][i])] * len(val[model_attribute])) + X[i].extend([float(key[i])] * len(val)) else: - X[i].extend([np.nan] * len(val[model_attribute])) - elif key[0] == state_or_tran and len(key[1]) != dimension: + X[i].extend([np.nan] * len(val)) + else: logger.warning( - "Invalid parameter key length while gathering fit data for {}/{}. is {}, want {}.".format( - state_or_tran, model_attribute, len(key[1]), dimension - ), + "Invalid parameter key length while gathering fit data. is {}, want {}.".format( + len(key), dimension + ) ) X = np.array(X) Y = np.array(Y) return X, Y, num_valid, num_total - def fit(self, by_param, state_or_tran, model_attribute): + def fit(self, by_param): """ Fit the function on measurements via least squares regression. - :param by_param: measurement data, partitioned by state/transition name and parameter/arg values - :param state_or_tran: state or transition name, e.g. "TX" or "send" - :param model_attribute: model attribute name, e.g. "power" or "duration" + :param by_param: measurement data, partitioned by parameter/arg values - The ground truth is read from by_param[(state_or_tran, *)][model_attribute], + The ground truth is read from by_param[*], which must be a list or 1-D NumPy array. Parameter values must be ordered according to the parameter names in the constructor. If argument values are present, they must come after parameter values in the order of their appearance in the function signature. """ - X, Y, num_valid, num_total = self.get_fit_data( - by_param, state_or_tran, model_attribute - ) + X, Y, num_valid, num_total = self.get_fit_data(by_param) if num_valid > 2: error_function = lambda P, X, y: self._function(P, X) - y try: @@ -304,27 +297,17 @@ class AnalyticFunction: error_function, self.model_args, args=(X, Y), xtol=2e-15 ) except ValueError as err: - logger.warning( - "Fit failed for {}/{}: {} (function: {})".format( - state_or_tran, model_attribute, err, self.model_function - ), - ) + logger.warning(f"Fit failed: {err} (function: {self.model_function})") return if res.status > 0: self.model_args = res.x self.fit_success = True else: logger.warning( - "Fit failed for {}/{}: {} (function: {})".format( - state_or_tran, model_attribute, res.message, self.model_function - ), + f"Fit failed: {res.message} (function: {self.model_function})" ) else: - logger.warning( - "Insufficient amount of valid parameter keys, cannot fit {}/{}".format( - state_or_tran, model_attribute - ), - ) + logger.warning("Insufficient amount of valid parameter keys, cannot fit") def is_predictable(self, param_list): """ -- cgit v1.2.3