summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2022-09-23 15:22:19 +0200
committerDaniel Friesel <daniel.friesel@uos.de>2022-09-23 15:22:19 +0200
commit656e4a22c55cd785a8f6fe079adfb7d249f42e1e (patch)
tree2cd606578c02045cc18efd1bc7d3a39ee5edf56c
parente224ced44f95880d86f4913396d7c621fe2f2db1 (diff)
do not build dtree in static and LUT cross-validation runs
-rwxr-xr-xbin/analyze-kconfig.py10
-rw-r--r--lib/validation.py25
2 files changed, 23 insertions, 12 deletions
diff --git a/bin/analyze-kconfig.py b/bin/analyze-kconfig.py
index cea1c9e..c33292e 100755
--- a/bin/analyze-kconfig.py
+++ b/bin/analyze-kconfig.py
@@ -421,10 +421,12 @@ def main():
logging.debug(f"model.get_fitted(...) took {fit_duration : 7.1f} seconds")
if xv_method == "montecarlo":
- static_quality, _ = xv.montecarlo(lambda m: m.get_static(), xv_count)
+ static_quality, _ = xv.montecarlo(
+ lambda m: m.get_static(), xv_count, static=True
+ )
if lut_model:
lut_quality, _ = xv.montecarlo(
- lambda m: m.get_param_lut(fallback=True), xv_count
+ lambda m: m.get_param_lut(fallback=True), xv_count, static=True
)
else:
lut_quality = None
@@ -433,10 +435,10 @@ def main():
lambda m: m.get_fitted()[0], xv_count
)
elif xv_method == "kfold":
- static_quality, _ = xv.kfold(lambda m: m.get_static(), xv_count)
+ static_quality, _ = xv.kfold(lambda m: m.get_static(), xv_count, static=True)
if lut_model:
lut_quality, _ = xv.kfold(
- lambda m: m.get_param_lut(fallback=True), xv_count
+ lambda m: m.get_param_lut(fallback=True), xv_count, static=True
)
else:
lut_quality = None
diff --git a/lib/validation.py b/lib/validation.py
index 0e735a0..95815ac 100644
--- a/lib/validation.py
+++ b/lib/validation.py
@@ -107,7 +107,7 @@ class CrossValidator:
self.args = args
self.kwargs = kwargs
- def kfold(self, model_getter, k=10):
+ def kfold(self, model_getter, k=10, static=False):
"""
Perform k-fold cross-validation and return average model quality.
@@ -159,9 +159,11 @@ class CrossValidator:
for name in self.names:
training_and_validation_sets[i][name] = subsets_by_name[name][i]
- return self._generic_xv(model_getter, training_and_validation_sets)
+ return self._generic_xv(
+ model_getter, training_and_validation_sets, static=static
+ )
- def montecarlo(self, model_getter, count=200):
+ def montecarlo(self, model_getter, count=200, static=False):
"""
Perform Monte Carlo cross-validation and return average model quality.
@@ -208,9 +210,11 @@ class CrossValidator:
for name in self.names:
training_and_validation_sets[i][name] = subsets_by_name[name][i]
- return self._generic_xv(model_getter, training_and_validation_sets)
+ return self._generic_xv(
+ model_getter, training_and_validation_sets, static=static
+ )
- def _generic_xv(self, model_getter, training_and_validation_sets):
+ def _generic_xv(self, model_getter, training_and_validation_sets, static=False):
ret = dict()
models = list()
@@ -225,7 +229,9 @@ class CrossValidator:
}
for training_and_validation_by_name in training_and_validation_sets:
- model, res = self._single_xv(model_getter, training_and_validation_by_name)
+ model, res = self._single_xv(
+ model_getter, training_and_validation_by_name, static=static
+ )
models.append(model)
for name in self.names:
for attribute in self.by_name[name]["attributes"]:
@@ -249,7 +255,7 @@ class CrossValidator:
return ret, models
- def _single_xv(self, model_getter, tv_set_dict):
+ def _single_xv(self, model_getter, tv_set_dict, static=False):
training = dict()
validation = dict()
for name in self.names:
@@ -279,8 +285,11 @@ class CrossValidator:
for idx in validation_subset:
validation[name]["param"].append(self.by_name[name]["param"][idx])
+ kwargs = self.kwargs.copy()
+ if static:
+ kwargs["force_tree"] = False
training_data = self.model_class(
- training, self.parameters, *self.args, **self.kwargs
+ training, self.parameters, *self.args, **kwargs
)
training_model = model_getter(training_data)
kwargs = self.kwargs.copy()