summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2022-02-07 14:20:43 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2022-02-07 14:20:43 +0100
commite7d6d9d98be7035862405c79ee0bd4cb6fda8943 (patch)
tree5a3a0217992df30cb139d0a83b6b8ad09cceba33 /lib
parent9b25065d34fb79ba08aed603447eafd9937b8f73 (diff)
SplitFunction: Use mean of children for predicting unknown values
Diffstat (limited to 'lib')
-rw-r--r--lib/functions.py14
1 files changed, 10 insertions, 4 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 399072b..b934e79 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -302,13 +302,19 @@ class SplitFunction(ModelFunction):
(e.g. None).
"""
param_value = param_list[self.param_index]
- if param_value not in self.child:
- return False
- return self.child[param_value].is_predictable(param_list)
+ if param_value in self.child:
+ return self.child[param_value].is_predictable(param_list)
+ return all(
+ map(lambda child: child.is_predictable(param_list), self.child.values())
+ )
def eval(self, param_list):
param_value = param_list[self.param_index]
- return self.child[param_value].eval(param_list)
+ if param_value in self.child:
+ return self.child[param_value].eval(param_list)
+ return np.mean(
+ list(map(lambda child: child.eval(param_list), self.child.values()))
+ )
def webconf_function_map(self):
ret = list()