diff options
Diffstat (limited to 'lib/functions.py')
-rw-r--r-- | lib/functions.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/lib/functions.py b/lib/functions.py index 187e6ff..35b04ef 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -466,6 +466,23 @@ class SplitFunction(ModelFunction): ) return hyper + # SplitFunction only + def flatten(self): + paths = list() + for param_value, subtree in self.child.items(): + if type(subtree) is SplitFunction: + for path, value in subtree.flatten(): + path = [(self.param_name, param_value)] + path + paths.append((path, value)) + elif type(subtree) is StaticFunction: + path = [(self.param_name, param_value)] + paths.append((path, subtree.value)) + else: + raise RuntimeError( + "flatten is only implemented for RMTs with constant leaves" + ) + return paths + @classmethod def from_json(cls, data): assert data["type"] == "split" |