From 24c645137264c2f9d9146cda73677a67ec815ca3 Mon Sep 17 00:00:00 2001 From: Daniel Friesel Date: Tue, 26 Oct 2021 15:14:37 +0200 Subject: allow custom standard deviation thresholds for decision tree compilation --- lib/model.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) (limited to 'lib') diff --git a/lib/model.py b/lib/model.py index 749eebb..b6318d7 100644 --- a/lib/model.py +++ b/lib/model.py @@ -76,6 +76,7 @@ class AnalyticModel: use_corrcoef=False, compute_stats=True, force_tree=False, + max_std=None, ): """ Create a new AnalyticModel and compute parameter statistics. @@ -119,6 +120,7 @@ class AnalyticModel: self.names = sorted(by_name.keys()) self.parameters = sorted(parameters) self.function_override = function_override.copy() + self.dtree_max_std = max_std self._use_corrcoef = use_corrcoef self._num_args = arg_count if self._num_args is None: @@ -138,8 +140,19 @@ class AnalyticModel: if force_tree: for name in self.names: for attr in self.by_name[name]["attributes"]: - # TODO specify correct threshold - self.build_dtree(name, attr, 0) + if max_std and name in max_std and attr in max_std[name]: + threshold = max_std[name][attr] + elif compute_stats: + threshold = (self.attr_by_name[name][attr].stats.std_param_lut,) + else: + threshold = 0 + with_function_leaves = bool( + int(os.getenv("DFATOOL_DTREE_FUNCTION_LEAVES", "1")) + ) + logger.debug( + f"build_dtree({name}, {attr}, threshold={threshold}, with_function_leaves={with_function_leaves})" + ) + self.build_dtree(name, attr, threshold, with_function_leaves) self.fit_done = True def __repr__(self): @@ -278,13 +291,20 @@ class AnalyticModel: with_function_leaves = bool( int(os.getenv("DFATOOL_DTREE_FUNCTION_LEAVES", "1")) ) + threshold = self.attr_by_name[name][attr].stats.std_param_lut + if ( + self.dtree_max_std + and name in self.dtree_max_std + and attr in self.dtree_max_std[name] + ): + threshold = self.dtree_max_std[name][attr] logger.debug( - f"build_dtree({name}, {attr}, threshold={self.attr_by_name[name][attr].stats.std_param_lut}, with_function_leaves={with_function_leaves})" + f"build_dtree({name}, {attr}, threshold={threshold}, with_function_leaves={with_function_leaves})" ) self.build_dtree( name, attr, - self.attr_by_name[name][attr].stats.std_param_lut, + threshold, with_function_leaves=with_function_leaves, ) else: @@ -513,6 +533,7 @@ class PTAModel(AnalyticModel): pta=None, pelt=None, compute_stats=True, + dtree_max_std=None, ): """ Prepare a new PTA energy model. @@ -556,6 +577,7 @@ class PTAModel(AnalyticModel): ) ) self.states_and_transitions = self.states + self.transitions + self.dtree_max_std = dtree_max_std self._parameter_names = sorted(parameters) self.parameters = sorted(parameters) -- cgit v1.2.3