summaryrefslogtreecommitdiff
path: root/bin
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2021-10-13 16:06:30 +0200
committerDaniel Friesel <daniel.friesel@uos.de>2021-10-13 16:06:30 +0200
commit0f59ffb5f2ecb3dc23764cd566d962e483bf31e2 (patch)
tree01acb783f3bfe9593d7d1353c0677f2702bd98b4 /bin
parentb4f7b9e9407dbdc3be957fdfc6da0d7755b4b64d (diff)
analyze-kconfig: add cross-validation support
Diffstat (limited to 'bin')
-rwxr-xr-xbin/analyze-kconfig.py38
1 files changed, 30 insertions, 8 deletions
diff --git a/bin/analyze-kconfig.py b/bin/analyze-kconfig.py
index ec2f324..004e691 100755
--- a/bin/analyze-kconfig.py
+++ b/bin/analyze-kconfig.py
@@ -18,6 +18,7 @@ import numpy as np
import dfatool.utils
from dfatool.loader import KConfigAttributes
from dfatool.model import AnalyticModel
+from dfatool.validation import CrossValidator
def main():
@@ -66,6 +67,12 @@ def main():
help="Restrict model generation to N random samples",
metavar="N",
)
+ parser.add_argument(
+ "--cross-validate",
+ type=str,
+ help="Report modul accuracy via Cross-Validation",
+ metavar="METHOD:COUNT",
+ )
parser.add_argument("kconfig_path", type=str, help="Path to Kconfig file")
parser.add_argument(
"model",
@@ -113,18 +120,33 @@ def main():
observations = None
model = AnalyticModel(
- by_name, parameter_names, compute_stats=not args.force_tree
+ by_name,
+ parameter_names,
+ compute_stats=not args.force_tree,
+ force_tree=args.force_tree,
)
- if args.force_tree:
- for name in model.names:
- for attr in model.by_name[name]["attributes"]:
- # TODO specify correct threshold
- model.build_dtree(name, attr, 0)
- model.fit_done = True
+ if args.cross_validate:
+ xv_method, xv_count = args.cross_validate.split(":")
+ xv_count = int(xv_count)
+ xv = CrossValidator(
+ AnalyticModel,
+ by_name,
+ parameter_names,
+ compute_stats=not args.force_tree,
+ force_tree=args.force_tree,
+ )
+ else:
+ xv_method = None
param_model, param_info = model.get_fitted()
- analytic_quality = model.assess(param_model)
+
+ if xv_method == "montecarlo":
+ analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count)
+ elif xv_method == "kfold":
+ analytic_quality = xv.kfold(lambda m: m.get_fitted()[0], xv_count)
+ else:
+ analytic_quality = model.assess(param_model)
print("Model Error on Training Data:")
for name in model.names: