summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-01-25 10:18:55 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-01-25 10:18:55 +0100
commitc2e1e6f4034e7800f8b151fa2e971478d4376347 (patch)
tree773c9c86cf7a85298c90ccd614aa60b7536f9ca3 /lib
parenteb34056dbf9e10be7bed3835600f4edd7f7a1ef3 (diff)
add an optional XV progress bar
Diffstat (limited to 'lib')
-rw-r--r--lib/cli.py5
-rw-r--r--lib/validation.py14
2 files changed, 19 insertions, 0 deletions
diff --git a/lib/cli.py b/lib/cli.py
index 51f77d3..abeb3d3 100644
--- a/lib/cli.py
+++ b/lib/cli.py
@@ -621,6 +621,11 @@ def add_standard_arguments(parser):
action="store_true",
help="Build regression tree without checking whether static/analytic functions are sufficient.",
)
+ parser.add_argument(
+ "--progress",
+ action="store_true",
+ help="Show progress bars while executing compute-intensive tasks such as cross-validation.",
+ )
def parse_shift_function(param_name, param_shift):
diff --git a/lib/validation.py b/lib/validation.py
index 68f4ddb..6203003 100644
--- a/lib/validation.py
+++ b/lib/validation.py
@@ -105,6 +105,7 @@ class CrossValidator:
self.parameters = sorted(parameters)
self.parameter_aware = False
self.export_filename = None
+ self.show_progress = kwargs.pop("show_progress", False)
self.args = args
self.kwargs = kwargs
@@ -217,6 +218,15 @@ class CrossValidator:
ret = dict()
models = list()
+ if self.show_progress:
+ from progress.bar import Bar
+
+ if static:
+ title = "Static XV"
+ else:
+ title = "Model XV"
+ bar = Bar(title, max=len(training_and_validation_sets))
+
for name in self.names:
ret[name] = dict()
for attribute in self.by_name[name]["attributes"]:
@@ -226,6 +236,8 @@ class CrossValidator:
}
for training_and_validation_by_name in training_and_validation_sets:
+ if self.show_progress:
+ bar.next()
model, (res, raw) = self._single_xv(
model_getter, training_and_validation_by_name, static=static
)
@@ -238,6 +250,8 @@ class CrossValidator:
ret[name][attribute]["modelOutput"].extend(
raw[name]["attribute"][attribute]["modelOutput"]
)
+ if self.show_progress:
+ bar.finish()
if self.export_filename:
import json