summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-01-19 13:20:06 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-01-19 13:20:06 +0100
commit1e54270e724f719e6b60ae02662cf7d8a175bf4c (patch)
tree3b871791d751fa0f8460fd998bd89c2152c8f41c /lib
parent51b639a524133297d3c5fd3c42187820b363e268 (diff)
Store n_samples in all relevant ModelFunction instances
Diffstat (limited to 'lib')
-rw-r--r--lib/functions.py4
-rw-r--r--lib/parameters.py17
2 files changed, 14 insertions, 7 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 0f472ee..2e7735a 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -1330,7 +1330,7 @@ class analytic:
return "analytic._{}({})".format(function_type, ref_str)
@staticmethod
- def function_powerset(fit_results, parameter_names, num_args=0):
+ def function_powerset(fit_results, parameter_names, num_args=0, **kwargs):
"""
Combine per-parameter regression results into a single multi-dimensional function.
@@ -1368,5 +1368,5 @@ class analytic:
)
)
return AnalyticFunction(
- None, buf, parameter_names, num_args, fit_by_param=fit_results
+ None, buf, parameter_names, num_args, fit_by_param=fit_results, **kwargs
)
diff --git a/lib/parameters.py b/lib/parameters.py
index 6183ccc..667bcae 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -909,7 +909,7 @@ class ModelAttribute:
for param_index, _ in enumerate(self.param_names):
if len(self.stats.distinct_values_by_param_index[param_index]) < 2:
ignore_param_indexes.append(param_index)
- x = df.FOLFunction(self.median, self.param_names)
+ x = df.FOLFunction(self.median, self.param_names, n_samples=self.data.shape[0])
x.fit(self.param_values, self.data, ignore_param_indexes=ignore_param_indexes)
if x.fit_success:
self.model_function = x
@@ -942,7 +942,10 @@ class ModelAttribute:
pass
elif len(fit_result.keys()):
x = df.analytic.function_powerset(
- fit_result, self.param_names, self.arg_count
+ fit_result,
+ self.param_names,
+ self.arg_count,
+ n_samples=self.data.shape[0],
)
x.value = self.median
x.fit(self.by_param)
@@ -1042,7 +1045,11 @@ class ModelAttribute:
logger.debug("Fitting sklearn CART ...")
cart.fit(fit_parameters, data)
self.model_function = df.CARTFunction(
- np.mean(data), cart, category_to_index, ignore_index
+ np.mean(data),
+ cart,
+ category_to_index,
+ ignore_index,
+ n_samples=len(data),
)
logger.debug("Fitted sklearn CART")
return
@@ -1077,7 +1084,7 @@ class ModelAttribute:
return
xgb.fit(fit_parameters, np.reshape(data, (-1, 1)))
self.model_function = df.XGBoostFunction(
- np.mean(data), xgb, category_to_index, ignore_index
+ np.mean(data), xgb, category_to_index, ignore_index, n_samples=len(data)
)
output_filename = os.getenv("DFATOOL_XGB_DUMP_MODEL", None)
if output_filename:
@@ -1113,7 +1120,7 @@ class ModelAttribute:
return
logger.debug("Fitted LMT")
self.model_function = df.LMTFunction(
- np.mean(data), lmt, category_to_index, ignore_index
+ np.mean(data), lmt, category_to_index, ignore_index, n_samples=len(data)
)
return