summaryrefslogtreecommitdiff
path: root/lib/cli.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/cli.py')
-rw-r--r--lib/cli.py28
1 files changed, 28 insertions, 0 deletions
diff --git a/lib/cli.py b/lib/cli.py
index a1e4a58..a176e88 100644
--- a/lib/cli.py
+++ b/lib/cli.py
@@ -115,3 +115,31 @@ def add_standard_arguments(parser):
action="store_true",
help="Perform parameter-aware cross-validation: ensure that parameter values (and not just observations) are mutually exclusive between training and validation sets.",
)
+ parser.add_argument(
+ "--param-shift",
+ metavar="<key>=<+|-|*|/><value>;...",
+ type=str,
+ help="Adjust parameter values before passing them to model generation",
+ )
+
+
+def parse_param_shift(raw_param_shift):
+ shift_list = list()
+ for shift_pair in raw_param_shift.split(";"):
+ param_name, param_shift = shift_pair.split("=")
+ if param_shift.startswith("+"):
+ param_shift_value = float(param_shift[1:])
+ param_shift_function = lambda p: p + param_shift_value
+ elif param_shift.startswith("-"):
+ param_shift_value = float(param_shift[1:])
+ param_shift_function = lambda p: p - param_shift_value
+ elif param_shift.startswith("*"):
+ param_shift_value = float(param_shift[1:])
+ param_shift_function = lambda p: p * param_shift_value
+ elif param_shift.startswith("/"):
+ param_shift_value = float(param_shift[1:])
+ param_shift_function = lambda p: p / param_shift_value
+ else:
+ raise ValueError(f"Unsupported shift operation {param_name}={param_shift}")
+ shift_list.append((param_name, param_shift_function))
+ return shift_list