diff options
-rwxr-xr-x | lib/dfatool.py | 122 |
1 files changed, 120 insertions, 2 deletions
diff --git a/lib/dfatool.py b/lib/dfatool.py index b71df98..87759fd 100755 --- a/lib/dfatool.py +++ b/lib/dfatool.py @@ -783,12 +783,14 @@ class AnalyticModel: All list except for 'attributes' must have the same length. """ - def __init__(self, by_name, by_param, parameters): + def __init__(self, by_name, by_param, parameters, verbose = True): + self.cache = dict() self.by_name = by_name self.by_param = by_param self.parameters = sorted(parameters) + self.verbose = verbose - self.stats = ParamStats(self.by_name, self.by_param, self.parameters) + self.stats = ParamStats(self.by_name, self.by_param, self.parameters, {}) def _fit(self): paramfit = ParallelParamFit(self.by_param) @@ -812,6 +814,122 @@ class AnalyticModel: x = analytic.function_powerset(fit_result, parameters) x.fit(by_param, fname, attribute) + def names(self): + return sorted(self.by_name.keys()) + + def _get_model_from_dict(self, model_dict, model_function): + model = {} + for name, elem in model_dict.items(): + model[name] = {} + for key in elem['attributes']: + try: + model[name][key] = model_function(elem[key]) + except RuntimeWarning: + vprint(self.verbose, '[W] Got no data for {} {}'.format(name, key)) + except FloatingPointError as fpe: + vprint(self.verbose, '[W] Got no data for {} {}: {}'.format(name, key, fpe)) + return model + + def get_static(self): + static_model = self._get_model_from_dict(self.by_name, np.median) + + def static_median_getter(name, key, **kwargs): + return static_model[name][key] + + return static_median_getter + + def get_static_using_mean(self): + static_model = self._get_model_from_dict(self.by_name, np.mean) + + def static_mean_getter(name, key, **kwargs): + return static_model[name][key] + + return static_mean_getter + + def get_param_lut(self): + lut_model = self._get_model_from_dict(self.by_param, np.median) + + def lut_median_getter(name, key, param, arg = [], **kwargs): + param.extend(map(soft_cast_int, arg)) + return lut_model[(name, tuple(param))][key] + + return lut_median_getter + + def get_fitted(self, safe_functions_enabled = False): + + if 'fitted_model_getter' in self.cache and 'fitted_info_getter' in self.cache: + return self.cache['fitted_model_getter'], self.cache['fitted_info_getter'] + + static_model = self._get_model_from_dict(self.by_name, np.median) + param_model = dict([[name, {}] for name in self.by_name.keys()]) + paramfit = ParallelParamFit(self.by_param) + + for name in self.by_name.keys(): + for attribute in self.by_name[name]['attributes']: + for param_index, param in enumerate(self.parameters): + ratio = self.stats.param_dependence_ratio(name, attribute, param) + if self.stats.depends_on_param(name, attribute, param): + paramfit.enqueue(name, attribute, param_index, param, False) + + paramfit.fit() + + for name in self.by_name.keys(): + for attribute in self.by_name[name]['attributes']: + fit_result = {} + for result in paramfit.results: + if result['key'][0] == name and result['key'][1] == attribute and result['result']['best'] != None: + this_result = result['result'] + if this_result['best_rmsd'] >= min(this_result['mean_rmsd'], this_result['median_rmsd']): + vprint(self.verbose, '[I] Not modeling {} {} as function of {}: best ({:.0f}) is worse than ref ({:.0f}, {:.0f})'.format( + name, attribute, result['key'][2], this_result['best_rmsd'], + this_result['mean_rmsd'], this_result['median_rmsd'])) + # See notes on depends_on_param + elif this_result['best_rmsd'] >= 0.8 * min(this_result['mean_rmsd'], this_result['median_rmsd']): + vprint(self.verbose, '[I] Not modeling {} {} as function of {}: best ({:.0f}) is not much better than ({:.0f}, {:.0f})'.format( + name, attribute, result['key'][2], this_result['best_rmsd'], + this_result['mean_rmsd'], this_result['median_rmsd'])) + else: + fit_result[result['key'][2]] = this_result + + if len(fit_result.keys()): + x = analytic.function_powerset(fit_result, self.parameters) + x.fit(self.by_param, name, attribute) + + if x.fit_success: + param_model[name][attribute] = { + 'fit_result': fit_result, + 'function' : x + } + + def model_getter(name, key, **kwargs): + if key in param_model[name]: + param_list = kwargs['param'] + param_function = param_model[name][key]['function'] + if param_function.is_predictable(param_list): + return param_function.eval(param_list) + return static_model[name][key] + + def info_getter(name, key): + if key in param_model[name]: + return param_model[name][key] + return None + + self.cache['fitted_model_getter'] = model_getter + self.cache['fitted_info_getter'] = info_getter + + return model_getter, info_getter + + def assess(self, model_function): + detailed_results = {} + for name, elem in sorted(self.by_name.items()): + detailed_results[name] = {} + for attribute in elem['attributes']: + predicted_data = np.array(list(map(lambda i: model_function(name, attribute, param=elem['param'][i]), range(len(elem[attribute]))))) + measures = regression_measures(predicted_data, elem[attribute]) + detailed_results[name][attribute] = measures + + return detailed_results + class PTAModel: u""" |