summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/functions.py6
1 files changed, 6 insertions, 0 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 2e7735a..38b8b38 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -314,6 +314,7 @@ class SplitFunction(ModelFunction):
super().__init__(value, **kwargs)
self.param_index = param_index
self.child = child
+ self.use_weighted_avg = bool(int(os.getenv("DFATOOL_RMT_WEIGHTED_AVG", "0")))
def is_predictable(self, param_list):
"""
@@ -336,6 +337,11 @@ class SplitFunction(ModelFunction):
param_value = param_list[self.param_index]
if param_value in self.child:
return self.child[param_value].eval(param_list)
+ if self.use_weighted_avg:
+ return np.average(
+ list(map(lambda child: child.eval(param_list), self.child.values())),
+ weights=list(map(lambda child: child.n_samples, self.child.values())),
+ )
return np.mean(
list(map(lambda child: child.eval(param_list), self.child.values()))
)