From 0f59ffb5f2ecb3dc23764cd566d962e483bf31e2 Mon Sep 17 00:00:00 2001 From: Daniel Friesel Date: Wed, 13 Oct 2021 16:06:30 +0200 Subject: analyze-kconfig: add cross-validation support --- bin/analyze-kconfig.py | 38 ++++++++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 8 deletions(-) (limited to 'bin') diff --git a/bin/analyze-kconfig.py b/bin/analyze-kconfig.py index ec2f324..004e691 100755 --- a/bin/analyze-kconfig.py +++ b/bin/analyze-kconfig.py @@ -18,6 +18,7 @@ import numpy as np import dfatool.utils from dfatool.loader import KConfigAttributes from dfatool.model import AnalyticModel +from dfatool.validation import CrossValidator def main(): @@ -66,6 +67,12 @@ def main(): help="Restrict model generation to N random samples", metavar="N", ) + parser.add_argument( + "--cross-validate", + type=str, + help="Report modul accuracy via Cross-Validation", + metavar="METHOD:COUNT", + ) parser.add_argument("kconfig_path", type=str, help="Path to Kconfig file") parser.add_argument( "model", @@ -113,18 +120,33 @@ def main(): observations = None model = AnalyticModel( - by_name, parameter_names, compute_stats=not args.force_tree + by_name, + parameter_names, + compute_stats=not args.force_tree, + force_tree=args.force_tree, ) - if args.force_tree: - for name in model.names: - for attr in model.by_name[name]["attributes"]: - # TODO specify correct threshold - model.build_dtree(name, attr, 0) - model.fit_done = True + if args.cross_validate: + xv_method, xv_count = args.cross_validate.split(":") + xv_count = int(xv_count) + xv = CrossValidator( + AnalyticModel, + by_name, + parameter_names, + compute_stats=not args.force_tree, + force_tree=args.force_tree, + ) + else: + xv_method = None param_model, param_info = model.get_fitted() - analytic_quality = model.assess(param_model) + + if xv_method == "montecarlo": + analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count) + elif xv_method == "kfold": + analytic_quality = xv.kfold(lambda m: m.get_fitted()[0], xv_count) + else: + analytic_quality = model.assess(param_model) print("Model Error on Training Data:") for name in model.names: -- cgit v1.2.3