diff options
-rwxr-xr-x | bin/analyze-log.py | 5 | ||||
-rw-r--r-- | lib/cli.py | 51 | ||||
-rw-r--r-- | lib/utils.py | 9 |
3 files changed, 48 insertions, 17 deletions
diff --git a/bin/analyze-log.py b/bin/analyze-log.py index 81298d0..667921d 100755 --- a/bin/analyze-log.py +++ b/bin/analyze-log.py @@ -125,9 +125,12 @@ def main(): if args.param_shift: param_shift = dfatool.cli.parse_param_shift(args.param_shift) - print(param_shift) dfatool.utils.shift_param_in_aggregate(by_name, parameter_names, param_shift) + if args.normalize_nfp: + norm = dfatool.cli.parse_nfp_normalization(args.normalize_nfp) + dfatool.utils.normalize_nfp_in_aggregate(by_name, norm) + model = AnalyticModel( by_name, parameter_names, @@ -272,6 +272,12 @@ def add_standard_arguments(parser): help="Adjust parameter values before passing them to model generation", ) parser.add_argument( + "--normalize-nfp", + metavar="<newkey>=<oldkey>=<+|-|*|/><value>|none-to-0;...", + type=str, + help="Normalize observation values before passing them to model generation", + ) + parser.add_argument( "--filter-param", metavar="<parameter name>=<parameter value>[,<parameter name>=<parameter value>...]", type=str, @@ -287,25 +293,38 @@ def add_standard_arguments(parser): ) +def parse_shift_function(param_name, param_shift): + if param_shift.startswith("+"): + param_shift_value = float(param_shift[1:]) + return lambda p: p + param_shift_value + elif param_shift.startswith("-"): + param_shift_value = float(param_shift[1:]) + return lambda p: p - param_shift_value + elif param_shift.startswith("*"): + param_shift_value = float(param_shift[1:]) + return lambda p: p * param_shift_value + elif param_shift.startswith("/"): + param_shift_value = float(param_shift[1:]) + return lambda p: p / param_shift_value + elif param_shift == "none-to-0": + return lambda p: p or 0 + else: + raise ValueError(f"Unsupported shift operation {param_name}={param_shift}") + + +def parse_nfp_normalization(raw_normalization): + norm_list = list() + for norm_pair in raw_normalization.split(";"): + new_name, old_name, norm_val = norm_pair.split("=") + norm_function = parse_shift_function(new_name, norm_val) + norm_list.append((new_name, old_name, norm_function)) + return norm_list + + def parse_param_shift(raw_param_shift): shift_list = list() for shift_pair in raw_param_shift.split(";"): param_name, param_shift = shift_pair.split("=") - if param_shift.startswith("+"): - param_shift_value = float(param_shift[1:]) - param_shift_function = lambda p: p + param_shift_value - elif param_shift.startswith("-"): - param_shift_value = float(param_shift[1:]) - param_shift_function = lambda p: p - param_shift_value - elif param_shift.startswith("*"): - param_shift_value = float(param_shift[1:]) - param_shift_function = lambda p: p * param_shift_value - elif param_shift.startswith("/"): - param_shift_value = float(param_shift[1:]) - param_shift_function = lambda p: p / param_shift_value - elif param_shift == "none-to-0": - param_shift_function = lambda p: p or 0 - else: - raise ValueError(f"Unsupported shift operation {param_name}={param_shift}") + param_shift_function = parse_shift_function(param_name, param_shift) shift_list.append((param_name, param_shift_function)) return shift_list diff --git a/lib/utils.py b/lib/utils.py index 5b40f51..ed5aa14 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -447,6 +447,15 @@ def by_param_to_by_name(by_param: dict) -> dict: return by_name +def normalize_nfp_in_aggregate(aggregate, nfp_norm): + for name in aggregate.keys(): + for new_name, old_name, norm_function in nfp_norm: + if old_name in aggregate[name]["attributes"]: + aggregate[name][new_name] = norm_function(aggregate[name].pop(old_name)) + aggregate[name]["attributes"].remove(old_name) + aggregate[name]["attributes"].append(new_name) + + def shift_param_in_aggregate(aggregate, parameters, parameter_shift): """ Remove entries which do not have certain parameter values from `aggregate`. |