summaryrefslogtreecommitdiff
path: root/lib/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/utils.py')
-rw-r--r--lib/utils.py44
1 files changed, 41 insertions, 3 deletions
diff --git a/lib/utils.py b/lib/utils.py
index 4850a53..fb76367 100644
--- a/lib/utils.py
+++ b/lib/utils.py
@@ -48,6 +48,8 @@ def running_mean(x: np.ndarray, N: int) -> np.ndarray:
def human_readable(value, unit):
+ if value is None:
+ return value
for prefix, factor in (
("p", 1e-12),
("n", 1e-9),
@@ -55,6 +57,8 @@ def human_readable(value, unit):
("m", 1e-3),
("", 1),
("k", 1e3),
+ ("M", 1e6),
+ ("G", 1e9),
):
if value < 1e3 * factor:
return "{:.2f} {}{}".format(value * (1 / factor), prefix, unit)
@@ -150,7 +154,7 @@ def parse_conf_str(conf_str):
"""
conf_dict = dict()
for option in conf_str.split(","):
- key, value = option.split("=")
+ key, value = option.strip().split("=")
conf_dict[key] = soft_cast_float(value)
return conf_dict
@@ -205,6 +209,18 @@ def param_slice_eq(a, b, index):
return False
+def param_eq_or_none(a, b):
+ """
+ Check if by_param keys a and b are identical, allowing a None in a to match any key in b.
+ """
+ set_keys = tuple(filter(lambda i: a[i] is not None, range(len(a))))
+ a_not_none = tuple(map(lambda i: a[i], set_keys))
+ b_not_none = tuple(map(lambda i: b[i], set_keys))
+ if a_not_none == b_not_none:
+ return True
+ return False
+
+
def match_parameter_values(input_param: dict, match_param: dict):
"""
Check whether one of the paramaters in `input_param` has the same value in `match_param`.
@@ -302,6 +318,21 @@ def param_dict_to_list(param_dict, parameter_names, default=None):
return ret
+def param_dict_to_str(param_dict):
+ ret = list()
+ for parameter_name in sorted(param_dict.keys()):
+ ret.append(f"{parameter_name}={param_dict[parameter_name]}")
+ return " ".join(ret)
+
+
+def param_str_to_dict(param_str):
+ ret = dict()
+ for param_pair in param_str.split():
+ key, value = param_pair.split("=")
+ ret[key] = soft_cast_int_or_float(value)
+ return ret
+
+
def observations_enum_to_bool(observations: list, kconfig=False):
"""
Convert enum / categorical observations to boolean-only ones.
@@ -697,11 +728,18 @@ def regression_measures(predicted: np.ndarray, ground_truth: np.ndarray):
rsq -- R^2 measure, see sklearn.metrics.r2_score
count -- Number of values
"""
- if type(predicted) != np.ndarray:
+
+ if type(predicted) is list:
+ predicted = np.array(predicted)
+
+ if type(ground_truth) is list:
+ ground_truth = np.array(ground_truth)
+
+ if type(predicted) is not np.ndarray:
raise ValueError(
"first arg ('predicted') must be ndarray, is {}".format(type(predicted))
)
- if type(ground_truth) != np.ndarray:
+ if type(ground_truth) is not np.ndarray:
raise ValueError(
"second arg ('ground_truth') must be ndarray, is {}".format(
type(ground_truth)