summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-03-07 11:17:49 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-03-07 11:17:49 +0100
commitd26f7a51e17911cc5a11749df27d69cb095ced4c (patch)
tree9b5fa74f43cbc59fa3f7e8752178cebf22bc244d /lib
parent92b4dd6e05df3b2805570fa1f86c35c33f147bec (diff)
RMT: support mutual information-based parameter relevance
Diffstat (limited to 'lib')
-rw-r--r--lib/parameters.py26
1 files changed, 26 insertions, 0 deletions
diff --git a/lib/parameters.py b/lib/parameters.py
index a154918..ac69075 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -18,6 +18,12 @@ logger = logging.getLogger(__name__)
dfatool_fol_relevance_method = os.getenv("DFATOOL_FOL_RELEVANCE_METHOD", None)
dfatool_symreg_relevance_method = os.getenv("DFATOOL_SYMREG_RELEVANCE_METHOD", None)
dfatool_rmt_relevance_method = os.getenv("DFATOOL_RMT_RELEVANCE_METHOD", None)
+dfatool_rmt_relevance_threshold = float(
+ os.getenv("DFATOOL_RMT_RELEVANCE_THRESHOLD", "0.5")
+)
+
+if dfatool_rmt_relevance_method == "mi":
+ import sklearn.feature_selection
def distinct_param_values(param_tuples):
@@ -1124,6 +1130,19 @@ class ModelAttribute:
by_param = partition_by_param(data, parameters)
distinct_values_by_param_index = distinct_param_values(parameters)
std_lut = np.mean([np.std(v) for v in by_param.values()])
+ elif dfatool_rmt_relevance_method == "mi":
+ fit_parameters, _, ignore_index = param_to_ndarray(
+ parameters, with_nan=False, categorical_to_scalar=True
+ )
+ param_to_fit_param = dict()
+ j = 0
+ for i in range(param_count):
+ if not ignore_index[i]:
+ param_to_fit_param[i] = j
+ j += 1
+ mutual_information = sklearn.feature_selection.mutual_info_regression(
+ fit_parameters, data
+ )
if loss_ignore_scalar:
ffs_eligible_params = list()
@@ -1188,6 +1207,13 @@ class ModelAttribute:
irrelevant_params.append(param_index)
loss.append(np.inf)
continue
+ elif dfatool_rmt_relevance_method == "mi":
+ if (
+ mutual_information[param_to_fit_param[param_index]]
+ < dfatool_rmt_relevance_threshold
+ ):
+ loss.append(np.inf)
+ continue
child_indexes = list()
for value in unique_values: