diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/validation.py | 15 |
1 files changed, 10 insertions, 5 deletions
diff --git a/lib/validation.py b/lib/validation.py index cfd4deb..e10ba6c 100644 --- a/lib/validation.py +++ b/lib/validation.py @@ -53,7 +53,7 @@ class CrossValidator: Reports the mean model error over all cross-validation runs. """ - def __init__(self, model_class, by_name, parameters, arg_count): + def __init__(self, model_class, by_name, parameters, *args, **kwargs): """ Create a new CrossValidator object. @@ -62,7 +62,7 @@ class CrossValidator: arguments: model_class -- model class/type used for model synthesis, e.g. PTAModel or AnalyticModel. model_class must have a - constructor accepting (by_name, parameters, arg_count) + constructor accepting (by_name, parameters, *args, **kwargs) and provide an `assess` method. by_name -- measurements aggregated by state/transition/function/... name. Layout: by_name[name][attribute] = list of data. Additionally, @@ -73,7 +73,8 @@ class CrossValidator: self.by_name = by_name self.names = sorted(by_name.keys()) self.parameters = sorted(parameters) - self.arg_count = arg_count + self.args = args + self.kwargs = kwargs def kfold(self, model_getter, k=10): """ @@ -231,8 +232,12 @@ class CrossValidator: for idx in validation_subset: validation[name]["param"].append(self.by_name[name]["param"][idx]) - training_data = self.model_class(training, self.parameters, self.arg_count) + training_data = self.model_class( + training, self.parameters, *self.args, **self.kwargs + ) training_model = model_getter(training_data) - validation_data = self.model_class(validation, self.parameters, self.arg_count) + validation_data = self.model_class( + validation, self.parameters, *self.args, **self.kwargs + ) return validation_data.assess(training_model) |