summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2021-03-17 17:56:01 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2021-03-17 17:56:01 +0100
commit53546c9c52b4b45525726babbeb01f7532922783 (patch)
tree7b71e101619134f19a1da57c7dbda99940649294
parent162a0c287f5dab664e9168a60d76c9f8da07e46a (diff)
always handle co-dependent parameters
-rwxr-xr-xbin/analyze-timing.py3
-rw-r--r--lib/model.py4
-rw-r--r--lib/parameters.py191
-rw-r--r--lib/paramfit.py8
-rw-r--r--lib/utils.py6
-rwxr-xr-xtest/test_ptamodel.py212
-rwxr-xr-xtest/test_timingharness.py40
7 files changed, 272 insertions, 192 deletions
diff --git a/bin/analyze-timing.py b/bin/analyze-timing.py
index 129b0ff..d67c553 100755
--- a/bin/analyze-timing.py
+++ b/bin/analyze-timing.py
@@ -84,7 +84,6 @@ from dfatool.functions import gplearn_to_function, StaticFunction, AnalyticFunct
from dfatool.model import AnalyticModel
from dfatool.validation import CrossValidator
from dfatool.utils import filter_aggregate_by_param, NpEncoder
-from dfatool.parameters import prune_dependent_parameters
opt = dict()
@@ -254,8 +253,6 @@ if __name__ == "__main__":
preprocessed_data, ignored_trace_indexes
)
- prune_dependent_parameters(by_name, parameters)
-
filter_aggregate_by_param(by_name, parameters, opt["filter-param"])
model = AnalyticModel(
diff --git a/lib/model.py b/lib/model.py
index 75f7195..5000caa 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -235,12 +235,12 @@ class AnalyticModel:
for name in self.names:
for attr in self.attr_by_name[name].keys():
- for key, param, args in self.attr_by_name[name][
+ for key, param, args, kwargs in self.attr_by_name[name][
attr
].get_data_for_paramfit(
safe_functions_enabled=safe_functions_enabled
):
- paramfit.enqueue(key, param, args)
+ paramfit.enqueue(key, param, args, kwargs)
paramfit.fit()
diff --git a/lib/parameters.py b/lib/parameters.py
index ea14ad2..99510b2 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -3,7 +3,6 @@ import itertools
import logging
import numpy as np
import os
-import warnings
from collections import OrderedDict
from copy import deepcopy
from multiprocessing import Pool
@@ -109,12 +108,10 @@ def _std_by_param(n_by_param, all_param_values, param_index):
# vprint(verbose, '[W] parameter value partition for {} is empty'.format(param_value))
if np.all(np.isnan(stddev_matrix)):
- warnings.warn(
- "parameter #{} has no data partitions. stddev_matrix = {}".format(
- param_index, stddev_matrix
- )
+ logger.warning(
+ f"parameter #{param_index} has no data partitions. All stddev_matrix entries are NaN."
)
- return stddev_matrix, 0.0
+ return stddev_matrix, 0.0, lut_matrix
return (
stddev_matrix,
@@ -159,7 +156,12 @@ def _corr_by_param(attribute_data, param_values, param_index):
def _compute_param_statistics(
- data, param_names, param_tuples, arg_count=None, use_corrcoef=False
+ data,
+ param_names,
+ param_tuples,
+ arg_count=None,
+ use_corrcoef=False,
+ codependent_params=list(),
):
"""
Compute standard deviation and correlation coefficient on parameterized data partitions.
@@ -230,8 +232,18 @@ def _compute_param_statistics(
np.seterr("raise")
for param_idx, param in enumerate(param_names):
+ if param_idx < len(codependent_params) and codependent_params[param_idx]:
+ by_param = partition_by_param(
+ data, param_tuples, ignore_parameters=codependent_params[param_idx]
+ )
+ distinct_values = ret["distinct_values_by_param_index"].copy()
+ for codependent_param_index in codependent_params[param_idx]:
+ distinct_values[codependent_param_index] = [None]
+ else:
+ by_param = ret["by_param"]
+ distinct_values = ret["distinct_values_by_param_index"]
std_matrix, mean_std, lut_matrix = _std_by_param(
- by_param, ret["distinct_values_by_param_index"], param_idx
+ by_param, distinct_values, param_idx
)
ret["std_by_param"][param] = mean_std
ret["std_by_param_values"][param] = std_matrix
@@ -246,17 +258,26 @@ def _compute_param_statistics(
if arg_count:
for arg_index in range(arg_count):
+ param_idx = len(param_names) + arg_index
+ if param_idx < len(codependent_params) and codependent_params[param_idx]:
+ by_param = partition_by_param(
+ data, param_tuples, ignore_parameters=codependent_params[param_idx]
+ )
+ distinct_values = ret["distinct_values_by_param_index"].copy()
+ for codependent_param_index in codependent_params[param_idx]:
+ distinct_values[codependent_param_index] = [None]
+ else:
+ by_param = ret["by_param"]
+ distinct_values = ret["distinct_values_by_param_index"]
std_matrix, mean_std, lut_matrix = _std_by_param(
by_param,
- ret["distinct_values_by_param_index"],
- len(param_names) + arg_index,
+ distinct_values,
+ param_idx,
)
ret["std_by_arg"].append(mean_std)
ret["std_by_arg_values"].append(std_matrix)
ret["lut_by_arg_values"].append(lut_matrix)
- ret["corr_by_arg"].append(
- _corr_by_param(data, param_tuples, len(param_names) + arg_index)
- )
+ ret["corr_by_arg"].append(_corr_by_param(data, param_tuples, param_idx))
if False:
ret["_depends_on_arg"].append(ret["corr_by_arg"][arg_index] > 0.1)
@@ -326,91 +347,6 @@ def _all_params_are_numeric(data, param_idx):
return False
-def prune_dependent_parameters(by_name, parameter_names, correlation_threshold=0.5):
- """
- Remove dependent parameters from aggregate.
-
- :param by_name: measurements partitioned by state/transition/... name and attribute, edited in-place.
- by_name[name][attribute] must be a list or 1-D numpy array.
- by_name[stanamete_or_trans]['param'] must be a list of parameter values.
- Other dict members are left as-is
- :param parameter_names: List of parameter names in the order they are used in by_name[name]['param'], edited in-place.
- :param correlation_threshold: Remove parameter if absolute correlation exceeds this threshold (default: 0.5)
-
- Model generation (and its components, such as relevant parameter detection and least squares optimization) only works if input variables (i.e., parameters)
- are independent of each other. This function computes the correlation coefficient for each pair of parameters and removes those which depend on each other.
- For each pair of dependent parameters, the lexically greater one is removed (e.g. "a" and "b" -> "b" is removed).
- """
-
- parameter_indices_to_remove = list()
- for parameter_combination in itertools.product(
- range(len(parameter_names)), range(len(parameter_names))
- ):
- index_1, index_2 = parameter_combination
- if index_1 >= index_2:
- continue
- parameter_values = [list(), list()] # both parameters have a value
- parameter_values_1 = list() # parameter 1 has a value
- parameter_values_2 = list() # parameter 2 has a value
- for name in by_name:
- for measurement in by_name[name]["param"]:
- value_1 = measurement[index_1]
- value_2 = measurement[index_2]
- if is_numeric(value_1):
- parameter_values_1.append(value_1)
- if is_numeric(value_2):
- parameter_values_2.append(value_2)
- if is_numeric(value_1) and is_numeric(value_2):
- parameter_values[0].append(value_1)
- parameter_values[1].append(value_2)
- if len(parameter_values[0]):
- # Calculating the correlation coefficient only makes sense when neither value is constant
- if np.std(parameter_values_1) != 0 and np.std(parameter_values_2) != 0:
- correlation = np.corrcoef(parameter_values)[0][1]
- if (
- correlation != np.nan
- and np.abs(correlation) > correlation_threshold
- ):
- logger.debug(
- "Parameters {} <-> {} are correlated with coefficcient {}".format(
- parameter_names[index_1],
- parameter_names[index_2],
- correlation,
- )
- )
- if len(parameter_values_1) < len(parameter_values_2):
- index_to_remove = index_1
- else:
- index_to_remove = index_2
- logger.debug(
- " Removing parameter {}".format(
- parameter_names[index_to_remove]
- )
- )
- parameter_indices_to_remove.append(index_to_remove)
- remove_parameters_by_indices(by_name, parameter_names, parameter_indices_to_remove)
-
-
-def remove_parameters_by_indices(by_name, parameter_names, parameter_indices_to_remove):
- """
- Remove parameters listed in `parameter_indices` from aggregate `by_name` and `parameter_names`.
-
- :param by_name: measurements partitioned by state/transition/... name and attribute, edited in-place.
- by_name[name][attribute] must be a list or 1-D numpy array.
- by_name[stanamete_or_trans]['param'] must be a list of parameter values.
- Other dict members are left as-is
- :param parameter_names: List of parameter names in the order they are used in by_name[name]['param'], edited in-place.
- :param parameter_indices_to_remove: List of parameter indices to be removed
- """
-
- # Start removal from the end of the list to avoid renumbering of list elemenets
- for parameter_index in sorted(parameter_indices_to_remove, reverse=True):
- for name in by_name:
- for measurement in by_name[name]["param"]:
- measurement.pop(parameter_index)
- parameter_names.pop(parameter_index)
-
-
class ParallelParamStats:
def __init__(self):
self.queue = list()
@@ -425,6 +361,8 @@ class ParallelParamStats:
attr.param_names,
attr.param_values,
attr.arg_count,
+ False,
+ attr.codependent_params,
],
}
)
@@ -458,6 +396,7 @@ class ParamStats:
attr.param_values,
arg_count=attr.arg_count,
use_corrcoef=use_corrcoef,
+ codependent_params=attr.codependent_params,
)
attr.by_param = res.pop("by_param")
attr.stats = cls(res)
@@ -609,9 +548,10 @@ class ModelAttribute:
map(lambda i: f"arg{i}", range(arg_count))
)
- # Co-dependent parameters. If (paam1_index, param2_index) in codependent_param, they are codependent.
+ # Co-dependent parameters. If (param1_index, param2_index) in codependent_param, they are codependent.
# In this case, only one of them must be used for parameter-dependent model attribute detection and modeling
- self.codependent_param = codependent_param
+ self.codependent_param_pair = codependent_param
+ self.codependent_params = [list() for x in self.log_param_names]
self.ignore_param = dict()
# Static model used as lower bound of model accuracy
@@ -663,7 +603,7 @@ class ModelAttribute:
for (
param1_index,
param2_index,
- ), is_codependent in self.codependent_param.items():
+ ), is_codependent in self.codependent_param_pair.items():
if not is_codependent:
continue
param1_values = map(lambda pv: pv[param1_index], self.param_values)
@@ -672,12 +612,14 @@ class ModelAttribute:
param2_numeric_count = sum(map(is_numeric, param2_values))
if param1_numeric_count >= param2_numeric_count:
self.ignore_param[param2_index] = True
- logger.warning(
+ self.codependent_params[param1_index].append(param2_index)
+ logger.info(
f"{self.name} {self.attr}: parameters ({self.log_param_names[param1_index]}, {self.log_param_names[param2_index]}) are codependent. Ignoring {self.log_param_names[param2_index]}"
)
else:
self.ignore_param[param1_index] = True
- logger.warning(
+ self.codependent_params[param2_index].append(param1_index)
+ logger.info(
f"{self.name} {self.attr}: parameters ({self.log_param_names[param1_index]}, {self.log_param_names[param2_index]}) are codependent. Ignoring {self.log_param_names[param1_index]}"
)
@@ -719,10 +661,22 @@ class ModelAttribute:
# )
child1 = ModelAttribute(
- self.name, self.attr, self.data[tt1], pv1, self.param_names, self.arg_count
+ self.name,
+ self.attr,
+ self.data[tt1],
+ pv1,
+ self.param_names,
+ self.arg_count,
+ codependent_param_dict(pv1),
)
child2 = ModelAttribute(
- self.name, self.attr, self.data[tt2], pv2, self.param_names, self.arg_count
+ self.name,
+ self.attr,
+ self.data[tt2],
+ pv2,
+ self.param_names,
+ self.arg_count,
+ codependent_param_dict(pv2),
)
ParamStats.compute_for_attr(child1)
@@ -814,10 +768,20 @@ class ModelAttribute:
child_ret = child.get_data_for_paramfit(
safe_functions_enabled=safe_functions_enabled
)
- for key, param, val in child_ret:
- ret.append((key[:2] + (param_value,) + key[2:], param, val))
+ for key, param, args, kwargs in child_ret:
+ ret.append((key[:2] + (param_value,) + key[2:], param, args, kwargs))
return ret
+ def _by_param_for_index(self, param_index):
+ if not self.codependent_params[param_index]:
+ return self.by_param
+ new_param_values = list()
+ for param_tuple in self.param_values:
+ for i in self.codependent_params[param_index]:
+ param_tuple[i] = None
+ new_param_values.append(param_tuple)
+ return partition_by_param(self.data, new_param_values)
+
def get_data_for_paramfit_this(self, safe_functions_enabled=False):
ret = list()
for param_index, param_name in enumerate(self.param_names):
@@ -825,28 +789,33 @@ class ModelAttribute:
self.stats.depends_on_param(param_name)
and not param_index in self.ignore_param
):
+ by_param = self._by_param_for_index(param_index)
ret.append(
(
(self.name, self.attr),
param_name,
- (self.by_param, param_index, safe_functions_enabled),
+ (by_param, param_index, safe_functions_enabled),
+ dict(),
)
)
if self.arg_count:
for arg_index in range(self.arg_count):
+ param_index = len(self.param_names) + arg_index
if (
self.stats.depends_on_arg(arg_index)
- and not arg_index + len(self.param_names) in self.ignore_param
+ and not param_index in self.ignore_param
):
+ by_param = self._by_param_for_index(param_index)
ret.append(
(
(self.name, self.attr),
arg_index,
(
- self.by_param,
- len(self.param_names) + arg_index,
+ by_param,
+ param_index,
safe_functions_enabled,
),
+ dict(),
)
)
diff --git a/lib/paramfit.py b/lib/paramfit.py
index eed8eed..8bc7505 100644
--- a/lib/paramfit.py
+++ b/lib/paramfit.py
@@ -29,7 +29,7 @@ class ParallelParamFit:
"""Create a new ParallelParamFit object."""
self.fit_queue = list()
- def enqueue(self, key, param, args):
+ def enqueue(self, key, param, args, kwargs=dict()):
"""
Add state_or_tran/attribute/param_name to fit queue.
@@ -41,7 +41,7 @@ class ParallelParamFit:
:param args: [by_param, param_index, safe_functions_enabled, param_filter]
by_param[(param 1, param2, ...)] holds measurements.
"""
- self.fit_queue.append({"key": (key, param), "args": args})
+ self.fit_queue.append({"key": (key, param), "args": args, "kwargs": kwargs})
def fit(self):
"""
@@ -102,11 +102,11 @@ class ParallelParamFit:
def _try_fits_parallel(arg):
"""
- Call _try_fits(*arg['args']) and return arg['key'] and the _try_fits result.
+ Call _try_fits(*arg['args'], **arg["kwargs"]) and return arg['key'] and the _try_fits result.
Must be a global function as it is called from a multiprocessing Pool.
"""
- return {"key": arg["key"], "result": _try_fits(*arg["args"])}
+ return {"key": arg["key"], "result": _try_fits(*arg["args"], **arg["kwargs"])}
def _try_fits(
diff --git a/lib/utils.py b/lib/utils.py
index e39b329..17cdb51 100644
--- a/lib/utils.py
+++ b/lib/utils.py
@@ -175,9 +175,13 @@ def match_parameter_values(input_param: dict, match_param: dict):
return True
-def partition_by_param(data, param_values):
+def partition_by_param(data, param_values, ignore_parameters=list()):
ret = dict()
for i, parameters in enumerate(param_values):
+ # ensure that parameters[param_index] = None does not affect the "param_values" entries passed to this function
+ parameters = list(parameters)
+ for param_index in ignore_parameters:
+ parameters[param_index] = None
param_key = tuple(parameters)
if param_key not in ret:
ret[param_key] = list()
diff --git a/test/test_ptamodel.py b/test/test_ptamodel.py
index 9d1b39b..ece83c5 100755
--- a/test/test_ptamodel.py
+++ b/test/test_ptamodel.py
@@ -694,74 +694,162 @@ class TestFromFile(unittest.TestCase):
self.assertAlmostEqual(
param_model(
- "write", "duration", param=[0, 76, 1000, 0, 10, None, None, 1500, 0]
+ "write",
+ "duration",
+ param=[0, 76, 1000, 0, 10, None, None, 1500, 0, None, 9, None, None],
),
- 1090,
+ 1133,
places=0,
)
- # only bitrate is relevant
+ # only bitrate and packet length are relevant
self.assertAlmostEqual(
param_model(
"write",
"duration",
- param=[0, None, 1000, None, None, None, None, None, None],
+ param=[
+ 0,
+ None,
+ 1000,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ 9,
+ None,
+ None,
+ ],
),
- 1090,
+ 1133,
places=0,
)
self.assertAlmostEqual(
param_model(
"write",
"duration",
- param=[0, None, 250, None, None, None, None, None, None],
+ param=[
+ 0,
+ None,
+ 250,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ 9,
+ None,
+ None,
+ ],
),
- 2057,
+ 2100,
places=0,
)
self.assertAlmostEqual(
param_model(
"write",
"duration",
- param=[0, None, 2000, None, None, None, None, None, None],
+ param=[
+ 0,
+ None,
+ 2000,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ 9,
+ None,
+ None,
+ ],
),
- 929,
+ 972,
places=0,
)
- # auto_ack == 1 has a different write duration, still only bitrate is relevant
+ # auto_ack == 1 has a different write duration, still only bitrate and packet length are relevant
self.assertAlmostEqual(
param_model(
- "write", "duration", param=[1, 76, 1000, 0, 10, None, None, 1500, 0]
+ "write",
+ "duration",
+ param=[1, 76, 1000, 0, 10, None, None, 1500, 0, None, 9, None, None],
),
- 22284,
+ 22327,
places=0,
)
self.assertAlmostEqual(
param_model(
"write",
"duration",
- param=[1, None, 1000, None, None, None, None, None, None],
+ param=[
+ 1,
+ None,
+ 1000,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ 9,
+ None,
+ None,
+ ],
),
- 22284,
+ 22327,
places=0,
)
self.assertAlmostEqual(
param_model(
"write",
"duration",
- param=[1, None, 250, None, None, None, None, None, None],
+ param=[
+ 1,
+ None,
+ 250,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ 9,
+ None,
+ None,
+ ],
),
- 33229,
+ 33273,
places=0,
)
self.assertAlmostEqual(
param_model(
"write",
"duration",
- param=[1, None, 2000, None, None, None, None, None, None],
+ param=[
+ 1,
+ None,
+ 2000,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ 9,
+ None,
+ None,
+ ],
),
- 20459,
+ 20503,
places=0,
)
@@ -864,12 +952,12 @@ class TestFromFile(unittest.TestCase):
250,
0,
None,
- None,
+ 9,
None,
None,
],
),
- 1150,
+ 1148,
places=0,
)
self.assertAlmostEqual(
@@ -887,12 +975,12 @@ class TestFromFile(unittest.TestCase):
250,
0,
None,
- None,
+ 9,
None,
None,
],
),
- 1150,
+ 1148,
places=0,
)
self.assertAlmostEqual(
@@ -910,12 +998,12 @@ class TestFromFile(unittest.TestCase):
2750,
0,
None,
- None,
+ 9,
None,
None,
],
),
- 1150,
+ 1148,
places=0,
)
self.assertAlmostEqual(
@@ -933,12 +1021,12 @@ class TestFromFile(unittest.TestCase):
2750,
15,
None,
- None,
+ 9,
None,
None,
],
),
- 1150,
+ 1148,
places=0,
)
@@ -958,12 +1046,12 @@ class TestFromFile(unittest.TestCase):
250,
0,
None,
- None,
+ 9,
None,
None,
],
),
- 1425,
+ 1473,
places=0,
)
self.assertAlmostEqual(
@@ -981,12 +1069,12 @@ class TestFromFile(unittest.TestCase):
250,
12,
None,
- None,
+ 9,
None,
None,
],
),
- 1425,
+ 1473,
places=0,
)
self.assertAlmostEqual(
@@ -1004,12 +1092,12 @@ class TestFromFile(unittest.TestCase):
250,
0,
None,
- None,
+ 9,
None,
None,
],
),
- 4982,
+ 5030,
places=0,
)
self.assertAlmostEqual(
@@ -1027,12 +1115,12 @@ class TestFromFile(unittest.TestCase):
250,
16,
None,
- None,
+ 9,
None,
None,
],
),
- 4982,
+ 5030,
places=0,
)
self.assertAlmostEqual(
@@ -1050,12 +1138,12 @@ class TestFromFile(unittest.TestCase):
2750,
0,
None,
- None,
+ 9,
None,
None,
],
),
- 19982,
+ 20029,
places=0,
)
self.assertAlmostEqual(
@@ -1073,12 +1161,12 @@ class TestFromFile(unittest.TestCase):
2750,
15,
None,
- None,
+ 9,
None,
None,
],
),
- 19982,
+ 20029,
places=0,
)
@@ -1099,12 +1187,12 @@ class TestFromFile(unittest.TestCase):
250,
0,
None,
- None,
+ 9,
None,
None,
],
),
- 12989,
+ 12420,
places=0,
)
self.assertAlmostEqual(
@@ -1122,12 +1210,12 @@ class TestFromFile(unittest.TestCase):
250,
12,
None,
- None,
+ 9,
None,
None,
],
),
- 20055,
+ 19172,
places=0,
)
self.assertAlmostEqual(
@@ -1145,12 +1233,12 @@ class TestFromFile(unittest.TestCase):
250,
0,
None,
- None,
+ 9,
None,
None,
],
),
- 12989,
+ 12420,
places=0,
)
self.assertAlmostEqual(
@@ -1168,12 +1256,12 @@ class TestFromFile(unittest.TestCase):
250,
12,
None,
- None,
+ 9,
None,
None,
],
),
- 20055,
+ 19172,
places=0,
)
self.assertAlmostEqual(
@@ -1191,12 +1279,12 @@ class TestFromFile(unittest.TestCase):
2750,
0,
None,
- None,
+ 9,
None,
None,
],
),
- 12989,
+ 12420,
places=0,
)
self.assertAlmostEqual(
@@ -1214,12 +1302,12 @@ class TestFromFile(unittest.TestCase):
2750,
12,
None,
- None,
+ 9,
None,
None,
],
),
- 20055,
+ 19172,
places=0,
)
@@ -1239,12 +1327,12 @@ class TestFromFile(unittest.TestCase):
250,
0,
None,
- None,
+ 9,
None,
None,
],
),
- 17255,
+ 16692,
places=0,
)
self.assertAlmostEqual(
@@ -1262,12 +1350,12 @@ class TestFromFile(unittest.TestCase):
250,
12,
None,
- None,
+ 9,
None,
None,
],
),
- 23074,
+ 22317,
places=0,
)
self.assertAlmostEqual(
@@ -1285,12 +1373,12 @@ class TestFromFile(unittest.TestCase):
250,
0,
None,
- None,
+ 9,
None,
None,
],
),
- 26633,
+ 26361,
places=0,
)
self.assertAlmostEqual(
@@ -1308,12 +1396,12 @@ class TestFromFile(unittest.TestCase):
250,
12,
None,
- None,
+ 9,
None,
None,
],
),
- 35656,
+ 35292,
places=0,
)
self.assertAlmostEqual(
@@ -1331,12 +1419,12 @@ class TestFromFile(unittest.TestCase):
2750,
0,
None,
- None,
+ 9,
None,
None,
],
),
- 7964,
+ 7931,
places=0,
)
self.assertAlmostEqual(
@@ -1354,12 +1442,12 @@ class TestFromFile(unittest.TestCase):
2750,
12,
None,
- None,
+ 9,
None,
None,
],
),
- 10400,
+ 10356,
places=0,
)
diff --git a/test/test_timingharness.py b/test/test_timingharness.py
index 06edc16..0741c7a 100755
--- a/test/test_timingharness.py
+++ b/test/test_timingharness.py
@@ -3,7 +3,6 @@
from dfatool.functions import StaticFunction
from dfatool.loader import TimingData, pta_trace_to_aggregate
from dfatool.model import AnalyticModel
-from dfatool.parameters import prune_dependent_parameters
import os
import unittest
@@ -36,11 +35,11 @@ class TestModels(unittest.TestCase):
self.assertIsInstance(param_info("setup", "duration"), StaticFunction)
self.assertEqual(
param_info("write", "duration").model_function,
- "0 + regression_arg(0) + regression_arg(1) * parameter(max_retry_count) + regression_arg(2) * parameter(retry_delay) + regression_arg(3) * parameter(max_retry_count) * parameter(retry_delay)",
+ "0 + regression_arg(0) + regression_arg(1) * parameter(max_retry_count) + regression_arg(2) * parameter(retry_delay) + regression_arg(3) * function_arg(1) + regression_arg(4) * parameter(max_retry_count) * parameter(retry_delay) + regression_arg(5) * parameter(max_retry_count) * function_arg(1) + regression_arg(6) * parameter(retry_delay) * function_arg(1) + regression_arg(7) * parameter(max_retry_count) * parameter(retry_delay) * function_arg(1)",
)
self.assertAlmostEqual(
- param_info("write", "duration").model_args[0], 1163, places=0
+ param_info("write", "duration").model_args[0], 1281, places=0
)
self.assertAlmostEqual(
param_info("write", "duration").model_args[1], 464, places=0
@@ -49,14 +48,25 @@ class TestModels(unittest.TestCase):
param_info("write", "duration").model_args[2], 1, places=0
)
self.assertAlmostEqual(
- param_info("write", "duration").model_args[3], 1, places=0
+ param_info("write", "duration").model_args[3], -9, places=0
+ )
+ self.assertAlmostEqual(
+ param_info("write", "duration").model_args[4], 1, places=0
+ )
+ self.assertAlmostEqual(
+ param_info("write", "duration").model_args[5], 0, places=0
+ )
+ self.assertAlmostEqual(
+ param_info("write", "duration").model_args[6], 0, places=0
+ )
+ self.assertAlmostEqual(
+ param_info("write", "duration").model_args[7], 0, places=0
)
def test_dependent_parameter_pruning(self):
raw_data = TimingData(["test-data/20190815_103347_nRF24_no-rx.json"])
preprocessed_data = raw_data.get_preprocessed_data()
by_name, parameters, arg_count = pta_trace_to_aggregate(preprocessed_data)
- prune_dependent_parameters(by_name, parameters)
model = AnalyticModel(by_name, parameters, arg_count)
self.assertEqual(
model.names, "getObserveTx setPALevel setRetries setup write".split(" ")
@@ -84,20 +94,32 @@ class TestModels(unittest.TestCase):
self.assertIsInstance(param_info("setup", "duration"), StaticFunction)
self.assertEqual(
param_info("write", "duration").model_function,
- "0 + regression_arg(0) + regression_arg(1) * parameter(max_retry_count) + regression_arg(2) * parameter(retry_delay) + regression_arg(3) * parameter(max_retry_count) * parameter(retry_delay)",
+ "0 + regression_arg(0) + regression_arg(1) * parameter(max_retry_count) + regression_arg(2) * parameter(retry_delay) + regression_arg(3) * function_arg(1) + regression_arg(4) * parameter(max_retry_count) * parameter(retry_delay) + regression_arg(5) * parameter(max_retry_count) * function_arg(1) + regression_arg(6) * parameter(retry_delay) * function_arg(1) + regression_arg(7) * parameter(max_retry_count) * parameter(retry_delay) * function_arg(1)",
)
self.assertAlmostEqual(
- param_info("write", "duration").model_args[0], 1163, places=0
+ param_info("write", "duration").model_args[0], 1282, places=0
)
self.assertAlmostEqual(
- param_info("write", "duration").model_args[1], 464, places=0
+ param_info("write", "duration").model_args[1], 463, places=0
)
self.assertAlmostEqual(
param_info("write", "duration").model_args[2], 1, places=0
)
self.assertAlmostEqual(
- param_info("write", "duration").model_args[3], 1, places=0
+ param_info("write", "duration").model_args[3], -9, places=0
+ )
+ self.assertAlmostEqual(
+ param_info("write", "duration").model_args[4], 1, places=0
+ )
+ self.assertAlmostEqual(
+ param_info("write", "duration").model_args[5], 0, places=0
+ )
+ self.assertAlmostEqual(
+ param_info("write", "duration").model_args[6], 0, places=0
+ )
+ self.assertAlmostEqual(
+ param_info("write", "duration").model_args[7], 0, places=0
)
def test_function_override(self):