summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2021-11-16 08:54:49 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2021-11-16 08:54:49 +0100
commit76832a396075c30cea7d272cff69dd75354a057c (patch)
treef4f33cae7aad1be04b57c1705028d8ba79078c3a
parentee0d122605120edfeff3837b8d358a1ef3ac5a3d (diff)
Add env var for safe functions; disable unsafe functions in that case
-rw-r--r--README.md1
-rw-r--r--lib/functions.py8
2 files changed, 8 insertions, 1 deletions
diff --git a/README.md b/README.md
index b547471..e2f499a 100644
--- a/README.md
+++ b/README.md
@@ -35,3 +35,4 @@ The following variables may be set to alter the behaviour of dfatool components.
| `DFATOOL_KCONF_IGNORE_NUMERIC` | **0**, 1 | Ignore numeric (int/hex) configuration options. Useful for comparison with CART/DECART. |
| `DFATOOL_KCONF_IGNORE_STRING` | **0**, 1 | Ignore string configuration options. Useful for comparison with CART/DECART. |
| `DFATOOL_FIT_LINEAR_ONLY` | **0**, 1 | Only consider linear functions (a + bx) in regression analysis. Useful for comparison with Linear Model Trees / M5. |
+| `DFATOOL_REGRESSION_SAFE_FUNCTIONS` | **0**, 1 | Use safe functions only (e.g. 1/x returnning 1 for x==0) |
diff --git a/lib/functions.py b/lib/functions.py
index 7cd06c0..0a488dc 100644
--- a/lib/functions.py
+++ b/lib/functions.py
@@ -740,19 +740,25 @@ class analytic:
# ),
}
- if safe_functions_enabled:
+ if safe_functions_enabled or bool(
+ int(os.getenv("DFATOOL_REGRESSION_SAFE_FUNCTIONS", "0"))
+ ):
+ functions.pop("logarithmic1")
+ functions.pop("logarithmic")
functions["safe_log"] = ParamFunction(
lambda reg_param, model_param: reg_param[0]
+ reg_param[1] * analytic._safe_log(model_param),
lambda model_param: True,
2,
)
+ functions.pop("inverse")
functions["safe_inv"] = ParamFunction(
lambda reg_param, model_param: reg_param[0]
+ reg_param[1] * analytic._safe_inv(model_param),
lambda model_param: True,
2,
)
+ functions.pop("sqrt")
functions["safe_sqrt"] = ParamFunction(
lambda reg_param, model_param: reg_param[0]
+ reg_param[1] * analytic._safe_sqrt(model_param),