summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2022-01-05 15:23:51 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2022-01-05 15:23:51 +0100
commit2c5bcd77f2c952cc5269ca3e4b6e0a7323ebd085 (patch)
tree93da4dc33c77855445e6aa21f45c12b1803861fa
parentd9aee2a314ae6d3fc0216893a4ccfd8bb66ffa9c (diff)
cross validation: return intermediate models used for XV
These are interesting for statistics, e.g. to determine the average dtree size
-rwxr-xr-xbin/analyze-archive.py14
-rwxr-xr-xbin/analyze-kconfig.py18
-rwxr-xr-xbin/analyze-timing.py8
-rw-r--r--lib/validation.py8
4 files changed, 28 insertions, 20 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py
index b29091d..0a2845c 100755
--- a/bin/analyze-archive.py
+++ b/bin/analyze-archive.py
@@ -800,9 +800,9 @@ if __name__ == "__main__":
)
if xv_method == "montecarlo":
- static_quality = xv.montecarlo(lambda m: m.get_static(), xv_count)
+ static_quality, _ = xv.montecarlo(lambda m: m.get_static(), xv_count)
elif xv_method == "kfold":
- static_quality = xv.kfold(lambda m: m.get_static(), xv_count)
+ static_quality, _ = xv.kfold(lambda m: m.get_static(), xv_count)
else:
static_quality = model.assess(static_model)
@@ -811,9 +811,11 @@ if __name__ == "__main__":
lut_model = model.get_param_lut()
if xv_method == "montecarlo":
- lut_quality = xv.montecarlo(lambda m: m.get_param_lut(fallback=True), xv_count)
+ lut_quality, _ = xv.montecarlo(
+ lambda m: m.get_param_lut(fallback=True), xv_count
+ )
elif xv_method == "kfold":
- lut_quality = xv.kfold(lambda m: m.get_param_lut(fallback=True), xv_count)
+ lut_quality, _ = xv.kfold(lambda m: m.get_param_lut(fallback=True), xv_count)
else:
lut_quality = model.assess(lut_model)
@@ -933,9 +935,9 @@ if __name__ == "__main__":
)
if xv_method == "montecarlo":
- analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count)
+ analytic_quality, _ = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count)
elif xv_method == "kfold":
- analytic_quality = xv.kfold(lambda m: m.get_fitted()[0], xv_count)
+ analytic_quality, _ = xv.kfold(lambda m: m.get_fitted()[0], xv_count)
else:
analytic_quality = model.assess(param_model)
diff --git a/bin/analyze-kconfig.py b/bin/analyze-kconfig.py
index bd9cccb..048c8c9 100755
--- a/bin/analyze-kconfig.py
+++ b/bin/analyze-kconfig.py
@@ -237,19 +237,21 @@ def main():
fit_duration = time.time() - fit_start_time
if xv_method == "montecarlo":
- static_quality = xv.montecarlo(lambda m: m.get_static(), xv_count)
- analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count)
+ static_quality, _ = xv.montecarlo(lambda m: m.get_static(), xv_count)
+ analytic_quality, _ = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count)
if lut_model:
- lut_quality = xv.montecarlo(
+ lut_quality, _ = xv.montecarlo(
lambda m: m.get_param_lut(fallback=True), xv_count
)
else:
lut_quality = None
elif xv_method == "kfold":
- static_quality = xv.kfold(lambda m: m.get_static(), xv_count)
- analytic_quality = xv.kfold(lambda m: m.get_fitted()[0], xv_count)
+ static_quality, _ = xv.kfold(lambda m: m.get_static(), xv_count)
+ analytic_quality, _ = xv.kfold(lambda m: m.get_fitted()[0], xv_count)
if lut_model:
- lut_quality = xv.kfold(lambda m: m.get_param_lut(fallback=True), xv_count)
+ lut_quality, _ = xv.kfold(
+ lambda m: m.get_param_lut(fallback=True), xv_count
+ )
else:
lut_quality = None
else:
@@ -315,9 +317,9 @@ def main():
json.dump(json_model, f, sort_keys=True, cls=dfatool.utils.NpEncoder)
if xv_method == "montecarlo":
- static_quality = xv.montecarlo(lambda m: m.get_static(), xv_count)
+ static_quality, _ = xv.montecarlo(lambda m: m.get_static(), xv_count)
elif xv_method == "kfold":
- static_quality = xv.kfold(lambda m: m.get_static(), xv_count)
+ static_quality, _ = xv.kfold(lambda m: m.get_static(), xv_count)
else:
static_quality = model.assess(static_model)
diff --git a/bin/analyze-timing.py b/bin/analyze-timing.py
index d67c553..c37ea65 100755
--- a/bin/analyze-timing.py
+++ b/bin/analyze-timing.py
@@ -305,7 +305,7 @@ if __name__ == "__main__":
)
if xv_method == "montecarlo":
- static_quality = xv.montecarlo(lambda m: m.get_static(), xv_count)
+ static_quality, _ = xv.montecarlo(lambda m: m.get_static(), xv_count)
else:
static_quality = model.assess(static_model)
@@ -314,7 +314,9 @@ if __name__ == "__main__":
lut_model = model.get_param_lut()
if xv_method == "montecarlo":
- lut_quality = xv.montecarlo(lambda m: m.get_param_lut(fallback=True), xv_count)
+ lut_quality, _ = xv.montecarlo(
+ lambda m: m.get_param_lut(fallback=True), xv_count
+ )
else:
lut_quality = model.assess(lut_model)
@@ -412,7 +414,7 @@ if __name__ == "__main__":
print("{:10s} {:10s} {}".format("", "", info.model_args))
if xv_method == "montecarlo":
- analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count)
+ analytic_quality, _ = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count)
else:
analytic_quality = model.assess(param_model)
diff --git a/lib/validation.py b/lib/validation.py
index 47395dc..5c65fe3 100644
--- a/lib/validation.py
+++ b/lib/validation.py
@@ -211,6 +211,7 @@ class CrossValidator:
def _generic_xv(self, model_getter, training_and_validation_sets):
ret = dict()
+ models = list()
for name in self.names:
ret[name] = dict()
@@ -223,7 +224,8 @@ class CrossValidator:
}
for training_and_validation_by_name in training_and_validation_sets:
- res = self._single_xv(model_getter, training_and_validation_by_name)
+ model, res = self._single_xv(model_getter, training_and_validation_by_name)
+ models.append(model)
for name in self.names:
for attribute in self.by_name[name]["attributes"]:
for measure in ("mae", "rmsd", "mape", "smape"):
@@ -238,7 +240,7 @@ class CrossValidator:
ret[name][attribute][f"{measure}_list"]
)
- return ret
+ return ret, models
def _single_xv(self, model_getter, tv_set_dict):
training = dict()
@@ -278,4 +280,4 @@ class CrossValidator:
validation, self.parameters, *self.args, **self.kwargs
)
- return validation_data.assess(training_model)
+ return training_model, validation_data.assess(training_model)