summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xbin/analyze-kconfig.py38
-rw-r--r--lib/model.py8
2 files changed, 38 insertions, 8 deletions
diff --git a/bin/analyze-kconfig.py b/bin/analyze-kconfig.py
index ec2f324..004e691 100755
--- a/bin/analyze-kconfig.py
+++ b/bin/analyze-kconfig.py
@@ -18,6 +18,7 @@ import numpy as np
import dfatool.utils
from dfatool.loader import KConfigAttributes
from dfatool.model import AnalyticModel
+from dfatool.validation import CrossValidator
def main():
@@ -66,6 +67,12 @@ def main():
help="Restrict model generation to N random samples",
metavar="N",
)
+ parser.add_argument(
+ "--cross-validate",
+ type=str,
+ help="Report modul accuracy via Cross-Validation",
+ metavar="METHOD:COUNT",
+ )
parser.add_argument("kconfig_path", type=str, help="Path to Kconfig file")
parser.add_argument(
"model",
@@ -113,18 +120,33 @@ def main():
observations = None
model = AnalyticModel(
- by_name, parameter_names, compute_stats=not args.force_tree
+ by_name,
+ parameter_names,
+ compute_stats=not args.force_tree,
+ force_tree=args.force_tree,
)
- if args.force_tree:
- for name in model.names:
- for attr in model.by_name[name]["attributes"]:
- # TODO specify correct threshold
- model.build_dtree(name, attr, 0)
- model.fit_done = True
+ if args.cross_validate:
+ xv_method, xv_count = args.cross_validate.split(":")
+ xv_count = int(xv_count)
+ xv = CrossValidator(
+ AnalyticModel,
+ by_name,
+ parameter_names,
+ compute_stats=not args.force_tree,
+ force_tree=args.force_tree,
+ )
+ else:
+ xv_method = None
param_model, param_info = model.get_fitted()
- analytic_quality = model.assess(param_model)
+
+ if xv_method == "montecarlo":
+ analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count)
+ elif xv_method == "kfold":
+ analytic_quality = xv.kfold(lambda m: m.get_fitted()[0], xv_count)
+ else:
+ analytic_quality = model.assess(param_model)
print("Model Error on Training Data:")
for name in model.names:
diff --git a/lib/model.py b/lib/model.py
index 590c6f0..06ff25c 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -75,6 +75,7 @@ class AnalyticModel:
function_override=dict(),
use_corrcoef=False,
compute_stats=True,
+ force_tree=False,
):
"""
Create a new AnalyticModel and compute parameter statistics.
@@ -134,6 +135,13 @@ class AnalyticModel:
if compute_stats:
self._compute_stats(by_name)
+ if force_tree:
+ for name in self.names:
+ for attr in self.by_name[name]["attributes"]:
+ # TODO specify correct threshold
+ self.build_dtree(name, attr, 0)
+ self.fit_done = True
+
def __repr__(self):
names = ", ".join(self.by_name.keys())
return f"AnalyticModel<names=[{names}]>"