modusa 0.2.22__py3-none-any.whl → 0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. modusa/.DS_Store +0 -0
  2. modusa/__init__.py +8 -1
  3. modusa/decorators.py +4 -4
  4. modusa/devtools/generate_docs_source.py +96 -0
  5. modusa/devtools/generate_template.py +13 -13
  6. modusa/devtools/main.py +4 -3
  7. modusa/devtools/templates/generator.py +1 -1
  8. modusa/devtools/templates/io.py +1 -1
  9. modusa/devtools/templates/{signal.py → model.py} +18 -11
  10. modusa/devtools/templates/plugin.py +1 -1
  11. modusa/devtools/templates/test.py +2 -3
  12. modusa/devtools/templates/{engine.py → tool.py} +3 -8
  13. modusa/generators/__init__.py +9 -1
  14. modusa/generators/audio.py +188 -0
  15. modusa/generators/audio_waveforms.py +22 -13
  16. modusa/generators/base.py +1 -1
  17. modusa/generators/ftds.py +298 -0
  18. modusa/generators/s1d.py +270 -0
  19. modusa/generators/s2d.py +300 -0
  20. modusa/generators/s_ax.py +102 -0
  21. modusa/generators/t_ax.py +64 -0
  22. modusa/generators/tds.py +267 -0
  23. modusa/main.py +0 -30
  24. modusa/models/__init__.py +14 -0
  25. modusa/models/__pycache__/signal1D.cpython-312.pyc.4443461152 +0 -0
  26. modusa/models/audio.py +90 -0
  27. modusa/models/base.py +70 -0
  28. modusa/models/data.py +457 -0
  29. modusa/models/ftds.py +584 -0
  30. modusa/models/s1d.py +578 -0
  31. modusa/models/s2d.py +619 -0
  32. modusa/models/s_ax.py +448 -0
  33. modusa/models/t_ax.py +335 -0
  34. modusa/models/tds.py +465 -0
  35. modusa/plugins/__init__.py +3 -1
  36. modusa/tmp.py +98 -0
  37. modusa/tools/__init__.py +7 -0
  38. modusa/tools/audio_converter.py +73 -0
  39. modusa/tools/audio_loader.py +90 -0
  40. modusa/tools/audio_player.py +89 -0
  41. modusa/tools/base.py +43 -0
  42. modusa/tools/math_ops.py +335 -0
  43. modusa/tools/plotter.py +351 -0
  44. modusa/tools/youtube_downloader.py +72 -0
  45. modusa/utils/excp.py +15 -42
  46. modusa/utils/np_func_cat.py +44 -0
  47. modusa/utils/plot.py +142 -0
  48. {modusa-0.2.22.dist-info → modusa-0.3.dist-info}/METADATA +5 -16
  49. modusa-0.3.dist-info/RECORD +60 -0
  50. modusa/engines/.DS_Store +0 -0
  51. modusa/engines/__init__.py +0 -3
  52. modusa/engines/base.py +0 -14
  53. modusa/io/__init__.py +0 -9
  54. modusa/io/audio_converter.py +0 -76
  55. modusa/io/audio_loader.py +0 -214
  56. modusa/io/audio_player.py +0 -72
  57. modusa/io/base.py +0 -43
  58. modusa/io/plotter.py +0 -430
  59. modusa/io/youtube_downloader.py +0 -139
  60. modusa/signals/__init__.py +0 -7
  61. modusa/signals/audio_signal.py +0 -483
  62. modusa/signals/base.py +0 -34
  63. modusa/signals/frequency_domain_signal.py +0 -329
  64. modusa/signals/signal_ops.py +0 -158
  65. modusa/signals/spectrogram.py +0 -465
  66. modusa/signals/time_domain_signal.py +0 -309
  67. modusa-0.2.22.dist-info/RECORD +0 -47
  68. {modusa-0.2.22.dist-info → modusa-0.3.dist-info}/WHEEL +0 -0
  69. {modusa-0.2.22.dist-info → modusa-0.3.dist-info}/entry_points.txt +0 -0
  70. {modusa-0.2.22.dist-info → modusa-0.3.dist-info}/licenses/LICENSE.md +0 -0
@@ -0,0 +1,351 @@
1
+ #!/usr/bin/env python3
2
+
3
+
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import matplotlib.gridspec as gridspec
7
+ from matplotlib.patches import Rectangle
8
+ from mpl_toolkits.axes_grid1.inset_locator import inset_axes
9
+
10
+ #======== 1D ===========
11
+ def plot1d(*args, ann=None, events=None, xlim=None, ylim=None, xlabel=None, ylabel=None, title=None, legend=None):
12
+ """
13
+ Plots a 1D signal using matplotlib.
14
+
15
+ .. code-block:: python
16
+
17
+ import modusa as ms
18
+ import numpy as np
19
+
20
+ x = np.arange(100) / 100
21
+ y = np.sin(x)
22
+
23
+ display(ms.plot1d(y, x))
24
+
25
+
26
+ Parameters
27
+ ----------
28
+ *args : tuple[array-like, array-like] | tuple[array-like]
29
+ - The signal y and axis x to be plotted.
30
+ - If only values are provided, we generate the axis using arange.
31
+ - E.g. (y1, x1), (y2, x2), ...
32
+ ann : list[tuple[Number, Number, str] | None
33
+ - A list of annotations to mark specific points. Each tuple should be of the form (start, end, label).
34
+ - Default: None => No annotation.
35
+ events : list[Number] | None
36
+ - A list of x-values where vertical lines (event markers) will be drawn.
37
+ - Default: None
38
+ xlim : tuple[Number, Number] | None
39
+ - Limits for the x-axis as (xmin, xmax).
40
+ - Default: None
41
+ ylim : tuple[Number, Number] | None
42
+ - Limits for the y-axis as (ymin, ymax).
43
+ - Default: None
44
+ xlabel : str | None
45
+ - Label for the x-axis.
46
+ - - Default: None
47
+ ylabel : str | None
48
+ - Label for the y-axis.
49
+ - Default: None
50
+ title : str | None
51
+ - Title of the plot.
52
+ - Default: None
53
+ legend : list[str] | None
54
+ - List of legend labels corresponding to each signal if plotting multiple lines.
55
+ - Default: None
56
+
57
+ Returns
58
+ -------
59
+ plt.Figure
60
+ Matplolib figure.
61
+ """
62
+
63
+ for arg in args:
64
+ if len(arg) not in [1, 2]: # 1 if it just provides values, 2 if it provided axis as well
65
+ raise ValueError(f"1D signal needs to have max 2 arrays (y, x) or simply (y, )")
66
+ if isinstance(legend, str): legend = (legend, )
67
+
68
+ if legend is not None:
69
+ if len(legend) < len(args):
70
+ raise ValueError(f"Legend should be provided for each signal.")
71
+
72
+ fig = plt.figure(figsize=(16, 2))
73
+ gs = gridspec.GridSpec(2, 1, height_ratios=[0.2, 1])
74
+
75
+ colors = plt.get_cmap('tab10').colors
76
+
77
+ signal_ax = fig.add_subplot(gs[1, 0])
78
+ annotation_ax = fig.add_subplot(gs[0, 0], sharex=signal_ax)
79
+
80
+ # Set lim
81
+ if xlim is not None:
82
+ signal_ax.set_xlim(xlim)
83
+
84
+ if ylim is not None:
85
+ signal_ax.set_ylim(ylim)
86
+
87
+
88
+ # Add signal plot
89
+ for i, signal in enumerate(args):
90
+ if len(signal) == 1:
91
+ y = signal[0]
92
+ if legend is not None:
93
+ signal_ax.plot(y, label=legend[i])
94
+ else:
95
+ signal_ax.plot(y)
96
+ elif len(signal) == 2:
97
+ y, x = signal[0], signal[1]
98
+ if legend is not None:
99
+ signal_ax.plot(x, y, label=legend[i])
100
+ else:
101
+ signal_ax.plot(x, y)
102
+
103
+ # Add annotations
104
+ if ann is not None:
105
+ annotation_ax.set_ylim(0, 1)
106
+ for i, (start, end, tag) in enumerate(ann):
107
+ if xlim is not None:
108
+ if end < xlim[0] or start > xlim[1]:
109
+ continue # Skip out-of-view regions
110
+ # Clip boundaries to xlim
111
+ start = max(start, xlim[0])
112
+ end = min(end, xlim[1])
113
+
114
+ color = colors[i % len(colors)]
115
+ width = end - start
116
+ rect = Rectangle((start, 0), width, 1, color=color, alpha=0.7)
117
+ annotation_ax.add_patch(rect)
118
+ annotation_ax.text((start + end) / 2, 0.5, tag,
119
+ ha='center', va='center',
120
+ fontsize=10, color='white', fontweight='bold', zorder=10)
121
+ # Add vlines
122
+ if events is not None:
123
+ for xpos in events:
124
+ if xlim is not None:
125
+ if xlim[0] <= xpos <= xlim[1]:
126
+ annotation_ax.axvline(x=xpos, color='black', linestyle='--', linewidth=1.5)
127
+ else:
128
+ annotation_ax.axvline(x=xpos, color='black', linestyle='--', linewidth=1.5)
129
+
130
+ # Add legend
131
+ if legend is not None:
132
+ handles, labels = signal_ax.get_legend_handles_labels()
133
+ fig.legend(handles, labels, loc='upper right', bbox_to_anchor=(0.9, 1.2), ncol=len(legend), frameon=False)
134
+
135
+ # Set title, labels
136
+ if title is not None:
137
+ annotation_ax.set_title(title, pad=10, size=11)
138
+ if xlabel is not None:
139
+ signal_ax.set_xlabel(xlabel)
140
+ if ylabel is not None:
141
+ signal_ax.set_ylabel(ylabel)
142
+
143
+ # Decorating annotation axis thicker
144
+ if ann is not None:
145
+ annotation_ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
146
+ else:
147
+ annotation_ax.axis("off")
148
+
149
+
150
+ fig.subplots_adjust(hspace=0.01, wspace=0.05)
151
+ plt.close()
152
+ return fig
153
+
154
+ #======== 2D ===========
155
+ def plot2d(*args, ann=None, events=None, xlim=None, ylim=None, origin="lower", Mlabel=None, xlabel=None, ylabel=None, title=None, legend=None, lm=False):
156
+ """
157
+ Plots a 2D matrix (e.g., spectrogram or heatmap) with optional annotations and events.
158
+
159
+ .. code-block:: python
160
+
161
+ import modusa as ms
162
+ import numpy as np
163
+
164
+ M = np.random.random((10, 30))
165
+ y = np.arange(M.shape[0])
166
+ x = np.arange(M.shape[1])
167
+
168
+ display(ms.plot2d(M, y, x))
169
+
170
+ Parameters
171
+ ----------
172
+ *args : tuple[array-like, array-like]
173
+ - The signal values to be plotted.
174
+ - E.g. (M1, y1, x1), (M2, y2, x2), ...
175
+ ann : list[tuple[Number, Number, str]] | None
176
+ - A list of annotation spans. Each tuple should be (start, end, label).
177
+ - Default: None (no annotations).
178
+ events : list[Number] | None
179
+ - X-values where vertical event lines will be drawn.
180
+ - Default: None.
181
+ xlim : tuple[Number, Number] | None
182
+ - Limits for the x-axis as (xmin, xmax).
183
+ - Default: None (auto-scaled).
184
+ ylim : tuple[Number, Number] | None
185
+ - Limits for the y-axis as (ymin, ymax).
186
+ - Default: None (auto-scaled).
187
+ origin : {'upper', 'lower'}
188
+ - Origin position for the image display. Used in `imshow`.
189
+ - Default: "lower".
190
+ Mlabel : str | None
191
+ - Label for the colorbar (e.g., "Magnitude", "Energy").
192
+ - Default: None.
193
+ xlabel : str | None
194
+ - Label for the x-axis.
195
+ - Default: None.
196
+ ylabel : str | None
197
+ - Label for the y-axis.
198
+ - Default: None.
199
+ title : str | None
200
+ - Title of the plot.
201
+ - Default: None.
202
+ legend : list[str] | None
203
+ - Legend labels for any overlaid lines or annotations.
204
+ - Default: None.
205
+ lm: bool
206
+ - Adds a circular marker for the line.
207
+ - Default: False
208
+ - Useful to show the data points.
209
+
210
+ Returns
211
+ -------
212
+ matplotlib.figure.Figure
213
+ The matplotlib Figure object.
214
+ """
215
+
216
+ for arg in args:
217
+ if len(arg) not in [1, 2, 3]: # Either provide just the matrix or with both axes info
218
+ raise ValueError(f"Data to plot needs to have 3 arrays (M, y, x)")
219
+ if isinstance(legend, str): legend = (legend, )
220
+
221
+ fig = plt.figure(figsize=(16, 4))
222
+ gs = gridspec.GridSpec(3, 1, height_ratios=[0.2, 0.1, 1]) # colorbar, annotation, signal
223
+
224
+ colors = plt.get_cmap('tab10').colors
225
+
226
+ signal_ax = fig.add_subplot(gs[2, 0])
227
+ annotation_ax = fig.add_subplot(gs[1, 0], sharex=signal_ax)
228
+
229
+ colorbar_ax = fig.add_subplot(gs[0, 0])
230
+ colorbar_ax.axis("off")
231
+
232
+
233
+ # Add lim
234
+ if xlim is not None:
235
+ signal_ax.set_xlim(xlim)
236
+
237
+ if ylim is not None:
238
+ signal_ax.set_ylim(ylim)
239
+
240
+ # Add signal plot
241
+ i = 0 # This is to track the legend for 1D plots
242
+ for signal in args:
243
+
244
+ data = signal[0] # This can be 1D or 2D (1D meaning we have to overlay on the matrix)
245
+
246
+ if data.ndim == 1: # 1D
247
+ if len(signal) == 1: # It means that the axis was not passed
248
+ x = np.arange(data.shape[0])
249
+
250
+ if lm is False:
251
+ if legend is not None:
252
+ signal_ax.plot(x, data, label=legend[i])
253
+ signal_ax.legend(loc="upper right")
254
+ else:
255
+ signal_ax.plot(x, data)
256
+ else:
257
+ if legend is not None:
258
+ signal_ax.plot(x, data, marker="o", markersize=7, markerfacecolor='red', linestyle="--", linewidth=2, label=legend[i])
259
+ signal_ax.legend(loc="upper right")
260
+ else:
261
+ signal_ax.plot(x, data, marker="o", markersize=7, markerfacecolor='red', linestyle="--", linewidth=2)
262
+
263
+ i += 1
264
+
265
+ elif data.ndim == 2: # 2D
266
+ M = data
267
+ if len(signal) == 1: # It means that the axes were not passed
268
+ y = np.arange(M.shape[0])
269
+ x = np.arange(M.shape[1])
270
+ dx = x[1] - x[0]
271
+ dy = y[1] - y[0]
272
+ extent=[x[0] - dx/2, x[-1] + dx/2, y[0] - dy/2, y[-1] + dy/2]
273
+ im = signal_ax.imshow(M, aspect="auto", origin=origin, cmap="gray_r", extent=extent)
274
+
275
+ elif len(signal) == 3: # It means that the axes were passed
276
+ M, y, x = signal[0], signal[1], signal[2]
277
+ dx = x[1] - x[0]
278
+ dy = y[1] - y[0]
279
+ extent=[x[0] - dx/2, x[-1] + dx/2, y[0] - dy/2, y[-1] + dy/2]
280
+ im = signal_ax.imshow(M, aspect="auto", origin=origin, cmap="gray_r", extent=extent)
281
+
282
+ # Add annotations
283
+ if ann is not None:
284
+ annotation_ax.set_ylim(0, 1)
285
+ for i, (start, end, tag) in enumerate(ann):
286
+ if xlim is not None:
287
+ if end < xlim[0] or start > xlim[1]:
288
+ continue # Skip out-of-view regions
289
+ # Clip boundaries to xlim
290
+ start = max(start, xlim[0])
291
+ end = min(end, xlim[1])
292
+
293
+ color = colors[i % len(colors)]
294
+ width = end - start
295
+ rect = Rectangle((start, 0), width, 1, color=color, alpha=0.7)
296
+ annotation_ax.add_patch(rect)
297
+ annotation_ax.text((start + end) / 2, 0.5, tag,
298
+ ha='center', va='center',
299
+ fontsize=10, color='white', fontweight='bold', zorder=10)
300
+ # Add vlines
301
+ if events is not None:
302
+ for xpos in events:
303
+ if xlim is not None:
304
+ if xlim[0] <= xpos <= xlim[1]:
305
+ annotation_ax.axvline(x=xpos, color='black', linestyle='--', linewidth=1.5)
306
+ else:
307
+ annotation_ax.axvline(x=xpos, color='black', linestyle='--', linewidth=1.5)
308
+
309
+ # Add legend incase there are 1D overlays
310
+ if legend is not None:
311
+ handles, labels = signal_ax.get_legend_handles_labels()
312
+ if handles: # Only add legend if there's something to show
313
+ signal_ax.legend(handles, labels, loc="upper right")
314
+
315
+ # Add colorbar
316
+ # Create an inset axis on top-right of signal_ax
317
+ cax = inset_axes(
318
+ colorbar_ax,
319
+ width="20%", # percentage of parent width
320
+ height="20%", # height in percentage of parent height
321
+ loc='upper right',
322
+ bbox_to_anchor=(0, 0, 1, 1),
323
+ bbox_transform=colorbar_ax.transAxes,
324
+ borderpad=1
325
+ )
326
+
327
+ cbar = plt.colorbar(im, cax=cax, orientation='horizontal')
328
+ cbar.ax.xaxis.set_ticks_position('top')
329
+
330
+ if Mlabel is not None:
331
+ cbar.set_label(Mlabel, labelpad=5)
332
+
333
+
334
+ # Set title, labels
335
+ if title is not None:
336
+ annotation_ax.set_title(title, pad=10, size=11)
337
+ if xlabel is not None:
338
+ signal_ax.set_xlabel(xlabel)
339
+ if ylabel is not None:
340
+ signal_ax.set_ylabel(ylabel)
341
+
342
+
343
+ # Making annotation axis spines thicker
344
+ if ann is not None:
345
+ annotation_ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
346
+ else:
347
+ annotation_ax.axis("off")
348
+
349
+ fig.subplots_adjust(hspace=0.01, wspace=0.05)
350
+ plt.close()
351
+ return fig
@@ -0,0 +1,72 @@
1
+ #!/usr/bin/env python3
2
+
3
+
4
+ from modusa import excp
5
+ from modusa.decorators import validate_args_type
6
+ from modusa.tools.base import ModusaTool
7
+ from typing import Any
8
+ from pathlib import Path
9
+ import yt_dlp
10
+
11
+
12
+ def download(url, content_type, output_dir):
13
+ """
14
+ Downloads audio/video from YouTube.
15
+
16
+ .. code-block:: python
17
+
18
+ # To download audio
19
+ import modusa as ms
20
+ audio_fp = ms.download(
21
+ url="https://www.youtube.com/watch?v=lIpw9-Y_N0g",
22
+ content_type="audio",
23
+ output_dir=".")
24
+
25
+ Parameters
26
+ ----------
27
+ url: str
28
+ Link to the YouTube video.
29
+ content_type: str
30
+ "audio" or "video"
31
+ output_dir: str | Path
32
+ Directory to save the YouTube content.
33
+
34
+ Returns
35
+ -------
36
+ Path
37
+ File path of the downloaded content.
38
+
39
+ """
40
+ if content_type == "audio":
41
+ output_dir = Path(output_dir)
42
+ output_dir.mkdir(parents=True, exist_ok=True)
43
+
44
+ ydl_opts = {
45
+ 'format': 'bestaudio/best',
46
+ 'outtmpl': f'{output_dir}/%(title)s.%(ext)s',
47
+ 'quiet': True,
48
+ }
49
+
50
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
51
+ info = ydl.extract_info(url, download=True)
52
+ return Path(info['requested_downloads'][0]['filepath'])
53
+
54
+ elif content_type == "video":
55
+ output_dir = Path(output_dir)
56
+ output_dir.mkdir(parents=True, exist_ok=True)
57
+
58
+ ydl_opts = {
59
+ 'format': 'bestvideo+bestaudio/best', # High quality
60
+ 'outtmpl': str(output_dir / '%(title)s.%(ext)s'),
61
+ 'merge_output_format': 'mp4',
62
+ 'quiet': True, # Hide verbose output
63
+ }
64
+
65
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
66
+ info = ydl.extract_info(url, download=True)
67
+ return Path(info['requested_downloads'][0]['filepath'])
68
+ else:
69
+ raise excp.InputValueError(f"`content_type` can either take 'audio' or 'video' not {content_type}")
70
+
71
+
72
+
modusa/utils/excp.py CHANGED
@@ -4,73 +4,46 @@
4
4
  #----------------------------------------
5
5
  # Base class errors
6
6
  #----------------------------------------
7
- class MusaBaseError(Exception):
7
+ class ModusaBaseError(Exception):
8
8
  """
9
9
  Ultimate base class for any kind of custom errors.
10
10
  """
11
11
  pass
12
12
 
13
- class TypeError(MusaBaseError):
13
+ class TypeError(ModusaBaseError):
14
14
  pass
15
15
 
16
- class InputError(MusaBaseError):
16
+ class InputError(ModusaBaseError):
17
17
  """
18
18
  Any Input type error.
19
19
  """
20
20
 
21
- class InputTypeError(MusaBaseError):
21
+ class InputTypeError(ModusaBaseError):
22
22
  """
23
23
  Any Input type error.
24
24
  """
25
25
 
26
- class InputValueError(MusaBaseError):
26
+ class InputValueError(ModusaBaseError):
27
27
  """
28
28
  Any Input type error.
29
29
  """
30
30
 
31
- class ImmutableAttributeError(MusaBaseError):
31
+ class OperationNotPossibleError(ModusaBaseError):
32
+ """
33
+ Any errors if there is an operations
34
+ failure.
35
+ """
36
+
37
+ class ImmutableAttributeError(ModusaBaseError):
32
38
  """Raised when attempting to modify an immutable attribute."""
33
39
  pass
34
40
 
35
- class FileNotFoundError(MusaBaseError):
41
+ class FileNotFoundError(ModusaBaseError):
36
42
  """Raised when file does not exist."""
37
43
  pass
38
44
 
39
-
40
- class PluginInputError(MusaBaseError):
41
- pass
42
-
43
- class PluginOutputError(MusaBaseError):
44
- pass
45
-
46
-
47
-
48
- class SignalOpError(MusaBaseError):
49
- pass
50
-
51
- class AttributeNotFoundError(MusaBaseError):
52
- pass
53
-
54
- class ParsingError(MusaBaseError):
55
- """
56
- Base class for any parsing related issues
57
- """
58
- pass
59
-
60
- class ValidationError(MusaBaseError):
61
- """
62
- Base class for all input validation error
63
- """
45
+ class PluginInputError(ModusaBaseError):
64
46
  pass
65
47
 
66
- class GenerationError(MusaBaseError):
67
- """
68
- Error when generation fails
69
- """
70
- pass
71
-
72
- class FileLoadingError(MusaBaseError):
73
- """
74
- Error loading a file
75
- """
48
+ class PluginOutputError(ModusaBaseError):
76
49
  pass
@@ -0,0 +1,44 @@
1
+ #!/usr/bin/env python3
2
+
3
+ # This contains categorised NumPy functions
4
+ # This is useful while handling np operations on modusa signals
5
+
6
+ import numpy as np
7
+
8
+ SHAPE_PRESERVING_FUNCS = {
9
+ np.sin, np.cos, np.tan,
10
+ np.sinh, np.cosh, np.tanh,
11
+ np.arcsin, np.arccos, np.arctan,
12
+ np.exp, np.expm1, np.log, np.log2, np.log10, np.log1p,
13
+ np.abs, np.negative, np.positive,
14
+ np.square, np.sqrt, np.cbrt,
15
+ np.floor, np.ceil, np.round, np.rint, np.trunc,
16
+ np.clip, np.sign,
17
+ np.add, np.subtract, np.multiply, np.true_divide, np.floor_divide
18
+ }
19
+
20
+ REDUCTION_FUNCS = {
21
+ np.sum, np.prod,
22
+ np.mean, np.std, np.var,
23
+ np.min, np.max,
24
+ np.argmin, np.argmax,
25
+ np.median, np.percentile,
26
+ np.all, np.any,
27
+ np.nanmean, np.nanstd, np.nanvar, np.nansum
28
+ }
29
+
30
+ CONCAT_FUNCS = {
31
+ np.concatenate, np.stack, np.hstack, np.vstack, np.dstack,
32
+ np.column_stack, np.row_stack
33
+ }
34
+
35
+ X_NEEDS_ADJUSTMENT_FUNCS = {
36
+ np.diff,
37
+ np.gradient,
38
+ np.trim_zeros,
39
+ np.unwrap,
40
+ np.fft.fft, np.fft.ifft, np.fft.fftshift, np.fft.ifftshift,
41
+ np.correlate, np.convolve
42
+ }
43
+
44
+
modusa/utils/plot.py ADDED
@@ -0,0 +1,142 @@
1
+ #!/usr/bin/env python3
2
+
3
+ from modusa.models.s1d import S1D
4
+ from modusa.models.s2d import S2D
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from collections import defaultdict
8
+ import itertools
9
+
10
+ def _in_notebook() -> bool:
11
+ """
12
+ To check if we are in jupyter notebook environment.
13
+ """
14
+ try:
15
+ from IPython import get_ipython
16
+ shell = get_ipython()
17
+ return shell and shell.__class__.__name__ == "ZMQInteractiveShell"
18
+ except ImportError:
19
+ return False
20
+
21
+ def plot_multiple_signals(
22
+ *args,
23
+ loc = None,
24
+ x_lim: tuple[float, float] | None = None,
25
+ highlight_regions: list[tuple[float, float, str]] | None = None,
26
+ vlines: list[float, ...] | None = None,
27
+ ) -> plt.Figure:
28
+ """
29
+ Plots multiple instances of uniform `Signal1D` and `Signal2D`
30
+ with proper formatting and time aligned.
31
+
32
+ Parameters
33
+ ----------
34
+ loc: tuple[int]
35
+ - The len should match the number of signals sent as args.
36
+ - e.g. (0, 1, 1) => First plot at ax 0, second and third plot at ax 1
37
+ - Default: None => (0, 1, 2, ...) all plots on a new ax.
38
+ highlight_regions: list[tuple[float, float, str]]
39
+ -
40
+
41
+ """
42
+
43
+ assert len(args) >= 1, "No signal provided to plot"
44
+ signals = args
45
+
46
+ for signal in signals:
47
+ if not isinstance(signal, (S1D, S2D)):
48
+ raise TypeError(f"Invalid signal type {type(signal)}")
49
+
50
+ if loc is None: # => We will plot all the signals on different subplots
51
+ loc = tuple([i for i in range(len(args))]) # Create (0, 1, 2, ...) that serves as ax number for plots
52
+ else: # Incase loc is provided, we make sure that we do it for each of the signals
53
+ assert len(args) == len(loc)
54
+
55
+ # Make sure that all the elements in loc do not miss any number in between (0, 1, 1, 2, 4) -> Not allowed
56
+ assert min(loc) == 0
57
+ max_loc = max(loc)
58
+ for i in range(max_loc):
59
+ if i not in loc:
60
+ raise ValueError(f"Invalid `loc` values, it should not have any missing integer in between.")
61
+
62
+ # Create a dict that maps subplot to signals that need to be plotted on that subplot e.g. {0: [signal1, signal3], ...}
63
+ subplot_signal_map = defaultdict(list)
64
+ for signal, i in zip(args, loc):
65
+ subplot_signal_map[i].append(signal)
66
+
67
+ # We need to create a figure with right configurations
68
+ height_ratios = []
69
+ height_1d_subplot = 0.4
70
+ height_2d_subplot = 1
71
+ n_1d_subplots = 0
72
+ n_2d_subplots = 0
73
+ for l, signals in subplot_signal_map.items():
74
+
75
+ # If there is any 2D signal, the subplot will be 2D
76
+ if any(isinstance(s, S2D) for s in signals):
77
+ n_2d_subplots += 1
78
+ height_ratios.append(height_2d_subplot)
79
+
80
+ # If all are 1D signal, the subplot will be 1D
81
+ elif all(isinstance(s, S1D) for s in signals):
82
+ n_1d_subplots += 1
83
+ height_ratios.append(height_1d_subplot)
84
+
85
+
86
+ n_subplots = n_1d_subplots + n_2d_subplots
87
+ fig_width = 15
88
+ fig_height = n_1d_subplots * 2 + n_2d_subplots * 4 # This is as per the figsize height set in the plotter tool
89
+ fig, axs = plt.subplots(n_subplots, 2, figsize=(fig_width, fig_height), width_ratios=[1, 0.01], height_ratios=height_ratios) # 2nd column for cbar
90
+
91
+ if n_subplots == 1:
92
+ axs = [axs] # axs becomes list of one pair [ (ax, cbar_ax) ]
93
+
94
+ # We find the x axis limits as per the max limit for all the signals combined, so that all the signals can be seen.
95
+ if x_lim is None:
96
+ x_min = min(np.min(signal.x.values) for signal in args)
97
+ x_max = max(np.max(signal.x.values) for signal in args)
98
+ x_lim = (x_min, x_max)
99
+
100
+ for l, signals in subplot_signal_map.items():
101
+ # Incase we have plot multiple signals in the same subplot, we change the color
102
+ fmt_cycle = itertools.cycle(['k-', 'r-', 'g-', 'b-', 'm-', 'c-', 'y-'])
103
+
104
+ # For each subplot, we want to know if it is 2D or 1D
105
+ if any(isinstance(s, S2D) for s in signals): is_1d_subplot = False
106
+ else: is_1d_subplot = True
107
+
108
+ if is_1d_subplot: # All the signals are 1D
109
+ for signal in signals:
110
+ fmt = next(fmt_cycle)
111
+ if len(signals) == 1: # highlight region works properly only if there is one signal for a subplot
112
+ signal.plot(axs[l][0], x_lim=x_lim, highlight_regions=highlight_regions, show_grid=True, vlines=vlines, fmt=fmt, legend=signal._title)
113
+ else:
114
+ y, x = signal._x._values, signal._y._values
115
+ signal.plot(axs[l][0], x_lim=x_lim, show_grid=True, vlines=vlines, fmt=fmt, legend=signal.title, y_label=signal.y.label, x_label=signal.x.label, title="")
116
+
117
+ # Remove the colorbar column (if the subplot is 1d)
118
+ axs[l][1].remove()
119
+
120
+ if not is_1d_subplot: # Atleast 1 signal is 2D, we we have a 2D subplot
121
+ for signal in signals:
122
+ if len(signals) == 1: # Only one 2D signal is to be plotted
123
+ signal.plot(axs[l][0], x_lim=x_lim, show_colorbar=True, cax=axs[l][1], highlight_regions=highlight_regions, vlines=vlines)
124
+ else:
125
+ if isinstance(signal, S1D):
126
+ fmt = next(fmt_cycle)
127
+ signal.plot(axs[l][0], x_lim=x_lim, show_grid=True, vlines=vlines, fmt=fmt, legend=signal.title, y_label=signal.y.label, x_label=signal.x.label, title="")
128
+ elif isinstance(signal, S2D):
129
+ signal.plot(axs[l][0], x_lim=x_lim, show_colorbar=True, cax=axs[l][1], vlines=vlines, x_label=signal.x.label, y_label=signal.y.label, title="")
130
+
131
+ # We set the xlim, this will align all the signals automatically as they are on the same row
132
+ for l in range(n_subplots):
133
+ axs[l][0].set_xlim(x_lim)
134
+
135
+ if _in_notebook():
136
+ plt.tight_layout()
137
+ plt.close(fig)
138
+ return fig
139
+ else:
140
+ plt.tight_layout()
141
+ plt.show()
142
+ return fig