diff options
| -rwxr-xr-x | bin/analyze-log.py | 5 | ||||
| -rw-r--r-- | lib/cli.py | 51 | ||||
| -rw-r--r-- | lib/utils.py | 9 | 
3 files changed, 48 insertions, 17 deletions
diff --git a/bin/analyze-log.py b/bin/analyze-log.py index 81298d0..667921d 100755 --- a/bin/analyze-log.py +++ b/bin/analyze-log.py @@ -125,9 +125,12 @@ def main():      if args.param_shift:          param_shift = dfatool.cli.parse_param_shift(args.param_shift) -        print(param_shift)          dfatool.utils.shift_param_in_aggregate(by_name, parameter_names, param_shift) +    if args.normalize_nfp: +        norm = dfatool.cli.parse_nfp_normalization(args.normalize_nfp) +        dfatool.utils.normalize_nfp_in_aggregate(by_name, norm) +      model = AnalyticModel(          by_name,          parameter_names, @@ -272,6 +272,12 @@ def add_standard_arguments(parser):          help="Adjust parameter values before passing them to model generation",      )      parser.add_argument( +        "--normalize-nfp", +        metavar="<newkey>=<oldkey>=<+|-|*|/><value>|none-to-0;...", +        type=str, +        help="Normalize observation values before passing them to model generation", +    ) +    parser.add_argument(          "--filter-param",          metavar="<parameter name>=<parameter value>[,<parameter name>=<parameter value>...]",          type=str, @@ -287,25 +293,38 @@ def add_standard_arguments(parser):      ) +def parse_shift_function(param_name, param_shift): +    if param_shift.startswith("+"): +        param_shift_value = float(param_shift[1:]) +        return lambda p: p + param_shift_value +    elif param_shift.startswith("-"): +        param_shift_value = float(param_shift[1:]) +        return lambda p: p - param_shift_value +    elif param_shift.startswith("*"): +        param_shift_value = float(param_shift[1:]) +        return lambda p: p * param_shift_value +    elif param_shift.startswith("/"): +        param_shift_value = float(param_shift[1:]) +        return lambda p: p / param_shift_value +    elif param_shift == "none-to-0": +        return lambda p: p or 0 +    else: +        raise ValueError(f"Unsupported shift operation {param_name}={param_shift}") + + +def parse_nfp_normalization(raw_normalization): +    norm_list = list() +    for norm_pair in raw_normalization.split(";"): +        new_name, old_name, norm_val = norm_pair.split("=") +        norm_function = parse_shift_function(new_name, norm_val) +        norm_list.append((new_name, old_name, norm_function)) +    return norm_list + +  def parse_param_shift(raw_param_shift):      shift_list = list()      for shift_pair in raw_param_shift.split(";"):          param_name, param_shift = shift_pair.split("=") -        if param_shift.startswith("+"): -            param_shift_value = float(param_shift[1:]) -            param_shift_function = lambda p: p + param_shift_value -        elif param_shift.startswith("-"): -            param_shift_value = float(param_shift[1:]) -            param_shift_function = lambda p: p - param_shift_value -        elif param_shift.startswith("*"): -            param_shift_value = float(param_shift[1:]) -            param_shift_function = lambda p: p * param_shift_value -        elif param_shift.startswith("/"): -            param_shift_value = float(param_shift[1:]) -            param_shift_function = lambda p: p / param_shift_value -        elif param_shift == "none-to-0": -            param_shift_function = lambda p: p or 0 -        else: -            raise ValueError(f"Unsupported shift operation {param_name}={param_shift}") +        param_shift_function = parse_shift_function(param_name, param_shift)          shift_list.append((param_name, param_shift_function))      return shift_list diff --git a/lib/utils.py b/lib/utils.py index 5b40f51..ed5aa14 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -447,6 +447,15 @@ def by_param_to_by_name(by_param: dict) -> dict:      return by_name +def normalize_nfp_in_aggregate(aggregate, nfp_norm): +    for name in aggregate.keys(): +        for new_name, old_name, norm_function in nfp_norm: +            if old_name in aggregate[name]["attributes"]: +                aggregate[name][new_name] = norm_function(aggregate[name].pop(old_name)) +                aggregate[name]["attributes"].remove(old_name) +                aggregate[name]["attributes"].append(new_name) + +  def shift_param_in_aggregate(aggregate, parameters, parameter_shift):      """      Remove entries which do not have certain parameter values from `aggregate`.  | 
