From 7ed42ca64831018aa7ea379651225d66412e62f0 Mon Sep 17 00:00:00 2001 From: Daniel Friesel Date: Thu, 14 Feb 2019 08:29:26 +0100 Subject: improved cross-validation in analyze-archive; fallback for param_lut_model --- lib/dfatool.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) (limited to 'lib/dfatool.py') diff --git a/lib/dfatool.py b/lib/dfatool.py index 82ed35f..c4529ce 100755 --- a/lib/dfatool.py +++ b/lib/dfatool.py @@ -295,7 +295,7 @@ class CrossValidator: self.parameters = sorted(parameters) self.arg_count = arg_count - def montecarlo(self, model_getter, count = 2): + def montecarlo(self, model_getter, count = 200): """ Perform Monte Carlo cross-validation and return average model quality. @@ -1013,18 +1013,27 @@ class AnalyticModel: return static_mean_getter - def get_param_lut(self): + def get_param_lut(self, fallback = False): """ Get parameter-look-up-table model function: name, attribute, parameter values -> model value. The function can only give model values for parameter combinations - present in by_param. It raises KeyError for other values. + present in by_param. By default, it raises KeyError for other values. + + arguments: + fallback -- Fall back to the (non-parameter-aware) static model when encountering unknown parameter values """ + static_model = self._get_model_from_dict(self.by_name, np.median) lut_model = self._get_model_from_dict(self.by_param, np.median) def lut_median_getter(name, key, param, arg = [], **kwargs): param.extend(map(soft_cast_int, arg)) - return lut_model[(name, tuple(param))][key] + try: + return lut_model[(name, tuple(param))][key] + except KeyError: + if fallback: + return static_model[name][key] + raise return lut_median_getter @@ -1380,21 +1389,27 @@ class PTAModel: return static_mean_getter - def get_param_lut(self): + def get_param_lut(self, fallback = False): """ Get parameter-look-up-table model function: name, attribute, parameter values -> model value. The function can only give model values for parameter combinations - present in by_param. It raises KeyError for other values. + present in by_param. By default, it raises KeyError for other values. + + arguments: + fallback -- Fall back to the (non-parameter-aware) static model when encountering unknown parameter values """ static_model = self._get_model_from_dict(self.by_name, np.median) lut_model = self._get_model_from_dict(self.by_param, np.median) def lut_median_getter(name, key, param, arg = [], **kwargs): param.extend(map(soft_cast_int, arg)) - if (name, tuple(param)) in lut_model: + try: return lut_model[(name, tuple(param))][key] - return static_model[name][key] + except KeyError: + if fallback: + return static_model[name][key] + raise return lut_median_getter -- cgit v1.2.3