From 2c5bcd77f2c952cc5269ca3e4b6e0a7323ebd085 Mon Sep 17 00:00:00 2001 From: Daniel Friesel Date: Wed, 5 Jan 2022 15:23:51 +0100 Subject: cross validation: return intermediate models used for XV These are interesting for statistics, e.g. to determine the average dtree size --- bin/analyze-kconfig.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) (limited to 'bin/analyze-kconfig.py') diff --git a/bin/analyze-kconfig.py b/bin/analyze-kconfig.py index bd9cccb..048c8c9 100755 --- a/bin/analyze-kconfig.py +++ b/bin/analyze-kconfig.py @@ -237,19 +237,21 @@ def main(): fit_duration = time.time() - fit_start_time if xv_method == "montecarlo": - static_quality = xv.montecarlo(lambda m: m.get_static(), xv_count) - analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count) + static_quality, _ = xv.montecarlo(lambda m: m.get_static(), xv_count) + analytic_quality, _ = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count) if lut_model: - lut_quality = xv.montecarlo( + lut_quality, _ = xv.montecarlo( lambda m: m.get_param_lut(fallback=True), xv_count ) else: lut_quality = None elif xv_method == "kfold": - static_quality = xv.kfold(lambda m: m.get_static(), xv_count) - analytic_quality = xv.kfold(lambda m: m.get_fitted()[0], xv_count) + static_quality, _ = xv.kfold(lambda m: m.get_static(), xv_count) + analytic_quality, _ = xv.kfold(lambda m: m.get_fitted()[0], xv_count) if lut_model: - lut_quality = xv.kfold(lambda m: m.get_param_lut(fallback=True), xv_count) + lut_quality, _ = xv.kfold( + lambda m: m.get_param_lut(fallback=True), xv_count + ) else: lut_quality = None else: @@ -315,9 +317,9 @@ def main(): json.dump(json_model, f, sort_keys=True, cls=dfatool.utils.NpEncoder) if xv_method == "montecarlo": - static_quality = xv.montecarlo(lambda m: m.get_static(), xv_count) + static_quality, _ = xv.montecarlo(lambda m: m.get_static(), xv_count) elif xv_method == "kfold": - static_quality = xv.kfold(lambda m: m.get_static(), xv_count) + static_quality, _ = xv.kfold(lambda m: m.get_static(), xv_count) else: static_quality = model.assess(static_model) -- cgit v1.2.3