summaryrefslogtreecommitdiff
path: root/lib/model.py
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-02-21 11:38:48 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-02-21 11:38:48 +0100
commitc3dbe93034bdeff9dba534d29b04daa527d70241 (patch)
tree14cdfd0047a01880ca30e96cf6d347009777a80e /lib/model.py
parent6bf411afb9289408c62e1696d8fb2b4da47a9fab (diff)
move (de)cart, lmt, xgb model generation into separate ModelAttribute functions
Diffstat (limited to 'lib/model.py')
-rw-r--r--lib/model.py23
1 files changed, 19 insertions, 4 deletions
diff --git a/lib/model.py b/lib/model.py
index 9266153..972547d 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -4,7 +4,7 @@ import logging
import numpy as np
import os
from .automata import PTA
-from .functions import StaticFunction, SubstateFunction, SplitFunction
+import dfatool.functions as df
from .parameters import (
ModelAttribute,
ParamType,
@@ -295,7 +295,22 @@ class AnalyticModel:
def build_fitted(self, safe_functions_enabled=False):
- if self.force_tree:
+ model_type = os.getenv("DFATOOL_MODEL", "rmt")
+
+ if model_type != "rmt":
+ for name in self.names:
+ for attr in self.by_name[name]["attributes"]:
+ if model_type == "cart":
+ self.attr_by_name[name][attr].build_cart()
+ elif model_type == "decart":
+ self.attr_by_name[name][attr].build_decart()
+ elif model_type == "lmt":
+ self.attr_by_name[name][attr].build_lmt()
+ elif model_type == "xgb":
+ self.attr_by_name[name][attr].build_xgb()
+ else:
+ logger.error("build_fitted: unknown model type: {model_type}")
+ elif self.force_tree:
for name in self.names:
for attr in self.by_name[name]["attributes"]:
if (
@@ -392,7 +407,7 @@ class AnalyticModel:
model_info = self.attr_by_name[name][key].model_function
# shortcut
- if type(model_info) is StaticFunction:
+ if type(model_info) is df.StaticFunction:
if "params" in kwargs:
return [static_model[name][key] for p in kwargs["params"]]
return static_model[name][key]
@@ -1018,7 +1033,7 @@ class PTAModel(AnalyticModel):
)
)
- self.attr_by_name[p_name]["power"].model_function = SubstateFunction(
+ self.attr_by_name[p_name]["power"].model_function = df.SubstateFunction(
self.attr_by_name[p_name]["power"].get_static(),
sequence_by_count,
self.attr_by_name[p_name]["substate_count"].model_function,