summaryrefslogtreecommitdiff
path: root/lib/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/model.py')
-rw-r--r--lib/model.py17
1 files changed, 13 insertions, 4 deletions
diff --git a/lib/model.py b/lib/model.py
index 227a323..558f049 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -399,7 +399,7 @@ class AnalyticModel:
return model_getter, info_getter
- def assess(self, model_function, ref=None):
+ def assess(self, model_function, ref=None, return_raw=False):
"""
Calculate MAE, SMAPE, etc. of model_function for each by_name entry.
@@ -412,11 +412,13 @@ class AnalyticModel:
exclusive (e.g. by performing cross validation). Otherwise,
overfitting cannot be detected.
"""
- detailed_results = {}
+ detailed_results = dict()
+ raw_results = dict()
if ref is None:
ref = self.by_name
for name, elem in sorted(ref.items()):
- detailed_results[name] = {}
+ detailed_results[name] = dict()
+ raw_results[name] = dict()
for attribute in elem["attributes"]:
predicted_data = np.array(
list(
@@ -430,7 +432,14 @@ class AnalyticModel:
)
measures = regression_measures(predicted_data, elem[attribute])
detailed_results[name][attribute] = measures
-
+ if return_raw:
+ raw_results[name][attribute] = {
+ "ground_truth": list(elem[attribute]),
+ "model_output": list(predicted_data),
+ }
+
+ if return_raw:
+ return detailed_results, raw_results
return detailed_results
def build_dtree(