diff options
Diffstat (limited to 'lib/functions.py')
-rw-r--r-- | lib/functions.py | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/lib/functions.py b/lib/functions.py index 9bbc0e7..14893ad 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -461,6 +461,32 @@ class SKLearnRegressionFunction(ModelFunction): return predictions +class CARTFunction(SKLearnRegressionFunction): + def get_number_of_nodes(self): + return self.regressor.tree_.node_count + + def get_max_depth(self): + return self.regressor.get_depth() + + +class XGBoostFunction(SKLearnRegressionFunction): + def get_number_of_nodes(self): + ret = 1 + for v in self.child.values(): + if type(v) is SplitFunction: + ret += v.get_number_of_nodes() + else: + ret += 1 + return ret + + def get_max_depth(self): + ret = [0] + for v in self.child.values(): + if type(v) is SplitFunction: + ret.append(v.get_max_depth()) + return 1 + max(ret) + + class AnalyticFunction(ModelFunction): """ A multi-dimensional model function, generated from a string, which can be optimized using regression. |