diff options
Diffstat (limited to 'bin')
-rwxr-xr-x | bin/analyze-kconfig.py | 38 |
1 files changed, 30 insertions, 8 deletions
diff --git a/bin/analyze-kconfig.py b/bin/analyze-kconfig.py index ec2f324..004e691 100755 --- a/bin/analyze-kconfig.py +++ b/bin/analyze-kconfig.py @@ -18,6 +18,7 @@ import numpy as np import dfatool.utils from dfatool.loader import KConfigAttributes from dfatool.model import AnalyticModel +from dfatool.validation import CrossValidator def main(): @@ -66,6 +67,12 @@ def main(): help="Restrict model generation to N random samples", metavar="N", ) + parser.add_argument( + "--cross-validate", + type=str, + help="Report modul accuracy via Cross-Validation", + metavar="METHOD:COUNT", + ) parser.add_argument("kconfig_path", type=str, help="Path to Kconfig file") parser.add_argument( "model", @@ -113,18 +120,33 @@ def main(): observations = None model = AnalyticModel( - by_name, parameter_names, compute_stats=not args.force_tree + by_name, + parameter_names, + compute_stats=not args.force_tree, + force_tree=args.force_tree, ) - if args.force_tree: - for name in model.names: - for attr in model.by_name[name]["attributes"]: - # TODO specify correct threshold - model.build_dtree(name, attr, 0) - model.fit_done = True + if args.cross_validate: + xv_method, xv_count = args.cross_validate.split(":") + xv_count = int(xv_count) + xv = CrossValidator( + AnalyticModel, + by_name, + parameter_names, + compute_stats=not args.force_tree, + force_tree=args.force_tree, + ) + else: + xv_method = None param_model, param_info = model.get_fitted() - analytic_quality = model.assess(param_model) + + if xv_method == "montecarlo": + analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count) + elif xv_method == "kfold": + analytic_quality = xv.kfold(lambda m: m.get_fitted()[0], xv_count) + else: + analytic_quality = model.assess(param_model) print("Model Error on Training Data:") for name in model.names: |