summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2021-10-26 15:14:37 +0200
committerDaniel Friesel <daniel.friesel@uos.de>2021-10-26 15:14:37 +0200
commit24c645137264c2f9d9146cda73677a67ec815ca3 (patch)
treead6b67dc7ac7f0b28581a5b60fb67f43b231faf5
parentd6a19d976b699e0b230b2e6c8fdd11a0c832ae83 (diff)
allow custom standard deviation thresholds for decision tree compilation
-rwxr-xr-xbin/analyze-kconfig.py25
-rw-r--r--lib/model.py30
2 files changed, 51 insertions, 4 deletions
diff --git a/bin/analyze-kconfig.py b/bin/analyze-kconfig.py
index 004e691..533e621 100755
--- a/bin/analyze-kconfig.py
+++ b/bin/analyze-kconfig.py
@@ -41,6 +41,12 @@ def main():
help="Build decision tree without checking for analytic functions first. Use this for large kconfig files.",
)
parser.add_argument(
+ "--max-std",
+ type=str,
+ metavar="VALUE_OR_MAP",
+ help="Specify desired maximum standard deviation for decision tree generation, either as float (global) or <key>/<attribute>=<value>[,<key>/<attribute>=<value>,...]",
+ )
+ parser.add_argument(
"--export-model",
type=str,
help="Export kconfig-webconf NFP model to file",
@@ -119,11 +125,29 @@ def main():
# Release memory
observations = None
+ if args.max_std:
+ max_std = dict()
+ if "=" in args.max_std:
+ for kkv in args.max_std.split(","):
+ kk, v = kkv.split("=")
+ key, attr = kk.split("/")
+ if key not in max_std:
+ max_std[key] = dict()
+ max_std[key][attr] = float(v)
+ else:
+ for key in by_name.keys():
+ max_std[key] = dict()
+ for attr in by_name[key]["attributes"]:
+ max_std[key][attr] = float(args.max_std)
+ else:
+ max_std = None
+
model = AnalyticModel(
by_name,
parameter_names,
compute_stats=not args.force_tree,
force_tree=args.force_tree,
+ max_std=max_std,
)
if args.cross_validate:
@@ -135,6 +159,7 @@ def main():
parameter_names,
compute_stats=not args.force_tree,
force_tree=args.force_tree,
+ max_std=max_std,
)
else:
xv_method = None
diff --git a/lib/model.py b/lib/model.py
index 749eebb..b6318d7 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -76,6 +76,7 @@ class AnalyticModel:
use_corrcoef=False,
compute_stats=True,
force_tree=False,
+ max_std=None,
):
"""
Create a new AnalyticModel and compute parameter statistics.
@@ -119,6 +120,7 @@ class AnalyticModel:
self.names = sorted(by_name.keys())
self.parameters = sorted(parameters)
self.function_override = function_override.copy()
+ self.dtree_max_std = max_std
self._use_corrcoef = use_corrcoef
self._num_args = arg_count
if self._num_args is None:
@@ -138,8 +140,19 @@ class AnalyticModel:
if force_tree:
for name in self.names:
for attr in self.by_name[name]["attributes"]:
- # TODO specify correct threshold
- self.build_dtree(name, attr, 0)
+ if max_std and name in max_std and attr in max_std[name]:
+ threshold = max_std[name][attr]
+ elif compute_stats:
+ threshold = (self.attr_by_name[name][attr].stats.std_param_lut,)
+ else:
+ threshold = 0
+ with_function_leaves = bool(
+ int(os.getenv("DFATOOL_DTREE_FUNCTION_LEAVES", "1"))
+ )
+ logger.debug(
+ f"build_dtree({name}, {attr}, threshold={threshold}, with_function_leaves={with_function_leaves})"
+ )
+ self.build_dtree(name, attr, threshold, with_function_leaves)
self.fit_done = True
def __repr__(self):
@@ -278,13 +291,20 @@ class AnalyticModel:
with_function_leaves = bool(
int(os.getenv("DFATOOL_DTREE_FUNCTION_LEAVES", "1"))
)
+ threshold = self.attr_by_name[name][attr].stats.std_param_lut
+ if (
+ self.dtree_max_std
+ and name in self.dtree_max_std
+ and attr in self.dtree_max_std[name]
+ ):
+ threshold = self.dtree_max_std[name][attr]
logger.debug(
- f"build_dtree({name}, {attr}, threshold={self.attr_by_name[name][attr].stats.std_param_lut}, with_function_leaves={with_function_leaves})"
+ f"build_dtree({name}, {attr}, threshold={threshold}, with_function_leaves={with_function_leaves})"
)
self.build_dtree(
name,
attr,
- self.attr_by_name[name][attr].stats.std_param_lut,
+ threshold,
with_function_leaves=with_function_leaves,
)
else:
@@ -513,6 +533,7 @@ class PTAModel(AnalyticModel):
pta=None,
pelt=None,
compute_stats=True,
+ dtree_max_std=None,
):
"""
Prepare a new PTA energy model.
@@ -556,6 +577,7 @@ class PTAModel(AnalyticModel):
)
)
self.states_and_transitions = self.states + self.transitions
+ self.dtree_max_std = dtree_max_std
self._parameter_names = sorted(parameters)
self.parameters = sorted(parameters)