diff options
Diffstat (limited to 'lib/dfatool.py')
-rwxr-xr-x | lib/dfatool.py | 31 |
1 files changed, 23 insertions, 8 deletions
diff --git a/lib/dfatool.py b/lib/dfatool.py index 82ed35f..c4529ce 100755 --- a/lib/dfatool.py +++ b/lib/dfatool.py @@ -295,7 +295,7 @@ class CrossValidator: self.parameters = sorted(parameters) self.arg_count = arg_count - def montecarlo(self, model_getter, count = 2): + def montecarlo(self, model_getter, count = 200): """ Perform Monte Carlo cross-validation and return average model quality. @@ -1013,18 +1013,27 @@ class AnalyticModel: return static_mean_getter - def get_param_lut(self): + def get_param_lut(self, fallback = False): """ Get parameter-look-up-table model function: name, attribute, parameter values -> model value. The function can only give model values for parameter combinations - present in by_param. It raises KeyError for other values. + present in by_param. By default, it raises KeyError for other values. + + arguments: + fallback -- Fall back to the (non-parameter-aware) static model when encountering unknown parameter values """ + static_model = self._get_model_from_dict(self.by_name, np.median) lut_model = self._get_model_from_dict(self.by_param, np.median) def lut_median_getter(name, key, param, arg = [], **kwargs): param.extend(map(soft_cast_int, arg)) - return lut_model[(name, tuple(param))][key] + try: + return lut_model[(name, tuple(param))][key] + except KeyError: + if fallback: + return static_model[name][key] + raise return lut_median_getter @@ -1380,21 +1389,27 @@ class PTAModel: return static_mean_getter - def get_param_lut(self): + def get_param_lut(self, fallback = False): """ Get parameter-look-up-table model function: name, attribute, parameter values -> model value. The function can only give model values for parameter combinations - present in by_param. It raises KeyError for other values. + present in by_param. By default, it raises KeyError for other values. + + arguments: + fallback -- Fall back to the (non-parameter-aware) static model when encountering unknown parameter values """ static_model = self._get_model_from_dict(self.by_name, np.median) lut_model = self._get_model_from_dict(self.by_param, np.median) def lut_median_getter(name, key, param, arg = [], **kwargs): param.extend(map(soft_cast_int, arg)) - if (name, tuple(param)) in lut_model: + try: return lut_model[(name, tuple(param))][key] - return static_model[name][key] + except KeyError: + if fallback: + return static_model[name][key] + raise return lut_median_getter |