From 6e2fa5f36deb9882e5bddb212695434c219c3d94 Mon Sep 17 00:00:00 2001 From: Daniel Friesel Date: Wed, 5 Jan 2022 15:56:24 +0100 Subject: store decision tree attributes of xv models in dataref export --- lib/functions.py | 7 +++++++ lib/model.py | 18 +++++++++++++++++- lib/parameters.py | 3 ++- lib/validation.py | 2 +- 4 files changed, 27 insertions(+), 3 deletions(-) (limited to 'lib') diff --git a/lib/functions.py b/lib/functions.py index 320e8ed..5358d8e 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -339,6 +339,13 @@ class SplitFunction(ModelFunction): ret += 1 return ret + def get_max_depth(self): + ret = [0] + for v in self.child.values(): + if type(v) is SplitFunction: + ret.append(v.get_max_depth()) + return 1 + max(ret) + @classmethod def from_json(cls, data): assert data["type"] == "split" diff --git a/lib/model.py b/lib/model.py index 4f5f60f..5a44c7b 100644 --- a/lib/model.py +++ b/lib/model.py @@ -449,7 +449,9 @@ class AnalyticModel: threshold=threshold, ) - def to_dref(self, static_quality, lut_quality, model_quality) -> dict: + def to_dref( + self, static_quality, lut_quality, model_quality, xv_models=None + ) -> dict: ret = dict() for name in self.names: param_data = { @@ -546,6 +548,20 @@ class AnalyticModel: ) except KeyError: logger.warning(f"{name} {attr_name} param model has no MAPE") + + if xv_models is not None: + keys = ("decision tree/nodes", "decision tree/max depth") + entry = dict() + for k in keys: + entry[k] = list() + for xv_model in xv_models: + dref = xv_model.attr_by_name[name][attr_name].to_dref() + for k in keys: + if k in dref: + entry[k].append(dref[k]) + for k in keys: + if len(entry[k]): + ret[k] = np.mean(entry[k]) return ret def to_json(self, **kwargs) -> dict: diff --git a/lib/parameters.py b/lib/parameters.py index 9cfe145..38e36b2 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -652,7 +652,8 @@ class ModelAttribute: ret = {"mean": (self.mean, unit), "median": (self.median, unit)} if type(self.model_function) is df.SplitFunction: - ret["decision tree nodes"] = self.model_function.get_number_of_nodes() + ret["decision tree/nodes"] = self.model_function.get_number_of_nodes() + ret["decision tree/max depth"] = self.model_function.get_max_depth() return ret diff --git a/lib/validation.py b/lib/validation.py index 5c65fe3..89bc67c 100644 --- a/lib/validation.py +++ b/lib/validation.py @@ -280,4 +280,4 @@ class CrossValidator: validation, self.parameters, *self.args, **self.kwargs ) - return training_model, validation_data.assess(training_model) + return training_data, validation_data.assess(training_model) -- cgit v1.2.3