diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2022-12-22 11:02:19 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2022-12-22 11:02:19 +0100 |
commit | 097d2fabe3cb4477e4515b4106f0c72001620de0 (patch) | |
tree | eeacfaec66498c13840ac082286dea3897e82f2e /lib | |
parent | 293f636491aae95e472bb99350791b8393131798 (diff) |
make threshold for parameter relevance heuristic configurable
Diffstat (limited to 'lib')
-rw-r--r-- | lib/parameters.py | 15 |
1 files changed, 12 insertions, 3 deletions
diff --git a/lib/parameters.py b/lib/parameters.py index a3c5ab1..74b9091 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -42,7 +42,7 @@ def distinct_param_values(param_tuples): return distinct_values -def _depends_on_param(corr_param, std_param, std_lut): +def _depends_on_param(corr_param, std_param, std_lut, threshold=0.5): # if self.use_corrcoef: if False: return corr_param > 0.1 @@ -50,7 +50,7 @@ def _depends_on_param(corr_param, std_param, std_lut): # In general, std_param_lut < std_by_param. So, if std_by_param == 0, std_param_lut == 0 follows. # This means that the variation of param does not affect the model quality -> no influence return False - return std_lut / std_param < 0.5 + return std_lut / std_param < threshold def _mean_std_by_param(n_by_param, all_param_values, param_index): @@ -190,6 +190,8 @@ def _compute_param_statistics( np.seterr("raise") + relevance_threshold = float(os.getenv("DFATOOL_PARAM_RELEVANCE_TRESHOLD", 0.5)) + for param_idx, param in enumerate(param_names): if param_idx < len(codependent_params) and codependent_params[param_idx]: by_param = partition_by_param( @@ -209,6 +211,7 @@ def _compute_param_statistics( ret["corr_by_param"][param], ret["std_by_param"][param], ret["std_param_lut"], + relevance_threshold, ) if arg_count: @@ -1124,6 +1127,8 @@ class ModelAttribute: "build_dtree {self.name} {self.attr} called with loss_ignore_scalar=True, with_function_leaves=False. This does not make sense." ) + relevance_threshold = float(os.getenv("DFATOOL_PARAM_RELEVANCE_TRESHOLD", 0.5)) + self.model_function = self._build_dtree( parameters, data, @@ -1132,6 +1137,7 @@ class ModelAttribute: ignore_irrelevant_parameters=ignore_irrelevant_parameters, loss_ignore_scalar=loss_ignore_scalar, threshold=threshold, + relevance_threshold=relevance_threshold, ) def _build_dtree( @@ -1143,6 +1149,7 @@ class ModelAttribute: ignore_irrelevant_parameters=True, loss_ignore_scalar=False, threshold=100, + relevance_threshold=0.5, level=0, ): """ @@ -1206,7 +1213,9 @@ class ModelAttribute: std_by_param = _mean_std_by_param( by_param, distinct_values_by_param_index, param_index ) - if not _depends_on_param(None, std_by_param, std_lut): + if not _depends_on_param( + None, std_by_param, std_lut, relevance_threshold + ): loss.append(np.inf) continue |