diff options
author | Daniel Friesel <derf@finalrewind.org> | 2019-02-07 08:40:37 +0100 |
---|---|---|
committer | Daniel Friesel <derf@finalrewind.org> | 2019-02-07 08:40:37 +0100 |
commit | cfe740a107964c805451ad0a59eeff0049d5bac1 (patch) | |
tree | 35cf13edffbc4cee59105bebcb9350577b0fddb7 /lib | |
parent | e24681055ded2f273979c2ec05ce0c86651bca50 (diff) |
Use ParallelParamFit class for parallel fitting
Diffstat (limited to 'lib')
-rwxr-xr-x | lib/dfatool.py | 56 |
1 files changed, 39 insertions, 17 deletions
diff --git a/lib/dfatool.py b/lib/dfatool.py index 3048f43..372adee 100755 --- a/lib/dfatool.py +++ b/lib/dfatool.py @@ -636,6 +636,40 @@ class RawData: 'num_valid' : num_valid } +class ParallelParamFit: + """ + Fit a set of functions on parameterized measurements. + + One parameter is variale, all others are fixed. Reports the best-fitting + function type for each parameter. + """ + + def __init__(self, by_param): + """Create a new ParallelParamFit object.""" + self.fit_queue = [] + self.by_param = by_param + + def enqueue(self, state_or_tran, attribute, param_index, param_name, safe_functions_enabled = False): + """ + Add state_or_tran/attribute/param_name to fit queue. + + This causes fit() to compute the best-fitting function for this model part. + """ + self.fit_queue.append({ + 'key' : [state_or_tran, attribute, param_name], + 'args' : [self.by_param, state_or_tran, attribute, param_index, safe_functions_enabled] + }) + + def fit(self): + """ + Fit functions on previously enqueue data. + + Fitting is one in parallel with one process per core. + + Results can be accessed using the public ParallelParamFit.results object. + """ + with Pool() as pool: + self.results = pool.map(_try_fits_parallel, self.fit_queue) def _try_fits_parallel(arg): return { @@ -1025,7 +1059,7 @@ class EnergyModel: static_model = self._get_model_from_dict(self.by_name, np.median) param_model = dict([[state_or_tran, {}] for state_or_tran in self.by_name.keys()]) - fit_queue = [] + paramfit = ParallelParamFit(self.by_param) for state_or_tran in self.by_name.keys(): param_keys = filter(lambda k: k[0] == state_or_tran, self.by_param.keys()) param_subdict = dict(map(lambda k: [k, self.by_param[k]], param_keys)) @@ -1033,24 +1067,12 @@ class EnergyModel: fit_results = {} for parameter_index, parameter_name in enumerate(self._parameter_names): if self.depends_on_param(state_or_tran, model_attribute, parameter_name): - fit_queue.append({ - 'key' : [state_or_tran, model_attribute, parameter_name], - 'args' : [self.by_param, state_or_tran, model_attribute, parameter_index, safe_functions_enabled] - }) - #fit_results[parameter_name] = _try_fits(self.by_param, state_or_tran, model_attribute, parameter_index) - #print('{} {} is {}'.format(state_or_tran, parameter_name, fit_results[parameter_name]['best'])) + paramfit.enqueue(state_or_tran, model_attribute, parameter_index, parameter_name, safe_functions_enabled) if arg_support_enabled and self.by_name[state_or_tran]['isa'] == 'transition': for arg_index in range(self._num_args[state_or_tran]): if self.depends_on_arg(state_or_tran, model_attribute, arg_index): - fit_queue.append({ - 'key' : [state_or_tran, model_attribute, arg_index], - 'args' : [param_subdict, state_or_tran, model_attribute, len(self._parameter_names) + arg_index, safe_functions_enabled] - }) - #fit_results[_arg_name(arg_index)] = _try_fits(self.by_param, state_or_tran, model_attribute, len(self._parameter_names) + arg_index) - #if 'args' in self.by_name[state_or_tran]: - # for i, arg in range(len(self.by_name - with Pool() as pool: - all_fit_results = pool.map(_try_fits_parallel, fit_queue) + paramfit.enqueue(state_or_tran, model_attribute, len(self._parameter_names) + arg_index, arg_index, safe_functions_enabled) + paramfit.fit() for state_or_tran in self.by_name.keys(): num_args = 0 @@ -1058,7 +1080,7 @@ class EnergyModel: num_args = self._num_args[state_or_tran] for model_attribute in self.by_name[state_or_tran]['attributes']: fit_results = {} - for result in all_fit_results: + for result in paramfit.results: if result['key'][0] == state_or_tran and result['key'][1] == model_attribute: fit_result = result['result'] if fit_result['best_rmsd'] >= min(fit_result['mean_rmsd'], fit_result['median_rmsd']): |