summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorDaniel Friesel <derf@finalrewind.org>2018-03-01 11:01:06 +0100
committerDaniel Friesel <derf@finalrewind.org>2018-03-01 11:01:06 +0100
commita5e5cf1f9708cdd20122b4e267551272ea260d4f (patch)
treedd86c68897265da9b3f2fc1b5b6dcc6dea85ac0a /lib
parent25b4281ef57486b925fd86f2ce9d8a0fff608d29 (diff)
re-add argument support, starting with --ignored-trace-indexes
Diffstat (limited to 'lib')
-rwxr-xr-xlib/dfatool.py35
1 files changed, 18 insertions, 17 deletions
diff --git a/lib/dfatool.py b/lib/dfatool.py
index 40cacef..422b838 100755
--- a/lib/dfatool.py
+++ b/lib/dfatool.py
@@ -29,7 +29,7 @@ def is_numeric(n):
except ValueError:
return False
-def _soft_cast_int(n):
+def soft_cast_int(n):
if n == None or n == '':
return None
try:
@@ -48,9 +48,9 @@ def float_or_nan(n):
def _elem_param_and_arg_list(elem):
param_dict = elem['parameter']
paramkeys = sorted(param_dict.keys())
- paramvalue = [_soft_cast_int(param_dict[x]) for x in paramkeys]
+ paramvalue = [soft_cast_int(param_dict[x]) for x in paramkeys]
if arg_support_enabled and 'args' in elem:
- paramvalue.extend(map(_soft_cast_int, elem['args']))
+ paramvalue.extend(map(soft_cast_int, elem['args']))
return paramvalue
def _arg_name(arg_index):
@@ -269,13 +269,13 @@ class RawData:
online_trace_part['offline'].append(offline_trace_part)
paramkeys = sorted(online_trace_part['parameter'].keys())
- paramvalue = [_soft_cast_int(online_trace_part['parameter'][x]) for x in paramkeys]
+ paramvalue = [soft_cast_int(online_trace_part['parameter'][x]) for x in paramkeys]
# NB: Unscheduled transitions do not have an 'args' field set.
# However, they should only be caused by interrupts, and
# interrupts don't have args anyways.
if arg_support_enabled and 'args' in online_trace_part:
- paramvalue.extend(map(_soft_cast_int, online_trace_part['args']))
+ paramvalue.extend(map(soft_cast_int, online_trace_part['args']))
if not 'offline_aggregates' in online_trace_part:
online_trace_part['offline_aggregates'] = {
@@ -437,8 +437,7 @@ class AnalyticFunction:
X[i].extend([np.nan] * len(val[model_attribute]))
elif key[0] == state_or_tran and len(key[1]) != dimension:
print('[W] Invalid parameter key length while gathering fit data for {}/{}. is {}, want {}.'.format(state_or_tran, model_attribute, len(key[1]), dimension))
- for i in range(dimension):
- X[i] = np.array(X[i])
+ X = np.array(X)
Y = np.array(Y)
return X, Y, num_valid, num_total
@@ -711,7 +710,7 @@ def _mean_std_by_param(by_param, state_or_tran, key, param_index):
class EnergyModel:
- def __init__(self, preprocessed_data):
+ def __init__(self, preprocessed_data, ignore_trace_indexes = None):
self.traces = preprocessed_data
self.by_name = {}
self.by_param = {}
@@ -720,13 +719,15 @@ class EnergyModel:
np.seterr('raise')
self._parameter_names = sorted(self.traces[0]['trace'][0]['parameter'].keys())
self._num_args = {}
- for runidx, run in enumerate(self.traces):
- # if opts['ignore-trace-idx'] != runidx
- for i, elem in enumerate(run['trace']):
- if elem['name'] != 'UNINITIALIZED':
- self._load_run_elem(i, elem)
- if elem['isa'] == 'transition' and not elem['name'] in self._num_args and 'args' in elem:
- self._num_args[elem['name']] = len(elem['args'])
+ for run in self.traces:
+ if ignore_trace_indexes == None or int(run['id']) not in ignore_trace_indexes:
+ for i, elem in enumerate(run['trace']):
+ if elem['name'] != 'UNINITIALIZED':
+ self._load_run_elem(i, elem)
+ if elem['isa'] == 'transition' and not elem['name'] in self._num_args and 'args' in elem:
+ self._num_args[elem['name']] = len(elem['args'])
+ else:
+ print('[I] ignored trace index #{:d}'.format(int(run['id'])))
self._aggregate_to_ndarray(self.by_name)
self._compute_all_param_statistics()
@@ -851,7 +852,7 @@ class EnergyModel:
lut_model = self._get_model_from_dict(self.by_param, np.median)
def lut_median_getter(name, key, param, arg = [], **kwargs):
- param.extend(map(_soft_cast_int, arg))
+ param.extend(map(soft_cast_int, arg))
return lut_model[(name, tuple(param))][key]
return lut_median_getter
@@ -911,7 +912,7 @@ class EnergyModel:
state_or_tran, model_attribute, result['key'][2], fit_result['best_rmsd'],
fit_result['mean_rmsd'], fit_result['median_rmsd']))
elif fit_result['best_rmsd'] >= 0.5 * min(fit_result['mean_rmsd'], fit_result['median_rmsd']):
- print('[I] Not modeling {} {} as function o {}: best ({:.0f}) is not much better than ({:.0f}, {:.0f})'.format(
+ print('[I] Not modeling {} {} as function of {}: best ({:.0f}) is not much better than ({:.0f}, {:.0f})'.format(
state_or_tran, model_attribute, result['key'][2], fit_result['best_rmsd'],
fit_result['mean_rmsd'], fit_result['median_rmsd']))
else: