summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2021-02-26 08:55:25 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2021-02-26 08:55:25 +0100
commit2be86e1d7cccb6f2b0daaa5c69c95796537b47c1 (patch)
tree5631ff4dbe6e97ef2914b2a48414515ef6a24e94
parentc2865937f841863b3d21499e18d72b144ee9dac0 (diff)
add simple decisiontree test
-rw-r--r--lib/model.py2
-rwxr-xr-xtest/test_ptamodel.py157
2 files changed, 156 insertions, 3 deletions
diff --git a/lib/model.py b/lib/model.py
index 451a39a..83c31b1 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -422,7 +422,7 @@ class ModelAttribute:
# )
# None -> kein split notwendig
- # andernfalls: Parameter, anhand dessen eine Decision Tree-Ebene aufgespannt wird
+ # andernfalls: Parameter-Index, anhand dessen eine Decision Tree-Ebene aufgespannt wird
# (Kinder sind wiederum ModelAttributes, in denen dieser Parameter konstant ist)
def get_split_param_index(self):
if not self.param_names:
diff --git a/test/test_ptamodel.py b/test/test_ptamodel.py
index bcbb19a..dad4328 100755
--- a/test/test_ptamodel.py
+++ b/test/test_ptamodel.py
@@ -338,8 +338,6 @@ class TestSynthetic(unittest.TestCase):
static_quality = validator.kfold(lambda m: m.get_static(), 10)
param_quality = validator.kfold(lambda m: m.get_fitted()[0], 10)
- print(static_quality)
-
# static quality reflects normal distribution scale for non-parameterized data
# the Root Mean Square Deviation must not be greater the scale (i.e., standard deviation) of the normal distribution
@@ -669,6 +667,161 @@ class TestFromFile(unittest.TestCase):
param_model("RX", "power", param=[1, None, None]), 48647, places=-1
)
+ def test_decisiontrees_rf24(self):
+ raw_data = RawData(["test-data/20191024-152648-nrf24l01.tar"])
+ preprocessed_data = raw_data.get_preprocessed_data()
+ by_name, parameters, arg_count = pta_trace_to_aggregate(preprocessed_data)
+ model = PTAModel(by_name, parameters, arg_count)
+ self.assertEqual(model.states(), "RX STANDBY1".split(" "))
+ self.assertEqual(
+ model.transitions(),
+ "setAutoAck setDataRate setPALevel setup startListening stopListening write".split(
+ " "
+ ),
+ )
+ static_model = model.get_static()
+ self.assertAlmostEqual(static_model("RX", "power"), 47964, places=0)
+ self.assertAlmostEqual(static_model("STANDBY1", "power"), 128, places=0)
+ self.assertAlmostEqual(static_model("setAutoAck", "power"), 151, places=0)
+ self.assertAlmostEqual(static_model("setDataRate", "power"), 146, places=0)
+ self.assertAlmostEqual(static_model("setPALevel", "power"), 147, places=0)
+ self.assertAlmostEqual(static_model("setup", "power"), 153, places=0)
+ self.assertAlmostEqual(static_model("startListening", "power"), 18954, places=0)
+ self.assertAlmostEqual(static_model("stopListening", "power"), 2426, places=0)
+ self.assertAlmostEqual(static_model("write", "power"), 17629, places=0)
+
+ self.assertAlmostEqual(static_model("setAutoAck", "duration"), 90, places=0)
+ self.assertAlmostEqual(static_model("setDataRate", "duration"), 240, places=0)
+ self.assertAlmostEqual(static_model("setPALevel", "duration"), 160, places=0)
+ self.assertAlmostEqual(static_model("setup", "duration"), 6550, places=0)
+ self.assertAlmostEqual(
+ static_model("startListening", "duration"), 470, places=0
+ )
+ self.assertAlmostEqual(static_model("stopListening", "duration"), 510, places=0)
+ self.assertAlmostEqual(static_model("write", "duration"), 11230, places=0)
+
+ self.assertAlmostEqual(
+ model.attr_by_name["write"]["duration"].stats.param_dependence_ratio(
+ "auto_ack!"
+ ),
+ 1,
+ places=2,
+ )
+ self.assertAlmostEqual(
+ model.attr_by_name["write"]["power"].stats.param_dependence_ratio(
+ "auto_ack!"
+ ),
+ 0.99,
+ places=2,
+ )
+
+ param_model, param_info = model.get_fitted()
+
+ self.assertAlmostEqual(
+ param_model(
+ "write", "duration", param=[0, 76, 1000, 0, 10, None, None, 1500, 0]
+ ),
+ 1090,
+ places=0,
+ )
+
+ # only bitrate is relevant
+ self.assertAlmostEqual(
+ param_model(
+ "write",
+ "duration",
+ param=[0, None, 1000, None, None, None, None, None, None],
+ ),
+ 1090,
+ places=0,
+ )
+ self.assertAlmostEqual(
+ param_model(
+ "write",
+ "duration",
+ param=[0, None, 250, None, None, None, None, None, None],
+ ),
+ 2057,
+ places=0,
+ )
+ self.assertAlmostEqual(
+ param_model(
+ "write",
+ "duration",
+ param=[0, None, 2000, None, None, None, None, None, None],
+ ),
+ 929,
+ places=0,
+ )
+
+ # auto_ack == 1 has a different write duration, still only bitrate is relevant
+ self.assertAlmostEqual(
+ param_model(
+ "write", "duration", param=[1, 76, 1000, 0, 10, None, None, 1500, 0]
+ ),
+ 22284,
+ places=0,
+ )
+ self.assertAlmostEqual(
+ param_model(
+ "write",
+ "duration",
+ param=[1, None, 1000, None, None, None, None, None, None],
+ ),
+ 22284,
+ places=0,
+ )
+ self.assertAlmostEqual(
+ param_model(
+ "write",
+ "duration",
+ param=[1, None, 250, None, None, None, None, None, None],
+ ),
+ 33229,
+ places=0,
+ )
+ self.assertAlmostEqual(
+ param_model(
+ "write",
+ "duration",
+ param=[1, None, 2000, None, None, None, None, None, None],
+ ),
+ 20459,
+ places=0,
+ )
+
+ """
+ param_model, param_info = model.get_fitted()
+ self.assertEqual(param_info("POWERDOWN", "power"), None)
+ self.assertEqual(
+ param_info("RX", "power")["function"].model_function,
+ "0 + regression_arg(0) + regression_arg(1) * np.sqrt(parameter(datarate))",
+ )
+ self.assertAlmostEqual(
+ param_info("RX", "power")["function"].model_args[0], 48530.7, places=0
+ )
+ self.assertAlmostEqual(
+ param_info("RX", "power")["function"].model_args[1], 117, places=0
+ )
+ self.assertEqual(param_info("STANDBY1", "power"), None)
+ self.assertEqual(
+ param_info("TX", "power")["function"].model_function,
+ "0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate)) + regression_arg(2) * parameter(txpower) + regression_arg(3) * 1/(parameter(datarate)) * parameter(txpower)",
+ )
+ self.assertEqual(
+ param_info("epilogue", "timeout")["function"].model_function,
+ "0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate))",
+ )
+ self.assertEqual(
+ param_info("stopListening", "duration")["function"].model_function,
+ "0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate))",
+ )
+
+ self.assertAlmostEqual(
+ param_model("RX", "power", param=[1, None, None]), 48647, places=-1
+ )
+ """
+
def test_singlefile_mmparam(self):
raw_data = RawData(["test-data/20161221_123347_mmparam.tar"])
preprocessed_data = raw_data.get_preprocessed_data()