diff options
-rw-r--r-- | lib/model.py | 30 |
1 files changed, 9 insertions, 21 deletions
diff --git a/lib/model.py b/lib/model.py index 4b0f46d..003ca16 100644 --- a/lib/model.py +++ b/lib/model.py @@ -277,7 +277,7 @@ class AnalyticModel: return model_getter, info_getter - def assess(self, model_function): + def assess(self, model_function, ref=None): """ Calculate MAE, SMAPE, etc. of model_function for each by_name entry. @@ -291,22 +291,22 @@ class AnalyticModel: overfitting cannot be detected. """ detailed_results = {} - for name in self.names: + if ref is None: + ref = self.by_name + for name, elem in sorted(ref.items()): detailed_results[name] = {} - for attribute in self.attr_by_name[name].keys(): - data = self.attr_by_name[name][attribute].data - param_values = self.attr_by_name[name][attribute].param_values + for attribute in elem["attributes"]: predicted_data = np.array( list( map( lambda i: model_function( - name, attribute, param=param_values[i] + name, attribute, param=elem["param"][i] ), - range(len(data)), + range(len(elem[attribute])), ) ) ) - measures = regression_measures(predicted_data, data) + measures = regression_measures(predicted_data, elem[attribute]) detailed_results[name][attribute] = measures return {"by_name": detailed_results} @@ -808,22 +808,10 @@ class PTAModel(AnalyticModel): exclusive (e.g. by performing cross validation). Otherwise, overfitting cannot be detected. """ - detailed_results = {} if ref is None: ref = self.by_name + detailed_results = super().assess(model_function, ref=ref)["by_name"] for name, elem in sorted(ref.items()): - detailed_results[name] = {} - for key in elem["attributes"]: - predicted_data = np.array( - list( - map( - lambda i: model_function(name, key, param=elem["param"][i]), - range(len(elem[key])), - ) - ) - ) - measures = regression_measures(predicted_data, elem[key]) - detailed_results[name][key] = measures if elem["isa"] == "transition": predicted_data = np.array( list( |