From b4f7b9e9407dbdc3be957fdfc6da0d7755b4b64d Mon Sep 17 00:00:00 2001 From: Daniel Friesel Date: Wed, 13 Oct 2021 15:38:27 +0200 Subject: Make CrossValidator independent of PTAModel signature --- lib/validation.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) (limited to 'lib/validation.py') diff --git a/lib/validation.py b/lib/validation.py index cfd4deb..e10ba6c 100644 --- a/lib/validation.py +++ b/lib/validation.py @@ -53,7 +53,7 @@ class CrossValidator: Reports the mean model error over all cross-validation runs. """ - def __init__(self, model_class, by_name, parameters, arg_count): + def __init__(self, model_class, by_name, parameters, *args, **kwargs): """ Create a new CrossValidator object. @@ -62,7 +62,7 @@ class CrossValidator: arguments: model_class -- model class/type used for model synthesis, e.g. PTAModel or AnalyticModel. model_class must have a - constructor accepting (by_name, parameters, arg_count) + constructor accepting (by_name, parameters, *args, **kwargs) and provide an `assess` method. by_name -- measurements aggregated by state/transition/function/... name. Layout: by_name[name][attribute] = list of data. Additionally, @@ -73,7 +73,8 @@ class CrossValidator: self.by_name = by_name self.names = sorted(by_name.keys()) self.parameters = sorted(parameters) - self.arg_count = arg_count + self.args = args + self.kwargs = kwargs def kfold(self, model_getter, k=10): """ @@ -231,8 +232,12 @@ class CrossValidator: for idx in validation_subset: validation[name]["param"].append(self.by_name[name]["param"][idx]) - training_data = self.model_class(training, self.parameters, self.arg_count) + training_data = self.model_class( + training, self.parameters, *self.args, **self.kwargs + ) training_model = model_getter(training_data) - validation_data = self.model_class(validation, self.parameters, self.arg_count) + validation_data = self.model_class( + validation, self.parameters, *self.args, **self.kwargs + ) return validation_data.assess(training_model) -- cgit v1.2.3