From 92b4dd6e05df3b2805570fa1f86c35c33f147bec Mon Sep 17 00:00:00 2001 From: Birte Kristina Friesel Date: Thu, 7 Mar 2024 10:52:34 +0100 Subject: DFATOOL_RMT_IGNORE_IRRELEVANT_PARAMS=1 → DFATOOL_RMT_RELEVANCE_METHOD=std_by_param MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lib/parameters.py | 33 ++++++++++++--------------------- 1 file changed, 12 insertions(+), 21 deletions(-) (limited to 'lib') diff --git a/lib/parameters.py b/lib/parameters.py index 8c7c9cb..a154918 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -15,6 +15,10 @@ from .utils import soft_cast_int, soft_cast_float logger = logging.getLogger(__name__) +dfatool_fol_relevance_method = os.getenv("DFATOOL_FOL_RELEVANCE_METHOD", None) +dfatool_symreg_relevance_method = os.getenv("DFATOOL_SYMREG_RELEVANCE_METHOD", None) +dfatool_rmt_relevance_method = os.getenv("DFATOOL_RMT_RELEVANCE_METHOD", None) + def distinct_param_values(param_tuples): """ @@ -912,11 +916,8 @@ class ModelAttribute: return False def build_fol(self): - ignore_irrelevant = bool( - int(os.getenv("DFATOOL_RMT_IGNORE_IRRELEVANT_PARAMS", "0")) - ) ignore_param_indexes = list() - if ignore_irrelevant: + if dfatool_fol_relevance_method == "std_by_param": for param_index, param in enumerate(self.param_names): if not self.stats.depends_on_param(param): ignore_param_indexes.append(param_index) @@ -964,11 +965,8 @@ class ModelAttribute: return False def build_symreg(self): - ignore_irrelevant = bool( - int(os.getenv("DFATOOL_RMT_IGNORE_IRRELEVANT_PARAMS", "0")) - ) ignore_param_indexes = list() - if ignore_irrelevant: + if dfatool_symreg_relevance_method == "std_by_param": for param_index, param in enumerate(self.param_names): if not self.stats.depends_on_param(param): ignore_param_indexes.append(param_index) @@ -1031,7 +1029,6 @@ class ModelAttribute: with_function_leaves=None, with_nonbinary_nodes=None, with_gplearn_symreg=None, - ignore_irrelevant_parameters=None, loss_ignore_scalar=None, threshold=100, ): @@ -1059,10 +1056,6 @@ class ModelAttribute: ) if with_gplearn_symreg is None: with_gplearn_symreg = bool(int(os.getenv("DFATOOL_USE_SYMREG", "0"))) - if ignore_irrelevant_parameters is None: - ignore_irrelevant_parameters = bool( - int(os.getenv("DFATOOL_RMT_IGNORE_IRRELEVANT_PARAMS", "0")) - ) if loss_ignore_scalar is None: loss_ignore_scalar = bool( int(os.getenv("DFATOOL_RMT_LOSS_IGNORE_SCALAR", "0")) @@ -1084,7 +1077,6 @@ class ModelAttribute: self.data, with_function_leaves=with_function_leaves, with_nonbinary_nodes=with_nonbinary_nodes, - ignore_irrelevant_parameters=ignore_irrelevant_parameters, loss_ignore_scalar=loss_ignore_scalar, submodel=os.getenv("DFATOOL_RMT_SUBMODEL", "uls"), threshold=threshold, @@ -1097,7 +1089,6 @@ class ModelAttribute: data, with_function_leaves=False, with_nonbinary_nodes=True, - ignore_irrelevant_parameters=True, loss_ignore_scalar=False, submodel="uls", threshold=100, @@ -1127,11 +1118,12 @@ class ModelAttribute: loss = list() ffs_feasible = False - if ignore_irrelevant_parameters: - by_param = partition_by_param(data, parameters) - distinct_values_by_param_index = distinct_param_values(parameters) - std_lut = np.mean([np.std(v) for v in by_param.values()]) + if dfatool_rmt_relevance_method: irrelevant_params = list() + if dfatool_rmt_relevance_method == "std_by_param": + by_param = partition_by_param(data, parameters) + distinct_values_by_param_index = distinct_param_values(parameters) + std_lut = np.mean([np.std(v) for v in by_param.values()]) if loss_ignore_scalar: ffs_eligible_params = list() @@ -1182,7 +1174,7 @@ class ModelAttribute: loss.append(np.inf) continue - if ignore_irrelevant_parameters: + if dfatool_rmt_relevance_method == "std_by_param": std_by_param = _mean_std_by_params( by_param, distinct_values_by_param_index, @@ -1335,7 +1327,6 @@ class ModelAttribute: child_data, with_function_leaves=with_function_leaves, with_nonbinary_nodes=with_nonbinary_nodes, - ignore_irrelevant_parameters=ignore_irrelevant_parameters, loss_ignore_scalar=loss_ignore_scalar, submodel=submodel, threshold=threshold, -- cgit v1.2.3