diff options
-rw-r--r-- | lib/parameters.py | 14 |
1 files changed, 6 insertions, 8 deletions
diff --git a/lib/parameters.py b/lib/parameters.py index 341b463..a3c5ab1 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -654,19 +654,17 @@ class ModelAttribute: df.LMTFunction, ): ret["decision tree/nodes"] = self.model_function.get_number_of_nodes() + ret["decision tree/leaves"] = self.model_function.get_number_of_leaves() + ret["decision tree/inner nodes"] = ( + ret["decision tree/nodes"] - ret["decision tree/leaves"] + ) ret["decision tree/max depth"] = self.model_function.get_max_depth() elif type(self.model_function) in (df.StaticFunction, df.AnalyticFunction): ret["decision tree/nodes"] = 1 + ret["decision tree/leaves"] = 1 + ret["decision tree/inner nodes"] = 0 ret["decision tree/max depth"] = 0 - if type(self.model_function) in ( - df.SplitFunction, - df.CARTFunction, - df.LMTFunction, - df.XGBoostFunction, - ): - ret["decision tree/leaves"] = self.model_function.get_number_of_leaves() - return ret def to_dot(self): |