diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2021-10-26 15:14:37 +0200 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2021-10-26 15:14:37 +0200 |
commit | 24c645137264c2f9d9146cda73677a67ec815ca3 (patch) | |
tree | ad6b67dc7ac7f0b28581a5b60fb67f43b231faf5 /lib | |
parent | d6a19d976b699e0b230b2e6c8fdd11a0c832ae83 (diff) |
allow custom standard deviation thresholds for decision tree compilation
Diffstat (limited to 'lib')
-rw-r--r-- | lib/model.py | 30 |
1 files changed, 26 insertions, 4 deletions
diff --git a/lib/model.py b/lib/model.py index 749eebb..b6318d7 100644 --- a/lib/model.py +++ b/lib/model.py @@ -76,6 +76,7 @@ class AnalyticModel: use_corrcoef=False, compute_stats=True, force_tree=False, + max_std=None, ): """ Create a new AnalyticModel and compute parameter statistics. @@ -119,6 +120,7 @@ class AnalyticModel: self.names = sorted(by_name.keys()) self.parameters = sorted(parameters) self.function_override = function_override.copy() + self.dtree_max_std = max_std self._use_corrcoef = use_corrcoef self._num_args = arg_count if self._num_args is None: @@ -138,8 +140,19 @@ class AnalyticModel: if force_tree: for name in self.names: for attr in self.by_name[name]["attributes"]: - # TODO specify correct threshold - self.build_dtree(name, attr, 0) + if max_std and name in max_std and attr in max_std[name]: + threshold = max_std[name][attr] + elif compute_stats: + threshold = (self.attr_by_name[name][attr].stats.std_param_lut,) + else: + threshold = 0 + with_function_leaves = bool( + int(os.getenv("DFATOOL_DTREE_FUNCTION_LEAVES", "1")) + ) + logger.debug( + f"build_dtree({name}, {attr}, threshold={threshold}, with_function_leaves={with_function_leaves})" + ) + self.build_dtree(name, attr, threshold, with_function_leaves) self.fit_done = True def __repr__(self): @@ -278,13 +291,20 @@ class AnalyticModel: with_function_leaves = bool( int(os.getenv("DFATOOL_DTREE_FUNCTION_LEAVES", "1")) ) + threshold = self.attr_by_name[name][attr].stats.std_param_lut + if ( + self.dtree_max_std + and name in self.dtree_max_std + and attr in self.dtree_max_std[name] + ): + threshold = self.dtree_max_std[name][attr] logger.debug( - f"build_dtree({name}, {attr}, threshold={self.attr_by_name[name][attr].stats.std_param_lut}, with_function_leaves={with_function_leaves})" + f"build_dtree({name}, {attr}, threshold={threshold}, with_function_leaves={with_function_leaves})" ) self.build_dtree( name, attr, - self.attr_by_name[name][attr].stats.std_param_lut, + threshold, with_function_leaves=with_function_leaves, ) else: @@ -513,6 +533,7 @@ class PTAModel(AnalyticModel): pta=None, pelt=None, compute_stats=True, + dtree_max_std=None, ): """ Prepare a new PTA energy model. @@ -556,6 +577,7 @@ class PTAModel(AnalyticModel): ) ) self.states_and_transitions = self.states + self.transitions + self.dtree_max_std = dtree_max_std self._parameter_names = sorted(parameters) self.parameters = sorted(parameters) |