diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2022-02-25 16:36:13 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2022-02-25 16:36:13 +0100 |
commit | 2c3e1ddceb88fd7ea6c3090fc48d27407ce751b1 (patch) | |
tree | 82f99c37a44cd24c87ecfcca62079825694cd876 /lib | |
parent | 30b7b17e5f49ad91e16104e9d1ab3f12ef72d4fe (diff) |
add --export-raw-predictions
Diffstat (limited to 'lib')
-rw-r--r-- | lib/cli.py | 6 | ||||
-rw-r--r-- | lib/model.py | 17 |
2 files changed, 19 insertions, 4 deletions
@@ -184,6 +184,12 @@ def add_standard_arguments(parser): help="Export raw cross-validation results to FILE for later analysis (e.g. to compare different modeling approaches by means of a t-test)", ) parser.add_argument( + "--export-raw-predictions", + metavar="FILE", + type=str, + help="Export raw model error data (i.e., ground truth vs. model output) to FILE for later analysis (e.g. to compare different modeling approaches by means of a t-test)", + ) + parser.add_argument( "--info", action="store_true", help="Show benchmark information (number of measurements, parameter values, ...)", diff --git a/lib/model.py b/lib/model.py index 227a323..558f049 100644 --- a/lib/model.py +++ b/lib/model.py @@ -399,7 +399,7 @@ class AnalyticModel: return model_getter, info_getter - def assess(self, model_function, ref=None): + def assess(self, model_function, ref=None, return_raw=False): """ Calculate MAE, SMAPE, etc. of model_function for each by_name entry. @@ -412,11 +412,13 @@ class AnalyticModel: exclusive (e.g. by performing cross validation). Otherwise, overfitting cannot be detected. """ - detailed_results = {} + detailed_results = dict() + raw_results = dict() if ref is None: ref = self.by_name for name, elem in sorted(ref.items()): - detailed_results[name] = {} + detailed_results[name] = dict() + raw_results[name] = dict() for attribute in elem["attributes"]: predicted_data = np.array( list( @@ -430,7 +432,14 @@ class AnalyticModel: ) measures = regression_measures(predicted_data, elem[attribute]) detailed_results[name][attribute] = measures - + if return_raw: + raw_results[name][attribute] = { + "ground_truth": list(elem[attribute]), + "model_output": list(predicted_data), + } + + if return_raw: + return detailed_results, raw_results return detailed_results def build_dtree( |