summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2022-11-29 13:16:45 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2022-11-29 13:16:45 +0100
commit48241ecff7fb76ff2139d9c261e146d2501b7610 (patch)
treea0ff10b3d10931d06240f18208f472cd71693657 /lib
parent6ec55b014b779550f86daa70c17fb24bf2c4d816 (diff)
dataref export: add number of inner decision tree nodes
Diffstat (limited to 'lib')
-rw-r--r--lib/parameters.py14
1 files changed, 6 insertions, 8 deletions
diff --git a/lib/parameters.py b/lib/parameters.py
index 341b463..a3c5ab1 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -654,19 +654,17 @@ class ModelAttribute:
df.LMTFunction,
):
ret["decision tree/nodes"] = self.model_function.get_number_of_nodes()
+ ret["decision tree/leaves"] = self.model_function.get_number_of_leaves()
+ ret["decision tree/inner nodes"] = (
+ ret["decision tree/nodes"] - ret["decision tree/leaves"]
+ )
ret["decision tree/max depth"] = self.model_function.get_max_depth()
elif type(self.model_function) in (df.StaticFunction, df.AnalyticFunction):
ret["decision tree/nodes"] = 1
+ ret["decision tree/leaves"] = 1
+ ret["decision tree/inner nodes"] = 0
ret["decision tree/max depth"] = 0
- if type(self.model_function) in (
- df.SplitFunction,
- df.CARTFunction,
- df.LMTFunction,
- df.XGBoostFunction,
- ):
- ret["decision tree/leaves"] = self.model_function.get_number_of_leaves()
-
return ret
def to_dot(self):