summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md3
-rw-r--r--lib/parameters.py26
2 files changed, 28 insertions, 1 deletions
diff --git a/README.md b/README.md
index aafdc51..0fbbd16 100644
--- a/README.md
+++ b/README.md
@@ -145,7 +145,8 @@ The following variables may be set to alter the behaviour of dfatool components.
| `DFATOOL_KCONF_IGNORE_STRING` | 0, **1** | Ignore string configuration options. These often hold compiler paths and other not really helpful information. |
| `DFATOOL_REGRESSION_SAFE_FUNCTIONS` | **0**, 1 | Use safe functions only (e.g. 1/x returnning 1 for x==0) |
| `DFATOOL_RMT_NONBINARY_NODES` | 0, **1** | Enable non-binary nodes (i.e., nodes with more than two children corresponding to enum variables) in decision trees |
-| `DFATOOL_RMT_RELEVANCE_METHOD` | **none**, std\_by\_param | Ignore parameters deemed irrelevant by the specified heuristic during regression tree generation. Use with caution. |
+| `DFATOOL_RMT_RELEVANCE_METHOD` | **none**, mi, std\_by\_param | Ignore parameters deemed irrelevant by the specified heuristic during regression tree generation. mi := [Mutual Information Regression](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.mutual_info_regression.html). Use with caution. |
+| `DFATOOL_RMT_RELEVANCE_THRESHOLD` | .. **0.5** .. | Threshold for relevance checks. |
| `DFATOOL_PARAM_RELEVANCE_THRESHOLD` | 0 .. **0.5** .. 1 | Threshold for relevant parameter detection: parameter *i* is relevant if mean standard deviation (data partitioned by all parameters) / mean standard deviation (data partition by all parameters but *i*) is less than threshold |
| `DFATOOL_RMT_LOSS_IGNORE_SCALAR` | **0**, 1 | Ignore scalar parameters when computing the loss for split node candidates. Instead of computing the loss of a single partition for each `x_i == j`, compute the loss of partitions for `x_i == j` in which non-scalar parameters vary and scalar parameters are constant. This way, scalar parameters do not affect the decision about which non-scalar parameter to use for splitting. |
| `DFATOOL_PARAM_CATEGORICAL_TO_SCALAR` | **0**, 1 | Some models (e.g. FOL, sklearn CART, XGBoost) do not support categorical parameters. Ignore them (0) or convert them to scalar indexes (1). Conversion uses lexical order. |
diff --git a/lib/parameters.py b/lib/parameters.py
index a154918..ac69075 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -18,6 +18,12 @@ logger = logging.getLogger(__name__)
dfatool_fol_relevance_method = os.getenv("DFATOOL_FOL_RELEVANCE_METHOD", None)
dfatool_symreg_relevance_method = os.getenv("DFATOOL_SYMREG_RELEVANCE_METHOD", None)
dfatool_rmt_relevance_method = os.getenv("DFATOOL_RMT_RELEVANCE_METHOD", None)
+dfatool_rmt_relevance_threshold = float(
+ os.getenv("DFATOOL_RMT_RELEVANCE_THRESHOLD", "0.5")
+)
+
+if dfatool_rmt_relevance_method == "mi":
+ import sklearn.feature_selection
def distinct_param_values(param_tuples):
@@ -1124,6 +1130,19 @@ class ModelAttribute:
by_param = partition_by_param(data, parameters)
distinct_values_by_param_index = distinct_param_values(parameters)
std_lut = np.mean([np.std(v) for v in by_param.values()])
+ elif dfatool_rmt_relevance_method == "mi":
+ fit_parameters, _, ignore_index = param_to_ndarray(
+ parameters, with_nan=False, categorical_to_scalar=True
+ )
+ param_to_fit_param = dict()
+ j = 0
+ for i in range(param_count):
+ if not ignore_index[i]:
+ param_to_fit_param[i] = j
+ j += 1
+ mutual_information = sklearn.feature_selection.mutual_info_regression(
+ fit_parameters, data
+ )
if loss_ignore_scalar:
ffs_eligible_params = list()
@@ -1188,6 +1207,13 @@ class ModelAttribute:
irrelevant_params.append(param_index)
loss.append(np.inf)
continue
+ elif dfatool_rmt_relevance_method == "mi":
+ if (
+ mutual_information[param_to_fit_param[param_index]]
+ < dfatool_rmt_relevance_threshold
+ ):
+ loss.append(np.inf)
+ continue
child_indexes = list()
for value in unique_values: