From 943c8bee7a511e1aff5e1639a479ea868d7656a7 Mon Sep 17 00:00:00 2001
From: Daniel Friesel <daniel.friesel@uos.de>
Date: Fri, 21 Jan 2022 17:12:30 +0100
Subject: add depth / #nodes accessor to sklearn CART model

---
 lib/functions.py  | 26 ++++++++++++++++++++++++++
 lib/parameters.py |  4 ++--
 2 files changed, 28 insertions(+), 2 deletions(-)

(limited to 'lib')

diff --git a/lib/functions.py b/lib/functions.py
index 9bbc0e7..14893ad 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -461,6 +461,32 @@ class SKLearnRegressionFunction(ModelFunction):
         return predictions
 
 
+class CARTFunction(SKLearnRegressionFunction):
+    def get_number_of_nodes(self):
+        return self.regressor.tree_.node_count
+
+    def get_max_depth(self):
+        return self.regressor.get_depth()
+
+
+class XGBoostFunction(SKLearnRegressionFunction):
+    def get_number_of_nodes(self):
+        ret = 1
+        for v in self.child.values():
+            if type(v) is SplitFunction:
+                ret += v.get_number_of_nodes()
+            else:
+                ret += 1
+        return ret
+
+    def get_max_depth(self):
+        ret = [0]
+        for v in self.child.values():
+            if type(v) is SplitFunction:
+                ret.append(v.get_max_depth())
+        return 1 + max(ret)
+
+
 class AnalyticFunction(ModelFunction):
     """
     A multi-dimensional model function, generated from a string, which can be optimized using regression.
diff --git a/lib/parameters.py b/lib/parameters.py
index 8ca8f39..9350a06 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -666,7 +666,7 @@ class ModelAttribute:
     def to_dref(self, unit=None):
         ret = {"mean": (self.mean, unit), "median": (self.median, unit)}
 
-        if type(self.model_function) is df.SplitFunction:
+        if type(self.model_function) in (df.SplitFunction, df.CARTFunction):
             ret["decision tree/nodes"] = self.model_function.get_number_of_nodes()
             ret["decision tree/max depth"] = self.model_function.get_max_depth()
 
@@ -931,7 +931,7 @@ class ModelAttribute:
                 self.model_function = df.StaticFunction(np.mean(data))
                 return
             cart.fit(fit_parameters, data)
-            self.model_function = df.SKLearnRegressionFunction(
+            self.model_function = df.CARTFunction(
                 np.mean(data), cart, category_to_index, ignore_index
             )
             return
-- 
cgit v1.2.3