diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/parameters.py | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/lib/parameters.py b/lib/parameters.py index a154918..ac69075 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -18,6 +18,12 @@ logger = logging.getLogger(__name__) dfatool_fol_relevance_method = os.getenv("DFATOOL_FOL_RELEVANCE_METHOD", None) dfatool_symreg_relevance_method = os.getenv("DFATOOL_SYMREG_RELEVANCE_METHOD", None) dfatool_rmt_relevance_method = os.getenv("DFATOOL_RMT_RELEVANCE_METHOD", None) +dfatool_rmt_relevance_threshold = float( + os.getenv("DFATOOL_RMT_RELEVANCE_THRESHOLD", "0.5") +) + +if dfatool_rmt_relevance_method == "mi": + import sklearn.feature_selection def distinct_param_values(param_tuples): @@ -1124,6 +1130,19 @@ class ModelAttribute: by_param = partition_by_param(data, parameters) distinct_values_by_param_index = distinct_param_values(parameters) std_lut = np.mean([np.std(v) for v in by_param.values()]) + elif dfatool_rmt_relevance_method == "mi": + fit_parameters, _, ignore_index = param_to_ndarray( + parameters, with_nan=False, categorical_to_scalar=True + ) + param_to_fit_param = dict() + j = 0 + for i in range(param_count): + if not ignore_index[i]: + param_to_fit_param[i] = j + j += 1 + mutual_information = sklearn.feature_selection.mutual_info_regression( + fit_parameters, data + ) if loss_ignore_scalar: ffs_eligible_params = list() @@ -1188,6 +1207,13 @@ class ModelAttribute: irrelevant_params.append(param_index) loss.append(np.inf) continue + elif dfatool_rmt_relevance_method == "mi": + if ( + mutual_information[param_to_fit_param[param_index]] + < dfatool_rmt_relevance_threshold + ): + loss.append(np.inf) + continue child_indexes = list() for value in unique_values: |