From 2be86e1d7cccb6f2b0daaa5c69c95796537b47c1 Mon Sep 17 00:00:00 2001 From: Daniel Friesel Date: Fri, 26 Feb 2021 08:55:25 +0100 Subject: add simple decisiontree test --- lib/model.py | 2 +- test/test_ptamodel.py | 157 +++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 156 insertions(+), 3 deletions(-) diff --git a/lib/model.py b/lib/model.py index 451a39a..83c31b1 100644 --- a/lib/model.py +++ b/lib/model.py @@ -422,7 +422,7 @@ class ModelAttribute: # ) # None -> kein split notwendig - # andernfalls: Parameter, anhand dessen eine Decision Tree-Ebene aufgespannt wird + # andernfalls: Parameter-Index, anhand dessen eine Decision Tree-Ebene aufgespannt wird # (Kinder sind wiederum ModelAttributes, in denen dieser Parameter konstant ist) def get_split_param_index(self): if not self.param_names: diff --git a/test/test_ptamodel.py b/test/test_ptamodel.py index bcbb19a..dad4328 100755 --- a/test/test_ptamodel.py +++ b/test/test_ptamodel.py @@ -338,8 +338,6 @@ class TestSynthetic(unittest.TestCase): static_quality = validator.kfold(lambda m: m.get_static(), 10) param_quality = validator.kfold(lambda m: m.get_fitted()[0], 10) - print(static_quality) - # static quality reflects normal distribution scale for non-parameterized data # the Root Mean Square Deviation must not be greater the scale (i.e., standard deviation) of the normal distribution @@ -669,6 +667,161 @@ class TestFromFile(unittest.TestCase): param_model("RX", "power", param=[1, None, None]), 48647, places=-1 ) + def test_decisiontrees_rf24(self): + raw_data = RawData(["test-data/20191024-152648-nrf24l01.tar"]) + preprocessed_data = raw_data.get_preprocessed_data() + by_name, parameters, arg_count = pta_trace_to_aggregate(preprocessed_data) + model = PTAModel(by_name, parameters, arg_count) + self.assertEqual(model.states(), "RX STANDBY1".split(" ")) + self.assertEqual( + model.transitions(), + "setAutoAck setDataRate setPALevel setup startListening stopListening write".split( + " " + ), + ) + static_model = model.get_static() + self.assertAlmostEqual(static_model("RX", "power"), 47964, places=0) + self.assertAlmostEqual(static_model("STANDBY1", "power"), 128, places=0) + self.assertAlmostEqual(static_model("setAutoAck", "power"), 151, places=0) + self.assertAlmostEqual(static_model("setDataRate", "power"), 146, places=0) + self.assertAlmostEqual(static_model("setPALevel", "power"), 147, places=0) + self.assertAlmostEqual(static_model("setup", "power"), 153, places=0) + self.assertAlmostEqual(static_model("startListening", "power"), 18954, places=0) + self.assertAlmostEqual(static_model("stopListening", "power"), 2426, places=0) + self.assertAlmostEqual(static_model("write", "power"), 17629, places=0) + + self.assertAlmostEqual(static_model("setAutoAck", "duration"), 90, places=0) + self.assertAlmostEqual(static_model("setDataRate", "duration"), 240, places=0) + self.assertAlmostEqual(static_model("setPALevel", "duration"), 160, places=0) + self.assertAlmostEqual(static_model("setup", "duration"), 6550, places=0) + self.assertAlmostEqual( + static_model("startListening", "duration"), 470, places=0 + ) + self.assertAlmostEqual(static_model("stopListening", "duration"), 510, places=0) + self.assertAlmostEqual(static_model("write", "duration"), 11230, places=0) + + self.assertAlmostEqual( + model.attr_by_name["write"]["duration"].stats.param_dependence_ratio( + "auto_ack!" + ), + 1, + places=2, + ) + self.assertAlmostEqual( + model.attr_by_name["write"]["power"].stats.param_dependence_ratio( + "auto_ack!" + ), + 0.99, + places=2, + ) + + param_model, param_info = model.get_fitted() + + self.assertAlmostEqual( + param_model( + "write", "duration", param=[0, 76, 1000, 0, 10, None, None, 1500, 0] + ), + 1090, + places=0, + ) + + # only bitrate is relevant + self.assertAlmostEqual( + param_model( + "write", + "duration", + param=[0, None, 1000, None, None, None, None, None, None], + ), + 1090, + places=0, + ) + self.assertAlmostEqual( + param_model( + "write", + "duration", + param=[0, None, 250, None, None, None, None, None, None], + ), + 2057, + places=0, + ) + self.assertAlmostEqual( + param_model( + "write", + "duration", + param=[0, None, 2000, None, None, None, None, None, None], + ), + 929, + places=0, + ) + + # auto_ack == 1 has a different write duration, still only bitrate is relevant + self.assertAlmostEqual( + param_model( + "write", "duration", param=[1, 76, 1000, 0, 10, None, None, 1500, 0] + ), + 22284, + places=0, + ) + self.assertAlmostEqual( + param_model( + "write", + "duration", + param=[1, None, 1000, None, None, None, None, None, None], + ), + 22284, + places=0, + ) + self.assertAlmostEqual( + param_model( + "write", + "duration", + param=[1, None, 250, None, None, None, None, None, None], + ), + 33229, + places=0, + ) + self.assertAlmostEqual( + param_model( + "write", + "duration", + param=[1, None, 2000, None, None, None, None, None, None], + ), + 20459, + places=0, + ) + + """ + param_model, param_info = model.get_fitted() + self.assertEqual(param_info("POWERDOWN", "power"), None) + self.assertEqual( + param_info("RX", "power")["function"].model_function, + "0 + regression_arg(0) + regression_arg(1) * np.sqrt(parameter(datarate))", + ) + self.assertAlmostEqual( + param_info("RX", "power")["function"].model_args[0], 48530.7, places=0 + ) + self.assertAlmostEqual( + param_info("RX", "power")["function"].model_args[1], 117, places=0 + ) + self.assertEqual(param_info("STANDBY1", "power"), None) + self.assertEqual( + param_info("TX", "power")["function"].model_function, + "0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate)) + regression_arg(2) * parameter(txpower) + regression_arg(3) * 1/(parameter(datarate)) * parameter(txpower)", + ) + self.assertEqual( + param_info("epilogue", "timeout")["function"].model_function, + "0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate))", + ) + self.assertEqual( + param_info("stopListening", "duration")["function"].model_function, + "0 + regression_arg(0) + regression_arg(1) * 1/(parameter(datarate))", + ) + + self.assertAlmostEqual( + param_model("RX", "power", param=[1, None, None]), 48647, places=-1 + ) + """ + def test_singlefile_mmparam(self): raw_data = RawData(["test-data/20161221_123347_mmparam.tar"]) preprocessed_data = raw_data.get_preprocessed_data() -- cgit v1.2.3