diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-03-07 11:17:49 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-03-07 11:17:49 +0100 |
commit | d26f7a51e17911cc5a11749df27d69cb095ced4c (patch) | |
tree | 9b5fa74f43cbc59fa3f7e8752178cebf22bc244d /lib | |
parent | 92b4dd6e05df3b2805570fa1f86c35c33f147bec (diff) |
RMT: support mutual information-based parameter relevance
Diffstat (limited to 'lib')
-rw-r--r-- | lib/parameters.py | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/lib/parameters.py b/lib/parameters.py index a154918..ac69075 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -18,6 +18,12 @@ logger = logging.getLogger(__name__) dfatool_fol_relevance_method = os.getenv("DFATOOL_FOL_RELEVANCE_METHOD", None) dfatool_symreg_relevance_method = os.getenv("DFATOOL_SYMREG_RELEVANCE_METHOD", None) dfatool_rmt_relevance_method = os.getenv("DFATOOL_RMT_RELEVANCE_METHOD", None) +dfatool_rmt_relevance_threshold = float( + os.getenv("DFATOOL_RMT_RELEVANCE_THRESHOLD", "0.5") +) + +if dfatool_rmt_relevance_method == "mi": + import sklearn.feature_selection def distinct_param_values(param_tuples): @@ -1124,6 +1130,19 @@ class ModelAttribute: by_param = partition_by_param(data, parameters) distinct_values_by_param_index = distinct_param_values(parameters) std_lut = np.mean([np.std(v) for v in by_param.values()]) + elif dfatool_rmt_relevance_method == "mi": + fit_parameters, _, ignore_index = param_to_ndarray( + parameters, with_nan=False, categorical_to_scalar=True + ) + param_to_fit_param = dict() + j = 0 + for i in range(param_count): + if not ignore_index[i]: + param_to_fit_param[i] = j + j += 1 + mutual_information = sklearn.feature_selection.mutual_info_regression( + fit_parameters, data + ) if loss_ignore_scalar: ffs_eligible_params = list() @@ -1188,6 +1207,13 @@ class ModelAttribute: irrelevant_params.append(param_index) loss.append(np.inf) continue + elif dfatool_rmt_relevance_method == "mi": + if ( + mutual_information[param_to_fit_param[param_index]] + < dfatool_rmt_relevance_threshold + ): + loss.append(np.inf) + continue child_indexes = list() for value in unique_values: |