From 2a8aed73ba35107e35b2343670c01c5f760282b0 Mon Sep 17 00:00:00 2001 From: Birte Kristina Friesel Date: Fri, 19 Jan 2024 13:37:17 +0100 Subject: RMT: Set DFATOOL_RMT_WEIGHTED_AVG=1 for weighted average in queries --- lib/functions.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'lib') diff --git a/lib/functions.py b/lib/functions.py index 2e7735a..38b8b38 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -314,6 +314,7 @@ class SplitFunction(ModelFunction): super().__init__(value, **kwargs) self.param_index = param_index self.child = child + self.use_weighted_avg = bool(int(os.getenv("DFATOOL_RMT_WEIGHTED_AVG", "0"))) def is_predictable(self, param_list): """ @@ -336,6 +337,11 @@ class SplitFunction(ModelFunction): param_value = param_list[self.param_index] if param_value in self.child: return self.child[param_value].eval(param_list) + if self.use_weighted_avg: + return np.average( + list(map(lambda child: child.eval(param_list), self.child.values())), + weights=list(map(lambda child: child.n_samples, self.child.values())), + ) return np.mean( list(map(lambda child: child.eval(param_list), self.child.values())) ) -- cgit v1.2.3