summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2022-10-18 12:32:45 +0200
committerDaniel Friesel <daniel.friesel@uos.de>2022-10-18 12:32:45 +0200
commit0fa46d5aa180031b195c167f2bfa4720e4fc9e28 (patch)
treeaf1857b0089f559b24c67130de77c25e764c7370
parent00bf254e111d1642de68259f68a73bbb0907566b (diff)
--export-dref: calculate model complexity
-rw-r--r--lib/functions.py29
-rw-r--r--lib/parameters.py3
2 files changed, 32 insertions, 0 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 9e5639a..c229594 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -197,6 +197,9 @@ class ModelFunction:
def eval_arr(self, params):
raise NotImplementedError
+ def get_complexity_score(self):
+ raise NotImplementedError
+
def eval_mae(self, param_list):
"""Return model Mean Absolute Error (MAE) for `param_list`."""
if self.is_predictable(param_list):
@@ -282,6 +285,9 @@ class StaticFunction(ModelFunction):
def eval_arr(self, params):
return [self.value for p in params]
+ def get_complexity_score(self):
+ return 1
+
def to_json(self, **kwargs):
ret = super().to_json(**kwargs)
ret.update({"type": "static", "value": self.value})
@@ -377,6 +383,14 @@ class SplitFunction(ModelFunction):
ret += 1
return ret
+ def get_complexity_score(self):
+ if not self.child:
+ return 1
+ ret = 0
+ for v in self.child.values():
+ ret += v.get_complexity_score()
+ return ret
+
def to_dot(self, pydot, graph, feature_names, parent=None):
try:
label = feature_names[self.param_index]
@@ -537,6 +551,9 @@ class CARTFunction(SKLearnRegressionFunction):
def get_max_depth(self):
return self.regressor.get_depth()
+ def get_complexity_score(self):
+ return self.get_number_of_leaves()
+
def to_json(self, feature_names=None, **kwargs):
import sklearn.tree
@@ -588,6 +605,9 @@ class LMTFunction(SKLearnRegressionFunction):
def get_number_of_leaves(self):
return len(self.regressor._leaves.keys())
+ # def get_complexity_score(self):
+ # FIXME
+
def get_max_depth(self):
return max(map(len, self.regressor._leaves.keys())) + 1
@@ -635,6 +655,9 @@ class XGBoostFunction(SKLearnRegressionFunction):
ret.append(self._get_max_depth(child))
return 1 + max(ret)
+ def get_complexity_score(self):
+ return self.get_number_of_leaves()
+
# first-order linear function (no feature interaction)
class FOLFunction(ModelFunction):
@@ -766,6 +789,9 @@ class FOLFunction(ModelFunction):
)
raise
+ def get_complexity_score(self):
+ return len(self.model_args)
+
def to_json(self, **kwargs):
ret = super().to_json(**kwargs)
ret.update(
@@ -978,6 +1004,9 @@ class AnalyticFunction(ModelFunction):
)
return self.value
+ def get_complexity_score(self):
+ return len(self.model_args)
+
def webconf_function_map(self):
js_buf = self.model_function
for i in range(len(self.model_args)):
diff --git a/lib/parameters.py b/lib/parameters.py
index 51c47b8..00260fc 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -642,6 +642,9 @@ class ModelAttribute:
def to_dref(self, unit=None):
ret = {"mean": (self.mean, unit), "median": (self.median, unit)}
+ if issubclass(type(self.model_function), df.ModelFunction):
+ ret["complexity"] = self.model_function.get_complexity_score()
+
if type(self.model_function) in (
df.SplitFunction,
df.CARTFunction,