summaryrefslogtreecommitdiff
path: root/lib/validation.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/validation.py')
-rw-r--r--lib/validation.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/lib/validation.py b/lib/validation.py
index 68f4ddb..6203003 100644
--- a/lib/validation.py
+++ b/lib/validation.py
@@ -105,6 +105,7 @@ class CrossValidator:
self.parameters = sorted(parameters)
self.parameter_aware = False
self.export_filename = None
+ self.show_progress = kwargs.pop("show_progress", False)
self.args = args
self.kwargs = kwargs
@@ -217,6 +218,15 @@ class CrossValidator:
ret = dict()
models = list()
+ if self.show_progress:
+ from progress.bar import Bar
+
+ if static:
+ title = "Static XV"
+ else:
+ title = "Model XV"
+ bar = Bar(title, max=len(training_and_validation_sets))
+
for name in self.names:
ret[name] = dict()
for attribute in self.by_name[name]["attributes"]:
@@ -226,6 +236,8 @@ class CrossValidator:
}
for training_and_validation_by_name in training_and_validation_sets:
+ if self.show_progress:
+ bar.next()
model, (res, raw) = self._single_xv(
model_getter, training_and_validation_by_name, static=static
)
@@ -238,6 +250,8 @@ class CrossValidator:
ret[name][attribute]["modelOutput"].extend(
raw[name]["attribute"][attribute]["modelOutput"]
)
+ if self.show_progress:
+ bar.finish()
if self.export_filename:
import json