diff options
Diffstat (limited to 'lib/model.py')
-rw-r--r-- | lib/model.py | 17 |
1 files changed, 13 insertions, 4 deletions
diff --git a/lib/model.py b/lib/model.py index 227a323..558f049 100644 --- a/lib/model.py +++ b/lib/model.py @@ -399,7 +399,7 @@ class AnalyticModel: return model_getter, info_getter - def assess(self, model_function, ref=None): + def assess(self, model_function, ref=None, return_raw=False): """ Calculate MAE, SMAPE, etc. of model_function for each by_name entry. @@ -412,11 +412,13 @@ class AnalyticModel: exclusive (e.g. by performing cross validation). Otherwise, overfitting cannot be detected. """ - detailed_results = {} + detailed_results = dict() + raw_results = dict() if ref is None: ref = self.by_name for name, elem in sorted(ref.items()): - detailed_results[name] = {} + detailed_results[name] = dict() + raw_results[name] = dict() for attribute in elem["attributes"]: predicted_data = np.array( list( @@ -430,7 +432,14 @@ class AnalyticModel: ) measures = regression_measures(predicted_data, elem[attribute]) detailed_results[name][attribute] = measures - + if return_raw: + raw_results[name][attribute] = { + "ground_truth": list(elem[attribute]), + "model_output": list(predicted_data), + } + + if return_raw: + return detailed_results, raw_results return detailed_results def build_dtree( |