diff options
Diffstat (limited to 'lib/utils.py')
-rw-r--r-- | lib/utils.py | 44 |
1 files changed, 41 insertions, 3 deletions
diff --git a/lib/utils.py b/lib/utils.py index 4850a53..fb76367 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -48,6 +48,8 @@ def running_mean(x: np.ndarray, N: int) -> np.ndarray: def human_readable(value, unit): + if value is None: + return value for prefix, factor in ( ("p", 1e-12), ("n", 1e-9), @@ -55,6 +57,8 @@ def human_readable(value, unit): ("m", 1e-3), ("", 1), ("k", 1e3), + ("M", 1e6), + ("G", 1e9), ): if value < 1e3 * factor: return "{:.2f} {}{}".format(value * (1 / factor), prefix, unit) @@ -150,7 +154,7 @@ def parse_conf_str(conf_str): """ conf_dict = dict() for option in conf_str.split(","): - key, value = option.split("=") + key, value = option.strip().split("=") conf_dict[key] = soft_cast_float(value) return conf_dict @@ -205,6 +209,18 @@ def param_slice_eq(a, b, index): return False +def param_eq_or_none(a, b): + """ + Check if by_param keys a and b are identical, allowing a None in a to match any key in b. + """ + set_keys = tuple(filter(lambda i: a[i] is not None, range(len(a)))) + a_not_none = tuple(map(lambda i: a[i], set_keys)) + b_not_none = tuple(map(lambda i: b[i], set_keys)) + if a_not_none == b_not_none: + return True + return False + + def match_parameter_values(input_param: dict, match_param: dict): """ Check whether one of the paramaters in `input_param` has the same value in `match_param`. @@ -302,6 +318,21 @@ def param_dict_to_list(param_dict, parameter_names, default=None): return ret +def param_dict_to_str(param_dict): + ret = list() + for parameter_name in sorted(param_dict.keys()): + ret.append(f"{parameter_name}={param_dict[parameter_name]}") + return " ".join(ret) + + +def param_str_to_dict(param_str): + ret = dict() + for param_pair in param_str.split(): + key, value = param_pair.split("=") + ret[key] = soft_cast_int_or_float(value) + return ret + + def observations_enum_to_bool(observations: list, kconfig=False): """ Convert enum / categorical observations to boolean-only ones. @@ -697,11 +728,18 @@ def regression_measures(predicted: np.ndarray, ground_truth: np.ndarray): rsq -- R^2 measure, see sklearn.metrics.r2_score count -- Number of values """ - if type(predicted) != np.ndarray: + + if type(predicted) is list: + predicted = np.array(predicted) + + if type(ground_truth) is list: + ground_truth = np.array(ground_truth) + + if type(predicted) is not np.ndarray: raise ValueError( "first arg ('predicted') must be ndarray, is {}".format(type(predicted)) ) - if type(ground_truth) != np.ndarray: + if type(ground_truth) is not np.ndarray: raise ValueError( "second arg ('ground_truth') must be ndarray, is {}".format( type(ground_truth) |