diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-01-19 13:20:06 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-01-19 13:20:06 +0100 |
commit | 1e54270e724f719e6b60ae02662cf7d8a175bf4c (patch) | |
tree | 3b871791d751fa0f8460fd998bd89c2152c8f41c /lib | |
parent | 51b639a524133297d3c5fd3c42187820b363e268 (diff) |
Store n_samples in all relevant ModelFunction instances
Diffstat (limited to 'lib')
-rw-r--r-- | lib/functions.py | 4 | ||||
-rw-r--r-- | lib/parameters.py | 17 |
2 files changed, 14 insertions, 7 deletions
diff --git a/lib/functions.py b/lib/functions.py index 0f472ee..2e7735a 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -1330,7 +1330,7 @@ class analytic: return "analytic._{}({})".format(function_type, ref_str) @staticmethod - def function_powerset(fit_results, parameter_names, num_args=0): + def function_powerset(fit_results, parameter_names, num_args=0, **kwargs): """ Combine per-parameter regression results into a single multi-dimensional function. @@ -1368,5 +1368,5 @@ class analytic: ) ) return AnalyticFunction( - None, buf, parameter_names, num_args, fit_by_param=fit_results + None, buf, parameter_names, num_args, fit_by_param=fit_results, **kwargs ) diff --git a/lib/parameters.py b/lib/parameters.py index 6183ccc..667bcae 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -909,7 +909,7 @@ class ModelAttribute: for param_index, _ in enumerate(self.param_names): if len(self.stats.distinct_values_by_param_index[param_index]) < 2: ignore_param_indexes.append(param_index) - x = df.FOLFunction(self.median, self.param_names) + x = df.FOLFunction(self.median, self.param_names, n_samples=self.data.shape[0]) x.fit(self.param_values, self.data, ignore_param_indexes=ignore_param_indexes) if x.fit_success: self.model_function = x @@ -942,7 +942,10 @@ class ModelAttribute: pass elif len(fit_result.keys()): x = df.analytic.function_powerset( - fit_result, self.param_names, self.arg_count + fit_result, + self.param_names, + self.arg_count, + n_samples=self.data.shape[0], ) x.value = self.median x.fit(self.by_param) @@ -1042,7 +1045,11 @@ class ModelAttribute: logger.debug("Fitting sklearn CART ...") cart.fit(fit_parameters, data) self.model_function = df.CARTFunction( - np.mean(data), cart, category_to_index, ignore_index + np.mean(data), + cart, + category_to_index, + ignore_index, + n_samples=len(data), ) logger.debug("Fitted sklearn CART") return @@ -1077,7 +1084,7 @@ class ModelAttribute: return xgb.fit(fit_parameters, np.reshape(data, (-1, 1))) self.model_function = df.XGBoostFunction( - np.mean(data), xgb, category_to_index, ignore_index + np.mean(data), xgb, category_to_index, ignore_index, n_samples=len(data) ) output_filename = os.getenv("DFATOOL_XGB_DUMP_MODEL", None) if output_filename: @@ -1113,7 +1120,7 @@ class ModelAttribute: return logger.debug("Fitted LMT") self.model_function = df.LMTFunction( - np.mean(data), lmt, category_to_index, ignore_index + np.mean(data), lmt, category_to_index, ignore_index, n_samples=len(data) ) return |