diff options
Diffstat (limited to 'lib/cli.py')
-rw-r--r-- | lib/cli.py | 28 |
1 files changed, 28 insertions, 0 deletions
@@ -115,3 +115,31 @@ def add_standard_arguments(parser): action="store_true", help="Perform parameter-aware cross-validation: ensure that parameter values (and not just observations) are mutually exclusive between training and validation sets.", ) + parser.add_argument( + "--param-shift", + metavar="<key>=<+|-|*|/><value>;...", + type=str, + help="Adjust parameter values before passing them to model generation", + ) + + +def parse_param_shift(raw_param_shift): + shift_list = list() + for shift_pair in raw_param_shift.split(";"): + param_name, param_shift = shift_pair.split("=") + if param_shift.startswith("+"): + param_shift_value = float(param_shift[1:]) + param_shift_function = lambda p: p + param_shift_value + elif param_shift.startswith("-"): + param_shift_value = float(param_shift[1:]) + param_shift_function = lambda p: p - param_shift_value + elif param_shift.startswith("*"): + param_shift_value = float(param_shift[1:]) + param_shift_function = lambda p: p * param_shift_value + elif param_shift.startswith("/"): + param_shift_value = float(param_shift[1:]) + param_shift_function = lambda p: p / param_shift_value + else: + raise ValueError(f"Unsupported shift operation {param_name}={param_shift}") + shift_list.append((param_name, param_shift_function)) + return shift_list |