diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-01-19 13:37:17 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-01-19 13:37:17 +0100 |
commit | 2a8aed73ba35107e35b2343670c01c5f760282b0 (patch) | |
tree | 4ae15a2d4a07d42a0687b6d33df3665d417ac8dd | |
parent | 1e54270e724f719e6b60ae02662cf7d8a175bf4c (diff) |
RMT: Set DFATOOL_RMT_WEIGHTED_AVG=1 for weighted average in queries
-rw-r--r-- | lib/functions.py | 6 |
1 files changed, 6 insertions, 0 deletions
diff --git a/lib/functions.py b/lib/functions.py index 2e7735a..38b8b38 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -314,6 +314,7 @@ class SplitFunction(ModelFunction): super().__init__(value, **kwargs) self.param_index = param_index self.child = child + self.use_weighted_avg = bool(int(os.getenv("DFATOOL_RMT_WEIGHTED_AVG", "0"))) def is_predictable(self, param_list): """ @@ -336,6 +337,11 @@ class SplitFunction(ModelFunction): param_value = param_list[self.param_index] if param_value in self.child: return self.child[param_value].eval(param_list) + if self.use_weighted_avg: + return np.average( + list(map(lambda child: child.eval(param_list), self.child.values())), + weights=list(map(lambda child: child.n_samples, self.child.values())), + ) return np.mean( list(map(lambda child: child.eval(param_list), self.child.values())) ) |