summaryrefslogtreecommitdiff
path: root/lib/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/utils.py')
-rw-r--r--lib/utils.py27
1 files changed, 27 insertions, 0 deletions
diff --git a/lib/utils.py b/lib/utils.py
index 2ed3d6e..c8f31c2 100644
--- a/lib/utils.py
+++ b/lib/utils.py
@@ -199,6 +199,33 @@ def by_name_to_by_param(by_name: dict):
return by_param
+def by_param_to_by_name(by_param: dict) -> dict:
+ """
+ Convert aggregation by name and parameter values to aggregation by name only.
+ """
+ by_name = dict()
+ for param_key in by_param.keys():
+ name, _ = param_key
+ if name not in by_name:
+ by_name[name] = dict()
+ for key in by_param[param_key].keys():
+ by_name[name][key] = list()
+ by_name[name]["attributes"] = by_param[param_key]["attributes"]
+ # special case for PTA models
+ if "isa" in by_param[param_key]:
+ by_name[name]["isa"] = by_param[param_key]["isa"]
+ for attribute in by_name[name]["attributes"]:
+ by_name[name][attribute].extend(by_param[param_key][attribute])
+ if "supports" in by_param[param_key]:
+ for support in by_param[param_key]["supports"]:
+ by_name[name][support].extend(by_param[param_key][support])
+ by_name[name]["param"].extend(by_param[param_key]["param"])
+ for name in by_name.keys():
+ for attribute in by_name[name]["attributes"]:
+ by_name[name][attribute] = np.array(by_name[name][attribute])
+ return by_name
+
+
def filter_aggregate_by_param(aggregate, parameters, parameter_filter):
"""
Remove entries which do not have certain parameter values from `aggregate`.