QuLab 2.0.1__cp310-cp310-macosx_11_0_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. QuLab-2.0.1.dist-info/LICENSE +21 -0
  2. QuLab-2.0.1.dist-info/METADATA +95 -0
  3. QuLab-2.0.1.dist-info/RECORD +82 -0
  4. QuLab-2.0.1.dist-info/WHEEL +5 -0
  5. QuLab-2.0.1.dist-info/entry_points.txt +2 -0
  6. QuLab-2.0.1.dist-info/top_level.txt +1 -0
  7. qulab/__init__.py +1 -0
  8. qulab/__main__.py +24 -0
  9. qulab/fun.cpython-310-darwin.so +0 -0
  10. qulab/monitor/__init__.py +1 -0
  11. qulab/monitor/__main__.py +8 -0
  12. qulab/monitor/config.py +41 -0
  13. qulab/monitor/dataset.py +77 -0
  14. qulab/monitor/event_queue.py +54 -0
  15. qulab/monitor/mainwindow.py +234 -0
  16. qulab/monitor/monitor.py +93 -0
  17. qulab/monitor/ploter.py +123 -0
  18. qulab/monitor/qt_compat.py +16 -0
  19. qulab/monitor/toolbar.py +265 -0
  20. qulab/scan/__init__.py +4 -0
  21. qulab/scan/base.py +548 -0
  22. qulab/scan/dataset.py +0 -0
  23. qulab/scan/expression.py +472 -0
  24. qulab/scan/optimize.py +0 -0
  25. qulab/scan/scanner.py +270 -0
  26. qulab/scan/transforms.py +16 -0
  27. qulab/scan/utils.py +37 -0
  28. qulab/storage/__init__.py +0 -0
  29. qulab/storage/__main__.py +51 -0
  30. qulab/storage/backend/__init__.py +0 -0
  31. qulab/storage/backend/redis.py +204 -0
  32. qulab/storage/base_dataset.py +352 -0
  33. qulab/storage/chunk.py +60 -0
  34. qulab/storage/dataset.py +127 -0
  35. qulab/storage/file.py +273 -0
  36. qulab/storage/models/__init__.py +22 -0
  37. qulab/storage/models/base.py +4 -0
  38. qulab/storage/models/config.py +28 -0
  39. qulab/storage/models/file.py +89 -0
  40. qulab/storage/models/ipy.py +58 -0
  41. qulab/storage/models/models.py +88 -0
  42. qulab/storage/models/record.py +161 -0
  43. qulab/storage/models/report.py +22 -0
  44. qulab/storage/models/tag.py +93 -0
  45. qulab/storage/storage.py +95 -0
  46. qulab/sys/__init__.py +0 -0
  47. qulab/sys/chat.py +688 -0
  48. qulab/sys/device/__init__.py +3 -0
  49. qulab/sys/device/basedevice.py +221 -0
  50. qulab/sys/device/loader.py +86 -0
  51. qulab/sys/device/utils.py +46 -0
  52. qulab/sys/drivers/FakeInstrument.py +52 -0
  53. qulab/sys/drivers/__init__.py +0 -0
  54. qulab/sys/ipy_events.py +125 -0
  55. qulab/sys/net/__init__.py +0 -0
  56. qulab/sys/net/bencoder.py +205 -0
  57. qulab/sys/net/cli.py +169 -0
  58. qulab/sys/net/dhcp.py +543 -0
  59. qulab/sys/net/dhcpd.py +176 -0
  60. qulab/sys/net/kad.py +1142 -0
  61. qulab/sys/net/kcp.py +192 -0
  62. qulab/sys/net/nginx.py +192 -0
  63. qulab/sys/progress.py +190 -0
  64. qulab/sys/rpc/__init__.py +0 -0
  65. qulab/sys/rpc/client.py +0 -0
  66. qulab/sys/rpc/exceptions.py +96 -0
  67. qulab/sys/rpc/msgpack.py +1052 -0
  68. qulab/sys/rpc/msgpack.pyi +41 -0
  69. qulab/sys/rpc/rpc.py +412 -0
  70. qulab/sys/rpc/serialize.py +139 -0
  71. qulab/sys/rpc/server.py +29 -0
  72. qulab/sys/rpc/socket.py +29 -0
  73. qulab/sys/rpc/utils.py +25 -0
  74. qulab/sys/rpc/worker.py +0 -0
  75. qulab/version.py +1 -0
  76. qulab/visualization/__init__.py +188 -0
  77. qulab/visualization/__main__.py +71 -0
  78. qulab/visualization/_autoplot.py +457 -0
  79. qulab/visualization/plot_layout.py +408 -0
  80. qulab/visualization/plot_seq.py +90 -0
  81. qulab/visualization/qdat.py +152 -0
  82. qulab/visualization/widgets.py +86 -0
@@ -0,0 +1,188 @@
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+
4
+ from ._autoplot import autoplot
5
+
6
+
7
+ def plotLine(c0, c1, ax, **kwargs):
8
+ t = np.linspace(0, 1, 11)
9
+ c = (c1 - c0) * t + c0
10
+ ax.plot(c.real, c.imag, **kwargs)
11
+
12
+
13
+ def plotCircle(c0, r, ax, **kwargs):
14
+ t = np.linspace(0, 1, 1001) * 2 * np.pi
15
+ s = c0 + r * np.exp(1j * t)
16
+ ax.plot(s.real, s.imag, **kwargs)
17
+
18
+
19
+ def plotEllipse(c0, a, b, phi, ax, **kwargs):
20
+ t = np.linspace(0, 1, 1001) * 2 * np.pi
21
+ c = np.exp(1j * t)
22
+ s = c0 + (c.real * a + 1j * c.imag * b) * np.exp(1j * phi)
23
+ ax.plot(s.real, s.imag, **kwargs)
24
+
25
+
26
+ def plotDistribution(s0,
27
+ s1,
28
+ fig=None,
29
+ axes=None,
30
+ info=None,
31
+ hotThresh=10000,
32
+ logy=False):
33
+ from waveforms.math.fit import get_threshold_info, mult_gaussian_pdf
34
+
35
+ if info is None:
36
+ info = get_threshold_info(s0, s1)
37
+ else:
38
+ info = get_threshold_info(s0, s1, info['threshold'], info['phi'])
39
+ thr, phi = info['threshold'], info['phi']
40
+ # visibility, p0, p1 = info['visibility']
41
+ # print(
42
+ # f"thr={thr:.6f}, phi={phi:.6f}, visibility={visibility:.3f}, {p0}, {1-p1}"
43
+ # )
44
+
45
+ if axes is not None:
46
+ ax1, ax2 = axes
47
+ else:
48
+ if fig is None:
49
+ fig = plt.figure()
50
+ ax1 = fig.add_subplot(121)
51
+ ax2 = fig.add_subplot(122)
52
+
53
+ if (len(s0) + len(s1)) < hotThresh:
54
+ ax1.plot(np.real(s0), np.imag(s0), '.', alpha=0.2)
55
+ ax1.plot(np.real(s1), np.imag(s1), '.', alpha=0.2)
56
+ else:
57
+ _, *bins = np.histogram2d(np.real(np.hstack([s0, s1])),
58
+ np.imag(np.hstack([s0, s1])),
59
+ bins=50)
60
+
61
+ H0, *_ = np.histogram2d(np.real(s0),
62
+ np.imag(s0),
63
+ bins=bins,
64
+ density=True)
65
+ H1, *_ = np.histogram2d(np.real(s1),
66
+ np.imag(s1),
67
+ bins=bins,
68
+ density=True)
69
+ vlim = max(np.max(np.abs(H0)), np.max(np.abs(H1)))
70
+
71
+ ax1.imshow(H1.T - H0.T,
72
+ alpha=(np.fmax(H0.T, H1.T) / vlim).clip(0, 1),
73
+ interpolation='nearest',
74
+ origin='lower',
75
+ cmap='coolwarm',
76
+ vmin=-vlim,
77
+ vmax=vlim,
78
+ extent=(bins[0][0], bins[0][-1], bins[1][0], bins[1][-1]))
79
+
80
+ ax1.axis('equal')
81
+ ax1.set_xticks([])
82
+ ax1.set_yticks([])
83
+ for s in ax1.spines.values():
84
+ s.set_visible(False)
85
+
86
+ # c0, c1 = info['center']
87
+ # a0, b0, a1, b1 = info['std']
88
+ params = info['params']
89
+ r0, i0, r1, i1 = params[0][0], params[1][0], params[0][1], params[1][1]
90
+ a0, b0, a1, b1 = params[0][2], params[1][2], params[0][3], params[1][3]
91
+ c0 = (r0 + 1j * i0) * np.exp(1j * phi)
92
+ c1 = (r1 + 1j * i1) * np.exp(1j * phi)
93
+ phi0 = phi + params[0][6]
94
+ phi1 = phi + params[1][6]
95
+ plotEllipse(c0, 2 * a0, 2 * b0, phi0, ax1)
96
+ plotEllipse(c1, 2 * a1, 2 * b1, phi1, ax1)
97
+
98
+ im0, im1 = info['idle']
99
+ lim = min(im0.min(), im1.min()), max(im0.max(), im1.max())
100
+ t = (np.linspace(lim[0], lim[1], 3) + 1j * thr) * np.exp(-1j * phi)
101
+ ax1.plot(t.imag, t.real, 'k--')
102
+
103
+ ax1.plot(np.real(c0), np.imag(c0), 'o', color='C3')
104
+ ax1.plot(np.real(c1), np.imag(c1), 'o', color='C4')
105
+
106
+ re0, re1 = info['signal']
107
+ x, a, b, c = info['cdf']
108
+
109
+ xrange = (min(re0.min(), re1.min()), max(re0.max(), re1.max()))
110
+
111
+ n0, bins0, *_ = ax2.hist(re0, bins=80, range=xrange, alpha=0.5)
112
+ n1, bins1, *_ = ax2.hist(re1, bins=80, range=xrange, alpha=0.5)
113
+
114
+ x_range = np.linspace(x.min(), x.max(), 1001)
115
+ *_, cov0, cov1 = info['std']
116
+ ax2.plot(
117
+ x_range,
118
+ np.sum(n0) * (bins0[1] - bins0[0]) *
119
+ mult_gaussian_pdf(x_range, [r0, r1], [
120
+ np.sqrt(cov0[0, 0]), np.sqrt(cov1[0, 0])
121
+ ], [params[0][4], 1 - params[0][4]]))
122
+ ax2.plot(
123
+ x_range,
124
+ np.sum(n1) * (bins1[1] - bins1[0]) *
125
+ mult_gaussian_pdf(x_range, [r0, r1], [
126
+ np.sqrt(cov0[0, 0]), np.sqrt(cov1[0, 0])
127
+ ], [params[0][5], 1 - params[0][5]]))
128
+ ax2.set_ylabel('Count')
129
+ ax2.set_xlabel('Projection Axes')
130
+ if logy:
131
+ ax2.set_yscale('log')
132
+ ax2.set_ylim(0.1, max(np.sum(n0), np.sum(n1)))
133
+
134
+ ax3 = ax2.twinx()
135
+ ax3.plot(x, a, '--', lw=1, color='C0')
136
+ ax3.plot(x, b, '--', lw=1, color='C1')
137
+ ax3.plot(x, c, 'k--', alpha=0.5, lw=1)
138
+ ax3.set_ylim(0, 1.1)
139
+ ax3.vlines(thr, 0, 1.1, 'k', alpha=0.5)
140
+ ax3.set_ylabel('Integral Probability')
141
+
142
+ return info
143
+
144
+
145
+ ALLXYSeq = [('I', 'I'), ('X', 'X'), ('Y', 'Y'), ('X', 'Y'), ('Y', 'X'),
146
+ ('X/2', 'I'), ('Y/2', 'I'), ('X/2', 'Y/2'), ('Y/2', 'X/2'),
147
+ ('X/2', 'Y'), ('Y/2', 'X'), ('X', 'Y/2'), ('Y', 'X/2'),
148
+ ('X/2', 'X'), ('X', 'X/2'), ('Y/2', 'Y'), ('Y', 'Y/2'), ('X', 'I'),
149
+ ('Y', 'I'), ('X/2', 'X/2'), ('Y/2', 'Y/2')]
150
+
151
+
152
+ def plotALLXY(data, ax=None):
153
+ assert len(data) % len(ALLXYSeq) == 0
154
+
155
+ if ax is None:
156
+ ax = plt.gca()
157
+
158
+ ax.plot(np.array(data), 'o-')
159
+ repeat = len(data) // len(ALLXYSeq)
160
+ ax.set_xticks(np.arange(len(ALLXYSeq)) * repeat + 0.5 * (repeat - 1))
161
+ ax.set_xticklabels([','.join(seq) for seq in ALLXYSeq], rotation=60)
162
+ ax.grid(which='major')
163
+
164
+
165
+ def plot_mat(rho, title='$\\chi$', cmap='coolwarm'):
166
+ lim = np.abs(rho).max()
167
+ N = rho.shape[0]
168
+
169
+ fig = plt.figure(figsize=(6, 4))
170
+ fig.suptitle(title)
171
+
172
+ ax1 = plt.subplot(121)
173
+ cax1 = ax1.imshow(rho.real, vmin=-lim, vmax=lim, cmap=cmap)
174
+ ax1.set_title('Re')
175
+ ax1.set_xticks(np.arange(N))
176
+ ax1.set_yticks(np.arange(N))
177
+
178
+ ax2 = plt.subplot(122)
179
+ cax2 = ax2.imshow(rho.imag, vmin=-lim, vmax=lim, cmap=cmap)
180
+ ax2.set_title('Im')
181
+ ax2.set_xticks(np.arange(N))
182
+ ax2.set_yticks(np.arange(N))
183
+
184
+ plt.subplots_adjust(bottom=0.2, right=0.9, top=0.95)
185
+
186
+ cbar_ax = fig.add_axes([0.15, 0.15, 0.7, 0.05])
187
+ cb = fig.colorbar(cax1, cax=cbar_ax, orientation='horizontal')
188
+ plt.show()
@@ -0,0 +1,71 @@
1
+ import pathlib
2
+ import pickle
3
+
4
+ import click
5
+ import dill
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+
9
+ from .qdat import draw as draw_qdat
10
+
11
+ default_draw_methods = {
12
+ '.qdat': draw_qdat,
13
+ }
14
+
15
+
16
+ def load_data(fname):
17
+ try:
18
+ from home.hkxu.tools import get_record_by_id
19
+ record_id = int(str(fname))
20
+ return get_record_by_id(record_id).data
21
+ except:
22
+ pass
23
+ with open(fname, 'rb') as f:
24
+ try:
25
+ data = pickle.load(f)
26
+ except:
27
+ f.seek(0)
28
+ data = dill.load(f)
29
+ return data
30
+
31
+
32
+ def draw_common(data):
33
+ try:
34
+ script = data['meta']['plot_script']
35
+ assert script.strip()
36
+ global_namespace = {'plt': plt, 'np': np, 'result': data}
37
+ exec(script, global_namespace)
38
+ except:
39
+ from home.hkxu.tools import plot_record
40
+ plot_record(data['meta']['id'])
41
+
42
+
43
+ def draw_error(data, text="No validate plot script found"):
44
+ fig = plt.figure()
45
+ ax = fig.add_subplot(111)
46
+ ax.text(0.5, 0.5, text, ha='center', va='center')
47
+ ax.set_axis_off()
48
+ return fig
49
+
50
+
51
+ @click.command()
52
+ @click.argument('fname', default='')
53
+ def plot(fname):
54
+ """Plot the data in the file."""
55
+ try:
56
+ fname = pathlib.Path(fname)
57
+ data = load_data(fname)
58
+ try:
59
+ draw_common(data)
60
+ except:
61
+ default_draw_methods.get(fname.suffix, draw_error)(data)
62
+ except FileNotFoundError:
63
+ draw_error(None, text=f"File {fname} not found.")
64
+ except pickle.UnpicklingError:
65
+ draw_error(None, text=f"File {fname} is not a pickle file.")
66
+
67
+ plt.show()
68
+
69
+
70
+ if __name__ == '__main__':
71
+ plot()
@@ -0,0 +1,457 @@
1
+ import math
2
+
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ from matplotlib.colors import LogNorm
6
+ from matplotlib.ticker import EngFormatter, LogFormatterSciNotation
7
+ from scipy.interpolate import griddata
8
+
9
+
10
+ def good_for_logscale(x, threshold=4):
11
+ if np.any(x <= 0):
12
+ return False
13
+ mid = (np.nanmin(x) + np.nanmax(x)) / 2
14
+ a = np.count_nonzero(np.nan_to_num(x <= mid, nan=0))
15
+ b = np.count_nonzero(np.nan_to_num(x >= mid, nan=0))
16
+ if a / b > threshold:
17
+ return True
18
+ return False
19
+
20
+
21
+ def equal_logspace(x):
22
+ logx = np.logspace(np.log10(x[0]), np.log10(x[-1]), len(x))
23
+ return np.allclose(x, logx)
24
+
25
+
26
+ def equal_linspace(x):
27
+ linearx = np.linspace(x[0], x[-1], len(x))
28
+ return np.allclose(x, linearx)
29
+
30
+
31
+ def as_1d_data(x, y, z):
32
+ x = np.asarray(x)
33
+ y = np.asarray(y)
34
+ z = np.asarray(z)
35
+ if z.ndim == 1:
36
+ return x, y, z
37
+
38
+ if z.ndim == 2:
39
+ x, y = np.meshgrid(x, y)
40
+ return x.ravel(), y.ravel(), z.ravel()
41
+
42
+ raise ValueError("z must be 1D or 2D")
43
+
44
+
45
+ def griddata_logx_logy(x, y, z, shape=(401, 401)):
46
+ x, y, z = as_1d_data(x, y, z)
47
+ xspace = np.logspace(np.log10(x.min()), np.log10(x.max()), shape[0])
48
+ yspace = np.logspace(np.log10(y.min()), np.log10(y.max()), shape[1])
49
+ xgrid, ygrid = np.meshgrid(xspace, yspace)
50
+ zgrid = griddata((x, y), z, (xgrid, ygrid), method='nearest')
51
+ return xspace, yspace, zgrid
52
+
53
+
54
+ def griddata_logx_linear_y(x, y, z, shape=(401, 401)):
55
+ x, y, z = as_1d_data(x, y, z)
56
+ xspace = np.logspace(np.log10(x.min()), np.log10(x.max()), shape[0])
57
+ yspace = np.linspace(y.min(), y.max(), shape[1])
58
+ xgrid, ygrid = np.meshgrid(xspace, yspace)
59
+ zgrid = griddata((x, y), z, (xgrid, ygrid), method='nearest')
60
+ return xspace, yspace, zgrid
61
+
62
+
63
+ def griddata_linear_x_logy(x, y, z, shape=(401, 401)):
64
+ x, y, z = as_1d_data(x, y, z)
65
+ xspace = np.linspace(x.min(), x.max(), shape[0])
66
+ yspace = np.logspace(np.log10(y.min()), np.log10(y.max()), shape[1])
67
+ xgrid, ygrid = np.meshgrid(xspace, yspace)
68
+ zgrid = griddata((x, y), z, (xgrid, ygrid), method='nearest')
69
+ return xspace, yspace, zgrid
70
+
71
+
72
+ def griddata_linear_x_linear_y(x, y, z, shape=(401, 401)):
73
+ x, y, z = as_1d_data(x, y, z)
74
+ xspace = np.linspace(x.min(), x.max(), shape[0])
75
+ yspace = np.linspace(y.min(), y.max(), shape[1])
76
+ xgrid, ygrid = np.meshgrid(xspace, yspace)
77
+ zgrid = griddata((x, y), z, (xgrid, ygrid), method='nearest')
78
+ return xspace, yspace, zgrid
79
+
80
+
81
+ def _get_log_ticks(x):
82
+ log10x = np.log10(x)
83
+
84
+ major_ticks = np.array(
85
+ range(math.floor(log10x[0]) - 1,
86
+ math.ceil(log10x[-1]) + 1))
87
+ minor_ticks = np.hstack([
88
+ np.log10(np.linspace(2, 10, 9, endpoint=False)) + x
89
+ for x in major_ticks
90
+ ])
91
+
92
+ major_ticks = major_ticks[(major_ticks >= log10x[0]) *
93
+ (major_ticks <= log10x[-1])]
94
+ minor_ticks = minor_ticks[(minor_ticks >= log10x[0]) *
95
+ (minor_ticks <= log10x[-1])]
96
+
97
+ return log10x, major_ticks, minor_ticks
98
+
99
+
100
+ class MyLogFormatter(EngFormatter):
101
+
102
+ def format_ticks(self, values):
103
+ if self.unit is None or self.unit == '':
104
+ fmt = LogFormatterSciNotation()
105
+ return [f"${fmt.format_data(10.0**x)}$" for x in values]
106
+ else:
107
+ return super().format_ticks(values)
108
+
109
+ def format_eng(self, x):
110
+ if self.unit is None or self.unit == '':
111
+ self.unit = ''
112
+ return f"{10.0**x:g}"
113
+ else:
114
+ return super().format_eng(10.0**x)
115
+
116
+
117
+ def imshow_logx(x, y, z, x_unit=None, ax=None, **kwargs):
118
+ if ax is None:
119
+ ax = plt.gca()
120
+
121
+ log10x, major_ticks, minor_ticks = _get_log_ticks(x)
122
+
123
+ dlogx, dy = log10x[1] - log10x[0], y[1] - y[0]
124
+ extent = (log10x[0] - dlogx / 2, log10x[-1] + dlogx / 2, y[0] - dy / 2,
125
+ y[-1] + dy / 2)
126
+
127
+ img = ax.imshow(z, extent=extent, **kwargs)
128
+
129
+ ax.set_xticks(major_ticks, minor=False)
130
+ ax.xaxis.set_major_formatter(MyLogFormatter(x_unit))
131
+ ax.set_xticks(minor_ticks, minor=True)
132
+
133
+ return img
134
+
135
+
136
+ def imshow_logy(x, y, z, y_unit=None, ax=None, **kwargs):
137
+ if ax is None:
138
+ ax = plt.gca()
139
+
140
+ log10y, major_ticks, minor_ticks = _get_log_ticks(y)
141
+
142
+ dlogy, dx = log10y[1] - log10y[0], x[1] - x[0]
143
+ extent = (x[0] - dx / 2, x[-1] + dx / 2, log10y[0] - dlogy / 2,
144
+ log10y[-1] + dlogy / 2)
145
+
146
+ img = ax.imshow(z, extent=extent, **kwargs)
147
+
148
+ ax.set_yticks(major_ticks, minor=False)
149
+ ax.yaxis.set_major_formatter(MyLogFormatter(y_unit))
150
+ ax.set_yticks(minor_ticks, minor=True)
151
+
152
+ return img
153
+
154
+
155
+ def imshow_loglog(x, y, z, x_unit=None, y_unit=None, ax=None, **kwargs):
156
+ if ax is None:
157
+ ax = plt.gca()
158
+
159
+ log10x, x_major_ticks, x_minor_ticks = _get_log_ticks(x)
160
+ log10y, y_major_ticks, y_minor_ticks = _get_log_ticks(y)
161
+
162
+ dlogx, dlogy = log10x[1] - log10x[0], log10y[1] - log10y[0]
163
+ extent = (log10x[0] - dlogx / 2, log10x[-1] + dlogx / 2,
164
+ log10y[0] - dlogy / 2, log10y[-1] + dlogy / 2)
165
+
166
+ img = ax.imshow(z, extent=extent, **kwargs)
167
+
168
+ ax.set_xticks(x_major_ticks, minor=False)
169
+ ax.xaxis.set_major_formatter(MyLogFormatter(x_unit))
170
+ ax.set_xticks(x_minor_ticks, minor=True)
171
+
172
+ ax.set_yticks(y_major_ticks, minor=False)
173
+ ax.yaxis.set_major_formatter(MyLogFormatter(y_unit))
174
+ ax.set_yticks(y_minor_ticks, minor=True)
175
+
176
+ return img
177
+
178
+
179
+ def plot_lines(x,
180
+ y,
181
+ z,
182
+ xlabel,
183
+ ylabel,
184
+ zlabel,
185
+ x_unit,
186
+ y_unit,
187
+ z_unit,
188
+ ax,
189
+ xscale='linear',
190
+ yscale='linear',
191
+ zscale='linear',
192
+ index=None,
193
+ **kwds):
194
+ z = np.asarray(z)
195
+ if len(y) > len(x):
196
+ x, y = y, x
197
+ xlabel, ylabel = ylabel, xlabel
198
+ xscale, yscale = yscale, xscale
199
+ z = z.T
200
+ if index is not None:
201
+ y = y[index]
202
+ z = z[index, :]
203
+
204
+ for i, l in enumerate(y):
205
+ if y_unit:
206
+ label = f"{ylabel}={l:.3} [{y_unit}]"
207
+ else:
208
+ if isinstance(l, float):
209
+ label = f"{ylabel}={l:.3}"
210
+ else:
211
+ label = f"{ylabel}={l}"
212
+ ax.plot(x, z[i, :], label=label, **kwds)
213
+ ax.legend()
214
+ xlabel = f"{xlabel} [{x_unit}]" if x_unit else xlabel
215
+ zlabel = f"{zlabel} [{z_unit}]" if z_unit else zlabel
216
+ ax.set_xlabel(xlabel)
217
+ ax.set_ylabel(zlabel)
218
+ ax.set_xscale(xscale)
219
+ ax.set_yscale(zscale)
220
+
221
+
222
+ def plot_img(x,
223
+ y,
224
+ z,
225
+ xlabel,
226
+ ylabel,
227
+ zlabel,
228
+ x_unit,
229
+ y_unit,
230
+ z_unit,
231
+ fig,
232
+ ax,
233
+ xscale='linear',
234
+ yscale='linear',
235
+ zscale='linear',
236
+ resolution=None,
237
+ **kwds):
238
+ kwds.setdefault('origin', 'lower')
239
+ kwds.setdefault('aspect', 'auto')
240
+ kwds.setdefault('interpolation', 'nearest')
241
+
242
+ if zscale == 'log':
243
+ vmim = kwds.get('vmin', np.min(z))
244
+ vmax = kwds.get('vmax', np.max(z))
245
+ kwds.setdefault('norm', LogNorm(vmax=vmax, vmin=vmim))
246
+ zlabel = f"{zlabel} [{z_unit}]" if z_unit else zlabel
247
+
248
+ band_area = False
249
+ if x.ndim == 1 and y.ndim == 2 and y.shape[1] == x.shape[0]:
250
+ x = np.asarray([x] * y.shape[0])
251
+ band_area = True
252
+ elif x.ndim == 2 and y.ndim == 1 and x.shape[0] == y.shape[0]:
253
+ y = np.asarray([y] * x.shape[1]).T
254
+ band_area = True
255
+ if band_area:
256
+ kwds.pop('origin', None)
257
+ kwds.pop('aspect', None)
258
+ kwds.pop('interpolation', None)
259
+ pc = ax.pcolormesh(x, y, z, **kwds)
260
+ xlabel = f"{xlabel} [{x_unit}]" if x_unit else xlabel
261
+ ylabel = f"{ylabel} [{y_unit}]" if y_unit else ylabel
262
+ ax.set_xlabel(xlabel)
263
+ ax.set_ylabel(ylabel)
264
+ cb = fig.colorbar(pc, ax=ax)
265
+ ax.set_xscale(xscale)
266
+ ax.set_yscale(yscale)
267
+ cb.set_label(zlabel)
268
+ return
269
+
270
+ if resolution is None:
271
+ resolution = (401, 401)
272
+ elif isinstance(resolution, int):
273
+ resolution = (resolution, resolution)
274
+
275
+ if (z.ndim == 1 or (xscale == 'linear' and not equal_linspace(x))
276
+ or (yscale == 'linear' and not equal_linspace(y))
277
+ or (xscale == 'log' and not equal_logspace(x))
278
+ or (yscale == 'log' and not equal_logspace(y))):
279
+ griddata = {
280
+ ('log', 'log'): griddata_logx_logy,
281
+ ('log', 'linear'): griddata_logx_linear_y,
282
+ ('linear', 'log'): griddata_linear_x_logy,
283
+ ('linear', 'linear'): griddata_linear_x_linear_y,
284
+ }[(xscale, yscale)]
285
+ x, y, z = griddata(x, y, z, resolution)
286
+
287
+ if (xscale, yscale) == ('linear', 'linear'):
288
+ dx, dy = x[1] - x[0], y[1] - y[0]
289
+ extent = (x[0] - dx / 2, x[-1] + dx / 2, y[0] - dy / 2, y[-1] + dy / 2)
290
+ kwds.setdefault('extent', extent)
291
+ img = ax.imshow(np.asarray(z), **kwds)
292
+ xlabel = f"{xlabel} [{x_unit}]" if x_unit else xlabel
293
+ ylabel = f"{ylabel} [{y_unit}]" if y_unit else ylabel
294
+ elif (xscale, yscale) == ('log', 'linear'):
295
+ ylabel = f"{ylabel} [{y_unit}]" if y_unit else ylabel
296
+ img = imshow_logx(x, y, z, x_unit, ax, **kwds)
297
+ elif (xscale, yscale) == ('linear', 'log'):
298
+ xlabel = f"{xlabel} [{x_unit}]" if x_unit else xlabel
299
+ img = imshow_logy(x, y, z, y_unit, ax, **kwds)
300
+ elif (xscale, yscale) == ('log', 'log'):
301
+ img = imshow_loglog(x, y, z, x_unit, y_unit, ax, **kwds)
302
+ else:
303
+ pass
304
+ ax.set_xlabel(xlabel)
305
+ ax.set_ylabel(ylabel)
306
+ cb = fig.colorbar(img, ax=ax)
307
+ cb.set_label(zlabel)
308
+
309
+
310
+ def plot_scatter(x,
311
+ y,
312
+ z,
313
+ xlabel,
314
+ ylabel,
315
+ zlabel,
316
+ x_unit,
317
+ y_unit,
318
+ z_unit,
319
+ fig,
320
+ ax,
321
+ xscale='linear',
322
+ yscale='linear',
323
+ zscale='linear',
324
+ **kwds):
325
+ if np.any(np.iscomplex(z)):
326
+ s = np.abs(z)
327
+ c = np.angle(z)
328
+ else:
329
+ s = np.abs(z)
330
+ c = z.real
331
+ ax.scatter(x, y, s=s, c=c, **kwds)
332
+ xlabel = f"{xlabel} [{x_unit}]" if x_unit else xlabel
333
+ ylabel = f"{ylabel} [{y_unit}]" if y_unit else ylabel
334
+ ax.set_xlabel(xlabel)
335
+ ax.set_ylabel(ylabel)
336
+ ax.set_xscale(xscale)
337
+ ax.set_yscale(yscale)
338
+
339
+
340
+ def autoplot(x,
341
+ y,
342
+ z,
343
+ xlabel='x',
344
+ ylabel='y',
345
+ zlabel='z',
346
+ x_unit='',
347
+ y_unit='',
348
+ z_unit='',
349
+ fig=None,
350
+ ax=None,
351
+ index=None,
352
+ xscale='auto',
353
+ yscale='auto',
354
+ zscale='auto',
355
+ max_lines=3,
356
+ scatter_lim=1000,
357
+ resolution=None,
358
+ **kwds):
359
+ """
360
+ Plot a 2D array as a line plot or an image.
361
+
362
+ Parameters:
363
+ x (array): x values
364
+ y (array): y values
365
+ z (array): z values
366
+ xlabel (str): x label
367
+ ylabel (str): y label
368
+ zlabel (str): z label
369
+ x_unit (str): x unit
370
+ y_unit (str): y unit
371
+ z_unit (str): z unit
372
+ fig (Figure): figure to plot on
373
+ ax (Axes): axes to plot on
374
+ index (int): index of the line to plot
375
+ xscale (str): x scale 'auto', 'linear' or 'log'
376
+ yscale (str): y scale 'auto', 'linear' or 'log'
377
+ zscale (str): z scale 'auto', 'linear' or 'log'
378
+ max_lines (int): maximum number of lines to plot
379
+ **kwds: keyword arguments passed to plot_img or plot_lines
380
+ """
381
+ if ax is not None:
382
+ fig = ax.figure
383
+ if fig is None:
384
+ fig = plt.gcf()
385
+ if ax is None:
386
+ ax = fig.add_subplot(111)
387
+
388
+ x = np.asarray(x)
389
+ y = np.asarray(y)
390
+ z = np.asarray(z)
391
+
392
+ if xscale == 'auto':
393
+ if good_for_logscale(x):
394
+ xscale = 'log'
395
+ else:
396
+ xscale = 'linear'
397
+ if yscale == 'auto':
398
+ if good_for_logscale(y):
399
+ yscale = 'log'
400
+ else:
401
+ yscale = 'linear'
402
+ if zscale == 'auto':
403
+ if good_for_logscale(z):
404
+ zscale = 'log'
405
+ else:
406
+ zscale = 'linear'
407
+
408
+ if x.shape == y.shape == z.shape and z.size < scatter_lim:
409
+ plot_scatter(x,
410
+ y,
411
+ z,
412
+ xlabel,
413
+ ylabel,
414
+ zlabel,
415
+ x_unit,
416
+ y_unit,
417
+ z_unit,
418
+ fig,
419
+ ax,
420
+ xscale=xscale,
421
+ yscale=yscale,
422
+ zscale=zscale,
423
+ **kwds)
424
+ elif z.ndim == 2 and (len(y) <= max_lines or len(x) <= max_lines
425
+ or index is not None):
426
+ plot_lines(x,
427
+ y,
428
+ z,
429
+ xlabel,
430
+ ylabel,
431
+ zlabel,
432
+ x_unit=x_unit,
433
+ y_unit=y_unit,
434
+ z_unit=z_unit,
435
+ xscale=xscale,
436
+ yscale=yscale,
437
+ zscale=zscale,
438
+ ax=ax,
439
+ index=index,
440
+ **kwds)
441
+ else:
442
+ plot_img(x,
443
+ y,
444
+ z,
445
+ xlabel,
446
+ ylabel,
447
+ zlabel,
448
+ x_unit=x_unit,
449
+ y_unit=y_unit,
450
+ z_unit=z_unit,
451
+ xscale=xscale,
452
+ yscale=yscale,
453
+ zscale=zscale,
454
+ fig=fig,
455
+ ax=ax,
456
+ resolution=resolution,
457
+ **kwds)