code-loader 1.0.42a0__py3-none-any.whl → 1.0.43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- code_loader/contract/datasetclasses.py +1 -19
- code_loader/contract/enums.py +0 -6
- code_loader/contract/visualizer_classes.py +362 -0
- code_loader/inner_leap_binder/leapbinder.py +5 -9
- code_loader/leaploader.py +3 -15
- code_loader/visualizers/default_visualizers.py +0 -1
- {code_loader-1.0.42a0.dist-info → code_loader-1.0.43.dist-info}/METADATA +2 -1
- code_loader-1.0.43.dist-info/RECORD +18 -0
- code_loader-1.0.42a0.dist-info/RECORD +0 -18
- {code_loader-1.0.42a0.dist-info → code_loader-1.0.43.dist-info}/LICENSE +0 -0
- {code_loader-1.0.42a0.dist-info → code_loader-1.0.43.dist-info}/WHEEL +0 -0
@@ -4,7 +4,7 @@ from typing import Any, Callable, List, Optional, Dict, Union, Type
|
|
4
4
|
import numpy as np
|
5
5
|
import numpy.typing as npt
|
6
6
|
|
7
|
-
from code_loader.contract.enums import DataStateType, DataStateEnum, LeapDataType, ConfusionMatrixValue, MetricDirection
|
7
|
+
from code_loader.contract.enums import DataStateType, DataStateEnum, LeapDataType, ConfusionMatrixValue, MetricDirection
|
8
8
|
from code_loader.contract.visualizer_classes import LeapImage, LeapText, LeapGraph, LeapHorizontalBar, \
|
9
9
|
LeapTextMask, LeapImageMask, LeapImageWithBBox, LeapImageWithHeatmap
|
10
10
|
|
@@ -36,16 +36,8 @@ class PreprocessResponse:
|
|
36
36
|
data: Any
|
37
37
|
|
38
38
|
|
39
|
-
@dataclass
|
40
|
-
class ElementInstance:
|
41
|
-
name: str
|
42
|
-
mask: npt.NDArray[np.float32]
|
43
|
-
|
44
|
-
|
45
39
|
SectionCallableInterface = Callable[[int, PreprocessResponse], npt.NDArray[np.float32]]
|
46
40
|
|
47
|
-
InstanceCallableInterface = Callable[[int, PreprocessResponse], List[ElementInstance]]
|
48
|
-
|
49
41
|
MetadataSectionCallableInterface = Union[
|
50
42
|
Callable[[int, PreprocessResponse], int],
|
51
43
|
Callable[[int, PreprocessResponse], Dict[str, int]],
|
@@ -117,7 +109,6 @@ class MetricHandler:
|
|
117
109
|
arg_names: List[str]
|
118
110
|
direction: Optional[MetricDirection] = MetricDirection.Downward
|
119
111
|
|
120
|
-
|
121
112
|
@dataclass
|
122
113
|
class RawInputsForHeatmap:
|
123
114
|
raw_input_by_vizualizer_arg_name: Dict[str, npt.NDArray[np.float32]]
|
@@ -143,14 +134,6 @@ class InputHandler(DatasetBaseHandler):
|
|
143
134
|
shape: Optional[List[int]] = None
|
144
135
|
|
145
136
|
|
146
|
-
@dataclass
|
147
|
-
class ElementInstanceHandler:
|
148
|
-
input_name: str
|
149
|
-
instance_function: InstanceCallableInterface
|
150
|
-
analysis_type: InstanceAnalysisType
|
151
|
-
|
152
|
-
|
153
|
-
|
154
137
|
@dataclass
|
155
138
|
class GroundTruthHandler(DatasetBaseHandler):
|
156
139
|
shape: Optional[List[int]] = None
|
@@ -184,7 +167,6 @@ class DatasetIntegrationSetup:
|
|
184
167
|
unlabeled_data_preprocess: Optional[UnlabeledDataPreprocessHandler] = None
|
185
168
|
visualizers: List[VisualizerHandler] = field(default_factory=list)
|
186
169
|
inputs: List[InputHandler] = field(default_factory=list)
|
187
|
-
element_instances: List[ElementInstanceHandler] = field(default_factory=list)
|
188
170
|
ground_truths: List[GroundTruthHandler] = field(default_factory=list)
|
189
171
|
metadata: List[MetadataHandler] = field(default_factory=list)
|
190
172
|
prediction_types: List[PredictionTypeHandler] = field(default_factory=list)
|
code_loader/contract/enums.py
CHANGED
@@ -4,6 +4,8 @@ import numpy as np
|
|
4
4
|
import numpy.typing as npt
|
5
5
|
from dataclasses import dataclass
|
6
6
|
|
7
|
+
import matplotlib.pyplot as plt # type: ignore
|
8
|
+
|
7
9
|
from code_loader.contract.enums import LeapDataType
|
8
10
|
from code_loader.contract.responsedataclasses import BoundingBox
|
9
11
|
|
@@ -34,6 +36,10 @@ class LeapImage:
|
|
34
36
|
Attributes:
|
35
37
|
data (npt.NDArray[np.float32] | npt.NDArray[np.uint8]): The image data.
|
36
38
|
type (LeapDataType): The data type, default is LeapDataType.Image.
|
39
|
+
|
40
|
+
Example:
|
41
|
+
image_data = np.random.rand(100, 100, 3).astype(np.float32)
|
42
|
+
leap_image = LeapImage(data=image_data)
|
37
43
|
"""
|
38
44
|
data: Union[npt.NDArray[np.float32], npt.NDArray[np.uint8]]
|
39
45
|
type: LeapDataType = LeapDataType.Image
|
@@ -45,6 +51,34 @@ class LeapImage:
|
|
45
51
|
validate_type(len(self.data.shape), 3, 'Image must be of shape 3')
|
46
52
|
validate_type(self.data.shape[2], [1, 3], 'Image channel must be either 3(rgb) or 1(gray)')
|
47
53
|
|
54
|
+
def plot_visualizer(self) -> None:
|
55
|
+
"""
|
56
|
+
Display the image contained in the LeapImage object.
|
57
|
+
|
58
|
+
Returns:
|
59
|
+
None
|
60
|
+
|
61
|
+
Example:
|
62
|
+
image_data = np.random.rand(100, 100, 3).astype(np.float32)
|
63
|
+
leap_image = LeapImage(data=image_data)
|
64
|
+
leap_image.plot_visualizer()
|
65
|
+
"""
|
66
|
+
image_data = self.data
|
67
|
+
|
68
|
+
# If the image has one channel, convert it to a 3-channel image for display
|
69
|
+
if image_data.shape[2] == 1:
|
70
|
+
image_data = np.repeat(image_data, 3, axis=2)
|
71
|
+
|
72
|
+
fig, ax = plt.subplots()
|
73
|
+
fig.patch.set_facecolor('black')
|
74
|
+
ax.set_facecolor('black')
|
75
|
+
|
76
|
+
ax.imshow(image_data)
|
77
|
+
|
78
|
+
plt.axis('off')
|
79
|
+
plt.title('Leap Image Visualization', color='white')
|
80
|
+
plt.show()
|
81
|
+
|
48
82
|
|
49
83
|
@dataclass
|
50
84
|
class LeapImageWithBBox:
|
@@ -55,6 +89,11 @@ class LeapImageWithBBox:
|
|
55
89
|
data (npt.NDArray[np.float32] | npt.NDArray[np.uint8]): The image data, shaped [H, W, 3] or [H, W, 1].
|
56
90
|
bounding_boxes (List[BoundingBox]): List of Tensorleap bounding boxes objects in relative size to image size.
|
57
91
|
type (LeapDataType): The data type, default is LeapDataType.ImageWithBBox.
|
92
|
+
|
93
|
+
Example:
|
94
|
+
image_data = np.random.rand(100, 100, 3).astype(np.float32)
|
95
|
+
bbox = BoundingBox(x=0.5, y=0.5, width=0.2, height=0.2, confidence=0.9, label="object")
|
96
|
+
leap_image_with_bbox = LeapImageWithBBox(data=image_data, bounding_boxes=[bbox])
|
58
97
|
"""
|
59
98
|
data: Union[npt.NDArray[np.float32], npt.NDArray[np.uint8]]
|
60
99
|
bounding_boxes: List[BoundingBox]
|
@@ -67,6 +106,60 @@ class LeapImageWithBBox:
|
|
67
106
|
validate_type(len(self.data.shape), 3, 'Image must be of shape 3')
|
68
107
|
validate_type(self.data.shape[2], [1, 3], 'Image channel must be either 3(rgb) or 1(gray)')
|
69
108
|
|
109
|
+
def plot_visualizer(self) -> None:
|
110
|
+
"""
|
111
|
+
Plot an image with overlaid bounding boxes.
|
112
|
+
|
113
|
+
Returns:
|
114
|
+
None
|
115
|
+
|
116
|
+
Example:
|
117
|
+
image_data = np.random.rand(100, 100, 3).astype(np.float32)
|
118
|
+
bbox = BoundingBox(x=0.5, y=0.5, width=0.2, height=0.2, confidence=0.9, label="object")
|
119
|
+
leap_image_with_bbox = LeapImageWithBBox(data=image_data, bounding_boxes=[bbox])
|
120
|
+
leap_image_with_bbox.plot_visualizer()
|
121
|
+
"""
|
122
|
+
|
123
|
+
image = self.data
|
124
|
+
bounding_boxes = self.bounding_boxes
|
125
|
+
|
126
|
+
# Create figure and axes
|
127
|
+
fig, ax = plt.subplots(1)
|
128
|
+
fig.patch.set_facecolor('black')
|
129
|
+
ax.set_facecolor('black')
|
130
|
+
|
131
|
+
# Display the image
|
132
|
+
ax.imshow(image)
|
133
|
+
ax.set_title('Leap Image With BBox Visualization', color='white')
|
134
|
+
|
135
|
+
# Draw bounding boxes on the image
|
136
|
+
for bbox in bounding_boxes:
|
137
|
+
x, y, width, height = bbox.x, bbox.y, bbox.width, bbox.height
|
138
|
+
confidence, label = bbox.confidence, bbox.label
|
139
|
+
|
140
|
+
# Convert relative coordinates to absolute coordinates
|
141
|
+
abs_x = x * image.shape[1]
|
142
|
+
abs_y = y * image.shape[0]
|
143
|
+
abs_width = width * image.shape[1]
|
144
|
+
abs_height = height * image.shape[0]
|
145
|
+
|
146
|
+
# Create a rectangle patch
|
147
|
+
rect = plt.Rectangle(
|
148
|
+
(abs_x - abs_width / 2, abs_y - abs_height / 2),
|
149
|
+
abs_width, abs_height,
|
150
|
+
linewidth=3, edgecolor='r', facecolor='none'
|
151
|
+
)
|
152
|
+
|
153
|
+
# Add the rectangle to the axes
|
154
|
+
ax.add_patch(rect)
|
155
|
+
|
156
|
+
# Display label and confidence
|
157
|
+
ax.text(abs_x - abs_width / 2, abs_y - abs_height / 2 - 5,
|
158
|
+
f"{label} {confidence:.2f}", color='r', fontsize=8,
|
159
|
+
bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', boxstyle='round,pad=0.3'))
|
160
|
+
|
161
|
+
# Show the image with bounding boxes
|
162
|
+
plt.show()
|
70
163
|
|
71
164
|
@dataclass
|
72
165
|
class LeapGraph:
|
@@ -76,6 +169,10 @@ class LeapGraph:
|
|
76
169
|
Attributes:
|
77
170
|
data (npt.NDArray[np.float32]): The array data, shaped [M, N] where M is the number of data points and N is the number of variables.
|
78
171
|
type (LeapDataType): The data type, default is LeapDataType.Graph.
|
172
|
+
|
173
|
+
Example:
|
174
|
+
graph_data = np.random.rand(100, 3).astype(np.float32)
|
175
|
+
leap_graph = LeapGraph(data=graph_data)
|
79
176
|
"""
|
80
177
|
data: npt.NDArray[np.float32]
|
81
178
|
type: LeapDataType = LeapDataType.Graph
|
@@ -86,6 +183,40 @@ class LeapGraph:
|
|
86
183
|
validate_type(self.data.dtype, np.float32)
|
87
184
|
validate_type(len(self.data.shape), 2, 'Graph must be of shape 2')
|
88
185
|
|
186
|
+
def plot_visualizer(self) -> None:
|
187
|
+
"""
|
188
|
+
Display the line chart contained in the LeapGraph object.
|
189
|
+
|
190
|
+
Returns:
|
191
|
+
None
|
192
|
+
|
193
|
+
Example:
|
194
|
+
graph_data = np.random.rand(100, 3).astype(np.float32)
|
195
|
+
leap_graph = LeapGraph(data=graph_data)
|
196
|
+
leap_graph.plot_visualizer()
|
197
|
+
"""
|
198
|
+
graph_data = self.data
|
199
|
+
num_variables = graph_data.shape[1]
|
200
|
+
|
201
|
+
fig, ax = plt.subplots(figsize=(10, 6))
|
202
|
+
|
203
|
+
# Set the background color to black
|
204
|
+
fig.patch.set_facecolor('black')
|
205
|
+
ax.set_facecolor('black')
|
206
|
+
|
207
|
+
for i in range(num_variables):
|
208
|
+
plt.plot(graph_data[:, i], label=f'Variable {i + 1}')
|
209
|
+
|
210
|
+
ax.set_xlabel('Data Points', color='white')
|
211
|
+
ax.set_ylabel('Values', color='white')
|
212
|
+
ax.set_title('Leap Graph Visualization', color='white')
|
213
|
+
ax.legend()
|
214
|
+
ax.grid(True, color='white')
|
215
|
+
|
216
|
+
# Change the color of the tick labels to white
|
217
|
+
ax.tick_params(colors='white')
|
218
|
+
|
219
|
+
plt.show()
|
89
220
|
|
90
221
|
@dataclass
|
91
222
|
class LeapText:
|
@@ -96,6 +227,11 @@ class LeapText:
|
|
96
227
|
data (List[str]): The text data, consisting of a list of text tokens. If the model requires fixed-length inputs,
|
97
228
|
it is recommended to maintain the fixed length, using empty strings ('') instead of padding tokens ('PAD') e.g., ['I', 'ate', 'a', 'banana', '', '', '', ...]
|
98
229
|
type (LeapDataType): The data type, default is LeapDataType.Text.
|
230
|
+
|
231
|
+
Example:
|
232
|
+
text_data = ['I', 'ate', 'a', 'banana', '', '', '']
|
233
|
+
leap_text = LeapText(data=text_data) # Create LeapText object
|
234
|
+
LeapText(leap_text)
|
99
235
|
"""
|
100
236
|
data: List[str]
|
101
237
|
type: LeapDataType = LeapDataType.Text
|
@@ -106,6 +242,41 @@ class LeapText:
|
|
106
242
|
for value in self.data:
|
107
243
|
validate_type(type(value), str)
|
108
244
|
|
245
|
+
def plot_visualizer(self) -> None:
|
246
|
+
"""
|
247
|
+
Display the text contained in the LeapText object.
|
248
|
+
|
249
|
+
Returns:
|
250
|
+
None
|
251
|
+
|
252
|
+
Example:
|
253
|
+
text_data = ['I', 'ate', 'a', 'banana', '', '', '']
|
254
|
+
leap_text = LeapText(data=text_data)
|
255
|
+
leap_text.plot_visualizer()
|
256
|
+
"""
|
257
|
+
text_data = self.data
|
258
|
+
# Join the text tokens into a single string, ignoring empty strings
|
259
|
+
display_text = ' '.join([token for token in text_data if token])
|
260
|
+
|
261
|
+
# Create a black image using Matplotlib
|
262
|
+
fig, ax = plt.subplots(figsize=(10, 5))
|
263
|
+
fig.patch.set_facecolor('black')
|
264
|
+
ax.set_facecolor('black')
|
265
|
+
|
266
|
+
# Hide the axes
|
267
|
+
ax.axis('off')
|
268
|
+
|
269
|
+
# Set the text properties
|
270
|
+
font_size = 20
|
271
|
+
font_color = 'white'
|
272
|
+
|
273
|
+
# Add the text to the image
|
274
|
+
ax.text(0.5, 0.5, display_text, color=font_color, fontsize=font_size, ha='center', va='center')
|
275
|
+
ax.set_title('Leap Text Visualization', color='white')
|
276
|
+
|
277
|
+
# Display the image
|
278
|
+
plt.show()
|
279
|
+
|
109
280
|
|
110
281
|
@dataclass
|
111
282
|
class LeapHorizontalBar:
|
@@ -118,6 +289,11 @@ class LeapHorizontalBar:
|
|
118
289
|
labels (List[str]): Labels for the horizontal bar; e.g., when visualizing the model's classification output, labels are the class names.
|
119
290
|
Length of `body` should match the length of `labels`, C.
|
120
291
|
type (LeapDataType): The data type, default is LeapDataType.HorizontalBar.
|
292
|
+
|
293
|
+
Example:
|
294
|
+
body_data = np.random.rand(5).astype(np.float32)
|
295
|
+
labels = ['Class A', 'Class B', 'Class C', 'Class D', 'Class E']
|
296
|
+
leap_horizontal_bar = LeapHorizontalBar(body=body_data, labels=labels)
|
121
297
|
"""
|
122
298
|
body: npt.NDArray[np.float32]
|
123
299
|
labels: List[str]
|
@@ -133,6 +309,39 @@ class LeapHorizontalBar:
|
|
133
309
|
for label in self.labels:
|
134
310
|
validate_type(type(label), str)
|
135
311
|
|
312
|
+
def plot_visualizer(self) -> None:
|
313
|
+
"""
|
314
|
+
Display the horizontal bar chart contained in the LeapHorizontalBar object.
|
315
|
+
|
316
|
+
Returns:
|
317
|
+
None
|
318
|
+
|
319
|
+
Example:
|
320
|
+
body_data = np.random.rand(5).astype(np.float32)
|
321
|
+
labels = ['Class A', 'Class B', 'Class C', 'Class D', 'Class E']
|
322
|
+
leap_horizontal_bar = LeapHorizontalBar(body=body_data, labels=labels)
|
323
|
+
leap_horizontal_bar.plot_visualizer()
|
324
|
+
"""
|
325
|
+
body_data = self.body
|
326
|
+
labels = self.labels
|
327
|
+
|
328
|
+
fig, ax = plt.subplots()
|
329
|
+
|
330
|
+
fig.patch.set_facecolor('black')
|
331
|
+
ax.set_facecolor('black')
|
332
|
+
|
333
|
+
# Plot horizontal bar chart
|
334
|
+
ax.barh(labels, body_data, color='green')
|
335
|
+
|
336
|
+
# Set the color of the labels and title to white
|
337
|
+
ax.set_xlabel('Scores', color='white')
|
338
|
+
ax.set_title('Leap Horizontal Bar Visualization', color='white')
|
339
|
+
|
340
|
+
# Set the color of the ticks to white
|
341
|
+
ax.tick_params(axis='x', colors='white')
|
342
|
+
ax.tick_params(axis='y', colors='white')
|
343
|
+
|
344
|
+
plt.show()
|
136
345
|
|
137
346
|
@dataclass
|
138
347
|
class LeapImageMask:
|
@@ -145,6 +354,12 @@ class LeapImageMask:
|
|
145
354
|
image (npt.NDArray[np.float32] | npt.NDArray[np.uint8]): The image data, shaped [H, W, 3] or shaped [H, W, 1].
|
146
355
|
labels (List[str]): Labels associated with the mask regions; e.g., class names for segmented objects. The length of `labels` should match the number of unique values in `mask`.
|
147
356
|
type (LeapDataType): The data type, default is LeapDataType.ImageMask.
|
357
|
+
|
358
|
+
Example:
|
359
|
+
image_data = np.random.rand(100, 100, 3).astype(np.float32)
|
360
|
+
mask_data = np.random.randint(0, 2, (100, 100)).astype(np.uint8)
|
361
|
+
labels = ["background", "object"]
|
362
|
+
leap_image_mask = LeapImageMask(image=image_data, mask=mask_data, labels=labels)
|
148
363
|
"""
|
149
364
|
mask: npt.NDArray[np.uint8]
|
150
365
|
image: Union[npt.NDArray[np.float32], npt.NDArray[np.uint8]]
|
@@ -164,6 +379,51 @@ class LeapImageMask:
|
|
164
379
|
for label in self.labels:
|
165
380
|
validate_type(type(label), str)
|
166
381
|
|
382
|
+
def plot_visualizer(self) -> None:
|
383
|
+
"""
|
384
|
+
Plots an image with overlaid masks given a LeapImageMask visualizer object.
|
385
|
+
|
386
|
+
Returns:
|
387
|
+
None
|
388
|
+
|
389
|
+
|
390
|
+
Example:
|
391
|
+
image_data = np.random.rand(100, 100, 3).astype(np.float32)
|
392
|
+
mask_data = np.random.randint(0, 2, (100, 100)).astype(np.uint8)
|
393
|
+
labels = ["background", "object"]
|
394
|
+
leap_image_mask = LeapImageMask(image=image_data, mask=mask_data, labels=labels)
|
395
|
+
leap_image_mask.plot_visualizer()
|
396
|
+
"""
|
397
|
+
|
398
|
+
image = self.image
|
399
|
+
mask = self.mask
|
400
|
+
labels = self.labels
|
401
|
+
|
402
|
+
# Create a color map for each label
|
403
|
+
colors = plt.cm.jet(np.linspace(0, 1, len(labels)))
|
404
|
+
|
405
|
+
# Make a copy of the image to draw on
|
406
|
+
overlayed_image = image.copy()
|
407
|
+
|
408
|
+
# Iterate through the unique values in the mask (excluding 0)
|
409
|
+
for i, label in enumerate(labels):
|
410
|
+
# Extract binary mask for the current instance
|
411
|
+
instance_mask = (mask == (i + 1))
|
412
|
+
|
413
|
+
# fill the instance mask with a translucent color
|
414
|
+
overlayed_image[instance_mask] = (
|
415
|
+
overlayed_image[instance_mask] * (1 - 0.5) + np.array(colors[i][:3], dtype=np.uint8) * 0.5)
|
416
|
+
|
417
|
+
# Display the result using matplotlib
|
418
|
+
fig, ax = plt.subplots(1)
|
419
|
+
fig.patch.set_facecolor('black') # Set the figure background to black
|
420
|
+
ax.set_facecolor('black') # Set the axis background to black
|
421
|
+
|
422
|
+
ax.imshow(overlayed_image)
|
423
|
+
ax.set_title('Leap Image With Mask Visualization', color='white')
|
424
|
+
plt.axis('off') # Hide the axis
|
425
|
+
plt.show()
|
426
|
+
|
167
427
|
|
168
428
|
@dataclass
|
169
429
|
class LeapTextMask:
|
@@ -176,6 +436,13 @@ class LeapTextMask:
|
|
176
436
|
text (List[str]): The text data, consisting of a list of text tokens, length of L.
|
177
437
|
labels (List[str]): Labels associated with the masked tokens; e.g., named entities or sentiment categories. The length of `labels` should match the number of unique values in `mask`.
|
178
438
|
type (LeapDataType): The data type, default is LeapDataType.TextMask.
|
439
|
+
|
440
|
+
Example:
|
441
|
+
text_data = ['I', 'ate', 'a', 'banana', '', '', '']
|
442
|
+
mask_data = np.array([0, 0, 0, 1, 0, 0, 0]).astype(np.uint8)
|
443
|
+
labels = ["object"]
|
444
|
+
leap_text_mask = LeapTextMask(text=text_data, mask=mask_data, labels=labels)
|
445
|
+
leap_text_mask.plot_visualizer()
|
179
446
|
"""
|
180
447
|
mask: npt.NDArray[np.uint8]
|
181
448
|
text: List[str]
|
@@ -194,6 +461,55 @@ class LeapTextMask:
|
|
194
461
|
for label in self.labels:
|
195
462
|
validate_type(type(label), str)
|
196
463
|
|
464
|
+
def plot_visualizer(self) -> None:
|
465
|
+
"""
|
466
|
+
Plots text with overlaid masks given a LeapTextMask visualizer object.
|
467
|
+
|
468
|
+
Returns:
|
469
|
+
None
|
470
|
+
|
471
|
+
Example:
|
472
|
+
text_data = ['I', 'ate', 'a', 'banana', '', '', '']
|
473
|
+
mask_data = np.array([0, 0, 0, 1, 0, 0, 0]).astype(np.uint8)
|
474
|
+
labels = ["object"]
|
475
|
+
leap_text_mask = LeapTextMask(text=text_data, mask=mask_data, labels=labels)
|
476
|
+
"""
|
477
|
+
|
478
|
+
text_data = self.text
|
479
|
+
mask_data = self.mask
|
480
|
+
labels = self.labels
|
481
|
+
|
482
|
+
# Create a color map for each label
|
483
|
+
colors = plt.cm.jet(np.linspace(0, 1, len(labels)))
|
484
|
+
|
485
|
+
# Create a figure and axis
|
486
|
+
fig, ax = plt.subplots()
|
487
|
+
|
488
|
+
# Set background to black
|
489
|
+
fig.patch.set_facecolor('black')
|
490
|
+
ax.set_facecolor('black')
|
491
|
+
ax.set_title('Leap Text Mask Visualization', color='white')
|
492
|
+
ax.axis('off')
|
493
|
+
|
494
|
+
# Set initial position
|
495
|
+
x_pos, y_pos = 0.01, 0.5 # Adjusted initial position for better visibility
|
496
|
+
|
497
|
+
# Display the text with colors
|
498
|
+
for token, mask_value in zip(text_data, mask_data):
|
499
|
+
if mask_value > 0:
|
500
|
+
color = colors[mask_value % len(colors)]
|
501
|
+
bbox = dict(facecolor=color, edgecolor='none',
|
502
|
+
boxstyle='round,pad=0.3') # Background color for masked tokens
|
503
|
+
else:
|
504
|
+
bbox = None
|
505
|
+
|
506
|
+
ax.text(x_pos, y_pos, token, fontsize=12, color='white', ha='left', va='center', bbox=bbox)
|
507
|
+
|
508
|
+
# Update the x position for the next token
|
509
|
+
x_pos += len(token) * 0.03 + 0.02 # Adjust the spacing between tokens
|
510
|
+
|
511
|
+
plt.show()
|
512
|
+
|
197
513
|
|
198
514
|
@dataclass
|
199
515
|
class LeapImageWithHeatmap:
|
@@ -206,6 +522,12 @@ class LeapImageWithHeatmap:
|
|
206
522
|
heatmaps (npt.NDArray[np.float32]): The heatmap data, shaped [N, H, W], where N is the number of heatmaps.
|
207
523
|
labels (List[str]): Labels associated with the heatmaps; e.g., feature names or attention regions. The length of `labels` should match the number of heatmaps, N.
|
208
524
|
type (LeapDataType): The data type, default is LeapDataType.ImageWithHeatmap.
|
525
|
+
|
526
|
+
Example:
|
527
|
+
image_data = np.random.rand(100, 100, 3).astype(np.float32)
|
528
|
+
heatmaps = np.random.rand(3, 100, 100).astype(np.float32)
|
529
|
+
labels = ["heatmap1", "heatmap2", "heatmap3"]
|
530
|
+
leap_image_with_heatmap = LeapImageWithHeatmap(image=image_data, heatmaps=heatmaps, labels=labels)
|
209
531
|
"""
|
210
532
|
image: npt.NDArray[np.float32]
|
211
533
|
heatmaps: npt.NDArray[np.float32]
|
@@ -225,6 +547,46 @@ class LeapImageWithHeatmap:
|
|
225
547
|
raise LeapValidationError(
|
226
548
|
'Number of heatmaps and labels must be equal')
|
227
549
|
|
550
|
+
def plot_visualizer(self) -> None:
|
551
|
+
"""
|
552
|
+
Display the image with overlaid heatmaps contained in the LeapImageWithHeatmap object.
|
553
|
+
|
554
|
+
Returns:
|
555
|
+
None
|
556
|
+
|
557
|
+
Example:
|
558
|
+
image_data = np.random.rand(100, 100, 3).astype(np.float32)
|
559
|
+
heatmaps = np.random.rand(3, 100, 100).astype(np.float32)
|
560
|
+
labels = ["heatmap1", "heatmap2", "heatmap3"]
|
561
|
+
leap_image_with_heatmap = LeapImageWithHeatmap(image=image_data, heatmaps=heatmaps, labels=labels)
|
562
|
+
leap_image_with_heatmap.plot_visualizer()
|
563
|
+
"""
|
564
|
+
image = self.image
|
565
|
+
heatmaps = self.heatmaps
|
566
|
+
labels = self.labels
|
567
|
+
|
568
|
+
# Plot the base image
|
569
|
+
fig, ax = plt.subplots()
|
570
|
+
fig.patch.set_facecolor('black') # Set the figure background to black
|
571
|
+
ax.set_facecolor('black') # Set the axis background to black
|
572
|
+
ax.imshow(image, cmap='gray')
|
573
|
+
|
574
|
+
# Overlay each heatmap with some transparency
|
575
|
+
for i in range(len(labels)):
|
576
|
+
heatmap = heatmaps[i]
|
577
|
+
ax.imshow(heatmap, cmap='jet', alpha=0.5) # Adjust alpha for transparency
|
578
|
+
ax.set_title(f'Heatmap: {labels[i]}', color='white')
|
579
|
+
|
580
|
+
# Display a colorbar for the heatmap
|
581
|
+
cbar = plt.colorbar(ax.imshow(heatmap, cmap='jet', alpha=0.5))
|
582
|
+
cbar.set_label(labels[i], color='white')
|
583
|
+
cbar.ax.yaxis.set_tick_params(color='white') # Set color for the colorbar ticks
|
584
|
+
plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), color='white') # Set color for the colorbar labels
|
585
|
+
|
586
|
+
plt.axis('off')
|
587
|
+
plt.title('Leap Image With Heatmaps Visualization', color='white')
|
588
|
+
plt.show()
|
589
|
+
|
228
590
|
|
229
591
|
map_leap_data_type_to_visualizer_class = {
|
230
592
|
LeapDataType.Image.value: LeapImage,
|
@@ -9,9 +9,8 @@ from code_loader.contract.datasetclasses import SectionCallableInterface, InputH
|
|
9
9
|
PreprocessHandler, VisualizerCallableInterface, CustomLossHandler, CustomCallableInterface, PredictionTypeHandler, \
|
10
10
|
MetadataSectionCallableInterface, UnlabeledDataPreprocessHandler, CustomLayerHandler, MetricHandler, \
|
11
11
|
CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, VisualizerCallableReturnType, \
|
12
|
-
CustomMultipleReturnCallableInterfaceMultiArgs, DatasetBaseHandler, custom_latent_space_attribute,
|
13
|
-
|
14
|
-
from code_loader.contract.enums import LeapDataType, DataStateEnum, DataStateType, MetricDirection, InstanceAnalysisType
|
12
|
+
CustomMultipleReturnCallableInterfaceMultiArgs, DatasetBaseHandler, custom_latent_space_attribute, RawInputsForHeatmap
|
13
|
+
from code_loader.contract.enums import LeapDataType, DataStateEnum, DataStateType, MetricDirection
|
15
14
|
from code_loader.contract.responsedataclasses import DatasetTestResultPayload
|
16
15
|
from code_loader.contract.visualizer_classes import map_leap_data_type_to_visualizer_class
|
17
16
|
from code_loader.utils import to_numpy_return_wrapper, get_shape
|
@@ -203,10 +202,6 @@ class LeapBinder:
|
|
203
202
|
|
204
203
|
self._encoder_names.append(name)
|
205
204
|
|
206
|
-
def set_instance_element(self, input_name: str, instance_function: Optional[InstanceCallableInterface] = None,
|
207
|
-
analysis_type: InstanceAnalysisType = InstanceAnalysisType.MaskInput) -> None:
|
208
|
-
self.setup_container.element_instances.append(ElementInstanceHandler(input_name, instance_function, analysis_type))
|
209
|
-
|
210
205
|
def add_custom_loss(self, function: CustomCallableInterface, name: str) -> None:
|
211
206
|
"""
|
212
207
|
Add a custom loss function to the setup.
|
@@ -366,8 +361,7 @@ class LeapBinder:
|
|
366
361
|
custom_layer.kernel_index = kernel_index
|
367
362
|
|
368
363
|
if use_custom_latent_space and not hasattr(custom_layer, custom_latent_space_attribute):
|
369
|
-
raise Exception(
|
370
|
-
f"{custom_latent_space_attribute} function has not been set for custom layer: {custom_layer.__name__}")
|
364
|
+
raise Exception(f"{custom_latent_space_attribute} function has not been set for custom layer: {custom_layer.__name__}")
|
371
365
|
|
372
366
|
init_args = inspect.getfullargspec(custom_layer.__init__)[0][1:]
|
373
367
|
call_args = inspect.getfullargspec(custom_layer.call)[0][1:]
|
@@ -458,3 +452,5 @@ class LeapBinder:
|
|
458
452
|
self.check_preprocess(preprocess_result)
|
459
453
|
self.check_handlers(preprocess_result)
|
460
454
|
print("Successful!")
|
455
|
+
|
456
|
+
|
code_loader/leaploader.py
CHANGED
@@ -5,16 +5,15 @@ import sys
|
|
5
5
|
from contextlib import redirect_stdout
|
6
6
|
from functools import lru_cache
|
7
7
|
from pathlib import Path
|
8
|
-
from typing import Dict, List, Iterable, Union, Any
|
8
|
+
from typing import Dict, List, Iterable, Union, Any
|
9
9
|
|
10
10
|
import numpy as np
|
11
11
|
import numpy.typing as npt
|
12
12
|
|
13
13
|
from code_loader.contract.datasetclasses import DatasetSample, DatasetBaseHandler, GroundTruthHandler, \
|
14
14
|
PreprocessResponse, VisualizerHandler, VisualizerCallableReturnType, CustomLossHandler, \
|
15
|
-
PredictionTypeHandler, MetadataHandler, CustomLayerHandler, MetricHandler
|
16
|
-
from code_loader.contract.enums import DataStateEnum, TestingSectionEnum, DataStateType, DatasetMetadataType
|
17
|
-
InstanceAnalysisType
|
15
|
+
PredictionTypeHandler, MetadataHandler, CustomLayerHandler, MetricHandler
|
16
|
+
from code_loader.contract.enums import DataStateEnum, TestingSectionEnum, DataStateType, DatasetMetadataType
|
18
17
|
from code_loader.contract.exceptions import DatasetScriptException
|
19
18
|
from code_loader.contract.responsedataclasses import DatasetIntegParseResult, DatasetTestResultPayload, \
|
20
19
|
DatasetPreprocess, DatasetSetup, DatasetInputInstance, DatasetOutputInstance, DatasetMetadataInstance, \
|
@@ -299,17 +298,6 @@ class LeapLoader:
|
|
299
298
|
def _get_inputs(self, state: DataStateEnum, idx: int) -> Dict[str, npt.NDArray[np.float32]]:
|
300
299
|
return self._get_dataset_handlers(global_leap_binder.setup_container.inputs, state, idx)
|
301
300
|
|
302
|
-
def get_instance_elements(self, state: DataStateEnum, idx: int, input_name: str) \
|
303
|
-
-> Tuple[Optional[List[ElementInstance]], Optional[InstanceAnalysisType]]:
|
304
|
-
preprocess_result = self._preprocess_result()
|
305
|
-
preprocess_state = preprocess_result[state]
|
306
|
-
for element_instance in global_leap_binder.setup_container.element_instances:
|
307
|
-
if element_instance.input_name == input_name:
|
308
|
-
return element_instance.instance_function(idx, preprocess_state), element_instance.analysis_type
|
309
|
-
|
310
|
-
return None, None
|
311
|
-
|
312
|
-
|
313
301
|
def _get_gt(self, state: DataStateEnum, idx: int) -> Dict[str, npt.NDArray[np.float32]]:
|
314
302
|
return self._get_dataset_handlers(global_leap_binder.setup_container.ground_truths, state, idx)
|
315
303
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: code-loader
|
3
|
-
Version: 1.0.
|
3
|
+
Version: 1.0.43
|
4
4
|
Summary:
|
5
5
|
Home-page: https://github.com/tensorleap/code-loader
|
6
6
|
License: MIT
|
@@ -13,6 +13,7 @@ Classifier: Programming Language :: Python :: 3.8
|
|
13
13
|
Classifier: Programming Language :: Python :: 3.9
|
14
14
|
Classifier: Programming Language :: Python :: 3.10
|
15
15
|
Classifier: Programming Language :: Python :: 3.11
|
16
|
+
Requires-Dist: matplotlib (>=3.3,<3.4)
|
16
17
|
Requires-Dist: numpy (>=1.22.3,<2.0.0)
|
17
18
|
Requires-Dist: psutil (>=5.9.5,<6.0.0)
|
18
19
|
Project-URL: Repository, https://github.com/tensorleap/code-loader
|
@@ -0,0 +1,18 @@
|
|
1
|
+
LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
|
2
|
+
code_loader/__init__.py,sha256=V3DEXSN6Ie6PlGeSAbzjp9ufRj0XPJLpD7pDLLYxk6M,122
|
3
|
+
code_loader/contract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
+
code_loader/contract/datasetclasses.py,sha256=81TmCcVol7768lzKUp70MatLLipR3ftcR9jgE1r8Yqo,5698
|
5
|
+
code_loader/contract/enums.py,sha256=6Lo7p5CUog68Fd31bCozIuOgIp_IhSiPqWWph2k3OGU,1602
|
6
|
+
code_loader/contract/exceptions.py,sha256=jWqu5i7t-0IG0jGRsKF4DjJdrsdpJjIYpUkN1F4RiyQ,51
|
7
|
+
code_loader/contract/responsedataclasses.py,sha256=w7xVOv2S8Hyb5lqyomMGiKAWXDTSOG-FX1YW39bXD3A,3969
|
8
|
+
code_loader/contract/visualizer_classes.py,sha256=Ka8fJSVKrOeZ12Eg8-dOBGMW0UswYCIkuhnNMd-7z9s,22948
|
9
|
+
code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
|
10
|
+
code_loader/inner_leap_binder/leapbinder.py,sha256=fHND8ayIXDlKHufHHq36u5rKNnZq64MmMuI42GctHvQ,24022
|
11
|
+
code_loader/leaploader.py,sha256=pUySweZetJ6SsubCcZlDCJpvWmUrm5YlPlkZWQxY1hQ,17289
|
12
|
+
code_loader/utils.py,sha256=61I4PgSl-ZBIe4DifLxMNlBELE-HQR2pB9efVYPceIU,2230
|
13
|
+
code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
|
+
code_loader/visualizers/default_visualizers.py,sha256=VoqO9FN84yXyMjRjHjUTOt2GdTkJRMbHbXJ1cJkREkk,2230
|
15
|
+
code_loader-1.0.43.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
|
16
|
+
code_loader-1.0.43.dist-info/METADATA,sha256=FO-vqcxPvPCvmVbDXmgc5ne0o2pSnKF9NsP2w0j_Ub8,807
|
17
|
+
code_loader-1.0.43.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
18
|
+
code_loader-1.0.43.dist-info/RECORD,,
|
@@ -1,18 +0,0 @@
|
|
1
|
-
LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
|
2
|
-
code_loader/__init__.py,sha256=V3DEXSN6Ie6PlGeSAbzjp9ufRj0XPJLpD7pDLLYxk6M,122
|
3
|
-
code_loader/contract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
-
code_loader/contract/datasetclasses.py,sha256=eoDxCGZzvogp440dhWZBBm8oTs6XiqcHxE-JXD4txeA,6128
|
5
|
-
code_loader/contract/enums.py,sha256=2tKXamabooQjCe3EBkeBcZMA8vBbeGjDAbACvUTGqLE,1707
|
6
|
-
code_loader/contract/exceptions.py,sha256=jWqu5i7t-0IG0jGRsKF4DjJdrsdpJjIYpUkN1F4RiyQ,51
|
7
|
-
code_loader/contract/responsedataclasses.py,sha256=w7xVOv2S8Hyb5lqyomMGiKAWXDTSOG-FX1YW39bXD3A,3969
|
8
|
-
code_loader/contract/visualizer_classes.py,sha256=t0EKxTGFJoFBnKuMmjtFpj-L32xEFF3NV-E5XutVXa0,10118
|
9
|
-
code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
|
10
|
-
code_loader/inner_leap_binder/leapbinder.py,sha256=7_ffUiGW54Y27f7eKnvpJ2XLOqfGfTCW5cMNaH21iIk,24464
|
11
|
-
code_loader/leaploader.py,sha256=gQaEpHyQx50FP-d0LS2PjQdm9aOUf3jok8ZeLbWkGZc,17917
|
12
|
-
code_loader/utils.py,sha256=61I4PgSl-ZBIe4DifLxMNlBELE-HQR2pB9efVYPceIU,2230
|
13
|
-
code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
14
|
-
code_loader/visualizers/default_visualizers.py,sha256=F2qRT6rNy7SOmr4QfqVxIkLlYEa00CwDkLJuA45lfSI,2254
|
15
|
-
code_loader-1.0.42a0.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
|
16
|
-
code_loader-1.0.42a0.dist-info/METADATA,sha256=HrRgFfW68Nr0KIBzK-C9Fwnxcw9SYotWTzWfwW2PF4I,770
|
17
|
-
code_loader-1.0.42a0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
18
|
-
code_loader-1.0.42a0.dist-info/RECORD,,
|
File without changes
|
File without changes
|