From 805d4f982b104ce372e61a413935767c27943b0c Mon Sep 17 00:00:00 2001 From: Birte Kristina Friesel Date: Tue, 19 Mar 2024 15:32:08 +0100 Subject: dataref export: add preprocessing and RMT method/threshold --- lib/functions.py | 35 +++++++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/lib/functions.py b/lib/functions.py index 0f19668..25c0354 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -22,6 +22,11 @@ dfatool_preproc_relevance_threshold = float( os.getenv("DFATOOL_PREPROCESSING_RELEVANCE_THRESHOLD", "0.1") ) +dfatool_rmt_relevance_method = os.getenv("DFATOOL_RMT_RELEVANCE_METHOD", None) +dfatool_rmt_relevance_threshold = float( + os.getenv("DFATOOL_RMT_RELEVANCE_THRESHOLD", "0.5") +) + if dfatool_preproc_relevance_method == "mi": import sklearn.feature_selection @@ -233,7 +238,22 @@ class ModelFunction: return ret def hyper_to_dref(self): - return dict() + hyper = dict() + if dfatool_preproc_relevance_method: + hyper.update( + { + "preprocessing/relevance/method": dfatool_preproc_relevance_method, + "preprocessing/relevance/threshold": dfatool_preproc_relevance_threshold, + } + ) + if dfatool_rmt_relevance_method: + hyper.update( + { + "rmt/relevance/method": dfatool_rmt_relevance_method, + "rmt/relevance/threshold": dfatool_rmt_relevance_threshold, + } + ) + return hyper @classmethod def from_json(cls, data): @@ -660,12 +680,15 @@ class SKLearnRegressionFunction(ModelFunction): return np.asarray(ret) def hyper_to_dref(self): - ret = { - "paramcount/ndarray": self.paramcount_ndarray, - } + hyper = super().hyper_to_dref() + hyper.update( + { + "paramcount/ndarray": self.paramcount_ndarray, + } + ) if self.paramcount_preprocessed is not None: - ret["paramcount/preprocessed"] = self.paramcount_preprocessed - return ret + hyper["paramcount/preprocessed"] = self.paramcount_preprocessed + return hyper def _build_feature_names(self): # SKLearnRegressionFunction descendants use self.param_names \ self.ignore_index as features. -- cgit v1.2.3