summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2021-10-26 15:14:37 +0200
committerDaniel Friesel <daniel.friesel@uos.de>2021-10-26 15:14:37 +0200
commit24c645137264c2f9d9146cda73677a67ec815ca3 (patch)
treead6b67dc7ac7f0b28581a5b60fb67f43b231faf5 /lib
parentd6a19d976b699e0b230b2e6c8fdd11a0c832ae83 (diff)
allow custom standard deviation thresholds for decision tree compilation
Diffstat (limited to 'lib')
-rw-r--r--lib/model.py30
1 files changed, 26 insertions, 4 deletions
diff --git a/lib/model.py b/lib/model.py
index 749eebb..b6318d7 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -76,6 +76,7 @@ class AnalyticModel:
use_corrcoef=False,
compute_stats=True,
force_tree=False,
+ max_std=None,
):
"""
Create a new AnalyticModel and compute parameter statistics.
@@ -119,6 +120,7 @@ class AnalyticModel:
self.names = sorted(by_name.keys())
self.parameters = sorted(parameters)
self.function_override = function_override.copy()
+ self.dtree_max_std = max_std
self._use_corrcoef = use_corrcoef
self._num_args = arg_count
if self._num_args is None:
@@ -138,8 +140,19 @@ class AnalyticModel:
if force_tree:
for name in self.names:
for attr in self.by_name[name]["attributes"]:
- # TODO specify correct threshold
- self.build_dtree(name, attr, 0)
+ if max_std and name in max_std and attr in max_std[name]:
+ threshold = max_std[name][attr]
+ elif compute_stats:
+ threshold = (self.attr_by_name[name][attr].stats.std_param_lut,)
+ else:
+ threshold = 0
+ with_function_leaves = bool(
+ int(os.getenv("DFATOOL_DTREE_FUNCTION_LEAVES", "1"))
+ )
+ logger.debug(
+ f"build_dtree({name}, {attr}, threshold={threshold}, with_function_leaves={with_function_leaves})"
+ )
+ self.build_dtree(name, attr, threshold, with_function_leaves)
self.fit_done = True
def __repr__(self):
@@ -278,13 +291,20 @@ class AnalyticModel:
with_function_leaves = bool(
int(os.getenv("DFATOOL_DTREE_FUNCTION_LEAVES", "1"))
)
+ threshold = self.attr_by_name[name][attr].stats.std_param_lut
+ if (
+ self.dtree_max_std
+ and name in self.dtree_max_std
+ and attr in self.dtree_max_std[name]
+ ):
+ threshold = self.dtree_max_std[name][attr]
logger.debug(
- f"build_dtree({name}, {attr}, threshold={self.attr_by_name[name][attr].stats.std_param_lut}, with_function_leaves={with_function_leaves})"
+ f"build_dtree({name}, {attr}, threshold={threshold}, with_function_leaves={with_function_leaves})"
)
self.build_dtree(
name,
attr,
- self.attr_by_name[name][attr].stats.std_param_lut,
+ threshold,
with_function_leaves=with_function_leaves,
)
else:
@@ -513,6 +533,7 @@ class PTAModel(AnalyticModel):
pta=None,
pelt=None,
compute_stats=True,
+ dtree_max_std=None,
):
"""
Prepare a new PTA energy model.
@@ -556,6 +577,7 @@ class PTAModel(AnalyticModel):
)
)
self.states_and_transitions = self.states + self.transitions
+ self.dtree_max_std = dtree_max_std
self._parameter_names = sorted(parameters)
self.parameters = sorted(parameters)