diff options
Diffstat (limited to 'lib/utils.py')
-rw-r--r-- | lib/utils.py | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/lib/utils.py b/lib/utils.py index 2ed3d6e..c8f31c2 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -199,6 +199,33 @@ def by_name_to_by_param(by_name: dict): return by_param +def by_param_to_by_name(by_param: dict) -> dict: + """ + Convert aggregation by name and parameter values to aggregation by name only. + """ + by_name = dict() + for param_key in by_param.keys(): + name, _ = param_key + if name not in by_name: + by_name[name] = dict() + for key in by_param[param_key].keys(): + by_name[name][key] = list() + by_name[name]["attributes"] = by_param[param_key]["attributes"] + # special case for PTA models + if "isa" in by_param[param_key]: + by_name[name]["isa"] = by_param[param_key]["isa"] + for attribute in by_name[name]["attributes"]: + by_name[name][attribute].extend(by_param[param_key][attribute]) + if "supports" in by_param[param_key]: + for support in by_param[param_key]["supports"]: + by_name[name][support].extend(by_param[param_key][support]) + by_name[name]["param"].extend(by_param[param_key]["param"]) + for name in by_name.keys(): + for attribute in by_name[name]["attributes"]: + by_name[name][attribute] = np.array(by_name[name][attribute]) + return by_name + + def filter_aggregate_by_param(aggregate, parameters, parameter_filter): """ Remove entries which do not have certain parameter values from `aggregate`. |