modusa 0.2.23__py3-none-any.whl → 0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. modusa/.DS_Store +0 -0
  2. modusa/__init__.py +8 -1
  3. modusa/devtools/{generate_doc_source.py → generate_docs_source.py} +5 -5
  4. modusa/devtools/generate_template.py +5 -5
  5. modusa/devtools/main.py +3 -3
  6. modusa/devtools/templates/generator.py +1 -1
  7. modusa/devtools/templates/io.py +1 -1
  8. modusa/devtools/templates/{signal.py → model.py} +18 -11
  9. modusa/devtools/templates/plugin.py +1 -1
  10. modusa/generators/__init__.py +11 -1
  11. modusa/generators/audio.py +188 -0
  12. modusa/generators/audio_waveforms.py +1 -1
  13. modusa/generators/base.py +1 -1
  14. modusa/generators/ftds.py +298 -0
  15. modusa/generators/s1d.py +270 -0
  16. modusa/generators/s2d.py +300 -0
  17. modusa/generators/s_ax.py +102 -0
  18. modusa/generators/t_ax.py +64 -0
  19. modusa/generators/tds.py +267 -0
  20. modusa/models/__init__.py +14 -0
  21. modusa/models/__pycache__/signal1D.cpython-312.pyc.4443461152 +0 -0
  22. modusa/models/audio.py +90 -0
  23. modusa/models/base.py +70 -0
  24. modusa/models/data.py +457 -0
  25. modusa/models/ftds.py +584 -0
  26. modusa/models/s1d.py +578 -0
  27. modusa/models/s2d.py +619 -0
  28. modusa/models/s_ax.py +448 -0
  29. modusa/models/t_ax.py +335 -0
  30. modusa/models/tds.py +465 -0
  31. modusa/plugins/__init__.py +3 -1
  32. modusa/tmp.py +98 -0
  33. modusa/tools/__init__.py +5 -0
  34. modusa/tools/audio_converter.py +56 -67
  35. modusa/tools/audio_loader.py +90 -0
  36. modusa/tools/audio_player.py +42 -67
  37. modusa/tools/math_ops.py +104 -1
  38. modusa/tools/plotter.py +305 -497
  39. modusa/tools/youtube_downloader.py +31 -98
  40. modusa/utils/excp.py +6 -0
  41. modusa/utils/np_func_cat.py +44 -0
  42. modusa/utils/plot.py +142 -0
  43. {modusa-0.2.23.dist-info → modusa-0.3.dist-info}/METADATA +5 -16
  44. modusa-0.3.dist-info/RECORD +60 -0
  45. modusa/devtools/docs/source/generators/audio_waveforms.rst +0 -8
  46. modusa/devtools/docs/source/generators/base.rst +0 -8
  47. modusa/devtools/docs/source/generators/index.rst +0 -8
  48. modusa/devtools/docs/source/io/audio_loader.rst +0 -8
  49. modusa/devtools/docs/source/io/base.rst +0 -8
  50. modusa/devtools/docs/source/io/index.rst +0 -8
  51. modusa/devtools/docs/source/plugins/base.rst +0 -8
  52. modusa/devtools/docs/source/plugins/index.rst +0 -7
  53. modusa/devtools/docs/source/signals/audio_signal.rst +0 -8
  54. modusa/devtools/docs/source/signals/base.rst +0 -8
  55. modusa/devtools/docs/source/signals/frequency_domain_signal.rst +0 -8
  56. modusa/devtools/docs/source/signals/index.rst +0 -11
  57. modusa/devtools/docs/source/signals/spectrogram.rst +0 -8
  58. modusa/devtools/docs/source/signals/time_domain_signal.rst +0 -8
  59. modusa/devtools/docs/source/tools/audio_converter.rst +0 -8
  60. modusa/devtools/docs/source/tools/audio_player.rst +0 -8
  61. modusa/devtools/docs/source/tools/base.rst +0 -8
  62. modusa/devtools/docs/source/tools/fourier_tranform.rst +0 -8
  63. modusa/devtools/docs/source/tools/index.rst +0 -13
  64. modusa/devtools/docs/source/tools/math_ops.rst +0 -8
  65. modusa/devtools/docs/source/tools/plotter.rst +0 -8
  66. modusa/devtools/docs/source/tools/youtube_downloader.rst +0 -8
  67. modusa/io/__init__.py +0 -5
  68. modusa/io/audio_loader.py +0 -184
  69. modusa/io/base.py +0 -43
  70. modusa/signals/__init__.py +0 -3
  71. modusa/signals/audio_signal.py +0 -540
  72. modusa/signals/base.py +0 -27
  73. modusa/signals/frequency_domain_signal.py +0 -376
  74. modusa/signals/spectrogram.py +0 -564
  75. modusa/signals/time_domain_signal.py +0 -412
  76. modusa/tools/fourier_tranform.py +0 -24
  77. modusa-0.2.23.dist-info/RECORD +0 -70
  78. {modusa-0.2.23.dist-info → modusa-0.3.dist-info}/WHEEL +0 -0
  79. {modusa-0.2.23.dist-info → modusa-0.3.dist-info}/entry_points.txt +0 -0
  80. {modusa-0.2.23.dist-info → modusa-0.3.dist-info}/licenses/LICENSE.md +0 -0
modusa/tools/plotter.py CHANGED
@@ -1,543 +1,351 @@
1
1
  #!/usr/bin/env python3
2
2
 
3
3
 
4
- from modusa import excp
5
- from modusa.decorators import validate_args_type
6
- from modusa.tools.base import ModusaTool
7
4
  import numpy as np
8
5
  import matplotlib.pyplot as plt
6
+ import matplotlib.gridspec as gridspec
9
7
  from matplotlib.patches import Rectangle
10
- from matplotlib.ticker import MaxNLocator
11
- import warnings
8
+ from mpl_toolkits.axes_grid1.inset_locator import inset_axes
12
9
 
13
- warnings.filterwarnings("ignore", message="Glyph .* missing from font.*") # To supress any font related warnings, TODO: Add support to Devnagri font
14
-
15
-
16
- class Plotter(ModusaTool):
17
- """
18
- Plots different kind of signals using `matplotlib`.
19
-
20
- Note
21
- ----
22
- - The class has `plot_` methods to plot different types of signals (1D, 2D).
23
-
24
- """
25
-
26
- #--------Meta Information----------
27
- _name = ""
28
- _description = ""
29
- _author_name = "Ankit Anand"
30
- _author_email = "ankit0.anand0@gmail.com"
31
- _created_at = "2025-07-06"
32
- #----------------------------------
33
-
34
- @staticmethod
35
- def plot_signal(
36
- y: np.ndarray,
37
- x: np.ndarray | None,
38
- ax: plt.Axes | None = None,
39
- fmt: str = "k",
40
- title: str | None = None,
41
- label: str | None = None,
42
- ylabel: str | None = None,
43
- xlabel: str | None = None,
44
- ylim: tuple[float, float] | None = None,
45
- xlim: tuple[float, float] | None = None,
46
- highlight: list[tuple[float, float], ...] | None = None,
47
- vlines: list[float] | None = None,
48
- hlines: list[float] | None = None,
49
- show_grid: bool = False,
50
- stem: bool = False,
51
- legend_loc: str = None,
52
- ) -> plt.Figure | None:
10
+ #======== 1D ===========
11
+ def plot1d(*args, ann=None, events=None, xlim=None, ylim=None, xlabel=None, ylabel=None, title=None, legend=None):
53
12
  """
54
- Plots 1D signal using `matplotlib` with various settings passed through the
55
- arguments.
13
+ Plots a 1D signal using matplotlib.
56
14
 
57
15
  .. code-block:: python
58
-
59
- from modusa.io import Plotter
16
+
17
+ import modusa as ms
60
18
  import numpy as np
61
19
 
62
- # Generate a sample sine wave
63
- x = np.linspace(0, 2 * np.pi, 100)
20
+ x = np.arange(100) / 100
64
21
  y = np.sin(x)
65
22
 
66
- # Plot the signal
67
- fig = Plotter.plot_signal(
68
- y=y,
69
- x=x,
70
- ax=None,
71
- color="blue",
72
- marker=None,
73
- linestyle="-",
74
- stem=False,
75
- labels=("Time", "Amplitude", "Sine Wave"),
76
- legend_loc="upper right",
77
- zoom=None,
78
- highlight=[(2, 4)]
79
- )
80
-
81
-
23
+ display(ms.plot1d(y, x))
24
+
25
+
82
26
  Parameters
83
27
  ----------
84
- y: np.ndarray
85
- The signal values to plot on the y-axis.
86
- x: np.ndarray | None
87
- The x-axis values. If None, indices of `y` are used.
88
- ax: plt.Axes | None
89
- matplotlib Axes object to draw on. If None, a new figure and axis are created. Return type depends on parameter value.
90
- color: str
91
- Color of the plotted line or markers. (e.g. "k")
92
- marker: str | None
93
- marker style for the plot (e.g., 'o', 'x'). If None, no marker is used.
94
- linestyle: str | None
95
- Line style for the plot (e.g., '-', '--'). If None, no line is drawn.
96
- stem: bool
97
- If True, plots a stem plot.
98
- labels: tuple[str, str, str] | None
99
- Tuple containing (title, xlabel, ylabel). If None, no labels are set.
100
- legend_loc: str | None
101
- Location string for legend placement (e.g., 'upper right'). If None, no legend is shown.
102
- zoom: tuple | None
103
- Tuple specifying x-axis limits for zoom as (start, end). If None, full x-range is shown.
104
- highlight: list[tuple[float, float], ...] | None
105
- List of (start, end) tuples to highlight regions on the plot. e.g. [(1, 2.5), (6, 10)]
106
-
28
+ *args : tuple[array-like, array-like] | tuple[array-like]
29
+ - The signal y and axis x to be plotted.
30
+ - If only values are provided, we generate the axis using arange.
31
+ - E.g. (y1, x1), (y2, x2), ...
32
+ ann : list[tuple[Number, Number, str] | None
33
+ - A list of annotations to mark specific points. Each tuple should be of the form (start, end, label).
34
+ - Default: None => No annotation.
35
+ events : list[Number] | None
36
+ - A list of x-values where vertical lines (event markers) will be drawn.
37
+ - Default: None
38
+ xlim : tuple[Number, Number] | None
39
+ - Limits for the x-axis as (xmin, xmax).
40
+ - Default: None
41
+ ylim : tuple[Number, Number] | None
42
+ - Limits for the y-axis as (ymin, ymax).
43
+ - Default: None
44
+ xlabel : str | None
45
+ - Label for the x-axis.
46
+ - - Default: None
47
+ ylabel : str | None
48
+ - Label for the y-axis.
49
+ - Default: None
50
+ title : str | None
51
+ - Title of the plot.
52
+ - Default: None
53
+ legend : list[str] | None
54
+ - List of legend labels corresponding to each signal if plotting multiple lines.
55
+ - Default: None
56
+
107
57
  Returns
108
58
  -------
109
- plt.Figure | None
110
- Figure if `ax` is None else None.
111
-
112
-
59
+ plt.Figure
60
+ Matplolib figure.
113
61
  """
62
+
63
+ for arg in args:
64
+ if len(arg) not in [1, 2]: # 1 if it just provides values, 2 if it provided axis as well
65
+ raise ValueError(f"1D signal needs to have max 2 arrays (y, x) or simply (y, )")
66
+ if isinstance(legend, str): legend = (legend, )
114
67
 
115
- # Validate the important args and get the signal that needs to be plotted
116
- if y.ndim != 1:
117
- raise excp.InputValueError(f"`y` must be of dimension 1 not {y.ndim}.")
118
- if y.shape[0] < 1:
119
- raise excp.InputValueError(f"`y` must not be empty.")
120
-
121
- if x is None:
122
- x = np.arange(y.shape[0])
123
- elif x.ndim != 1:
124
- raise excp.InputValueError(f"`x` must be of dimension 1 not {x.ndim}.")
125
- elif x.shape[0] < 1:
126
- raise excp.InputValueError(f"`x` must not be empty.")
127
-
128
- if x.shape[0] != y.shape[0]:
129
- raise excp.InputValueError(f"`y` and `x` must be of same shape")
68
+ if legend is not None:
69
+ if len(legend) < len(args):
70
+ raise ValueError(f"Legend should be provided for each signal.")
71
+
72
+ fig = plt.figure(figsize=(16, 2))
73
+ gs = gridspec.GridSpec(2, 1, height_ratios=[0.2, 1])
130
74
 
131
- # Create a figure
132
- if ax is None:
133
- fig, ax = plt.subplots(figsize=(15, 2))
134
- created_fig = True
135
- else:
136
- fig = ax.get_figure()
137
- created_fig = False
138
-
139
- # Add legend
140
- if label is not None:
141
- legend_loc = legend_loc or "best"
142
- # Plot the signal and attach the label
143
- if stem:
144
- ax.stem(x, y, linefmt="k", markerfmt='o', label=label)
145
- else:
146
- ax.plot(x, y, fmt, lw=1.5, ms=3, label=label)
147
- ax.legend(loc=legend_loc)
148
- else:
149
- # Plot the signal without label
150
- if stem:
151
- ax.stem(x, y, linefmt="k", markerfmt='o')
152
- else:
153
- ax.plot(x, y, fmt, lw=1.5, ms=3)
75
+ colors = plt.get_cmap('tab10').colors
154
76
 
77
+ signal_ax = fig.add_subplot(gs[1, 0])
78
+ annotation_ax = fig.add_subplot(gs[0, 0], sharex=signal_ax)
155
79
 
156
-
157
- # Set the labels
158
- if title is not None:
159
- ax.set_title(title)
160
- if ylabel is not None:
161
- ax.set_ylabel(ylabel)
162
- if xlabel is not None:
163
- ax.set_xlabel(xlabel)
164
-
165
- # Applying axes limits into a region
166
- if ylim is not None:
167
- ax.set_ylim(ylim)
80
+ # Set lim
168
81
  if xlim is not None:
169
- ax.set_xlim(xlim)
170
-
171
- if highlight is not None:
172
- y_min = np.min(y)
173
- y_max = np.max(y)
174
- y_range = y_max - y_min
175
- label_box_height = 0.20 * y_range
176
-
177
- for i, highlight_region in enumerate(highlight):
178
- if len(highlight_region) != 2:
179
- raise excp.InputValueError("`highlight` should be a list of tuple of 2 values (left, right) => [(1, 10.5)]")
180
-
181
- l, r = highlight_region
182
- l = x[0] if l is None else l
183
- r = x[-1] if r is None else r
184
-
185
- # Highlight rectangle (main background)
186
- ax.add_patch(Rectangle(
187
- (l, y_min),
188
- r - l,
189
- y_range,
190
- color='red',
191
- alpha=0.2,
192
- zorder=10
193
- ))
194
-
195
- # Label box inside the top of the highlight
196
- ax.add_patch(Rectangle(
197
- (l, y_max - label_box_height),
198
- r - l,
199
- label_box_height,
200
- color='red',
201
- alpha=0.4,
202
- zorder=11
203
- ))
204
-
205
- # Centered label inside that box
206
- ax.text(
207
- (l + r) / 2,
208
- y_max - label_box_height / 2,
209
- str(i + 1),
210
- ha='center',
211
- va='center',
212
- fontsize=10,
213
- color='white',
214
- fontweight='bold',
215
- zorder=12
216
- )
82
+ signal_ax.set_xlim(xlim)
217
83
 
218
- # Vertical lines
219
- if vlines:
220
- for xpos in vlines:
221
- ax.axvline(x=xpos, color='blue', linestyle='--', linewidth=2, zorder=5)
222
-
223
- # Horizontal lines
224
- if hlines:
225
- for ypos in hlines:
226
- ax.axhline(y=ypos, color='blue', linestyle='--', linewidth=2, zorder=5)
227
-
228
- # Show grid
229
- if show_grid:
230
- ax.grid(True, color="gray", linestyle="--", linewidth=0.5)
231
-
232
- # Show/Return the figure as per needed
233
- if created_fig:
234
- fig.tight_layout()
235
- if Plotter._in_notebook():
236
- plt.tight_layout()
237
- plt.close(fig)
238
- return fig
239
- else:
240
- plt.tight_layout()
241
- plt.show()
242
- return fig
243
-
244
- @staticmethod
245
- @validate_args_type()
246
- def plot_matrix(
247
- M: np.ndarray,
248
- r: np.ndarray | None = None,
249
- c: np.ndarray | None = None,
250
- ax: plt.Axes | None = None,
251
- cmap: str = "gray_r",
252
- title: str | None = None,
253
- Mlabel: str | None = None,
254
- rlabel: str | None = None,
255
- clabel: str | None = None,
256
- rlim: tuple[float, float] | None = None,
257
- clim: tuple[float, float] | None = None,
258
- highlight: list[tuple[float, float, float, float]] | None = None,
259
- vlines: list[float] | None = None,
260
- hlines: list[float] | None = None,
261
- origin: str = "lower", # or "lower"
262
- gamma: int | float | None = None,
263
- show_colorbar: bool = True,
264
- cax: plt.Axes | None = None,
265
- show_grid: bool = True,
266
- tick_mode: str = "center", # "center" or "edge"
267
- n_ticks: tuple[int, int] | None = None,
268
- ) -> plt.Figure:
269
- """
270
- Plot a 2D matrix with optional zooming, highlighting, and grid.
271
-
272
- .. code-block:: python
84
+ if ylim is not None:
85
+ signal_ax.set_ylim(ylim)
273
86
 
274
- from modusa.io import Plotter
275
- import numpy as np
276
- import matplotlib.pyplot as plt
277
-
278
- # Create a 50x50 random matrix
279
- M = np.random.rand(50, 50)
280
87
 
281
- # Coordinate axes
282
- r = np.linspace(0, 1, M.shape[0])
283
- c = np.linspace(0, 1, M.shape[1])
284
-
285
- # Plot the matrix
286
- fig = Plotter.plot_matrix(
287
- M=M,
288
- r=r,
289
- c=c,
290
- log_compression_factor=None,
291
- ax=None,
292
- labels=None,
293
- zoom=None,
294
- highlight=None,
295
- cmap="viridis",
296
- origin="lower",
297
- show_colorbar=True,
298
- cax=None,
299
- show_grid=False,
300
- tick_mode="center",
301
- n_ticks=(5, 5),
302
- )
303
-
88
+ # Add signal plot
89
+ for i, signal in enumerate(args):
90
+ if len(signal) == 1:
91
+ y = signal[0]
92
+ if legend is not None:
93
+ signal_ax.plot(y, label=legend[i])
94
+ else:
95
+ signal_ax.plot(y)
96
+ elif len(signal) == 2:
97
+ y, x = signal[0], signal[1]
98
+ if legend is not None:
99
+ signal_ax.plot(x, y, label=legend[i])
100
+ else:
101
+ signal_ax.plot(x, y)
304
102
 
305
- Parameters
306
- ----------
307
- M: np.ndarray
308
- 2D matrix to plot.
309
- r: np.ndarray
310
- Row coordinate axes.
311
- c: np.ndarray
312
- Column coordinate axes.
313
- log_compression_factor: int | float | None
314
- Apply log compression to enhance contrast (if provided).
315
- ax: plt.Axes | None
316
- Matplotlib axis to draw on (creates new if None).
317
- labels: tuple[str, str, str, str] | None
318
- Labels for the plot (title, Mlabel, xlabel, ylabel).
319
- zoom: tuple[float, float, float, float] | None
320
- Zoom to (r1, r2, c1, c2) in matrix coordinates.
321
- highlight: list[tuple[float, float, float, float]] | None
322
- List of rectangles (r1, r2, c1, c2) to highlight.
323
- cmap: str
324
- Colormap to use.
325
- origin: str
326
- Image origin, e.g., "upper" or "lower".
327
- show_colorbar: bool
328
- Whether to display colorbar.
329
- cax: plt.Axes | None
330
- Axis to draw colorbar on (ignored if show_colorbar is False).
331
- show_grid: bool
332
- Whether to show grid lines.
333
- tick_mode: str
334
- Tick alignment mode: "center" or "edge".
335
- n_ticks: tuple[int, int]
336
- Number of ticks on row and column axes.
337
-
338
- Returns
339
- -------
340
- plt.Figure
341
- Matplotlib figure containing the plot.
342
-
343
- """
344
-
345
- # Validate the important args and get the signal that needs to be plotted
346
- if M.ndim != 2:
347
- raise excp.InputValueError(f"`M` must have 2 dimension not {M.ndim}")
348
- if r is None:
349
- r = M.shape[0]
350
- if c is None:
351
- c = M.shape[1]
352
-
353
- if r.ndim != 1 and c.ndim != 1:
354
- raise excp.InputValueError(f"`r` and `c` must have 2 dimension not r:{r.ndim}, c:{c.ndim}")
355
-
356
- if r.shape[0] != M.shape[0]:
357
- raise excp.InputValueError(f"`r` must have shape as `M row` not {r.shape}")
358
- if c.shape[0] != M.shape[1]:
359
- raise excp.InputValueError(f"`c` must have shape as `M column` not {c.shape}")
360
-
361
- # Scale the signal if needed
362
- if gamma is not None:
363
- M = np.log1p(float(gamma) * M)
364
-
365
- # Create a figure
366
- if ax is None:
367
- fig, ax = plt.subplots(figsize=(15, 4))
368
- created_fig = True
369
- else:
370
- fig = ax.get_figure()
371
- created_fig = False
103
+ # Add annotations
104
+ if ann is not None:
105
+ annotation_ax.set_ylim(0, 1)
106
+ for i, (start, end, tag) in enumerate(ann):
107
+ if xlim is not None:
108
+ if end < xlim[0] or start > xlim[1]:
109
+ continue # Skip out-of-view regions
110
+ # Clip boundaries to xlim
111
+ start = max(start, xlim[0])
112
+ end = min(end, xlim[1])
113
+
114
+ color = colors[i % len(colors)]
115
+ width = end - start
116
+ rect = Rectangle((start, 0), width, 1, color=color, alpha=0.7)
117
+ annotation_ax.add_patch(rect)
118
+ annotation_ax.text((start + end) / 2, 0.5, tag,
119
+ ha='center', va='center',
120
+ fontsize=10, color='white', fontweight='bold', zorder=10)
121
+ # Add vlines
122
+ if events is not None:
123
+ for xpos in events:
124
+ if xlim is not None:
125
+ if xlim[0] <= xpos <= xlim[1]:
126
+ annotation_ax.axvline(x=xpos, color='black', linestyle='--', linewidth=1.5)
127
+ else:
128
+ annotation_ax.axvline(x=xpos, color='black', linestyle='--', linewidth=1.5)
129
+
130
+ # Add legend
131
+ if legend is not None:
132
+ handles, labels = signal_ax.get_legend_handles_labels()
133
+ fig.legend(handles, labels, loc='upper right', bbox_to_anchor=(0.9, 1.2), ncol=len(legend), frameon=False)
372
134
 
373
- # Plot the signal with right configurations
374
- # Compute extent
375
- extent = Plotter._compute_centered_extent(r, c, origin)
135
+ # Set title, labels
136
+ if title is not None:
137
+ annotation_ax.set_title(title, pad=10, size=11)
138
+ if xlabel is not None:
139
+ signal_ax.set_xlabel(xlabel)
140
+ if ylabel is not None:
141
+ signal_ax.set_ylabel(ylabel)
376
142
 
377
- # Plot image
378
- im = ax.imshow(
379
- M,
380
- aspect="auto",
381
- cmap=cmap,
382
- origin=origin,
383
- extent=extent
384
- )
143
+ # Decorating annotation axis thicker
144
+ if ann is not None:
145
+ annotation_ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
146
+ else:
147
+ annotation_ax.axis("off")
385
148
 
386
- # Set the ticks and labels
387
- if n_ticks is None:
388
- n_ticks = (10, 10)
389
149
 
390
- if tick_mode == "center":
391
- ax.yaxis.set_major_locator(MaxNLocator(nbins=n_ticks[0]))
392
- ax.xaxis.set_major_locator(MaxNLocator(nbins=n_ticks[1])) # limits ticks
393
-
394
- elif tick_mode == "edge":
395
- dr = np.diff(r).mean() if len(r) > 1 else 1
396
- dc = np.diff(c).mean() if len(c) > 1 else 1
150
+ fig.subplots_adjust(hspace=0.01, wspace=0.05)
151
+ plt.close()
152
+ return fig
153
+
154
+ #======== 2D ===========
155
+ def plot2d(*args, ann=None, events=None, xlim=None, ylim=None, origin="lower", Mlabel=None, xlabel=None, ylabel=None, title=None, legend=None, lm=False):
156
+ """
157
+ Plots a 2D matrix (e.g., spectrogram or heatmap) with optional annotations and events.
158
+
159
+ .. code-block:: python
160
+
161
+ import modusa as ms
162
+ import numpy as np
397
163
 
398
- # Edge tick positions (centered)
399
- xticks_all = np.append(c, c[-1] + dc) - dc / 2
400
- yticks_all = np.append(r, r[-1] + dr) - dr / 2
164
+ M = np.random.random((10, 30))
165
+ y = np.arange(M.shape[0])
166
+ x = np.arange(M.shape[1])
401
167
 
402
- # Determine number of ticks
403
- nr, nc = n_ticks
168
+ display(ms.plot2d(M, y, x))
169
+
170
+ Parameters
171
+ ----------
172
+ *args : tuple[array-like, array-like]
173
+ - The signal values to be plotted.
174
+ - E.g. (M1, y1, x1), (M2, y2, x2), ...
175
+ ann : list[tuple[Number, Number, str]] | None
176
+ - A list of annotation spans. Each tuple should be (start, end, label).
177
+ - Default: None (no annotations).
178
+ events : list[Number] | None
179
+ - X-values where vertical event lines will be drawn.
180
+ - Default: None.
181
+ xlim : tuple[Number, Number] | None
182
+ - Limits for the x-axis as (xmin, xmax).
183
+ - Default: None (auto-scaled).
184
+ ylim : tuple[Number, Number] | None
185
+ - Limits for the y-axis as (ymin, ymax).
186
+ - Default: None (auto-scaled).
187
+ origin : {'upper', 'lower'}
188
+ - Origin position for the image display. Used in `imshow`.
189
+ - Default: "lower".
190
+ Mlabel : str | None
191
+ - Label for the colorbar (e.g., "Magnitude", "Energy").
192
+ - Default: None.
193
+ xlabel : str | None
194
+ - Label for the x-axis.
195
+ - Default: None.
196
+ ylabel : str | None
197
+ - Label for the y-axis.
198
+ - Default: None.
199
+ title : str | None
200
+ - Title of the plot.
201
+ - Default: None.
202
+ legend : list[str] | None
203
+ - Legend labels for any overlaid lines or annotations.
204
+ - Default: None.
205
+ lm: bool
206
+ - Adds a circular marker for the line.
207
+ - Default: False
208
+ - Useful to show the data points.
209
+
210
+ Returns
211
+ -------
212
+ matplotlib.figure.Figure
213
+ The matplotlib Figure object.
214
+ """
215
+
216
+ for arg in args:
217
+ if len(arg) not in [1, 2, 3]: # Either provide just the matrix or with both axes info
218
+ raise ValueError(f"Data to plot needs to have 3 arrays (M, y, x)")
219
+ if isinstance(legend, str): legend = (legend, )
220
+
221
+ fig = plt.figure(figsize=(16, 4))
222
+ gs = gridspec.GridSpec(3, 1, height_ratios=[0.2, 0.1, 1]) # colorbar, annotation, signal
223
+
224
+ colors = plt.get_cmap('tab10').colors
225
+
226
+ signal_ax = fig.add_subplot(gs[2, 0])
227
+ annotation_ax = fig.add_subplot(gs[1, 0], sharex=signal_ax)
228
+
229
+ colorbar_ax = fig.add_subplot(gs[0, 0])
230
+ colorbar_ax.axis("off")
231
+
232
+
233
+ # Add lim
234
+ if xlim is not None:
235
+ signal_ax.set_xlim(xlim)
404
236
 
405
- # Choose evenly spaced tick indices
406
- xtick_idx = np.linspace(0, len(xticks_all) - 1, nc, dtype=int)
407
- ytick_idx = np.linspace(0, len(yticks_all) - 1, nr, dtype=int)
237
+ if ylim is not None:
238
+ signal_ax.set_ylim(ylim)
408
239
 
409
- ax.set_xticks(xticks_all[xtick_idx])
410
- ax.set_yticks(yticks_all[ytick_idx])
240
+ # Add signal plot
241
+ i = 0 # This is to track the legend for 1D plots
242
+ for signal in args:
411
243
 
412
- # Set the labels
413
- if title is not None:
414
- ax.set_title(title)
415
- if rlabel is not None:
416
- ax.set_ylabel(rlabel)
417
- if clabel is not None:
418
- ax.set_xlabel(clabel)
244
+ data = signal[0] # This can be 1D or 2D (1D meaning we have to overlay on the matrix)
419
245
 
420
- # Applying axes limits into a region
421
- if rlim is not None:
422
- ax.set_ylim(rlim)
423
- if clim is not None:
424
- ax.set_xlim(clim)
425
-
426
- # Applying axes limits into a region
427
- if rlim is not None:
428
- ax.set_ylim(rlim)
429
- if clim is not None:
430
- ax.set_xlim(clim)
246
+ if data.ndim == 1: # 1D
247
+ if len(signal) == 1: # It means that the axis was not passed
248
+ x = np.arange(data.shape[0])
431
249
 
432
- if highlight is not None:
433
- row_range = r.max() - r.min()
434
- label_box_height = 0.08 * row_range
250
+ if lm is False:
251
+ if legend is not None:
252
+ signal_ax.plot(x, data, label=legend[i])
253
+ signal_ax.legend(loc="upper right")
254
+ else:
255
+ signal_ax.plot(x, data)
256
+ else:
257
+ if legend is not None:
258
+ signal_ax.plot(x, data, marker="o", markersize=7, markerfacecolor='red', linestyle="--", linewidth=2, label=legend[i])
259
+ signal_ax.legend(loc="upper right")
260
+ else:
261
+ signal_ax.plot(x, data, marker="o", markersize=7, markerfacecolor='red', linestyle="--", linewidth=2)
262
+
263
+ i += 1
435
264
 
436
- for i, highlight_region in enumerate(highlight):
437
- if len(highlight_region) != 4 and len(highlight_region) != 2:
438
- raise excp.InputValueError(
439
- "`highlight` should be a list of tuple of 4 or 2 values (row_min, row_max, col_min, col_max) or (col_min, col_max) => [(1, 10.5, 2, 40)] or [(2, 40)] "
440
- )
441
-
442
- if len(highlight_region) == 2:
443
- r1, r2 = None, None
444
- c1, c2 = highlight_region
445
- elif len(highlight_region) == 4:
446
- r1, r2, c1, c2 = highlight_region
447
-
448
- r1 = r[0] if r1 is None else r1
449
- r2 = r[-1] if r2 is None else r2
450
- c1 = c[0] if c1 is None else c1
451
- c2 = c[-1] if c2 is None else c2
452
-
453
- row_min, row_max = min(r1, r2), max(r1, r2)
454
- col_min, col_max = min(c1, c2), max(c1, c2)
455
-
456
- width = col_max - col_min
457
- height = row_max - row_min
458
-
459
- # Main red highlight box
460
- ax.add_patch(Rectangle(
461
- (col_min, row_min),
462
- width,
463
- height,
464
- color='red',
465
- alpha=0.2,
466
- zorder=10
467
- ))
265
+ elif data.ndim == 2: # 2D
266
+ M = data
267
+ if len(signal) == 1: # It means that the axes were not passed
268
+ y = np.arange(M.shape[0])
269
+ x = np.arange(M.shape[1])
270
+ dx = x[1] - x[0]
271
+ dy = y[1] - y[0]
272
+ extent=[x[0] - dx/2, x[-1] + dx/2, y[0] - dy/2, y[-1] + dy/2]
273
+ im = signal_ax.imshow(M, aspect="auto", origin=origin, cmap="gray_r", extent=extent)
468
274
 
469
- # Label box inside top of highlight (just below row_max)
470
- ax.add_patch(Rectangle(
471
- (col_min, row_max - label_box_height),
472
- width,
473
- label_box_height,
474
- color='red',
475
- alpha=0.4,
476
- zorder=11
477
- ))
478
-
479
- # Centered label in that box
480
- ax.text(
481
- (col_min + col_max) / 2,
482
- row_max - (label_box_height / 2),
483
- str(i + 1),
484
- ha='center',
485
- va='center',
486
- fontsize=10,
487
- color='white',
488
- fontweight='bold',
489
- zorder=12
490
- )
491
-
492
- # Show colorbar
493
- if show_colorbar is not None:
494
- cbar = fig.colorbar(im, ax=ax, cax=cax)
495
- if Mlabel is not None:
496
- cbar.set_label(Mlabel)
497
-
498
- # Vertical lines
499
- if vlines:
500
- for xpos in vlines:
501
- ax.axvline(x=xpos, color='blue', linestyle='--', linewidth=2, zorder=5)
502
-
503
- # Horizontal lines
504
- if hlines:
505
- for ypos in hlines:
506
- ax.axhline(y=ypos, color='blue', linestyle='--', linewidth=2, zorder=5)
275
+ elif len(signal) == 3: # It means that the axes were passed
276
+ M, y, x = signal[0], signal[1], signal[2]
277
+ dx = x[1] - x[0]
278
+ dy = y[1] - y[0]
279
+ extent=[x[0] - dx/2, x[-1] + dx/2, y[0] - dy/2, y[-1] + dy/2]
280
+ im = signal_ax.imshow(M, aspect="auto", origin=origin, cmap="gray_r", extent=extent)
281
+
282
+ # Add annotations
283
+ if ann is not None:
284
+ annotation_ax.set_ylim(0, 1)
285
+ for i, (start, end, tag) in enumerate(ann):
286
+ if xlim is not None:
287
+ if end < xlim[0] or start > xlim[1]:
288
+ continue # Skip out-of-view regions
289
+ # Clip boundaries to xlim
290
+ start = max(start, xlim[0])
291
+ end = min(end, xlim[1])
507
292
 
508
- # Show grid
509
- if show_grid:
510
- ax.grid(True, color="gray", linestyle="--", linewidth=0.5)
511
-
512
- # Show/Return the figure as per needed
513
- if created_fig:
514
- fig.tight_layout()
515
- if Plotter._in_notebook():
516
- plt.close(fig)
517
- return fig
293
+ color = colors[i % len(colors)]
294
+ width = end - start
295
+ rect = Rectangle((start, 0), width, 1, color=color, alpha=0.7)
296
+ annotation_ax.add_patch(rect)
297
+ annotation_ax.text((start + end) / 2, 0.5, tag,
298
+ ha='center', va='center',
299
+ fontsize=10, color='white', fontweight='bold', zorder=10)
300
+ # Add vlines
301
+ if events is not None:
302
+ for xpos in events:
303
+ if xlim is not None:
304
+ if xlim[0] <= xpos <= xlim[1]:
305
+ annotation_ax.axvline(x=xpos, color='black', linestyle='--', linewidth=1.5)
518
306
  else:
519
- plt.show()
520
- return fig
307
+ annotation_ax.axvline(x=xpos, color='black', linestyle='--', linewidth=1.5)
308
+
309
+ # Add legend incase there are 1D overlays
310
+ if legend is not None:
311
+ handles, labels = signal_ax.get_legend_handles_labels()
312
+ if handles: # Only add legend if there's something to show
313
+ signal_ax.legend(handles, labels, loc="upper right")
314
+
315
+ # Add colorbar
316
+ # Create an inset axis on top-right of signal_ax
317
+ cax = inset_axes(
318
+ colorbar_ax,
319
+ width="20%", # percentage of parent width
320
+ height="20%", # height in percentage of parent height
321
+ loc='upper right',
322
+ bbox_to_anchor=(0, 0, 1, 1),
323
+ bbox_transform=colorbar_ax.transAxes,
324
+ borderpad=1
325
+ )
326
+
327
+ cbar = plt.colorbar(im, cax=cax, orientation='horizontal')
328
+ cbar.ax.xaxis.set_ticks_position('top')
329
+
330
+ if Mlabel is not None:
331
+ cbar.set_label(Mlabel, labelpad=5)
521
332
 
522
- @staticmethod
523
- def _compute_centered_extent(r: np.ndarray, c: np.ndarray, origin: str) -> list[float]:
524
- """
525
333
 
526
- """
527
- dc = np.diff(c).mean() if len(c) > 1 else 1
528
- dr = np.diff(r).mean() if len(r) > 1 else 1
529
- left = c[0] - dc / 2
530
- right = c[-1] + dc / 2
531
- bottom = r[0] - dr / 2
532
- top = r[-1] + dr / 2
533
- return [left, right, top, bottom] if origin == "upper" else [left, right, bottom, top]
334
+ # Set title, labels
335
+ if title is not None:
336
+ annotation_ax.set_title(title, pad=10, size=11)
337
+ if xlabel is not None:
338
+ signal_ax.set_xlabel(xlabel)
339
+ if ylabel is not None:
340
+ signal_ax.set_ylabel(ylabel)
341
+
534
342
 
535
- @staticmethod
536
- def _in_notebook() -> bool:
537
- try:
538
- from IPython import get_ipython
539
- shell = get_ipython()
540
- return shell and shell.__class__.__name__ == "ZMQInteractiveShell"
541
- except ImportError:
542
- return False
543
-
343
+ # Making annotation axis spines thicker
344
+ if ann is not None:
345
+ annotation_ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
346
+ else:
347
+ annotation_ax.axis("off")
348
+
349
+ fig.subplots_adjust(hspace=0.01, wspace=0.05)
350
+ plt.close()
351
+ return fig