From 31ed212674589f5806880281eb2ef9a006dfcfa9 Mon Sep 17 00:00:00 2001 From: Birte Kristina Friesel Date: Fri, 6 Jun 2025 12:36:10 +0200 Subject: regression_measures: support list arguments --- lib/utils.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'lib') diff --git a/lib/utils.py b/lib/utils.py index 48a29d8..fb76367 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -728,11 +728,18 @@ def regression_measures(predicted: np.ndarray, ground_truth: np.ndarray): rsq -- R^2 measure, see sklearn.metrics.r2_score count -- Number of values """ - if type(predicted) != np.ndarray: + + if type(predicted) is list: + predicted = np.array(predicted) + + if type(ground_truth) is list: + ground_truth = np.array(ground_truth) + + if type(predicted) is not np.ndarray: raise ValueError( "first arg ('predicted') must be ndarray, is {}".format(type(predicted)) ) - if type(ground_truth) != np.ndarray: + if type(ground_truth) is not np.ndarray: raise ValueError( "second arg ('ground_truth') must be ndarray, is {}".format( type(ground_truth) -- cgit v1.2.3