code-loader 1.0.42__tar.gz → 1.0.43__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.42
3
+ Version: 1.0.43
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/code-loader
6
6
  License: MIT
@@ -13,6 +13,7 @@ Classifier: Programming Language :: Python :: 3.8
13
13
  Classifier: Programming Language :: Python :: 3.9
14
14
  Classifier: Programming Language :: Python :: 3.10
15
15
  Classifier: Programming Language :: Python :: 3.11
16
+ Requires-Dist: matplotlib (>=3.3,<3.4)
16
17
  Requires-Dist: numpy (>=1.22.3,<2.0.0)
17
18
  Requires-Dist: psutil (>=5.9.5,<6.0.0)
18
19
  Project-URL: Repository, https://github.com/tensorleap/code-loader
@@ -0,0 +1,600 @@
1
+ from typing import List, Any, Union
2
+
3
+ import numpy as np
4
+ import numpy.typing as npt
5
+ from dataclasses import dataclass
6
+
7
+ import matplotlib.pyplot as plt # type: ignore
8
+
9
+ from code_loader.contract.enums import LeapDataType
10
+ from code_loader.contract.responsedataclasses import BoundingBox
11
+
12
+
13
+ class LeapValidationError(Exception):
14
+ pass
15
+
16
+
17
+ def validate_type(actual: Any, expected: Any, prefix_message: str = '') -> None:
18
+ if not isinstance(expected, list):
19
+ expected = [expected]
20
+ if actual not in expected:
21
+ if len(expected) == 1:
22
+ raise LeapValidationError(
23
+ f'{prefix_message}.\n'
24
+ f'visualizer returned unexpected type. got {actual}, instead of {expected[0]}')
25
+ else:
26
+ raise LeapValidationError(
27
+ f'{prefix_message}.\n'
28
+ f'visualizer returned unexpected type. got {actual}, allowed is one of {expected}')
29
+
30
+
31
+ @dataclass
32
+ class LeapImage:
33
+ """
34
+ Visualizer representing an image for Tensorleap.
35
+
36
+ Attributes:
37
+ data (npt.NDArray[np.float32] | npt.NDArray[np.uint8]): The image data.
38
+ type (LeapDataType): The data type, default is LeapDataType.Image.
39
+
40
+ Example:
41
+ image_data = np.random.rand(100, 100, 3).astype(np.float32)
42
+ leap_image = LeapImage(data=image_data)
43
+ """
44
+ data: Union[npt.NDArray[np.float32], npt.NDArray[np.uint8]]
45
+ type: LeapDataType = LeapDataType.Image
46
+
47
+ def __post_init__(self) -> None:
48
+ validate_type(self.type, LeapDataType.Image)
49
+ validate_type(type(self.data), np.ndarray)
50
+ validate_type(self.data.dtype, [np.uint8, np.float32])
51
+ validate_type(len(self.data.shape), 3, 'Image must be of shape 3')
52
+ validate_type(self.data.shape[2], [1, 3], 'Image channel must be either 3(rgb) or 1(gray)')
53
+
54
+ def plot_visualizer(self) -> None:
55
+ """
56
+ Display the image contained in the LeapImage object.
57
+
58
+ Returns:
59
+ None
60
+
61
+ Example:
62
+ image_data = np.random.rand(100, 100, 3).astype(np.float32)
63
+ leap_image = LeapImage(data=image_data)
64
+ leap_image.plot_visualizer()
65
+ """
66
+ image_data = self.data
67
+
68
+ # If the image has one channel, convert it to a 3-channel image for display
69
+ if image_data.shape[2] == 1:
70
+ image_data = np.repeat(image_data, 3, axis=2)
71
+
72
+ fig, ax = plt.subplots()
73
+ fig.patch.set_facecolor('black')
74
+ ax.set_facecolor('black')
75
+
76
+ ax.imshow(image_data)
77
+
78
+ plt.axis('off')
79
+ plt.title('Leap Image Visualization', color='white')
80
+ plt.show()
81
+
82
+
83
+ @dataclass
84
+ class LeapImageWithBBox:
85
+ """
86
+ Visualizer representing an image with bounding boxes for Tensorleap, used for object detection tasks.
87
+
88
+ Attributes:
89
+ data (npt.NDArray[np.float32] | npt.NDArray[np.uint8]): The image data, shaped [H, W, 3] or [H, W, 1].
90
+ bounding_boxes (List[BoundingBox]): List of Tensorleap bounding boxes objects in relative size to image size.
91
+ type (LeapDataType): The data type, default is LeapDataType.ImageWithBBox.
92
+
93
+ Example:
94
+ image_data = np.random.rand(100, 100, 3).astype(np.float32)
95
+ bbox = BoundingBox(x=0.5, y=0.5, width=0.2, height=0.2, confidence=0.9, label="object")
96
+ leap_image_with_bbox = LeapImageWithBBox(data=image_data, bounding_boxes=[bbox])
97
+ """
98
+ data: Union[npt.NDArray[np.float32], npt.NDArray[np.uint8]]
99
+ bounding_boxes: List[BoundingBox]
100
+ type: LeapDataType = LeapDataType.ImageWithBBox
101
+
102
+ def __post_init__(self) -> None:
103
+ validate_type(self.type, LeapDataType.ImageWithBBox)
104
+ validate_type(type(self.data), np.ndarray)
105
+ validate_type(self.data.dtype, [np.uint8, np.float32])
106
+ validate_type(len(self.data.shape), 3, 'Image must be of shape 3')
107
+ validate_type(self.data.shape[2], [1, 3], 'Image channel must be either 3(rgb) or 1(gray)')
108
+
109
+ def plot_visualizer(self) -> None:
110
+ """
111
+ Plot an image with overlaid bounding boxes.
112
+
113
+ Returns:
114
+ None
115
+
116
+ Example:
117
+ image_data = np.random.rand(100, 100, 3).astype(np.float32)
118
+ bbox = BoundingBox(x=0.5, y=0.5, width=0.2, height=0.2, confidence=0.9, label="object")
119
+ leap_image_with_bbox = LeapImageWithBBox(data=image_data, bounding_boxes=[bbox])
120
+ leap_image_with_bbox.plot_visualizer()
121
+ """
122
+
123
+ image = self.data
124
+ bounding_boxes = self.bounding_boxes
125
+
126
+ # Create figure and axes
127
+ fig, ax = plt.subplots(1)
128
+ fig.patch.set_facecolor('black')
129
+ ax.set_facecolor('black')
130
+
131
+ # Display the image
132
+ ax.imshow(image)
133
+ ax.set_title('Leap Image With BBox Visualization', color='white')
134
+
135
+ # Draw bounding boxes on the image
136
+ for bbox in bounding_boxes:
137
+ x, y, width, height = bbox.x, bbox.y, bbox.width, bbox.height
138
+ confidence, label = bbox.confidence, bbox.label
139
+
140
+ # Convert relative coordinates to absolute coordinates
141
+ abs_x = x * image.shape[1]
142
+ abs_y = y * image.shape[0]
143
+ abs_width = width * image.shape[1]
144
+ abs_height = height * image.shape[0]
145
+
146
+ # Create a rectangle patch
147
+ rect = plt.Rectangle(
148
+ (abs_x - abs_width / 2, abs_y - abs_height / 2),
149
+ abs_width, abs_height,
150
+ linewidth=3, edgecolor='r', facecolor='none'
151
+ )
152
+
153
+ # Add the rectangle to the axes
154
+ ax.add_patch(rect)
155
+
156
+ # Display label and confidence
157
+ ax.text(abs_x - abs_width / 2, abs_y - abs_height / 2 - 5,
158
+ f"{label} {confidence:.2f}", color='r', fontsize=8,
159
+ bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', boxstyle='round,pad=0.3'))
160
+
161
+ # Show the image with bounding boxes
162
+ plt.show()
163
+
164
+ @dataclass
165
+ class LeapGraph:
166
+ """
167
+ Visualizer representing a line chart data for Tensorleap.
168
+
169
+ Attributes:
170
+ data (npt.NDArray[np.float32]): The array data, shaped [M, N] where M is the number of data points and N is the number of variables.
171
+ type (LeapDataType): The data type, default is LeapDataType.Graph.
172
+
173
+ Example:
174
+ graph_data = np.random.rand(100, 3).astype(np.float32)
175
+ leap_graph = LeapGraph(data=graph_data)
176
+ """
177
+ data: npt.NDArray[np.float32]
178
+ type: LeapDataType = LeapDataType.Graph
179
+
180
+ def __post_init__(self) -> None:
181
+ validate_type(self.type, LeapDataType.Graph)
182
+ validate_type(type(self.data), np.ndarray)
183
+ validate_type(self.data.dtype, np.float32)
184
+ validate_type(len(self.data.shape), 2, 'Graph must be of shape 2')
185
+
186
+ def plot_visualizer(self) -> None:
187
+ """
188
+ Display the line chart contained in the LeapGraph object.
189
+
190
+ Returns:
191
+ None
192
+
193
+ Example:
194
+ graph_data = np.random.rand(100, 3).astype(np.float32)
195
+ leap_graph = LeapGraph(data=graph_data)
196
+ leap_graph.plot_visualizer()
197
+ """
198
+ graph_data = self.data
199
+ num_variables = graph_data.shape[1]
200
+
201
+ fig, ax = plt.subplots(figsize=(10, 6))
202
+
203
+ # Set the background color to black
204
+ fig.patch.set_facecolor('black')
205
+ ax.set_facecolor('black')
206
+
207
+ for i in range(num_variables):
208
+ plt.plot(graph_data[:, i], label=f'Variable {i + 1}')
209
+
210
+ ax.set_xlabel('Data Points', color='white')
211
+ ax.set_ylabel('Values', color='white')
212
+ ax.set_title('Leap Graph Visualization', color='white')
213
+ ax.legend()
214
+ ax.grid(True, color='white')
215
+
216
+ # Change the color of the tick labels to white
217
+ ax.tick_params(colors='white')
218
+
219
+ plt.show()
220
+
221
+ @dataclass
222
+ class LeapText:
223
+ """
224
+ Visualizer representing text data for Tensorleap.
225
+
226
+ Attributes:
227
+ data (List[str]): The text data, consisting of a list of text tokens. If the model requires fixed-length inputs,
228
+ it is recommended to maintain the fixed length, using empty strings ('') instead of padding tokens ('PAD') e.g., ['I', 'ate', 'a', 'banana', '', '', '', ...]
229
+ type (LeapDataType): The data type, default is LeapDataType.Text.
230
+
231
+ Example:
232
+ text_data = ['I', 'ate', 'a', 'banana', '', '', '']
233
+ leap_text = LeapText(data=text_data) # Create LeapText object
234
+ LeapText(leap_text)
235
+ """
236
+ data: List[str]
237
+ type: LeapDataType = LeapDataType.Text
238
+
239
+ def __post_init__(self) -> None:
240
+ validate_type(self.type, LeapDataType.Text)
241
+ validate_type(type(self.data), list)
242
+ for value in self.data:
243
+ validate_type(type(value), str)
244
+
245
+ def plot_visualizer(self) -> None:
246
+ """
247
+ Display the text contained in the LeapText object.
248
+
249
+ Returns:
250
+ None
251
+
252
+ Example:
253
+ text_data = ['I', 'ate', 'a', 'banana', '', '', '']
254
+ leap_text = LeapText(data=text_data)
255
+ leap_text.plot_visualizer()
256
+ """
257
+ text_data = self.data
258
+ # Join the text tokens into a single string, ignoring empty strings
259
+ display_text = ' '.join([token for token in text_data if token])
260
+
261
+ # Create a black image using Matplotlib
262
+ fig, ax = plt.subplots(figsize=(10, 5))
263
+ fig.patch.set_facecolor('black')
264
+ ax.set_facecolor('black')
265
+
266
+ # Hide the axes
267
+ ax.axis('off')
268
+
269
+ # Set the text properties
270
+ font_size = 20
271
+ font_color = 'white'
272
+
273
+ # Add the text to the image
274
+ ax.text(0.5, 0.5, display_text, color=font_color, fontsize=font_size, ha='center', va='center')
275
+ ax.set_title('Leap Text Visualization', color='white')
276
+
277
+ # Display the image
278
+ plt.show()
279
+
280
+
281
+ @dataclass
282
+ class LeapHorizontalBar:
283
+ """
284
+ Visualizer representing horizontal bar data for Tensorleap.
285
+ For example, this can be used to visualize the model's prediction scores in a classification problem.
286
+
287
+ Attributes:
288
+ body (npt.NDArray[np.float32]): The data for the bar, shaped [C], where C is the number of data points.
289
+ labels (List[str]): Labels for the horizontal bar; e.g., when visualizing the model's classification output, labels are the class names.
290
+ Length of `body` should match the length of `labels`, C.
291
+ type (LeapDataType): The data type, default is LeapDataType.HorizontalBar.
292
+
293
+ Example:
294
+ body_data = np.random.rand(5).astype(np.float32)
295
+ labels = ['Class A', 'Class B', 'Class C', 'Class D', 'Class E']
296
+ leap_horizontal_bar = LeapHorizontalBar(body=body_data, labels=labels)
297
+ """
298
+ body: npt.NDArray[np.float32]
299
+ labels: List[str]
300
+ type: LeapDataType = LeapDataType.HorizontalBar
301
+
302
+ def __post_init__(self) -> None:
303
+ validate_type(self.type, LeapDataType.HorizontalBar)
304
+ validate_type(type(self.body), np.ndarray)
305
+ validate_type(self.body.dtype, np.float32)
306
+ validate_type(len(self.body.shape), 1, 'HorizontalBar body must be of shape 1')
307
+
308
+ validate_type(type(self.labels), list)
309
+ for label in self.labels:
310
+ validate_type(type(label), str)
311
+
312
+ def plot_visualizer(self) -> None:
313
+ """
314
+ Display the horizontal bar chart contained in the LeapHorizontalBar object.
315
+
316
+ Returns:
317
+ None
318
+
319
+ Example:
320
+ body_data = np.random.rand(5).astype(np.float32)
321
+ labels = ['Class A', 'Class B', 'Class C', 'Class D', 'Class E']
322
+ leap_horizontal_bar = LeapHorizontalBar(body=body_data, labels=labels)
323
+ leap_horizontal_bar.plot_visualizer()
324
+ """
325
+ body_data = self.body
326
+ labels = self.labels
327
+
328
+ fig, ax = plt.subplots()
329
+
330
+ fig.patch.set_facecolor('black')
331
+ ax.set_facecolor('black')
332
+
333
+ # Plot horizontal bar chart
334
+ ax.barh(labels, body_data, color='green')
335
+
336
+ # Set the color of the labels and title to white
337
+ ax.set_xlabel('Scores', color='white')
338
+ ax.set_title('Leap Horizontal Bar Visualization', color='white')
339
+
340
+ # Set the color of the ticks to white
341
+ ax.tick_params(axis='x', colors='white')
342
+ ax.tick_params(axis='y', colors='white')
343
+
344
+ plt.show()
345
+
346
+ @dataclass
347
+ class LeapImageMask:
348
+ """
349
+ Visualizer representing an image with a mask for Tensorleap.
350
+ This can be used for tasks such as segmentation, and other applications where it is important to highlight specific regions within an image.
351
+
352
+ Attributes:
353
+ mask (npt.NDArray[np.uint8]): The mask data, shaped [H, W].
354
+ image (npt.NDArray[np.float32] | npt.NDArray[np.uint8]): The image data, shaped [H, W, 3] or shaped [H, W, 1].
355
+ labels (List[str]): Labels associated with the mask regions; e.g., class names for segmented objects. The length of `labels` should match the number of unique values in `mask`.
356
+ type (LeapDataType): The data type, default is LeapDataType.ImageMask.
357
+
358
+ Example:
359
+ image_data = np.random.rand(100, 100, 3).astype(np.float32)
360
+ mask_data = np.random.randint(0, 2, (100, 100)).astype(np.uint8)
361
+ labels = ["background", "object"]
362
+ leap_image_mask = LeapImageMask(image=image_data, mask=mask_data, labels=labels)
363
+ """
364
+ mask: npt.NDArray[np.uint8]
365
+ image: Union[npt.NDArray[np.float32], npt.NDArray[np.uint8]]
366
+ labels: List[str]
367
+ type: LeapDataType = LeapDataType.ImageMask
368
+
369
+ def __post_init__(self) -> None:
370
+ validate_type(self.type, LeapDataType.ImageMask)
371
+ validate_type(type(self.mask), np.ndarray)
372
+ validate_type(self.mask.dtype, np.uint8)
373
+ validate_type(len(self.mask.shape), 2, 'image mask must be of shape 2')
374
+ validate_type(type(self.image), np.ndarray)
375
+ validate_type(self.image.dtype, [np.uint8, np.float32])
376
+ validate_type(len(self.image.shape), 3, 'Image must be of shape 3')
377
+ validate_type(self.image.shape[2], [1, 3], 'Image channel must be either 3(rgb) or 1(gray)')
378
+ validate_type(type(self.labels), list)
379
+ for label in self.labels:
380
+ validate_type(type(label), str)
381
+
382
+ def plot_visualizer(self) -> None:
383
+ """
384
+ Plots an image with overlaid masks given a LeapImageMask visualizer object.
385
+
386
+ Returns:
387
+ None
388
+
389
+
390
+ Example:
391
+ image_data = np.random.rand(100, 100, 3).astype(np.float32)
392
+ mask_data = np.random.randint(0, 2, (100, 100)).astype(np.uint8)
393
+ labels = ["background", "object"]
394
+ leap_image_mask = LeapImageMask(image=image_data, mask=mask_data, labels=labels)
395
+ leap_image_mask.plot_visualizer()
396
+ """
397
+
398
+ image = self.image
399
+ mask = self.mask
400
+ labels = self.labels
401
+
402
+ # Create a color map for each label
403
+ colors = plt.cm.jet(np.linspace(0, 1, len(labels)))
404
+
405
+ # Make a copy of the image to draw on
406
+ overlayed_image = image.copy()
407
+
408
+ # Iterate through the unique values in the mask (excluding 0)
409
+ for i, label in enumerate(labels):
410
+ # Extract binary mask for the current instance
411
+ instance_mask = (mask == (i + 1))
412
+
413
+ # fill the instance mask with a translucent color
414
+ overlayed_image[instance_mask] = (
415
+ overlayed_image[instance_mask] * (1 - 0.5) + np.array(colors[i][:3], dtype=np.uint8) * 0.5)
416
+
417
+ # Display the result using matplotlib
418
+ fig, ax = plt.subplots(1)
419
+ fig.patch.set_facecolor('black') # Set the figure background to black
420
+ ax.set_facecolor('black') # Set the axis background to black
421
+
422
+ ax.imshow(overlayed_image)
423
+ ax.set_title('Leap Image With Mask Visualization', color='white')
424
+ plt.axis('off') # Hide the axis
425
+ plt.show()
426
+
427
+
428
+ @dataclass
429
+ class LeapTextMask:
430
+ """
431
+ Visualizer representing text data with a mask for Tensorleap.
432
+ This can be used for tasks such as named entity recognition (NER), sentiment analysis, and other applications where it is important to highlight specific tokens or parts of the text.
433
+
434
+ Attributes:
435
+ mask (npt.NDArray[np.uint8]): The mask data, shaped [L].
436
+ text (List[str]): The text data, consisting of a list of text tokens, length of L.
437
+ labels (List[str]): Labels associated with the masked tokens; e.g., named entities or sentiment categories. The length of `labels` should match the number of unique values in `mask`.
438
+ type (LeapDataType): The data type, default is LeapDataType.TextMask.
439
+
440
+ Example:
441
+ text_data = ['I', 'ate', 'a', 'banana', '', '', '']
442
+ mask_data = np.array([0, 0, 0, 1, 0, 0, 0]).astype(np.uint8)
443
+ labels = ["object"]
444
+ leap_text_mask = LeapTextMask(text=text_data, mask=mask_data, labels=labels)
445
+ leap_text_mask.plot_visualizer()
446
+ """
447
+ mask: npt.NDArray[np.uint8]
448
+ text: List[str]
449
+ labels: List[str]
450
+ type: LeapDataType = LeapDataType.TextMask
451
+
452
+ def __post_init__(self) -> None:
453
+ validate_type(self.type, LeapDataType.TextMask)
454
+ validate_type(type(self.mask), np.ndarray)
455
+ validate_type(self.mask.dtype, np.uint8)
456
+ validate_type(len(self.mask.shape), 1, 'text mask must be of shape 1')
457
+ validate_type(type(self.text), list)
458
+ for t in self.text:
459
+ validate_type(type(t), str)
460
+ validate_type(type(self.labels), list)
461
+ for label in self.labels:
462
+ validate_type(type(label), str)
463
+
464
+ def plot_visualizer(self) -> None:
465
+ """
466
+ Plots text with overlaid masks given a LeapTextMask visualizer object.
467
+
468
+ Returns:
469
+ None
470
+
471
+ Example:
472
+ text_data = ['I', 'ate', 'a', 'banana', '', '', '']
473
+ mask_data = np.array([0, 0, 0, 1, 0, 0, 0]).astype(np.uint8)
474
+ labels = ["object"]
475
+ leap_text_mask = LeapTextMask(text=text_data, mask=mask_data, labels=labels)
476
+ """
477
+
478
+ text_data = self.text
479
+ mask_data = self.mask
480
+ labels = self.labels
481
+
482
+ # Create a color map for each label
483
+ colors = plt.cm.jet(np.linspace(0, 1, len(labels)))
484
+
485
+ # Create a figure and axis
486
+ fig, ax = plt.subplots()
487
+
488
+ # Set background to black
489
+ fig.patch.set_facecolor('black')
490
+ ax.set_facecolor('black')
491
+ ax.set_title('Leap Text Mask Visualization', color='white')
492
+ ax.axis('off')
493
+
494
+ # Set initial position
495
+ x_pos, y_pos = 0.01, 0.5 # Adjusted initial position for better visibility
496
+
497
+ # Display the text with colors
498
+ for token, mask_value in zip(text_data, mask_data):
499
+ if mask_value > 0:
500
+ color = colors[mask_value % len(colors)]
501
+ bbox = dict(facecolor=color, edgecolor='none',
502
+ boxstyle='round,pad=0.3') # Background color for masked tokens
503
+ else:
504
+ bbox = None
505
+
506
+ ax.text(x_pos, y_pos, token, fontsize=12, color='white', ha='left', va='center', bbox=bbox)
507
+
508
+ # Update the x position for the next token
509
+ x_pos += len(token) * 0.03 + 0.02 # Adjust the spacing between tokens
510
+
511
+ plt.show()
512
+
513
+
514
+ @dataclass
515
+ class LeapImageWithHeatmap:
516
+ """
517
+ Visualizer representing an image with heatmaps for Tensorleap.
518
+ This can be used for tasks such as highlighting important regions in an image, visualizing attention maps, and other applications where it is important to overlay heatmaps on images.
519
+
520
+ Attributes:
521
+ image (npt.NDArray[np.float32]): The image data, shaped [H, W, C], where C is the number of channels.
522
+ heatmaps (npt.NDArray[np.float32]): The heatmap data, shaped [N, H, W], where N is the number of heatmaps.
523
+ labels (List[str]): Labels associated with the heatmaps; e.g., feature names or attention regions. The length of `labels` should match the number of heatmaps, N.
524
+ type (LeapDataType): The data type, default is LeapDataType.ImageWithHeatmap.
525
+
526
+ Example:
527
+ image_data = np.random.rand(100, 100, 3).astype(np.float32)
528
+ heatmaps = np.random.rand(3, 100, 100).astype(np.float32)
529
+ labels = ["heatmap1", "heatmap2", "heatmap3"]
530
+ leap_image_with_heatmap = LeapImageWithHeatmap(image=image_data, heatmaps=heatmaps, labels=labels)
531
+ """
532
+ image: npt.NDArray[np.float32]
533
+ heatmaps: npt.NDArray[np.float32]
534
+ labels: List[str]
535
+ type: LeapDataType = LeapDataType.ImageWithHeatmap
536
+
537
+ def __post_init__(self) -> None:
538
+ validate_type(self.type, LeapDataType.ImageWithHeatmap)
539
+ validate_type(type(self.heatmaps), np.ndarray)
540
+ validate_type(self.heatmaps.dtype, np.float32)
541
+ validate_type(type(self.image), np.ndarray)
542
+ validate_type(self.image.dtype, np.float32)
543
+ validate_type(type(self.labels), list)
544
+ for label in self.labels:
545
+ validate_type(type(label), str)
546
+ if self.heatmaps.shape[0] != len(self.labels):
547
+ raise LeapValidationError(
548
+ 'Number of heatmaps and labels must be equal')
549
+
550
+ def plot_visualizer(self) -> None:
551
+ """
552
+ Display the image with overlaid heatmaps contained in the LeapImageWithHeatmap object.
553
+
554
+ Returns:
555
+ None
556
+
557
+ Example:
558
+ image_data = np.random.rand(100, 100, 3).astype(np.float32)
559
+ heatmaps = np.random.rand(3, 100, 100).astype(np.float32)
560
+ labels = ["heatmap1", "heatmap2", "heatmap3"]
561
+ leap_image_with_heatmap = LeapImageWithHeatmap(image=image_data, heatmaps=heatmaps, labels=labels)
562
+ leap_image_with_heatmap.plot_visualizer()
563
+ """
564
+ image = self.image
565
+ heatmaps = self.heatmaps
566
+ labels = self.labels
567
+
568
+ # Plot the base image
569
+ fig, ax = plt.subplots()
570
+ fig.patch.set_facecolor('black') # Set the figure background to black
571
+ ax.set_facecolor('black') # Set the axis background to black
572
+ ax.imshow(image, cmap='gray')
573
+
574
+ # Overlay each heatmap with some transparency
575
+ for i in range(len(labels)):
576
+ heatmap = heatmaps[i]
577
+ ax.imshow(heatmap, cmap='jet', alpha=0.5) # Adjust alpha for transparency
578
+ ax.set_title(f'Heatmap: {labels[i]}', color='white')
579
+
580
+ # Display a colorbar for the heatmap
581
+ cbar = plt.colorbar(ax.imshow(heatmap, cmap='jet', alpha=0.5))
582
+ cbar.set_label(labels[i], color='white')
583
+ cbar.ax.yaxis.set_tick_params(color='white') # Set color for the colorbar ticks
584
+ plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), color='white') # Set color for the colorbar labels
585
+
586
+ plt.axis('off')
587
+ plt.title('Leap Image With Heatmaps Visualization', color='white')
588
+ plt.show()
589
+
590
+
591
+ map_leap_data_type_to_visualizer_class = {
592
+ LeapDataType.Image.value: LeapImage,
593
+ LeapDataType.Graph.value: LeapGraph,
594
+ LeapDataType.Text.value: LeapText,
595
+ LeapDataType.HorizontalBar.value: LeapHorizontalBar,
596
+ LeapDataType.ImageMask.value: LeapImageMask,
597
+ LeapDataType.TextMask.value: LeapTextMask,
598
+ LeapDataType.ImageWithBBox.value: LeapImageWithBBox,
599
+ LeapDataType.ImageWithHeatmap.value: LeapImageWithHeatmap
600
+ }
@@ -452,3 +452,5 @@ class LeapBinder:
452
452
  self.check_preprocess(preprocess_result)
453
453
  self.check_handlers(preprocess_result)
454
454
  print("Successful!")
455
+
456
+
@@ -1,5 +1,4 @@
1
1
  from enum import Enum
2
- from typing import List
3
2
 
4
3
  import numpy as np
5
4
  import numpy.typing as npt
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "code-loader"
3
- version = "1.0.42"
3
+ version = "1.0.43"
4
4
  description = ""
5
5
  authors = ["dorhar <doron.harnoy@tensorleap.ai>"]
6
6
  license = "MIT"
@@ -15,6 +15,7 @@ include = [
15
15
  python = ">=3.8,<3.12"
16
16
  numpy = "^1.22.3"
17
17
  psutil = "^5.9.5"
18
+ matplotlib = ">=3.3,<3.4"
18
19
 
19
20
  [tool.poetry.dev-dependencies]
20
21
  pytest = "^7.1.1"
@@ -1,238 +0,0 @@
1
- from typing import List, Any, Union
2
-
3
- import numpy as np
4
- import numpy.typing as npt
5
- from dataclasses import dataclass
6
-
7
- from code_loader.contract.enums import LeapDataType
8
- from code_loader.contract.responsedataclasses import BoundingBox
9
-
10
-
11
- class LeapValidationError(Exception):
12
- pass
13
-
14
-
15
- def validate_type(actual: Any, expected: Any, prefix_message: str = '') -> None:
16
- if not isinstance(expected, list):
17
- expected = [expected]
18
- if actual not in expected:
19
- if len(expected) == 1:
20
- raise LeapValidationError(
21
- f'{prefix_message}.\n'
22
- f'visualizer returned unexpected type. got {actual}, instead of {expected[0]}')
23
- else:
24
- raise LeapValidationError(
25
- f'{prefix_message}.\n'
26
- f'visualizer returned unexpected type. got {actual}, allowed is one of {expected}')
27
-
28
-
29
- @dataclass
30
- class LeapImage:
31
- """
32
- Visualizer representing an image for Tensorleap.
33
-
34
- Attributes:
35
- data (npt.NDArray[np.float32] | npt.NDArray[np.uint8]): The image data.
36
- type (LeapDataType): The data type, default is LeapDataType.Image.
37
- """
38
- data: Union[npt.NDArray[np.float32], npt.NDArray[np.uint8]]
39
- type: LeapDataType = LeapDataType.Image
40
-
41
- def __post_init__(self) -> None:
42
- validate_type(self.type, LeapDataType.Image)
43
- validate_type(type(self.data), np.ndarray)
44
- validate_type(self.data.dtype, [np.uint8, np.float32])
45
- validate_type(len(self.data.shape), 3, 'Image must be of shape 3')
46
- validate_type(self.data.shape[2], [1, 3], 'Image channel must be either 3(rgb) or 1(gray)')
47
-
48
-
49
- @dataclass
50
- class LeapImageWithBBox:
51
- """
52
- Visualizer representing an image with bounding boxes for Tensorleap, used for object detection tasks.
53
-
54
- Attributes:
55
- data (npt.NDArray[np.float32] | npt.NDArray[np.uint8]): The image data, shaped [H, W, 3] or [H, W, 1].
56
- bounding_boxes (List[BoundingBox]): List of Tensorleap bounding boxes objects in relative size to image size.
57
- type (LeapDataType): The data type, default is LeapDataType.ImageWithBBox.
58
- """
59
- data: Union[npt.NDArray[np.float32], npt.NDArray[np.uint8]]
60
- bounding_boxes: List[BoundingBox]
61
- type: LeapDataType = LeapDataType.ImageWithBBox
62
-
63
- def __post_init__(self) -> None:
64
- validate_type(self.type, LeapDataType.ImageWithBBox)
65
- validate_type(type(self.data), np.ndarray)
66
- validate_type(self.data.dtype, [np.uint8, np.float32])
67
- validate_type(len(self.data.shape), 3, 'Image must be of shape 3')
68
- validate_type(self.data.shape[2], [1, 3], 'Image channel must be either 3(rgb) or 1(gray)')
69
-
70
-
71
- @dataclass
72
- class LeapGraph:
73
- """
74
- Visualizer representing a line chart data for Tensorleap.
75
-
76
- Attributes:
77
- data (npt.NDArray[np.float32]): The array data, shaped [M, N] where M is the number of data points and N is the number of variables.
78
- type (LeapDataType): The data type, default is LeapDataType.Graph.
79
- """
80
- data: npt.NDArray[np.float32]
81
- type: LeapDataType = LeapDataType.Graph
82
-
83
- def __post_init__(self) -> None:
84
- validate_type(self.type, LeapDataType.Graph)
85
- validate_type(type(self.data), np.ndarray)
86
- validate_type(self.data.dtype, np.float32)
87
- validate_type(len(self.data.shape), 2, 'Graph must be of shape 2')
88
-
89
-
90
- @dataclass
91
- class LeapText:
92
- """
93
- Visualizer representing text data for Tensorleap.
94
-
95
- Attributes:
96
- data (List[str]): The text data, consisting of a list of text tokens. If the model requires fixed-length inputs,
97
- it is recommended to maintain the fixed length, using empty strings ('') instead of padding tokens ('PAD') e.g., ['I', 'ate', 'a', 'banana', '', '', '', ...]
98
- type (LeapDataType): The data type, default is LeapDataType.Text.
99
- """
100
- data: List[str]
101
- type: LeapDataType = LeapDataType.Text
102
-
103
- def __post_init__(self) -> None:
104
- validate_type(self.type, LeapDataType.Text)
105
- validate_type(type(self.data), list)
106
- for value in self.data:
107
- validate_type(type(value), str)
108
-
109
-
110
- @dataclass
111
- class LeapHorizontalBar:
112
- """
113
- Visualizer representing horizontal bar data for Tensorleap.
114
- For example, this can be used to visualize the model's prediction scores in a classification problem.
115
-
116
- Attributes:
117
- body (npt.NDArray[np.float32]): The data for the bar, shaped [C], where C is the number of data points.
118
- labels (List[str]): Labels for the horizontal bar; e.g., when visualizing the model's classification output, labels are the class names.
119
- Length of `body` should match the length of `labels`, C.
120
- type (LeapDataType): The data type, default is LeapDataType.HorizontalBar.
121
- """
122
- body: npt.NDArray[np.float32]
123
- labels: List[str]
124
- type: LeapDataType = LeapDataType.HorizontalBar
125
-
126
- def __post_init__(self) -> None:
127
- validate_type(self.type, LeapDataType.HorizontalBar)
128
- validate_type(type(self.body), np.ndarray)
129
- validate_type(self.body.dtype, np.float32)
130
- validate_type(len(self.body.shape), 1, 'HorizontalBar body must be of shape 1')
131
-
132
- validate_type(type(self.labels), list)
133
- for label in self.labels:
134
- validate_type(type(label), str)
135
-
136
-
137
- @dataclass
138
- class LeapImageMask:
139
- """
140
- Visualizer representing an image with a mask for Tensorleap.
141
- This can be used for tasks such as segmentation, and other applications where it is important to highlight specific regions within an image.
142
-
143
- Attributes:
144
- mask (npt.NDArray[np.uint8]): The mask data, shaped [H, W].
145
- image (npt.NDArray[np.float32] | npt.NDArray[np.uint8]): The image data, shaped [H, W, 3] or shaped [H, W, 1].
146
- labels (List[str]): Labels associated with the mask regions; e.g., class names for segmented objects. The length of `labels` should match the number of unique values in `mask`.
147
- type (LeapDataType): The data type, default is LeapDataType.ImageMask.
148
- """
149
- mask: npt.NDArray[np.uint8]
150
- image: Union[npt.NDArray[np.float32], npt.NDArray[np.uint8]]
151
- labels: List[str]
152
- type: LeapDataType = LeapDataType.ImageMask
153
-
154
- def __post_init__(self) -> None:
155
- validate_type(self.type, LeapDataType.ImageMask)
156
- validate_type(type(self.mask), np.ndarray)
157
- validate_type(self.mask.dtype, np.uint8)
158
- validate_type(len(self.mask.shape), 2, 'image mask must be of shape 2')
159
- validate_type(type(self.image), np.ndarray)
160
- validate_type(self.image.dtype, [np.uint8, np.float32])
161
- validate_type(len(self.image.shape), 3, 'Image must be of shape 3')
162
- validate_type(self.image.shape[2], [1, 3], 'Image channel must be either 3(rgb) or 1(gray)')
163
- validate_type(type(self.labels), list)
164
- for label in self.labels:
165
- validate_type(type(label), str)
166
-
167
-
168
- @dataclass
169
- class LeapTextMask:
170
- """
171
- Visualizer representing text data with a mask for Tensorleap.
172
- This can be used for tasks such as named entity recognition (NER), sentiment analysis, and other applications where it is important to highlight specific tokens or parts of the text.
173
-
174
- Attributes:
175
- mask (npt.NDArray[np.uint8]): The mask data, shaped [L].
176
- text (List[str]): The text data, consisting of a list of text tokens, length of L.
177
- labels (List[str]): Labels associated with the masked tokens; e.g., named entities or sentiment categories. The length of `labels` should match the number of unique values in `mask`.
178
- type (LeapDataType): The data type, default is LeapDataType.TextMask.
179
- """
180
- mask: npt.NDArray[np.uint8]
181
- text: List[str]
182
- labels: List[str]
183
- type: LeapDataType = LeapDataType.TextMask
184
-
185
- def __post_init__(self) -> None:
186
- validate_type(self.type, LeapDataType.TextMask)
187
- validate_type(type(self.mask), np.ndarray)
188
- validate_type(self.mask.dtype, np.uint8)
189
- validate_type(len(self.mask.shape), 1, 'text mask must be of shape 1')
190
- validate_type(type(self.text), list)
191
- for t in self.text:
192
- validate_type(type(t), str)
193
- validate_type(type(self.labels), list)
194
- for label in self.labels:
195
- validate_type(type(label), str)
196
-
197
-
198
- @dataclass
199
- class LeapImageWithHeatmap:
200
- """
201
- Visualizer representing an image with heatmaps for Tensorleap.
202
- This can be used for tasks such as highlighting important regions in an image, visualizing attention maps, and other applications where it is important to overlay heatmaps on images.
203
-
204
- Attributes:
205
- image (npt.NDArray[np.float32]): The image data, shaped [H, W, C], where C is the number of channels.
206
- heatmaps (npt.NDArray[np.float32]): The heatmap data, shaped [N, H, W], where N is the number of heatmaps.
207
- labels (List[str]): Labels associated with the heatmaps; e.g., feature names or attention regions. The length of `labels` should match the number of heatmaps, N.
208
- type (LeapDataType): The data type, default is LeapDataType.ImageWithHeatmap.
209
- """
210
- image: npt.NDArray[np.float32]
211
- heatmaps: npt.NDArray[np.float32]
212
- labels: List[str]
213
- type: LeapDataType = LeapDataType.ImageWithHeatmap
214
-
215
- def __post_init__(self) -> None:
216
- validate_type(self.type, LeapDataType.ImageWithHeatmap)
217
- validate_type(type(self.heatmaps), np.ndarray)
218
- validate_type(self.heatmaps.dtype, np.float32)
219
- validate_type(type(self.image), np.ndarray)
220
- validate_type(self.image.dtype, np.float32)
221
- validate_type(type(self.labels), list)
222
- for label in self.labels:
223
- validate_type(type(label), str)
224
- if self.heatmaps.shape[0] != len(self.labels):
225
- raise LeapValidationError(
226
- 'Number of heatmaps and labels must be equal')
227
-
228
-
229
- map_leap_data_type_to_visualizer_class = {
230
- LeapDataType.Image.value: LeapImage,
231
- LeapDataType.Graph.value: LeapGraph,
232
- LeapDataType.Text.value: LeapText,
233
- LeapDataType.HorizontalBar.value: LeapHorizontalBar,
234
- LeapDataType.ImageMask.value: LeapImageMask,
235
- LeapDataType.TextMask.value: LeapTextMask,
236
- LeapDataType.ImageWithBBox.value: LeapImageWithBBox,
237
- LeapDataType.ImageWithHeatmap.value: LeapImageWithHeatmap
238
- }
File without changes
File without changes