summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/functions.py26
-rw-r--r--lib/parameters.py4
2 files changed, 28 insertions, 2 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 9bbc0e7..14893ad 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -461,6 +461,32 @@ class SKLearnRegressionFunction(ModelFunction):
return predictions
+class CARTFunction(SKLearnRegressionFunction):
+ def get_number_of_nodes(self):
+ return self.regressor.tree_.node_count
+
+ def get_max_depth(self):
+ return self.regressor.get_depth()
+
+
+class XGBoostFunction(SKLearnRegressionFunction):
+ def get_number_of_nodes(self):
+ ret = 1
+ for v in self.child.values():
+ if type(v) is SplitFunction:
+ ret += v.get_number_of_nodes()
+ else:
+ ret += 1
+ return ret
+
+ def get_max_depth(self):
+ ret = [0]
+ for v in self.child.values():
+ if type(v) is SplitFunction:
+ ret.append(v.get_max_depth())
+ return 1 + max(ret)
+
+
class AnalyticFunction(ModelFunction):
"""
A multi-dimensional model function, generated from a string, which can be optimized using regression.
diff --git a/lib/parameters.py b/lib/parameters.py
index 8ca8f39..9350a06 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -666,7 +666,7 @@ class ModelAttribute:
def to_dref(self, unit=None):
ret = {"mean": (self.mean, unit), "median": (self.median, unit)}
- if type(self.model_function) is df.SplitFunction:
+ if type(self.model_function) in (df.SplitFunction, df.CARTFunction):
ret["decision tree/nodes"] = self.model_function.get_number_of_nodes()
ret["decision tree/max depth"] = self.model_function.get_max_depth()
@@ -931,7 +931,7 @@ class ModelAttribute:
self.model_function = df.StaticFunction(np.mean(data))
return
cart.fit(fit_parameters, data)
- self.model_function = df.SKLearnRegressionFunction(
+ self.model_function = df.CARTFunction(
np.mean(data), cart, category_to_index, ignore_index
)
return