diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2022-01-31 13:21:13 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2022-01-31 13:21:13 +0100 |
commit | 2d1247d25953e5d1479a6ddc6d7a7145f20a2cc5 (patch) | |
tree | e599fd9f0ce41c80309bad325516807728113154 | |
parent | 1f391139530f8051a1ece4fcd9de46afe6afde06 (diff) |
allow parameter values to be adjusted before modeling
(e.g. to ensure they're positive)
-rwxr-xr-x | bin/analyze-archive.py | 6 | ||||
-rw-r--r-- | lib/cli.py | 28 | ||||
-rw-r--r-- | lib/utils.py | 21 |
3 files changed, 55 insertions, 0 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py index 54ba1ef..5cc01f6 100755 --- a/bin/analyze-archive.py +++ b/bin/analyze-archive.py @@ -57,6 +57,7 @@ from dfatool.model import PTAModel from dfatool.validation import CrossValidator from dfatool.utils import ( filter_aggregate_by_param, + shift_param_in_aggregate, detect_outliers_in_aggregate, NpEncoder, is_numeric, @@ -630,6 +631,11 @@ if __name__ == "__main__": ) filter_aggregate_by_param(by_name, parameters, args.filter_param) + + if args.param_shift: + param_shift = dfatool.cli.parse_param_shift(args.param_shift) + shift_param_in_aggregate(by_name, parameters, param_shift) + detect_outliers_in_aggregate( by_name, z_limit=args.z_score, remove_outliers=args.remove_outliers ) @@ -115,3 +115,31 @@ def add_standard_arguments(parser): action="store_true", help="Perform parameter-aware cross-validation: ensure that parameter values (and not just observations) are mutually exclusive between training and validation sets.", ) + parser.add_argument( + "--param-shift", + metavar="<key>=<+|-|*|/><value>;...", + type=str, + help="Adjust parameter values before passing them to model generation", + ) + + +def parse_param_shift(raw_param_shift): + shift_list = list() + for shift_pair in raw_param_shift.split(";"): + param_name, param_shift = shift_pair.split("=") + if param_shift.startswith("+"): + param_shift_value = float(param_shift[1:]) + param_shift_function = lambda p: p + param_shift_value + elif param_shift.startswith("-"): + param_shift_value = float(param_shift[1:]) + param_shift_function = lambda p: p - param_shift_value + elif param_shift.startswith("*"): + param_shift_value = float(param_shift[1:]) + param_shift_function = lambda p: p * param_shift_value + elif param_shift.startswith("/"): + param_shift_value = float(param_shift[1:]) + param_shift_function = lambda p: p / param_shift_value + else: + raise ValueError(f"Unsupported shift operation {param_name}={param_shift}") + shift_list.append((param_name, param_shift_function)) + return shift_list diff --git a/lib/utils.py b/lib/utils.py index 10f0172..7372995 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -314,6 +314,27 @@ def by_param_to_by_name(by_param: dict) -> dict: return by_name +def shift_param_in_aggregate(aggregate, parameters, parameter_shift): + """ + Remove entries which do not have certain parameter values from `aggregate`. + + :param aggregate: aggregated measurement data, must be a dict conforming to + aggregate[state or transition name]['param'] = (first parameter value, second parameter value, ...) + and + aggregate[state or transition name]['attributes'] = [list of keys with measurement data, e.g. 'power' or 'duration'] + :param parameters: list of parameters, used to map parameter index to parameter name. parameters=['foo', ...] means 'foo' is the first parameter + :param parameter_filter: [[name, value], [name, value], ...] list of parameter values to keep, all others are removed. Values refer to normalizad parameter data. + """ + for param_name, param_shift_function in parameter_shift: + param_index = parameters.index(param_name) + for name in aggregate.keys(): + for param_list in aggregate[name]["param"]: + if param_list[param_index] is not None: + param_list[param_index] = param_shift_function( + param_list[param_index] + ) + + def filter_aggregate_by_param(aggregate, parameters, parameter_filter): """ Remove entries which do not have certain parameter values from `aggregate`. |