summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorDaniel Friesel <derf@finalrewind.org>2019-02-07 08:40:37 +0100
committerDaniel Friesel <derf@finalrewind.org>2019-02-07 08:40:37 +0100
commitcfe740a107964c805451ad0a59eeff0049d5bac1 (patch)
tree35cf13edffbc4cee59105bebcb9350577b0fddb7 /lib
parente24681055ded2f273979c2ec05ce0c86651bca50 (diff)
Use ParallelParamFit class for parallel fitting
Diffstat (limited to 'lib')
-rwxr-xr-xlib/dfatool.py56
1 files changed, 39 insertions, 17 deletions
diff --git a/lib/dfatool.py b/lib/dfatool.py
index 3048f43..372adee 100755
--- a/lib/dfatool.py
+++ b/lib/dfatool.py
@@ -636,6 +636,40 @@ class RawData:
'num_valid' : num_valid
}
+class ParallelParamFit:
+ """
+ Fit a set of functions on parameterized measurements.
+
+ One parameter is variale, all others are fixed. Reports the best-fitting
+ function type for each parameter.
+ """
+
+ def __init__(self, by_param):
+ """Create a new ParallelParamFit object."""
+ self.fit_queue = []
+ self.by_param = by_param
+
+ def enqueue(self, state_or_tran, attribute, param_index, param_name, safe_functions_enabled = False):
+ """
+ Add state_or_tran/attribute/param_name to fit queue.
+
+ This causes fit() to compute the best-fitting function for this model part.
+ """
+ self.fit_queue.append({
+ 'key' : [state_or_tran, attribute, param_name],
+ 'args' : [self.by_param, state_or_tran, attribute, param_index, safe_functions_enabled]
+ })
+
+ def fit(self):
+ """
+ Fit functions on previously enqueue data.
+
+ Fitting is one in parallel with one process per core.
+
+ Results can be accessed using the public ParallelParamFit.results object.
+ """
+ with Pool() as pool:
+ self.results = pool.map(_try_fits_parallel, self.fit_queue)
def _try_fits_parallel(arg):
return {
@@ -1025,7 +1059,7 @@ class EnergyModel:
static_model = self._get_model_from_dict(self.by_name, np.median)
param_model = dict([[state_or_tran, {}] for state_or_tran in self.by_name.keys()])
- fit_queue = []
+ paramfit = ParallelParamFit(self.by_param)
for state_or_tran in self.by_name.keys():
param_keys = filter(lambda k: k[0] == state_or_tran, self.by_param.keys())
param_subdict = dict(map(lambda k: [k, self.by_param[k]], param_keys))
@@ -1033,24 +1067,12 @@ class EnergyModel:
fit_results = {}
for parameter_index, parameter_name in enumerate(self._parameter_names):
if self.depends_on_param(state_or_tran, model_attribute, parameter_name):
- fit_queue.append({
- 'key' : [state_or_tran, model_attribute, parameter_name],
- 'args' : [self.by_param, state_or_tran, model_attribute, parameter_index, safe_functions_enabled]
- })
- #fit_results[parameter_name] = _try_fits(self.by_param, state_or_tran, model_attribute, parameter_index)
- #print('{} {} is {}'.format(state_or_tran, parameter_name, fit_results[parameter_name]['best']))
+ paramfit.enqueue(state_or_tran, model_attribute, parameter_index, parameter_name, safe_functions_enabled)
if arg_support_enabled and self.by_name[state_or_tran]['isa'] == 'transition':
for arg_index in range(self._num_args[state_or_tran]):
if self.depends_on_arg(state_or_tran, model_attribute, arg_index):
- fit_queue.append({
- 'key' : [state_or_tran, model_attribute, arg_index],
- 'args' : [param_subdict, state_or_tran, model_attribute, len(self._parameter_names) + arg_index, safe_functions_enabled]
- })
- #fit_results[_arg_name(arg_index)] = _try_fits(self.by_param, state_or_tran, model_attribute, len(self._parameter_names) + arg_index)
- #if 'args' in self.by_name[state_or_tran]:
- # for i, arg in range(len(self.by_name
- with Pool() as pool:
- all_fit_results = pool.map(_try_fits_parallel, fit_queue)
+ paramfit.enqueue(state_or_tran, model_attribute, len(self._parameter_names) + arg_index, arg_index, safe_functions_enabled)
+ paramfit.fit()
for state_or_tran in self.by_name.keys():
num_args = 0
@@ -1058,7 +1080,7 @@ class EnergyModel:
num_args = self._num_args[state_or_tran]
for model_attribute in self.by_name[state_or_tran]['attributes']:
fit_results = {}
- for result in all_fit_results:
+ for result in paramfit.results:
if result['key'][0] == state_or_tran and result['key'][1] == model_attribute:
fit_result = result['result']
if fit_result['best_rmsd'] >= min(fit_result['mean_rmsd'], fit_result['median_rmsd']):