diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2021-10-26 15:14:37 +0200 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2021-10-26 15:14:37 +0200 |
commit | 24c645137264c2f9d9146cda73677a67ec815ca3 (patch) | |
tree | ad6b67dc7ac7f0b28581a5b60fb67f43b231faf5 | |
parent | d6a19d976b699e0b230b2e6c8fdd11a0c832ae83 (diff) |
allow custom standard deviation thresholds for decision tree compilation
-rwxr-xr-x | bin/analyze-kconfig.py | 25 | ||||
-rw-r--r-- | lib/model.py | 30 |
2 files changed, 51 insertions, 4 deletions
diff --git a/bin/analyze-kconfig.py b/bin/analyze-kconfig.py index 004e691..533e621 100755 --- a/bin/analyze-kconfig.py +++ b/bin/analyze-kconfig.py @@ -41,6 +41,12 @@ def main(): help="Build decision tree without checking for analytic functions first. Use this for large kconfig files.", ) parser.add_argument( + "--max-std", + type=str, + metavar="VALUE_OR_MAP", + help="Specify desired maximum standard deviation for decision tree generation, either as float (global) or <key>/<attribute>=<value>[,<key>/<attribute>=<value>,...]", + ) + parser.add_argument( "--export-model", type=str, help="Export kconfig-webconf NFP model to file", @@ -119,11 +125,29 @@ def main(): # Release memory observations = None + if args.max_std: + max_std = dict() + if "=" in args.max_std: + for kkv in args.max_std.split(","): + kk, v = kkv.split("=") + key, attr = kk.split("/") + if key not in max_std: + max_std[key] = dict() + max_std[key][attr] = float(v) + else: + for key in by_name.keys(): + max_std[key] = dict() + for attr in by_name[key]["attributes"]: + max_std[key][attr] = float(args.max_std) + else: + max_std = None + model = AnalyticModel( by_name, parameter_names, compute_stats=not args.force_tree, force_tree=args.force_tree, + max_std=max_std, ) if args.cross_validate: @@ -135,6 +159,7 @@ def main(): parameter_names, compute_stats=not args.force_tree, force_tree=args.force_tree, + max_std=max_std, ) else: xv_method = None diff --git a/lib/model.py b/lib/model.py index 749eebb..b6318d7 100644 --- a/lib/model.py +++ b/lib/model.py @@ -76,6 +76,7 @@ class AnalyticModel: use_corrcoef=False, compute_stats=True, force_tree=False, + max_std=None, ): """ Create a new AnalyticModel and compute parameter statistics. @@ -119,6 +120,7 @@ class AnalyticModel: self.names = sorted(by_name.keys()) self.parameters = sorted(parameters) self.function_override = function_override.copy() + self.dtree_max_std = max_std self._use_corrcoef = use_corrcoef self._num_args = arg_count if self._num_args is None: @@ -138,8 +140,19 @@ class AnalyticModel: if force_tree: for name in self.names: for attr in self.by_name[name]["attributes"]: - # TODO specify correct threshold - self.build_dtree(name, attr, 0) + if max_std and name in max_std and attr in max_std[name]: + threshold = max_std[name][attr] + elif compute_stats: + threshold = (self.attr_by_name[name][attr].stats.std_param_lut,) + else: + threshold = 0 + with_function_leaves = bool( + int(os.getenv("DFATOOL_DTREE_FUNCTION_LEAVES", "1")) + ) + logger.debug( + f"build_dtree({name}, {attr}, threshold={threshold}, with_function_leaves={with_function_leaves})" + ) + self.build_dtree(name, attr, threshold, with_function_leaves) self.fit_done = True def __repr__(self): @@ -278,13 +291,20 @@ class AnalyticModel: with_function_leaves = bool( int(os.getenv("DFATOOL_DTREE_FUNCTION_LEAVES", "1")) ) + threshold = self.attr_by_name[name][attr].stats.std_param_lut + if ( + self.dtree_max_std + and name in self.dtree_max_std + and attr in self.dtree_max_std[name] + ): + threshold = self.dtree_max_std[name][attr] logger.debug( - f"build_dtree({name}, {attr}, threshold={self.attr_by_name[name][attr].stats.std_param_lut}, with_function_leaves={with_function_leaves})" + f"build_dtree({name}, {attr}, threshold={threshold}, with_function_leaves={with_function_leaves})" ) self.build_dtree( name, attr, - self.attr_by_name[name][attr].stats.std_param_lut, + threshold, with_function_leaves=with_function_leaves, ) else: @@ -513,6 +533,7 @@ class PTAModel(AnalyticModel): pta=None, pelt=None, compute_stats=True, + dtree_max_std=None, ): """ Prepare a new PTA energy model. @@ -556,6 +577,7 @@ class PTAModel(AnalyticModel): ) ) self.states_and_transitions = self.states + self.transitions + self.dtree_max_std = dtree_max_std self._parameter_names = sorted(parameters) self.parameters = sorted(parameters) |