From 2c3e1ddceb88fd7ea6c3090fc48d27407ce751b1 Mon Sep 17 00:00:00 2001 From: Daniel Friesel Date: Fri, 25 Feb 2022 16:36:13 +0100 Subject: add --export-raw-predictions --- lib/cli.py | 6 ++++++ lib/model.py | 17 +++++++++++++---- 2 files changed, 19 insertions(+), 4 deletions(-) (limited to 'lib') diff --git a/lib/cli.py b/lib/cli.py index 3af8cc1..f988421 100644 --- a/lib/cli.py +++ b/lib/cli.py @@ -183,6 +183,12 @@ def add_standard_arguments(parser): type=str, help="Export raw cross-validation results to FILE for later analysis (e.g. to compare different modeling approaches by means of a t-test)", ) + parser.add_argument( + "--export-raw-predictions", + metavar="FILE", + type=str, + help="Export raw model error data (i.e., ground truth vs. model output) to FILE for later analysis (e.g. to compare different modeling approaches by means of a t-test)", + ) parser.add_argument( "--info", action="store_true", diff --git a/lib/model.py b/lib/model.py index 227a323..558f049 100644 --- a/lib/model.py +++ b/lib/model.py @@ -399,7 +399,7 @@ class AnalyticModel: return model_getter, info_getter - def assess(self, model_function, ref=None): + def assess(self, model_function, ref=None, return_raw=False): """ Calculate MAE, SMAPE, etc. of model_function for each by_name entry. @@ -412,11 +412,13 @@ class AnalyticModel: exclusive (e.g. by performing cross validation). Otherwise, overfitting cannot be detected. """ - detailed_results = {} + detailed_results = dict() + raw_results = dict() if ref is None: ref = self.by_name for name, elem in sorted(ref.items()): - detailed_results[name] = {} + detailed_results[name] = dict() + raw_results[name] = dict() for attribute in elem["attributes"]: predicted_data = np.array( list( @@ -430,7 +432,14 @@ class AnalyticModel: ) measures = regression_measures(predicted_data, elem[attribute]) detailed_results[name][attribute] = measures - + if return_raw: + raw_results[name][attribute] = { + "ground_truth": list(elem[attribute]), + "model_output": list(predicted_data), + } + + if return_raw: + return detailed_results, raw_results return detailed_results def build_dtree( -- cgit v1.2.3