diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2025-06-06 12:36:10 +0200 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2025-06-06 12:36:10 +0200 |
commit | 31ed212674589f5806880281eb2ef9a006dfcfa9 (patch) | |
tree | 8578ec5b56f252bbd0e1ef433279d1e529174b2a /lib | |
parent | 40644c636b5084304d18fa7012fd91bc273e3973 (diff) |
regression_measures: support list arguments
Diffstat (limited to 'lib')
-rw-r--r-- | lib/utils.py | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/lib/utils.py b/lib/utils.py index 48a29d8..fb76367 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -728,11 +728,18 @@ def regression_measures(predicted: np.ndarray, ground_truth: np.ndarray): rsq -- R^2 measure, see sklearn.metrics.r2_score count -- Number of values """ - if type(predicted) != np.ndarray: + + if type(predicted) is list: + predicted = np.array(predicted) + + if type(ground_truth) is list: + ground_truth = np.array(ground_truth) + + if type(predicted) is not np.ndarray: raise ValueError( "first arg ('predicted') must be ndarray, is {}".format(type(predicted)) ) - if type(ground_truth) != np.ndarray: + if type(ground_truth) is not np.ndarray: raise ValueError( "second arg ('ground_truth') must be ndarray, is {}".format( type(ground_truth) |