diff options
Diffstat (limited to 'lib/parameters.py')
-rw-r--r-- | lib/parameters.py | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/lib/parameters.py b/lib/parameters.py index 4e98f54..51ff680 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -896,6 +896,7 @@ class ModelAttribute: with_nonbinary_nodes=True, with_sklearn_cart=False, with_xgboost=False, + with_lmt=False, loss_ignore_scalar=False, threshold=100, ): @@ -975,6 +976,26 @@ class ModelAttribute: ) return + if with_lmt: + from sklearn.linear_model import LinearRegression + from dfatool.lineartree import LinearTreeRegressor + + lmt = LinearTreeRegressor(base_estimator=LinearRegression()) + fit_parameters, category_to_index, ignore_index = param_to_ndarray( + parameters, with_nan=False, categorial_to_scalar=categorial_to_scalar + ) + if fit_parameters.shape[1] == 0: + logger.warning( + f"Cannot generate LMT due to lack of parameters: parameter shape is {np.array(parameters).shape}, fit_parameter shape is {fit_parameters.shape}" + ) + self.model_function = df.StaticFunction(np.mean(data)) + return + lmt.fit(fit_parameters, data) + self.model_function = df.LMTFunction( + np.mean(data), lmt, category_to_index, ignore_index + ) + return + if loss_ignore_scalar and not with_function_leaves: logger.warning( "build_dtree called with loss_ignore_scalar=True, with_function_leaves=False. This does not make sense." |