summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/cli.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/lib/cli.py b/lib/cli.py
index 6b5b796..e0cbc7f 100644
--- a/lib/cli.py
+++ b/lib/cli.py
@@ -255,7 +255,7 @@ def add_standard_arguments(parser):
)
parser.add_argument(
"--param-shift",
- metavar="<key>=<+|-|*|/><value>;...",
+ metavar="<key>=<+|-|*|/><value>|none-to-0;...",
type=str,
help="Adjust parameter values before passing them to model generation",
)
@@ -285,6 +285,8 @@ def parse_param_shift(raw_param_shift):
elif param_shift.startswith("/"):
param_shift_value = float(param_shift[1:])
param_shift_function = lambda p: p / param_shift_value
+ elif param_shift == "none-to-0":
+ param_shift_function = lambda p: p or 0
else:
raise ValueError(f"Unsupported shift operation {param_name}={param_shift}")
shift_list.append((param_name, param_shift_function))