summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/functions.py44
-rw-r--r--lib/parameters.py8
2 files changed, 38 insertions, 14 deletions
diff --git a/lib/functions.py b/lib/functions.py
index de5c722..978a993 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -355,6 +355,15 @@ class SplitFunction(ModelFunction):
ret.append(v.get_max_depth())
return 1 + max(ret)
+ def get_number_of_leaves(self):
+ ret = 0
+ for v in self.child.values():
+ if type(v) is SplitFunction:
+ ret += v.get_number_of_leaves()
+ else:
+ ret += 1
+ return ret
+
def to_dot(self, pydot, graph, feature_names, parent=None):
try:
label = feature_names[self.param_index]
@@ -484,6 +493,9 @@ class CARTFunction(SKLearnRegressionFunction):
def get_number_of_nodes(self):
return self.regressor.tree_.node_count
+ def get_number_of_leaves(self):
+ return self.regressor.tree_.n_leaves
+
def get_max_depth(self):
return self.regressor.get_depth()
@@ -543,16 +555,21 @@ class LMTFunction(SKLearnRegressionFunction):
class XGBoostFunction(SKLearnRegressionFunction):
- def get_number_of_nodes(self):
+ def to_json(self):
import json
+ tempfile = f"/tmp/xgb{os.getpid()}.json"
+
self.regressor.get_booster().dump_model(
- "/tmp/xgb.json", dump_format="json", with_stats=True
+ tempfile, dump_format="json", with_stats=True
)
- with open("/tmp/xgb.json", "r") as f:
+ with open(tempfile, "r") as f:
data = json.load(f)
+ os.remove(tempfile)
+ return data
- return sum(map(self._get_number_of_nodes, data))
+ def get_number_of_nodes(self):
+ return sum(map(self._get_number_of_nodes, self.to_json()))
def _get_number_of_nodes(self, data):
ret = 1
@@ -560,16 +577,19 @@ class XGBoostFunction(SKLearnRegressionFunction):
ret += self._get_number_of_nodes(child)
return ret
- def get_max_depth(self):
- import json
+ def get_number_of_leaves(self):
+ return sum(map(self._get_number_of_leaves, self.to_json()))
- self.regressor.get_booster().dump_model(
- "/tmp/xgb.json", dump_format="json", with_stats=True
- )
- with open("/tmp/xgb.json", "r") as f:
- data = json.load(f)
+ def _get_number_of_leaves(self, data):
+ if "leaf" in data:
+ return 1
+ ret = 0
+ for child in data.get("children", list()):
+ ret += self._get_number_of_leaves(child)
+ return ret
- return max(map(self._get_max_depth, data))
+ def get_max_depth(self):
+ return max(map(self._get_max_depth, self.to_json()))
def _get_max_depth(self, data):
ret = [0]
diff --git a/lib/parameters.py b/lib/parameters.py
index dae6e2a..cb4b76f 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -693,7 +693,12 @@ class ModelAttribute:
ret["decision tree/nodes"] = 1
ret["decision tree/max depth"] = 1
- if type(self.model_function) in (df.LMTFunction,):
+ if type(self.model_function) in (
+ df.SplitFunction,
+ df.CARTFunction,
+ df.LMTFunction,
+ df.XGBoostFunction,
+ ):
ret["decision tree/leaves"] = self.model_function.get_number_of_leaves()
return ret
@@ -1257,7 +1262,6 @@ class ModelAttribute:
if np.all(np.isinf(loss)):
# all children have the same configuration. We shouldn't get here due to the threshold check above...
if ffs_feasible:
- logger.debug("ffs feasible, attempting to fit a leaf")
# try generating a function. if it fails, model_function is a StaticFunction.
ma = ModelAttribute(
self.name + "_",