diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-01-25 09:43:35 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-01-25 09:55:36 +0100 |
commit | eb34056dbf9e10be7bed3835600f4edd7f7a1ef3 (patch) | |
tree | 31a181ed3d2bd36322fcaee19967e5a7b7edae33 /lib | |
parent | 9e8f9774c42cac0904d56d8edfd4abd4b2b717d1 (diff) |
LMT: Tailor hyper-parameters towards higher accuracy (and longer training time)
Also, allow users to override them
Diffstat (limited to 'lib')
-rw-r--r-- | lib/parameters.py | 55 |
1 files changed, 54 insertions, 1 deletions
diff --git a/lib/parameters.py b/lib/parameters.py index e697eda..8627b72 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -1106,7 +1106,60 @@ class ModelAttribute: from sklearn.linear_model import LinearRegression from dfatool.lineartree import LinearTreeRegressor - lmt = LinearTreeRegressor(base_estimator=LinearRegression(), max_depth=20) + # max_depth : int, default=5 + # The maximum depth of the tree considering only the splitting nodes. + # A higher value implies a higher training time. + max_depth = int(os.getenv("DFATOOL_LMT_MAX_DEPTH", "20")) + + # min_samples_split : int or float, default=6 + # The minimum number of samples required to split an internal node. + # The minimum valid number of samples in each node is 6. + # A lower value implies a higher training time. + # - If int, then consider `min_samples_split` as the minimum number. + # - If float, then `min_samples_split` is a fraction and + # `ceil(min_samples_split * n_samples)` are the minimum + # number of samples for each split. + if "." in os.getenv("DFATOOL_LMT_MIN_SAMPLES_SPLIT", ""): + min_samples_split = float(os.getenv("DFATOOL_LMT_MIN_SAMPLES_SPLIT")) + else: + min_samples_split = int(os.getenv("DFATOOL_LMT_MIN_SAMPLES_SPLIT", "6")) + + # min_samples_leaf : int or float, default=0.1 + # The minimum number of samples required to be at a leaf node. + # A split point at any depth will only be considered if it leaves at + # least `min_samples_leaf` training samples in each of the left and + # right branches. + # The minimum valid number of samples in each leaf is 3. + # A lower value implies a higher training time. + # - If int, then consider `min_samples_leaf` as the minimum number. + # - If float, then `min_samples_leaf` is a fraction and + # `ceil(min_samples_leaf * n_samples)` are the minimum + # number of samples for each node. + if "." in os.getenv("DFATOOL_LMT_MIN_SAMPLES_LEAF", ""): + min_samples_leaf = float(os.getenv("DFATOOL_LMT_MIN_SAMPLES_LEAF")) + else: + min_samples_leaf = int(os.getenv("DFATOOL_LMT_MIN_SAMPLES_LEAF", "3")) + + # max_bins : int, default=25 + # The maximum number of bins to use to search the optimal split in each + # feature. Features with a small number of unique values may use less than + # ``max_bins`` bins. Must be lower than 120 and larger than 10. + # A higher value implies a higher training time. + max_bins = int(os.getenv("DFATOOL_LMT_MAX_BINS", "120")) + + # criterion : {"mse", "rmse", "mae", "poisson"}, default="mse" + # The function to measure the quality of a split. "poisson" + # requires ``y >= 0``. + criterion = os.getenv("DFATOOL_LMT_CRITERION", "mse") + + lmt = LinearTreeRegressor( + base_estimator=LinearRegression(), + max_depth=max_depth, + min_samples_split=min_samples_split, + min_samples_leaf=min_samples_leaf, + max_bins=max_bins, + criterion=criterion, + ) fit_parameters, category_to_index, ignore_index = param_to_ndarray( parameters, with_nan=False, categorial_to_scalar=categorial_to_scalar ) |