summaryrefslogtreecommitdiff
path: root/lib/functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/functions.py')
-rw-r--r--lib/functions.py29
1 files changed, 29 insertions, 0 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 9e5639a..c229594 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -197,6 +197,9 @@ class ModelFunction:
def eval_arr(self, params):
raise NotImplementedError
+ def get_complexity_score(self):
+ raise NotImplementedError
+
def eval_mae(self, param_list):
"""Return model Mean Absolute Error (MAE) for `param_list`."""
if self.is_predictable(param_list):
@@ -282,6 +285,9 @@ class StaticFunction(ModelFunction):
def eval_arr(self, params):
return [self.value for p in params]
+ def get_complexity_score(self):
+ return 1
+
def to_json(self, **kwargs):
ret = super().to_json(**kwargs)
ret.update({"type": "static", "value": self.value})
@@ -377,6 +383,14 @@ class SplitFunction(ModelFunction):
ret += 1
return ret
+ def get_complexity_score(self):
+ if not self.child:
+ return 1
+ ret = 0
+ for v in self.child.values():
+ ret += v.get_complexity_score()
+ return ret
+
def to_dot(self, pydot, graph, feature_names, parent=None):
try:
label = feature_names[self.param_index]
@@ -537,6 +551,9 @@ class CARTFunction(SKLearnRegressionFunction):
def get_max_depth(self):
return self.regressor.get_depth()
+ def get_complexity_score(self):
+ return self.get_number_of_leaves()
+
def to_json(self, feature_names=None, **kwargs):
import sklearn.tree
@@ -588,6 +605,9 @@ class LMTFunction(SKLearnRegressionFunction):
def get_number_of_leaves(self):
return len(self.regressor._leaves.keys())
+ # def get_complexity_score(self):
+ # FIXME
+
def get_max_depth(self):
return max(map(len, self.regressor._leaves.keys())) + 1
@@ -635,6 +655,9 @@ class XGBoostFunction(SKLearnRegressionFunction):
ret.append(self._get_max_depth(child))
return 1 + max(ret)
+ def get_complexity_score(self):
+ return self.get_number_of_leaves()
+
# first-order linear function (no feature interaction)
class FOLFunction(ModelFunction):
@@ -766,6 +789,9 @@ class FOLFunction(ModelFunction):
)
raise
+ def get_complexity_score(self):
+ return len(self.model_args)
+
def to_json(self, **kwargs):
ret = super().to_json(**kwargs)
ret.update(
@@ -978,6 +1004,9 @@ class AnalyticFunction(ModelFunction):
)
return self.value
+ def get_complexity_score(self):
+ return len(self.model_args)
+
def webconf_function_map(self):
js_buf = self.model_function
for i in range(len(self.model_args)):