diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-02-21 11:38:48 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-02-21 11:38:48 +0100 |
commit | c3dbe93034bdeff9dba534d29b04daa527d70241 (patch) | |
tree | 14cdfd0047a01880ca30e96cf6d347009777a80e | |
parent | 6bf411afb9289408c62e1696d8fb2b4da47a9fab (diff) |
move (de)cart, lmt, xgb model generation into separate ModelAttribute functions
-rw-r--r-- | lib/model.py | 23 | ||||
-rw-r--r-- | lib/parameters.py | 160 |
2 files changed, 99 insertions, 84 deletions
diff --git a/lib/model.py b/lib/model.py index 9266153..972547d 100644 --- a/lib/model.py +++ b/lib/model.py @@ -4,7 +4,7 @@ import logging import numpy as np import os from .automata import PTA -from .functions import StaticFunction, SubstateFunction, SplitFunction +import dfatool.functions as df from .parameters import ( ModelAttribute, ParamType, @@ -295,7 +295,22 @@ class AnalyticModel: def build_fitted(self, safe_functions_enabled=False): - if self.force_tree: + model_type = os.getenv("DFATOOL_MODEL", "rmt") + + if model_type != "rmt": + for name in self.names: + for attr in self.by_name[name]["attributes"]: + if model_type == "cart": + self.attr_by_name[name][attr].build_cart() + elif model_type == "decart": + self.attr_by_name[name][attr].build_decart() + elif model_type == "lmt": + self.attr_by_name[name][attr].build_lmt() + elif model_type == "xgb": + self.attr_by_name[name][attr].build_xgb() + else: + logger.error("build_fitted: unknown model type: {model_type}") + elif self.force_tree: for name in self.names: for attr in self.by_name[name]["attributes"]: if ( @@ -392,7 +407,7 @@ class AnalyticModel: model_info = self.attr_by_name[name][key].model_function # shortcut - if type(model_info) is StaticFunction: + if type(model_info) is df.StaticFunction: if "params" in kwargs: return [static_model[name][key] for p in kwargs["params"]] return static_model[name][key] @@ -1018,7 +1033,7 @@ class PTAModel(AnalyticModel): ) ) - self.attr_by_name[p_name]["power"].model_function = SubstateFunction( + self.attr_by_name[p_name]["power"].model_function = df.SubstateFunction( self.attr_by_name[p_name]["power"].get_static(), sequence_by_count, self.attr_by_name[p_name]["substate_count"].model_function, diff --git a/lib/parameters.py b/lib/parameters.py index 83063c2..fa85b7a 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -919,16 +919,92 @@ class ModelAttribute: if x.fit_success: self.model_function = x + def build_cart(self): + mf = df.CARTFunction( + np.mean(self.data), + n_samples=len(self.data), + param_names=self.param_names, + arg_count=self.arg_count, + ).fit( + self.param_values, + self.data, + ) + + if mf.fit_success: + self.model_function = mf + return True + else: + logger.warning(f"CART generation for {self.name} {self.attr} faled") + self.model_function = df.StaticFunction( + np.mean(self.data), n_samples=len(self.data) + ) + return False + + def build_decart(self): + mf = df.CARTFunction( + np.mean(self.data), + n_samples=len(self.data), + param_names=self.param_names, + arg_count=self.arg_count, + decart=True, + ).fit( + self.param_values, + self.data, + scalar_param_indexes=self.scalar_param_indexes, + ) + + if mf.fit_success: + self.model_function = mf + return True + else: + logger.warning(f"DECART generation for {self.name} {self.attr} faled") + self.model_function = df.StaticFunction( + np.mean(self.data), n_samples=len(self.data) + ) + return False + + def build_xgb(self): + mf = df.XGBoostFunction( + np.mean(self.data), + n_samples=len(self.data), + param_names=self.param_names, + arg_count=self.arg_count, + ).fit(self.param_values, self.data) + + if mf.fit_success: + self.model_function = mf + return True + else: + logger.warning(f"XGB generation for {self.name} {self.attr} faled") + self.model_function = df.StaticFunction( + np.mean(self.data), n_samples=len(self.data) + ) + return False + + def build_lmt(self): + mf = df.LMTFunction( + np.mean(self.data), + n_samples=len(self.data), + param_names=self.param_names, + arg_count=self.arg_count, + ).fit(self.param_values, self.data) + + if mf.fit_success: + self.model_function = mf + return True + else: + logger.warning(f"LMT generation for {self.name} {self.attr} faled") + self.model_function = df.StaticFunction( + np.mean(self.data), n_samples=len(self.data) + ) + return False + def build_dtree( self, parameters, data, with_function_leaves=None, with_nonbinary_nodes=None, - with_sklearn_cart=None, - with_sklearn_decart=None, - with_lmt=None, - with_xgboost=None, with_gplearn_symreg=None, ignore_irrelevant_parameters=None, loss_ignore_scalar=None, @@ -941,10 +1017,6 @@ class ModelAttribute: :param data: Measurements. [data 1, data 2, data 3, ...] :param with_function_leaves: Use fitted function sets to generate function leaves for scalar parameters :param with_nonbinary_nodes: Allow non-binary nodes for enum and scalar parameters (i.e., nodes with more than two children) - :param with_sklearn_cart: Use `sklearn.tree.DecisionTreeRegressor` CART implementation for tree generation. Does not support categorical (enum) - and sparse parameters. Both are ignored during fitting. All other options are ignored as well. - :param with_sklearn_decart: Use `sklearn.tree.DecisionTreeRegressor` CART implementation in DECART mode for tree generation. CART limitations - apply; additionaly, scalar parameters are ignored during fitting. :param loss_ignore_scalar: Ignore scalar parameters when computing the loss for split candidates. Only sensible if with_function_leaves is enabled. :param threshold: Return a StaticFunction leaf node if std(data) < threshold. Default 100. @@ -959,16 +1031,6 @@ class ModelAttribute: with_nonbinary_nodes = bool( int(os.getenv("DFATOOL_DTREE_NONBINARY_NODES", "1")) ) - if with_sklearn_cart is None: - with_sklearn_cart = bool(int(os.getenv("DFATOOL_DTREE_SKLEARN_CART", "0"))) - if with_sklearn_decart is None: - with_sklearn_decart = bool( - int(os.getenv("DFATOOL_DTREE_SKLEARN_DECART", "0")) - ) - if with_lmt is None: - with_lmt = bool(int(os.getenv("DFATOOL_DTREE_LMT", "0"))) - if with_xgboost is None: - with_xgboost = bool(int(os.getenv("DFATOOL_USE_XGBOOST", "0"))) if with_gplearn_symreg is None: with_gplearn_symreg = bool(int(os.getenv("DFATOOL_USE_SYMREG", "0"))) if ignore_irrelevant_parameters is None: @@ -980,68 +1042,6 @@ class ModelAttribute: int(os.getenv("DFATOOL_DTREE_LOSS_IGNORE_SCALAR", "0")) ) - if with_sklearn_cart or with_sklearn_decart: - mf = df.CARTFunction( - np.mean(data), - n_samples=len(data), - param_names=self.param_names, - arg_count=self.arg_count, - decart=with_sklearn_decart, - ) - - mf.fit( - parameters, - data, - scalar_param_indexes=self.scalar_param_indexes, - ) - - if mf.fit_success: - self.model_function = mf - else: - logger.warning(f"CART generation for {self.name} {self.attr} faled") - self.model_function = df.StaticFunction( - np.mean(data), n_samples=len(data) - ) - return - - if with_xgboost: - mf = df.XGBoostFunction( - np.mean(data), - n_samples=len(data), - param_names=self.param_names, - arg_count=self.arg_count, - ) - - mf.fit(parameters, data) - - if mf.fit_success: - self.model_function = mf - else: - logger.warning(f"XGB generation for {self.name} {self.attr} faled") - self.model_function = df.StaticFunction( - np.mean(data), n_samples=len(data) - ) - return - - if with_lmt: - mf = df.LMTFunction( - np.mean(data), - n_samples=len(data), - param_names=self.param_names, - arg_count=self.arg_count, - ) - - mf.fit(parameters, data) - - if mf.fit_success: - self.model_function = mf - else: - logger.warning(f"LMT generation for {self.name} {self.attr} faled") - self.model_function = df.StaticFunction( - np.mean(data), n_samples=len(data) - ) - return - if loss_ignore_scalar and not with_function_leaves: logger.warning( "build_dtree {self.name} {self.attr} called with loss_ignore_scalar=True, with_function_leaves=False. This does not make sense." |