summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-03-07 10:52:34 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-03-07 10:52:34 +0100
commit92b4dd6e05df3b2805570fa1f86c35c33f147bec (patch)
tree96bddb10dc889888269e5dc808319720c74d30a3 /lib
parent9754b3a46dad43211539a3dbfbc7c5095bdf30f5 (diff)
DFATOOL_RMT_IGNORE_IRRELEVANT_PARAMS=1 → DFATOOL_RMT_RELEVANCE_METHOD=std_by_param
Diffstat (limited to 'lib')
-rw-r--r--lib/parameters.py33
1 files changed, 12 insertions, 21 deletions
diff --git a/lib/parameters.py b/lib/parameters.py
index 8c7c9cb..a154918 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -15,6 +15,10 @@ from .utils import soft_cast_int, soft_cast_float
logger = logging.getLogger(__name__)
+dfatool_fol_relevance_method = os.getenv("DFATOOL_FOL_RELEVANCE_METHOD", None)
+dfatool_symreg_relevance_method = os.getenv("DFATOOL_SYMREG_RELEVANCE_METHOD", None)
+dfatool_rmt_relevance_method = os.getenv("DFATOOL_RMT_RELEVANCE_METHOD", None)
+
def distinct_param_values(param_tuples):
"""
@@ -912,11 +916,8 @@ class ModelAttribute:
return False
def build_fol(self):
- ignore_irrelevant = bool(
- int(os.getenv("DFATOOL_RMT_IGNORE_IRRELEVANT_PARAMS", "0"))
- )
ignore_param_indexes = list()
- if ignore_irrelevant:
+ if dfatool_fol_relevance_method == "std_by_param":
for param_index, param in enumerate(self.param_names):
if not self.stats.depends_on_param(param):
ignore_param_indexes.append(param_index)
@@ -964,11 +965,8 @@ class ModelAttribute:
return False
def build_symreg(self):
- ignore_irrelevant = bool(
- int(os.getenv("DFATOOL_RMT_IGNORE_IRRELEVANT_PARAMS", "0"))
- )
ignore_param_indexes = list()
- if ignore_irrelevant:
+ if dfatool_symreg_relevance_method == "std_by_param":
for param_index, param in enumerate(self.param_names):
if not self.stats.depends_on_param(param):
ignore_param_indexes.append(param_index)
@@ -1031,7 +1029,6 @@ class ModelAttribute:
with_function_leaves=None,
with_nonbinary_nodes=None,
with_gplearn_symreg=None,
- ignore_irrelevant_parameters=None,
loss_ignore_scalar=None,
threshold=100,
):
@@ -1059,10 +1056,6 @@ class ModelAttribute:
)
if with_gplearn_symreg is None:
with_gplearn_symreg = bool(int(os.getenv("DFATOOL_USE_SYMREG", "0")))
- if ignore_irrelevant_parameters is None:
- ignore_irrelevant_parameters = bool(
- int(os.getenv("DFATOOL_RMT_IGNORE_IRRELEVANT_PARAMS", "0"))
- )
if loss_ignore_scalar is None:
loss_ignore_scalar = bool(
int(os.getenv("DFATOOL_RMT_LOSS_IGNORE_SCALAR", "0"))
@@ -1084,7 +1077,6 @@ class ModelAttribute:
self.data,
with_function_leaves=with_function_leaves,
with_nonbinary_nodes=with_nonbinary_nodes,
- ignore_irrelevant_parameters=ignore_irrelevant_parameters,
loss_ignore_scalar=loss_ignore_scalar,
submodel=os.getenv("DFATOOL_RMT_SUBMODEL", "uls"),
threshold=threshold,
@@ -1097,7 +1089,6 @@ class ModelAttribute:
data,
with_function_leaves=False,
with_nonbinary_nodes=True,
- ignore_irrelevant_parameters=True,
loss_ignore_scalar=False,
submodel="uls",
threshold=100,
@@ -1127,11 +1118,12 @@ class ModelAttribute:
loss = list()
ffs_feasible = False
- if ignore_irrelevant_parameters:
- by_param = partition_by_param(data, parameters)
- distinct_values_by_param_index = distinct_param_values(parameters)
- std_lut = np.mean([np.std(v) for v in by_param.values()])
+ if dfatool_rmt_relevance_method:
irrelevant_params = list()
+ if dfatool_rmt_relevance_method == "std_by_param":
+ by_param = partition_by_param(data, parameters)
+ distinct_values_by_param_index = distinct_param_values(parameters)
+ std_lut = np.mean([np.std(v) for v in by_param.values()])
if loss_ignore_scalar:
ffs_eligible_params = list()
@@ -1182,7 +1174,7 @@ class ModelAttribute:
loss.append(np.inf)
continue
- if ignore_irrelevant_parameters:
+ if dfatool_rmt_relevance_method == "std_by_param":
std_by_param = _mean_std_by_params(
by_param,
distinct_values_by_param_index,
@@ -1335,7 +1327,6 @@ class ModelAttribute:
child_data,
with_function_leaves=with_function_leaves,
with_nonbinary_nodes=with_nonbinary_nodes,
- ignore_irrelevant_parameters=ignore_irrelevant_parameters,
loss_ignore_scalar=loss_ignore_scalar,
submodel=submodel,
threshold=threshold,