summaryrefslogtreecommitdiff
path: root/lib/plotter.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/plotter.py')
-rwxr-xr-xlib/plotter.py30
1 files changed, 20 insertions, 10 deletions
diff --git a/lib/plotter.py b/lib/plotter.py
index 2ec635f..24e2197 100755
--- a/lib/plotter.py
+++ b/lib/plotter.py
@@ -80,10 +80,10 @@ def plot_substate_thresholds_p(model, aggregate):
data = [aggregate[key]['sub_thresholds'] for key in keys]
boxplot(keys, None, None, data, 'Zustand', '% Clipping')
-def plot_y(Y, ylabel = None, title = None):
- plot_xy(np.arange(len(Y)), Y, ylabel = ylabel, title = title)
+def plot_y(Y, **kwargs):
+ plot_xy(np.arange(len(Y)), Y, **kwargs)
-def plot_xy(X, Y, xlabel = None, ylabel = None, title = None):
+def plot_xy(X, Y, xlabel = None, ylabel = None, title = None, output = None):
fig, ax1 = plt.subplots(figsize=(10,6))
if title != None:
fig.canvas.set_window_title(title)
@@ -91,14 +91,21 @@ def plot_xy(X, Y, xlabel = None, ylabel = None, title = None):
ax1.set_xlabel(xlabel)
if ylabel != None:
ax1.set_ylabel(ylabel)
- plt.subplots_adjust(left = 0.05, bottom = 0.05, right = 0.99, top = 0.99)
- plt.plot(X, Y, "rx")
- plt.show()
+ plt.subplots_adjust(left = 0.1, bottom = 0.1, right = 0.99, top = 0.99)
+ plt.plot(X, Y, "bo", markersize=2)
+ if output:
+ plt.savefig(output)
+ with open('{}.txt'.format(output), 'w') as f:
+ print('X Y', file=f)
+ for i in range(len(X)):
+ print('{} {}'.format(X[i], Y[i]), file=f)
+ else:
+ plt.show()
def _param_slice_eq(a, b, index):
return (*a[1][:index], *a[1][index+1:]) == (*b[1][:index], *b[1][index+1:]) and a[0] == b[0]
-def plot_param(model, state_or_trans, attribute, param_idx, xlabel = None, ylabel = None, title = None, extra_function = None):
+def plot_param(model, state_or_trans, attribute, param_idx, xlabel = None, ylabel = None, title = None, extra_function = None, output = None):
fig, ax1 = plt.subplots(figsize=(10,6))
if title != None:
fig.canvas.set_window_title(title)
@@ -106,7 +113,7 @@ def plot_param(model, state_or_trans, attribute, param_idx, xlabel = None, ylabe
ax1.set_xlabel(xlabel)
if ylabel != None:
ax1.set_ylabel(ylabel)
- plt.subplots_adjust(left = 0.05, bottom = 0.05, right = 0.99, top = 0.99)
+ plt.subplots_adjust(left = 0.1, bottom = 0.1, right = 0.99, top = 0.99)
param_name = model.param_name(param_idx)
@@ -143,7 +150,7 @@ def plot_param(model, state_or_trans, attribute, param_idx, xlabel = None, ylabe
v = by_other_param[k]
v['X'] = np.array(v['X'])
v['Y'] = np.array(v['Y'])
- plt.plot(v['X'], v['Y'], "rx", color=cm(i))
+ plt.plot(v['X'], v['Y'], "ro", color=cm(i), markersize=3)
YY2_legend.append(legend_sanitizer.sub('_', 'X_{}'.format(k)))
YY2.append(v['X'])
YY2_legend.append(legend_sanitizer.sub('_', 'Y_{}'.format(k)))
@@ -181,7 +188,10 @@ def plot_param(model, state_or_trans, attribute, param_idx, xlabel = None, ylabe
print(' '.join(map(str, elem)), file=f)
print(data_filename_base, function_filename)
- plt.show()
+ if output:
+ plt.savefig(output)
+ else:
+ plt.show()
def plot_param_fit(function, name, fitfunc, funp, parameters, datatype, index, X, Y, xaxis=None, yaxis=None):