code-loader 1.0.42a0__tar.gz → 1.0.43__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.42a0
3
+ Version: 1.0.43
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/code-loader
6
6
  License: MIT
@@ -13,6 +13,7 @@ Classifier: Programming Language :: Python :: 3.8
13
13
  Classifier: Programming Language :: Python :: 3.9
14
14
  Classifier: Programming Language :: Python :: 3.10
15
15
  Classifier: Programming Language :: Python :: 3.11
16
+ Requires-Dist: matplotlib (>=3.3,<3.4)
16
17
  Requires-Dist: numpy (>=1.22.3,<2.0.0)
17
18
  Requires-Dist: psutil (>=5.9.5,<6.0.0)
18
19
  Project-URL: Repository, https://github.com/tensorleap/code-loader
@@ -4,7 +4,7 @@ from typing import Any, Callable, List, Optional, Dict, Union, Type
4
4
  import numpy as np
5
5
  import numpy.typing as npt
6
6
 
7
- from code_loader.contract.enums import DataStateType, DataStateEnum, LeapDataType, ConfusionMatrixValue, MetricDirection, InstanceAnalysisType
7
+ from code_loader.contract.enums import DataStateType, DataStateEnum, LeapDataType, ConfusionMatrixValue, MetricDirection
8
8
  from code_loader.contract.visualizer_classes import LeapImage, LeapText, LeapGraph, LeapHorizontalBar, \
9
9
  LeapTextMask, LeapImageMask, LeapImageWithBBox, LeapImageWithHeatmap
10
10
 
@@ -36,16 +36,8 @@ class PreprocessResponse:
36
36
  data: Any
37
37
 
38
38
 
39
- @dataclass
40
- class ElementInstance:
41
- name: str
42
- mask: npt.NDArray[np.float32]
43
-
44
-
45
39
  SectionCallableInterface = Callable[[int, PreprocessResponse], npt.NDArray[np.float32]]
46
40
 
47
- InstanceCallableInterface = Callable[[int, PreprocessResponse], List[ElementInstance]]
48
-
49
41
  MetadataSectionCallableInterface = Union[
50
42
  Callable[[int, PreprocessResponse], int],
51
43
  Callable[[int, PreprocessResponse], Dict[str, int]],
@@ -117,7 +109,6 @@ class MetricHandler:
117
109
  arg_names: List[str]
118
110
  direction: Optional[MetricDirection] = MetricDirection.Downward
119
111
 
120
-
121
112
  @dataclass
122
113
  class RawInputsForHeatmap:
123
114
  raw_input_by_vizualizer_arg_name: Dict[str, npt.NDArray[np.float32]]
@@ -143,14 +134,6 @@ class InputHandler(DatasetBaseHandler):
143
134
  shape: Optional[List[int]] = None
144
135
 
145
136
 
146
- @dataclass
147
- class ElementInstanceHandler:
148
- input_name: str
149
- instance_function: InstanceCallableInterface
150
- analysis_type: InstanceAnalysisType
151
-
152
-
153
-
154
137
  @dataclass
155
138
  class GroundTruthHandler(DatasetBaseHandler):
156
139
  shape: Optional[List[int]] = None
@@ -184,7 +167,6 @@ class DatasetIntegrationSetup:
184
167
  unlabeled_data_preprocess: Optional[UnlabeledDataPreprocessHandler] = None
185
168
  visualizers: List[VisualizerHandler] = field(default_factory=list)
186
169
  inputs: List[InputHandler] = field(default_factory=list)
187
- element_instances: List[ElementInstanceHandler] = field(default_factory=list)
188
170
  ground_truths: List[GroundTruthHandler] = field(default_factory=list)
189
171
  metadata: List[MetadataHandler] = field(default_factory=list)
190
172
  prediction_types: List[PredictionTypeHandler] = field(default_factory=list)
@@ -64,9 +64,3 @@ class ConfusionMatrixValue(Enum):
64
64
  class TestingSectionEnum(Enum):
65
65
  Warnings = "Warnings"
66
66
  Errors = "Errors"
67
-
68
-
69
-
70
- class InstanceAnalysisType(Enum):
71
- MaskInput = "MaskInput"
72
- MaskLatentSpace = "MaskLatentSpace"
@@ -0,0 +1,600 @@
1
+ from typing import List, Any, Union
2
+
3
+ import numpy as np
4
+ import numpy.typing as npt
5
+ from dataclasses import dataclass
6
+
7
+ import matplotlib.pyplot as plt # type: ignore
8
+
9
+ from code_loader.contract.enums import LeapDataType
10
+ from code_loader.contract.responsedataclasses import BoundingBox
11
+
12
+
13
+ class LeapValidationError(Exception):
14
+ pass
15
+
16
+
17
+ def validate_type(actual: Any, expected: Any, prefix_message: str = '') -> None:
18
+ if not isinstance(expected, list):
19
+ expected = [expected]
20
+ if actual not in expected:
21
+ if len(expected) == 1:
22
+ raise LeapValidationError(
23
+ f'{prefix_message}.\n'
24
+ f'visualizer returned unexpected type. got {actual}, instead of {expected[0]}')
25
+ else:
26
+ raise LeapValidationError(
27
+ f'{prefix_message}.\n'
28
+ f'visualizer returned unexpected type. got {actual}, allowed is one of {expected}')
29
+
30
+
31
+ @dataclass
32
+ class LeapImage:
33
+ """
34
+ Visualizer representing an image for Tensorleap.
35
+
36
+ Attributes:
37
+ data (npt.NDArray[np.float32] | npt.NDArray[np.uint8]): The image data.
38
+ type (LeapDataType): The data type, default is LeapDataType.Image.
39
+
40
+ Example:
41
+ image_data = np.random.rand(100, 100, 3).astype(np.float32)
42
+ leap_image = LeapImage(data=image_data)
43
+ """
44
+ data: Union[npt.NDArray[np.float32], npt.NDArray[np.uint8]]
45
+ type: LeapDataType = LeapDataType.Image
46
+
47
+ def __post_init__(self) -> None:
48
+ validate_type(self.type, LeapDataType.Image)
49
+ validate_type(type(self.data), np.ndarray)
50
+ validate_type(self.data.dtype, [np.uint8, np.float32])
51
+ validate_type(len(self.data.shape), 3, 'Image must be of shape 3')
52
+ validate_type(self.data.shape[2], [1, 3], 'Image channel must be either 3(rgb) or 1(gray)')
53
+
54
+ def plot_visualizer(self) -> None:
55
+ """
56
+ Display the image contained in the LeapImage object.
57
+
58
+ Returns:
59
+ None
60
+
61
+ Example:
62
+ image_data = np.random.rand(100, 100, 3).astype(np.float32)
63
+ leap_image = LeapImage(data=image_data)
64
+ leap_image.plot_visualizer()
65
+ """
66
+ image_data = self.data
67
+
68
+ # If the image has one channel, convert it to a 3-channel image for display
69
+ if image_data.shape[2] == 1:
70
+ image_data = np.repeat(image_data, 3, axis=2)
71
+
72
+ fig, ax = plt.subplots()
73
+ fig.patch.set_facecolor('black')
74
+ ax.set_facecolor('black')
75
+
76
+ ax.imshow(image_data)
77
+
78
+ plt.axis('off')
79
+ plt.title('Leap Image Visualization', color='white')
80
+ plt.show()
81
+
82
+
83
+ @dataclass
84
+ class LeapImageWithBBox:
85
+ """
86
+ Visualizer representing an image with bounding boxes for Tensorleap, used for object detection tasks.
87
+
88
+ Attributes:
89
+ data (npt.NDArray[np.float32] | npt.NDArray[np.uint8]): The image data, shaped [H, W, 3] or [H, W, 1].
90
+ bounding_boxes (List[BoundingBox]): List of Tensorleap bounding boxes objects in relative size to image size.
91
+ type (LeapDataType): The data type, default is LeapDataType.ImageWithBBox.
92
+
93
+ Example:
94
+ image_data = np.random.rand(100, 100, 3).astype(np.float32)
95
+ bbox = BoundingBox(x=0.5, y=0.5, width=0.2, height=0.2, confidence=0.9, label="object")
96
+ leap_image_with_bbox = LeapImageWithBBox(data=image_data, bounding_boxes=[bbox])
97
+ """
98
+ data: Union[npt.NDArray[np.float32], npt.NDArray[np.uint8]]
99
+ bounding_boxes: List[BoundingBox]
100
+ type: LeapDataType = LeapDataType.ImageWithBBox
101
+
102
+ def __post_init__(self) -> None:
103
+ validate_type(self.type, LeapDataType.ImageWithBBox)
104
+ validate_type(type(self.data), np.ndarray)
105
+ validate_type(self.data.dtype, [np.uint8, np.float32])
106
+ validate_type(len(self.data.shape), 3, 'Image must be of shape 3')
107
+ validate_type(self.data.shape[2], [1, 3], 'Image channel must be either 3(rgb) or 1(gray)')
108
+
109
+ def plot_visualizer(self) -> None:
110
+ """
111
+ Plot an image with overlaid bounding boxes.
112
+
113
+ Returns:
114
+ None
115
+
116
+ Example:
117
+ image_data = np.random.rand(100, 100, 3).astype(np.float32)
118
+ bbox = BoundingBox(x=0.5, y=0.5, width=0.2, height=0.2, confidence=0.9, label="object")
119
+ leap_image_with_bbox = LeapImageWithBBox(data=image_data, bounding_boxes=[bbox])
120
+ leap_image_with_bbox.plot_visualizer()
121
+ """
122
+
123
+ image = self.data
124
+ bounding_boxes = self.bounding_boxes
125
+
126
+ # Create figure and axes
127
+ fig, ax = plt.subplots(1)
128
+ fig.patch.set_facecolor('black')
129
+ ax.set_facecolor('black')
130
+
131
+ # Display the image
132
+ ax.imshow(image)
133
+ ax.set_title('Leap Image With BBox Visualization', color='white')
134
+
135
+ # Draw bounding boxes on the image
136
+ for bbox in bounding_boxes:
137
+ x, y, width, height = bbox.x, bbox.y, bbox.width, bbox.height
138
+ confidence, label = bbox.confidence, bbox.label
139
+
140
+ # Convert relative coordinates to absolute coordinates
141
+ abs_x = x * image.shape[1]
142
+ abs_y = y * image.shape[0]
143
+ abs_width = width * image.shape[1]
144
+ abs_height = height * image.shape[0]
145
+
146
+ # Create a rectangle patch
147
+ rect = plt.Rectangle(
148
+ (abs_x - abs_width / 2, abs_y - abs_height / 2),
149
+ abs_width, abs_height,
150
+ linewidth=3, edgecolor='r', facecolor='none'
151
+ )
152
+
153
+ # Add the rectangle to the axes
154
+ ax.add_patch(rect)
155
+
156
+ # Display label and confidence
157
+ ax.text(abs_x - abs_width / 2, abs_y - abs_height / 2 - 5,
158
+ f"{label} {confidence:.2f}", color='r', fontsize=8,
159
+ bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', boxstyle='round,pad=0.3'))
160
+
161
+ # Show the image with bounding boxes
162
+ plt.show()
163
+
164
+ @dataclass
165
+ class LeapGraph:
166
+ """
167
+ Visualizer representing a line chart data for Tensorleap.
168
+
169
+ Attributes:
170
+ data (npt.NDArray[np.float32]): The array data, shaped [M, N] where M is the number of data points and N is the number of variables.
171
+ type (LeapDataType): The data type, default is LeapDataType.Graph.
172
+
173
+ Example:
174
+ graph_data = np.random.rand(100, 3).astype(np.float32)
175
+ leap_graph = LeapGraph(data=graph_data)
176
+ """
177
+ data: npt.NDArray[np.float32]
178
+ type: LeapDataType = LeapDataType.Graph
179
+
180
+ def __post_init__(self) -> None:
181
+ validate_type(self.type, LeapDataType.Graph)
182
+ validate_type(type(self.data), np.ndarray)
183
+ validate_type(self.data.dtype, np.float32)
184
+ validate_type(len(self.data.shape), 2, 'Graph must be of shape 2')
185
+
186
+ def plot_visualizer(self) -> None:
187
+ """
188
+ Display the line chart contained in the LeapGraph object.
189
+
190
+ Returns:
191
+ None
192
+
193
+ Example:
194
+ graph_data = np.random.rand(100, 3).astype(np.float32)
195
+ leap_graph = LeapGraph(data=graph_data)
196
+ leap_graph.plot_visualizer()
197
+ """
198
+ graph_data = self.data
199
+ num_variables = graph_data.shape[1]
200
+
201
+ fig, ax = plt.subplots(figsize=(10, 6))
202
+
203
+ # Set the background color to black
204
+ fig.patch.set_facecolor('black')
205
+ ax.set_facecolor('black')
206
+
207
+ for i in range(num_variables):
208
+ plt.plot(graph_data[:, i], label=f'Variable {i + 1}')
209
+
210
+ ax.set_xlabel('Data Points', color='white')
211
+ ax.set_ylabel('Values', color='white')
212
+ ax.set_title('Leap Graph Visualization', color='white')
213
+ ax.legend()
214
+ ax.grid(True, color='white')
215
+
216
+ # Change the color of the tick labels to white
217
+ ax.tick_params(colors='white')
218
+
219
+ plt.show()
220
+
221
+ @dataclass
222
+ class LeapText:
223
+ """
224
+ Visualizer representing text data for Tensorleap.
225
+
226
+ Attributes:
227
+ data (List[str]): The text data, consisting of a list of text tokens. If the model requires fixed-length inputs,
228
+ it is recommended to maintain the fixed length, using empty strings ('') instead of padding tokens ('PAD') e.g., ['I', 'ate', 'a', 'banana', '', '', '', ...]
229
+ type (LeapDataType): The data type, default is LeapDataType.Text.
230
+
231
+ Example:
232
+ text_data = ['I', 'ate', 'a', 'banana', '', '', '']
233
+ leap_text = LeapText(data=text_data) # Create LeapText object
234
+ LeapText(leap_text)
235
+ """
236
+ data: List[str]
237
+ type: LeapDataType = LeapDataType.Text
238
+
239
+ def __post_init__(self) -> None:
240
+ validate_type(self.type, LeapDataType.Text)
241
+ validate_type(type(self.data), list)
242
+ for value in self.data:
243
+ validate_type(type(value), str)
244
+
245
+ def plot_visualizer(self) -> None:
246
+ """
247
+ Display the text contained in the LeapText object.
248
+
249
+ Returns:
250
+ None
251
+
252
+ Example:
253
+ text_data = ['I', 'ate', 'a', 'banana', '', '', '']
254
+ leap_text = LeapText(data=text_data)
255
+ leap_text.plot_visualizer()
256
+ """
257
+ text_data = self.data
258
+ # Join the text tokens into a single string, ignoring empty strings
259
+ display_text = ' '.join([token for token in text_data if token])
260
+
261
+ # Create a black image using Matplotlib
262
+ fig, ax = plt.subplots(figsize=(10, 5))
263
+ fig.patch.set_facecolor('black')
264
+ ax.set_facecolor('black')
265
+
266
+ # Hide the axes
267
+ ax.axis('off')
268
+
269
+ # Set the text properties
270
+ font_size = 20
271
+ font_color = 'white'
272
+
273
+ # Add the text to the image
274
+ ax.text(0.5, 0.5, display_text, color=font_color, fontsize=font_size, ha='center', va='center')
275
+ ax.set_title('Leap Text Visualization', color='white')
276
+
277
+ # Display the image
278
+ plt.show()
279
+
280
+
281
+ @dataclass
282
+ class LeapHorizontalBar:
283
+ """
284
+ Visualizer representing horizontal bar data for Tensorleap.
285
+ For example, this can be used to visualize the model's prediction scores in a classification problem.
286
+
287
+ Attributes:
288
+ body (npt.NDArray[np.float32]): The data for the bar, shaped [C], where C is the number of data points.
289
+ labels (List[str]): Labels for the horizontal bar; e.g., when visualizing the model's classification output, labels are the class names.
290
+ Length of `body` should match the length of `labels`, C.
291
+ type (LeapDataType): The data type, default is LeapDataType.HorizontalBar.
292
+
293
+ Example:
294
+ body_data = np.random.rand(5).astype(np.float32)
295
+ labels = ['Class A', 'Class B', 'Class C', 'Class D', 'Class E']
296
+ leap_horizontal_bar = LeapHorizontalBar(body=body_data, labels=labels)
297
+ """
298
+ body: npt.NDArray[np.float32]
299
+ labels: List[str]
300
+ type: LeapDataType = LeapDataType.HorizontalBar
301
+
302
+ def __post_init__(self) -> None:
303
+ validate_type(self.type, LeapDataType.HorizontalBar)
304
+ validate_type(type(self.body), np.ndarray)
305
+ validate_type(self.body.dtype, np.float32)
306
+ validate_type(len(self.body.shape), 1, 'HorizontalBar body must be of shape 1')
307
+
308
+ validate_type(type(self.labels), list)
309
+ for label in self.labels:
310
+ validate_type(type(label), str)
311
+
312
+ def plot_visualizer(self) -> None:
313
+ """
314
+ Display the horizontal bar chart contained in the LeapHorizontalBar object.
315
+
316
+ Returns:
317
+ None
318
+
319
+ Example:
320
+ body_data = np.random.rand(5).astype(np.float32)
321
+ labels = ['Class A', 'Class B', 'Class C', 'Class D', 'Class E']
322
+ leap_horizontal_bar = LeapHorizontalBar(body=body_data, labels=labels)
323
+ leap_horizontal_bar.plot_visualizer()
324
+ """
325
+ body_data = self.body
326
+ labels = self.labels
327
+
328
+ fig, ax = plt.subplots()
329
+
330
+ fig.patch.set_facecolor('black')
331
+ ax.set_facecolor('black')
332
+
333
+ # Plot horizontal bar chart
334
+ ax.barh(labels, body_data, color='green')
335
+
336
+ # Set the color of the labels and title to white
337
+ ax.set_xlabel('Scores', color='white')
338
+ ax.set_title('Leap Horizontal Bar Visualization', color='white')
339
+
340
+ # Set the color of the ticks to white
341
+ ax.tick_params(axis='x', colors='white')
342
+ ax.tick_params(axis='y', colors='white')
343
+
344
+ plt.show()
345
+
346
+ @dataclass
347
+ class LeapImageMask:
348
+ """
349
+ Visualizer representing an image with a mask for Tensorleap.
350
+ This can be used for tasks such as segmentation, and other applications where it is important to highlight specific regions within an image.
351
+
352
+ Attributes:
353
+ mask (npt.NDArray[np.uint8]): The mask data, shaped [H, W].
354
+ image (npt.NDArray[np.float32] | npt.NDArray[np.uint8]): The image data, shaped [H, W, 3] or shaped [H, W, 1].
355
+ labels (List[str]): Labels associated with the mask regions; e.g., class names for segmented objects. The length of `labels` should match the number of unique values in `mask`.
356
+ type (LeapDataType): The data type, default is LeapDataType.ImageMask.
357
+
358
+ Example:
359
+ image_data = np.random.rand(100, 100, 3).astype(np.float32)
360
+ mask_data = np.random.randint(0, 2, (100, 100)).astype(np.uint8)
361
+ labels = ["background", "object"]
362
+ leap_image_mask = LeapImageMask(image=image_data, mask=mask_data, labels=labels)
363
+ """
364
+ mask: npt.NDArray[np.uint8]
365
+ image: Union[npt.NDArray[np.float32], npt.NDArray[np.uint8]]
366
+ labels: List[str]
367
+ type: LeapDataType = LeapDataType.ImageMask
368
+
369
+ def __post_init__(self) -> None:
370
+ validate_type(self.type, LeapDataType.ImageMask)
371
+ validate_type(type(self.mask), np.ndarray)
372
+ validate_type(self.mask.dtype, np.uint8)
373
+ validate_type(len(self.mask.shape), 2, 'image mask must be of shape 2')
374
+ validate_type(type(self.image), np.ndarray)
375
+ validate_type(self.image.dtype, [np.uint8, np.float32])
376
+ validate_type(len(self.image.shape), 3, 'Image must be of shape 3')
377
+ validate_type(self.image.shape[2], [1, 3], 'Image channel must be either 3(rgb) or 1(gray)')
378
+ validate_type(type(self.labels), list)
379
+ for label in self.labels:
380
+ validate_type(type(label), str)
381
+
382
+ def plot_visualizer(self) -> None:
383
+ """
384
+ Plots an image with overlaid masks given a LeapImageMask visualizer object.
385
+
386
+ Returns:
387
+ None
388
+
389
+
390
+ Example:
391
+ image_data = np.random.rand(100, 100, 3).astype(np.float32)
392
+ mask_data = np.random.randint(0, 2, (100, 100)).astype(np.uint8)
393
+ labels = ["background", "object"]
394
+ leap_image_mask = LeapImageMask(image=image_data, mask=mask_data, labels=labels)
395
+ leap_image_mask.plot_visualizer()
396
+ """
397
+
398
+ image = self.image
399
+ mask = self.mask
400
+ labels = self.labels
401
+
402
+ # Create a color map for each label
403
+ colors = plt.cm.jet(np.linspace(0, 1, len(labels)))
404
+
405
+ # Make a copy of the image to draw on
406
+ overlayed_image = image.copy()
407
+
408
+ # Iterate through the unique values in the mask (excluding 0)
409
+ for i, label in enumerate(labels):
410
+ # Extract binary mask for the current instance
411
+ instance_mask = (mask == (i + 1))
412
+
413
+ # fill the instance mask with a translucent color
414
+ overlayed_image[instance_mask] = (
415
+ overlayed_image[instance_mask] * (1 - 0.5) + np.array(colors[i][:3], dtype=np.uint8) * 0.5)
416
+
417
+ # Display the result using matplotlib
418
+ fig, ax = plt.subplots(1)
419
+ fig.patch.set_facecolor('black') # Set the figure background to black
420
+ ax.set_facecolor('black') # Set the axis background to black
421
+
422
+ ax.imshow(overlayed_image)
423
+ ax.set_title('Leap Image With Mask Visualization', color='white')
424
+ plt.axis('off') # Hide the axis
425
+ plt.show()
426
+
427
+
428
+ @dataclass
429
+ class LeapTextMask:
430
+ """
431
+ Visualizer representing text data with a mask for Tensorleap.
432
+ This can be used for tasks such as named entity recognition (NER), sentiment analysis, and other applications where it is important to highlight specific tokens or parts of the text.
433
+
434
+ Attributes:
435
+ mask (npt.NDArray[np.uint8]): The mask data, shaped [L].
436
+ text (List[str]): The text data, consisting of a list of text tokens, length of L.
437
+ labels (List[str]): Labels associated with the masked tokens; e.g., named entities or sentiment categories. The length of `labels` should match the number of unique values in `mask`.
438
+ type (LeapDataType): The data type, default is LeapDataType.TextMask.
439
+
440
+ Example:
441
+ text_data = ['I', 'ate', 'a', 'banana', '', '', '']
442
+ mask_data = np.array([0, 0, 0, 1, 0, 0, 0]).astype(np.uint8)
443
+ labels = ["object"]
444
+ leap_text_mask = LeapTextMask(text=text_data, mask=mask_data, labels=labels)
445
+ leap_text_mask.plot_visualizer()
446
+ """
447
+ mask: npt.NDArray[np.uint8]
448
+ text: List[str]
449
+ labels: List[str]
450
+ type: LeapDataType = LeapDataType.TextMask
451
+
452
+ def __post_init__(self) -> None:
453
+ validate_type(self.type, LeapDataType.TextMask)
454
+ validate_type(type(self.mask), np.ndarray)
455
+ validate_type(self.mask.dtype, np.uint8)
456
+ validate_type(len(self.mask.shape), 1, 'text mask must be of shape 1')
457
+ validate_type(type(self.text), list)
458
+ for t in self.text:
459
+ validate_type(type(t), str)
460
+ validate_type(type(self.labels), list)
461
+ for label in self.labels:
462
+ validate_type(type(label), str)
463
+
464
+ def plot_visualizer(self) -> None:
465
+ """
466
+ Plots text with overlaid masks given a LeapTextMask visualizer object.
467
+
468
+ Returns:
469
+ None
470
+
471
+ Example:
472
+ text_data = ['I', 'ate', 'a', 'banana', '', '', '']
473
+ mask_data = np.array([0, 0, 0, 1, 0, 0, 0]).astype(np.uint8)
474
+ labels = ["object"]
475
+ leap_text_mask = LeapTextMask(text=text_data, mask=mask_data, labels=labels)
476
+ """
477
+
478
+ text_data = self.text
479
+ mask_data = self.mask
480
+ labels = self.labels
481
+
482
+ # Create a color map for each label
483
+ colors = plt.cm.jet(np.linspace(0, 1, len(labels)))
484
+
485
+ # Create a figure and axis
486
+ fig, ax = plt.subplots()
487
+
488
+ # Set background to black
489
+ fig.patch.set_facecolor('black')
490
+ ax.set_facecolor('black')
491
+ ax.set_title('Leap Text Mask Visualization', color='white')
492
+ ax.axis('off')
493
+
494
+ # Set initial position
495
+ x_pos, y_pos = 0.01, 0.5 # Adjusted initial position for better visibility
496
+
497
+ # Display the text with colors
498
+ for token, mask_value in zip(text_data, mask_data):
499
+ if mask_value > 0:
500
+ color = colors[mask_value % len(colors)]
501
+ bbox = dict(facecolor=color, edgecolor='none',
502
+ boxstyle='round,pad=0.3') # Background color for masked tokens
503
+ else:
504
+ bbox = None
505
+
506
+ ax.text(x_pos, y_pos, token, fontsize=12, color='white', ha='left', va='center', bbox=bbox)
507
+
508
+ # Update the x position for the next token
509
+ x_pos += len(token) * 0.03 + 0.02 # Adjust the spacing between tokens
510
+
511
+ plt.show()
512
+
513
+
514
+ @dataclass
515
+ class LeapImageWithHeatmap:
516
+ """
517
+ Visualizer representing an image with heatmaps for Tensorleap.
518
+ This can be used for tasks such as highlighting important regions in an image, visualizing attention maps, and other applications where it is important to overlay heatmaps on images.
519
+
520
+ Attributes:
521
+ image (npt.NDArray[np.float32]): The image data, shaped [H, W, C], where C is the number of channels.
522
+ heatmaps (npt.NDArray[np.float32]): The heatmap data, shaped [N, H, W], where N is the number of heatmaps.
523
+ labels (List[str]): Labels associated with the heatmaps; e.g., feature names or attention regions. The length of `labels` should match the number of heatmaps, N.
524
+ type (LeapDataType): The data type, default is LeapDataType.ImageWithHeatmap.
525
+
526
+ Example:
527
+ image_data = np.random.rand(100, 100, 3).astype(np.float32)
528
+ heatmaps = np.random.rand(3, 100, 100).astype(np.float32)
529
+ labels = ["heatmap1", "heatmap2", "heatmap3"]
530
+ leap_image_with_heatmap = LeapImageWithHeatmap(image=image_data, heatmaps=heatmaps, labels=labels)
531
+ """
532
+ image: npt.NDArray[np.float32]
533
+ heatmaps: npt.NDArray[np.float32]
534
+ labels: List[str]
535
+ type: LeapDataType = LeapDataType.ImageWithHeatmap
536
+
537
+ def __post_init__(self) -> None:
538
+ validate_type(self.type, LeapDataType.ImageWithHeatmap)
539
+ validate_type(type(self.heatmaps), np.ndarray)
540
+ validate_type(self.heatmaps.dtype, np.float32)
541
+ validate_type(type(self.image), np.ndarray)
542
+ validate_type(self.image.dtype, np.float32)
543
+ validate_type(type(self.labels), list)
544
+ for label in self.labels:
545
+ validate_type(type(label), str)
546
+ if self.heatmaps.shape[0] != len(self.labels):
547
+ raise LeapValidationError(
548
+ 'Number of heatmaps and labels must be equal')
549
+
550
+ def plot_visualizer(self) -> None:
551
+ """
552
+ Display the image with overlaid heatmaps contained in the LeapImageWithHeatmap object.
553
+
554
+ Returns:
555
+ None
556
+
557
+ Example:
558
+ image_data = np.random.rand(100, 100, 3).astype(np.float32)
559
+ heatmaps = np.random.rand(3, 100, 100).astype(np.float32)
560
+ labels = ["heatmap1", "heatmap2", "heatmap3"]
561
+ leap_image_with_heatmap = LeapImageWithHeatmap(image=image_data, heatmaps=heatmaps, labels=labels)
562
+ leap_image_with_heatmap.plot_visualizer()
563
+ """
564
+ image = self.image
565
+ heatmaps = self.heatmaps
566
+ labels = self.labels
567
+
568
+ # Plot the base image
569
+ fig, ax = plt.subplots()
570
+ fig.patch.set_facecolor('black') # Set the figure background to black
571
+ ax.set_facecolor('black') # Set the axis background to black
572
+ ax.imshow(image, cmap='gray')
573
+
574
+ # Overlay each heatmap with some transparency
575
+ for i in range(len(labels)):
576
+ heatmap = heatmaps[i]
577
+ ax.imshow(heatmap, cmap='jet', alpha=0.5) # Adjust alpha for transparency
578
+ ax.set_title(f'Heatmap: {labels[i]}', color='white')
579
+
580
+ # Display a colorbar for the heatmap
581
+ cbar = plt.colorbar(ax.imshow(heatmap, cmap='jet', alpha=0.5))
582
+ cbar.set_label(labels[i], color='white')
583
+ cbar.ax.yaxis.set_tick_params(color='white') # Set color for the colorbar ticks
584
+ plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), color='white') # Set color for the colorbar labels
585
+
586
+ plt.axis('off')
587
+ plt.title('Leap Image With Heatmaps Visualization', color='white')
588
+ plt.show()
589
+
590
+
591
+ map_leap_data_type_to_visualizer_class = {
592
+ LeapDataType.Image.value: LeapImage,
593
+ LeapDataType.Graph.value: LeapGraph,
594
+ LeapDataType.Text.value: LeapText,
595
+ LeapDataType.HorizontalBar.value: LeapHorizontalBar,
596
+ LeapDataType.ImageMask.value: LeapImageMask,
597
+ LeapDataType.TextMask.value: LeapTextMask,
598
+ LeapDataType.ImageWithBBox.value: LeapImageWithBBox,
599
+ LeapDataType.ImageWithHeatmap.value: LeapImageWithHeatmap
600
+ }
@@ -9,9 +9,8 @@ from code_loader.contract.datasetclasses import SectionCallableInterface, InputH
9
9
  PreprocessHandler, VisualizerCallableInterface, CustomLossHandler, CustomCallableInterface, PredictionTypeHandler, \
10
10
  MetadataSectionCallableInterface, UnlabeledDataPreprocessHandler, CustomLayerHandler, MetricHandler, \
11
11
  CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, VisualizerCallableReturnType, \
12
- CustomMultipleReturnCallableInterfaceMultiArgs, DatasetBaseHandler, custom_latent_space_attribute, \
13
- RawInputsForHeatmap, InstanceCallableInterface, ElementInstanceHandler
14
- from code_loader.contract.enums import LeapDataType, DataStateEnum, DataStateType, MetricDirection, InstanceAnalysisType
12
+ CustomMultipleReturnCallableInterfaceMultiArgs, DatasetBaseHandler, custom_latent_space_attribute, RawInputsForHeatmap
13
+ from code_loader.contract.enums import LeapDataType, DataStateEnum, DataStateType, MetricDirection
15
14
  from code_loader.contract.responsedataclasses import DatasetTestResultPayload
16
15
  from code_loader.contract.visualizer_classes import map_leap_data_type_to_visualizer_class
17
16
  from code_loader.utils import to_numpy_return_wrapper, get_shape
@@ -203,10 +202,6 @@ class LeapBinder:
203
202
 
204
203
  self._encoder_names.append(name)
205
204
 
206
- def set_instance_element(self, input_name: str, instance_function: Optional[InstanceCallableInterface] = None,
207
- analysis_type: InstanceAnalysisType = InstanceAnalysisType.MaskInput) -> None:
208
- self.setup_container.element_instances.append(ElementInstanceHandler(input_name, instance_function, analysis_type))
209
-
210
205
  def add_custom_loss(self, function: CustomCallableInterface, name: str) -> None:
211
206
  """
212
207
  Add a custom loss function to the setup.
@@ -366,8 +361,7 @@ class LeapBinder:
366
361
  custom_layer.kernel_index = kernel_index
367
362
 
368
363
  if use_custom_latent_space and not hasattr(custom_layer, custom_latent_space_attribute):
369
- raise Exception(
370
- f"{custom_latent_space_attribute} function has not been set for custom layer: {custom_layer.__name__}")
364
+ raise Exception(f"{custom_latent_space_attribute} function has not been set for custom layer: {custom_layer.__name__}")
371
365
 
372
366
  init_args = inspect.getfullargspec(custom_layer.__init__)[0][1:]
373
367
  call_args = inspect.getfullargspec(custom_layer.call)[0][1:]
@@ -458,3 +452,5 @@ class LeapBinder:
458
452
  self.check_preprocess(preprocess_result)
459
453
  self.check_handlers(preprocess_result)
460
454
  print("Successful!")
455
+
456
+
@@ -5,16 +5,15 @@ import sys
5
5
  from contextlib import redirect_stdout
6
6
  from functools import lru_cache
7
7
  from pathlib import Path
8
- from typing import Dict, List, Iterable, Union, Any, Optional, Tuple
8
+ from typing import Dict, List, Iterable, Union, Any
9
9
 
10
10
  import numpy as np
11
11
  import numpy.typing as npt
12
12
 
13
13
  from code_loader.contract.datasetclasses import DatasetSample, DatasetBaseHandler, GroundTruthHandler, \
14
14
  PreprocessResponse, VisualizerHandler, VisualizerCallableReturnType, CustomLossHandler, \
15
- PredictionTypeHandler, MetadataHandler, CustomLayerHandler, MetricHandler, ElementInstance
16
- from code_loader.contract.enums import DataStateEnum, TestingSectionEnum, DataStateType, DatasetMetadataType, \
17
- InstanceAnalysisType
15
+ PredictionTypeHandler, MetadataHandler, CustomLayerHandler, MetricHandler
16
+ from code_loader.contract.enums import DataStateEnum, TestingSectionEnum, DataStateType, DatasetMetadataType
18
17
  from code_loader.contract.exceptions import DatasetScriptException
19
18
  from code_loader.contract.responsedataclasses import DatasetIntegParseResult, DatasetTestResultPayload, \
20
19
  DatasetPreprocess, DatasetSetup, DatasetInputInstance, DatasetOutputInstance, DatasetMetadataInstance, \
@@ -299,17 +298,6 @@ class LeapLoader:
299
298
  def _get_inputs(self, state: DataStateEnum, idx: int) -> Dict[str, npt.NDArray[np.float32]]:
300
299
  return self._get_dataset_handlers(global_leap_binder.setup_container.inputs, state, idx)
301
300
 
302
- def get_instance_elements(self, state: DataStateEnum, idx: int, input_name: str) \
303
- -> Tuple[Optional[List[ElementInstance]], Optional[InstanceAnalysisType]]:
304
- preprocess_result = self._preprocess_result()
305
- preprocess_state = preprocess_result[state]
306
- for element_instance in global_leap_binder.setup_container.element_instances:
307
- if element_instance.input_name == input_name:
308
- return element_instance.instance_function(idx, preprocess_state), element_instance.analysis_type
309
-
310
- return None, None
311
-
312
-
313
301
  def _get_gt(self, state: DataStateEnum, idx: int) -> Dict[str, npt.NDArray[np.float32]]:
314
302
  return self._get_dataset_handlers(global_leap_binder.setup_container.ground_truths, state, idx)
315
303
 
@@ -1,5 +1,4 @@
1
1
  from enum import Enum
2
- from typing import List
3
2
 
4
3
  import numpy as np
5
4
  import numpy.typing as npt
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "code-loader"
3
- version = "1.0.42a0"
3
+ version = "1.0.43"
4
4
  description = ""
5
5
  authors = ["dorhar <doron.harnoy@tensorleap.ai>"]
6
6
  license = "MIT"
@@ -15,6 +15,7 @@ include = [
15
15
  python = ">=3.8,<3.12"
16
16
  numpy = "^1.22.3"
17
17
  psutil = "^5.9.5"
18
+ matplotlib = ">=3.3,<3.4"
18
19
 
19
20
  [tool.poetry.dev-dependencies]
20
21
  pytest = "^7.1.1"
@@ -1,238 +0,0 @@
1
- from typing import List, Any, Union
2
-
3
- import numpy as np
4
- import numpy.typing as npt
5
- from dataclasses import dataclass
6
-
7
- from code_loader.contract.enums import LeapDataType
8
- from code_loader.contract.responsedataclasses import BoundingBox
9
-
10
-
11
- class LeapValidationError(Exception):
12
- pass
13
-
14
-
15
- def validate_type(actual: Any, expected: Any, prefix_message: str = '') -> None:
16
- if not isinstance(expected, list):
17
- expected = [expected]
18
- if actual not in expected:
19
- if len(expected) == 1:
20
- raise LeapValidationError(
21
- f'{prefix_message}.\n'
22
- f'visualizer returned unexpected type. got {actual}, instead of {expected[0]}')
23
- else:
24
- raise LeapValidationError(
25
- f'{prefix_message}.\n'
26
- f'visualizer returned unexpected type. got {actual}, allowed is one of {expected}')
27
-
28
-
29
- @dataclass
30
- class LeapImage:
31
- """
32
- Visualizer representing an image for Tensorleap.
33
-
34
- Attributes:
35
- data (npt.NDArray[np.float32] | npt.NDArray[np.uint8]): The image data.
36
- type (LeapDataType): The data type, default is LeapDataType.Image.
37
- """
38
- data: Union[npt.NDArray[np.float32], npt.NDArray[np.uint8]]
39
- type: LeapDataType = LeapDataType.Image
40
-
41
- def __post_init__(self) -> None:
42
- validate_type(self.type, LeapDataType.Image)
43
- validate_type(type(self.data), np.ndarray)
44
- validate_type(self.data.dtype, [np.uint8, np.float32])
45
- validate_type(len(self.data.shape), 3, 'Image must be of shape 3')
46
- validate_type(self.data.shape[2], [1, 3], 'Image channel must be either 3(rgb) or 1(gray)')
47
-
48
-
49
- @dataclass
50
- class LeapImageWithBBox:
51
- """
52
- Visualizer representing an image with bounding boxes for Tensorleap, used for object detection tasks.
53
-
54
- Attributes:
55
- data (npt.NDArray[np.float32] | npt.NDArray[np.uint8]): The image data, shaped [H, W, 3] or [H, W, 1].
56
- bounding_boxes (List[BoundingBox]): List of Tensorleap bounding boxes objects in relative size to image size.
57
- type (LeapDataType): The data type, default is LeapDataType.ImageWithBBox.
58
- """
59
- data: Union[npt.NDArray[np.float32], npt.NDArray[np.uint8]]
60
- bounding_boxes: List[BoundingBox]
61
- type: LeapDataType = LeapDataType.ImageWithBBox
62
-
63
- def __post_init__(self) -> None:
64
- validate_type(self.type, LeapDataType.ImageWithBBox)
65
- validate_type(type(self.data), np.ndarray)
66
- validate_type(self.data.dtype, [np.uint8, np.float32])
67
- validate_type(len(self.data.shape), 3, 'Image must be of shape 3')
68
- validate_type(self.data.shape[2], [1, 3], 'Image channel must be either 3(rgb) or 1(gray)')
69
-
70
-
71
- @dataclass
72
- class LeapGraph:
73
- """
74
- Visualizer representing a line chart data for Tensorleap.
75
-
76
- Attributes:
77
- data (npt.NDArray[np.float32]): The array data, shaped [M, N] where M is the number of data points and N is the number of variables.
78
- type (LeapDataType): The data type, default is LeapDataType.Graph.
79
- """
80
- data: npt.NDArray[np.float32]
81
- type: LeapDataType = LeapDataType.Graph
82
-
83
- def __post_init__(self) -> None:
84
- validate_type(self.type, LeapDataType.Graph)
85
- validate_type(type(self.data), np.ndarray)
86
- validate_type(self.data.dtype, np.float32)
87
- validate_type(len(self.data.shape), 2, 'Graph must be of shape 2')
88
-
89
-
90
- @dataclass
91
- class LeapText:
92
- """
93
- Visualizer representing text data for Tensorleap.
94
-
95
- Attributes:
96
- data (List[str]): The text data, consisting of a list of text tokens. If the model requires fixed-length inputs,
97
- it is recommended to maintain the fixed length, using empty strings ('') instead of padding tokens ('PAD') e.g., ['I', 'ate', 'a', 'banana', '', '', '', ...]
98
- type (LeapDataType): The data type, default is LeapDataType.Text.
99
- """
100
- data: List[str]
101
- type: LeapDataType = LeapDataType.Text
102
-
103
- def __post_init__(self) -> None:
104
- validate_type(self.type, LeapDataType.Text)
105
- validate_type(type(self.data), list)
106
- for value in self.data:
107
- validate_type(type(value), str)
108
-
109
-
110
- @dataclass
111
- class LeapHorizontalBar:
112
- """
113
- Visualizer representing horizontal bar data for Tensorleap.
114
- For example, this can be used to visualize the model's prediction scores in a classification problem.
115
-
116
- Attributes:
117
- body (npt.NDArray[np.float32]): The data for the bar, shaped [C], where C is the number of data points.
118
- labels (List[str]): Labels for the horizontal bar; e.g., when visualizing the model's classification output, labels are the class names.
119
- Length of `body` should match the length of `labels`, C.
120
- type (LeapDataType): The data type, default is LeapDataType.HorizontalBar.
121
- """
122
- body: npt.NDArray[np.float32]
123
- labels: List[str]
124
- type: LeapDataType = LeapDataType.HorizontalBar
125
-
126
- def __post_init__(self) -> None:
127
- validate_type(self.type, LeapDataType.HorizontalBar)
128
- validate_type(type(self.body), np.ndarray)
129
- validate_type(self.body.dtype, np.float32)
130
- validate_type(len(self.body.shape), 1, 'HorizontalBar body must be of shape 1')
131
-
132
- validate_type(type(self.labels), list)
133
- for label in self.labels:
134
- validate_type(type(label), str)
135
-
136
-
137
- @dataclass
138
- class LeapImageMask:
139
- """
140
- Visualizer representing an image with a mask for Tensorleap.
141
- This can be used for tasks such as segmentation, and other applications where it is important to highlight specific regions within an image.
142
-
143
- Attributes:
144
- mask (npt.NDArray[np.uint8]): The mask data, shaped [H, W].
145
- image (npt.NDArray[np.float32] | npt.NDArray[np.uint8]): The image data, shaped [H, W, 3] or shaped [H, W, 1].
146
- labels (List[str]): Labels associated with the mask regions; e.g., class names for segmented objects. The length of `labels` should match the number of unique values in `mask`.
147
- type (LeapDataType): The data type, default is LeapDataType.ImageMask.
148
- """
149
- mask: npt.NDArray[np.uint8]
150
- image: Union[npt.NDArray[np.float32], npt.NDArray[np.uint8]]
151
- labels: List[str]
152
- type: LeapDataType = LeapDataType.ImageMask
153
-
154
- def __post_init__(self) -> None:
155
- validate_type(self.type, LeapDataType.ImageMask)
156
- validate_type(type(self.mask), np.ndarray)
157
- validate_type(self.mask.dtype, np.uint8)
158
- validate_type(len(self.mask.shape), 2, 'image mask must be of shape 2')
159
- validate_type(type(self.image), np.ndarray)
160
- validate_type(self.image.dtype, [np.uint8, np.float32])
161
- validate_type(len(self.image.shape), 3, 'Image must be of shape 3')
162
- validate_type(self.image.shape[2], [1, 3], 'Image channel must be either 3(rgb) or 1(gray)')
163
- validate_type(type(self.labels), list)
164
- for label in self.labels:
165
- validate_type(type(label), str)
166
-
167
-
168
- @dataclass
169
- class LeapTextMask:
170
- """
171
- Visualizer representing text data with a mask for Tensorleap.
172
- This can be used for tasks such as named entity recognition (NER), sentiment analysis, and other applications where it is important to highlight specific tokens or parts of the text.
173
-
174
- Attributes:
175
- mask (npt.NDArray[np.uint8]): The mask data, shaped [L].
176
- text (List[str]): The text data, consisting of a list of text tokens, length of L.
177
- labels (List[str]): Labels associated with the masked tokens; e.g., named entities or sentiment categories. The length of `labels` should match the number of unique values in `mask`.
178
- type (LeapDataType): The data type, default is LeapDataType.TextMask.
179
- """
180
- mask: npt.NDArray[np.uint8]
181
- text: List[str]
182
- labels: List[str]
183
- type: LeapDataType = LeapDataType.TextMask
184
-
185
- def __post_init__(self) -> None:
186
- validate_type(self.type, LeapDataType.TextMask)
187
- validate_type(type(self.mask), np.ndarray)
188
- validate_type(self.mask.dtype, np.uint8)
189
- validate_type(len(self.mask.shape), 1, 'text mask must be of shape 1')
190
- validate_type(type(self.text), list)
191
- for t in self.text:
192
- validate_type(type(t), str)
193
- validate_type(type(self.labels), list)
194
- for label in self.labels:
195
- validate_type(type(label), str)
196
-
197
-
198
- @dataclass
199
- class LeapImageWithHeatmap:
200
- """
201
- Visualizer representing an image with heatmaps for Tensorleap.
202
- This can be used for tasks such as highlighting important regions in an image, visualizing attention maps, and other applications where it is important to overlay heatmaps on images.
203
-
204
- Attributes:
205
- image (npt.NDArray[np.float32]): The image data, shaped [H, W, C], where C is the number of channels.
206
- heatmaps (npt.NDArray[np.float32]): The heatmap data, shaped [N, H, W], where N is the number of heatmaps.
207
- labels (List[str]): Labels associated with the heatmaps; e.g., feature names or attention regions. The length of `labels` should match the number of heatmaps, N.
208
- type (LeapDataType): The data type, default is LeapDataType.ImageWithHeatmap.
209
- """
210
- image: npt.NDArray[np.float32]
211
- heatmaps: npt.NDArray[np.float32]
212
- labels: List[str]
213
- type: LeapDataType = LeapDataType.ImageWithHeatmap
214
-
215
- def __post_init__(self) -> None:
216
- validate_type(self.type, LeapDataType.ImageWithHeatmap)
217
- validate_type(type(self.heatmaps), np.ndarray)
218
- validate_type(self.heatmaps.dtype, np.float32)
219
- validate_type(type(self.image), np.ndarray)
220
- validate_type(self.image.dtype, np.float32)
221
- validate_type(type(self.labels), list)
222
- for label in self.labels:
223
- validate_type(type(label), str)
224
- if self.heatmaps.shape[0] != len(self.labels):
225
- raise LeapValidationError(
226
- 'Number of heatmaps and labels must be equal')
227
-
228
-
229
- map_leap_data_type_to_visualizer_class = {
230
- LeapDataType.Image.value: LeapImage,
231
- LeapDataType.Graph.value: LeapGraph,
232
- LeapDataType.Text.value: LeapText,
233
- LeapDataType.HorizontalBar.value: LeapHorizontalBar,
234
- LeapDataType.ImageMask.value: LeapImageMask,
235
- LeapDataType.TextMask.value: LeapTextMask,
236
- LeapDataType.ImageWithBBox.value: LeapImageWithBBox,
237
- LeapDataType.ImageWithHeatmap.value: LeapImageWithHeatmap
238
- }
File without changes
File without changes