diff options
Diffstat (limited to 'lib/utils.py')
-rw-r--r-- | lib/utils.py | 44 |
1 files changed, 41 insertions, 3 deletions
diff --git a/lib/utils.py b/lib/utils.py index 7372995..7d5b5b9 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -211,6 +211,44 @@ def param_dict_to_list(param_dict, parameter_names, default=None): return ret +def observations_enum_to_bool(observations: list, kconfig=False): + """ + Convert enum / categorial observations to boolean-only ones. + 'observations' is altered in-place. + """ + distinct_param_values = dict() + replace_map = dict() + + for observation in observations: + for k, v in observation["param"].items(): + if not k in distinct_param_values: + distinct_param_values[k] = set() + if v is not None: + distinct_param_values[k].add(v) + + for param_name, distinct_values in distinct_param_values.items(): + if len(distinct_values) > 2 and not all( + map(lambda x: x is None or is_numeric(x), distinct_values) + ): + replace_map[param_name] = distinct_values + + for observation in observations: + binary_keys = set() + for k, v in replace_map.items(): + enum_value = observation["param"].pop(k) + for binary_key in v: + if kconfig: + if enum_value == binary_key: + observation["param"][binary_key] = "y" + else: + observation["param"][binary_key] = "n" + else: + observation["param"][binary_key] = int(enum_value == binary_key) + if binary_key in binary_keys: + print(f"Error: key '{binary_key}' is not unique") + binary_keys.add(binary_key) + + def observations_to_by_name(observations: list): """ Convert observation list to by_name dictionary for AnalyticModel analysis @@ -462,9 +500,9 @@ def regression_measures(predicted: np.ndarray, actual: np.ndarray): return {} measures = { "mae": np.mean(np.abs(deviations), dtype=np.float64), - "msd": np.mean(deviations ** 2, dtype=np.float64), - "rmsd": np.sqrt(np.mean(deviations ** 2), dtype=np.float64), - "ssr": np.sum(deviations ** 2, dtype=np.float64), + "msd": np.mean(deviations**2, dtype=np.float64), + "rmsd": np.sqrt(np.mean(deviations**2), dtype=np.float64), + "ssr": np.sum(deviations**2, dtype=np.float64), "rsq": r2_score(actual, predicted), "count": len(actual), } |