summaryrefslogtreecommitdiff
path: root/lib/parameters.py
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-01-19 13:20:06 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-01-19 13:20:06 +0100
commit1e54270e724f719e6b60ae02662cf7d8a175bf4c (patch)
tree3b871791d751fa0f8460fd998bd89c2152c8f41c /lib/parameters.py
parent51b639a524133297d3c5fd3c42187820b363e268 (diff)
Store n_samples in all relevant ModelFunction instances
Diffstat (limited to 'lib/parameters.py')
-rw-r--r--lib/parameters.py17
1 files changed, 12 insertions, 5 deletions
diff --git a/lib/parameters.py b/lib/parameters.py
index 6183ccc..667bcae 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -909,7 +909,7 @@ class ModelAttribute:
for param_index, _ in enumerate(self.param_names):
if len(self.stats.distinct_values_by_param_index[param_index]) < 2:
ignore_param_indexes.append(param_index)
- x = df.FOLFunction(self.median, self.param_names)
+ x = df.FOLFunction(self.median, self.param_names, n_samples=self.data.shape[0])
x.fit(self.param_values, self.data, ignore_param_indexes=ignore_param_indexes)
if x.fit_success:
self.model_function = x
@@ -942,7 +942,10 @@ class ModelAttribute:
pass
elif len(fit_result.keys()):
x = df.analytic.function_powerset(
- fit_result, self.param_names, self.arg_count
+ fit_result,
+ self.param_names,
+ self.arg_count,
+ n_samples=self.data.shape[0],
)
x.value = self.median
x.fit(self.by_param)
@@ -1042,7 +1045,11 @@ class ModelAttribute:
logger.debug("Fitting sklearn CART ...")
cart.fit(fit_parameters, data)
self.model_function = df.CARTFunction(
- np.mean(data), cart, category_to_index, ignore_index
+ np.mean(data),
+ cart,
+ category_to_index,
+ ignore_index,
+ n_samples=len(data),
)
logger.debug("Fitted sklearn CART")
return
@@ -1077,7 +1084,7 @@ class ModelAttribute:
return
xgb.fit(fit_parameters, np.reshape(data, (-1, 1)))
self.model_function = df.XGBoostFunction(
- np.mean(data), xgb, category_to_index, ignore_index
+ np.mean(data), xgb, category_to_index, ignore_index, n_samples=len(data)
)
output_filename = os.getenv("DFATOOL_XGB_DUMP_MODEL", None)
if output_filename:
@@ -1113,7 +1120,7 @@ class ModelAttribute:
return
logger.debug("Fitted LMT")
self.model_function = df.LMTFunction(
- np.mean(data), lmt, category_to_index, ignore_index
+ np.mean(data), lmt, category_to_index, ignore_index, n_samples=len(data)
)
return