modusa 0.2.23__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modusa/.DS_Store +0 -0
- modusa/__init__.py +8 -1
- modusa/devtools/{generate_doc_source.py → generate_docs_source.py} +5 -5
- modusa/devtools/generate_template.py +5 -5
- modusa/devtools/main.py +3 -3
- modusa/devtools/templates/generator.py +1 -1
- modusa/devtools/templates/io.py +1 -1
- modusa/devtools/templates/{signal.py → model.py} +18 -11
- modusa/devtools/templates/plugin.py +1 -1
- modusa/generators/__init__.py +11 -1
- modusa/generators/audio.py +188 -0
- modusa/generators/audio_waveforms.py +1 -1
- modusa/generators/base.py +1 -1
- modusa/generators/ftds.py +298 -0
- modusa/generators/s1d.py +270 -0
- modusa/generators/s2d.py +300 -0
- modusa/generators/s_ax.py +102 -0
- modusa/generators/t_ax.py +64 -0
- modusa/generators/tds.py +267 -0
- modusa/models/__init__.py +14 -0
- modusa/models/__pycache__/signal1D.cpython-312.pyc.4443461152 +0 -0
- modusa/models/audio.py +90 -0
- modusa/models/base.py +70 -0
- modusa/models/data.py +457 -0
- modusa/models/ftds.py +584 -0
- modusa/models/s1d.py +578 -0
- modusa/models/s2d.py +619 -0
- modusa/models/s_ax.py +448 -0
- modusa/models/t_ax.py +335 -0
- modusa/models/tds.py +465 -0
- modusa/plugins/__init__.py +3 -1
- modusa/tmp.py +98 -0
- modusa/tools/__init__.py +5 -0
- modusa/tools/audio_converter.py +56 -67
- modusa/tools/audio_loader.py +90 -0
- modusa/tools/audio_player.py +42 -67
- modusa/tools/math_ops.py +104 -1
- modusa/tools/plotter.py +305 -497
- modusa/tools/youtube_downloader.py +31 -98
- modusa/utils/excp.py +6 -0
- modusa/utils/np_func_cat.py +44 -0
- modusa/utils/plot.py +142 -0
- {modusa-0.2.23.dist-info → modusa-0.3.1.dist-info}/METADATA +24 -19
- modusa-0.3.1.dist-info/RECORD +60 -0
- modusa/devtools/docs/source/generators/audio_waveforms.rst +0 -8
- modusa/devtools/docs/source/generators/base.rst +0 -8
- modusa/devtools/docs/source/generators/index.rst +0 -8
- modusa/devtools/docs/source/io/audio_loader.rst +0 -8
- modusa/devtools/docs/source/io/base.rst +0 -8
- modusa/devtools/docs/source/io/index.rst +0 -8
- modusa/devtools/docs/source/plugins/base.rst +0 -8
- modusa/devtools/docs/source/plugins/index.rst +0 -7
- modusa/devtools/docs/source/signals/audio_signal.rst +0 -8
- modusa/devtools/docs/source/signals/base.rst +0 -8
- modusa/devtools/docs/source/signals/frequency_domain_signal.rst +0 -8
- modusa/devtools/docs/source/signals/index.rst +0 -11
- modusa/devtools/docs/source/signals/spectrogram.rst +0 -8
- modusa/devtools/docs/source/signals/time_domain_signal.rst +0 -8
- modusa/devtools/docs/source/tools/audio_converter.rst +0 -8
- modusa/devtools/docs/source/tools/audio_player.rst +0 -8
- modusa/devtools/docs/source/tools/base.rst +0 -8
- modusa/devtools/docs/source/tools/fourier_tranform.rst +0 -8
- modusa/devtools/docs/source/tools/index.rst +0 -13
- modusa/devtools/docs/source/tools/math_ops.rst +0 -8
- modusa/devtools/docs/source/tools/plotter.rst +0 -8
- modusa/devtools/docs/source/tools/youtube_downloader.rst +0 -8
- modusa/io/__init__.py +0 -5
- modusa/io/audio_loader.py +0 -184
- modusa/io/base.py +0 -43
- modusa/signals/__init__.py +0 -3
- modusa/signals/audio_signal.py +0 -540
- modusa/signals/base.py +0 -27
- modusa/signals/frequency_domain_signal.py +0 -376
- modusa/signals/spectrogram.py +0 -564
- modusa/signals/time_domain_signal.py +0 -412
- modusa/tools/fourier_tranform.py +0 -24
- modusa-0.2.23.dist-info/RECORD +0 -70
- {modusa-0.2.23.dist-info → modusa-0.3.1.dist-info}/WHEEL +0 -0
- {modusa-0.2.23.dist-info → modusa-0.3.1.dist-info}/entry_points.txt +0 -0
- {modusa-0.2.23.dist-info → modusa-0.3.1.dist-info}/licenses/LICENSE.md +0 -0
modusa/tools/plotter.py
CHANGED
|
@@ -1,543 +1,351 @@
|
|
|
1
1
|
#!/usr/bin/env python3
|
|
2
2
|
|
|
3
3
|
|
|
4
|
-
from modusa import excp
|
|
5
|
-
from modusa.decorators import validate_args_type
|
|
6
|
-
from modusa.tools.base import ModusaTool
|
|
7
4
|
import numpy as np
|
|
8
5
|
import matplotlib.pyplot as plt
|
|
6
|
+
import matplotlib.gridspec as gridspec
|
|
9
7
|
from matplotlib.patches import Rectangle
|
|
10
|
-
from
|
|
11
|
-
import warnings
|
|
8
|
+
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
|
|
12
9
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class Plotter(ModusaTool):
|
|
17
|
-
"""
|
|
18
|
-
Plots different kind of signals using `matplotlib`.
|
|
19
|
-
|
|
20
|
-
Note
|
|
21
|
-
----
|
|
22
|
-
- The class has `plot_` methods to plot different types of signals (1D, 2D).
|
|
23
|
-
|
|
24
|
-
"""
|
|
25
|
-
|
|
26
|
-
#--------Meta Information----------
|
|
27
|
-
_name = ""
|
|
28
|
-
_description = ""
|
|
29
|
-
_author_name = "Ankit Anand"
|
|
30
|
-
_author_email = "ankit0.anand0@gmail.com"
|
|
31
|
-
_created_at = "2025-07-06"
|
|
32
|
-
#----------------------------------
|
|
33
|
-
|
|
34
|
-
@staticmethod
|
|
35
|
-
def plot_signal(
|
|
36
|
-
y: np.ndarray,
|
|
37
|
-
x: np.ndarray | None,
|
|
38
|
-
ax: plt.Axes | None = None,
|
|
39
|
-
fmt: str = "k",
|
|
40
|
-
title: str | None = None,
|
|
41
|
-
label: str | None = None,
|
|
42
|
-
ylabel: str | None = None,
|
|
43
|
-
xlabel: str | None = None,
|
|
44
|
-
ylim: tuple[float, float] | None = None,
|
|
45
|
-
xlim: tuple[float, float] | None = None,
|
|
46
|
-
highlight: list[tuple[float, float], ...] | None = None,
|
|
47
|
-
vlines: list[float] | None = None,
|
|
48
|
-
hlines: list[float] | None = None,
|
|
49
|
-
show_grid: bool = False,
|
|
50
|
-
stem: bool = False,
|
|
51
|
-
legend_loc: str = None,
|
|
52
|
-
) -> plt.Figure | None:
|
|
10
|
+
#======== 1D ===========
|
|
11
|
+
def plot1d(*args, ann=None, events=None, xlim=None, ylim=None, xlabel=None, ylabel=None, title=None, legend=None):
|
|
53
12
|
"""
|
|
54
|
-
Plots 1D signal using
|
|
55
|
-
arguments.
|
|
13
|
+
Plots a 1D signal using matplotlib.
|
|
56
14
|
|
|
57
15
|
.. code-block:: python
|
|
58
|
-
|
|
59
|
-
|
|
16
|
+
|
|
17
|
+
import modusa as ms
|
|
60
18
|
import numpy as np
|
|
61
19
|
|
|
62
|
-
|
|
63
|
-
x = np.linspace(0, 2 * np.pi, 100)
|
|
20
|
+
x = np.arange(100) / 100
|
|
64
21
|
y = np.sin(x)
|
|
65
22
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
x=x,
|
|
70
|
-
ax=None,
|
|
71
|
-
color="blue",
|
|
72
|
-
marker=None,
|
|
73
|
-
linestyle="-",
|
|
74
|
-
stem=False,
|
|
75
|
-
labels=("Time", "Amplitude", "Sine Wave"),
|
|
76
|
-
legend_loc="upper right",
|
|
77
|
-
zoom=None,
|
|
78
|
-
highlight=[(2, 4)]
|
|
79
|
-
)
|
|
80
|
-
|
|
81
|
-
|
|
23
|
+
display(ms.plot1d(y, x))
|
|
24
|
+
|
|
25
|
+
|
|
82
26
|
Parameters
|
|
83
27
|
----------
|
|
84
|
-
|
|
85
|
-
The signal
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
28
|
+
*args : tuple[array-like, array-like] | tuple[array-like]
|
|
29
|
+
- The signal y and axis x to be plotted.
|
|
30
|
+
- If only values are provided, we generate the axis using arange.
|
|
31
|
+
- E.g. (y1, x1), (y2, x2), ...
|
|
32
|
+
ann : list[tuple[Number, Number, str] | None
|
|
33
|
+
- A list of annotations to mark specific points. Each tuple should be of the form (start, end, label).
|
|
34
|
+
- Default: None => No annotation.
|
|
35
|
+
events : list[Number] | None
|
|
36
|
+
- A list of x-values where vertical lines (event markers) will be drawn.
|
|
37
|
+
- Default: None
|
|
38
|
+
xlim : tuple[Number, Number] | None
|
|
39
|
+
- Limits for the x-axis as (xmin, xmax).
|
|
40
|
+
- Default: None
|
|
41
|
+
ylim : tuple[Number, Number] | None
|
|
42
|
+
- Limits for the y-axis as (ymin, ymax).
|
|
43
|
+
- Default: None
|
|
44
|
+
xlabel : str | None
|
|
45
|
+
- Label for the x-axis.
|
|
46
|
+
- - Default: None
|
|
47
|
+
ylabel : str | None
|
|
48
|
+
- Label for the y-axis.
|
|
49
|
+
- Default: None
|
|
50
|
+
title : str | None
|
|
51
|
+
- Title of the plot.
|
|
52
|
+
- Default: None
|
|
53
|
+
legend : list[str] | None
|
|
54
|
+
- List of legend labels corresponding to each signal if plotting multiple lines.
|
|
55
|
+
- Default: None
|
|
56
|
+
|
|
107
57
|
Returns
|
|
108
58
|
-------
|
|
109
|
-
plt.Figure
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
59
|
+
plt.Figure
|
|
60
|
+
Matplolib figure.
|
|
113
61
|
"""
|
|
62
|
+
|
|
63
|
+
for arg in args:
|
|
64
|
+
if len(arg) not in [1, 2]: # 1 if it just provides values, 2 if it provided axis as well
|
|
65
|
+
raise ValueError(f"1D signal needs to have max 2 arrays (y, x) or simply (y, )")
|
|
66
|
+
if isinstance(legend, str): legend = (legend, )
|
|
114
67
|
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
if x is None:
|
|
122
|
-
x = np.arange(y.shape[0])
|
|
123
|
-
elif x.ndim != 1:
|
|
124
|
-
raise excp.InputValueError(f"`x` must be of dimension 1 not {x.ndim}.")
|
|
125
|
-
elif x.shape[0] < 1:
|
|
126
|
-
raise excp.InputValueError(f"`x` must not be empty.")
|
|
127
|
-
|
|
128
|
-
if x.shape[0] != y.shape[0]:
|
|
129
|
-
raise excp.InputValueError(f"`y` and `x` must be of same shape")
|
|
68
|
+
if legend is not None:
|
|
69
|
+
if len(legend) < len(args):
|
|
70
|
+
raise ValueError(f"Legend should be provided for each signal.")
|
|
71
|
+
|
|
72
|
+
fig = plt.figure(figsize=(16, 2))
|
|
73
|
+
gs = gridspec.GridSpec(2, 1, height_ratios=[0.2, 1])
|
|
130
74
|
|
|
131
|
-
|
|
132
|
-
if ax is None:
|
|
133
|
-
fig, ax = plt.subplots(figsize=(15, 2))
|
|
134
|
-
created_fig = True
|
|
135
|
-
else:
|
|
136
|
-
fig = ax.get_figure()
|
|
137
|
-
created_fig = False
|
|
138
|
-
|
|
139
|
-
# Add legend
|
|
140
|
-
if label is not None:
|
|
141
|
-
legend_loc = legend_loc or "best"
|
|
142
|
-
# Plot the signal and attach the label
|
|
143
|
-
if stem:
|
|
144
|
-
ax.stem(x, y, linefmt="k", markerfmt='o', label=label)
|
|
145
|
-
else:
|
|
146
|
-
ax.plot(x, y, fmt, lw=1.5, ms=3, label=label)
|
|
147
|
-
ax.legend(loc=legend_loc)
|
|
148
|
-
else:
|
|
149
|
-
# Plot the signal without label
|
|
150
|
-
if stem:
|
|
151
|
-
ax.stem(x, y, linefmt="k", markerfmt='o')
|
|
152
|
-
else:
|
|
153
|
-
ax.plot(x, y, fmt, lw=1.5, ms=3)
|
|
75
|
+
colors = plt.get_cmap('tab10').colors
|
|
154
76
|
|
|
77
|
+
signal_ax = fig.add_subplot(gs[1, 0])
|
|
78
|
+
annotation_ax = fig.add_subplot(gs[0, 0], sharex=signal_ax)
|
|
155
79
|
|
|
156
|
-
|
|
157
|
-
# Set the labels
|
|
158
|
-
if title is not None:
|
|
159
|
-
ax.set_title(title)
|
|
160
|
-
if ylabel is not None:
|
|
161
|
-
ax.set_ylabel(ylabel)
|
|
162
|
-
if xlabel is not None:
|
|
163
|
-
ax.set_xlabel(xlabel)
|
|
164
|
-
|
|
165
|
-
# Applying axes limits into a region
|
|
166
|
-
if ylim is not None:
|
|
167
|
-
ax.set_ylim(ylim)
|
|
80
|
+
# Set lim
|
|
168
81
|
if xlim is not None:
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
if highlight is not None:
|
|
172
|
-
y_min = np.min(y)
|
|
173
|
-
y_max = np.max(y)
|
|
174
|
-
y_range = y_max - y_min
|
|
175
|
-
label_box_height = 0.20 * y_range
|
|
176
|
-
|
|
177
|
-
for i, highlight_region in enumerate(highlight):
|
|
178
|
-
if len(highlight_region) != 2:
|
|
179
|
-
raise excp.InputValueError("`highlight` should be a list of tuple of 2 values (left, right) => [(1, 10.5)]")
|
|
180
|
-
|
|
181
|
-
l, r = highlight_region
|
|
182
|
-
l = x[0] if l is None else l
|
|
183
|
-
r = x[-1] if r is None else r
|
|
184
|
-
|
|
185
|
-
# Highlight rectangle (main background)
|
|
186
|
-
ax.add_patch(Rectangle(
|
|
187
|
-
(l, y_min),
|
|
188
|
-
r - l,
|
|
189
|
-
y_range,
|
|
190
|
-
color='red',
|
|
191
|
-
alpha=0.2,
|
|
192
|
-
zorder=10
|
|
193
|
-
))
|
|
194
|
-
|
|
195
|
-
# Label box inside the top of the highlight
|
|
196
|
-
ax.add_patch(Rectangle(
|
|
197
|
-
(l, y_max - label_box_height),
|
|
198
|
-
r - l,
|
|
199
|
-
label_box_height,
|
|
200
|
-
color='red',
|
|
201
|
-
alpha=0.4,
|
|
202
|
-
zorder=11
|
|
203
|
-
))
|
|
204
|
-
|
|
205
|
-
# Centered label inside that box
|
|
206
|
-
ax.text(
|
|
207
|
-
(l + r) / 2,
|
|
208
|
-
y_max - label_box_height / 2,
|
|
209
|
-
str(i + 1),
|
|
210
|
-
ha='center',
|
|
211
|
-
va='center',
|
|
212
|
-
fontsize=10,
|
|
213
|
-
color='white',
|
|
214
|
-
fontweight='bold',
|
|
215
|
-
zorder=12
|
|
216
|
-
)
|
|
82
|
+
signal_ax.set_xlim(xlim)
|
|
217
83
|
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
for xpos in vlines:
|
|
221
|
-
ax.axvline(x=xpos, color='blue', linestyle='--', linewidth=2, zorder=5)
|
|
222
|
-
|
|
223
|
-
# Horizontal lines
|
|
224
|
-
if hlines:
|
|
225
|
-
for ypos in hlines:
|
|
226
|
-
ax.axhline(y=ypos, color='blue', linestyle='--', linewidth=2, zorder=5)
|
|
227
|
-
|
|
228
|
-
# Show grid
|
|
229
|
-
if show_grid:
|
|
230
|
-
ax.grid(True, color="gray", linestyle="--", linewidth=0.5)
|
|
231
|
-
|
|
232
|
-
# Show/Return the figure as per needed
|
|
233
|
-
if created_fig:
|
|
234
|
-
fig.tight_layout()
|
|
235
|
-
if Plotter._in_notebook():
|
|
236
|
-
plt.tight_layout()
|
|
237
|
-
plt.close(fig)
|
|
238
|
-
return fig
|
|
239
|
-
else:
|
|
240
|
-
plt.tight_layout()
|
|
241
|
-
plt.show()
|
|
242
|
-
return fig
|
|
243
|
-
|
|
244
|
-
@staticmethod
|
|
245
|
-
@validate_args_type()
|
|
246
|
-
def plot_matrix(
|
|
247
|
-
M: np.ndarray,
|
|
248
|
-
r: np.ndarray | None = None,
|
|
249
|
-
c: np.ndarray | None = None,
|
|
250
|
-
ax: plt.Axes | None = None,
|
|
251
|
-
cmap: str = "gray_r",
|
|
252
|
-
title: str | None = None,
|
|
253
|
-
Mlabel: str | None = None,
|
|
254
|
-
rlabel: str | None = None,
|
|
255
|
-
clabel: str | None = None,
|
|
256
|
-
rlim: tuple[float, float] | None = None,
|
|
257
|
-
clim: tuple[float, float] | None = None,
|
|
258
|
-
highlight: list[tuple[float, float, float, float]] | None = None,
|
|
259
|
-
vlines: list[float] | None = None,
|
|
260
|
-
hlines: list[float] | None = None,
|
|
261
|
-
origin: str = "lower", # or "lower"
|
|
262
|
-
gamma: int | float | None = None,
|
|
263
|
-
show_colorbar: bool = True,
|
|
264
|
-
cax: plt.Axes | None = None,
|
|
265
|
-
show_grid: bool = True,
|
|
266
|
-
tick_mode: str = "center", # "center" or "edge"
|
|
267
|
-
n_ticks: tuple[int, int] | None = None,
|
|
268
|
-
) -> plt.Figure:
|
|
269
|
-
"""
|
|
270
|
-
Plot a 2D matrix with optional zooming, highlighting, and grid.
|
|
271
|
-
|
|
272
|
-
.. code-block:: python
|
|
84
|
+
if ylim is not None:
|
|
85
|
+
signal_ax.set_ylim(ylim)
|
|
273
86
|
|
|
274
|
-
from modusa.io import Plotter
|
|
275
|
-
import numpy as np
|
|
276
|
-
import matplotlib.pyplot as plt
|
|
277
|
-
|
|
278
|
-
# Create a 50x50 random matrix
|
|
279
|
-
M = np.random.rand(50, 50)
|
|
280
87
|
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
cmap="viridis",
|
|
296
|
-
origin="lower",
|
|
297
|
-
show_colorbar=True,
|
|
298
|
-
cax=None,
|
|
299
|
-
show_grid=False,
|
|
300
|
-
tick_mode="center",
|
|
301
|
-
n_ticks=(5, 5),
|
|
302
|
-
)
|
|
303
|
-
|
|
88
|
+
# Add signal plot
|
|
89
|
+
for i, signal in enumerate(args):
|
|
90
|
+
if len(signal) == 1:
|
|
91
|
+
y = signal[0]
|
|
92
|
+
if legend is not None:
|
|
93
|
+
signal_ax.plot(y, label=legend[i])
|
|
94
|
+
else:
|
|
95
|
+
signal_ax.plot(y)
|
|
96
|
+
elif len(signal) == 2:
|
|
97
|
+
y, x = signal[0], signal[1]
|
|
98
|
+
if legend is not None:
|
|
99
|
+
signal_ax.plot(x, y, label=legend[i])
|
|
100
|
+
else:
|
|
101
|
+
signal_ax.plot(x, y)
|
|
304
102
|
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
Number of ticks on row and column axes.
|
|
337
|
-
|
|
338
|
-
Returns
|
|
339
|
-
-------
|
|
340
|
-
plt.Figure
|
|
341
|
-
Matplotlib figure containing the plot.
|
|
342
|
-
|
|
343
|
-
"""
|
|
344
|
-
|
|
345
|
-
# Validate the important args and get the signal that needs to be plotted
|
|
346
|
-
if M.ndim != 2:
|
|
347
|
-
raise excp.InputValueError(f"`M` must have 2 dimension not {M.ndim}")
|
|
348
|
-
if r is None:
|
|
349
|
-
r = M.shape[0]
|
|
350
|
-
if c is None:
|
|
351
|
-
c = M.shape[1]
|
|
352
|
-
|
|
353
|
-
if r.ndim != 1 and c.ndim != 1:
|
|
354
|
-
raise excp.InputValueError(f"`r` and `c` must have 2 dimension not r:{r.ndim}, c:{c.ndim}")
|
|
355
|
-
|
|
356
|
-
if r.shape[0] != M.shape[0]:
|
|
357
|
-
raise excp.InputValueError(f"`r` must have shape as `M row` not {r.shape}")
|
|
358
|
-
if c.shape[0] != M.shape[1]:
|
|
359
|
-
raise excp.InputValueError(f"`c` must have shape as `M column` not {c.shape}")
|
|
360
|
-
|
|
361
|
-
# Scale the signal if needed
|
|
362
|
-
if gamma is not None:
|
|
363
|
-
M = np.log1p(float(gamma) * M)
|
|
364
|
-
|
|
365
|
-
# Create a figure
|
|
366
|
-
if ax is None:
|
|
367
|
-
fig, ax = plt.subplots(figsize=(15, 4))
|
|
368
|
-
created_fig = True
|
|
369
|
-
else:
|
|
370
|
-
fig = ax.get_figure()
|
|
371
|
-
created_fig = False
|
|
103
|
+
# Add annotations
|
|
104
|
+
if ann is not None:
|
|
105
|
+
annotation_ax.set_ylim(0, 1)
|
|
106
|
+
for i, (start, end, tag) in enumerate(ann):
|
|
107
|
+
if xlim is not None:
|
|
108
|
+
if end < xlim[0] or start > xlim[1]:
|
|
109
|
+
continue # Skip out-of-view regions
|
|
110
|
+
# Clip boundaries to xlim
|
|
111
|
+
start = max(start, xlim[0])
|
|
112
|
+
end = min(end, xlim[1])
|
|
113
|
+
|
|
114
|
+
color = colors[i % len(colors)]
|
|
115
|
+
width = end - start
|
|
116
|
+
rect = Rectangle((start, 0), width, 1, color=color, alpha=0.7)
|
|
117
|
+
annotation_ax.add_patch(rect)
|
|
118
|
+
annotation_ax.text((start + end) / 2, 0.5, tag,
|
|
119
|
+
ha='center', va='center',
|
|
120
|
+
fontsize=10, color='white', fontweight='bold', zorder=10)
|
|
121
|
+
# Add vlines
|
|
122
|
+
if events is not None:
|
|
123
|
+
for xpos in events:
|
|
124
|
+
if xlim is not None:
|
|
125
|
+
if xlim[0] <= xpos <= xlim[1]:
|
|
126
|
+
annotation_ax.axvline(x=xpos, color='black', linestyle='--', linewidth=1.5)
|
|
127
|
+
else:
|
|
128
|
+
annotation_ax.axvline(x=xpos, color='black', linestyle='--', linewidth=1.5)
|
|
129
|
+
|
|
130
|
+
# Add legend
|
|
131
|
+
if legend is not None:
|
|
132
|
+
handles, labels = signal_ax.get_legend_handles_labels()
|
|
133
|
+
fig.legend(handles, labels, loc='upper right', bbox_to_anchor=(0.9, 1.2), ncol=len(legend), frameon=False)
|
|
372
134
|
|
|
373
|
-
#
|
|
374
|
-
|
|
375
|
-
|
|
135
|
+
# Set title, labels
|
|
136
|
+
if title is not None:
|
|
137
|
+
annotation_ax.set_title(title, pad=10, size=11)
|
|
138
|
+
if xlabel is not None:
|
|
139
|
+
signal_ax.set_xlabel(xlabel)
|
|
140
|
+
if ylabel is not None:
|
|
141
|
+
signal_ax.set_ylabel(ylabel)
|
|
376
142
|
|
|
377
|
-
#
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
origin=origin,
|
|
383
|
-
extent=extent
|
|
384
|
-
)
|
|
143
|
+
# Decorating annotation axis thicker
|
|
144
|
+
if ann is not None:
|
|
145
|
+
annotation_ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
|
|
146
|
+
else:
|
|
147
|
+
annotation_ax.axis("off")
|
|
385
148
|
|
|
386
|
-
# Set the ticks and labels
|
|
387
|
-
if n_ticks is None:
|
|
388
|
-
n_ticks = (10, 10)
|
|
389
149
|
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
150
|
+
fig.subplots_adjust(hspace=0.01, wspace=0.05)
|
|
151
|
+
plt.close()
|
|
152
|
+
return fig
|
|
153
|
+
|
|
154
|
+
#======== 2D ===========
|
|
155
|
+
def plot2d(*args, ann=None, events=None, xlim=None, ylim=None, origin="lower", Mlabel=None, xlabel=None, ylabel=None, title=None, legend=None, lm=False):
|
|
156
|
+
"""
|
|
157
|
+
Plots a 2D matrix (e.g., spectrogram or heatmap) with optional annotations and events.
|
|
158
|
+
|
|
159
|
+
.. code-block:: python
|
|
160
|
+
|
|
161
|
+
import modusa as ms
|
|
162
|
+
import numpy as np
|
|
397
163
|
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
164
|
+
M = np.random.random((10, 30))
|
|
165
|
+
y = np.arange(M.shape[0])
|
|
166
|
+
x = np.arange(M.shape[1])
|
|
401
167
|
|
|
402
|
-
|
|
403
|
-
|
|
168
|
+
display(ms.plot2d(M, y, x))
|
|
169
|
+
|
|
170
|
+
Parameters
|
|
171
|
+
----------
|
|
172
|
+
*args : tuple[array-like, array-like]
|
|
173
|
+
- The signal values to be plotted.
|
|
174
|
+
- E.g. (M1, y1, x1), (M2, y2, x2), ...
|
|
175
|
+
ann : list[tuple[Number, Number, str]] | None
|
|
176
|
+
- A list of annotation spans. Each tuple should be (start, end, label).
|
|
177
|
+
- Default: None (no annotations).
|
|
178
|
+
events : list[Number] | None
|
|
179
|
+
- X-values where vertical event lines will be drawn.
|
|
180
|
+
- Default: None.
|
|
181
|
+
xlim : tuple[Number, Number] | None
|
|
182
|
+
- Limits for the x-axis as (xmin, xmax).
|
|
183
|
+
- Default: None (auto-scaled).
|
|
184
|
+
ylim : tuple[Number, Number] | None
|
|
185
|
+
- Limits for the y-axis as (ymin, ymax).
|
|
186
|
+
- Default: None (auto-scaled).
|
|
187
|
+
origin : {'upper', 'lower'}
|
|
188
|
+
- Origin position for the image display. Used in `imshow`.
|
|
189
|
+
- Default: "lower".
|
|
190
|
+
Mlabel : str | None
|
|
191
|
+
- Label for the colorbar (e.g., "Magnitude", "Energy").
|
|
192
|
+
- Default: None.
|
|
193
|
+
xlabel : str | None
|
|
194
|
+
- Label for the x-axis.
|
|
195
|
+
- Default: None.
|
|
196
|
+
ylabel : str | None
|
|
197
|
+
- Label for the y-axis.
|
|
198
|
+
- Default: None.
|
|
199
|
+
title : str | None
|
|
200
|
+
- Title of the plot.
|
|
201
|
+
- Default: None.
|
|
202
|
+
legend : list[str] | None
|
|
203
|
+
- Legend labels for any overlaid lines or annotations.
|
|
204
|
+
- Default: None.
|
|
205
|
+
lm: bool
|
|
206
|
+
- Adds a circular marker for the line.
|
|
207
|
+
- Default: False
|
|
208
|
+
- Useful to show the data points.
|
|
209
|
+
|
|
210
|
+
Returns
|
|
211
|
+
-------
|
|
212
|
+
matplotlib.figure.Figure
|
|
213
|
+
The matplotlib Figure object.
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
for arg in args:
|
|
217
|
+
if len(arg) not in [1, 2, 3]: # Either provide just the matrix or with both axes info
|
|
218
|
+
raise ValueError(f"Data to plot needs to have 3 arrays (M, y, x)")
|
|
219
|
+
if isinstance(legend, str): legend = (legend, )
|
|
220
|
+
|
|
221
|
+
fig = plt.figure(figsize=(16, 4))
|
|
222
|
+
gs = gridspec.GridSpec(3, 1, height_ratios=[0.2, 0.1, 1]) # colorbar, annotation, signal
|
|
223
|
+
|
|
224
|
+
colors = plt.get_cmap('tab10').colors
|
|
225
|
+
|
|
226
|
+
signal_ax = fig.add_subplot(gs[2, 0])
|
|
227
|
+
annotation_ax = fig.add_subplot(gs[1, 0], sharex=signal_ax)
|
|
228
|
+
|
|
229
|
+
colorbar_ax = fig.add_subplot(gs[0, 0])
|
|
230
|
+
colorbar_ax.axis("off")
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
# Add lim
|
|
234
|
+
if xlim is not None:
|
|
235
|
+
signal_ax.set_xlim(xlim)
|
|
404
236
|
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
ytick_idx = np.linspace(0, len(yticks_all) - 1, nr, dtype=int)
|
|
237
|
+
if ylim is not None:
|
|
238
|
+
signal_ax.set_ylim(ylim)
|
|
408
239
|
|
|
409
|
-
|
|
410
|
-
|
|
240
|
+
# Add signal plot
|
|
241
|
+
i = 0 # This is to track the legend for 1D plots
|
|
242
|
+
for signal in args:
|
|
411
243
|
|
|
412
|
-
#
|
|
413
|
-
if title is not None:
|
|
414
|
-
ax.set_title(title)
|
|
415
|
-
if rlabel is not None:
|
|
416
|
-
ax.set_ylabel(rlabel)
|
|
417
|
-
if clabel is not None:
|
|
418
|
-
ax.set_xlabel(clabel)
|
|
244
|
+
data = signal[0] # This can be 1D or 2D (1D meaning we have to overlay on the matrix)
|
|
419
245
|
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
if clim is not None:
|
|
424
|
-
ax.set_xlim(clim)
|
|
425
|
-
|
|
426
|
-
# Applying axes limits into a region
|
|
427
|
-
if rlim is not None:
|
|
428
|
-
ax.set_ylim(rlim)
|
|
429
|
-
if clim is not None:
|
|
430
|
-
ax.set_xlim(clim)
|
|
246
|
+
if data.ndim == 1: # 1D
|
|
247
|
+
if len(signal) == 1: # It means that the axis was not passed
|
|
248
|
+
x = np.arange(data.shape[0])
|
|
431
249
|
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
250
|
+
if lm is False:
|
|
251
|
+
if legend is not None:
|
|
252
|
+
signal_ax.plot(x, data, label=legend[i])
|
|
253
|
+
signal_ax.legend(loc="upper right")
|
|
254
|
+
else:
|
|
255
|
+
signal_ax.plot(x, data)
|
|
256
|
+
else:
|
|
257
|
+
if legend is not None:
|
|
258
|
+
signal_ax.plot(x, data, marker="o", markersize=7, markerfacecolor='red', linestyle="--", linewidth=2, label=legend[i])
|
|
259
|
+
signal_ax.legend(loc="upper right")
|
|
260
|
+
else:
|
|
261
|
+
signal_ax.plot(x, data, marker="o", markersize=7, markerfacecolor='red', linestyle="--", linewidth=2)
|
|
262
|
+
|
|
263
|
+
i += 1
|
|
435
264
|
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
elif len(highlight_region) == 4:
|
|
446
|
-
r1, r2, c1, c2 = highlight_region
|
|
447
|
-
|
|
448
|
-
r1 = r[0] if r1 is None else r1
|
|
449
|
-
r2 = r[-1] if r2 is None else r2
|
|
450
|
-
c1 = c[0] if c1 is None else c1
|
|
451
|
-
c2 = c[-1] if c2 is None else c2
|
|
452
|
-
|
|
453
|
-
row_min, row_max = min(r1, r2), max(r1, r2)
|
|
454
|
-
col_min, col_max = min(c1, c2), max(c1, c2)
|
|
455
|
-
|
|
456
|
-
width = col_max - col_min
|
|
457
|
-
height = row_max - row_min
|
|
458
|
-
|
|
459
|
-
# Main red highlight box
|
|
460
|
-
ax.add_patch(Rectangle(
|
|
461
|
-
(col_min, row_min),
|
|
462
|
-
width,
|
|
463
|
-
height,
|
|
464
|
-
color='red',
|
|
465
|
-
alpha=0.2,
|
|
466
|
-
zorder=10
|
|
467
|
-
))
|
|
265
|
+
elif data.ndim == 2: # 2D
|
|
266
|
+
M = data
|
|
267
|
+
if len(signal) == 1: # It means that the axes were not passed
|
|
268
|
+
y = np.arange(M.shape[0])
|
|
269
|
+
x = np.arange(M.shape[1])
|
|
270
|
+
dx = x[1] - x[0]
|
|
271
|
+
dy = y[1] - y[0]
|
|
272
|
+
extent=[x[0] - dx/2, x[-1] + dx/2, y[0] - dy/2, y[-1] + dy/2]
|
|
273
|
+
im = signal_ax.imshow(M, aspect="auto", origin=origin, cmap="gray_r", extent=extent)
|
|
468
274
|
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
fontsize=10,
|
|
487
|
-
color='white',
|
|
488
|
-
fontweight='bold',
|
|
489
|
-
zorder=12
|
|
490
|
-
)
|
|
491
|
-
|
|
492
|
-
# Show colorbar
|
|
493
|
-
if show_colorbar is not None:
|
|
494
|
-
cbar = fig.colorbar(im, ax=ax, cax=cax)
|
|
495
|
-
if Mlabel is not None:
|
|
496
|
-
cbar.set_label(Mlabel)
|
|
497
|
-
|
|
498
|
-
# Vertical lines
|
|
499
|
-
if vlines:
|
|
500
|
-
for xpos in vlines:
|
|
501
|
-
ax.axvline(x=xpos, color='blue', linestyle='--', linewidth=2, zorder=5)
|
|
502
|
-
|
|
503
|
-
# Horizontal lines
|
|
504
|
-
if hlines:
|
|
505
|
-
for ypos in hlines:
|
|
506
|
-
ax.axhline(y=ypos, color='blue', linestyle='--', linewidth=2, zorder=5)
|
|
275
|
+
elif len(signal) == 3: # It means that the axes were passed
|
|
276
|
+
M, y, x = signal[0], signal[1], signal[2]
|
|
277
|
+
dx = x[1] - x[0]
|
|
278
|
+
dy = y[1] - y[0]
|
|
279
|
+
extent=[x[0] - dx/2, x[-1] + dx/2, y[0] - dy/2, y[-1] + dy/2]
|
|
280
|
+
im = signal_ax.imshow(M, aspect="auto", origin=origin, cmap="gray_r", extent=extent)
|
|
281
|
+
|
|
282
|
+
# Add annotations
|
|
283
|
+
if ann is not None:
|
|
284
|
+
annotation_ax.set_ylim(0, 1)
|
|
285
|
+
for i, (start, end, tag) in enumerate(ann):
|
|
286
|
+
if xlim is not None:
|
|
287
|
+
if end < xlim[0] or start > xlim[1]:
|
|
288
|
+
continue # Skip out-of-view regions
|
|
289
|
+
# Clip boundaries to xlim
|
|
290
|
+
start = max(start, xlim[0])
|
|
291
|
+
end = min(end, xlim[1])
|
|
507
292
|
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
293
|
+
color = colors[i % len(colors)]
|
|
294
|
+
width = end - start
|
|
295
|
+
rect = Rectangle((start, 0), width, 1, color=color, alpha=0.7)
|
|
296
|
+
annotation_ax.add_patch(rect)
|
|
297
|
+
annotation_ax.text((start + end) / 2, 0.5, tag,
|
|
298
|
+
ha='center', va='center',
|
|
299
|
+
fontsize=10, color='white', fontweight='bold', zorder=10)
|
|
300
|
+
# Add vlines
|
|
301
|
+
if events is not None:
|
|
302
|
+
for xpos in events:
|
|
303
|
+
if xlim is not None:
|
|
304
|
+
if xlim[0] <= xpos <= xlim[1]:
|
|
305
|
+
annotation_ax.axvline(x=xpos, color='black', linestyle='--', linewidth=1.5)
|
|
518
306
|
else:
|
|
519
|
-
|
|
520
|
-
|
|
307
|
+
annotation_ax.axvline(x=xpos, color='black', linestyle='--', linewidth=1.5)
|
|
308
|
+
|
|
309
|
+
# Add legend incase there are 1D overlays
|
|
310
|
+
if legend is not None:
|
|
311
|
+
handles, labels = signal_ax.get_legend_handles_labels()
|
|
312
|
+
if handles: # Only add legend if there's something to show
|
|
313
|
+
signal_ax.legend(handles, labels, loc="upper right")
|
|
314
|
+
|
|
315
|
+
# Add colorbar
|
|
316
|
+
# Create an inset axis on top-right of signal_ax
|
|
317
|
+
cax = inset_axes(
|
|
318
|
+
colorbar_ax,
|
|
319
|
+
width="20%", # percentage of parent width
|
|
320
|
+
height="20%", # height in percentage of parent height
|
|
321
|
+
loc='upper right',
|
|
322
|
+
bbox_to_anchor=(0, 0, 1, 1),
|
|
323
|
+
bbox_transform=colorbar_ax.transAxes,
|
|
324
|
+
borderpad=1
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
cbar = plt.colorbar(im, cax=cax, orientation='horizontal')
|
|
328
|
+
cbar.ax.xaxis.set_ticks_position('top')
|
|
329
|
+
|
|
330
|
+
if Mlabel is not None:
|
|
331
|
+
cbar.set_label(Mlabel, labelpad=5)
|
|
521
332
|
|
|
522
|
-
@staticmethod
|
|
523
|
-
def _compute_centered_extent(r: np.ndarray, c: np.ndarray, origin: str) -> list[float]:
|
|
524
|
-
"""
|
|
525
333
|
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
334
|
+
# Set title, labels
|
|
335
|
+
if title is not None:
|
|
336
|
+
annotation_ax.set_title(title, pad=10, size=11)
|
|
337
|
+
if xlabel is not None:
|
|
338
|
+
signal_ax.set_xlabel(xlabel)
|
|
339
|
+
if ylabel is not None:
|
|
340
|
+
signal_ax.set_ylabel(ylabel)
|
|
341
|
+
|
|
534
342
|
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
343
|
+
# Making annotation axis spines thicker
|
|
344
|
+
if ann is not None:
|
|
345
|
+
annotation_ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
|
|
346
|
+
else:
|
|
347
|
+
annotation_ax.axis("off")
|
|
348
|
+
|
|
349
|
+
fig.subplots_adjust(hspace=0.01, wspace=0.05)
|
|
350
|
+
plt.close()
|
|
351
|
+
return fig
|