summaryrefslogtreecommitdiff
path: root/lib/functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/functions.py')
-rw-r--r--lib/functions.py57
1 files changed, 20 insertions, 37 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 9d799c7..0bdea45 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -222,21 +222,18 @@ class AnalyticFunction:
else:
self.model_args = []
- def get_fit_data(self, by_param, state_or_tran, model_attribute):
+ def get_fit_data(self, by_param):
"""
Return training data suitable for scipy.optimize.least_squares.
- :param by_param: measurement data, partitioned by state/transition name and parameter/arg values.
- This function only uses by_param[(state_or_tran, *)][model_attribute],
- which must be a list or 1-D NumPy array containing the ground truth.
- The parameter values in (state_or_tran, *) must be numeric for
+ :param by_param: measurement data, partitioned by parameter/arg values.
+ by_param[*] must be a list or 1-D NumPy array containing the ground truth.
+ The parameter values (dict keys) must be numeric for
all parameters this function depends on -- otherwise, the
corresponding data will be left out. Parameter values must be
ordered according to the order of parameter names used in
the ParamFunction constructor. Argument values (if any) always come after
parameters, in the order of their index in the function signature.
- :param state_or_tran: state or transition name, e.g. "TX" or "send"
- :param model_attribute: model attribute name, e.g. "power" or "duration"
:return: (X, Y, num_valid, num_total):
X -- 2-D NumPy array of parameter combinations (model input).
@@ -255,48 +252,44 @@ class AnalyticFunction:
num_total = 0
for key, val in by_param.items():
- if key[0] == state_or_tran and len(key[1]) == dimension:
+ if len(key) == dimension:
valid = True
num_total += 1
for i in range(dimension):
- if self._dependson[i] and not is_numeric(key[1][i]):
+ if self._dependson[i] and not is_numeric(key[i]):
valid = False
if valid:
num_valid += 1
- Y.extend(val[model_attribute])
+ Y.extend(val)
for i in range(dimension):
if self._dependson[i]:
- X[i].extend([float(key[1][i])] * len(val[model_attribute]))
+ X[i].extend([float(key[i])] * len(val))
else:
- X[i].extend([np.nan] * len(val[model_attribute]))
- elif key[0] == state_or_tran and len(key[1]) != dimension:
+ X[i].extend([np.nan] * len(val))
+ else:
logger.warning(
- "Invalid parameter key length while gathering fit data for {}/{}. is {}, want {}.".format(
- state_or_tran, model_attribute, len(key[1]), dimension
- ),
+ "Invalid parameter key length while gathering fit data. is {}, want {}.".format(
+ len(key), dimension
+ )
)
X = np.array(X)
Y = np.array(Y)
return X, Y, num_valid, num_total
- def fit(self, by_param, state_or_tran, model_attribute):
+ def fit(self, by_param):
"""
Fit the function on measurements via least squares regression.
- :param by_param: measurement data, partitioned by state/transition name and parameter/arg values
- :param state_or_tran: state or transition name, e.g. "TX" or "send"
- :param model_attribute: model attribute name, e.g. "power" or "duration"
+ :param by_param: measurement data, partitioned by parameter/arg values
- The ground truth is read from by_param[(state_or_tran, *)][model_attribute],
+ The ground truth is read from by_param[*],
which must be a list or 1-D NumPy array. Parameter values must be
ordered according to the parameter names in the constructor. If
argument values are present, they must come after parameter values
in the order of their appearance in the function signature.
"""
- X, Y, num_valid, num_total = self.get_fit_data(
- by_param, state_or_tran, model_attribute
- )
+ X, Y, num_valid, num_total = self.get_fit_data(by_param)
if num_valid > 2:
error_function = lambda P, X, y: self._function(P, X) - y
try:
@@ -304,27 +297,17 @@ class AnalyticFunction:
error_function, self.model_args, args=(X, Y), xtol=2e-15
)
except ValueError as err:
- logger.warning(
- "Fit failed for {}/{}: {} (function: {})".format(
- state_or_tran, model_attribute, err, self.model_function
- ),
- )
+ logger.warning(f"Fit failed: {err} (function: {self.model_function})")
return
if res.status > 0:
self.model_args = res.x
self.fit_success = True
else:
logger.warning(
- "Fit failed for {}/{}: {} (function: {})".format(
- state_or_tran, model_attribute, res.message, self.model_function
- ),
+ f"Fit failed: {res.message} (function: {self.model_function})"
)
else:
- logger.warning(
- "Insufficient amount of valid parameter keys, cannot fit {}/{}".format(
- state_or_tran, model_attribute
- ),
- )
+ logger.warning("Insufficient amount of valid parameter keys, cannot fit")
def is_predictable(self, param_list):
"""