diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2023-12-19 18:04:59 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2023-12-19 18:04:59 +0100 |
commit | 66c2fe62fbde3344161e1de385760cf41284045a (patch) | |
tree | 15bdf0ef75af34d73d1b8812421051269cf02e18 | |
parent | a5b4fe753556d9e78028a9759f8cbe7cc0126d3c (diff) |
analyze-log: actually do cross validation when requested. derp...
-rwxr-xr-x | bin/analyze-log.py | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/bin/analyze-log.py b/bin/analyze-log.py index 4e22c87..4a74116 100755 --- a/bin/analyze-log.py +++ b/bin/analyze-log.py @@ -230,6 +230,7 @@ def main(): static_model = model.get_static() try: lut_model = model.get_param_lut() + lut_quality = model.assess(lut_model) except RuntimeError as e: if args.force_tree: # this is to be expected @@ -237,14 +238,21 @@ def main(): else: logging.warning(f"Skipping LUT model: {e}") lut_model = None + lut_quality = None param_model, param_info = model.get_fitted() - static_quality = model.assess(static_model) - analytic_quality = model.assess(param_model) - if lut_model: - lut_quality = model.assess(lut_model) + + if xv_method == "montecarlo": + static_quality, _ = xv.montecarlo( + lambda m: m.get_static(), xv_count, static=True + ) + analytic_quality, _ = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count) + elif xv_method == "kfold": + static_quality, _ = xv.kfold(lambda m: m.get_static(), xv_count, static=True) + analytic_quality, _ = xv.kfold(lambda m: m.get_fitted()[0], xv_count) else: - lut_quality = None + static_quality = model.assess(static_model) + analytic_quality = model.assess(param_model) if "static" in args.show_model or "all" in args.show_model: print("--- static model ---") |