summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2022-01-31 13:21:13 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2022-01-31 13:21:13 +0100
commit2d1247d25953e5d1479a6ddc6d7a7145f20a2cc5 (patch)
treee599fd9f0ce41c80309bad325516807728113154
parent1f391139530f8051a1ece4fcd9de46afe6afde06 (diff)
allow parameter values to be adjusted before modeling
(e.g. to ensure they're positive)
-rwxr-xr-xbin/analyze-archive.py6
-rw-r--r--lib/cli.py28
-rw-r--r--lib/utils.py21
3 files changed, 55 insertions, 0 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py
index 54ba1ef..5cc01f6 100755
--- a/bin/analyze-archive.py
+++ b/bin/analyze-archive.py
@@ -57,6 +57,7 @@ from dfatool.model import PTAModel
from dfatool.validation import CrossValidator
from dfatool.utils import (
filter_aggregate_by_param,
+ shift_param_in_aggregate,
detect_outliers_in_aggregate,
NpEncoder,
is_numeric,
@@ -630,6 +631,11 @@ if __name__ == "__main__":
)
filter_aggregate_by_param(by_name, parameters, args.filter_param)
+
+ if args.param_shift:
+ param_shift = dfatool.cli.parse_param_shift(args.param_shift)
+ shift_param_in_aggregate(by_name, parameters, param_shift)
+
detect_outliers_in_aggregate(
by_name, z_limit=args.z_score, remove_outliers=args.remove_outliers
)
diff --git a/lib/cli.py b/lib/cli.py
index a1e4a58..a176e88 100644
--- a/lib/cli.py
+++ b/lib/cli.py
@@ -115,3 +115,31 @@ def add_standard_arguments(parser):
action="store_true",
help="Perform parameter-aware cross-validation: ensure that parameter values (and not just observations) are mutually exclusive between training and validation sets.",
)
+ parser.add_argument(
+ "--param-shift",
+ metavar="<key>=<+|-|*|/><value>;...",
+ type=str,
+ help="Adjust parameter values before passing them to model generation",
+ )
+
+
+def parse_param_shift(raw_param_shift):
+ shift_list = list()
+ for shift_pair in raw_param_shift.split(";"):
+ param_name, param_shift = shift_pair.split("=")
+ if param_shift.startswith("+"):
+ param_shift_value = float(param_shift[1:])
+ param_shift_function = lambda p: p + param_shift_value
+ elif param_shift.startswith("-"):
+ param_shift_value = float(param_shift[1:])
+ param_shift_function = lambda p: p - param_shift_value
+ elif param_shift.startswith("*"):
+ param_shift_value = float(param_shift[1:])
+ param_shift_function = lambda p: p * param_shift_value
+ elif param_shift.startswith("/"):
+ param_shift_value = float(param_shift[1:])
+ param_shift_function = lambda p: p / param_shift_value
+ else:
+ raise ValueError(f"Unsupported shift operation {param_name}={param_shift}")
+ shift_list.append((param_name, param_shift_function))
+ return shift_list
diff --git a/lib/utils.py b/lib/utils.py
index 10f0172..7372995 100644
--- a/lib/utils.py
+++ b/lib/utils.py
@@ -314,6 +314,27 @@ def by_param_to_by_name(by_param: dict) -> dict:
return by_name
+def shift_param_in_aggregate(aggregate, parameters, parameter_shift):
+ """
+ Remove entries which do not have certain parameter values from `aggregate`.
+
+ :param aggregate: aggregated measurement data, must be a dict conforming to
+ aggregate[state or transition name]['param'] = (first parameter value, second parameter value, ...)
+ and
+ aggregate[state or transition name]['attributes'] = [list of keys with measurement data, e.g. 'power' or 'duration']
+ :param parameters: list of parameters, used to map parameter index to parameter name. parameters=['foo', ...] means 'foo' is the first parameter
+ :param parameter_filter: [[name, value], [name, value], ...] list of parameter values to keep, all others are removed. Values refer to normalizad parameter data.
+ """
+ for param_name, param_shift_function in parameter_shift:
+ param_index = parameters.index(param_name)
+ for name in aggregate.keys():
+ for param_list in aggregate[name]["param"]:
+ if param_list[param_index] is not None:
+ param_list[param_index] = param_shift_function(
+ param_list[param_index]
+ )
+
+
def filter_aggregate_by_param(aggregate, parameters, parameter_filter):
"""
Remove entries which do not have certain parameter values from `aggregate`.