diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2023-12-20 12:16:51 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2023-12-20 12:16:51 +0100 |
commit | c4c14295b205a8a780fc625b7c3bf6e2be96a0ee (patch) | |
tree | 2f052bc59b6715e7b8bf9f3bd223d350c0bb972f /lib | |
parent | 66c2fe62fbde3344161e1de385760cf41284045a (diff) |
Do not run XV on LUT model; it's not helpful.
--show-quality=table now always compares LUT (training data) to model
(training or XV) and static (training or XV)
Diffstat (limited to 'lib')
-rw-r--r-- | lib/cli.py | 67 |
1 files changed, 42 insertions, 25 deletions
@@ -117,48 +117,65 @@ def print_model_size(model): ) -def format_quality_measures(result): - if "smape" in result: - return "{:6.2f}% / {:9.0f}".format(result["smape"], result["mae"]) +def format_quality_measures(result, measure="smape", col_len=8): + if measure in result and result[measure] is not np.nan: + if measure.endswith("pe"): + unit = "%" + else: + unit = " " + return f"{result[measure]:{col_len-1}.2f}{unit}" else: - return "{:6} {:9.0f}".format("", result["mae"]) - + return f"""{result["mae"]:{col_len-1}.0f} """ + + +def model_quality_table(lut, model, static, model_info, xv_method=None): + key_len = 0 + attr_len = 0 + for key in static.keys(): + if len(key) > key_len: + key_len = len(key) + for attr in static[key].keys(): + if len(attr) > attr_len: + attr_len = len(attr) + + if xv_method == "kfold": + xv_header = "kfold XV" + elif xv_method == "montecarlo": + xv_header = "MC XV" + elif xv_method: + xv_header = "XV" + else: + xv_header = "training" -def model_quality_table(header, result_lists, info_list): print( - "{:20s} {:15s} {:19s} {:19s} {:19s}".format( - "key", - "attribute", - header[0].center(19), - header[1].center(19), - header[2].center(19), - ) + f"""{"":>{key_len}s} {"":>{attr_len}s} {"training":>8s} {xv_header:>8s} {xv_header:>8s}""" + ) + print( + f"""{"Key":>{key_len}s} {"Attribute":>{attr_len}s} {"LUT":>8s} {"model":>8s} {"static":>8s}""" ) - for state_or_tran in sorted(result_lists[0].keys()): - for key in sorted(result_lists[0][state_or_tran].keys()): - buf = "{:20s} {:15s}".format(state_or_tran, key) - for i, results in enumerate(result_lists): - info = info_list[i] - buf += " ||| " + for key in sorted(static.keys()): + for attr in sorted(static[key].keys()): + buf = f"{key:>{key_len}s} {attr:>{attr_len}s}" + for results, info in ((lut, None), (model, model_info), (static, None)): + buf += " " if results is not None and ( info is None or ( key != "energy_Pt" - and type(info(state_or_tran, key)) is not StaticFunction + and type(info(key, attr)) is not StaticFunction ) or ( key == "energy_Pt" and ( - type(info(state_or_tran, "power")) is not StaticFunction - or type(info(state_or_tran, "duration")) - is not StaticFunction + type(info(key, "power")) is not StaticFunction + or type(info(key, "duration")) is not StaticFunction ) ) ): - result = results[state_or_tran][key] + result = results[key][attr] buf += format_quality_measures(result) else: - buf += "{:7}----{:8}".format("", "") + buf += f"""{"----":>7s} """ print(buf) |