diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-01-19 13:20:06 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-01-19 13:20:06 +0100 |
commit | 1e54270e724f719e6b60ae02662cf7d8a175bf4c (patch) | |
tree | 3b871791d751fa0f8460fd998bd89c2152c8f41c /lib/parameters.py | |
parent | 51b639a524133297d3c5fd3c42187820b363e268 (diff) |
Store n_samples in all relevant ModelFunction instances
Diffstat (limited to 'lib/parameters.py')
-rw-r--r-- | lib/parameters.py | 17 |
1 files changed, 12 insertions, 5 deletions
diff --git a/lib/parameters.py b/lib/parameters.py index 6183ccc..667bcae 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -909,7 +909,7 @@ class ModelAttribute: for param_index, _ in enumerate(self.param_names): if len(self.stats.distinct_values_by_param_index[param_index]) < 2: ignore_param_indexes.append(param_index) - x = df.FOLFunction(self.median, self.param_names) + x = df.FOLFunction(self.median, self.param_names, n_samples=self.data.shape[0]) x.fit(self.param_values, self.data, ignore_param_indexes=ignore_param_indexes) if x.fit_success: self.model_function = x @@ -942,7 +942,10 @@ class ModelAttribute: pass elif len(fit_result.keys()): x = df.analytic.function_powerset( - fit_result, self.param_names, self.arg_count + fit_result, + self.param_names, + self.arg_count, + n_samples=self.data.shape[0], ) x.value = self.median x.fit(self.by_param) @@ -1042,7 +1045,11 @@ class ModelAttribute: logger.debug("Fitting sklearn CART ...") cart.fit(fit_parameters, data) self.model_function = df.CARTFunction( - np.mean(data), cart, category_to_index, ignore_index + np.mean(data), + cart, + category_to_index, + ignore_index, + n_samples=len(data), ) logger.debug("Fitted sklearn CART") return @@ -1077,7 +1084,7 @@ class ModelAttribute: return xgb.fit(fit_parameters, np.reshape(data, (-1, 1))) self.model_function = df.XGBoostFunction( - np.mean(data), xgb, category_to_index, ignore_index + np.mean(data), xgb, category_to_index, ignore_index, n_samples=len(data) ) output_filename = os.getenv("DFATOOL_XGB_DUMP_MODEL", None) if output_filename: @@ -1113,7 +1120,7 @@ class ModelAttribute: return logger.debug("Fitted LMT") self.model_function = df.LMTFunction( - np.mean(data), lmt, category_to_index, ignore_index + np.mean(data), lmt, category_to_index, ignore_index, n_samples=len(data) ) return |