From d26f7a51e17911cc5a11749df27d69cb095ced4c Mon Sep 17 00:00:00 2001 From: Birte Kristina Friesel Date: Thu, 7 Mar 2024 11:17:49 +0100 Subject: RMT: support mutual information-based parameter relevance --- lib/parameters.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) (limited to 'lib') diff --git a/lib/parameters.py b/lib/parameters.py index a154918..ac69075 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -18,6 +18,12 @@ logger = logging.getLogger(__name__) dfatool_fol_relevance_method = os.getenv("DFATOOL_FOL_RELEVANCE_METHOD", None) dfatool_symreg_relevance_method = os.getenv("DFATOOL_SYMREG_RELEVANCE_METHOD", None) dfatool_rmt_relevance_method = os.getenv("DFATOOL_RMT_RELEVANCE_METHOD", None) +dfatool_rmt_relevance_threshold = float( + os.getenv("DFATOOL_RMT_RELEVANCE_THRESHOLD", "0.5") +) + +if dfatool_rmt_relevance_method == "mi": + import sklearn.feature_selection def distinct_param_values(param_tuples): @@ -1124,6 +1130,19 @@ class ModelAttribute: by_param = partition_by_param(data, parameters) distinct_values_by_param_index = distinct_param_values(parameters) std_lut = np.mean([np.std(v) for v in by_param.values()]) + elif dfatool_rmt_relevance_method == "mi": + fit_parameters, _, ignore_index = param_to_ndarray( + parameters, with_nan=False, categorical_to_scalar=True + ) + param_to_fit_param = dict() + j = 0 + for i in range(param_count): + if not ignore_index[i]: + param_to_fit_param[i] = j + j += 1 + mutual_information = sklearn.feature_selection.mutual_info_regression( + fit_parameters, data + ) if loss_ignore_scalar: ffs_eligible_params = list() @@ -1188,6 +1207,13 @@ class ModelAttribute: irrelevant_params.append(param_index) loss.append(np.inf) continue + elif dfatool_rmt_relevance_method == "mi": + if ( + mutual_information[param_to_fit_param[param_index]] + < dfatool_rmt_relevance_threshold + ): + loss.append(np.inf) + continue child_indexes = list() for value in unique_values: -- cgit v1.2.3