diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2023-12-20 16:32:47 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2023-12-20 16:32:47 +0100 |
commit | 62fcf3c40a121e30807e1936ff98b7542716e5f3 (patch) | |
tree | b4472e996097da2849b6f8cc6749b38fa1ffa238 /lib | |
parent | e70a20ed920ef968b70380537ce925dc91902edd (diff) |
--show-quality=table: customizable --error-metric; default smape, fallback mae
Diffstat (limited to 'lib')
-rw-r--r-- | lib/cli.py | 44 |
1 files changed, 38 insertions, 6 deletions
@@ -117,18 +117,20 @@ def print_model_size(model): ) -def format_quality_measures(result, measure="smape", col_len=8): - if measure in result and result[measure] is not np.nan: - if measure.endswith("pe"): +def format_quality_measures(result, error_metric="smape", col_len=8): + if error_metric in result and result[error_metric] is not np.nan: + if error_metric.endswith("pe"): unit = "%" else: unit = " " - return f"{result[measure]:{col_len-1}.2f}{unit}" + return f"{result[error_metric]:{col_len-1}.2f}{unit}" else: return f"""{result["mae"]:{col_len-1}.0f} """ -def model_quality_table(lut, model, static, model_info, xv_method=None): +def model_quality_table( + lut, model, static, model_info, xv_method=None, error_metric="smape" +): key_len = len("Key") attr_len = len("Attribute") for key in static.keys(): @@ -173,7 +175,7 @@ def model_quality_table(lut, model, static, model_info, xv_method=None): ) ): result = results[key][attr] - buf += format_quality_measures(result) + buf += format_quality_measures(result, error_metric=error_metric) else: buf += f"""{"----":>7s} """ if type(model_info(key, attr)) is not StaticFunction: @@ -375,6 +377,36 @@ def add_standard_arguments(parser): "A function specified this way bypasses parameter detection: " "It is always assigned, even if the model seems to be independent of the parameters it references.", ) + parser.add_argument( + "--error-metric", + metavar="METRIC", + choices=[ + "mae", + "mape", + "smape", + "p50", + "p90", + "p95", + "p99", + "msd", + "rmsd", + "ssr", + "rsq", + ], + default="smape", + help="Error metric to use in --show-quality reports. In case a metric is undefined for a particular set of ground truth and prediction entries, dfatool falls back to mae.\n" + "MAE : Mean Absolute Error\n" + "MAPE : Mean Absolute Percentage Error\n" + "SMAPE : Symmetric Mean Absolute Percentage Error\n" + "p50 : Median (50th Percentile) Absolute Error\n" + "p90 : 90th Percentile Absolute Error\n" + "p95 : 95th Percentile Absolute Error\n" + "p99 : 99th Percentile Absolute Error\n" + "msd : Mean Square Deviation\n" + "rmsd : Root Mean Square Deviation\n" + "ssr : Sum of Squared Residuals\n" + "rsq : R² Score", + ) def parse_shift_function(param_name, param_shift): |