diff options
-rw-r--r-- | lib/functions.py | 6 |
1 files changed, 6 insertions, 0 deletions
diff --git a/lib/functions.py b/lib/functions.py index 2e7735a..38b8b38 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -314,6 +314,7 @@ class SplitFunction(ModelFunction): super().__init__(value, **kwargs) self.param_index = param_index self.child = child + self.use_weighted_avg = bool(int(os.getenv("DFATOOL_RMT_WEIGHTED_AVG", "0"))) def is_predictable(self, param_list): """ @@ -336,6 +337,11 @@ class SplitFunction(ModelFunction): param_value = param_list[self.param_index] if param_value in self.child: return self.child[param_value].eval(param_list) + if self.use_weighted_avg: + return np.average( + list(map(lambda child: child.eval(param_list), self.child.values())), + weights=list(map(lambda child: child.n_samples, self.child.values())), + ) return np.mean( list(map(lambda child: child.eval(param_list), self.child.values())) ) |