summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/functions.py35
1 files changed, 29 insertions, 6 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 0f19668..25c0354 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -22,6 +22,11 @@ dfatool_preproc_relevance_threshold = float(
os.getenv("DFATOOL_PREPROCESSING_RELEVANCE_THRESHOLD", "0.1")
)
+dfatool_rmt_relevance_method = os.getenv("DFATOOL_RMT_RELEVANCE_METHOD", None)
+dfatool_rmt_relevance_threshold = float(
+ os.getenv("DFATOOL_RMT_RELEVANCE_THRESHOLD", "0.5")
+)
+
if dfatool_preproc_relevance_method == "mi":
import sklearn.feature_selection
@@ -233,7 +238,22 @@ class ModelFunction:
return ret
def hyper_to_dref(self):
- return dict()
+ hyper = dict()
+ if dfatool_preproc_relevance_method:
+ hyper.update(
+ {
+ "preprocessing/relevance/method": dfatool_preproc_relevance_method,
+ "preprocessing/relevance/threshold": dfatool_preproc_relevance_threshold,
+ }
+ )
+ if dfatool_rmt_relevance_method:
+ hyper.update(
+ {
+ "rmt/relevance/method": dfatool_rmt_relevance_method,
+ "rmt/relevance/threshold": dfatool_rmt_relevance_threshold,
+ }
+ )
+ return hyper
@classmethod
def from_json(cls, data):
@@ -660,12 +680,15 @@ class SKLearnRegressionFunction(ModelFunction):
return np.asarray(ret)
def hyper_to_dref(self):
- ret = {
- "paramcount/ndarray": self.paramcount_ndarray,
- }
+ hyper = super().hyper_to_dref()
+ hyper.update(
+ {
+ "paramcount/ndarray": self.paramcount_ndarray,
+ }
+ )
if self.paramcount_preprocessed is not None:
- ret["paramcount/preprocessed"] = self.paramcount_preprocessed
- return ret
+ hyper["paramcount/preprocessed"] = self.paramcount_preprocessed
+ return hyper
def _build_feature_names(self):
# SKLearnRegressionFunction descendants use self.param_names \ self.ignore_index as features.