diff options
author | Daniel Friesel <derf@finalrewind.org> | 2019-02-14 08:29:26 +0100 |
---|---|---|
committer | Daniel Friesel <derf@finalrewind.org> | 2019-02-14 08:29:26 +0100 |
commit | 7ed42ca64831018aa7ea379651225d66412e62f0 (patch) | |
tree | 6a0ac767a39001901cd6ffc6d63a3ee1e9d61e56 /lib | |
parent | 906a01beeb24fafa691f585018a0ec809a88de08 (diff) |
improved cross-validation in analyze-archive; fallback for param_lut_model
Diffstat (limited to 'lib')
-rwxr-xr-x | lib/dfatool.py | 31 |
1 files changed, 23 insertions, 8 deletions
diff --git a/lib/dfatool.py b/lib/dfatool.py index 82ed35f..c4529ce 100755 --- a/lib/dfatool.py +++ b/lib/dfatool.py @@ -295,7 +295,7 @@ class CrossValidator: self.parameters = sorted(parameters) self.arg_count = arg_count - def montecarlo(self, model_getter, count = 2): + def montecarlo(self, model_getter, count = 200): """ Perform Monte Carlo cross-validation and return average model quality. @@ -1013,18 +1013,27 @@ class AnalyticModel: return static_mean_getter - def get_param_lut(self): + def get_param_lut(self, fallback = False): """ Get parameter-look-up-table model function: name, attribute, parameter values -> model value. The function can only give model values for parameter combinations - present in by_param. It raises KeyError for other values. + present in by_param. By default, it raises KeyError for other values. + + arguments: + fallback -- Fall back to the (non-parameter-aware) static model when encountering unknown parameter values """ + static_model = self._get_model_from_dict(self.by_name, np.median) lut_model = self._get_model_from_dict(self.by_param, np.median) def lut_median_getter(name, key, param, arg = [], **kwargs): param.extend(map(soft_cast_int, arg)) - return lut_model[(name, tuple(param))][key] + try: + return lut_model[(name, tuple(param))][key] + except KeyError: + if fallback: + return static_model[name][key] + raise return lut_median_getter @@ -1380,21 +1389,27 @@ class PTAModel: return static_mean_getter - def get_param_lut(self): + def get_param_lut(self, fallback = False): """ Get parameter-look-up-table model function: name, attribute, parameter values -> model value. The function can only give model values for parameter combinations - present in by_param. It raises KeyError for other values. + present in by_param. By default, it raises KeyError for other values. + + arguments: + fallback -- Fall back to the (non-parameter-aware) static model when encountering unknown parameter values """ static_model = self._get_model_from_dict(self.by_name, np.median) lut_model = self._get_model_from_dict(self.by_param, np.median) def lut_median_getter(name, key, param, arg = [], **kwargs): param.extend(map(soft_cast_int, arg)) - if (name, tuple(param)) in lut_model: + try: return lut_model[(name, tuple(param))][key] - return static_model[name][key] + except KeyError: + if fallback: + return static_model[name][key] + raise return lut_median_getter |