diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2022-10-18 12:32:45 +0200 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2022-10-18 12:32:45 +0200 |
commit | 0fa46d5aa180031b195c167f2bfa4720e4fc9e28 (patch) | |
tree | af1857b0089f559b24c67130de77c25e764c7370 /lib | |
parent | 00bf254e111d1642de68259f68a73bbb0907566b (diff) |
--export-dref: calculate model complexity
Diffstat (limited to 'lib')
-rw-r--r-- | lib/functions.py | 29 | ||||
-rw-r--r-- | lib/parameters.py | 3 |
2 files changed, 32 insertions, 0 deletions
diff --git a/lib/functions.py b/lib/functions.py index 9e5639a..c229594 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -197,6 +197,9 @@ class ModelFunction: def eval_arr(self, params): raise NotImplementedError + def get_complexity_score(self): + raise NotImplementedError + def eval_mae(self, param_list): """Return model Mean Absolute Error (MAE) for `param_list`.""" if self.is_predictable(param_list): @@ -282,6 +285,9 @@ class StaticFunction(ModelFunction): def eval_arr(self, params): return [self.value for p in params] + def get_complexity_score(self): + return 1 + def to_json(self, **kwargs): ret = super().to_json(**kwargs) ret.update({"type": "static", "value": self.value}) @@ -377,6 +383,14 @@ class SplitFunction(ModelFunction): ret += 1 return ret + def get_complexity_score(self): + if not self.child: + return 1 + ret = 0 + for v in self.child.values(): + ret += v.get_complexity_score() + return ret + def to_dot(self, pydot, graph, feature_names, parent=None): try: label = feature_names[self.param_index] @@ -537,6 +551,9 @@ class CARTFunction(SKLearnRegressionFunction): def get_max_depth(self): return self.regressor.get_depth() + def get_complexity_score(self): + return self.get_number_of_leaves() + def to_json(self, feature_names=None, **kwargs): import sklearn.tree @@ -588,6 +605,9 @@ class LMTFunction(SKLearnRegressionFunction): def get_number_of_leaves(self): return len(self.regressor._leaves.keys()) + # def get_complexity_score(self): + # FIXME + def get_max_depth(self): return max(map(len, self.regressor._leaves.keys())) + 1 @@ -635,6 +655,9 @@ class XGBoostFunction(SKLearnRegressionFunction): ret.append(self._get_max_depth(child)) return 1 + max(ret) + def get_complexity_score(self): + return self.get_number_of_leaves() + # first-order linear function (no feature interaction) class FOLFunction(ModelFunction): @@ -766,6 +789,9 @@ class FOLFunction(ModelFunction): ) raise + def get_complexity_score(self): + return len(self.model_args) + def to_json(self, **kwargs): ret = super().to_json(**kwargs) ret.update( @@ -978,6 +1004,9 @@ class AnalyticFunction(ModelFunction): ) return self.value + def get_complexity_score(self): + return len(self.model_args) + def webconf_function_map(self): js_buf = self.model_function for i in range(len(self.model_args)): diff --git a/lib/parameters.py b/lib/parameters.py index 51c47b8..00260fc 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -642,6 +642,9 @@ class ModelAttribute: def to_dref(self, unit=None): ret = {"mean": (self.mean, unit), "median": (self.median, unit)} + if issubclass(type(self.model_function), df.ModelFunction): + ret["complexity"] = self.model_function.get_complexity_score() + if type(self.model_function) in ( df.SplitFunction, df.CARTFunction, |