diff options
Diffstat (limited to 'test')
-rwxr-xr-x | test/test_parameters.py | 127 |
1 files changed, 127 insertions, 0 deletions
diff --git a/test/test_parameters.py b/test/test_parameters.py index baf1c99..6cd9c71 100755 --- a/test/test_parameters.py +++ b/test/test_parameters.py @@ -89,6 +89,133 @@ class TestModels(unittest.TestCase): for i in range(100): self.assertAlmostEqual(combined_fit.eval([None, i]), i, places=0) + def test_parameter_detection_multi_dimensional(self): + rng = np.random.default_rng(seed=1312) + # vary each parameter from 1 to 10 + Xi = (np.arange(50) % 10) + 1 + # Three parameters -> Build input array [[1, 1, 1], [1, 1, 2], ..., [10, 10, 10]] + X = np.array(np.meshgrid(Xi, Xi, Xi)).T.reshape(-1, 3) + + f_lls = np.vectorize( + lambda x: 42 + 7 * x[0] + 10 * np.log(x[1]) - 0.5 * x[2] * x[2], + signature="(n)->()", + ) + f_ll = np.vectorize(lambda x: 23 + 5 * x[0] - 3 * x[1], signature="(n)->()") + + Y_lls = f_lls(X) + rng.normal(size=X.shape[0]) + Y_ll = f_ll(X) + rng.normal(size=X.shape[0]) + + parameter_names = ["lin_lin", "log_lin", "square_none"] + + by_name = { + "someKey": { + "param": X, + "lls": Y_lls, + "ll": Y_ll, + "attributes": ["lls", "ll"], + } + } + by_param = by_name_to_by_param(by_name) + stats = parameters.ParamStats(by_name, by_param, parameter_names, dict()) + + self.assertEqual(stats.depends_on_param("someKey", "lls", "lin_lin"), True) + self.assertEqual(stats.depends_on_param("someKey", "lls", "log_lin"), True) + self.assertEqual(stats.depends_on_param("someKey", "lls", "square_none"), True) + + self.assertEqual(stats.depends_on_param("someKey", "ll", "lin_lin"), True) + self.assertEqual(stats.depends_on_param("someKey", "ll", "log_lin"), True) + self.assertEqual(stats.depends_on_param("someKey", "ll", "square_none"), False) + + paramfit = dt.ParallelParamFit(by_param) + paramfit.enqueue("someKey", "lls", 0, "lin_lin") + paramfit.enqueue("someKey", "lls", 1, "log_lin") + paramfit.enqueue("someKey", "lls", 2, "square_none") + paramfit.enqueue("someKey", "ll", 0, "lin_lin") + paramfit.enqueue("someKey", "ll", 1, "log_lin") + paramfit.fit() + + fit_lls = paramfit.get_result("someKey", "lls") + self.assertEqual(fit_lls["lin_lin"]["best"], "linear") + self.assertEqual(fit_lls["log_lin"]["best"], "logarithmic") + self.assertEqual(fit_lls["square_none"]["best"], "square") + + combined_fit_lls = analytic.function_powerset(fit_lls, parameter_names, 0) + + self.assertEqual( + combined_fit_lls.model_function, + "0 + regression_arg(0) + regression_arg(1) * parameter(lin_lin)" + " + regression_arg(2) * np.log(parameter(log_lin))" + " + regression_arg(3) * (parameter(square_none))**2" + " + regression_arg(4) * parameter(lin_lin) * np.log(parameter(log_lin))" + " + regression_arg(5) * parameter(lin_lin) * (parameter(square_none))**2" + " + regression_arg(6) * np.log(parameter(log_lin)) * (parameter(square_none))**2" + " + regression_arg(7) * parameter(lin_lin) * np.log(parameter(log_lin)) * (parameter(square_none))**2", + ) + + combined_fit_lls.fit(by_param, "someKey", "lls") + + self.assertEqual(combined_fit_lls.fit_success, True) + + # Verify that f_lls parameters have been found + self.assertAlmostEqual(combined_fit_lls.model_args[0], 42, places=0) + self.assertAlmostEqual(combined_fit_lls.model_args[1], 7, places=0) + self.assertAlmostEqual(combined_fit_lls.model_args[2], 10, places=0) + self.assertAlmostEqual(combined_fit_lls.model_args[3], -0.5, places=1) + self.assertAlmostEqual(combined_fit_lls.model_args[4], 0, places=2) + self.assertAlmostEqual(combined_fit_lls.model_args[5], 0, places=2) + self.assertAlmostEqual(combined_fit_lls.model_args[6], 0, places=2) + self.assertAlmostEqual(combined_fit_lls.model_args[7], 0, places=2) + + self.assertEqual(combined_fit_lls.is_predictable([None, None, None]), False) + self.assertEqual(combined_fit_lls.is_predictable([None, None, 11]), False) + self.assertEqual(combined_fit_lls.is_predictable([None, 11, None]), False) + self.assertEqual(combined_fit_lls.is_predictable([None, 11, 11]), False) + self.assertEqual(combined_fit_lls.is_predictable([11, None, None]), False) + self.assertEqual(combined_fit_lls.is_predictable([11, None, 11]), False) + self.assertEqual(combined_fit_lls.is_predictable([11, 11, None]), False) + self.assertEqual(combined_fit_lls.is_predictable([11, 11, 11]), True) + + # Verify that fitted function behaves like input function + for i, x in enumerate(X): + self.assertAlmostEqual(combined_fit_lls.eval(x), f_lls(x), places=0) + + fit_ll = paramfit.get_result("someKey", "ll") + self.assertEqual(fit_ll["lin_lin"]["best"], "linear") + self.assertEqual(fit_ll["log_lin"]["best"], "linear") + self.assertEqual("quare_none" not in fit_ll, True) + + combined_fit_ll = analytic.function_powerset(fit_ll, parameter_names, 0) + + self.assertEqual( + combined_fit_ll.model_function, + "0 + regression_arg(0) + regression_arg(1) * parameter(lin_lin)" + " + regression_arg(2) * parameter(log_lin)" + " + regression_arg(3) * parameter(lin_lin) * parameter(log_lin)", + ) + + combined_fit_ll.fit(by_param, "someKey", "ll") + + self.assertEqual(combined_fit_ll.fit_success, True) + + # Verify that f_ll parameters have been found + self.assertAlmostEqual(combined_fit_ll.model_args[0], 23, places=0) + self.assertAlmostEqual(combined_fit_ll.model_args[1], 5, places=0) + self.assertAlmostEqual(combined_fit_ll.model_args[2], -3, places=0) + self.assertAlmostEqual(combined_fit_ll.model_args[3], 0, places=2) + + self.assertEqual(combined_fit_ll.is_predictable([None, None, None]), False) + self.assertEqual(combined_fit_ll.is_predictable([None, None, 11]), False) + self.assertEqual(combined_fit_ll.is_predictable([None, 11, None]), False) + self.assertEqual(combined_fit_ll.is_predictable([None, 11, 11]), False) + self.assertEqual(combined_fit_ll.is_predictable([11, None, None]), False) + self.assertEqual(combined_fit_ll.is_predictable([11, None, 11]), False) + self.assertEqual(combined_fit_ll.is_predictable([11, 11, None]), True) + self.assertEqual(combined_fit_ll.is_predictable([11, 11, 11]), True) + + # Verify that fitted function behaves like input function + for i, x in enumerate(X): + self.assertAlmostEqual(combined_fit_ll.eval(x), f_ll(x), places=0) + if __name__ == "__main__": unittest.main() |