summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorDaniel Friesel <derf@finalrewind.org>2019-02-11 12:02:38 +0100
committerDaniel Friesel <derf@finalrewind.org>2019-02-11 12:02:38 +0100
commit87ad9424ad5490205576c9bfe6c9c070ae9d3b1c (patch)
tree99e351a7b7c5933d48ef7b2e40bb6f2336f06861 /lib
parent427374cab40dd1b658d3b7cf219a709062d79b8c (diff)
cleanup, doku
Diffstat (limited to 'lib')
-rwxr-xr-xlib/dfatool.py42
1 files changed, 38 insertions, 4 deletions
diff --git a/lib/dfatool.py b/lib/dfatool.py
index ba94226..2a9ba1a 100755
--- a/lib/dfatool.py
+++ b/lib/dfatool.py
@@ -1199,6 +1199,7 @@ class PTAModel:
self._parameter_names = sorted(parameters)
self._num_args = arg_count
self.traces = traces
+ self.stats = ParamStats(self.by_name, self.by_param, self._parameter_names, self._num_args, self._use_corrcoef)
self.cache = {}
np.seterr('raise')
self._outlier_threshold = discard_outliers
@@ -1208,16 +1209,12 @@ class PTAModel:
self.hwmodel = hwmodel
self.ignore_trace_indexes = ignore_trace_indexes
self._aggregate_to_ndarray(self.by_name)
- self._compute_all_param_statistics()
def distinct_param_values(self, state_or_tran, param_index = None, arg_index = None):
if param_index != None:
param_values = map(lambda x: x[param_index], self.by_name[state_or_tran]['param'])
return sorted(set(param_values))
- def _compute_all_param_statistics(self):
- self.stats = ParamStats(self.by_name, self.by_param, self._parameter_names, self._num_args, self._use_corrcoef)
-
def _aggregate_to_ndarray(self, aggregate):
for elem in aggregate.values():
for key in elem['attributes']:
@@ -1248,6 +1245,11 @@ class PTAModel:
return model
def get_static(self):
+ """
+ Get static model function: name, attribute -> model value.
+
+ Uses the median of by_name for modeling.
+ """
static_model = self._get_model_from_dict(self.by_name, np.median)
def static_median_getter(name, key, **kwargs):
@@ -1256,6 +1258,11 @@ class PTAModel:
return static_median_getter
def get_static_using_mean(self):
+ """
+ Get static model function: name, attribute -> model value.
+
+ Uses the mean of by_name for modeling.
+ """
static_model = self._get_model_from_dict(self.by_name, np.mean)
def static_mean_getter(name, key, **kwargs):
@@ -1264,6 +1271,12 @@ class PTAModel:
return static_mean_getter
def get_param_lut(self):
+ """
+ Get parameter-look-up-table model function: name, attribute, parameter values -> model value.
+
+ The function can only give model values for parameter combinations
+ present in by_param. It raises KeyError for other values.
+ """
lut_model = self._get_model_from_dict(self.by_param, np.median)
def lut_median_getter(name, key, param, arg = [], **kwargs):
@@ -1286,7 +1299,13 @@ class PTAModel:
return str(param_index)
def get_fitted(self, safe_functions_enabled = False):
+ """
+ Get paramete-aware model function and model information function.
+ Returns two functions:
+ model_function(name, attribute, param=parameter values) -> model value.
+ model_info(name, attribute) -> {'fit_result' : ..., 'function' : ... } or None
+ """
if 'fitted_model_getter' in self.cache and 'fitted_info_getter' in self.cache:
return self.cache['fitted_model_getter'], self.cache['fitted_info_getter']
@@ -1389,6 +1408,21 @@ class PTAModel:
return self.by_name[state_or_trans]['attributes']
def assess(self, model_function):
+ """
+ Calculate MAE, SMAPE, etc. of model_function for each by_name entry.
+
+ state/transition/... name and parameter values are fed into model_function.
+ The by_name entries of this PTAModel are used as ground truth and
+ compared with the values predicted by model_function.
+
+ If 'traces' was set when creating this object, the model quality is
+ also assessed on a per-trace basis.
+
+ For proper model assessments, the data used to generate model_function
+ and the data fed into this AnalyticModel instance must be mutually
+ exclusive (e.g. by performing cross validation). Otherwise,
+ overfitting cannot be detected.
+ """
detailed_results = {}
model_energy_list = []
real_energy_list = []