modusa 0.3.41__py3-none-any.whl → 0.3.43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
modusa/tools/plotter.py CHANGED
@@ -1,39 +1,26 @@
1
1
  #!/usr/bin/env python3
2
2
 
3
+ #---------------------------------
4
+ # Author: Ankit Anand
5
+ # Date: 26/08/25
6
+ # Email: ankit0.anand0@gmail.com
7
+ #---------------------------------
3
8
 
4
- import numpy as np
9
+ from pathlib import Path
10
+ import matplotlib as mpl
11
+ import matplotlib.font_manager as fm
5
12
  import matplotlib.pyplot as plt
6
13
  import matplotlib.gridspec as gridspec
7
14
  from matplotlib.patches import Rectangle
8
- from mpl_toolkits.axes_grid1.inset_locator import inset_axes
9
15
 
10
- # Helper for 2D plot
11
- def _calculate_extent(x, y):
12
- # Handle spacing safely
13
- if len(x) > 1:
14
- dx = x[1] - x[0]
15
- else:
16
- dx = 1 # Default spacing for single value
17
- if len(y) > 1:
18
- dy = y[1] - y[0]
19
- else:
20
- dy = 1 # Default spacing for single value
21
-
22
- return [
23
- x[0] - dx / 2,
24
- x[-1] + dx / 2,
25
- y[0] - dy / 2,
26
- y[-1] + dy / 2
27
- ]
28
-
29
- # Helper to load fonts (devnagri)
30
- def set_default_hindi_font():
16
+ import numpy as np
17
+
18
+ #===== Loading Devanagari font ========
19
+ def _load_devanagari_font():
31
20
  """
32
- Hindi fonts works for both english and hindi.
21
+ Load devanagari font as it works for both English and Hindi.
33
22
  """
34
- from pathlib import Path
35
- import matplotlib as mpl
36
- import matplotlib.font_manager as fm
23
+
37
24
  # Path to your bundled font
38
25
  font_path = Path(__file__).resolve().parents[1] / "fonts" / "NotoSansDevanagari-Regular.ttf"
39
26
 
@@ -45,432 +32,714 @@ def set_default_hindi_font():
45
32
 
46
33
  # Set as default rcParam
47
34
  mpl.rcParams['font.family'] = hindi_font.get_name()
35
+
36
+ _load_devanagari_font()
37
+ #==============
48
38
 
49
- set_default_hindi_font()
39
+ #--------------------------
40
+ # Figuure 1D
41
+ #--------------------------
42
+ class Figure1D:
43
+ """
44
+ A utility class that provides easy-to-use API
45
+ for plotting 1D signals along with clean
46
+ representations of annotations, events.
50
47
 
51
- #======== 1D ===========
52
- def plot1d(*args, ann=None, events=None, xlim=None, ylim=None, xlabel=None, ylabel=None, title=None, legend=None, fmt=None, show_grid=False, show_stem=False):
53
- """
54
- Plots a 1D signal using matplotlib.
48
+
55
49
 
56
- .. code-block:: python
50
+ Parameters
51
+ ----------
52
+ n_aux_subplots: int
53
+ - Total number of auxiliary subplots
54
+ - These include annotations and events subplots.
55
+ - Default: 0
56
+ title: str
57
+ - Title of the figure.
58
+ - Default: Title
59
+ xlim: tuple[number, number]
60
+ - xlim for the figure.
61
+ - All subplots will be automatically adjusted.
62
+ - Default: None
63
+ ylim: tuple[number, number]
64
+ - ylim for the signal.
65
+ - Default: None
66
+ """
57
67
 
58
- import modusa as ms
59
- import numpy as np
68
+ def __init__(self, n_aux_subplots=0, xlim=None, ylim=None):
69
+ self._n_aux_subplots: int = n_aux_subplots
70
+ self._active_subplot_idx: int = 1 # Any addition will happen on this subplot (0 is reserved for reference axis)
71
+ self._xlim = xlim # Many add functions depend on this, so we fix it while instantiating the class
72
+ self._ylim = ylim
73
+ self._subplots, self._fig = self._generate_subplots() # Will contain all the subplots (list, fig)
74
+
75
+ def _get_active_subplot(self):
76
+ """
77
+ Get the active subplot where you can add
78
+ either annotations or events.
79
+ """
80
+ active_subplot = self._subplots[self._active_subplot_idx]
81
+ self._active_subplot_idx += 1
82
+
83
+ return active_subplot
84
+
85
+ def _generate_subplots(self):
86
+ """
87
+ Generate subplots based on the configuration.
88
+ """
89
+
90
+ n_aux_subplots = self._n_aux_subplots
91
+
92
+ # Fixed heights per subplot type
93
+ ref_height = 0.0
94
+ aux_height = 0.4
95
+ signal_height = 2.0
96
+
97
+ # Total number of subplots
98
+ n_subplots = 1 + n_aux_subplots + 1
99
+
100
+ # Compute total fig height
101
+ fig_height = ref_height + n_aux_subplots * aux_height + signal_height
102
+
103
+ # Define height ratios
104
+ height_ratios = [ref_height] + [aux_height] * n_aux_subplots + [signal_height]
105
+
106
+ # Create figure and grid
107
+ fig = plt.figure(figsize=(16, fig_height))
108
+ gs = gridspec.GridSpec(n_subplots, 1, height_ratios=height_ratios)
109
+
110
+ # Add subplots
111
+ subplots_list = []
112
+ ref_subplot = fig.add_subplot(gs[0, 0])
113
+ ref_subplot.axis("off")
114
+ subplots_list.append(ref_subplot)
115
+
116
+ for i in range(1, n_subplots):
117
+ subplots_list.append(fig.add_subplot(gs[i, 0], sharex=ref_subplot))
60
118
 
61
- x = np.arange(100) / 100
62
- y = np.sin(x)
119
+ for i in range(n_subplots - 1):
120
+ subplots_list[i].tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
121
+
122
+ # Set xlim
123
+ if self._xlim is not None: # xlim should be applied on reference subplot, rest all sharex
124
+ ref_subplot.set_xlim(self._xlim)
63
125
 
64
- display(ms.plot1d(y, x))
126
+ # Set ylim
127
+ if self._ylim is not None: # ylim should be applied on the signal subplot
128
+ subplots_list[-1].set_ylim(self._ylim)
129
+
130
+ fig.subplots_adjust(hspace=0.01, wspace=0.05)
65
131
 
132
+ return subplots_list, fig
66
133
 
134
+
135
+ def add_events(self, events, c="k", ls="-", lw=1.5, label="Event Label"):
136
+ """
137
+ Add events to the figure.
138
+
67
139
  Parameters
68
140
  ----------
69
- *args : tuple[array-like, array-like] | tuple[array-like]
70
- - The signal y and axis x to be plotted.
71
- - If only values are provided, we generate the axis using arange.
72
- - E.g. (y1, x1), (y2, x2), ...
73
- ann : list[tuple[Number, Number, str] | None
74
- - A list of annotations to mark specific points. Each tuple should be of the form (start, end, label).
75
- - Default: None => No annotation.
76
- events : list[Number] | None
77
- - A list of x-values where vertical lines (event markers) will be drawn.
78
- - Default: None
79
- xlim : tuple[Number, Number] | None
80
- - Limits for the x-axis as (xmin, xmax).
81
- - Default: None
82
- ylim : tuple[Number, Number] | None
83
- - Limits for the y-axis as (ymin, ymax).
84
- - Default: None
85
- xlabel : str | None
86
- - Label for the x-axis.
87
- - - Default: None
88
- ylabel : str | None
89
- - Label for the y-axis.
90
- - Default: None
91
- title : str | None
92
- - Title of the plot.
93
- - Default: None
94
- legend : list[str] | None
95
- - List of legend labels corresponding to each signal if plotting multiple lines.
96
- - Default: None
97
- fmt: list[str] | None
98
- - linefmt for different line plots.
99
- - Default: None
100
- show_grid: bool
101
- - If you want to show the grid.
102
- - Default: False
103
- show_stem: bool:
104
- - If you want stem plot.
105
- - Default: False
106
-
141
+ events: np.ndarray
142
+ - All the event marker values.
143
+ c: str
144
+ - Color of the event marker.
145
+ - Default: "k"
146
+ ls: str
147
+ - Line style.
148
+ - Default: "-"
149
+ lw: float
150
+ - Linewidth.
151
+ - Default: 1.5
152
+ label: str
153
+ - Label for the event type.
154
+ - This will appear in the legend.
155
+ - Default: "Event label"
156
+
107
157
  Returns
108
158
  -------
109
- plt.Figure
110
- Matplolib figure.
159
+ None
111
160
  """
112
- for arg in args:
113
- if len(arg) not in [1, 2]: # 1 if it just provides values, 2 if it provided axis as well
114
- raise ValueError(f"1D signal needs to have max 2 arrays (y, x) or simply (y, )")
115
-
116
- if isinstance(legend, str): legend = (legend, )
117
- if legend is not None:
118
- if len(legend) < len(args):
119
- raise ValueError(f"`legend` should be provided for each signal.")
161
+ event_subplot = self._get_active_subplot()
162
+ xlim = self._xlim
120
163
 
121
- if isinstance(fmt, str): fmt = [fmt]
122
- if fmt is not None:
123
- if len(fmt) < len(args):
124
- raise ValueError(f"`fmt` should be provided for each signal.")
125
-
126
- colors = plt.get_cmap('tab10').colors
127
-
128
- fig = plt.figure(figsize=(16, 2))
129
- gs = gridspec.GridSpec(2, 1, height_ratios=[0.2, 1])
164
+ for i, event in enumerate(events):
165
+ if xlim is not None:
166
+ if xlim[0] <= event <= xlim[1]:
167
+ if i == 0: # Label should be set only once for all the events
168
+ event_subplot.axvline(x=event, color=c, linestyle=ls, linewidth=lw, label=label)
169
+ else:
170
+ event_subplot.axvline(x=event, color=c, linestyle=ls, linewidth=lw)
171
+ else:
172
+ if i == 0: # Label should be set only once for all the events
173
+ event_subplot.axvline(x=event, color=c, linestyle=ls, linewidth=lw, label=label)
174
+ else:
175
+ event_subplot.axvline(x=event, color=c, linestyle=ls, linewidth=lw)
130
176
 
131
- signal_ax = fig.add_subplot(gs[1, 0])
132
- annotation_ax = fig.add_subplot(gs[0, 0], sharex=signal_ax)
177
+ def add_annotation(self, ann):
178
+ """
179
+ Add annotation to the figure.
133
180
 
134
- # Set lim
135
- if xlim is not None:
136
- signal_ax.set_xlim(xlim)
181
+ Parameters
182
+ ----------
183
+ ann : list[tuple[Number, Number, str]] | None
184
+ - A list of annotation spans. Each tuple should be (start, end, label).
185
+ - Default: None (no annotations).
137
186
 
138
- if ylim is not None:
139
- signal_ax.set_ylim(ylim)
140
-
141
- # Add signal plot
142
- for i, signal in enumerate(args):
143
- if len(signal) == 1:
144
- y = signal[0]
145
- x = np.arange(y.size)
146
- if legend is not None:
147
- if show_stem is True:
148
- markerline, stemlines, baseline = signal_ax.stem(x, y, label=legend[i])
149
- markerline.set_color(colors[i])
150
- stemlines.set_color(colors[i])
151
- baseline.set_color("k")
152
- else:
153
- if fmt is not None:
154
- signal_ax.plot(x, y, fmt[i], markersize=4, label=legend[i])
155
- else:
156
- signal_ax.plot(x, y, color=colors[i], label=legend[i])
157
- else:
158
- if show_stem is True:
159
- markerline, stemlines, baseline = signal_ax.stem(x, y)
160
- markerline.set_color(colors[i])
161
- stemlines.set_color(colors[i])
162
- baseline.set_color("k")
163
- else:
164
- if fmt is not None:
165
- signal_ax.plot(x, y, fmt[i], markersize=4)
166
- else:
167
- signal_ax.plot(x, y, color=colors[i])
168
-
169
- elif len(signal) == 2:
170
- y, x = signal[0], signal[1]
171
- if legend is not None:
172
- if show_stem is True:
173
- markerline, stemlines, baseline = signal_ax.stem(x, y, label=legend[i])
174
- markerline.set_color(colors[i])
175
- stemlines.set_color(colors[i])
176
- baseline.set_color("k")
177
- else:
178
- if fmt is not None:
179
- signal_ax.plot(x, y, fmt[i], markersize=4, label=legend[i])
180
- else:
181
- signal_ax.plot(x, y, color=colors[i], label=legend[i])
182
- else:
183
- if show_stem is True:
184
- markerline, stemlines, baseline = signal_ax.stem(x, y)
185
- markerline.set_color(colors[i])
186
- stemlines.set_color(colors[i])
187
- baseline.set_color("k")
188
- else:
189
- if fmt is not None:
190
- signal_ax.plot(x, y, fmt[i], markersize=4)
191
- else:
192
- signal_ax.plot(x, y, color=colors[i])
193
-
187
+ Returns
188
+ -------
189
+ None
190
+ """
194
191
 
195
- # Add annotations
196
- if ann is not None:
197
- annotation_ax.set_ylim(0, 1) # For consistent layout
198
- # Determine visible x-range
199
- x_view_min = xlim[0] if xlim is not None else np.min(x)
200
- x_view_max = xlim[1] if xlim is not None else np.max(x)
192
+ ann_subplot = self._get_active_subplot()
193
+ xlim = self._xlim
194
+
195
+ for i, (start, end, tag) in enumerate(ann):
201
196
 
202
- for i, (start, end, tag) in enumerate(ann):
203
- # We make sure that we only plot annotation that are within the x range of the current view
204
- if start >= x_view_max or end <= x_view_min:
197
+ # We make sure that we only plot annotation that are within the x range of the current view
198
+ if xlim is not None:
199
+ if start >= xlim[1] or end <= xlim[0]:
205
200
  continue
206
201
 
207
202
  # Clip boundaries to xlim
208
- start = max(start, x_view_min)
209
- end = min(end, x_view_max)
210
-
211
- color = colors[i % len(colors)]
203
+ start = max(start, xlim[0])
204
+ end = min(end, xlim[1])
205
+
206
+ box_colors = ["gray", "lightgray"] # Alternates color between two
207
+ box_color = box_colors[i % 2]
208
+
212
209
  width = end - start
213
- rect = Rectangle((start, 0), width, 1, color=color, alpha=0.7)
214
- annotation_ax.add_patch(rect)
210
+ rect = Rectangle((start, 0), width, 1, facecolor=box_color, edgecolor="black", alpha=0.7)
211
+ ann_subplot.add_patch(rect)
215
212
 
216
- text_obj = annotation_ax.text(
213
+ text_obj = ann_subplot.text(
217
214
  (start + end) / 2, 0.5, tag,
218
215
  ha='center', va='center',
219
- fontsize=10, color='white', fontweight='bold', zorder=10, clip_on=True
216
+ fontsize=10, color="black", fontweight='bold', zorder=10, clip_on=True
220
217
  )
221
218
 
222
219
  text_obj.set_clip_path(rect)
220
+ else:
221
+ box_colors = ["gray", "lightgray"] # Alternates color between two
222
+ box_color = box_colors[i % 2]
223
223
 
224
+ width = end - start
225
+ rect = Rectangle((start, 0), width, 1, facecolor=box_color, edgecolor="black", alpha=0.7)
226
+ ann_subplot.add_patch(rect)
224
227
 
225
- # Add vlines
226
- if events is not None:
227
- for xpos in events:
228
- if xlim is not None:
229
- if xlim[0] <= xpos <= xlim[1]:
230
- annotation_ax.axvline(x=xpos, color='black', linestyle='--', linewidth=1.5)
231
- else:
232
- annotation_ax.axvline(x=xpos, color='black', linestyle='--', linewidth=1.5)
233
-
234
- # Add legend
235
- if legend is not None:
236
- handles, labels = signal_ax.get_legend_handles_labels()
237
- fig.legend(handles, labels, loc='upper right', bbox_to_anchor=(0.9, 1.2), ncol=len(legend), frameon=True)
238
-
239
- # Set title, labels
240
- if title is not None:
241
- annotation_ax.set_title(title, pad=10, size=11)
242
- if xlabel is not None:
243
- signal_ax.set_xlabel(xlabel)
244
- if ylabel is not None:
245
- signal_ax.set_ylabel(ylabel)
228
+ text_obj = ann_subplot.text(
229
+ (start + end) / 2, 0.5, tag,
230
+ ha='center', va='center',
231
+ fontsize=10, color="black", fontweight='bold', zorder=10, clip_on=True
232
+ )
233
+
234
+ text_obj.set_clip_path(rect)
235
+
236
+ def add_signal(self, y, x=None, c=None, ls="-", lw=1, m=None, ms=5, label="Signal"):
237
+ """
238
+ Add signal to the figure.
246
239
 
247
- # Add grid to the plot
248
- if show_grid is True:
249
- signal_ax.grid(True, linestyle=':', linewidth=0.7, color='gray', alpha=0.7)
240
+ Parameters
241
+ ----------
242
+ y: np.ndarray
243
+ - Signal y values.
244
+ x: np.ndarray | None
245
+ - Signal x values.
246
+ - Default: None (indices will be used)
247
+ c: str
248
+ - Color of the line.
249
+ - Default: None
250
+ ls: str
251
+ - Linestyle
252
+ - Default: "-"
253
+ lw: Number
254
+ - Linewidth
255
+ - Default: 1
256
+ m: str
257
+ - Marker
258
+ - Default: None
259
+ ms: number
260
+ - Markersize
261
+ - Default: 5
262
+ label: str
263
+ - Label for the plot.
264
+ - Legend will use this.
265
+ - Default: "Signal"
266
+
267
+ Returns
268
+ -------
269
+ None
270
+ """
271
+ if x is None:
272
+ x = np.arange(y.size)
273
+ signal_subplot = self._subplots[-1]
274
+ signal_subplot.plot(x, y, color=c, linestyle=ls, linewidth=lw, marker=m, markersize=ms, label=label)
250
275
 
251
- # Remove the boundaries and ticks from an axis
252
- if ann is not None:
253
- annotation_ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
254
- else:
255
- annotation_ax.axis("off")
276
+ def add_legend(self, ypos=1.3):
277
+ """
278
+ Add legend to the figure.
279
+
280
+ Parameters
281
+ ----------
282
+ ypos: float
283
+ - y position from the top.
284
+ - > 1 to push it higher, < 1 to push it lower
285
+ - Default: 1.3
286
+
287
+ Returns
288
+ -------
289
+ None
290
+ """
291
+ subplots: list = self._subplots
292
+ fig = self._fig
256
293
 
294
+ all_handles, all_labels = [], []
257
295
 
258
- fig.subplots_adjust(hspace=0.01, wspace=0.05)
259
- plt.close()
260
- return fig
261
-
262
- #======== 2D ===========
263
- def plot2d(*args, ann=None, events=None, xlim=None, ylim=None, origin="lower", Mlabel=None, xlabel=None, ylabel=None, title=None, legend=None, lm=False, show_grid=False):
264
- """
265
- Plots a 2D matrix (e.g., spectrogram or heatmap) with optional annotations and events.
296
+ for subplot in subplots:
297
+ handles, labels = subplot.get_legend_handles_labels()
298
+ all_handles.extend(handles)
299
+ all_labels.extend(labels)
300
+
301
+ # remove duplicates if needed
302
+ fig.legend(all_handles, all_labels, loc='upper right', bbox_to_anchor=(0.9, ypos), ncol=2, frameon=True, bbox_transform=fig.transFigure)
303
+
304
+ def add_meta_info(self, title="Title", ylabel="Y", xlabel="X"):
305
+ """
306
+ Add meta info to the figure.
266
307
 
267
- .. code-block:: python
308
+ Parameters
309
+ ----------
310
+ title: str
311
+ - Title of the figure.
312
+ - Default: "Title"
313
+ ylabel: str
314
+ - y label of the signal.
315
+ - It will only appear in the signal subplot.
316
+ - Default: "Y"
317
+ xlabel: str
318
+ - x label of the signal.
319
+ - It will only appear in the signal subplot.
320
+ - Default: "X"
268
321
 
269
- import modusa as ms
270
- import numpy as np
322
+ Returns
323
+ -------
324
+ None
325
+ """
326
+ subplots: list = self._subplots
327
+ fig = self._fig
328
+
329
+ ref_subplot = subplots[0]
330
+ signal_subplot = subplots[-1]
331
+
332
+ ref_subplot.set_title(title, pad=10, size=14)
333
+ signal_subplot.set_xlabel(xlabel, size=12)
334
+ signal_subplot.set_ylabel(ylabel, size=12)
271
335
 
272
- M = np.random.random((10, 30))
273
- y = np.arange(M.shape[0])
274
- x = np.arange(M.shape[1])
275
336
 
276
- display(ms.plot2d(M, y, x))
337
+ def save(self, path="./figure.png"):
338
+ """
339
+ Save the figure.
277
340
 
278
- Parameters
279
- ----------
280
- *args : tuple[array-like, array-like]
281
- - The signal values to be plotted.
282
- - E.g. (M1, y1, x1), (M2, y2, x2), ...
283
- ann : list[tuple[Number, Number, str]] | None
284
- - A list of annotation spans. Each tuple should be (start, end, label).
285
- - Default: None (no annotations).
286
- events : list[Number] | None
287
- - X-values where vertical event lines will be drawn.
288
- - Default: None.
289
- xlim : tuple[Number, Number] | None
290
- - Limits for the x-axis as (xmin, xmax).
291
- - Default: None (auto-scaled).
292
- ylim : tuple[Number, Number] | None
293
- - Limits for the y-axis as (ymin, ymax).
294
- - Default: None (auto-scaled).
295
- origin : {'upper', 'lower'}
296
- - Origin position for the image display. Used in `imshow`.
297
- - Default: "lower".
298
- Mlabel : str | None
299
- - Label for the colorbar (e.g., "Magnitude", "Energy").
300
- - Default: None.
301
- xlabel : str | None
302
- - Label for the x-axis.
303
- - Default: None.
304
- ylabel : str | None
305
- - Label for the y-axis.
306
- - Default: None.
307
- title : str | None
308
- - Title of the plot.
309
- - Default: None.
310
- legend : list[str] | None
311
- - Legend labels for any overlaid lines or annotations.
312
- - Default: None.
313
- lm: bool
314
- - Adds a circular marker for the line.
315
- - Default: False
316
- - Useful to show the data points.
317
- show_grid: bool
318
- - If you want to show the grid.
319
- - Default: False
341
+ Parameters
342
+ ----------
343
+ path: str
344
+ - Path to the output file.
345
+
346
+ Returns
347
+ -------
348
+ None
349
+ """
350
+ fig = self._fig
351
+ fig.savefig(path, bbox_inches="tight")
352
+
320
353
 
321
- Returns
322
- -------
323
- matplotlib.figure.Figure
324
- The matplotlib Figure object.
354
+ #--------------------------
355
+ # Figuure 2D
356
+ #--------------------------
357
+ class Figure2D:
325
358
  """
359
+ A utility class that provides easy-to-use API
360
+ for plotting 2D signals along with clean
361
+ representations of annotations, events.
362
+
326
363
 
327
- for arg in args:
328
- if len(arg) not in [1, 2, 3]: # Either provide just the matrix or with both axes info
329
- raise ValueError(f"Data to plot needs to have 3 arrays (M, y, x)")
330
- if isinstance(legend, str): legend = (legend, )
331
-
332
- fig = plt.figure(figsize=(16, 4))
333
- gs = gridspec.GridSpec(3, 1, height_ratios=[0.2, 0.1, 1]) # colorbar, annotation, signal
334
364
 
335
- colors = plt.get_cmap('tab10').colors
365
+ Parameters
366
+ ----------
367
+ n_aux_subplots: int
368
+ - Total number of auxiliary subplots
369
+ - These include annotations and events subplots.
370
+ - Default: 0
371
+ title: str
372
+ - Title of the figure.
373
+ - Default: Title
374
+ xlim: tuple[number, number]
375
+ - xlim for the figure.
376
+ - All subplots will be automatically adjusted.
377
+ - Default: None
378
+ ylim: tuple[number, number]
379
+ - ylim for the signal.
380
+ - Default: None
381
+ """
336
382
 
337
- signal_ax = fig.add_subplot(gs[2, 0])
338
- annotation_ax = fig.add_subplot(gs[1, 0], sharex=signal_ax)
383
+ def __init__(self, n_aux_subplots=0, xlim=None, ylim=None):
384
+ self._n_aux_subplots: int = n_aux_subplots
385
+ self._active_subplot_idx: int = 1 # Any addition will happen on this subplot (0 is reserved for reference axis)
386
+ self._xlim = xlim # Many add functions depend on this, so we fix it while instantiating the class
387
+ self._ylim = ylim
388
+ self._subplots, self._fig = self._generate_subplots() # Will contain all the subplots (list, fig)
389
+ self._im = None # Useful while creating colorbar for the image
390
+
391
+ def _get_active_subplot(self):
392
+ """
393
+ Get the active subplot where you can add
394
+ either annotations or events.
395
+ """
396
+ active_subplot = self._subplots[self._active_subplot_idx]
397
+ self._active_subplot_idx += 1
398
+
399
+ return active_subplot
339
400
 
340
- colorbar_ax = fig.add_subplot(gs[0, 0])
341
- colorbar_ax.axis("off")
401
+ def _calculate_extent(self, x, y):
402
+ # Handle spacing safely
403
+ if len(x) > 1:
404
+ dx = x[1] - x[0]
405
+ else:
406
+ dx = 1 # Default spacing for single value
407
+ if len(y) > 1:
408
+ dy = y[1] - y[0]
409
+ else:
410
+ dy = 1 # Default spacing for single value
411
+
412
+ return [x[0] - dx / 2, x[-1] + dx / 2, y[0] - dy / 2, y[-1] + dy / 2]
342
413
 
414
+ def _add_colorbar(self, im, label=None, width="20%", height="35%"):
415
+ from mpl_toolkits.axes_grid1.inset_locator import inset_axes
416
+
417
+ ref_subplot = self._subplots[0]
418
+
419
+ # Assume ref_subplot is your reference axes, im is the image
420
+ cax = inset_axes(
421
+ ref_subplot,
422
+ width=width, # width of colorbar
423
+ height=height, # height of colorbar
424
+ loc='right',
425
+ bbox_to_anchor=(0, 1, 1, 1), # move 0.9 right, 1.2 up from the subplot
426
+ bbox_transform=ref_subplot.transAxes, # important: use subplot coords
427
+ borderpad=0
428
+ )
429
+
430
+ cbar = plt.colorbar(im, cax=cax, orientation='horizontal')
431
+ cbar.ax.xaxis.set_ticks_position('top')
432
+ cbar.set_label(label, labelpad=5)
433
+
343
434
 
344
- # Add lim
345
- if xlim is not None:
346
- signal_ax.set_xlim(xlim)
347
435
 
348
- if ylim is not None:
349
- signal_ax.set_ylim(ylim)
436
+ def _generate_subplots(self):
437
+ """
438
+ Generate subplots based on the configuration.
439
+ """
440
+
441
+ n_aux_subplots = self._n_aux_subplots
442
+
443
+ # Fixed heights per subplot type
444
+ ref_height = 0.4
445
+ aux_height = 0.4
446
+ signal_height = 4.0
350
447
 
351
- # Add signal plot
352
- i = 0 # This is to track the legend for 1D plots
353
- for signal in args:
448
+ # Total number of subplots
449
+ n_subplots = 1 + n_aux_subplots + 1
354
450
 
355
- data = signal[0] # This can be 1D or 2D (1D meaning we have to overlay on the matrix)
451
+ # Compute total fig height
452
+ fig_height = ref_height + n_aux_subplots * aux_height + signal_height
453
+
454
+ # Define height ratios
455
+ height_ratios = [ref_height] + [aux_height] * n_aux_subplots + [signal_height]
456
+
457
+ # Create figure and grid
458
+ fig = plt.figure(figsize=(16, fig_height))
459
+ gs = gridspec.GridSpec(n_subplots, 1, height_ratios=height_ratios)
460
+
461
+ # Add subplots
462
+ subplots_list = []
463
+ ref_subplot = fig.add_subplot(gs[0, 0])
464
+ ref_subplot.axis("off")
465
+ subplots_list.append(ref_subplot)
466
+
467
+ for i in range(1, n_subplots):
468
+ subplots_list.append(fig.add_subplot(gs[i, 0], sharex=ref_subplot))
356
469
 
357
- if data.ndim == 1: # 1D
358
- if len(signal) == 1: # It means that the axis was not passed
359
- x = np.arange(data.shape[0])
360
- else:
361
- x = signal[1]
470
+ for i in range(n_subplots - 1):
471
+ subplots_list[i].tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
362
472
 
363
- if lm is False:
364
- if legend is not None:
365
- signal_ax.plot(x, data, label=legend[i])
366
- signal_ax.legend(loc="upper right")
367
- else:
368
- signal_ax.plot(x, data)
473
+ # Set xlim
474
+ if self._xlim is not None: # xlim should be applied on reference subplot, rest all sharex
475
+ ref_subplot.set_xlim(self._xlim)
476
+
477
+ # Set ylim
478
+ if self._ylim is not None: # ylim should be applied on the signal subplot
479
+ subplots_list[-1].set_ylim(self._ylim)
480
+
481
+ fig.subplots_adjust(hspace=0.01, wspace=0.05)
482
+
483
+ return subplots_list, fig
484
+
485
+
486
+ def add_events(self, events, c="k", ls="-", lw=1.5, label="Event Label"):
487
+ """
488
+ Add events to the figure.
489
+
490
+ Parameters
491
+ ----------
492
+ events: np.ndarray
493
+ - All the event marker values.
494
+ c: str
495
+ - Color of the event marker.
496
+ - Default: "k"
497
+ ls: str
498
+ - Line style.
499
+ - Default: "-"
500
+ lw: float
501
+ - Linewidth.
502
+ - Default: 1.5
503
+ label: str
504
+ - Label for the event type.
505
+ - This will appear in the legend.
506
+ - Default: "Event label"
507
+
508
+ Returns
509
+ -------
510
+ None
511
+ """
512
+ event_subplot = self._get_active_subplot()
513
+ xlim = self._xlim
514
+
515
+ for i, event in enumerate(events):
516
+ if xlim is not None:
517
+ if xlim[0] <= event <= xlim[1]:
518
+ if i == 0: # Label should be set only once for all the events
519
+ event_subplot.axvline(x=event, color=c, linestyle=ls, linewidth=lw, label=label)
520
+ else:
521
+ event_subplot.axvline(x=event, color=c, linestyle=ls, linewidth=lw)
369
522
  else:
370
- if legend is not None:
371
- signal_ax.plot(x, data, marker="o", markersize=7, markerfacecolor='red', linestyle="--", linewidth=2, label=legend[i])
372
- signal_ax.legend(loc="upper right")
523
+ if i == 0: # Label should be set only once for all the events
524
+ event_subplot.axvline(x=event, color=c, linestyle=ls, linewidth=lw, label=label)
373
525
  else:
374
- signal_ax.plot(x, data, marker="o", markersize=7, markerfacecolor='red', linestyle="--", linewidth=2)
526
+ event_subplot.axvline(x=event, color=c, linestyle=ls, linewidth=lw)
375
527
 
376
- i += 1
377
-
378
- elif data.ndim == 2: # 2D
379
- M = data
380
- if len(signal) == 1: # It means that the axes were not passed
381
- y = np.arange(M.shape[0])
382
- x = np.arange(M.shape[1])
383
- extent = _calculate_extent(x, y)
384
- im = signal_ax.imshow(M, aspect="auto", origin=origin, cmap="gray_r", extent=extent)
385
-
386
- elif len(signal) == 3: # It means that the axes were passed
387
- M, y, x = signal[0], signal[1], signal[2]
388
- extent = _calculate_extent(x, y)
389
- im = signal_ax.imshow(M, aspect="auto", origin=origin, cmap="gray_r", extent=extent)
390
-
391
- # Add annotations
392
- if ann is not None:
393
- annotation_ax.set_ylim(0, 1) # For consistent layout
394
- # Determine visible x-range
395
- x_view_min = xlim[0] if xlim is not None else np.min(x)
396
- x_view_max = xlim[1] if xlim is not None else np.max(x)
528
+ def add_annotation(self, ann):
529
+ """
530
+ Add annotation to the figure.
531
+
532
+ Parameters
533
+ ----------
534
+ ann : list[tuple[Number, Number, str]] | None
535
+ - A list of annotation spans. Each tuple should be (start, end, label).
536
+ - Default: None (no annotations).
537
+
538
+ Returns
539
+ -------
540
+ None
541
+ """
542
+ ann_subplot = self._get_active_subplot()
543
+ xlim = self._xlim
397
544
 
398
545
  for i, (start, end, tag) in enumerate(ann):
399
- # We make sure that we only plot annotation that are within the x range of the current view
400
- if start >= x_view_max or end <= x_view_min:
401
- continue
402
546
 
403
- # Clip boundaries to xlim
404
- start = max(start, x_view_min)
405
- end = min(end, x_view_max)
547
+ # We make sure that we only plot annotation that are within the x range of the current view
548
+ if xlim is not None:
549
+ if start >= xlim[1] or end <= xlim[0]:
550
+ continue
551
+
552
+ # Clip boundaries to xlim
553
+ start = max(start, xlim[0])
554
+ end = min(end, xlim[1])
555
+
556
+ box_colors = ["gray", "lightgray"] # Alternates color between two
557
+ box_color = box_colors[i % 2]
558
+
559
+ width = end - start
560
+ rect = Rectangle((start, 0), width, 1, facecolor=box_color, edgecolor="black", alpha=0.7)
561
+ ann_subplot.add_patch(rect)
562
+
563
+ text_obj = ann_subplot.text(
564
+ (start + end) / 2, 0.5, tag,
565
+ ha='center', va='center',
566
+ fontsize=10, color="black", fontweight='bold', zorder=10, clip_on=True
567
+ )
568
+
569
+ text_obj.set_clip_path(rect)
570
+ else:
571
+ box_colors = ["gray", "lightgray"] # Alternates color between two
572
+ box_color = box_colors[i % 2]
406
573
 
407
- color = colors[i % len(colors)]
408
- width = end - start
409
- rect = Rectangle((start, 0), width, 1, color=color, alpha=0.7)
410
- annotation_ax.add_patch(rect)
411
- text_obj = annotation_ax.text(
412
- (start + end) / 2, 0.5, tag,
413
- ha='center', va='center',
414
- fontsize=10, color='white', fontweight='bold', zorder=10, clip_on=True
415
- )
574
+ width = end - start
575
+ rect = Rectangle((start, 0), width, 1, facecolor=box_color, edgecolor="black", alpha=0.7)
576
+ ann_subplot.add_patch(rect)
577
+
578
+ text_obj = ann_subplot.text(
579
+ (start + end) / 2, 0.5, tag,
580
+ ha='center', va='center',
581
+ fontsize=10, color="black", fontweight='bold', zorder=10, clip_on=True
582
+ )
583
+
584
+ text_obj.set_clip_path(rect)
585
+
586
+ def add_matrix(self, M, y=None, x=None, c="gray_r", o="lower", label="Matrix"):
587
+ """
588
+ Add matrix to the figure.
416
589
 
417
- text_obj.set_clip_path(rect)
590
+ Parameters
591
+ ----------
592
+ M: np.ndarray
593
+ - Matrix (2D) array
594
+ y: np.ndarray | None
595
+ - y axis values.
596
+ x: np.ndarray | None (indices will be used)
597
+ - x axis values.
598
+ - Default: None (indices will be used)
599
+ c: str
600
+ - cmap for the matrix.
601
+ - Default: None
602
+ o: str
603
+ - origin
604
+ - Default: "lower"
605
+ label: str
606
+ - Label for the plot.
607
+ - Legend will use this.
608
+ - Default: "Signal"
609
+
610
+ Returns
611
+ -------
612
+ None
613
+ """
614
+ if x is None: x = np.arange(M.shape[1])
615
+ if y is None: y = np.arange(M.shape[0])
418
616
 
419
- # Add vlines
420
- if events is not None:
421
- for xpos in events:
422
- if xlim is not None:
423
- if xlim[0] <= xpos <= xlim[1]:
424
- annotation_ax.axvline(x=xpos, color='black', linestyle='--', linewidth=1.5)
425
- else:
426
- annotation_ax.axvline(x=xpos, color='black', linestyle='--', linewidth=1.5)
427
-
428
- # Add legend incase there are 1D overlays
429
- if legend is not None:
430
- handles, labels = signal_ax.get_legend_handles_labels()
431
- if handles: # Only add legend if there's something to show
432
- signal_ax.legend(handles, labels, loc="upper right")
433
-
434
- # Add colorbar
435
- # Create an inset axis on top-right of signal_ax
436
- cax = inset_axes(
437
- colorbar_ax,
438
- width="20%", # percentage of parent width
439
- height="20%", # height in percentage of parent height
440
- loc='upper right',
441
- bbox_to_anchor=(0, 0, 1, 1),
442
- bbox_transform=colorbar_ax.transAxes,
443
- borderpad=1
444
- )
445
-
446
- cbar = plt.colorbar(im, cax=cax, orientation='horizontal')
447
- cbar.ax.xaxis.set_ticks_position('top')
448
-
449
- if Mlabel is not None:
450
- cbar.set_label(Mlabel, labelpad=5)
617
+ matrix_subplot = self._subplots[-1]
618
+ extent = self._calculate_extent(x, y)
619
+ im = matrix_subplot.imshow(M, aspect="auto", origin=o, cmap=c, extent=extent)
620
+
621
+ self._add_colorbar(im=im, label=label)
451
622
 
623
+ def add_signal(self, y, x=None, c=None, ls="-", lw=1, m="o", ms=5, label="Signal"):
624
+ """
625
+ Add signal on the matrix.
626
+
627
+ Parameters
628
+ ----------
629
+ y: np.ndarray
630
+ - Signal y values.
631
+ x: np.ndarray | None
632
+ - Signal x values.
633
+ - Default: None (indices will be used)
634
+ c: str
635
+ - Color of the line.
636
+ - Default: None
637
+ ls: str
638
+ - Linestyle
639
+ - Default: "-"
640
+ lw: Number
641
+ - Linewidth
642
+ - Default: 1
643
+ m: str
644
+ - Marker
645
+ - Default: None
646
+ ms: number
647
+ - Markersize
648
+ - Default: 5
649
+ label: str
650
+ - Label for the plot.
651
+ - Legend will use this.
652
+ - Default: "Signal"
653
+
654
+ Returns
655
+ -------
656
+ None
657
+ """
658
+ if x is None:
659
+ x = np.arange(y.size)
660
+ matrix_subplot = self._subplots[-1]
661
+ matrix_subplot.plot(x, y, color=c, linestyle=ls, linewidth=lw, marker=m, markersize=ms, label=label)
452
662
 
453
- # Set title, labels
454
- if title is not None:
455
- annotation_ax.set_title(title, pad=10, size=11)
456
- if xlabel is not None:
457
- signal_ax.set_xlabel(xlabel)
458
- if ylabel is not None:
459
- signal_ax.set_ylabel(ylabel)
460
-
461
- # Add grid to the plot
462
- if show_grid is True:
463
- signal_ax.grid(True, linestyle=':', linewidth=0.7, color='gray', alpha=0.7)
464
663
 
465
- # Making annotation axis spines thicker
466
- if ann is not None:
467
- annotation_ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
468
- else:
469
- annotation_ax.axis("off")
664
+
665
+ def add_legend(self, ypos=1.1):
666
+ """
667
+ Add legend to the figure.
668
+
669
+ Parameters
670
+ ----------
671
+ ypos: float
672
+ - y position from the top.
673
+ - > 1 to push it higher, < 1 to push it lower
674
+ - Default: 1.3
675
+
676
+ Returns
677
+ -------
678
+ None
679
+ """
680
+ subplots: list = self._subplots
681
+ fig = self._fig
682
+
683
+ all_handles, all_labels = [], []
684
+
685
+ for subplot in subplots:
686
+ handles, labels = subplot.get_legend_handles_labels()
687
+ all_handles.extend(handles)
688
+ all_labels.extend(labels)
689
+
690
+ # remove duplicates if needed
691
+ fig.legend(all_handles, all_labels, loc='upper right', bbox_to_anchor=(0.9, ypos), ncol=2, frameon=True, bbox_transform=fig.transFigure)
692
+
693
+ def add_meta_info(self, title="Title", ylabel="Y", xlabel="X"):
694
+ """
695
+ Add meta info to the figure.
696
+
697
+ Parameters
698
+ ----------
699
+ title: str
700
+ - Title of the figure.
701
+ - Default: "Title"
702
+ ylabel: str
703
+ - y label of the signal.
704
+ - It will only appear in the signal subplot.
705
+ - Default: "Y"
706
+ xlabel: str
707
+ - x label of the signal.
708
+ - It will only appear in the signal subplot.
709
+ - Default: "X"
710
+
711
+ Returns
712
+ -------
713
+ None
714
+ """
715
+ subplots: list = self._subplots
716
+ fig = self._fig
717
+
718
+ ref_subplot = subplots[0]
719
+ signal_subplot = subplots[-1]
720
+
721
+ ref_subplot.set_title(title, pad=10, size=14)
722
+ signal_subplot.set_xlabel(xlabel, size=12)
723
+ signal_subplot.set_ylabel(ylabel, size=12)
724
+
725
+
726
+ def save(self, path="./figure.png"):
727
+ """
728
+ Save the figure.
729
+
730
+ Parameters
731
+ ----------
732
+ path: str
733
+ - Path to the output file.
734
+
735
+ Returns
736
+ -------
737
+ None
738
+ """
739
+ fig = self._fig
740
+ fig.savefig(path, bbox_inches="tight")
741
+
470
742
 
471
- fig.subplots_adjust(hspace=0.01, wspace=0.05)
472
- plt.close()
473
- return fig
474
743
 
475
744
  #======== Plot distribution ===========
476
745
  def plot_dist(*args, ann=None, xlim=None, ylim=None, ylabel=None, xlabel=None, title=None, legend=None, show_hist=True, npoints=200, bins=30):
@@ -536,7 +805,7 @@ def plot_dist(*args, ann=None, xlim=None, ylim=None, ylabel=None, xlabel=None, t
536
805
  if legend is not None:
537
806
  if len(legend) < len(args):
538
807
  raise ValueError(f"Legend should be provided for each signal.")
539
-
808
+
540
809
  # Create figure
541
810
  fig = plt.figure(figsize=(16, 4))
542
811
  gs = gridspec.GridSpec(2, 1, height_ratios=[0.1, 1])
@@ -557,11 +826,11 @@ def plot_dist(*args, ann=None, xlim=None, ylim=None, ylabel=None, xlabel=None, t
557
826
  for i, data in enumerate(args):
558
827
  # Fit gaussian to the data
559
828
  kde = gaussian_kde(data)
560
-
829
+
561
830
  # Create points to evaluate KDE
562
831
  x = np.linspace(np.min(data), np.max(data), npoints)
563
832
  y = kde(x)
564
-
833
+
565
834
  if legend is not None:
566
835
  dist_ax.plot(x, y, color=colors[i], label=legend[i])
567
836
  if show_hist is True:
@@ -570,7 +839,7 @@ def plot_dist(*args, ann=None, xlim=None, ylim=None, ylabel=None, xlabel=None, t
570
839
  dist_ax.plot(x, y, color=colors[i])
571
840
  if show_hist is True:
572
841
  dist_ax.hist(data, bins=bins, density=True, alpha=0.3, facecolor=colors[i], edgecolor='black')
573
-
842
+
574
843
  # Add annotations
575
844
  if ann is not None:
576
845
  annotation_ax.set_ylim(0, 1) # For consistent layout
@@ -585,15 +854,15 @@ def plot_dist(*args, ann=None, xlim=None, ylim=None, ylabel=None, xlabel=None, t
585
854
  # Clip boundaries to xlim
586
855
  start = max(start, x_view_min)
587
856
  end = min(end, x_view_max)
588
-
857
+
589
858
  color = colors[i % len(colors)]
590
859
  width = end - start
591
860
  rect = Rectangle((start, 0), width, 1, color=color, alpha=0.7)
592
861
  annotation_ax.add_patch(rect)
593
-
862
+
594
863
  text_obj = annotation_ax.text((start + end) / 2, 0.5, tag, ha='center', va='center', fontsize=10, color='white', fontweight='bold', zorder=10, clip_on=True)
595
864
  text_obj.set_clip_path(rect)
596
-
865
+
597
866
  # Add legend
598
867
  if legend is not None:
599
868
  handles, labels = dist_ax.get_legend_handles_labels()