diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-03-07 10:52:34 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-03-07 10:52:34 +0100 |
commit | 92b4dd6e05df3b2805570fa1f86c35c33f147bec (patch) | |
tree | 96bddb10dc889888269e5dc808319720c74d30a3 /lib | |
parent | 9754b3a46dad43211539a3dbfbc7c5095bdf30f5 (diff) |
DFATOOL_RMT_IGNORE_IRRELEVANT_PARAMS=1 → DFATOOL_RMT_RELEVANCE_METHOD=std_by_param
Diffstat (limited to 'lib')
-rw-r--r-- | lib/parameters.py | 33 |
1 files changed, 12 insertions, 21 deletions
diff --git a/lib/parameters.py b/lib/parameters.py index 8c7c9cb..a154918 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -15,6 +15,10 @@ from .utils import soft_cast_int, soft_cast_float logger = logging.getLogger(__name__) +dfatool_fol_relevance_method = os.getenv("DFATOOL_FOL_RELEVANCE_METHOD", None) +dfatool_symreg_relevance_method = os.getenv("DFATOOL_SYMREG_RELEVANCE_METHOD", None) +dfatool_rmt_relevance_method = os.getenv("DFATOOL_RMT_RELEVANCE_METHOD", None) + def distinct_param_values(param_tuples): """ @@ -912,11 +916,8 @@ class ModelAttribute: return False def build_fol(self): - ignore_irrelevant = bool( - int(os.getenv("DFATOOL_RMT_IGNORE_IRRELEVANT_PARAMS", "0")) - ) ignore_param_indexes = list() - if ignore_irrelevant: + if dfatool_fol_relevance_method == "std_by_param": for param_index, param in enumerate(self.param_names): if not self.stats.depends_on_param(param): ignore_param_indexes.append(param_index) @@ -964,11 +965,8 @@ class ModelAttribute: return False def build_symreg(self): - ignore_irrelevant = bool( - int(os.getenv("DFATOOL_RMT_IGNORE_IRRELEVANT_PARAMS", "0")) - ) ignore_param_indexes = list() - if ignore_irrelevant: + if dfatool_symreg_relevance_method == "std_by_param": for param_index, param in enumerate(self.param_names): if not self.stats.depends_on_param(param): ignore_param_indexes.append(param_index) @@ -1031,7 +1029,6 @@ class ModelAttribute: with_function_leaves=None, with_nonbinary_nodes=None, with_gplearn_symreg=None, - ignore_irrelevant_parameters=None, loss_ignore_scalar=None, threshold=100, ): @@ -1059,10 +1056,6 @@ class ModelAttribute: ) if with_gplearn_symreg is None: with_gplearn_symreg = bool(int(os.getenv("DFATOOL_USE_SYMREG", "0"))) - if ignore_irrelevant_parameters is None: - ignore_irrelevant_parameters = bool( - int(os.getenv("DFATOOL_RMT_IGNORE_IRRELEVANT_PARAMS", "0")) - ) if loss_ignore_scalar is None: loss_ignore_scalar = bool( int(os.getenv("DFATOOL_RMT_LOSS_IGNORE_SCALAR", "0")) @@ -1084,7 +1077,6 @@ class ModelAttribute: self.data, with_function_leaves=with_function_leaves, with_nonbinary_nodes=with_nonbinary_nodes, - ignore_irrelevant_parameters=ignore_irrelevant_parameters, loss_ignore_scalar=loss_ignore_scalar, submodel=os.getenv("DFATOOL_RMT_SUBMODEL", "uls"), threshold=threshold, @@ -1097,7 +1089,6 @@ class ModelAttribute: data, with_function_leaves=False, with_nonbinary_nodes=True, - ignore_irrelevant_parameters=True, loss_ignore_scalar=False, submodel="uls", threshold=100, @@ -1127,11 +1118,12 @@ class ModelAttribute: loss = list() ffs_feasible = False - if ignore_irrelevant_parameters: - by_param = partition_by_param(data, parameters) - distinct_values_by_param_index = distinct_param_values(parameters) - std_lut = np.mean([np.std(v) for v in by_param.values()]) + if dfatool_rmt_relevance_method: irrelevant_params = list() + if dfatool_rmt_relevance_method == "std_by_param": + by_param = partition_by_param(data, parameters) + distinct_values_by_param_index = distinct_param_values(parameters) + std_lut = np.mean([np.std(v) for v in by_param.values()]) if loss_ignore_scalar: ffs_eligible_params = list() @@ -1182,7 +1174,7 @@ class ModelAttribute: loss.append(np.inf) continue - if ignore_irrelevant_parameters: + if dfatool_rmt_relevance_method == "std_by_param": std_by_param = _mean_std_by_params( by_param, distinct_values_by_param_index, @@ -1335,7 +1327,6 @@ class ModelAttribute: child_data, with_function_leaves=with_function_leaves, with_nonbinary_nodes=with_nonbinary_nodes, - ignore_irrelevant_parameters=ignore_irrelevant_parameters, loss_ignore_scalar=loss_ignore_scalar, submodel=submodel, threshold=threshold, |