diff options
Diffstat (limited to 'lib/dfatool.py')
-rwxr-xr-x | lib/dfatool.py | 35 |
1 files changed, 18 insertions, 17 deletions
diff --git a/lib/dfatool.py b/lib/dfatool.py index 40cacef..422b838 100755 --- a/lib/dfatool.py +++ b/lib/dfatool.py @@ -29,7 +29,7 @@ def is_numeric(n): except ValueError: return False -def _soft_cast_int(n): +def soft_cast_int(n): if n == None or n == '': return None try: @@ -48,9 +48,9 @@ def float_or_nan(n): def _elem_param_and_arg_list(elem): param_dict = elem['parameter'] paramkeys = sorted(param_dict.keys()) - paramvalue = [_soft_cast_int(param_dict[x]) for x in paramkeys] + paramvalue = [soft_cast_int(param_dict[x]) for x in paramkeys] if arg_support_enabled and 'args' in elem: - paramvalue.extend(map(_soft_cast_int, elem['args'])) + paramvalue.extend(map(soft_cast_int, elem['args'])) return paramvalue def _arg_name(arg_index): @@ -269,13 +269,13 @@ class RawData: online_trace_part['offline'].append(offline_trace_part) paramkeys = sorted(online_trace_part['parameter'].keys()) - paramvalue = [_soft_cast_int(online_trace_part['parameter'][x]) for x in paramkeys] + paramvalue = [soft_cast_int(online_trace_part['parameter'][x]) for x in paramkeys] # NB: Unscheduled transitions do not have an 'args' field set. # However, they should only be caused by interrupts, and # interrupts don't have args anyways. if arg_support_enabled and 'args' in online_trace_part: - paramvalue.extend(map(_soft_cast_int, online_trace_part['args'])) + paramvalue.extend(map(soft_cast_int, online_trace_part['args'])) if not 'offline_aggregates' in online_trace_part: online_trace_part['offline_aggregates'] = { @@ -437,8 +437,7 @@ class AnalyticFunction: X[i].extend([np.nan] * len(val[model_attribute])) elif key[0] == state_or_tran and len(key[1]) != dimension: print('[W] Invalid parameter key length while gathering fit data for {}/{}. is {}, want {}.'.format(state_or_tran, model_attribute, len(key[1]), dimension)) - for i in range(dimension): - X[i] = np.array(X[i]) + X = np.array(X) Y = np.array(Y) return X, Y, num_valid, num_total @@ -711,7 +710,7 @@ def _mean_std_by_param(by_param, state_or_tran, key, param_index): class EnergyModel: - def __init__(self, preprocessed_data): + def __init__(self, preprocessed_data, ignore_trace_indexes = None): self.traces = preprocessed_data self.by_name = {} self.by_param = {} @@ -720,13 +719,15 @@ class EnergyModel: np.seterr('raise') self._parameter_names = sorted(self.traces[0]['trace'][0]['parameter'].keys()) self._num_args = {} - for runidx, run in enumerate(self.traces): - # if opts['ignore-trace-idx'] != runidx - for i, elem in enumerate(run['trace']): - if elem['name'] != 'UNINITIALIZED': - self._load_run_elem(i, elem) - if elem['isa'] == 'transition' and not elem['name'] in self._num_args and 'args' in elem: - self._num_args[elem['name']] = len(elem['args']) + for run in self.traces: + if ignore_trace_indexes == None or int(run['id']) not in ignore_trace_indexes: + for i, elem in enumerate(run['trace']): + if elem['name'] != 'UNINITIALIZED': + self._load_run_elem(i, elem) + if elem['isa'] == 'transition' and not elem['name'] in self._num_args and 'args' in elem: + self._num_args[elem['name']] = len(elem['args']) + else: + print('[I] ignored trace index #{:d}'.format(int(run['id']))) self._aggregate_to_ndarray(self.by_name) self._compute_all_param_statistics() @@ -851,7 +852,7 @@ class EnergyModel: lut_model = self._get_model_from_dict(self.by_param, np.median) def lut_median_getter(name, key, param, arg = [], **kwargs): - param.extend(map(_soft_cast_int, arg)) + param.extend(map(soft_cast_int, arg)) return lut_model[(name, tuple(param))][key] return lut_median_getter @@ -911,7 +912,7 @@ class EnergyModel: state_or_tran, model_attribute, result['key'][2], fit_result['best_rmsd'], fit_result['mean_rmsd'], fit_result['median_rmsd'])) elif fit_result['best_rmsd'] >= 0.5 * min(fit_result['mean_rmsd'], fit_result['median_rmsd']): - print('[I] Not modeling {} {} as function o {}: best ({:.0f}) is not much better than ({:.0f}, {:.0f})'.format( + print('[I] Not modeling {} {} as function of {}: best ({:.0f}) is not much better than ({:.0f}, {:.0f})'.format( state_or_tran, model_attribute, result['key'][2], fit_result['best_rmsd'], fit_result['mean_rmsd'], fit_result['median_rmsd'])) else: |