diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-02-21 12:35:45 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-02-21 12:35:45 +0100 |
commit | d5950d6f31de5403ed61124d799cd0fafe491b06 (patch) | |
tree | 8c38e4001013cb13c960922fe68b33b2b389df96 /lib | |
parent | 5d83f255f05c3b74df0ace1f70b260959b392eca (diff) |
Replace DFATOOL_FIT_FOL with DFATOOL_MODEL=fol
Diffstat (limited to 'lib')
-rw-r--r-- | lib/model.py | 5 | ||||
-rw-r--r-- | lib/parameters.py | 75 |
2 files changed, 42 insertions, 38 deletions
diff --git a/lib/model.py b/lib/model.py index 972547d..6718090 100644 --- a/lib/model.py +++ b/lib/model.py @@ -304,6 +304,8 @@ class AnalyticModel: self.attr_by_name[name][attr].build_cart() elif model_type == "decart": self.attr_by_name[name][attr].build_decart() + elif model_type == "fol": + self.attr_by_name[name][attr].build_fol() elif model_type == "lmt": self.attr_by_name[name][attr].build_lmt() elif model_type == "xgb": @@ -332,7 +334,6 @@ class AnalyticModel: else: paramfit = ParamFit() tree_allowed = bool(int(os.getenv("DFATOOL_DTREE_ENABLED", "1"))) - use_fol = bool(int(os.getenv("DFATOOL_FIT_FOL", "0"))) use_symreg = bool(int(os.getenv("DFATOOL_FIT_SYMREG", "0"))) tree_required = dict() @@ -341,8 +342,6 @@ class AnalyticModel: for attr in self.attr_by_name[name].keys(): if self.attr_by_name[name][attr].function_override is not None: self.attr_by_name[name][attr].fit_override_function() - elif use_fol: - self.attr_by_name[name][attr].build_fol_model() elif use_symreg: self.attr_by_name[name][attr].build_symreg_model() elif self.attr_by_name[name][ diff --git a/lib/parameters.py b/lib/parameters.py index fa85b7a..d3d2659 100644 --- a/lib/parameters.py +++ b/lib/parameters.py @@ -830,35 +830,6 @@ class ModelAttribute: return False return True - def build_fol_model(self): - ignore_irrelevant = bool( - int(os.getenv("DFATOOL_DTREE_IGNORE_IRRELEVANT_PARAMS", "0")) - ) - ignore_param_indexes = list() - if ignore_irrelevant: - for param_index, param in enumerate(self.param_names): - if not self.stats.depends_on_param(param): - ignore_param_indexes.append(param_index) - if not self.stats: - logger.warning( - "build_fol_model called with ModelAttribute.stats unavailable -- overfitting likely" - ) - else: - for param_index, _ in enumerate(self.param_names): - if len(self.stats.distinct_values_by_param_index[param_index]) < 2: - ignore_param_indexes.append(param_index) - x = df.FOLFunction( - self.median, - self.param_names, - n_samples=self.data.shape[0], - num_args=self.arg_count, - ) - x.fit(self.param_values, self.data, ignore_param_indexes=ignore_param_indexes) - if x.fit_success: - self.model_function = x - else: - logger.warning(f"Fit of first-order linear model function failed.") - def build_symreg_model(self): ignore_irrelevant = bool( int(os.getenv("DFATOOL_DTREE_IGNORE_IRRELEVANT_PARAMS", "0")) @@ -963,8 +934,42 @@ class ModelAttribute: ) return False - def build_xgb(self): - mf = df.XGBoostFunction( + def build_fol(self): + ignore_irrelevant = bool( + int(os.getenv("DFATOOL_DTREE_IGNORE_IRRELEVANT_PARAMS", "0")) + ) + ignore_param_indexes = list() + if ignore_irrelevant: + for param_index, param in enumerate(self.param_names): + if not self.stats.depends_on_param(param): + ignore_param_indexes.append(param_index) + if not self.stats: + logger.warning( + "build_fol_model called with ModelAttribute.stats unavailable -- overfitting likely" + ) + else: + for param_index, _ in enumerate(self.param_names): + if len(self.stats.distinct_values_by_param_index[param_index]) < 2: + ignore_param_indexes.append(param_index) + x = df.FOLFunction( + self.median, + self.param_names, + n_samples=self.data.shape[0], + num_args=self.arg_count, + ) + x.fit(self.param_values, self.data, ignore_param_indexes=ignore_param_indexes) + if x.fit_success: + self.model_function = x + return True + else: + logger.warning(f"Fit of first-order linear model function failed.") + self.model_function = df.StaticFunction( + np.mean(self.data), n_samples=len(self.data) + ) + return False + + def build_lmt(self): + mf = df.LMTFunction( np.mean(self.data), n_samples=len(self.data), param_names=self.param_names, @@ -975,14 +980,14 @@ class ModelAttribute: self.model_function = mf return True else: - logger.warning(f"XGB generation for {self.name} {self.attr} faled") + logger.warning(f"LMT generation for {self.name} {self.attr} faled") self.model_function = df.StaticFunction( np.mean(self.data), n_samples=len(self.data) ) return False - def build_lmt(self): - mf = df.LMTFunction( + def build_xgb(self): + mf = df.XGBoostFunction( np.mean(self.data), n_samples=len(self.data), param_names=self.param_names, @@ -993,7 +998,7 @@ class ModelAttribute: self.model_function = mf return True else: - logger.warning(f"LMT generation for {self.name} {self.attr} faled") + logger.warning(f"XGB generation for {self.name} {self.attr} faled") self.model_function = df.StaticFunction( np.mean(self.data), n_samples=len(self.data) ) |