diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2022-01-21 17:12:30 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2022-01-21 17:12:30 +0100 |
commit | 943c8bee7a511e1aff5e1639a479ea868d7656a7 (patch) | |
tree | a5b5c05aeeac9506161fbf2501870ed9ccfedb52 /lib | |
parent | 7064c8d0b7b3e49fe9ac48902bf1f0a3d62ab4bb (diff) |
add depth / #nodes accessor to sklearn CART model
Diffstat (limited to 'lib')
-rw-r--r-- | lib/functions.py | 26 | ||||
-rw-r--r-- | lib/parameters.py | 4 |
2 files changed, 28 insertions, 2 deletions
diff --git a/lib/functions.py b/lib/functions.py index 9bbc0e7..14893ad 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -461,6 +461,32 @@ class SKLearnRegressionFunction(ModelFunction): return predictions +class CARTFunction(SKLearnRegressionFunction): + def get_number_of_nodes(self): + return self.regressor.tree_.node_count + + def get_max_depth(self): + return self.regressor.get_depth() + + +class XGBoostFunction(SKLearnRegressionFunction): + def get_number_of_nodes(self): + ret = 1 + for v in self.child.values(): + if type(v) is SplitFunction: + ret += v.get_number_of_nodes() + else: + ret += 1 + return ret + + def get_max_depth(self): + ret = [0] + for v in self.child.values(): + if type(v) is SplitFunction: + ret.append(v.get_max_depth()) + return 1 + max(ret) + + class AnalyticFunction(ModelFunction): """ A multi-dimensional model function, generated from a string, which can be optimized using regression. diff --git a/lib/parameters.py b/lib/parameters.py index 8ca8f39..9350a06 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -666,7 +666,7 @@ class ModelAttribute: def to_dref(self, unit=None): ret = {"mean": (self.mean, unit), "median": (self.median, unit)} - if type(self.model_function) is df.SplitFunction: + if type(self.model_function) in (df.SplitFunction, df.CARTFunction): ret["decision tree/nodes"] = self.model_function.get_number_of_nodes() ret["decision tree/max depth"] = self.model_function.get_max_depth() @@ -931,7 +931,7 @@ class ModelAttribute: self.model_function = df.StaticFunction(np.mean(data)) return cart.fit(fit_parameters, data) - self.model_function = df.SKLearnRegressionFunction( + self.model_function = df.CARTFunction( np.mean(data), cart, category_to_index, ignore_index ) return |