summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorDaniel Friesel <derf@finalrewind.org>2019-02-14 08:29:26 +0100
committerDaniel Friesel <derf@finalrewind.org>2019-02-14 08:29:26 +0100
commit7ed42ca64831018aa7ea379651225d66412e62f0 (patch)
tree6a0ac767a39001901cd6ffc6d63a3ee1e9d61e56 /lib
parent906a01beeb24fafa691f585018a0ec809a88de08 (diff)
improved cross-validation in analyze-archive; fallback for param_lut_model
Diffstat (limited to 'lib')
-rwxr-xr-xlib/dfatool.py31
1 files changed, 23 insertions, 8 deletions
diff --git a/lib/dfatool.py b/lib/dfatool.py
index 82ed35f..c4529ce 100755
--- a/lib/dfatool.py
+++ b/lib/dfatool.py
@@ -295,7 +295,7 @@ class CrossValidator:
self.parameters = sorted(parameters)
self.arg_count = arg_count
- def montecarlo(self, model_getter, count = 2):
+ def montecarlo(self, model_getter, count = 200):
"""
Perform Monte Carlo cross-validation and return average model quality.
@@ -1013,18 +1013,27 @@ class AnalyticModel:
return static_mean_getter
- def get_param_lut(self):
+ def get_param_lut(self, fallback = False):
"""
Get parameter-look-up-table model function: name, attribute, parameter values -> model value.
The function can only give model values for parameter combinations
- present in by_param. It raises KeyError for other values.
+ present in by_param. By default, it raises KeyError for other values.
+
+ arguments:
+ fallback -- Fall back to the (non-parameter-aware) static model when encountering unknown parameter values
"""
+ static_model = self._get_model_from_dict(self.by_name, np.median)
lut_model = self._get_model_from_dict(self.by_param, np.median)
def lut_median_getter(name, key, param, arg = [], **kwargs):
param.extend(map(soft_cast_int, arg))
- return lut_model[(name, tuple(param))][key]
+ try:
+ return lut_model[(name, tuple(param))][key]
+ except KeyError:
+ if fallback:
+ return static_model[name][key]
+ raise
return lut_median_getter
@@ -1380,21 +1389,27 @@ class PTAModel:
return static_mean_getter
- def get_param_lut(self):
+ def get_param_lut(self, fallback = False):
"""
Get parameter-look-up-table model function: name, attribute, parameter values -> model value.
The function can only give model values for parameter combinations
- present in by_param. It raises KeyError for other values.
+ present in by_param. By default, it raises KeyError for other values.
+
+ arguments:
+ fallback -- Fall back to the (non-parameter-aware) static model when encountering unknown parameter values
"""
static_model = self._get_model_from_dict(self.by_name, np.median)
lut_model = self._get_model_from_dict(self.by_param, np.median)
def lut_median_getter(name, key, param, arg = [], **kwargs):
param.extend(map(soft_cast_int, arg))
- if (name, tuple(param)) in lut_model:
+ try:
return lut_model[(name, tuple(param))][key]
- return static_model[name][key]
+ except KeyError:
+ if fallback:
+ return static_model[name][key]
+ raise
return lut_median_getter