summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/validation.py15
1 files changed, 10 insertions, 5 deletions
diff --git a/lib/validation.py b/lib/validation.py
index cfd4deb..e10ba6c 100644
--- a/lib/validation.py
+++ b/lib/validation.py
@@ -53,7 +53,7 @@ class CrossValidator:
Reports the mean model error over all cross-validation runs.
"""
- def __init__(self, model_class, by_name, parameters, arg_count):
+ def __init__(self, model_class, by_name, parameters, *args, **kwargs):
"""
Create a new CrossValidator object.
@@ -62,7 +62,7 @@ class CrossValidator:
arguments:
model_class -- model class/type used for model synthesis,
e.g. PTAModel or AnalyticModel. model_class must have a
- constructor accepting (by_name, parameters, arg_count)
+ constructor accepting (by_name, parameters, *args, **kwargs)
and provide an `assess` method.
by_name -- measurements aggregated by state/transition/function/... name.
Layout: by_name[name][attribute] = list of data. Additionally,
@@ -73,7 +73,8 @@ class CrossValidator:
self.by_name = by_name
self.names = sorted(by_name.keys())
self.parameters = sorted(parameters)
- self.arg_count = arg_count
+ self.args = args
+ self.kwargs = kwargs
def kfold(self, model_getter, k=10):
"""
@@ -231,8 +232,12 @@ class CrossValidator:
for idx in validation_subset:
validation[name]["param"].append(self.by_name[name]["param"][idx])
- training_data = self.model_class(training, self.parameters, self.arg_count)
+ training_data = self.model_class(
+ training, self.parameters, *self.args, **self.kwargs
+ )
training_model = model_getter(training_data)
- validation_data = self.model_class(validation, self.parameters, self.arg_count)
+ validation_data = self.model_class(
+ validation, self.parameters, *self.args, **self.kwargs
+ )
return validation_data.assess(training_model)