diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2022-01-05 15:23:51 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2022-01-05 15:23:51 +0100 |
commit | 2c5bcd77f2c952cc5269ca3e4b6e0a7323ebd085 (patch) | |
tree | 93da4dc33c77855445e6aa21f45c12b1803861fa | |
parent | d9aee2a314ae6d3fc0216893a4ccfd8bb66ffa9c (diff) |
cross validation: return intermediate models used for XV
These are interesting for statistics, e.g. to determine the average dtree size
-rwxr-xr-x | bin/analyze-archive.py | 14 | ||||
-rwxr-xr-x | bin/analyze-kconfig.py | 18 | ||||
-rwxr-xr-x | bin/analyze-timing.py | 8 | ||||
-rw-r--r-- | lib/validation.py | 8 |
4 files changed, 28 insertions, 20 deletions
diff --git a/bin/analyze-archive.py b/bin/analyze-archive.py index b29091d..0a2845c 100755 --- a/bin/analyze-archive.py +++ b/bin/analyze-archive.py @@ -800,9 +800,9 @@ if __name__ == "__main__": ) if xv_method == "montecarlo": - static_quality = xv.montecarlo(lambda m: m.get_static(), xv_count) + static_quality, _ = xv.montecarlo(lambda m: m.get_static(), xv_count) elif xv_method == "kfold": - static_quality = xv.kfold(lambda m: m.get_static(), xv_count) + static_quality, _ = xv.kfold(lambda m: m.get_static(), xv_count) else: static_quality = model.assess(static_model) @@ -811,9 +811,11 @@ if __name__ == "__main__": lut_model = model.get_param_lut() if xv_method == "montecarlo": - lut_quality = xv.montecarlo(lambda m: m.get_param_lut(fallback=True), xv_count) + lut_quality, _ = xv.montecarlo( + lambda m: m.get_param_lut(fallback=True), xv_count + ) elif xv_method == "kfold": - lut_quality = xv.kfold(lambda m: m.get_param_lut(fallback=True), xv_count) + lut_quality, _ = xv.kfold(lambda m: m.get_param_lut(fallback=True), xv_count) else: lut_quality = model.assess(lut_model) @@ -933,9 +935,9 @@ if __name__ == "__main__": ) if xv_method == "montecarlo": - analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count) + analytic_quality, _ = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count) elif xv_method == "kfold": - analytic_quality = xv.kfold(lambda m: m.get_fitted()[0], xv_count) + analytic_quality, _ = xv.kfold(lambda m: m.get_fitted()[0], xv_count) else: analytic_quality = model.assess(param_model) diff --git a/bin/analyze-kconfig.py b/bin/analyze-kconfig.py index bd9cccb..048c8c9 100755 --- a/bin/analyze-kconfig.py +++ b/bin/analyze-kconfig.py @@ -237,19 +237,21 @@ def main(): fit_duration = time.time() - fit_start_time if xv_method == "montecarlo": - static_quality = xv.montecarlo(lambda m: m.get_static(), xv_count) - analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count) + static_quality, _ = xv.montecarlo(lambda m: m.get_static(), xv_count) + analytic_quality, _ = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count) if lut_model: - lut_quality = xv.montecarlo( + lut_quality, _ = xv.montecarlo( lambda m: m.get_param_lut(fallback=True), xv_count ) else: lut_quality = None elif xv_method == "kfold": - static_quality = xv.kfold(lambda m: m.get_static(), xv_count) - analytic_quality = xv.kfold(lambda m: m.get_fitted()[0], xv_count) + static_quality, _ = xv.kfold(lambda m: m.get_static(), xv_count) + analytic_quality, _ = xv.kfold(lambda m: m.get_fitted()[0], xv_count) if lut_model: - lut_quality = xv.kfold(lambda m: m.get_param_lut(fallback=True), xv_count) + lut_quality, _ = xv.kfold( + lambda m: m.get_param_lut(fallback=True), xv_count + ) else: lut_quality = None else: @@ -315,9 +317,9 @@ def main(): json.dump(json_model, f, sort_keys=True, cls=dfatool.utils.NpEncoder) if xv_method == "montecarlo": - static_quality = xv.montecarlo(lambda m: m.get_static(), xv_count) + static_quality, _ = xv.montecarlo(lambda m: m.get_static(), xv_count) elif xv_method == "kfold": - static_quality = xv.kfold(lambda m: m.get_static(), xv_count) + static_quality, _ = xv.kfold(lambda m: m.get_static(), xv_count) else: static_quality = model.assess(static_model) diff --git a/bin/analyze-timing.py b/bin/analyze-timing.py index d67c553..c37ea65 100755 --- a/bin/analyze-timing.py +++ b/bin/analyze-timing.py @@ -305,7 +305,7 @@ if __name__ == "__main__": ) if xv_method == "montecarlo": - static_quality = xv.montecarlo(lambda m: m.get_static(), xv_count) + static_quality, _ = xv.montecarlo(lambda m: m.get_static(), xv_count) else: static_quality = model.assess(static_model) @@ -314,7 +314,9 @@ if __name__ == "__main__": lut_model = model.get_param_lut() if xv_method == "montecarlo": - lut_quality = xv.montecarlo(lambda m: m.get_param_lut(fallback=True), xv_count) + lut_quality, _ = xv.montecarlo( + lambda m: m.get_param_lut(fallback=True), xv_count + ) else: lut_quality = model.assess(lut_model) @@ -412,7 +414,7 @@ if __name__ == "__main__": print("{:10s} {:10s} {}".format("", "", info.model_args)) if xv_method == "montecarlo": - analytic_quality = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count) + analytic_quality, _ = xv.montecarlo(lambda m: m.get_fitted()[0], xv_count) else: analytic_quality = model.assess(param_model) diff --git a/lib/validation.py b/lib/validation.py index 47395dc..5c65fe3 100644 --- a/lib/validation.py +++ b/lib/validation.py @@ -211,6 +211,7 @@ class CrossValidator: def _generic_xv(self, model_getter, training_and_validation_sets): ret = dict() + models = list() for name in self.names: ret[name] = dict() @@ -223,7 +224,8 @@ class CrossValidator: } for training_and_validation_by_name in training_and_validation_sets: - res = self._single_xv(model_getter, training_and_validation_by_name) + model, res = self._single_xv(model_getter, training_and_validation_by_name) + models.append(model) for name in self.names: for attribute in self.by_name[name]["attributes"]: for measure in ("mae", "rmsd", "mape", "smape"): @@ -238,7 +240,7 @@ class CrossValidator: ret[name][attribute][f"{measure}_list"] ) - return ret + return ret, models def _single_xv(self, model_getter, tv_set_dict): training = dict() @@ -278,4 +280,4 @@ class CrossValidator: validation, self.parameters, *self.args, **self.kwargs ) - return validation_data.assess(training_model) + return training_model, validation_data.assess(training_model) |