diff options
-rwxr-xr-x | bin/analyze-archive.py | 2 | ||||
-rwxr-xr-x | bin/analyze-kconfig.py | 7 | ||||
-rwxr-xr-x | bin/analyze-log.py | 7 | ||||
-rw-r--r-- | lib/cli.py | 44 |
4 files changed, 50 insertions, 10 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py index 535979c..ebd16ce 100755 --- a/bin/analyze-archive.py +++ b/bin/analyze-archive.py @@ -768,6 +768,7 @@ if __name__ == "__main__": static=static_quality, model_info=param_info, xv_method=xv_method, + error_metric=args.error_metric, ) if args.with_substates: for submodel in model.submodel_by_name.values(): @@ -782,6 +783,7 @@ if __name__ == "__main__": model=sub_analytic_quality, static=sub_static_quality, model_info=sub_param_info, + error_metric=args.error_metric, ) if "overall" in show_quality or "all" in show_quality: diff --git a/bin/analyze-kconfig.py b/bin/analyze-kconfig.py index 8fcb623..f2c6f8f 100755 --- a/bin/analyze-kconfig.py +++ b/bin/analyze-kconfig.py @@ -504,15 +504,18 @@ def main(): if "table" in args.show_quality or "all" in args.show_quality: if xv_method is not None: - print(f"Model error after cross validation ({xv_method}, {xv_count}):") + print( + f"Model error ({args.error_metric}) after cross validation ({xv_method}, {xv_count}):" + ) else: - print("Model error on training data:") + print(f"Model error ({args.error_metric}) on training data:") dfatool.cli.model_quality_table( lut=lut_quality, model=analytic_quality, static=static_quality, model_info=param_info, xv_method=xv_method, + error_metric=args.error_metric, ) if not args.show_quality: diff --git a/bin/analyze-log.py b/bin/analyze-log.py index 7d6f5bc..e3cd7aa 100755 --- a/bin/analyze-log.py +++ b/bin/analyze-log.py @@ -278,15 +278,18 @@ def main(): if "table" in args.show_quality or "all" in args.show_quality: if xv_method is not None: - print(f"Model error after cross validation ({xv_method}, {xv_count}):") + print( + f"Model error ({args.error_metric}) after cross validation ({xv_method}, {xv_count}):" + ) else: - print("Model error on training data:") + print(f"Model error ({args.error_metric}) on training data:") dfatool.cli.model_quality_table( lut=lut_quality, model=analytic_quality, static=static_quality, model_info=param_info, xv_method=xv_method, + error_metric=args.error_metric, ) if args.export_model: @@ -117,18 +117,20 @@ def print_model_size(model): ) -def format_quality_measures(result, measure="smape", col_len=8): - if measure in result and result[measure] is not np.nan: - if measure.endswith("pe"): +def format_quality_measures(result, error_metric="smape", col_len=8): + if error_metric in result and result[error_metric] is not np.nan: + if error_metric.endswith("pe"): unit = "%" else: unit = " " - return f"{result[measure]:{col_len-1}.2f}{unit}" + return f"{result[error_metric]:{col_len-1}.2f}{unit}" else: return f"""{result["mae"]:{col_len-1}.0f} """ -def model_quality_table(lut, model, static, model_info, xv_method=None): +def model_quality_table( + lut, model, static, model_info, xv_method=None, error_metric="smape" +): key_len = len("Key") attr_len = len("Attribute") for key in static.keys(): @@ -173,7 +175,7 @@ def model_quality_table(lut, model, static, model_info, xv_method=None): ) ): result = results[key][attr] - buf += format_quality_measures(result) + buf += format_quality_measures(result, error_metric=error_metric) else: buf += f"""{"----":>7s} """ if type(model_info(key, attr)) is not StaticFunction: @@ -375,6 +377,36 @@ def add_standard_arguments(parser): "A function specified this way bypasses parameter detection: " "It is always assigned, even if the model seems to be independent of the parameters it references.", ) + parser.add_argument( + "--error-metric", + metavar="METRIC", + choices=[ + "mae", + "mape", + "smape", + "p50", + "p90", + "p95", + "p99", + "msd", + "rmsd", + "ssr", + "rsq", + ], + default="smape", + help="Error metric to use in --show-quality reports. In case a metric is undefined for a particular set of ground truth and prediction entries, dfatool falls back to mae.\n" + "MAE : Mean Absolute Error\n" + "MAPE : Mean Absolute Percentage Error\n" + "SMAPE : Symmetric Mean Absolute Percentage Error\n" + "p50 : Median (50th Percentile) Absolute Error\n" + "p90 : 90th Percentile Absolute Error\n" + "p95 : 95th Percentile Absolute Error\n" + "p99 : 99th Percentile Absolute Error\n" + "msd : Mean Square Deviation\n" + "rmsd : Root Mean Square Deviation\n" + "ssr : Sum of Squared Residuals\n" + "rsq : R² Score", + ) def parse_shift_function(param_name, param_shift): |