diff options
author | Daniel Friesel <derf@finalrewind.org> | 2019-02-11 12:02:38 +0100 |
---|---|---|
committer | Daniel Friesel <derf@finalrewind.org> | 2019-02-11 12:02:38 +0100 |
commit | 87ad9424ad5490205576c9bfe6c9c070ae9d3b1c (patch) | |
tree | 99e351a7b7c5933d48ef7b2e40bb6f2336f06861 /lib | |
parent | 427374cab40dd1b658d3b7cf219a709062d79b8c (diff) |
cleanup, doku
Diffstat (limited to 'lib')
-rwxr-xr-x | lib/dfatool.py | 42 |
1 files changed, 38 insertions, 4 deletions
diff --git a/lib/dfatool.py b/lib/dfatool.py index ba94226..2a9ba1a 100755 --- a/lib/dfatool.py +++ b/lib/dfatool.py @@ -1199,6 +1199,7 @@ class PTAModel: self._parameter_names = sorted(parameters) self._num_args = arg_count self.traces = traces + self.stats = ParamStats(self.by_name, self.by_param, self._parameter_names, self._num_args, self._use_corrcoef) self.cache = {} np.seterr('raise') self._outlier_threshold = discard_outliers @@ -1208,16 +1209,12 @@ class PTAModel: self.hwmodel = hwmodel self.ignore_trace_indexes = ignore_trace_indexes self._aggregate_to_ndarray(self.by_name) - self._compute_all_param_statistics() def distinct_param_values(self, state_or_tran, param_index = None, arg_index = None): if param_index != None: param_values = map(lambda x: x[param_index], self.by_name[state_or_tran]['param']) return sorted(set(param_values)) - def _compute_all_param_statistics(self): - self.stats = ParamStats(self.by_name, self.by_param, self._parameter_names, self._num_args, self._use_corrcoef) - def _aggregate_to_ndarray(self, aggregate): for elem in aggregate.values(): for key in elem['attributes']: @@ -1248,6 +1245,11 @@ class PTAModel: return model def get_static(self): + """ + Get static model function: name, attribute -> model value. + + Uses the median of by_name for modeling. + """ static_model = self._get_model_from_dict(self.by_name, np.median) def static_median_getter(name, key, **kwargs): @@ -1256,6 +1258,11 @@ class PTAModel: return static_median_getter def get_static_using_mean(self): + """ + Get static model function: name, attribute -> model value. + + Uses the mean of by_name for modeling. + """ static_model = self._get_model_from_dict(self.by_name, np.mean) def static_mean_getter(name, key, **kwargs): @@ -1264,6 +1271,12 @@ class PTAModel: return static_mean_getter def get_param_lut(self): + """ + Get parameter-look-up-table model function: name, attribute, parameter values -> model value. + + The function can only give model values for parameter combinations + present in by_param. It raises KeyError for other values. + """ lut_model = self._get_model_from_dict(self.by_param, np.median) def lut_median_getter(name, key, param, arg = [], **kwargs): @@ -1286,7 +1299,13 @@ class PTAModel: return str(param_index) def get_fitted(self, safe_functions_enabled = False): + """ + Get paramete-aware model function and model information function. + Returns two functions: + model_function(name, attribute, param=parameter values) -> model value. + model_info(name, attribute) -> {'fit_result' : ..., 'function' : ... } or None + """ if 'fitted_model_getter' in self.cache and 'fitted_info_getter' in self.cache: return self.cache['fitted_model_getter'], self.cache['fitted_info_getter'] @@ -1389,6 +1408,21 @@ class PTAModel: return self.by_name[state_or_trans]['attributes'] def assess(self, model_function): + """ + Calculate MAE, SMAPE, etc. of model_function for each by_name entry. + + state/transition/... name and parameter values are fed into model_function. + The by_name entries of this PTAModel are used as ground truth and + compared with the values predicted by model_function. + + If 'traces' was set when creating this object, the model quality is + also assessed on a per-trace basis. + + For proper model assessments, the data used to generate model_function + and the data fed into this AnalyticModel instance must be mutually + exclusive (e.g. by performing cross validation). Otherwise, + overfitting cannot be detected. + """ detailed_results = {} model_energy_list = [] real_energy_list = [] |