summaryrefslogtreecommitdiff
path: root/test/test_parameters.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_parameters.py')
-rwxr-xr-xtest/test_parameters.py67
1 files changed, 41 insertions, 26 deletions
diff --git a/test/test_parameters.py b/test/test_parameters.py
index e36b1a1..a40ebfe 100755
--- a/test/test_parameters.py
+++ b/test/test_parameters.py
@@ -20,7 +20,7 @@ class TestModels(unittest.TestCase):
}
}
self.assertEqual(
- parameters.distinct_param_values(by_name, "TX"),
+ parameters.distinct_param_values(by_name["TX"]["param"]),
[list(range(5)), list(range(7))],
)
@@ -44,16 +44,20 @@ class TestModels(unittest.TestCase):
"attributes": ["power"],
}
}
- by_param = by_name_to_by_param(by_name)
- stats = parameters.ParamStats(by_name, by_param, parameter_names, dict())
- self.assertEqual(stats.depends_on_param("TX", "power", "p_mod5"), False)
- self.assertEqual(stats.depends_on_param("TX", "power", "p_linear"), True)
+ stats = parameters.ParamStats(
+ parameters._compute_param_statistics(
+ by_name["TX"]["power"], parameter_names, by_name["TX"]["param"]
+ )
+ )
+
+ self.assertEqual(stats.depends_on_param("p_mod5"), False)
+ self.assertEqual(stats.depends_on_param("p_linear"), True)
# Fit individual functions for each parameter (only "p_linear" in this case)
- paramfit = ParallelParamFit(by_param)
- paramfit.enqueue("TX", "power", 1, "p_linear")
+ paramfit = ParallelParamFit()
+ paramfit.enqueue(("TX", "power", "p_linear"), (stats.by_param, 1, False))
paramfit.fit()
fit_result = paramfit.get_result("TX", "power")
@@ -73,7 +77,7 @@ class TestModels(unittest.TestCase):
"0 + reg_param[0] + reg_param[1] * model_param[1]",
)
- combined_fit.fit(by_param, "TX", "power")
+ combined_fit.fit(stats.by_param)
self.assertEqual(combined_fit.fit_success, True)
@@ -123,22 +127,33 @@ class TestModels(unittest.TestCase):
}
}
by_param = by_name_to_by_param(by_name)
- stats = parameters.ParamStats(by_name, by_param, parameter_names, dict())
-
- self.assertEqual(stats.depends_on_param("someKey", "lls", "lin_lin"), True)
- self.assertEqual(stats.depends_on_param("someKey", "lls", "log_inv"), True)
- self.assertEqual(stats.depends_on_param("someKey", "lls", "square_none"), True)
-
- self.assertEqual(stats.depends_on_param("someKey", "ll", "lin_lin"), True)
- self.assertEqual(stats.depends_on_param("someKey", "ll", "log_inv"), True)
- self.assertEqual(stats.depends_on_param("someKey", "ll", "square_none"), False)
-
- paramfit = ParallelParamFit(by_param)
- paramfit.enqueue("someKey", "lls", 0, "lin_lin")
- paramfit.enqueue("someKey", "lls", 1, "log_inv")
- paramfit.enqueue("someKey", "lls", 2, "square_none")
- paramfit.enqueue("someKey", "ll", 0, "lin_lin")
- paramfit.enqueue("someKey", "ll", 1, "log_inv")
+ lls_stats = parameters.ParamStats(
+ parameters._compute_param_statistics(
+ by_name["someKey"]["lls"], parameter_names, by_name["someKey"]["param"]
+ )
+ )
+ ll_stats = parameters.ParamStats(
+ parameters._compute_param_statistics(
+ by_name["someKey"]["ll"], parameter_names, by_name["someKey"]["param"]
+ )
+ )
+
+ self.assertEqual(lls_stats.depends_on_param("lin_lin"), True)
+ self.assertEqual(lls_stats.depends_on_param("log_inv"), True)
+ self.assertEqual(lls_stats.depends_on_param("square_none"), True)
+
+ self.assertEqual(ll_stats.depends_on_param("lin_lin"), True)
+ self.assertEqual(ll_stats.depends_on_param("log_inv"), True)
+ self.assertEqual(ll_stats.depends_on_param("square_none"), False)
+
+ paramfit = ParallelParamFit()
+ paramfit.enqueue(("someKey", "lls", "lin_lin"), (lls_stats.by_param, 0, False))
+ paramfit.enqueue(("someKey", "lls", "log_inv"), (lls_stats.by_param, 1, False))
+ paramfit.enqueue(
+ ("someKey", "lls", "square_none"), (lls_stats.by_param, 2, False)
+ )
+ paramfit.enqueue(("someKey", "ll", "lin_lin"), (ll_stats.by_param, 0, False))
+ paramfit.enqueue(("someKey", "ll", "log_inv"), (ll_stats.by_param, 1, False))
paramfit.fit()
fit_lls = paramfit.get_result("someKey", "lls")
@@ -159,7 +174,7 @@ class TestModels(unittest.TestCase):
" + regression_arg(7) * parameter(lin_lin) * np.log(parameter(log_inv)) * (parameter(square_none))**2",
)
- combined_fit_lls.fit(by_param, "someKey", "lls")
+ combined_fit_lls.fit(lls_stats.by_param)
self.assertEqual(combined_fit_lls.fit_success, True)
@@ -200,7 +215,7 @@ class TestModels(unittest.TestCase):
" + regression_arg(3) * parameter(lin_lin) * 1/(parameter(log_inv))",
)
- combined_fit_ll.fit(by_param, "someKey", "ll")
+ combined_fit_ll.fit(ll_stats.by_param)
self.assertEqual(combined_fit_ll.fit_success, True)