summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-02-21 11:38:48 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-02-21 11:38:48 +0100
commitc3dbe93034bdeff9dba534d29b04daa527d70241 (patch)
tree14cdfd0047a01880ca30e96cf6d347009777a80e
parent6bf411afb9289408c62e1696d8fb2b4da47a9fab (diff)
move (de)cart, lmt, xgb model generation into separate ModelAttribute functions
-rw-r--r--lib/model.py23
-rw-r--r--lib/parameters.py160
2 files changed, 99 insertions, 84 deletions
diff --git a/lib/model.py b/lib/model.py
index 9266153..972547d 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -4,7 +4,7 @@ import logging
import numpy as np
import os
from .automata import PTA
-from .functions import StaticFunction, SubstateFunction, SplitFunction
+import dfatool.functions as df
from .parameters import (
ModelAttribute,
ParamType,
@@ -295,7 +295,22 @@ class AnalyticModel:
def build_fitted(self, safe_functions_enabled=False):
- if self.force_tree:
+ model_type = os.getenv("DFATOOL_MODEL", "rmt")
+
+ if model_type != "rmt":
+ for name in self.names:
+ for attr in self.by_name[name]["attributes"]:
+ if model_type == "cart":
+ self.attr_by_name[name][attr].build_cart()
+ elif model_type == "decart":
+ self.attr_by_name[name][attr].build_decart()
+ elif model_type == "lmt":
+ self.attr_by_name[name][attr].build_lmt()
+ elif model_type == "xgb":
+ self.attr_by_name[name][attr].build_xgb()
+ else:
+ logger.error("build_fitted: unknown model type: {model_type}")
+ elif self.force_tree:
for name in self.names:
for attr in self.by_name[name]["attributes"]:
if (
@@ -392,7 +407,7 @@ class AnalyticModel:
model_info = self.attr_by_name[name][key].model_function
# shortcut
- if type(model_info) is StaticFunction:
+ if type(model_info) is df.StaticFunction:
if "params" in kwargs:
return [static_model[name][key] for p in kwargs["params"]]
return static_model[name][key]
@@ -1018,7 +1033,7 @@ class PTAModel(AnalyticModel):
)
)
- self.attr_by_name[p_name]["power"].model_function = SubstateFunction(
+ self.attr_by_name[p_name]["power"].model_function = df.SubstateFunction(
self.attr_by_name[p_name]["power"].get_static(),
sequence_by_count,
self.attr_by_name[p_name]["substate_count"].model_function,
diff --git a/lib/parameters.py b/lib/parameters.py
index 83063c2..fa85b7a 100644
--- a/lib/parameters.py
+++ b/lib/parameters.py
@@ -919,16 +919,92 @@ class ModelAttribute:
if x.fit_success:
self.model_function = x
+ def build_cart(self):
+ mf = df.CARTFunction(
+ np.mean(self.data),
+ n_samples=len(self.data),
+ param_names=self.param_names,
+ arg_count=self.arg_count,
+ ).fit(
+ self.param_values,
+ self.data,
+ )
+
+ if mf.fit_success:
+ self.model_function = mf
+ return True
+ else:
+ logger.warning(f"CART generation for {self.name} {self.attr} faled")
+ self.model_function = df.StaticFunction(
+ np.mean(self.data), n_samples=len(self.data)
+ )
+ return False
+
+ def build_decart(self):
+ mf = df.CARTFunction(
+ np.mean(self.data),
+ n_samples=len(self.data),
+ param_names=self.param_names,
+ arg_count=self.arg_count,
+ decart=True,
+ ).fit(
+ self.param_values,
+ self.data,
+ scalar_param_indexes=self.scalar_param_indexes,
+ )
+
+ if mf.fit_success:
+ self.model_function = mf
+ return True
+ else:
+ logger.warning(f"DECART generation for {self.name} {self.attr} faled")
+ self.model_function = df.StaticFunction(
+ np.mean(self.data), n_samples=len(self.data)
+ )
+ return False
+
+ def build_xgb(self):
+ mf = df.XGBoostFunction(
+ np.mean(self.data),
+ n_samples=len(self.data),
+ param_names=self.param_names,
+ arg_count=self.arg_count,
+ ).fit(self.param_values, self.data)
+
+ if mf.fit_success:
+ self.model_function = mf
+ return True
+ else:
+ logger.warning(f"XGB generation for {self.name} {self.attr} faled")
+ self.model_function = df.StaticFunction(
+ np.mean(self.data), n_samples=len(self.data)
+ )
+ return False
+
+ def build_lmt(self):
+ mf = df.LMTFunction(
+ np.mean(self.data),
+ n_samples=len(self.data),
+ param_names=self.param_names,
+ arg_count=self.arg_count,
+ ).fit(self.param_values, self.data)
+
+ if mf.fit_success:
+ self.model_function = mf
+ return True
+ else:
+ logger.warning(f"LMT generation for {self.name} {self.attr} faled")
+ self.model_function = df.StaticFunction(
+ np.mean(self.data), n_samples=len(self.data)
+ )
+ return False
+
def build_dtree(
self,
parameters,
data,
with_function_leaves=None,
with_nonbinary_nodes=None,
- with_sklearn_cart=None,
- with_sklearn_decart=None,
- with_lmt=None,
- with_xgboost=None,
with_gplearn_symreg=None,
ignore_irrelevant_parameters=None,
loss_ignore_scalar=None,
@@ -941,10 +1017,6 @@ class ModelAttribute:
:param data: Measurements. [data 1, data 2, data 3, ...]
:param with_function_leaves: Use fitted function sets to generate function leaves for scalar parameters
:param with_nonbinary_nodes: Allow non-binary nodes for enum and scalar parameters (i.e., nodes with more than two children)
- :param with_sklearn_cart: Use `sklearn.tree.DecisionTreeRegressor` CART implementation for tree generation. Does not support categorical (enum)
- and sparse parameters. Both are ignored during fitting. All other options are ignored as well.
- :param with_sklearn_decart: Use `sklearn.tree.DecisionTreeRegressor` CART implementation in DECART mode for tree generation. CART limitations
- apply; additionaly, scalar parameters are ignored during fitting.
:param loss_ignore_scalar: Ignore scalar parameters when computing the loss for split candidates. Only sensible if with_function_leaves is enabled.
:param threshold: Return a StaticFunction leaf node if std(data) < threshold. Default 100.
@@ -959,16 +1031,6 @@ class ModelAttribute:
with_nonbinary_nodes = bool(
int(os.getenv("DFATOOL_DTREE_NONBINARY_NODES", "1"))
)
- if with_sklearn_cart is None:
- with_sklearn_cart = bool(int(os.getenv("DFATOOL_DTREE_SKLEARN_CART", "0")))
- if with_sklearn_decart is None:
- with_sklearn_decart = bool(
- int(os.getenv("DFATOOL_DTREE_SKLEARN_DECART", "0"))
- )
- if with_lmt is None:
- with_lmt = bool(int(os.getenv("DFATOOL_DTREE_LMT", "0")))
- if with_xgboost is None:
- with_xgboost = bool(int(os.getenv("DFATOOL_USE_XGBOOST", "0")))
if with_gplearn_symreg is None:
with_gplearn_symreg = bool(int(os.getenv("DFATOOL_USE_SYMREG", "0")))
if ignore_irrelevant_parameters is None:
@@ -980,68 +1042,6 @@ class ModelAttribute:
int(os.getenv("DFATOOL_DTREE_LOSS_IGNORE_SCALAR", "0"))
)
- if with_sklearn_cart or with_sklearn_decart:
- mf = df.CARTFunction(
- np.mean(data),
- n_samples=len(data),
- param_names=self.param_names,
- arg_count=self.arg_count,
- decart=with_sklearn_decart,
- )
-
- mf.fit(
- parameters,
- data,
- scalar_param_indexes=self.scalar_param_indexes,
- )
-
- if mf.fit_success:
- self.model_function = mf
- else:
- logger.warning(f"CART generation for {self.name} {self.attr} faled")
- self.model_function = df.StaticFunction(
- np.mean(data), n_samples=len(data)
- )
- return
-
- if with_xgboost:
- mf = df.XGBoostFunction(
- np.mean(data),
- n_samples=len(data),
- param_names=self.param_names,
- arg_count=self.arg_count,
- )
-
- mf.fit(parameters, data)
-
- if mf.fit_success:
- self.model_function = mf
- else:
- logger.warning(f"XGB generation for {self.name} {self.attr} faled")
- self.model_function = df.StaticFunction(
- np.mean(data), n_samples=len(data)
- )
- return
-
- if with_lmt:
- mf = df.LMTFunction(
- np.mean(data),
- n_samples=len(data),
- param_names=self.param_names,
- arg_count=self.arg_count,
- )
-
- mf.fit(parameters, data)
-
- if mf.fit_success:
- self.model_function = mf
- else:
- logger.warning(f"LMT generation for {self.name} {self.attr} faled")
- self.model_function = df.StaticFunction(
- np.mean(data), n_samples=len(data)
- )
- return
-
if loss_ignore_scalar and not with_function_leaves:
logger.warning(
"build_dtree {self.name} {self.attr} called with loss_ignore_scalar=True, with_function_leaves=False. This does not make sense."