From a0933fef969c4555452fcbf70e6183eddf141956 Mon Sep 17 00:00:00 2001 From: Birte Kristina Friesel Date: Thu, 24 Oct 2024 12:17:43 +0200 Subject: add --add-total-observation for behaviour model evaluation --- bin/analyze-log.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) (limited to 'bin') diff --git a/bin/analyze-log.py b/bin/analyze-log.py index 1978df0..fc7fc0d 100755 --- a/bin/analyze-log.py +++ b/bin/analyze-log.py @@ -214,14 +214,26 @@ def main(): ts = time.time() if xv_method == "montecarlo": static_quality, _ = xv.montecarlo( - lambda m: m.get_static(), xv_count, static=True + lambda m: m.get_static(), + xv_count, + static=True, + with_sum=args.add_total_observation, ) xv.export_filename = args.export_xv - analytic_quality, _ = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count) + analytic_quality, _ = xv.montecarlo( + lambda m: m.get_fitted()[0], xv_count, with_sum=args.add_total_observation + ) elif xv_method == "kfold": - static_quality, _ = xv.kfold(lambda m: m.get_static(), xv_count, static=True) + static_quality, _ = xv.kfold( + lambda m: m.get_static(), + xv_count, + static=True, + with_sum=args.add_total_observation, + ) xv.export_filename = args.export_xv - analytic_quality, _ = xv.kfold(lambda m: m.get_fitted()[0], xv_count) + analytic_quality, _ = xv.kfold( + lambda m: m.get_fitted()[0], xv_count, with_sum=args.add_total_observation + ) else: static_quality = model.assess(static_model) if args.export_raw_predictions: -- cgit v1.2.3