diff options
Diffstat (limited to 'lib/plotter.py')
-rwxr-xr-x | lib/plotter.py | 65 |
1 files changed, 52 insertions, 13 deletions
diff --git a/lib/plotter.py b/lib/plotter.py index 6ee2691..2ec635f 100755 --- a/lib/plotter.py +++ b/lib/plotter.py @@ -3,6 +3,7 @@ import itertools import numpy as np import matplotlib.pyplot as plt +import re from matplotlib.patches import Polygon def float_or_nan(n): @@ -97,7 +98,7 @@ def plot_xy(X, Y, xlabel = None, ylabel = None, title = None): def _param_slice_eq(a, b, index): return (*a[1][:index], *a[1][index+1:]) == (*b[1][:index], *b[1][index+1:]) and a[0] == b[0] -def plot_param(model, state_or_trans, attribute, param_idx, xlabel = None, ylabel = None, title = None, extra_functions = []): +def plot_param(model, state_or_trans, attribute, param_idx, xlabel = None, ylabel = None, title = None, extra_function = None): fig, ax1 = plt.subplots(figsize=(10,6)) if title != None: fig.canvas.set_window_title(title) @@ -107,10 +108,19 @@ def plot_param(model, state_or_trans, attribute, param_idx, xlabel = None, ylabe ax1.set_ylabel(ylabel) plt.subplots_adjust(left = 0.05, bottom = 0.05, right = 0.99, top = 0.99) + param_name = model.param_name(param_idx) + + function_filename = 'plot_param_{}_{}_{}.txt'.format(state_or_trans, attribute, param_name) + data_filename_base = 'measurements_{}_{}_{}'.format(state_or_trans, attribute, param_name) + param_model, param_info = model.get_fitted() by_other_param = {} + XX = [] + + legend_sanitizer = re.compile(r'[^0-9a-zA-Z]+') + for k, v in model.by_param.items(): if k[0] == state_or_trans: other_param_key = (*k[1][:param_idx], *k[1][param_idx+1:]) @@ -118,30 +128,59 @@ def plot_param(model, state_or_trans, attribute, param_idx, xlabel = None, ylabe by_other_param[other_param_key] = {'X': [], 'Y': []} by_other_param[other_param_key]['X'].extend([float(k[1][param_idx])] * len(v[attribute])) by_other_param[other_param_key]['Y'].extend(v[attribute]) + XX.extend(by_other_param[other_param_key]['X']) + + XX = np.array(XX) + x_range = int((XX.max() - XX.min()) * 10) + xsp = np.linspace(XX.min(), XX.max(), x_range) + YY = [xsp] + YY_legend = [param_name] + YY2 = [] + YY2_legend = [] cm = plt.get_cmap('brg', len(by_other_param)) - for i, k in enumerate(by_other_param): + for i, k in sorted(enumerate(by_other_param), key = lambda x: x[1]): v = by_other_param[k] v['X'] = np.array(v['X']) v['Y'] = np.array(v['Y']) plt.plot(v['X'], v['Y'], "rx", color=cm(i)) - x_range = int((v['X'].max() - v['X'].min()) * 2) - xsp = np.linspace(v['X'].min(), v['X'].max(), x_range) + YY2_legend.append(legend_sanitizer.sub('_', 'X_{}'.format(k))) + YY2.append(v['X']) + YY2_legend.append(legend_sanitizer.sub('_', 'Y_{}'.format(k))) + YY2.append(v['Y']) + + sanitized_k = legend_sanitizer.sub('_', str(k)) + with open('{}_{}.txt'.format(data_filename_base, sanitized_k), 'w') as f: + print('X Y', file=f) + for i in range(len(v['X'])): + print('{} {}'.format(v['X'][i], v['Y'][i]), file=f) + + #x_range = int((v['X'].max() - v['X'].min()) * 10) + #xsp = np.linspace(v['X'].min(), v['X'].max(), x_range) if param_model: ysp = [] for x in xsp: xarg = [*k[:param_idx], x, *k[param_idx:]] ysp.append(param_model(state_or_trans, attribute, param = xarg)) plt.plot(xsp, ysp, "r-", color=cm(i), linewidth=0.5) - if len(extra_functions) != 0: - for f in extra_functions: - ysp = [] - with np.errstate(divide='ignore', invalid='ignore'): - for x in xsp: - xarg = [*k[:param_idx], x, *k[param_idx:]] - ysp.append(f(*xarg)) - plt.plot(xsp, ysp, "r--", color=cm(i), linewidth=1, dashes=(3, 3)) - + YY.append(ysp) + YY_legend.append(legend_sanitizer.sub('_', 'regr_{}'.format(k))) + if extra_function != None: + ysp = [] + with np.errstate(divide='ignore', invalid='ignore'): + for x in xsp: + xarg = [*k[:param_idx], x, *k[param_idx:]] + ysp.append(extra_function(*xarg)) + plt.plot(xsp, ysp, "r--", color=cm(i), linewidth=1, dashes=(3, 3)) + YY.append(ysp) + YY_legend.append(legend_sanitizer.sub('_', 'symb_{}'.format(k))) + + with open(function_filename, 'w') as f: + print(' '.join(YY_legend), file=f) + for elem in np.array(YY).T: + print(' '.join(map(str, elem)), file=f) + + print(data_filename_base, function_filename) plt.show() |