summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/parameters.py14
1 files changed, 6 insertions, 8 deletions
diff --git a/lib/parameters.py b/lib/parameters.py
index 341b463..a3c5ab1 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -654,19 +654,17 @@ class ModelAttribute:
df.LMTFunction,
):
ret["decision tree/nodes"] = self.model_function.get_number_of_nodes()
+ ret["decision tree/leaves"] = self.model_function.get_number_of_leaves()
+ ret["decision tree/inner nodes"] = (
+ ret["decision tree/nodes"] - ret["decision tree/leaves"]
+ )
ret["decision tree/max depth"] = self.model_function.get_max_depth()
elif type(self.model_function) in (df.StaticFunction, df.AnalyticFunction):
ret["decision tree/nodes"] = 1
+ ret["decision tree/leaves"] = 1
+ ret["decision tree/inner nodes"] = 0
ret["decision tree/max depth"] = 0
- if type(self.model_function) in (
- df.SplitFunction,
- df.CARTFunction,
- df.LMTFunction,
- df.XGBoostFunction,
- ):
- ret["decision tree/leaves"] = self.model_function.get_number_of_leaves()
-
return ret
def to_dot(self):