summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/cli.py67
1 files changed, 42 insertions, 25 deletions
diff --git a/lib/cli.py b/lib/cli.py
index 8447281..51ebb84 100644
--- a/lib/cli.py
+++ b/lib/cli.py
@@ -117,48 +117,65 @@ def print_model_size(model):
)
-def format_quality_measures(result):
- if "smape" in result:
- return "{:6.2f}% / {:9.0f}".format(result["smape"], result["mae"])
+def format_quality_measures(result, measure="smape", col_len=8):
+ if measure in result and result[measure] is not np.nan:
+ if measure.endswith("pe"):
+ unit = "%"
+ else:
+ unit = " "
+ return f"{result[measure]:{col_len-1}.2f}{unit}"
else:
- return "{:6} {:9.0f}".format("", result["mae"])
-
+ return f"""{result["mae"]:{col_len-1}.0f} """
+
+
+def model_quality_table(lut, model, static, model_info, xv_method=None):
+ key_len = 0
+ attr_len = 0
+ for key in static.keys():
+ if len(key) > key_len:
+ key_len = len(key)
+ for attr in static[key].keys():
+ if len(attr) > attr_len:
+ attr_len = len(attr)
+
+ if xv_method == "kfold":
+ xv_header = "kfold XV"
+ elif xv_method == "montecarlo":
+ xv_header = "MC XV"
+ elif xv_method:
+ xv_header = "XV"
+ else:
+ xv_header = "training"
-def model_quality_table(header, result_lists, info_list):
print(
- "{:20s} {:15s} {:19s} {:19s} {:19s}".format(
- "key",
- "attribute",
- header[0].center(19),
- header[1].center(19),
- header[2].center(19),
- )
+ f"""{"":>{key_len}s} {"":>{attr_len}s} {"training":>8s} {xv_header:>8s} {xv_header:>8s}"""
+ )
+ print(
+ f"""{"Key":>{key_len}s} {"Attribute":>{attr_len}s} {"LUT":>8s} {"model":>8s} {"static":>8s}"""
)
- for state_or_tran in sorted(result_lists[0].keys()):
- for key in sorted(result_lists[0][state_or_tran].keys()):
- buf = "{:20s} {:15s}".format(state_or_tran, key)
- for i, results in enumerate(result_lists):
- info = info_list[i]
- buf += " ||| "
+ for key in sorted(static.keys()):
+ for attr in sorted(static[key].keys()):
+ buf = f"{key:>{key_len}s} {attr:>{attr_len}s}"
+ for results, info in ((lut, None), (model, model_info), (static, None)):
+ buf += " "
if results is not None and (
info is None
or (
key != "energy_Pt"
- and type(info(state_or_tran, key)) is not StaticFunction
+ and type(info(key, attr)) is not StaticFunction
)
or (
key == "energy_Pt"
and (
- type(info(state_or_tran, "power")) is not StaticFunction
- or type(info(state_or_tran, "duration"))
- is not StaticFunction
+ type(info(key, "power")) is not StaticFunction
+ or type(info(key, "duration")) is not StaticFunction
)
)
):
- result = results[state_or_tran][key]
+ result = results[key][attr]
buf += format_quality_measures(result)
else:
- buf += "{:7}----{:8}".format("", "")
+ buf += f"""{"----":>7s} """
print(buf)