diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2021-03-17 17:56:01 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2021-03-17 17:56:01 +0100 |
commit | 53546c9c52b4b45525726babbeb01f7532922783 (patch) | |
tree | 7b71e101619134f19a1da57c7dbda99940649294 | |
parent | 162a0c287f5dab664e9168a60d76c9f8da07e46a (diff) |
always handle co-dependent parameters
-rwxr-xr-x | bin/analyze-timing.py | 3 | ||||
-rw-r--r-- | lib/model.py | 4 | ||||
-rw-r--r-- | lib/parameters.py | 191 | ||||
-rw-r--r-- | lib/paramfit.py | 8 | ||||
-rw-r--r-- | lib/utils.py | 6 | ||||
-rwxr-xr-x | test/test_ptamodel.py | 212 | ||||
-rwxr-xr-x | test/test_timingharness.py | 40 |
7 files changed, 272 insertions, 192 deletions
diff --git a/bin/analyze-timing.py b/bin/analyze-timing.py index 129b0ff..d67c553 100755 --- a/bin/analyze-timing.py +++ b/bin/analyze-timing.py @@ -84,7 +84,6 @@ from dfatool.functions import gplearn_to_function, StaticFunction, AnalyticFunct from dfatool.model import AnalyticModel from dfatool.validation import CrossValidator from dfatool.utils import filter_aggregate_by_param, NpEncoder -from dfatool.parameters import prune_dependent_parameters opt = dict() @@ -254,8 +253,6 @@ if __name__ == "__main__": preprocessed_data, ignored_trace_indexes ) - prune_dependent_parameters(by_name, parameters) - filter_aggregate_by_param(by_name, parameters, opt["filter-param"]) model = AnalyticModel( diff --git a/lib/model.py b/lib/model.py index 75f7195..5000caa 100644 --- a/lib/model.py +++ b/lib/model.py @@ -235,12 +235,12 @@ class AnalyticModel: for name in self.names: for attr in self.attr_by_name[name].keys(): - for key, param, args in self.attr_by_name[name][ + for key, param, args, kwargs in self.attr_by_name[name][ attr ].get_data_for_paramfit( safe_functions_enabled=safe_functions_enabled ): - paramfit.enqueue(key, param, args) + paramfit.enqueue(key, param, args, kwargs) paramfit.fit() diff --git a/lib/parameters.py b/lib/parameters.py index ea14ad2..99510b2 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -3,7 +3,6 @@ import itertools import logging import numpy as np import os -import warnings from collections import OrderedDict from copy import deepcopy from multiprocessing import Pool @@ -109,12 +108,10 @@ def _std_by_param(n_by_param, all_param_values, param_index): # vprint(verbose, '[W] parameter value partition for {} is empty'.format(param_value)) if np.all(np.isnan(stddev_matrix)): - warnings.warn( - "parameter #{} has no data partitions. stddev_matrix = {}".format( - param_index, stddev_matrix - ) + logger.warning( + f"parameter #{param_index} has no data partitions. All stddev_matrix entries are NaN." ) - return stddev_matrix, 0.0 + return stddev_matrix, 0.0, lut_matrix return ( stddev_matrix, @@ -159,7 +156,12 @@ def _corr_by_param(attribute_data, param_values, param_index): def _compute_param_statistics( - data, param_names, param_tuples, arg_count=None, use_corrcoef=False + data, + param_names, + param_tuples, + arg_count=None, + use_corrcoef=False, + codependent_params=list(), ): """ Compute standard deviation and correlation coefficient on parameterized data partitions. @@ -230,8 +232,18 @@ def _compute_param_statistics( np.seterr("raise") for param_idx, param in enumerate(param_names): + if param_idx < len(codependent_params) and codependent_params[param_idx]: + by_param = partition_by_param( + data, param_tuples, ignore_parameters=codependent_params[param_idx] + ) + distinct_values = ret["distinct_values_by_param_index"].copy() + for codependent_param_index in codependent_params[param_idx]: + distinct_values[codependent_param_index] = [None] + else: + by_param = ret["by_param"] + distinct_values = ret["distinct_values_by_param_index"] std_matrix, mean_std, lut_matrix = _std_by_param( - by_param, ret["distinct_values_by_param_index"], param_idx + by_param, distinct_values, param_idx ) ret["std_by_param"][param] = mean_std ret["std_by_param_values"][param] = std_matrix @@ -246,17 +258,26 @@ def _compute_param_statistics( if arg_count: for arg_index in range(arg_count): + param_idx = len(param_names) + arg_index + if param_idx < len(codependent_params) and codependent_params[param_idx]: + by_param = partition_by_param( + data, param_tuples, ignore_parameters=codependent_params[param_idx] + ) + distinct_values = ret["distinct_values_by_param_index"].copy() + for codependent_param_index in codependent_params[param_idx]: + distinct_values[codependent_param_index] = [None] + else: + by_param = ret["by_param"] + distinct_values = ret["distinct_values_by_param_index"] std_matrix, mean_std, lut_matrix = _std_by_param( by_param, - ret["distinct_values_by_param_index"], - len(param_names) + arg_index, + distinct_values, + param_idx, ) ret["std_by_arg"].append(mean_std) ret["std_by_arg_values"].append(std_matrix) ret["lut_by_arg_values"].append(lut_matrix) - ret["corr_by_arg"].append( - _corr_by_param(data, param_tuples, len(param_names) + arg_index) - ) + ret["corr_by_arg"].append(_corr_by_param(data, param_tuples, param_idx)) if False: ret["_depends_on_arg"].append(ret["corr_by_arg"][arg_index] > 0.1) @@ -326,91 +347,6 @@ def _all_params_are_numeric(data, param_idx): return False -def prune_dependent_parameters(by_name, parameter_names, correlation_threshold=0.5): - """ - Remove dependent parameters from aggregate. - - :param by_name: measurements partitioned by state/transition/... name and attribute, edited in-place. - by_name[name][attribute] must be a list or 1-D numpy array. - by_name[stanamete_or_trans]['param'] must be a list of parameter values. - Other dict members are left as-is - :param parameter_names: List of parameter names in the order they are used in by_name[name]['param'], edited in-place. - :param correlation_threshold: Remove parameter if absolute correlation exceeds this threshold (default: 0.5) - - Model generation (and its components, such as relevant parameter detection and least squares optimization) only works if input variables (i.e., parameters) - are independent of each other. This function computes the correlation coefficient for each pair of parameters and removes those which depend on each other. - For each pair of dependent parameters, the lexically greater one is removed (e.g. "a" and "b" -> "b" is removed). - """ - - parameter_indices_to_remove = list() - for parameter_combination in itertools.product( - range(len(parameter_names)), range(len(parameter_names)) - ): - index_1, index_2 = parameter_combination - if index_1 >= index_2: - continue - parameter_values = [list(), list()] # both parameters have a value - parameter_values_1 = list() # parameter 1 has a value - parameter_values_2 = list() # parameter 2 has a value - for name in by_name: - for measurement in by_name[name]["param"]: - value_1 = measurement[index_1] - value_2 = measurement[index_2] - if is_numeric(value_1): - parameter_values_1.append(value_1) - if is_numeric(value_2): - parameter_values_2.append(value_2) - if is_numeric(value_1) and is_numeric(value_2): - parameter_values[0].append(value_1) - parameter_values[1].append(value_2) - if len(parameter_values[0]): - # Calculating the correlation coefficient only makes sense when neither value is constant - if np.std(parameter_values_1) != 0 and np.std(parameter_values_2) != 0: - correlation = np.corrcoef(parameter_values)[0][1] - if ( - correlation != np.nan - and np.abs(correlation) > correlation_threshold - ): - logger.debug( - "Parameters {} <-> {} are correlated with coefficcient {}".format( - parameter_names[index_1], - parameter_names[index_2], - correlation, - ) - ) - if len(parameter_values_1) < len(parameter_values_2): - index_to_remove = index_1 - else: - index_to_remove = index_2 - logger.debug( - " Removing parameter {}".format( - parameter_names[index_to_remove] - ) - ) - parameter_indices_to_remove.append(index_to_remove) - remove_parameters_by_indices(by_name, parameter_names, parameter_indices_to_remove) - - -def remove_parameters_by_indices(by_name, parameter_names, parameter_indices_to_remove): - """ - Remove parameters listed in `parameter_indices` from aggregate `by_name` and `parameter_names`. - - :param by_name: measurements partitioned by state/transition/... name and attribute, edited in-place. - by_name[name][attribute] must be a list or 1-D numpy array. - by_name[stanamete_or_trans]['param'] must be a list of parameter values. - Other dict members are left as-is - :param parameter_names: List of parameter names in the order they are used in by_name[name]['param'], edited in-place. - :param parameter_indices_to_remove: List of parameter indices to be removed - """ - - # Start removal from the end of the list to avoid renumbering of list elemenets - for parameter_index in sorted(parameter_indices_to_remove, reverse=True): - for name in by_name: - for measurement in by_name[name]["param"]: - measurement.pop(parameter_index) - parameter_names.pop(parameter_index) - - class ParallelParamStats: def __init__(self): self.queue = list() @@ -425,6 +361,8 @@ class ParallelParamStats: attr.param_names, attr.param_values, attr.arg_count, + False, + attr.codependent_params, ], } ) @@ -458,6 +396,7 @@ class ParamStats: attr.param_values, arg_count=attr.arg_count, use_corrcoef=use_corrcoef, + codependent_params=attr.codependent_params, ) attr.by_param = res.pop("by_param") attr.stats = cls(res) @@ -609,9 +548,10 @@ class ModelAttribute: map(lambda i: f"arg{i}", range(arg_count)) ) - # Co-dependent parameters. If (paam1_index, param2_index) in codependent_param, they are codependent. + # Co-dependent parameters. If (param1_index, param2_index) in codependent_param, they are codependent. # In this case, only one of them must be used for parameter-dependent model attribute detection and modeling - self.codependent_param = codependent_param + self.codependent_param_pair = codependent_param + self.codependent_params = [list() for x in self.log_param_names] self.ignore_param = dict() # Static model used as lower bound of model accuracy @@ -663,7 +603,7 @@ class ModelAttribute: for ( param1_index, param2_index, - ), is_codependent in self.codependent_param.items(): + ), is_codependent in self.codependent_param_pair.items(): if not is_codependent: continue param1_values = map(lambda pv: pv[param1_index], self.param_values) @@ -672,12 +612,14 @@ class ModelAttribute: param2_numeric_count = sum(map(is_numeric, param2_values)) if param1_numeric_count >= param2_numeric_count: self.ignore_param[param2_index] = True - logger.warning( + self.codependent_params[param1_index].append(param2_index) + logger.info( f"{self.name} {self.attr}: parameters ({self.log_param_names[param1_index]}, {self.log_param_names[param2_index]}) are codependent. Ignoring {self.log_param_names[param2_index]}" ) else: self.ignore_param[param1_index] = True - logger.warning( + self.codependent_params[param2_index].append(param1_index) + logger.info( f"{self.name} {self.attr}: parameters ({self.log_param_names[param1_index]}, {self.log_param_names[param2_index]}) are codependent. Ignoring {self.log_param_names[param1_index]}" ) @@ -719,10 +661,22 @@ class ModelAttribute: # ) child1 = ModelAttribute( - self.name, self.attr, self.data[tt1], pv1, self.param_names, self.arg_count + self.name, + self.attr, + self.data[tt1], + pv1, + self.param_names, + self.arg_count, + codependent_param_dict(pv1), ) child2 = ModelAttribute( - self.name, self.attr, self.data[tt2], pv2, self.param_names, self.arg_count + self.name, + self.attr, + self.data[tt2], + pv2, + self.param_names, + self.arg_count, + codependent_param_dict(pv2), ) ParamStats.compute_for_attr(child1) @@ -814,10 +768,20 @@ class ModelAttribute: child_ret = child.get_data_for_paramfit( safe_functions_enabled=safe_functions_enabled ) - for key, param, val in child_ret: - ret.append((key[:2] + (param_value,) + key[2:], param, val)) + for key, param, args, kwargs in child_ret: + ret.append((key[:2] + (param_value,) + key[2:], param, args, kwargs)) return ret + def _by_param_for_index(self, param_index): + if not self.codependent_params[param_index]: + return self.by_param + new_param_values = list() + for param_tuple in self.param_values: + for i in self.codependent_params[param_index]: + param_tuple[i] = None + new_param_values.append(param_tuple) + return partition_by_param(self.data, new_param_values) + def get_data_for_paramfit_this(self, safe_functions_enabled=False): ret = list() for param_index, param_name in enumerate(self.param_names): @@ -825,28 +789,33 @@ class ModelAttribute: self.stats.depends_on_param(param_name) and not param_index in self.ignore_param ): + by_param = self._by_param_for_index(param_index) ret.append( ( (self.name, self.attr), param_name, - (self.by_param, param_index, safe_functions_enabled), + (by_param, param_index, safe_functions_enabled), + dict(), ) ) if self.arg_count: for arg_index in range(self.arg_count): + param_index = len(self.param_names) + arg_index if ( self.stats.depends_on_arg(arg_index) - and not arg_index + len(self.param_names) in self.ignore_param + and not param_index in self.ignore_param ): + by_param = self._by_param_for_index(param_index) ret.append( ( (self.name, self.attr), arg_index, ( - self.by_param, - len(self.param_names) + arg_index, + by_param, + param_index, safe_functions_enabled, ), + dict(), ) ) diff --git a/lib/paramfit.py b/lib/paramfit.py index eed8eed..8bc7505 100644 --- a/lib/paramfit.py +++ b/lib/paramfit.py @@ -29,7 +29,7 @@ class ParallelParamFit: """Create a new ParallelParamFit object.""" self.fit_queue = list() - def enqueue(self, key, param, args): + def enqueue(self, key, param, args, kwargs=dict()): """ Add state_or_tran/attribute/param_name to fit queue. @@ -41,7 +41,7 @@ class ParallelParamFit: :param args: [by_param, param_index, safe_functions_enabled, param_filter] by_param[(param 1, param2, ...)] holds measurements. """ - self.fit_queue.append({"key": (key, param), "args": args}) + self.fit_queue.append({"key": (key, param), "args": args, "kwargs": kwargs}) def fit(self): """ @@ -102,11 +102,11 @@ class ParallelParamFit: def _try_fits_parallel(arg): """ - Call _try_fits(*arg['args']) and return arg['key'] and the _try_fits result. + Call _try_fits(*arg['args'], **arg["kwargs"]) and return arg['key'] and the _try_fits result. Must be a global function as it is called from a multiprocessing Pool. """ - return {"key": arg["key"], "result": _try_fits(*arg["args"])} + return {"key": arg["key"], "result": _try_fits(*arg["args"], **arg["kwargs"])} def _try_fits( diff --git a/lib/utils.py b/lib/utils.py index e39b329..17cdb51 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -175,9 +175,13 @@ def match_parameter_values(input_param: dict, match_param: dict): return True -def partition_by_param(data, param_values): +def partition_by_param(data, param_values, ignore_parameters=list()): ret = dict() for i, parameters in enumerate(param_values): + # ensure that parameters[param_index] = None does not affect the "param_values" entries passed to this function + parameters = list(parameters) + for param_index in ignore_parameters: + parameters[param_index] = None param_key = tuple(parameters) if param_key not in ret: ret[param_key] = list() diff --git a/test/test_ptamodel.py b/test/test_ptamodel.py index 9d1b39b..ece83c5 100755 --- a/test/test_ptamodel.py +++ b/test/test_ptamodel.py @@ -694,74 +694,162 @@ class TestFromFile(unittest.TestCase): self.assertAlmostEqual( param_model( - "write", "duration", param=[0, 76, 1000, 0, 10, None, None, 1500, 0] + "write", + "duration", + param=[0, 76, 1000, 0, 10, None, None, 1500, 0, None, 9, None, None], ), - 1090, + 1133, places=0, ) - # only bitrate is relevant + # only bitrate and packet length are relevant self.assertAlmostEqual( param_model( "write", "duration", - param=[0, None, 1000, None, None, None, None, None, None], + param=[ + 0, + None, + 1000, + None, + None, + None, + None, + None, + None, + None, + 9, + None, + None, + ], ), - 1090, + 1133, places=0, ) self.assertAlmostEqual( param_model( "write", "duration", - param=[0, None, 250, None, None, None, None, None, None], + param=[ + 0, + None, + 250, + None, + None, + None, + None, + None, + None, + None, + 9, + None, + None, + ], ), - 2057, + 2100, places=0, ) self.assertAlmostEqual( param_model( "write", "duration", - param=[0, None, 2000, None, None, None, None, None, None], + param=[ + 0, + None, + 2000, + None, + None, + None, + None, + None, + None, + None, + 9, + None, + None, + ], ), - 929, + 972, places=0, ) - # auto_ack == 1 has a different write duration, still only bitrate is relevant + # auto_ack == 1 has a different write duration, still only bitrate and packet length are relevant self.assertAlmostEqual( param_model( - "write", "duration", param=[1, 76, 1000, 0, 10, None, None, 1500, 0] + "write", + "duration", + param=[1, 76, 1000, 0, 10, None, None, 1500, 0, None, 9, None, None], ), - 22284, + 22327, places=0, ) self.assertAlmostEqual( param_model( "write", "duration", - param=[1, None, 1000, None, None, None, None, None, None], + param=[ + 1, + None, + 1000, + None, + None, + None, + None, + None, + None, + None, + 9, + None, + None, + ], ), - 22284, + 22327, places=0, ) self.assertAlmostEqual( param_model( "write", "duration", - param=[1, None, 250, None, None, None, None, None, None], + param=[ + 1, + None, + 250, + None, + None, + None, + None, + None, + None, + None, + 9, + None, + None, + ], ), - 33229, + 33273, places=0, ) self.assertAlmostEqual( param_model( "write", "duration", - param=[1, None, 2000, None, None, None, None, None, None], + param=[ + 1, + None, + 2000, + None, + None, + None, + None, + None, + None, + None, + 9, + None, + None, + ], ), - 20459, + 20503, places=0, ) @@ -864,12 +952,12 @@ class TestFromFile(unittest.TestCase): 250, 0, None, - None, + 9, None, None, ], ), - 1150, + 1148, places=0, ) self.assertAlmostEqual( @@ -887,12 +975,12 @@ class TestFromFile(unittest.TestCase): 250, 0, None, - None, + 9, None, None, ], ), - 1150, + 1148, places=0, ) self.assertAlmostEqual( @@ -910,12 +998,12 @@ class TestFromFile(unittest.TestCase): 2750, 0, None, - None, + 9, None, None, ], ), - 1150, + 1148, places=0, ) self.assertAlmostEqual( @@ -933,12 +1021,12 @@ class TestFromFile(unittest.TestCase): 2750, 15, None, - None, + 9, None, None, ], ), - 1150, + 1148, places=0, ) @@ -958,12 +1046,12 @@ class TestFromFile(unittest.TestCase): 250, 0, None, - None, + 9, None, None, ], ), - 1425, + 1473, places=0, ) self.assertAlmostEqual( @@ -981,12 +1069,12 @@ class TestFromFile(unittest.TestCase): 250, 12, None, - None, + 9, None, None, ], ), - 1425, + 1473, places=0, ) self.assertAlmostEqual( @@ -1004,12 +1092,12 @@ class TestFromFile(unittest.TestCase): 250, 0, None, - None, + 9, None, None, ], ), - 4982, + 5030, places=0, ) self.assertAlmostEqual( @@ -1027,12 +1115,12 @@ class TestFromFile(unittest.TestCase): 250, 16, None, - None, + 9, None, None, ], ), - 4982, + 5030, places=0, ) self.assertAlmostEqual( @@ -1050,12 +1138,12 @@ class TestFromFile(unittest.TestCase): 2750, 0, None, - None, + 9, None, None, ], ), - 19982, + 20029, places=0, ) self.assertAlmostEqual( @@ -1073,12 +1161,12 @@ class TestFromFile(unittest.TestCase): 2750, 15, None, - None, + 9, None, None, ], ), - 19982, + 20029, places=0, ) @@ -1099,12 +1187,12 @@ class TestFromFile(unittest.TestCase): 250, 0, None, - None, + 9, None, None, ], ), - 12989, + 12420, places=0, ) self.assertAlmostEqual( @@ -1122,12 +1210,12 @@ class TestFromFile(unittest.TestCase): 250, 12, None, - None, + 9, None, None, ], ), - 20055, + 19172, places=0, ) self.assertAlmostEqual( @@ -1145,12 +1233,12 @@ class TestFromFile(unittest.TestCase): 250, 0, None, - None, + 9, None, None, ], ), - 12989, + 12420, places=0, ) self.assertAlmostEqual( @@ -1168,12 +1256,12 @@ class TestFromFile(unittest.TestCase): 250, 12, None, - None, + 9, None, None, ], ), - 20055, + 19172, places=0, ) self.assertAlmostEqual( @@ -1191,12 +1279,12 @@ class TestFromFile(unittest.TestCase): 2750, 0, None, - None, + 9, None, None, ], ), - 12989, + 12420, places=0, ) self.assertAlmostEqual( @@ -1214,12 +1302,12 @@ class TestFromFile(unittest.TestCase): 2750, 12, None, - None, + 9, None, None, ], ), - 20055, + 19172, places=0, ) @@ -1239,12 +1327,12 @@ class TestFromFile(unittest.TestCase): 250, 0, None, - None, + 9, None, None, ], ), - 17255, + 16692, places=0, ) self.assertAlmostEqual( @@ -1262,12 +1350,12 @@ class TestFromFile(unittest.TestCase): 250, 12, None, - None, + 9, None, None, ], ), - 23074, + 22317, places=0, ) self.assertAlmostEqual( @@ -1285,12 +1373,12 @@ class TestFromFile(unittest.TestCase): 250, 0, None, - None, + 9, None, None, ], ), - 26633, + 26361, places=0, ) self.assertAlmostEqual( @@ -1308,12 +1396,12 @@ class TestFromFile(unittest.TestCase): 250, 12, None, - None, + 9, None, None, ], ), - 35656, + 35292, places=0, ) self.assertAlmostEqual( @@ -1331,12 +1419,12 @@ class TestFromFile(unittest.TestCase): 2750, 0, None, - None, + 9, None, None, ], ), - 7964, + 7931, places=0, ) self.assertAlmostEqual( @@ -1354,12 +1442,12 @@ class TestFromFile(unittest.TestCase): 2750, 12, None, - None, + 9, None, None, ], ), - 10400, + 10356, places=0, ) diff --git a/test/test_timingharness.py b/test/test_timingharness.py index 06edc16..0741c7a 100755 --- a/test/test_timingharness.py +++ b/test/test_timingharness.py @@ -3,7 +3,6 @@ from dfatool.functions import StaticFunction from dfatool.loader import TimingData, pta_trace_to_aggregate from dfatool.model import AnalyticModel -from dfatool.parameters import prune_dependent_parameters import os import unittest @@ -36,11 +35,11 @@ class TestModels(unittest.TestCase): self.assertIsInstance(param_info("setup", "duration"), StaticFunction) self.assertEqual( param_info("write", "duration").model_function, - "0 + regression_arg(0) + regression_arg(1) * parameter(max_retry_count) + regression_arg(2) * parameter(retry_delay) + regression_arg(3) * parameter(max_retry_count) * parameter(retry_delay)", + "0 + regression_arg(0) + regression_arg(1) * parameter(max_retry_count) + regression_arg(2) * parameter(retry_delay) + regression_arg(3) * function_arg(1) + regression_arg(4) * parameter(max_retry_count) * parameter(retry_delay) + regression_arg(5) * parameter(max_retry_count) * function_arg(1) + regression_arg(6) * parameter(retry_delay) * function_arg(1) + regression_arg(7) * parameter(max_retry_count) * parameter(retry_delay) * function_arg(1)", ) self.assertAlmostEqual( - param_info("write", "duration").model_args[0], 1163, places=0 + param_info("write", "duration").model_args[0], 1281, places=0 ) self.assertAlmostEqual( param_info("write", "duration").model_args[1], 464, places=0 @@ -49,14 +48,25 @@ class TestModels(unittest.TestCase): param_info("write", "duration").model_args[2], 1, places=0 ) self.assertAlmostEqual( - param_info("write", "duration").model_args[3], 1, places=0 + param_info("write", "duration").model_args[3], -9, places=0 + ) + self.assertAlmostEqual( + param_info("write", "duration").model_args[4], 1, places=0 + ) + self.assertAlmostEqual( + param_info("write", "duration").model_args[5], 0, places=0 + ) + self.assertAlmostEqual( + param_info("write", "duration").model_args[6], 0, places=0 + ) + self.assertAlmostEqual( + param_info("write", "duration").model_args[7], 0, places=0 ) def test_dependent_parameter_pruning(self): raw_data = TimingData(["test-data/20190815_103347_nRF24_no-rx.json"]) preprocessed_data = raw_data.get_preprocessed_data() by_name, parameters, arg_count = pta_trace_to_aggregate(preprocessed_data) - prune_dependent_parameters(by_name, parameters) model = AnalyticModel(by_name, parameters, arg_count) self.assertEqual( model.names, "getObserveTx setPALevel setRetries setup write".split(" ") @@ -84,20 +94,32 @@ class TestModels(unittest.TestCase): self.assertIsInstance(param_info("setup", "duration"), StaticFunction) self.assertEqual( param_info("write", "duration").model_function, - "0 + regression_arg(0) + regression_arg(1) * parameter(max_retry_count) + regression_arg(2) * parameter(retry_delay) + regression_arg(3) * parameter(max_retry_count) * parameter(retry_delay)", + "0 + regression_arg(0) + regression_arg(1) * parameter(max_retry_count) + regression_arg(2) * parameter(retry_delay) + regression_arg(3) * function_arg(1) + regression_arg(4) * parameter(max_retry_count) * parameter(retry_delay) + regression_arg(5) * parameter(max_retry_count) * function_arg(1) + regression_arg(6) * parameter(retry_delay) * function_arg(1) + regression_arg(7) * parameter(max_retry_count) * parameter(retry_delay) * function_arg(1)", ) self.assertAlmostEqual( - param_info("write", "duration").model_args[0], 1163, places=0 + param_info("write", "duration").model_args[0], 1282, places=0 ) self.assertAlmostEqual( - param_info("write", "duration").model_args[1], 464, places=0 + param_info("write", "duration").model_args[1], 463, places=0 ) self.assertAlmostEqual( param_info("write", "duration").model_args[2], 1, places=0 ) self.assertAlmostEqual( - param_info("write", "duration").model_args[3], 1, places=0 + param_info("write", "duration").model_args[3], -9, places=0 + ) + self.assertAlmostEqual( + param_info("write", "duration").model_args[4], 1, places=0 + ) + self.assertAlmostEqual( + param_info("write", "duration").model_args[5], 0, places=0 + ) + self.assertAlmostEqual( + param_info("write", "duration").model_args[6], 0, places=0 + ) + self.assertAlmostEqual( + param_info("write", "duration").model_args[7], 0, places=0 ) def test_function_override(self): |