summaryrefslogtreecommitdiff
path: root/lib/functions.py
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-03-22 09:23:02 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-03-22 09:23:02 +0100
commit81a1b42c554414bf9f0995741a866ebd1ece6c3c (patch)
tree686b4843ae1768143bdf1c397601ef8b51817041 /lib/functions.py
parent995753a7913b89ea5f58983cf2cfca1cb8a00b77 (diff)
RMT: add max depth hyper-parameter
Diffstat (limited to 'lib/functions.py')
-rw-r--r--lib/functions.py9
1 files changed, 9 insertions, 0 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 25c0354..32fade0 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -444,6 +444,15 @@ class SplitFunction(ModelFunction):
child.to_dot(pydot, graph, feature_names, str(id(self)))
graph.add_edge(pydot.Edge(str(id(self)), str(id(child)), label=key))
+ def hyper_to_dref(self):
+ hyper = super().hyper_to_dref()
+ hyper.update(
+ {
+ "rmt/max depth": int(os.getenv("DFATOOL_RMT_MAX_DEPTH", "0"))
+ or "infty",
+ }
+ )
+
@classmethod
def from_json(cls, data):
assert data["type"] == "split"