summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2021-11-17 14:35:10 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2021-11-17 14:35:10 +0100
commit91563e84617d096058f92a4c074bad0e3a1e0b12 (patch)
tree06a69556e14d8eb91a47c0f632bf5a66397322cd /lib
parent3090f274f698bd2e9c2fed2af2f730d9bf14fc07 (diff)
export number of dtree nodes to dataref
Diffstat (limited to 'lib')
-rw-r--r--lib/functions.py9
-rw-r--r--lib/parameters.py4
2 files changed, 13 insertions, 0 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 0a488dc..698d68c 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -324,6 +324,15 @@ class SplitFunction(ModelFunction):
ret.update(update)
return ret
+ def get_number_of_nodes(self):
+ ret = 1
+ for v in self.child.values():
+ if type(v) is SplitFunction:
+ ret += v.get_number_of_nodes()
+ else:
+ ret += 1
+ return ret
+
@classmethod
def from_json(cls, data):
assert data["type"] == "split"
diff --git a/lib/parameters.py b/lib/parameters.py
index 5ebf25c..40d0ba5 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -572,6 +572,10 @@ class ModelAttribute:
def to_dref(self, unit=None):
ret = {"mean": (self.mean, unit), "median": (self.median, unit)}
+
+ if type(self.model_function) is df.SplitFunction:
+ ret["decision tree nodes"] = self.model_function.get_number_of_nodes()
+
return ret
def webconf_function_map(self):