summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2022-06-03 14:05:48 +0200
committerDaniel Friesel <daniel.friesel@uos.de>2022-06-03 14:05:48 +0200
commit3246558d3480552f1c5a2b9bdb1267bf7d9007b5 (patch)
treedd2fbe7d25900e69727a70be270fa03713d0a758 /lib
parent205f8c9edc49218555eae58da3c7532f081e7754 (diff)
FOL: Add second order support
Diffstat (limited to 'lib')
-rw-r--r--lib/functions.py22
1 files changed, 18 insertions, 4 deletions
diff --git a/lib/functions.py b/lib/functions.py
index aa328ef..cd3c7d1 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -612,6 +612,7 @@ class FOLFunction(ModelFunction):
categorial_to_scalar = bool(
int(os.getenv("DFATOOL_PARAM_CATEGORIAL_TO_SCALAR", "0"))
)
+ second_order = bool(int(os.getenv("DFATOOL_FOL_SECOND_ORDER", "0")))
fit_parameters, categorial_to_index, ignore_index = param_to_ndarray(
param_values,
with_nan=False,
@@ -621,10 +622,23 @@ class FOLFunction(ModelFunction):
self.categorial_to_index = categorial_to_index
self.ignore_index = ignore_index
fit_parameters = fit_parameters.swapaxes(0, 1)
- num_vars = fit_parameters.shape[0]
- funbuf = "lambda reg_param, model_param: 0"
- for i in range(num_vars):
- funbuf += f" + reg_param[{i}] * model_param[{i}]"
+
+ if second_order:
+ num_param = fit_parameters.shape[0]
+ num_vars = 0
+ funbuf = "lambda reg_param, model_param: 0"
+ for i in range(num_param):
+ funbuf += f" + reg_param[{num_vars}] * model_param[{i}]"
+ num_vars += 1
+ for j in range(i + 1, num_param):
+ funbuf += f" + reg_param[{num_vars}] * model_param[{i}] * model_param[{j}]"
+ num_vars += 1
+ else:
+ num_vars = fit_parameters.shape[0]
+ funbuf = "lambda reg_param, model_param: 0"
+ for i in range(num_vars):
+ funbuf += f" + reg_param[{i}] * model_param[{i}]"
+
self._function_str = self.model_function = funbuf
self._function = eval(funbuf)