summaryrefslogtreecommitdiff
path: root/lib/model.py
diff options
context:
space:
mode:
authorBirte Kristina Friesel <birte.friesel@uos.de>2024-02-20 07:23:36 +0100
committerBirte Kristina Friesel <birte.friesel@uos.de>2024-02-20 07:23:36 +0100
commitddf6e00b3b16a07b994107a79450909f43588445 (patch)
treebeb8d4cb2b9de4a20e427fe30caa177599710cb8 /lib/model.py
parent8290d9ee20b5d2305c5fd519bf534029f36c30d9 (diff)
Re-add (very very basic, for now) Symbolic Regression support
Diffstat (limited to 'lib/model.py')
-rw-r--r--lib/model.py19
1 files changed, 13 insertions, 6 deletions
diff --git a/lib/model.py b/lib/model.py
index 5844113..e2a455d 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -296,9 +296,11 @@ class AnalyticModel:
if fallback:
return list(
map(
- lambda p: lut_model[name][key][tuple(p)]
- if tuple(p) in lut_model[name][key]
- else static_model[name][key],
+ lambda p: (
+ lut_model[name][key][tuple(p)]
+ if tuple(p) in lut_model[name][key]
+ else static_model[name][key]
+ ),
params,
)
)
@@ -320,6 +322,7 @@ class AnalyticModel:
paramfit = ParamFit()
tree_allowed = bool(int(os.getenv("DFATOOL_DTREE_ENABLED", "1")))
use_fol = bool(int(os.getenv("DFATOOL_FIT_FOL", "0")))
+ use_symreg = bool(int(os.getenv("DFATOOL_FIT_SYMREG", "0")))
tree_required = dict()
for name in self.names:
@@ -329,6 +332,8 @@ class AnalyticModel:
self.attr_by_name[name][attr].fit_override_function()
elif use_fol:
self.attr_by_name[name][attr].build_fol_model()
+ elif use_symreg:
+ self.attr_by_name[name][attr].build_symreg_model()
elif self.attr_by_name[name][
attr
].all_relevant_parameters_are_none_or_numeric():
@@ -395,9 +400,11 @@ class AnalyticModel:
return model_function.eval_arr(kwargs["params"])
return list(
map(
- lambda p: model_function.eval(p)
- if model_function.is_predictable(p)
- else static_model[name][key],
+ lambda p: (
+ model_function.eval(p)
+ if model_function.is_predictable(p)
+ else static_model[name][key]
+ ),
kwargs["params"],
)
)