summaryrefslogtreecommitdiff
path: root/lib/parameters.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/parameters.py')
-rw-r--r--lib/parameters.py75
1 files changed, 40 insertions, 35 deletions
diff --git a/lib/parameters.py b/lib/parameters.py
index fa85b7a..d3d2659 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -830,35 +830,6 @@ class ModelAttribute:
return False
return True
- def build_fol_model(self):
- ignore_irrelevant = bool(
- int(os.getenv("DFATOOL_DTREE_IGNORE_IRRELEVANT_PARAMS", "0"))
- )
- ignore_param_indexes = list()
- if ignore_irrelevant:
- for param_index, param in enumerate(self.param_names):
- if not self.stats.depends_on_param(param):
- ignore_param_indexes.append(param_index)
- if not self.stats:
- logger.warning(
- "build_fol_model called with ModelAttribute.stats unavailable -- overfitting likely"
- )
- else:
- for param_index, _ in enumerate(self.param_names):
- if len(self.stats.distinct_values_by_param_index[param_index]) < 2:
- ignore_param_indexes.append(param_index)
- x = df.FOLFunction(
- self.median,
- self.param_names,
- n_samples=self.data.shape[0],
- num_args=self.arg_count,
- )
- x.fit(self.param_values, self.data, ignore_param_indexes=ignore_param_indexes)
- if x.fit_success:
- self.model_function = x
- else:
- logger.warning(f"Fit of first-order linear model function failed.")
-
def build_symreg_model(self):
ignore_irrelevant = bool(
int(os.getenv("DFATOOL_DTREE_IGNORE_IRRELEVANT_PARAMS", "0"))
@@ -963,8 +934,42 @@ class ModelAttribute:
)
return False
- def build_xgb(self):
- mf = df.XGBoostFunction(
+ def build_fol(self):
+ ignore_irrelevant = bool(
+ int(os.getenv("DFATOOL_DTREE_IGNORE_IRRELEVANT_PARAMS", "0"))
+ )
+ ignore_param_indexes = list()
+ if ignore_irrelevant:
+ for param_index, param in enumerate(self.param_names):
+ if not self.stats.depends_on_param(param):
+ ignore_param_indexes.append(param_index)
+ if not self.stats:
+ logger.warning(
+ "build_fol_model called with ModelAttribute.stats unavailable -- overfitting likely"
+ )
+ else:
+ for param_index, _ in enumerate(self.param_names):
+ if len(self.stats.distinct_values_by_param_index[param_index]) < 2:
+ ignore_param_indexes.append(param_index)
+ x = df.FOLFunction(
+ self.median,
+ self.param_names,
+ n_samples=self.data.shape[0],
+ num_args=self.arg_count,
+ )
+ x.fit(self.param_values, self.data, ignore_param_indexes=ignore_param_indexes)
+ if x.fit_success:
+ self.model_function = x
+ return True
+ else:
+ logger.warning(f"Fit of first-order linear model function failed.")
+ self.model_function = df.StaticFunction(
+ np.mean(self.data), n_samples=len(self.data)
+ )
+ return False
+
+ def build_lmt(self):
+ mf = df.LMTFunction(
np.mean(self.data),
n_samples=len(self.data),
param_names=self.param_names,
@@ -975,14 +980,14 @@ class ModelAttribute:
self.model_function = mf
return True
else:
- logger.warning(f"XGB generation for {self.name} {self.attr} faled")
+ logger.warning(f"LMT generation for {self.name} {self.attr} faled")
self.model_function = df.StaticFunction(
np.mean(self.data), n_samples=len(self.data)
)
return False
- def build_lmt(self):
- mf = df.LMTFunction(
+ def build_xgb(self):
+ mf = df.XGBoostFunction(
np.mean(self.data),
n_samples=len(self.data),
param_names=self.param_names,
@@ -993,7 +998,7 @@ class ModelAttribute:
self.model_function = mf
return True
else:
- logger.warning(f"LMT generation for {self.name} {self.attr} faled")
+ logger.warning(f"XGB generation for {self.name} {self.attr} faled")
self.model_function = df.StaticFunction(
np.mean(self.data), n_samples=len(self.data)
)