summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2023-12-19 18:04:59 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2023-12-19 18:04:59 +0100
commit66c2fe62fbde3344161e1de385760cf41284045a (patch)
tree15bdf0ef75af34d73d1b8812421051269cf02e18
parenta5b4fe753556d9e78028a9759f8cbe7cc0126d3c (diff)
analyze-log: actually do cross validation when requested. derp...
-rwxr-xr-xbin/analyze-log.py18
1 files changed, 13 insertions, 5 deletions
diff --git a/bin/analyze-log.py b/bin/analyze-log.py
index 4e22c87..4a74116 100755
--- a/bin/analyze-log.py
+++ b/bin/analyze-log.py
@@ -230,6 +230,7 @@ def main():
static_model = model.get_static()
try:
lut_model = model.get_param_lut()
+ lut_quality = model.assess(lut_model)
except RuntimeError as e:
if args.force_tree:
# this is to be expected
@@ -237,14 +238,21 @@ def main():
else:
logging.warning(f"Skipping LUT model: {e}")
lut_model = None
+ lut_quality = None
param_model, param_info = model.get_fitted()
- static_quality = model.assess(static_model)
- analytic_quality = model.assess(param_model)
- if lut_model:
- lut_quality = model.assess(lut_model)
+
+ if xv_method == "montecarlo":
+ static_quality, _ = xv.montecarlo(
+ lambda m: m.get_static(), xv_count, static=True
+ )
+ analytic_quality, _ = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count)
+ elif xv_method == "kfold":
+ static_quality, _ = xv.kfold(lambda m: m.get_static(), xv_count, static=True)
+ analytic_quality, _ = xv.kfold(lambda m: m.get_fitted()[0], xv_count)
else:
- lut_quality = None
+ static_quality = model.assess(static_model)
+ analytic_quality = model.assess(param_model)
if "static" in args.show_model or "all" in args.show_model:
print("--- static model ---")