summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2022-12-22 11:02:19 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2022-12-22 11:02:19 +0100
commit097d2fabe3cb4477e4515b4106f0c72001620de0 (patch)
treeeeacfaec66498c13840ac082286dea3897e82f2e /lib
parent293f636491aae95e472bb99350791b8393131798 (diff)
make threshold for parameter relevance heuristic configurable
Diffstat (limited to 'lib')
-rw-r--r--lib/parameters.py15
1 files changed, 12 insertions, 3 deletions
diff --git a/lib/parameters.py b/lib/parameters.py
index a3c5ab1..74b9091 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -42,7 +42,7 @@ def distinct_param_values(param_tuples):
return distinct_values
-def _depends_on_param(corr_param, std_param, std_lut):
+def _depends_on_param(corr_param, std_param, std_lut, threshold=0.5):
# if self.use_corrcoef:
if False:
return corr_param > 0.1
@@ -50,7 +50,7 @@ def _depends_on_param(corr_param, std_param, std_lut):
# In general, std_param_lut < std_by_param. So, if std_by_param == 0, std_param_lut == 0 follows.
# This means that the variation of param does not affect the model quality -> no influence
return False
- return std_lut / std_param < 0.5
+ return std_lut / std_param < threshold
def _mean_std_by_param(n_by_param, all_param_values, param_index):
@@ -190,6 +190,8 @@ def _compute_param_statistics(
np.seterr("raise")
+ relevance_threshold = float(os.getenv("DFATOOL_PARAM_RELEVANCE_TRESHOLD", 0.5))
+
for param_idx, param in enumerate(param_names):
if param_idx < len(codependent_params) and codependent_params[param_idx]:
by_param = partition_by_param(
@@ -209,6 +211,7 @@ def _compute_param_statistics(
ret["corr_by_param"][param],
ret["std_by_param"][param],
ret["std_param_lut"],
+ relevance_threshold,
)
if arg_count:
@@ -1124,6 +1127,8 @@ class ModelAttribute:
"build_dtree {self.name} {self.attr} called with loss_ignore_scalar=True, with_function_leaves=False. This does not make sense."
)
+ relevance_threshold = float(os.getenv("DFATOOL_PARAM_RELEVANCE_TRESHOLD", 0.5))
+
self.model_function = self._build_dtree(
parameters,
data,
@@ -1132,6 +1137,7 @@ class ModelAttribute:
ignore_irrelevant_parameters=ignore_irrelevant_parameters,
loss_ignore_scalar=loss_ignore_scalar,
threshold=threshold,
+ relevance_threshold=relevance_threshold,
)
def _build_dtree(
@@ -1143,6 +1149,7 @@ class ModelAttribute:
ignore_irrelevant_parameters=True,
loss_ignore_scalar=False,
threshold=100,
+ relevance_threshold=0.5,
level=0,
):
"""
@@ -1206,7 +1213,9 @@ class ModelAttribute:
std_by_param = _mean_std_by_param(
by_param, distinct_values_by_param_index, param_index
)
- if not _depends_on_param(None, std_by_param, std_lut):
+ if not _depends_on_param(
+ None, std_by_param, std_lut, relevance_threshold
+ ):
loss.append(np.inf)
continue