summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/functions.py23
-rw-r--r--lib/parameters.py23
2 files changed, 29 insertions, 17 deletions
diff --git a/lib/functions.py b/lib/functions.py
index 4be32c6..1b8612e 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -177,11 +177,12 @@ class ModelFunction:
:type value_error: dict, optional
"""
- def __init__(self, value):
+ def __init__(self, value, n_samples=None):
# a model always has a static (median/mean) value. For StaticFunction, it's the only data point.
# For more complex models, it's usede both as fallback in case the model cannot predict the current
# parameter combination, and for use cases requiring static models
self.value = value
+ self.n_samples = n_samples
# A ModelFunction may track its own accuracy, both of the static value and of the eval() method.
# However, it does not specify how the accuracy was calculated (e.g. which data was used and whether cross-validation was performed)
@@ -213,6 +214,7 @@ class ModelFunction:
"""Convert model to JSON."""
ret = {
"value": self.value,
+ "n_samples": self.n_samples,
"valueError": self.value_error,
"functionError": self.function_error,
}
@@ -308,8 +310,8 @@ class StaticFunction(ModelFunction):
class SplitFunction(ModelFunction):
- def __init__(self, value, param_index, child):
- super().__init__(value)
+ def __init__(self, value, param_index, child, **kwargs):
+ super().__init__(value, **kwargs)
self.param_index = param_index
self.child = child
@@ -416,8 +418,8 @@ class SplitFunction(ModelFunction):
class SubstateFunction(ModelFunction):
- def __init__(self, value, sequence_by_count, count_model, sub_model):
- super().__init__(value)
+ def __init__(self, value, sequence_by_count, count_model, sub_model, **kwargs):
+ super().__init__(value, **kwargs)
self.sequence_by_count = sequence_by_count
self.count_model = count_model
self.sub_model = sub_model
@@ -475,8 +477,8 @@ class SKLearnRegressionFunction(ModelFunction):
always_predictable = True
has_eval_arr = True
- def __init__(self, value, regressor, categorial_to_index, ignore_index):
- super().__init__(value)
+ def __init__(self, value, regressor, categorial_to_index, ignore_index, **kwargs):
+ super().__init__(value, **kwargs)
self.regressor = regressor
self.categorial_to_index = categorial_to_index
self.ignore_index = ignore_index
@@ -734,8 +736,8 @@ class XGBoostFunction(SKLearnRegressionFunction):
class FOLFunction(ModelFunction):
always_predictable = True
- def __init__(self, value, parameters, num_args=0):
- super().__init__(value)
+ def __init__(self, value, parameters, num_args=0, **kwargs):
+ super().__init__(value, **kwargs)
self.parameter_names = parameters
self._num_args = num_args
self.fit_success = False
@@ -904,6 +906,7 @@ class AnalyticFunction(ModelFunction):
num_args=0,
regression_args=None,
fit_by_param=None,
+ **kwargs,
):
"""
Create a new AnalyticFunction object from a function string.
@@ -923,7 +926,7 @@ class AnalyticFunction(ModelFunction):
both for function usage and least squares optimization.
If unset, defaults to [1, 1, 1, ...]
"""
- super().__init__(value)
+ super().__init__(value, **kwargs)
self._parameter_names = parameters
self._num_args = num_args
self.model_function = function_str
diff --git a/lib/parameters.py b/lib/parameters.py
index dd435e6..0f23ef8 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -920,6 +920,7 @@ class ModelAttribute:
function_str,
self.param_names,
self.arg_count,
+ n_samples=self.data.shape[0],
# fit_by_param=fit_result,
)
x.fit(self.by_param)
@@ -1031,7 +1032,9 @@ class ModelAttribute:
logger.warning(
f"Cannot generate CART for {self.name} {self.attr} due to lack of parameters: parameter shape is {np.array(parameters).shape}, fit_parameter shape is {fit_parameters.shape}"
)
- self.model_function = df.StaticFunction(np.mean(data))
+ self.model_function = df.StaticFunction(
+ np.mean(data), n_samples=len(data)
+ )
return
logger.debug("Fitting sklearn CART ...")
cart.fit(fit_parameters, data)
@@ -1065,7 +1068,9 @@ class ModelAttribute:
logger.warning(
f"Cannot run XGBoost for {self.name} {self.attr} due to lack of parameters: parameter shape is {np.array(parameters).shape}, fit_parameter shape is {fit_parameters.shape}"
)
- self.model_function = df.StaticFunction(np.mean(data))
+ self.model_function = df.StaticFunction(
+ np.mean(data), n_samples=len(data)
+ )
return
xgb.fit(fit_parameters, np.reshape(data, (-1, 1)))
self.model_function = df.XGBoostFunction(
@@ -1090,14 +1095,18 @@ class ModelAttribute:
logger.warning(
f"Cannot generate LMT for {self.name} {self.attr} due to lack of parameters: parameter shape is {np.array(parameters).shape}, fit_parameter shape is {fit_parameters.shape}"
)
- self.model_function = df.StaticFunction(np.mean(data))
+ self.model_function = df.StaticFunction(
+ np.mean(data), n_samples=len(data)
+ )
return
logger.debug("Fitting LMT ...")
try:
lmt.fit(fit_parameters, data)
except np.linalg.LinAlgError as e:
logger.error(f"LMT generation for {self.name} {self.attr} failed: {e}")
- self.model_function = df.StaticFunction(np.mean(data))
+ self.model_function = df.StaticFunction(
+ np.mean(data), n_samples=len(data)
+ )
return
logger.debug("Fitted LMT")
self.model_function = df.LMTFunction(
@@ -1156,7 +1165,7 @@ class ModelAttribute:
param_count = nonarg_count + self.arg_count
# TODO eigentlich muss threshold hier auf Basis der aktuellen Messdatenpartition neu berechnet werden
if param_count == 0 or np.std(data) <= threshold:
- return df.StaticFunction(np.mean(data))
+ return df.StaticFunction(np.mean(data), n_samples=len(data))
# sf.value_error["std"] = np.std(data)
loss = list()
@@ -1294,7 +1303,7 @@ class ModelAttribute:
paramfit.fit()
ma.set_data_from_paramfit(paramfit)
return ma.model_function
- return df.StaticFunction(np.mean(data))
+ return df.StaticFunction(np.mean(data), n_samples=len(data))
split_feasible = True
if loss_ignore_scalar:
@@ -1365,4 +1374,4 @@ class ModelAttribute:
assert len(child.values()) >= 2
- return df.SplitFunction(np.mean(data), symbol_index, child)
+ return df.SplitFunction(np.mean(data), symbol_index, child, n_samples=len(data))