diff options
-rwxr-xr-x | bin/analyze-kconfig.py | 38 | ||||
-rw-r--r-- | lib/model.py | 8 |
2 files changed, 38 insertions, 8 deletions
diff --git a/bin/analyze-kconfig.py b/bin/analyze-kconfig.py index ec2f324..004e691 100755 --- a/bin/analyze-kconfig.py +++ b/bin/analyze-kconfig.py @@ -18,6 +18,7 @@ import numpy as np import dfatool.utils from dfatool.loader import KConfigAttributes from dfatool.model import AnalyticModel +from dfatool.validation import CrossValidator def main(): @@ -66,6 +67,12 @@ def main(): help="Restrict model generation to N random samples", metavar="N", ) + parser.add_argument( + "--cross-validate", + type=str, + help="Report modul accuracy via Cross-Validation", + metavar="METHOD:COUNT", + ) parser.add_argument("kconfig_path", type=str, help="Path to Kconfig file") parser.add_argument( "model", @@ -113,18 +120,33 @@ def main(): observations = None model = AnalyticModel( - by_name, parameter_names, compute_stats=not args.force_tree + by_name, + parameter_names, + compute_stats=not args.force_tree, + force_tree=args.force_tree, ) - if args.force_tree: - for name in model.names: - for attr in model.by_name[name]["attributes"]: - # TODO specify correct threshold - model.build_dtree(name, attr, 0) - model.fit_done = True + if args.cross_validate: + xv_method, xv_count = args.cross_validate.split(":") + xv_count = int(xv_count) + xv = CrossValidator( + AnalyticModel, + by_name, + parameter_names, + compute_stats=not args.force_tree, + force_tree=args.force_tree, + ) + else: + xv_method = None param_model, param_info = model.get_fitted() - analytic_quality = model.assess(param_model) + + if xv_method == "montecarlo": + analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count) + elif xv_method == "kfold": + analytic_quality = xv.kfold(lambda m: m.get_fitted()[0], xv_count) + else: + analytic_quality = model.assess(param_model) print("Model Error on Training Data:") for name in model.names: diff --git a/lib/model.py b/lib/model.py index 590c6f0..06ff25c 100644 --- a/lib/model.py +++ b/lib/model.py @@ -75,6 +75,7 @@ class AnalyticModel: function_override=dict(), use_corrcoef=False, compute_stats=True, + force_tree=False, ): """ Create a new AnalyticModel and compute parameter statistics. @@ -134,6 +135,13 @@ class AnalyticModel: if compute_stats: self._compute_stats(by_name) + if force_tree: + for name in self.names: + for attr in self.by_name[name]["attributes"]: + # TODO specify correct threshold + self.build_dtree(name, attr, 0) + self.fit_done = True + def __repr__(self): names = ", ".join(self.by_name.keys()) return f"AnalyticModel<names=[{names}]>" |