summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2022-02-25 16:36:13 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2022-02-25 16:36:13 +0100
commit2c3e1ddceb88fd7ea6c3090fc48d27407ce751b1 (patch)
tree82f99c37a44cd24c87ecfcca62079825694cd876 /lib
parent30b7b17e5f49ad91e16104e9d1ab3f12ef72d4fe (diff)
add --export-raw-predictions
Diffstat (limited to 'lib')
-rw-r--r--lib/cli.py6
-rw-r--r--lib/model.py17
2 files changed, 19 insertions, 4 deletions
diff --git a/lib/cli.py b/lib/cli.py
index 3af8cc1..f988421 100644
--- a/lib/cli.py
+++ b/lib/cli.py
@@ -184,6 +184,12 @@ def add_standard_arguments(parser):
help="Export raw cross-validation results to FILE for later analysis (e.g. to compare different modeling approaches by means of a t-test)",
)
parser.add_argument(
+ "--export-raw-predictions",
+ metavar="FILE",
+ type=str,
+ help="Export raw model error data (i.e., ground truth vs. model output) to FILE for later analysis (e.g. to compare different modeling approaches by means of a t-test)",
+ )
+ parser.add_argument(
"--info",
action="store_true",
help="Show benchmark information (number of measurements, parameter values, ...)",
diff --git a/lib/model.py b/lib/model.py
index 227a323..558f049 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -399,7 +399,7 @@ class AnalyticModel:
return model_getter, info_getter
- def assess(self, model_function, ref=None):
+ def assess(self, model_function, ref=None, return_raw=False):
"""
Calculate MAE, SMAPE, etc. of model_function for each by_name entry.
@@ -412,11 +412,13 @@ class AnalyticModel:
exclusive (e.g. by performing cross validation). Otherwise,
overfitting cannot be detected.
"""
- detailed_results = {}
+ detailed_results = dict()
+ raw_results = dict()
if ref is None:
ref = self.by_name
for name, elem in sorted(ref.items()):
- detailed_results[name] = {}
+ detailed_results[name] = dict()
+ raw_results[name] = dict()
for attribute in elem["attributes"]:
predicted_data = np.array(
list(
@@ -430,7 +432,14 @@ class AnalyticModel:
)
measures = regression_measures(predicted_data, elem[attribute])
detailed_results[name][attribute] = measures
-
+ if return_raw:
+ raw_results[name][attribute] = {
+ "ground_truth": list(elem[attribute]),
+ "model_output": list(predicted_data),
+ }
+
+ if return_raw:
+ return detailed_results, raw_results
return detailed_results
def build_dtree(