diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-03-19 15:32:08 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-03-19 15:32:08 +0100 |
commit | 805d4f982b104ce372e61a413935767c27943b0c (patch) | |
tree | d5e877880a39b326679077d1522f4fd4bbdb7784 /lib | |
parent | 972bcef0202e97d39e3bda4297961d5616e7cc3c (diff) |
dataref export: add preprocessing and RMT method/threshold
Diffstat (limited to 'lib')
-rw-r--r-- | lib/functions.py | 35 |
1 files changed, 29 insertions, 6 deletions
diff --git a/lib/functions.py b/lib/functions.py index 0f19668..25c0354 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -22,6 +22,11 @@ dfatool_preproc_relevance_threshold = float( os.getenv("DFATOOL_PREPROCESSING_RELEVANCE_THRESHOLD", "0.1") ) +dfatool_rmt_relevance_method = os.getenv("DFATOOL_RMT_RELEVANCE_METHOD", None) +dfatool_rmt_relevance_threshold = float( + os.getenv("DFATOOL_RMT_RELEVANCE_THRESHOLD", "0.5") +) + if dfatool_preproc_relevance_method == "mi": import sklearn.feature_selection @@ -233,7 +238,22 @@ class ModelFunction: return ret def hyper_to_dref(self): - return dict() + hyper = dict() + if dfatool_preproc_relevance_method: + hyper.update( + { + "preprocessing/relevance/method": dfatool_preproc_relevance_method, + "preprocessing/relevance/threshold": dfatool_preproc_relevance_threshold, + } + ) + if dfatool_rmt_relevance_method: + hyper.update( + { + "rmt/relevance/method": dfatool_rmt_relevance_method, + "rmt/relevance/threshold": dfatool_rmt_relevance_threshold, + } + ) + return hyper @classmethod def from_json(cls, data): @@ -660,12 +680,15 @@ class SKLearnRegressionFunction(ModelFunction): return np.asarray(ret) def hyper_to_dref(self): - ret = { - "paramcount/ndarray": self.paramcount_ndarray, - } + hyper = super().hyper_to_dref() + hyper.update( + { + "paramcount/ndarray": self.paramcount_ndarray, + } + ) if self.paramcount_preprocessed is not None: - ret["paramcount/preprocessed"] = self.paramcount_preprocessed - return ret + hyper["paramcount/preprocessed"] = self.paramcount_preprocessed + return hyper def _build_feature_names(self): # SKLearnRegressionFunction descendants use self.param_names \ self.ignore_index as features. |