summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-01-19 13:37:17 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-01-19 13:37:17 +0100
commit2a8aed73ba35107e35b2343670c01c5f760282b0 (patch)
tree4ae15a2d4a07d42a0687b6d33df3665d417ac8dd
parent1e54270e724f719e6b60ae02662cf7d8a175bf4c (diff)
RMT: Set DFATOOL_RMT_WEIGHTED_AVG=1 for weighted average in queries
-rw-r--r--lib/functions.py6
1 files changed, 6 insertions, 0 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 2e7735a..38b8b38 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -314,6 +314,7 @@ class SplitFunction(ModelFunction):
super().__init__(value, **kwargs)
self.param_index = param_index
self.child = child
+ self.use_weighted_avg = bool(int(os.getenv("DFATOOL_RMT_WEIGHTED_AVG", "0")))
def is_predictable(self, param_list):
"""
@@ -336,6 +337,11 @@ class SplitFunction(ModelFunction):
param_value = param_list[self.param_index]
if param_value in self.child:
return self.child[param_value].eval(param_list)
+ if self.use_weighted_avg:
+ return np.average(
+ list(map(lambda child: child.eval(param_list), self.child.values())),
+ weights=list(map(lambda child: child.n_samples, self.child.values())),
+ )
return np.mean(
list(map(lambda child: child.eval(param_list), self.child.values()))
)