diff options
Diffstat (limited to 'test/test_parameters.py')
-rwxr-xr-x | test/test_parameters.py | 67 |
1 files changed, 41 insertions, 26 deletions
diff --git a/test/test_parameters.py b/test/test_parameters.py index e36b1a1..a40ebfe 100755 --- a/test/test_parameters.py +++ b/test/test_parameters.py @@ -20,7 +20,7 @@ class TestModels(unittest.TestCase): } } self.assertEqual( - parameters.distinct_param_values(by_name, "TX"), + parameters.distinct_param_values(by_name["TX"]["param"]), [list(range(5)), list(range(7))], ) @@ -44,16 +44,20 @@ class TestModels(unittest.TestCase): "attributes": ["power"], } } - by_param = by_name_to_by_param(by_name) - stats = parameters.ParamStats(by_name, by_param, parameter_names, dict()) - self.assertEqual(stats.depends_on_param("TX", "power", "p_mod5"), False) - self.assertEqual(stats.depends_on_param("TX", "power", "p_linear"), True) + stats = parameters.ParamStats( + parameters._compute_param_statistics( + by_name["TX"]["power"], parameter_names, by_name["TX"]["param"] + ) + ) + + self.assertEqual(stats.depends_on_param("p_mod5"), False) + self.assertEqual(stats.depends_on_param("p_linear"), True) # Fit individual functions for each parameter (only "p_linear" in this case) - paramfit = ParallelParamFit(by_param) - paramfit.enqueue("TX", "power", 1, "p_linear") + paramfit = ParallelParamFit() + paramfit.enqueue(("TX", "power", "p_linear"), (stats.by_param, 1, False)) paramfit.fit() fit_result = paramfit.get_result("TX", "power") @@ -73,7 +77,7 @@ class TestModels(unittest.TestCase): "0 + reg_param[0] + reg_param[1] * model_param[1]", ) - combined_fit.fit(by_param, "TX", "power") + combined_fit.fit(stats.by_param) self.assertEqual(combined_fit.fit_success, True) @@ -123,22 +127,33 @@ class TestModels(unittest.TestCase): } } by_param = by_name_to_by_param(by_name) - stats = parameters.ParamStats(by_name, by_param, parameter_names, dict()) - - self.assertEqual(stats.depends_on_param("someKey", "lls", "lin_lin"), True) - self.assertEqual(stats.depends_on_param("someKey", "lls", "log_inv"), True) - self.assertEqual(stats.depends_on_param("someKey", "lls", "square_none"), True) - - self.assertEqual(stats.depends_on_param("someKey", "ll", "lin_lin"), True) - self.assertEqual(stats.depends_on_param("someKey", "ll", "log_inv"), True) - self.assertEqual(stats.depends_on_param("someKey", "ll", "square_none"), False) - - paramfit = ParallelParamFit(by_param) - paramfit.enqueue("someKey", "lls", 0, "lin_lin") - paramfit.enqueue("someKey", "lls", 1, "log_inv") - paramfit.enqueue("someKey", "lls", 2, "square_none") - paramfit.enqueue("someKey", "ll", 0, "lin_lin") - paramfit.enqueue("someKey", "ll", 1, "log_inv") + lls_stats = parameters.ParamStats( + parameters._compute_param_statistics( + by_name["someKey"]["lls"], parameter_names, by_name["someKey"]["param"] + ) + ) + ll_stats = parameters.ParamStats( + parameters._compute_param_statistics( + by_name["someKey"]["ll"], parameter_names, by_name["someKey"]["param"] + ) + ) + + self.assertEqual(lls_stats.depends_on_param("lin_lin"), True) + self.assertEqual(lls_stats.depends_on_param("log_inv"), True) + self.assertEqual(lls_stats.depends_on_param("square_none"), True) + + self.assertEqual(ll_stats.depends_on_param("lin_lin"), True) + self.assertEqual(ll_stats.depends_on_param("log_inv"), True) + self.assertEqual(ll_stats.depends_on_param("square_none"), False) + + paramfit = ParallelParamFit() + paramfit.enqueue(("someKey", "lls", "lin_lin"), (lls_stats.by_param, 0, False)) + paramfit.enqueue(("someKey", "lls", "log_inv"), (lls_stats.by_param, 1, False)) + paramfit.enqueue( + ("someKey", "lls", "square_none"), (lls_stats.by_param, 2, False) + ) + paramfit.enqueue(("someKey", "ll", "lin_lin"), (ll_stats.by_param, 0, False)) + paramfit.enqueue(("someKey", "ll", "log_inv"), (ll_stats.by_param, 1, False)) paramfit.fit() fit_lls = paramfit.get_result("someKey", "lls") @@ -159,7 +174,7 @@ class TestModels(unittest.TestCase): " + regression_arg(7) * parameter(lin_lin) * np.log(parameter(log_inv)) * (parameter(square_none))**2", ) - combined_fit_lls.fit(by_param, "someKey", "lls") + combined_fit_lls.fit(lls_stats.by_param) self.assertEqual(combined_fit_lls.fit_success, True) @@ -200,7 +215,7 @@ class TestModels(unittest.TestCase): " + regression_arg(3) * parameter(lin_lin) * 1/(parameter(log_inv))", ) - combined_fit_ll.fit(by_param, "someKey", "ll") + combined_fit_ll.fit(ll_stats.by_param) self.assertEqual(combined_fit_ll.fit_success, True) |