summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xbin/analyze-archive.py25
-rwxr-xr-xbin/analyze-kconfig.py22
-rwxr-xr-xbin/analyze-log.py8
-rw-r--r--lib/cli.py67
4 files changed, 64 insertions, 58 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py
index d939c44..64792fd 100755
--- a/bin/analyze-archive.py
+++ b/bin/analyze-archive.py
@@ -626,15 +626,7 @@ if __name__ == "__main__":
if len(show_models):
print("--- LUT ---")
lut_model = model.get_param_lut()
-
- if xv_method == "montecarlo":
- lut_quality, _ = xv.montecarlo(
- lambda m: m.get_param_lut(fallback=True), xv_count
- )
- elif xv_method == "kfold":
- lut_quality, _ = xv.kfold(lambda m: m.get_param_lut(fallback=True), xv_count)
- else:
- lut_quality = model.assess(lut_model)
+ lut_quality = model.assess(lut_model)
if len(show_models):
print("--- param model ---")
@@ -772,9 +764,11 @@ if __name__ == "__main__":
if "table" in show_quality or "all" in show_quality:
dfatool.cli.model_quality_table(
- ["static", "parameterized", "LUT"],
- [static_quality, analytic_quality, lut_quality],
- [None, param_info, None],
+ lut=lut_quality,
+ model=analytic_quality,
+ static=static_quality,
+ model_info=param_info,
+ xv_method=xv_method,
)
if args.with_substates:
for submodel in model.submodel_by_name.values():
@@ -785,9 +779,10 @@ if __name__ == "__main__":
sub_param_model, sub_param_info = submodel.get_fitted()
sub_analytic_quality = submodel.assess(sub_param_model)
dfatool.cli.model_quality_table(
- ["static", "parameterized", "LUT"],
- [sub_static_quality, sub_analytic_quality, sub_lut_quality],
- [None, sub_param_info, None],
+ lut=sub_lut_quality,
+ model=sub_analytic_quality,
+ static=sub_static_quality,
+ model_info=sub_param_info,
)
if "overall" in show_quality or "all" in show_quality:
diff --git a/bin/analyze-kconfig.py b/bin/analyze-kconfig.py
index 0900255..7eb63c8 100755
--- a/bin/analyze-kconfig.py
+++ b/bin/analyze-kconfig.py
@@ -418,6 +418,7 @@ def main():
static_model = model.get_static()
try:
lut_model = model.get_param_lut()
+ lut_quality = model.assess(lut_model)
except RuntimeError as e:
if args.force_tree:
# this is to be expected
@@ -425,6 +426,7 @@ def main():
else:
logging.warning(f"Skipping LUT model: {e}")
lut_model = None
+ lut_quality = None
if args.export_csv:
for name in model.names:
@@ -444,24 +446,12 @@ def main():
static_quality, _ = xv.montecarlo(
lambda m: m.get_static(), xv_count, static=True
)
- if lut_model:
- lut_quality, _ = xv.montecarlo(
- lambda m: m.get_param_lut(fallback=True), xv_count, static=True
- )
- else:
- lut_quality = None
xv.export_filename = args.export_xv
analytic_quality, xv_analytic_models = xv.montecarlo(
lambda m: m.get_fitted()[0], xv_count
)
elif xv_method == "kfold":
static_quality, _ = xv.kfold(lambda m: m.get_static(), xv_count, static=True)
- if lut_model:
- lut_quality, _ = xv.kfold(
- lambda m: m.get_param_lut(fallback=True), xv_count, static=True
- )
- else:
- lut_quality = None
xv.export_filename = args.export_xv
analytic_quality, xv_analytic_models = xv.kfold(
lambda m: m.get_fitted()[0], xv_count
@@ -521,9 +511,11 @@ def main():
else:
print("Model error on training data:")
dfatool.cli.model_quality_table(
- ["static", "parameterized", "LUT"],
- [static_quality, analytic_quality, lut_quality],
- [None, param_info, None],
+ lut=lut_quality,
+ model=analytic_quality,
+ static=static_quality,
+ model_info=param_info,
+ xv_method=xv_method,
)
if not args.show_quality:
diff --git a/bin/analyze-log.py b/bin/analyze-log.py
index 4a74116..4f5e420 100755
--- a/bin/analyze-log.py
+++ b/bin/analyze-log.py
@@ -284,9 +284,11 @@ def main():
else:
print("Model error on training data:")
dfatool.cli.model_quality_table(
- ["static", "parameterized", "LUT"],
- [static_quality, analytic_quality, lut_quality],
- [None, param_info, None],
+ lut=lut_quality,
+ model=analytic_quality,
+ static=static_quality,
+ model_info=param_info,
+ xv_method=xv_method,
)
if args.export_model:
diff --git a/lib/cli.py b/lib/cli.py
index 8447281..51ebb84 100644
--- a/lib/cli.py
+++ b/lib/cli.py
@@ -117,48 +117,65 @@ def print_model_size(model):
)
-def format_quality_measures(result):
- if "smape" in result:
- return "{:6.2f}% / {:9.0f}".format(result["smape"], result["mae"])
+def format_quality_measures(result, measure="smape", col_len=8):
+ if measure in result and result[measure] is not np.nan:
+ if measure.endswith("pe"):
+ unit = "%"
+ else:
+ unit = " "
+ return f"{result[measure]:{col_len-1}.2f}{unit}"
else:
- return "{:6} {:9.0f}".format("", result["mae"])
-
+ return f"""{result["mae"]:{col_len-1}.0f} """
+
+
+def model_quality_table(lut, model, static, model_info, xv_method=None):
+ key_len = 0
+ attr_len = 0
+ for key in static.keys():
+ if len(key) > key_len:
+ key_len = len(key)
+ for attr in static[key].keys():
+ if len(attr) > attr_len:
+ attr_len = len(attr)
+
+ if xv_method == "kfold":
+ xv_header = "kfold XV"
+ elif xv_method == "montecarlo":
+ xv_header = "MC XV"
+ elif xv_method:
+ xv_header = "XV"
+ else:
+ xv_header = "training"
-def model_quality_table(header, result_lists, info_list):
print(
- "{:20s} {:15s} {:19s} {:19s} {:19s}".format(
- "key",
- "attribute",
- header[0].center(19),
- header[1].center(19),
- header[2].center(19),
- )
+ f"""{"":>{key_len}s} {"":>{attr_len}s} {"training":>8s} {xv_header:>8s} {xv_header:>8s}"""
+ )
+ print(
+ f"""{"Key":>{key_len}s} {"Attribute":>{attr_len}s} {"LUT":>8s} {"model":>8s} {"static":>8s}"""
)
- for state_or_tran in sorted(result_lists[0].keys()):
- for key in sorted(result_lists[0][state_or_tran].keys()):
- buf = "{:20s} {:15s}".format(state_or_tran, key)
- for i, results in enumerate(result_lists):
- info = info_list[i]
- buf += " ||| "
+ for key in sorted(static.keys()):
+ for attr in sorted(static[key].keys()):
+ buf = f"{key:>{key_len}s} {attr:>{attr_len}s}"
+ for results, info in ((lut, None), (model, model_info), (static, None)):
+ buf += " "
if results is not None and (
info is None
or (
key != "energy_Pt"
- and type(info(state_or_tran, key)) is not StaticFunction
+ and type(info(key, attr)) is not StaticFunction
)
or (
key == "energy_Pt"
and (
- type(info(state_or_tran, "power")) is not StaticFunction
- or type(info(state_or_tran, "duration"))
- is not StaticFunction
+ type(info(key, "power")) is not StaticFunction
+ or type(info(key, "duration")) is not StaticFunction
)
)
):
- result = results[state_or_tran][key]
+ result = results[key][attr]
buf += format_quality_measures(result)
else:
- buf += "{:7}----{:8}".format("", "")
+ buf += f"""{"----":>7s} """
print(buf)