diff options
author | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-02-21 11:38:48 +0100 |
---|---|---|
committer | Birte Kristina Friesel <birte.friesel@uos.de> | 2024-02-21 11:38:48 +0100 |
commit | c3dbe93034bdeff9dba534d29b04daa527d70241 (patch) | |
tree | 14cdfd0047a01880ca30e96cf6d347009777a80e /lib/model.py | |
parent | 6bf411afb9289408c62e1696d8fb2b4da47a9fab (diff) |
move (de)cart, lmt, xgb model generation into separate ModelAttribute functions
Diffstat (limited to 'lib/model.py')
-rw-r--r-- | lib/model.py | 23 |
1 files changed, 19 insertions, 4 deletions
diff --git a/lib/model.py b/lib/model.py index 9266153..972547d 100644 --- a/lib/model.py +++ b/lib/model.py @@ -4,7 +4,7 @@ import logging import numpy as np import os from .automata import PTA -from .functions import StaticFunction, SubstateFunction, SplitFunction +import dfatool.functions as df from .parameters import ( ModelAttribute, ParamType, @@ -295,7 +295,22 @@ class AnalyticModel: def build_fitted(self, safe_functions_enabled=False): - if self.force_tree: + model_type = os.getenv("DFATOOL_MODEL", "rmt") + + if model_type != "rmt": + for name in self.names: + for attr in self.by_name[name]["attributes"]: + if model_type == "cart": + self.attr_by_name[name][attr].build_cart() + elif model_type == "decart": + self.attr_by_name[name][attr].build_decart() + elif model_type == "lmt": + self.attr_by_name[name][attr].build_lmt() + elif model_type == "xgb": + self.attr_by_name[name][attr].build_xgb() + else: + logger.error("build_fitted: unknown model type: {model_type}") + elif self.force_tree: for name in self.names: for attr in self.by_name[name]["attributes"]: if ( @@ -392,7 +407,7 @@ class AnalyticModel: model_info = self.attr_by_name[name][key].model_function # shortcut - if type(model_info) is StaticFunction: + if type(model_info) is df.StaticFunction: if "params" in kwargs: return [static_model[name][key] for p in kwargs["params"]] return static_model[name][key] @@ -1018,7 +1033,7 @@ class PTAModel(AnalyticModel): ) ) - self.attr_by_name[p_name]["power"].model_function = SubstateFunction( + self.attr_by_name[p_name]["power"].model_function = df.SubstateFunction( self.attr_by_name[p_name]["power"].get_static(), sequence_by_count, self.attr_by_name[p_name]["substate_count"].model_function, |