diff options
Diffstat (limited to 'lib/parameters.py')
-rw-r--r-- | lib/parameters.py | 50 |
1 files changed, 27 insertions, 23 deletions
diff --git a/lib/parameters.py b/lib/parameters.py index 1ae4e4c..bafc2a5 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -830,29 +830,6 @@ class ModelAttribute: return False return True - def build_symreg_model(self): - ignore_irrelevant = bool( - int(os.getenv("DFATOOL_RMT_IGNORE_IRRELEVANT_PARAMS", "0")) - ) - ignore_param_indexes = list() - if ignore_irrelevant: - for param_index, param in enumerate(self.param_names): - if not self.stats.depends_on_param(param): - ignore_param_indexes.append(param_index) - x = df.SymbolicRegressionFunction( - self.median, - self.param_names, - n_samples=self.data.shape[0], - num_args=self.arg_count, - ) - x.fit(self.param_values, self.data, ignore_param_indexes=ignore_param_indexes) - if x.fit_success: - self.model_function = x - else: - logger.debug( - f"Symbolic Regression model generation for {self.name} {self.attr} failed." - ) - def fit_override_function(self): function_str = self.function_override x = df.AnalyticFunction( @@ -986,6 +963,33 @@ class ModelAttribute: ) return False + def build_symreg(self): + ignore_irrelevant = bool( + int(os.getenv("DFATOOL_RMT_IGNORE_IRRELEVANT_PARAMS", "0")) + ) + ignore_param_indexes = list() + if ignore_irrelevant: + for param_index, param in enumerate(self.param_names): + if not self.stats.depends_on_param(param): + ignore_param_indexes.append(param_index) + x = df.SymbolicRegressionFunction( + np.mean(self.data), + n_samples=self.data.shape[0], + param_names=self.param_names, + arg_count=self.arg_count, + ).fit(self.param_values, self.data, ignore_param_indexes=ignore_param_indexes) + if x.fit_success: + self.model_function = x + return True + else: + logger.debug( + f"Symbolic Regression model generation for {self.name} {self.attr} failed." + ) + self.model_function = df.StaticFunction( + np.mean(self.data), n_samples=len(self.data) + ) + return False + def build_xgb(self): mf = df.XGBoostFunction( np.mean(self.data), |