diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2023-12-20 12:16:51 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2023-12-20 12:16:51 +0100 |
commit | c4c14295b205a8a780fc625b7c3bf6e2be96a0ee (patch) | |
tree | 2f052bc59b6715e7b8bf9f3bd223d350c0bb972f | |
parent | 66c2fe62fbde3344161e1de385760cf41284045a (diff) |
Do not run XV on LUT model; it's not helpful.
--show-quality=table now always compares LUT (training data) to model
(training or XV) and static (training or XV)
-rwxr-xr-x | bin/analyze-archive.py | 25 | ||||
-rwxr-xr-x | bin/analyze-kconfig.py | 22 | ||||
-rwxr-xr-x | bin/analyze-log.py | 8 | ||||
-rw-r--r-- | lib/cli.py | 67 |
4 files changed, 64 insertions, 58 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py index d939c44..64792fd 100755 --- a/bin/analyze-archive.py +++ b/bin/analyze-archive.py @@ -626,15 +626,7 @@ if __name__ == "__main__": if len(show_models): print("--- LUT ---") lut_model = model.get_param_lut() - - if xv_method == "montecarlo": - lut_quality, _ = xv.montecarlo( - lambda m: m.get_param_lut(fallback=True), xv_count - ) - elif xv_method == "kfold": - lut_quality, _ = xv.kfold(lambda m: m.get_param_lut(fallback=True), xv_count) - else: - lut_quality = model.assess(lut_model) + lut_quality = model.assess(lut_model) if len(show_models): print("--- param model ---") @@ -772,9 +764,11 @@ if __name__ == "__main__": if "table" in show_quality or "all" in show_quality: dfatool.cli.model_quality_table( - ["static", "parameterized", "LUT"], - [static_quality, analytic_quality, lut_quality], - [None, param_info, None], + lut=lut_quality, + model=analytic_quality, + static=static_quality, + model_info=param_info, + xv_method=xv_method, ) if args.with_substates: for submodel in model.submodel_by_name.values(): @@ -785,9 +779,10 @@ if __name__ == "__main__": sub_param_model, sub_param_info = submodel.get_fitted() sub_analytic_quality = submodel.assess(sub_param_model) dfatool.cli.model_quality_table( - ["static", "parameterized", "LUT"], - [sub_static_quality, sub_analytic_quality, sub_lut_quality], - [None, sub_param_info, None], + lut=sub_lut_quality, + model=sub_analytic_quality, + static=sub_static_quality, + model_info=sub_param_info, ) if "overall" in show_quality or "all" in show_quality: diff --git a/bin/analyze-kconfig.py b/bin/analyze-kconfig.py index 0900255..7eb63c8 100755 --- a/bin/analyze-kconfig.py +++ b/bin/analyze-kconfig.py @@ -418,6 +418,7 @@ def main(): static_model = model.get_static() try: lut_model = model.get_param_lut() + lut_quality = model.assess(lut_model) except RuntimeError as e: if args.force_tree: # this is to be expected @@ -425,6 +426,7 @@ def main(): else: logging.warning(f"Skipping LUT model: {e}") lut_model = None + lut_quality = None if args.export_csv: for name in model.names: @@ -444,24 +446,12 @@ def main(): static_quality, _ = xv.montecarlo( lambda m: m.get_static(), xv_count, static=True ) - if lut_model: - lut_quality, _ = xv.montecarlo( - lambda m: m.get_param_lut(fallback=True), xv_count, static=True - ) - else: - lut_quality = None xv.export_filename = args.export_xv analytic_quality, xv_analytic_models = xv.montecarlo( lambda m: m.get_fitted()[0], xv_count ) elif xv_method == "kfold": static_quality, _ = xv.kfold(lambda m: m.get_static(), xv_count, static=True) - if lut_model: - lut_quality, _ = xv.kfold( - lambda m: m.get_param_lut(fallback=True), xv_count, static=True - ) - else: - lut_quality = None xv.export_filename = args.export_xv analytic_quality, xv_analytic_models = xv.kfold( lambda m: m.get_fitted()[0], xv_count @@ -521,9 +511,11 @@ def main(): else: print("Model error on training data:") dfatool.cli.model_quality_table( - ["static", "parameterized", "LUT"], - [static_quality, analytic_quality, lut_quality], - [None, param_info, None], + lut=lut_quality, + model=analytic_quality, + static=static_quality, + model_info=param_info, + xv_method=xv_method, ) if not args.show_quality: diff --git a/bin/analyze-log.py b/bin/analyze-log.py index 4a74116..4f5e420 100755 --- a/bin/analyze-log.py +++ b/bin/analyze-log.py @@ -284,9 +284,11 @@ def main(): else: print("Model error on training data:") dfatool.cli.model_quality_table( - ["static", "parameterized", "LUT"], - [static_quality, analytic_quality, lut_quality], - [None, param_info, None], + lut=lut_quality, + model=analytic_quality, + static=static_quality, + model_info=param_info, + xv_method=xv_method, ) if args.export_model: @@ -117,48 +117,65 @@ def print_model_size(model): ) -def format_quality_measures(result): - if "smape" in result: - return "{:6.2f}% / {:9.0f}".format(result["smape"], result["mae"]) +def format_quality_measures(result, measure="smape", col_len=8): + if measure in result and result[measure] is not np.nan: + if measure.endswith("pe"): + unit = "%" + else: + unit = " " + return f"{result[measure]:{col_len-1}.2f}{unit}" else: - return "{:6} {:9.0f}".format("", result["mae"]) - + return f"""{result["mae"]:{col_len-1}.0f} """ + + +def model_quality_table(lut, model, static, model_info, xv_method=None): + key_len = 0 + attr_len = 0 + for key in static.keys(): + if len(key) > key_len: + key_len = len(key) + for attr in static[key].keys(): + if len(attr) > attr_len: + attr_len = len(attr) + + if xv_method == "kfold": + xv_header = "kfold XV" + elif xv_method == "montecarlo": + xv_header = "MC XV" + elif xv_method: + xv_header = "XV" + else: + xv_header = "training" -def model_quality_table(header, result_lists, info_list): print( - "{:20s} {:15s} {:19s} {:19s} {:19s}".format( - "key", - "attribute", - header[0].center(19), - header[1].center(19), - header[2].center(19), - ) + f"""{"":>{key_len}s} {"":>{attr_len}s} {"training":>8s} {xv_header:>8s} {xv_header:>8s}""" + ) + print( + f"""{"Key":>{key_len}s} {"Attribute":>{attr_len}s} {"LUT":>8s} {"model":>8s} {"static":>8s}""" ) - for state_or_tran in sorted(result_lists[0].keys()): - for key in sorted(result_lists[0][state_or_tran].keys()): - buf = "{:20s} {:15s}".format(state_or_tran, key) - for i, results in enumerate(result_lists): - info = info_list[i] - buf += " ||| " + for key in sorted(static.keys()): + for attr in sorted(static[key].keys()): + buf = f"{key:>{key_len}s} {attr:>{attr_len}s}" + for results, info in ((lut, None), (model, model_info), (static, None)): + buf += " " if results is not None and ( info is None or ( key != "energy_Pt" - and type(info(state_or_tran, key)) is not StaticFunction + and type(info(key, attr)) is not StaticFunction ) or ( key == "energy_Pt" and ( - type(info(state_or_tran, "power")) is not StaticFunction - or type(info(state_or_tran, "duration")) - is not StaticFunction + type(info(key, "power")) is not StaticFunction + or type(info(key, "duration")) is not StaticFunction ) ) ): - result = results[state_or_tran][key] + result = results[key][attr] buf += format_quality_measures(result) else: - buf += "{:7}----{:8}".format("", "") + buf += f"""{"----":>7s} """ print(buf) |