code-loader 1.0.42a0__tar.gz → 1.0.43a0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {code_loader-1.0.42a0 → code_loader-1.0.43a0}/PKG-INFO +4 -1
- {code_loader-1.0.42a0 → code_loader-1.0.43a0}/code_loader/__init__.py +1 -0
- {code_loader-1.0.42a0 → code_loader-1.0.43a0}/code_loader/contract/datasetclasses.py +1 -19
- {code_loader-1.0.42a0 → code_loader-1.0.43a0}/code_loader/contract/enums.py +0 -6
- code_loader-1.0.43a0/code_loader/contract/visualizer_classes.py +600 -0
- {code_loader-1.0.42a0 → code_loader-1.0.43a0}/code_loader/inner_leap_binder/leapbinder.py +5 -9
- {code_loader-1.0.42a0 → code_loader-1.0.43a0}/code_loader/leaploader.py +3 -15
- code_loader-1.0.43a0/code_loader/rt_api/__init__.py +12 -0
- code_loader-1.0.43a0/code_loader/rt_api/api_client.py +107 -0
- code_loader-1.0.43a0/code_loader/rt_api/cli_config_utils.py +41 -0
- code_loader-1.0.43a0/code_loader/rt_api/epoch.py +63 -0
- code_loader-1.0.43a0/code_loader/rt_api/experiment.py +47 -0
- code_loader-1.0.43a0/code_loader/rt_api/experiment_context.py +10 -0
- code_loader-1.0.43a0/code_loader/rt_api/types.py +25 -0
- code_loader-1.0.43a0/code_loader/rt_api/utils.py +34 -0
- code_loader-1.0.43a0/code_loader/rt_api/workingspace_config_utils.py +27 -0
- {code_loader-1.0.42a0 → code_loader-1.0.43a0}/code_loader/visualizers/default_visualizers.py +0 -1
- {code_loader-1.0.42a0 → code_loader-1.0.43a0}/pyproject.toml +4 -1
- code_loader-1.0.42a0/code_loader/contract/visualizer_classes.py +0 -238
- {code_loader-1.0.42a0 → code_loader-1.0.43a0}/LICENSE +0 -0
- {code_loader-1.0.42a0 → code_loader-1.0.43a0}/README.md +0 -0
- {code_loader-1.0.42a0 → code_loader-1.0.43a0}/code_loader/contract/__init__.py +0 -0
- {code_loader-1.0.42a0 → code_loader-1.0.43a0}/code_loader/contract/exceptions.py +0 -0
- {code_loader-1.0.42a0 → code_loader-1.0.43a0}/code_loader/contract/responsedataclasses.py +0 -0
- {code_loader-1.0.42a0 → code_loader-1.0.43a0}/code_loader/inner_leap_binder/__init__.py +0 -0
- {code_loader-1.0.42a0 → code_loader-1.0.43a0}/code_loader/utils.py +0 -0
- {code_loader-1.0.42a0 → code_loader-1.0.43a0}/code_loader/visualizers/__init__.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: code-loader
|
3
|
-
Version: 1.0.
|
3
|
+
Version: 1.0.43a0
|
4
4
|
Summary:
|
5
5
|
Home-page: https://github.com/tensorleap/code-loader
|
6
6
|
License: MIT
|
@@ -13,8 +13,11 @@ Classifier: Programming Language :: Python :: 3.8
|
|
13
13
|
Classifier: Programming Language :: Python :: 3.9
|
14
14
|
Classifier: Programming Language :: Python :: 3.10
|
15
15
|
Classifier: Programming Language :: Python :: 3.11
|
16
|
+
Requires-Dist: matplotlib (>=3.3,<3.4)
|
16
17
|
Requires-Dist: numpy (>=1.22.3,<2.0.0)
|
17
18
|
Requires-Dist: psutil (>=5.9.5,<6.0.0)
|
19
|
+
Requires-Dist: pyyaml (>=6.0.2,<7.0.0)
|
20
|
+
Requires-Dist: requests (>=2.32.3,<3.0.0)
|
18
21
|
Project-URL: Repository, https://github.com/tensorleap/code-loader
|
19
22
|
Description-Content-Type: text/markdown
|
20
23
|
|
@@ -4,7 +4,7 @@ from typing import Any, Callable, List, Optional, Dict, Union, Type
|
|
4
4
|
import numpy as np
|
5
5
|
import numpy.typing as npt
|
6
6
|
|
7
|
-
from code_loader.contract.enums import DataStateType, DataStateEnum, LeapDataType, ConfusionMatrixValue, MetricDirection
|
7
|
+
from code_loader.contract.enums import DataStateType, DataStateEnum, LeapDataType, ConfusionMatrixValue, MetricDirection
|
8
8
|
from code_loader.contract.visualizer_classes import LeapImage, LeapText, LeapGraph, LeapHorizontalBar, \
|
9
9
|
LeapTextMask, LeapImageMask, LeapImageWithBBox, LeapImageWithHeatmap
|
10
10
|
|
@@ -36,16 +36,8 @@ class PreprocessResponse:
|
|
36
36
|
data: Any
|
37
37
|
|
38
38
|
|
39
|
-
@dataclass
|
40
|
-
class ElementInstance:
|
41
|
-
name: str
|
42
|
-
mask: npt.NDArray[np.float32]
|
43
|
-
|
44
|
-
|
45
39
|
SectionCallableInterface = Callable[[int, PreprocessResponse], npt.NDArray[np.float32]]
|
46
40
|
|
47
|
-
InstanceCallableInterface = Callable[[int, PreprocessResponse], List[ElementInstance]]
|
48
|
-
|
49
41
|
MetadataSectionCallableInterface = Union[
|
50
42
|
Callable[[int, PreprocessResponse], int],
|
51
43
|
Callable[[int, PreprocessResponse], Dict[str, int]],
|
@@ -117,7 +109,6 @@ class MetricHandler:
|
|
117
109
|
arg_names: List[str]
|
118
110
|
direction: Optional[MetricDirection] = MetricDirection.Downward
|
119
111
|
|
120
|
-
|
121
112
|
@dataclass
|
122
113
|
class RawInputsForHeatmap:
|
123
114
|
raw_input_by_vizualizer_arg_name: Dict[str, npt.NDArray[np.float32]]
|
@@ -143,14 +134,6 @@ class InputHandler(DatasetBaseHandler):
|
|
143
134
|
shape: Optional[List[int]] = None
|
144
135
|
|
145
136
|
|
146
|
-
@dataclass
|
147
|
-
class ElementInstanceHandler:
|
148
|
-
input_name: str
|
149
|
-
instance_function: InstanceCallableInterface
|
150
|
-
analysis_type: InstanceAnalysisType
|
151
|
-
|
152
|
-
|
153
|
-
|
154
137
|
@dataclass
|
155
138
|
class GroundTruthHandler(DatasetBaseHandler):
|
156
139
|
shape: Optional[List[int]] = None
|
@@ -184,7 +167,6 @@ class DatasetIntegrationSetup:
|
|
184
167
|
unlabeled_data_preprocess: Optional[UnlabeledDataPreprocessHandler] = None
|
185
168
|
visualizers: List[VisualizerHandler] = field(default_factory=list)
|
186
169
|
inputs: List[InputHandler] = field(default_factory=list)
|
187
|
-
element_instances: List[ElementInstanceHandler] = field(default_factory=list)
|
188
170
|
ground_truths: List[GroundTruthHandler] = field(default_factory=list)
|
189
171
|
metadata: List[MetadataHandler] = field(default_factory=list)
|
190
172
|
prediction_types: List[PredictionTypeHandler] = field(default_factory=list)
|
@@ -0,0 +1,600 @@
|
|
1
|
+
from typing import List, Any, Union
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
import numpy.typing as npt
|
5
|
+
from dataclasses import dataclass
|
6
|
+
|
7
|
+
import matplotlib.pyplot as plt # type: ignore
|
8
|
+
|
9
|
+
from code_loader.contract.enums import LeapDataType
|
10
|
+
from code_loader.contract.responsedataclasses import BoundingBox
|
11
|
+
|
12
|
+
|
13
|
+
class LeapValidationError(Exception):
|
14
|
+
pass
|
15
|
+
|
16
|
+
|
17
|
+
def validate_type(actual: Any, expected: Any, prefix_message: str = '') -> None:
|
18
|
+
if not isinstance(expected, list):
|
19
|
+
expected = [expected]
|
20
|
+
if actual not in expected:
|
21
|
+
if len(expected) == 1:
|
22
|
+
raise LeapValidationError(
|
23
|
+
f'{prefix_message}.\n'
|
24
|
+
f'visualizer returned unexpected type. got {actual}, instead of {expected[0]}')
|
25
|
+
else:
|
26
|
+
raise LeapValidationError(
|
27
|
+
f'{prefix_message}.\n'
|
28
|
+
f'visualizer returned unexpected type. got {actual}, allowed is one of {expected}')
|
29
|
+
|
30
|
+
|
31
|
+
@dataclass
|
32
|
+
class LeapImage:
|
33
|
+
"""
|
34
|
+
Visualizer representing an image for Tensorleap.
|
35
|
+
|
36
|
+
Attributes:
|
37
|
+
data (npt.NDArray[np.float32] | npt.NDArray[np.uint8]): The image data.
|
38
|
+
type (LeapDataType): The data type, default is LeapDataType.Image.
|
39
|
+
|
40
|
+
Example:
|
41
|
+
image_data = np.random.rand(100, 100, 3).astype(np.float32)
|
42
|
+
leap_image = LeapImage(data=image_data)
|
43
|
+
"""
|
44
|
+
data: Union[npt.NDArray[np.float32], npt.NDArray[np.uint8]]
|
45
|
+
type: LeapDataType = LeapDataType.Image
|
46
|
+
|
47
|
+
def __post_init__(self) -> None:
|
48
|
+
validate_type(self.type, LeapDataType.Image)
|
49
|
+
validate_type(type(self.data), np.ndarray)
|
50
|
+
validate_type(self.data.dtype, [np.uint8, np.float32])
|
51
|
+
validate_type(len(self.data.shape), 3, 'Image must be of shape 3')
|
52
|
+
validate_type(self.data.shape[2], [1, 3], 'Image channel must be either 3(rgb) or 1(gray)')
|
53
|
+
|
54
|
+
def plot_visualizer(self) -> None:
|
55
|
+
"""
|
56
|
+
Display the image contained in the LeapImage object.
|
57
|
+
|
58
|
+
Returns:
|
59
|
+
None
|
60
|
+
|
61
|
+
Example:
|
62
|
+
image_data = np.random.rand(100, 100, 3).astype(np.float32)
|
63
|
+
leap_image = LeapImage(data=image_data)
|
64
|
+
leap_image.plot_visualizer()
|
65
|
+
"""
|
66
|
+
image_data = self.data
|
67
|
+
|
68
|
+
# If the image has one channel, convert it to a 3-channel image for display
|
69
|
+
if image_data.shape[2] == 1:
|
70
|
+
image_data = np.repeat(image_data, 3, axis=2)
|
71
|
+
|
72
|
+
fig, ax = plt.subplots()
|
73
|
+
fig.patch.set_facecolor('black')
|
74
|
+
ax.set_facecolor('black')
|
75
|
+
|
76
|
+
ax.imshow(image_data)
|
77
|
+
|
78
|
+
plt.axis('off')
|
79
|
+
plt.title('Leap Image Visualization', color='white')
|
80
|
+
plt.show()
|
81
|
+
|
82
|
+
|
83
|
+
@dataclass
|
84
|
+
class LeapImageWithBBox:
|
85
|
+
"""
|
86
|
+
Visualizer representing an image with bounding boxes for Tensorleap, used for object detection tasks.
|
87
|
+
|
88
|
+
Attributes:
|
89
|
+
data (npt.NDArray[np.float32] | npt.NDArray[np.uint8]): The image data, shaped [H, W, 3] or [H, W, 1].
|
90
|
+
bounding_boxes (List[BoundingBox]): List of Tensorleap bounding boxes objects in relative size to image size.
|
91
|
+
type (LeapDataType): The data type, default is LeapDataType.ImageWithBBox.
|
92
|
+
|
93
|
+
Example:
|
94
|
+
image_data = np.random.rand(100, 100, 3).astype(np.float32)
|
95
|
+
bbox = BoundingBox(x=0.5, y=0.5, width=0.2, height=0.2, confidence=0.9, label="object")
|
96
|
+
leap_image_with_bbox = LeapImageWithBBox(data=image_data, bounding_boxes=[bbox])
|
97
|
+
"""
|
98
|
+
data: Union[npt.NDArray[np.float32], npt.NDArray[np.uint8]]
|
99
|
+
bounding_boxes: List[BoundingBox]
|
100
|
+
type: LeapDataType = LeapDataType.ImageWithBBox
|
101
|
+
|
102
|
+
def __post_init__(self) -> None:
|
103
|
+
validate_type(self.type, LeapDataType.ImageWithBBox)
|
104
|
+
validate_type(type(self.data), np.ndarray)
|
105
|
+
validate_type(self.data.dtype, [np.uint8, np.float32])
|
106
|
+
validate_type(len(self.data.shape), 3, 'Image must be of shape 3')
|
107
|
+
validate_type(self.data.shape[2], [1, 3], 'Image channel must be either 3(rgb) or 1(gray)')
|
108
|
+
|
109
|
+
def plot_visualizer(self) -> None:
|
110
|
+
"""
|
111
|
+
Plot an image with overlaid bounding boxes.
|
112
|
+
|
113
|
+
Returns:
|
114
|
+
None
|
115
|
+
|
116
|
+
Example:
|
117
|
+
image_data = np.random.rand(100, 100, 3).astype(np.float32)
|
118
|
+
bbox = BoundingBox(x=0.5, y=0.5, width=0.2, height=0.2, confidence=0.9, label="object")
|
119
|
+
leap_image_with_bbox = LeapImageWithBBox(data=image_data, bounding_boxes=[bbox])
|
120
|
+
leap_image_with_bbox.plot_visualizer()
|
121
|
+
"""
|
122
|
+
|
123
|
+
image = self.data
|
124
|
+
bounding_boxes = self.bounding_boxes
|
125
|
+
|
126
|
+
# Create figure and axes
|
127
|
+
fig, ax = plt.subplots(1)
|
128
|
+
fig.patch.set_facecolor('black')
|
129
|
+
ax.set_facecolor('black')
|
130
|
+
|
131
|
+
# Display the image
|
132
|
+
ax.imshow(image)
|
133
|
+
ax.set_title('Leap Image With BBox Visualization', color='white')
|
134
|
+
|
135
|
+
# Draw bounding boxes on the image
|
136
|
+
for bbox in bounding_boxes:
|
137
|
+
x, y, width, height = bbox.x, bbox.y, bbox.width, bbox.height
|
138
|
+
confidence, label = bbox.confidence, bbox.label
|
139
|
+
|
140
|
+
# Convert relative coordinates to absolute coordinates
|
141
|
+
abs_x = x * image.shape[1]
|
142
|
+
abs_y = y * image.shape[0]
|
143
|
+
abs_width = width * image.shape[1]
|
144
|
+
abs_height = height * image.shape[0]
|
145
|
+
|
146
|
+
# Create a rectangle patch
|
147
|
+
rect = plt.Rectangle(
|
148
|
+
(abs_x - abs_width / 2, abs_y - abs_height / 2),
|
149
|
+
abs_width, abs_height,
|
150
|
+
linewidth=3, edgecolor='r', facecolor='none'
|
151
|
+
)
|
152
|
+
|
153
|
+
# Add the rectangle to the axes
|
154
|
+
ax.add_patch(rect)
|
155
|
+
|
156
|
+
# Display label and confidence
|
157
|
+
ax.text(abs_x - abs_width / 2, abs_y - abs_height / 2 - 5,
|
158
|
+
f"{label} {confidence:.2f}", color='r', fontsize=8,
|
159
|
+
bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', boxstyle='round,pad=0.3'))
|
160
|
+
|
161
|
+
# Show the image with bounding boxes
|
162
|
+
plt.show()
|
163
|
+
|
164
|
+
@dataclass
|
165
|
+
class LeapGraph:
|
166
|
+
"""
|
167
|
+
Visualizer representing a line chart data for Tensorleap.
|
168
|
+
|
169
|
+
Attributes:
|
170
|
+
data (npt.NDArray[np.float32]): The array data, shaped [M, N] where M is the number of data points and N is the number of variables.
|
171
|
+
type (LeapDataType): The data type, default is LeapDataType.Graph.
|
172
|
+
|
173
|
+
Example:
|
174
|
+
graph_data = np.random.rand(100, 3).astype(np.float32)
|
175
|
+
leap_graph = LeapGraph(data=graph_data)
|
176
|
+
"""
|
177
|
+
data: npt.NDArray[np.float32]
|
178
|
+
type: LeapDataType = LeapDataType.Graph
|
179
|
+
|
180
|
+
def __post_init__(self) -> None:
|
181
|
+
validate_type(self.type, LeapDataType.Graph)
|
182
|
+
validate_type(type(self.data), np.ndarray)
|
183
|
+
validate_type(self.data.dtype, np.float32)
|
184
|
+
validate_type(len(self.data.shape), 2, 'Graph must be of shape 2')
|
185
|
+
|
186
|
+
def plot_visualizer(self) -> None:
|
187
|
+
"""
|
188
|
+
Display the line chart contained in the LeapGraph object.
|
189
|
+
|
190
|
+
Returns:
|
191
|
+
None
|
192
|
+
|
193
|
+
Example:
|
194
|
+
graph_data = np.random.rand(100, 3).astype(np.float32)
|
195
|
+
leap_graph = LeapGraph(data=graph_data)
|
196
|
+
leap_graph.plot_visualizer()
|
197
|
+
"""
|
198
|
+
graph_data = self.data
|
199
|
+
num_variables = graph_data.shape[1]
|
200
|
+
|
201
|
+
fig, ax = plt.subplots(figsize=(10, 6))
|
202
|
+
|
203
|
+
# Set the background color to black
|
204
|
+
fig.patch.set_facecolor('black')
|
205
|
+
ax.set_facecolor('black')
|
206
|
+
|
207
|
+
for i in range(num_variables):
|
208
|
+
plt.plot(graph_data[:, i], label=f'Variable {i + 1}')
|
209
|
+
|
210
|
+
ax.set_xlabel('Data Points', color='white')
|
211
|
+
ax.set_ylabel('Values', color='white')
|
212
|
+
ax.set_title('Leap Graph Visualization', color='white')
|
213
|
+
ax.legend()
|
214
|
+
ax.grid(True, color='white')
|
215
|
+
|
216
|
+
# Change the color of the tick labels to white
|
217
|
+
ax.tick_params(colors='white')
|
218
|
+
|
219
|
+
plt.show()
|
220
|
+
|
221
|
+
@dataclass
|
222
|
+
class LeapText:
|
223
|
+
"""
|
224
|
+
Visualizer representing text data for Tensorleap.
|
225
|
+
|
226
|
+
Attributes:
|
227
|
+
data (List[str]): The text data, consisting of a list of text tokens. If the model requires fixed-length inputs,
|
228
|
+
it is recommended to maintain the fixed length, using empty strings ('') instead of padding tokens ('PAD') e.g., ['I', 'ate', 'a', 'banana', '', '', '', ...]
|
229
|
+
type (LeapDataType): The data type, default is LeapDataType.Text.
|
230
|
+
|
231
|
+
Example:
|
232
|
+
text_data = ['I', 'ate', 'a', 'banana', '', '', '']
|
233
|
+
leap_text = LeapText(data=text_data) # Create LeapText object
|
234
|
+
LeapText(leap_text)
|
235
|
+
"""
|
236
|
+
data: List[str]
|
237
|
+
type: LeapDataType = LeapDataType.Text
|
238
|
+
|
239
|
+
def __post_init__(self) -> None:
|
240
|
+
validate_type(self.type, LeapDataType.Text)
|
241
|
+
validate_type(type(self.data), list)
|
242
|
+
for value in self.data:
|
243
|
+
validate_type(type(value), str)
|
244
|
+
|
245
|
+
def plot_visualizer(self) -> None:
|
246
|
+
"""
|
247
|
+
Display the text contained in the LeapText object.
|
248
|
+
|
249
|
+
Returns:
|
250
|
+
None
|
251
|
+
|
252
|
+
Example:
|
253
|
+
text_data = ['I', 'ate', 'a', 'banana', '', '', '']
|
254
|
+
leap_text = LeapText(data=text_data)
|
255
|
+
leap_text.plot_visualizer()
|
256
|
+
"""
|
257
|
+
text_data = self.data
|
258
|
+
# Join the text tokens into a single string, ignoring empty strings
|
259
|
+
display_text = ' '.join([token for token in text_data if token])
|
260
|
+
|
261
|
+
# Create a black image using Matplotlib
|
262
|
+
fig, ax = plt.subplots(figsize=(10, 5))
|
263
|
+
fig.patch.set_facecolor('black')
|
264
|
+
ax.set_facecolor('black')
|
265
|
+
|
266
|
+
# Hide the axes
|
267
|
+
ax.axis('off')
|
268
|
+
|
269
|
+
# Set the text properties
|
270
|
+
font_size = 20
|
271
|
+
font_color = 'white'
|
272
|
+
|
273
|
+
# Add the text to the image
|
274
|
+
ax.text(0.5, 0.5, display_text, color=font_color, fontsize=font_size, ha='center', va='center')
|
275
|
+
ax.set_title('Leap Text Visualization', color='white')
|
276
|
+
|
277
|
+
# Display the image
|
278
|
+
plt.show()
|
279
|
+
|
280
|
+
|
281
|
+
@dataclass
|
282
|
+
class LeapHorizontalBar:
|
283
|
+
"""
|
284
|
+
Visualizer representing horizontal bar data for Tensorleap.
|
285
|
+
For example, this can be used to visualize the model's prediction scores in a classification problem.
|
286
|
+
|
287
|
+
Attributes:
|
288
|
+
body (npt.NDArray[np.float32]): The data for the bar, shaped [C], where C is the number of data points.
|
289
|
+
labels (List[str]): Labels for the horizontal bar; e.g., when visualizing the model's classification output, labels are the class names.
|
290
|
+
Length of `body` should match the length of `labels`, C.
|
291
|
+
type (LeapDataType): The data type, default is LeapDataType.HorizontalBar.
|
292
|
+
|
293
|
+
Example:
|
294
|
+
body_data = np.random.rand(5).astype(np.float32)
|
295
|
+
labels = ['Class A', 'Class B', 'Class C', 'Class D', 'Class E']
|
296
|
+
leap_horizontal_bar = LeapHorizontalBar(body=body_data, labels=labels)
|
297
|
+
"""
|
298
|
+
body: npt.NDArray[np.float32]
|
299
|
+
labels: List[str]
|
300
|
+
type: LeapDataType = LeapDataType.HorizontalBar
|
301
|
+
|
302
|
+
def __post_init__(self) -> None:
|
303
|
+
validate_type(self.type, LeapDataType.HorizontalBar)
|
304
|
+
validate_type(type(self.body), np.ndarray)
|
305
|
+
validate_type(self.body.dtype, np.float32)
|
306
|
+
validate_type(len(self.body.shape), 1, 'HorizontalBar body must be of shape 1')
|
307
|
+
|
308
|
+
validate_type(type(self.labels), list)
|
309
|
+
for label in self.labels:
|
310
|
+
validate_type(type(label), str)
|
311
|
+
|
312
|
+
def plot_visualizer(self) -> None:
|
313
|
+
"""
|
314
|
+
Display the horizontal bar chart contained in the LeapHorizontalBar object.
|
315
|
+
|
316
|
+
Returns:
|
317
|
+
None
|
318
|
+
|
319
|
+
Example:
|
320
|
+
body_data = np.random.rand(5).astype(np.float32)
|
321
|
+
labels = ['Class A', 'Class B', 'Class C', 'Class D', 'Class E']
|
322
|
+
leap_horizontal_bar = LeapHorizontalBar(body=body_data, labels=labels)
|
323
|
+
leap_horizontal_bar.plot_visualizer()
|
324
|
+
"""
|
325
|
+
body_data = self.body
|
326
|
+
labels = self.labels
|
327
|
+
|
328
|
+
fig, ax = plt.subplots()
|
329
|
+
|
330
|
+
fig.patch.set_facecolor('black')
|
331
|
+
ax.set_facecolor('black')
|
332
|
+
|
333
|
+
# Plot horizontal bar chart
|
334
|
+
ax.barh(labels, body_data, color='green')
|
335
|
+
|
336
|
+
# Set the color of the labels and title to white
|
337
|
+
ax.set_xlabel('Scores', color='white')
|
338
|
+
ax.set_title('Leap Horizontal Bar Visualization', color='white')
|
339
|
+
|
340
|
+
# Set the color of the ticks to white
|
341
|
+
ax.tick_params(axis='x', colors='white')
|
342
|
+
ax.tick_params(axis='y', colors='white')
|
343
|
+
|
344
|
+
plt.show()
|
345
|
+
|
346
|
+
@dataclass
|
347
|
+
class LeapImageMask:
|
348
|
+
"""
|
349
|
+
Visualizer representing an image with a mask for Tensorleap.
|
350
|
+
This can be used for tasks such as segmentation, and other applications where it is important to highlight specific regions within an image.
|
351
|
+
|
352
|
+
Attributes:
|
353
|
+
mask (npt.NDArray[np.uint8]): The mask data, shaped [H, W].
|
354
|
+
image (npt.NDArray[np.float32] | npt.NDArray[np.uint8]): The image data, shaped [H, W, 3] or shaped [H, W, 1].
|
355
|
+
labels (List[str]): Labels associated with the mask regions; e.g., class names for segmented objects. The length of `labels` should match the number of unique values in `mask`.
|
356
|
+
type (LeapDataType): The data type, default is LeapDataType.ImageMask.
|
357
|
+
|
358
|
+
Example:
|
359
|
+
image_data = np.random.rand(100, 100, 3).astype(np.float32)
|
360
|
+
mask_data = np.random.randint(0, 2, (100, 100)).astype(np.uint8)
|
361
|
+
labels = ["background", "object"]
|
362
|
+
leap_image_mask = LeapImageMask(image=image_data, mask=mask_data, labels=labels)
|
363
|
+
"""
|
364
|
+
mask: npt.NDArray[np.uint8]
|
365
|
+
image: Union[npt.NDArray[np.float32], npt.NDArray[np.uint8]]
|
366
|
+
labels: List[str]
|
367
|
+
type: LeapDataType = LeapDataType.ImageMask
|
368
|
+
|
369
|
+
def __post_init__(self) -> None:
|
370
|
+
validate_type(self.type, LeapDataType.ImageMask)
|
371
|
+
validate_type(type(self.mask), np.ndarray)
|
372
|
+
validate_type(self.mask.dtype, np.uint8)
|
373
|
+
validate_type(len(self.mask.shape), 2, 'image mask must be of shape 2')
|
374
|
+
validate_type(type(self.image), np.ndarray)
|
375
|
+
validate_type(self.image.dtype, [np.uint8, np.float32])
|
376
|
+
validate_type(len(self.image.shape), 3, 'Image must be of shape 3')
|
377
|
+
validate_type(self.image.shape[2], [1, 3], 'Image channel must be either 3(rgb) or 1(gray)')
|
378
|
+
validate_type(type(self.labels), list)
|
379
|
+
for label in self.labels:
|
380
|
+
validate_type(type(label), str)
|
381
|
+
|
382
|
+
def plot_visualizer(self) -> None:
|
383
|
+
"""
|
384
|
+
Plots an image with overlaid masks given a LeapImageMask visualizer object.
|
385
|
+
|
386
|
+
Returns:
|
387
|
+
None
|
388
|
+
|
389
|
+
|
390
|
+
Example:
|
391
|
+
image_data = np.random.rand(100, 100, 3).astype(np.float32)
|
392
|
+
mask_data = np.random.randint(0, 2, (100, 100)).astype(np.uint8)
|
393
|
+
labels = ["background", "object"]
|
394
|
+
leap_image_mask = LeapImageMask(image=image_data, mask=mask_data, labels=labels)
|
395
|
+
leap_image_mask.plot_visualizer()
|
396
|
+
"""
|
397
|
+
|
398
|
+
image = self.image
|
399
|
+
mask = self.mask
|
400
|
+
labels = self.labels
|
401
|
+
|
402
|
+
# Create a color map for each label
|
403
|
+
colors = plt.cm.jet(np.linspace(0, 1, len(labels)))
|
404
|
+
|
405
|
+
# Make a copy of the image to draw on
|
406
|
+
overlayed_image = image.copy()
|
407
|
+
|
408
|
+
# Iterate through the unique values in the mask (excluding 0)
|
409
|
+
for i, label in enumerate(labels):
|
410
|
+
# Extract binary mask for the current instance
|
411
|
+
instance_mask = (mask == (i + 1))
|
412
|
+
|
413
|
+
# fill the instance mask with a translucent color
|
414
|
+
overlayed_image[instance_mask] = (
|
415
|
+
overlayed_image[instance_mask] * (1 - 0.5) + np.array(colors[i][:3], dtype=np.uint8) * 0.5)
|
416
|
+
|
417
|
+
# Display the result using matplotlib
|
418
|
+
fig, ax = plt.subplots(1)
|
419
|
+
fig.patch.set_facecolor('black') # Set the figure background to black
|
420
|
+
ax.set_facecolor('black') # Set the axis background to black
|
421
|
+
|
422
|
+
ax.imshow(overlayed_image)
|
423
|
+
ax.set_title('Leap Image With Mask Visualization', color='white')
|
424
|
+
plt.axis('off') # Hide the axis
|
425
|
+
plt.show()
|
426
|
+
|
427
|
+
|
428
|
+
@dataclass
|
429
|
+
class LeapTextMask:
|
430
|
+
"""
|
431
|
+
Visualizer representing text data with a mask for Tensorleap.
|
432
|
+
This can be used for tasks such as named entity recognition (NER), sentiment analysis, and other applications where it is important to highlight specific tokens or parts of the text.
|
433
|
+
|
434
|
+
Attributes:
|
435
|
+
mask (npt.NDArray[np.uint8]): The mask data, shaped [L].
|
436
|
+
text (List[str]): The text data, consisting of a list of text tokens, length of L.
|
437
|
+
labels (List[str]): Labels associated with the masked tokens; e.g., named entities or sentiment categories. The length of `labels` should match the number of unique values in `mask`.
|
438
|
+
type (LeapDataType): The data type, default is LeapDataType.TextMask.
|
439
|
+
|
440
|
+
Example:
|
441
|
+
text_data = ['I', 'ate', 'a', 'banana', '', '', '']
|
442
|
+
mask_data = np.array([0, 0, 0, 1, 0, 0, 0]).astype(np.uint8)
|
443
|
+
labels = ["object"]
|
444
|
+
leap_text_mask = LeapTextMask(text=text_data, mask=mask_data, labels=labels)
|
445
|
+
leap_text_mask.plot_visualizer()
|
446
|
+
"""
|
447
|
+
mask: npt.NDArray[np.uint8]
|
448
|
+
text: List[str]
|
449
|
+
labels: List[str]
|
450
|
+
type: LeapDataType = LeapDataType.TextMask
|
451
|
+
|
452
|
+
def __post_init__(self) -> None:
|
453
|
+
validate_type(self.type, LeapDataType.TextMask)
|
454
|
+
validate_type(type(self.mask), np.ndarray)
|
455
|
+
validate_type(self.mask.dtype, np.uint8)
|
456
|
+
validate_type(len(self.mask.shape), 1, 'text mask must be of shape 1')
|
457
|
+
validate_type(type(self.text), list)
|
458
|
+
for t in self.text:
|
459
|
+
validate_type(type(t), str)
|
460
|
+
validate_type(type(self.labels), list)
|
461
|
+
for label in self.labels:
|
462
|
+
validate_type(type(label), str)
|
463
|
+
|
464
|
+
def plot_visualizer(self) -> None:
|
465
|
+
"""
|
466
|
+
Plots text with overlaid masks given a LeapTextMask visualizer object.
|
467
|
+
|
468
|
+
Returns:
|
469
|
+
None
|
470
|
+
|
471
|
+
Example:
|
472
|
+
text_data = ['I', 'ate', 'a', 'banana', '', '', '']
|
473
|
+
mask_data = np.array([0, 0, 0, 1, 0, 0, 0]).astype(np.uint8)
|
474
|
+
labels = ["object"]
|
475
|
+
leap_text_mask = LeapTextMask(text=text_data, mask=mask_data, labels=labels)
|
476
|
+
"""
|
477
|
+
|
478
|
+
text_data = self.text
|
479
|
+
mask_data = self.mask
|
480
|
+
labels = self.labels
|
481
|
+
|
482
|
+
# Create a color map for each label
|
483
|
+
colors = plt.cm.jet(np.linspace(0, 1, len(labels)))
|
484
|
+
|
485
|
+
# Create a figure and axis
|
486
|
+
fig, ax = plt.subplots()
|
487
|
+
|
488
|
+
# Set background to black
|
489
|
+
fig.patch.set_facecolor('black')
|
490
|
+
ax.set_facecolor('black')
|
491
|
+
ax.set_title('Leap Text Mask Visualization', color='white')
|
492
|
+
ax.axis('off')
|
493
|
+
|
494
|
+
# Set initial position
|
495
|
+
x_pos, y_pos = 0.01, 0.5 # Adjusted initial position for better visibility
|
496
|
+
|
497
|
+
# Display the text with colors
|
498
|
+
for token, mask_value in zip(text_data, mask_data):
|
499
|
+
if mask_value > 0:
|
500
|
+
color = colors[mask_value % len(colors)]
|
501
|
+
bbox = dict(facecolor=color, edgecolor='none',
|
502
|
+
boxstyle='round,pad=0.3') # Background color for masked tokens
|
503
|
+
else:
|
504
|
+
bbox = None
|
505
|
+
|
506
|
+
ax.text(x_pos, y_pos, token, fontsize=12, color='white', ha='left', va='center', bbox=bbox)
|
507
|
+
|
508
|
+
# Update the x position for the next token
|
509
|
+
x_pos += len(token) * 0.03 + 0.02 # Adjust the spacing between tokens
|
510
|
+
|
511
|
+
plt.show()
|
512
|
+
|
513
|
+
|
514
|
+
@dataclass
|
515
|
+
class LeapImageWithHeatmap:
|
516
|
+
"""
|
517
|
+
Visualizer representing an image with heatmaps for Tensorleap.
|
518
|
+
This can be used for tasks such as highlighting important regions in an image, visualizing attention maps, and other applications where it is important to overlay heatmaps on images.
|
519
|
+
|
520
|
+
Attributes:
|
521
|
+
image (npt.NDArray[np.float32]): The image data, shaped [H, W, C], where C is the number of channels.
|
522
|
+
heatmaps (npt.NDArray[np.float32]): The heatmap data, shaped [N, H, W], where N is the number of heatmaps.
|
523
|
+
labels (List[str]): Labels associated with the heatmaps; e.g., feature names or attention regions. The length of `labels` should match the number of heatmaps, N.
|
524
|
+
type (LeapDataType): The data type, default is LeapDataType.ImageWithHeatmap.
|
525
|
+
|
526
|
+
Example:
|
527
|
+
image_data = np.random.rand(100, 100, 3).astype(np.float32)
|
528
|
+
heatmaps = np.random.rand(3, 100, 100).astype(np.float32)
|
529
|
+
labels = ["heatmap1", "heatmap2", "heatmap3"]
|
530
|
+
leap_image_with_heatmap = LeapImageWithHeatmap(image=image_data, heatmaps=heatmaps, labels=labels)
|
531
|
+
"""
|
532
|
+
image: npt.NDArray[np.float32]
|
533
|
+
heatmaps: npt.NDArray[np.float32]
|
534
|
+
labels: List[str]
|
535
|
+
type: LeapDataType = LeapDataType.ImageWithHeatmap
|
536
|
+
|
537
|
+
def __post_init__(self) -> None:
|
538
|
+
validate_type(self.type, LeapDataType.ImageWithHeatmap)
|
539
|
+
validate_type(type(self.heatmaps), np.ndarray)
|
540
|
+
validate_type(self.heatmaps.dtype, np.float32)
|
541
|
+
validate_type(type(self.image), np.ndarray)
|
542
|
+
validate_type(self.image.dtype, np.float32)
|
543
|
+
validate_type(type(self.labels), list)
|
544
|
+
for label in self.labels:
|
545
|
+
validate_type(type(label), str)
|
546
|
+
if self.heatmaps.shape[0] != len(self.labels):
|
547
|
+
raise LeapValidationError(
|
548
|
+
'Number of heatmaps and labels must be equal')
|
549
|
+
|
550
|
+
def plot_visualizer(self) -> None:
|
551
|
+
"""
|
552
|
+
Display the image with overlaid heatmaps contained in the LeapImageWithHeatmap object.
|
553
|
+
|
554
|
+
Returns:
|
555
|
+
None
|
556
|
+
|
557
|
+
Example:
|
558
|
+
image_data = np.random.rand(100, 100, 3).astype(np.float32)
|
559
|
+
heatmaps = np.random.rand(3, 100, 100).astype(np.float32)
|
560
|
+
labels = ["heatmap1", "heatmap2", "heatmap3"]
|
561
|
+
leap_image_with_heatmap = LeapImageWithHeatmap(image=image_data, heatmaps=heatmaps, labels=labels)
|
562
|
+
leap_image_with_heatmap.plot_visualizer()
|
563
|
+
"""
|
564
|
+
image = self.image
|
565
|
+
heatmaps = self.heatmaps
|
566
|
+
labels = self.labels
|
567
|
+
|
568
|
+
# Plot the base image
|
569
|
+
fig, ax = plt.subplots()
|
570
|
+
fig.patch.set_facecolor('black') # Set the figure background to black
|
571
|
+
ax.set_facecolor('black') # Set the axis background to black
|
572
|
+
ax.imshow(image, cmap='gray')
|
573
|
+
|
574
|
+
# Overlay each heatmap with some transparency
|
575
|
+
for i in range(len(labels)):
|
576
|
+
heatmap = heatmaps[i]
|
577
|
+
ax.imshow(heatmap, cmap='jet', alpha=0.5) # Adjust alpha for transparency
|
578
|
+
ax.set_title(f'Heatmap: {labels[i]}', color='white')
|
579
|
+
|
580
|
+
# Display a colorbar for the heatmap
|
581
|
+
cbar = plt.colorbar(ax.imshow(heatmap, cmap='jet', alpha=0.5))
|
582
|
+
cbar.set_label(labels[i], color='white')
|
583
|
+
cbar.ax.yaxis.set_tick_params(color='white') # Set color for the colorbar ticks
|
584
|
+
plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), color='white') # Set color for the colorbar labels
|
585
|
+
|
586
|
+
plt.axis('off')
|
587
|
+
plt.title('Leap Image With Heatmaps Visualization', color='white')
|
588
|
+
plt.show()
|
589
|
+
|
590
|
+
|
591
|
+
map_leap_data_type_to_visualizer_class = {
|
592
|
+
LeapDataType.Image.value: LeapImage,
|
593
|
+
LeapDataType.Graph.value: LeapGraph,
|
594
|
+
LeapDataType.Text.value: LeapText,
|
595
|
+
LeapDataType.HorizontalBar.value: LeapHorizontalBar,
|
596
|
+
LeapDataType.ImageMask.value: LeapImageMask,
|
597
|
+
LeapDataType.TextMask.value: LeapTextMask,
|
598
|
+
LeapDataType.ImageWithBBox.value: LeapImageWithBBox,
|
599
|
+
LeapDataType.ImageWithHeatmap.value: LeapImageWithHeatmap
|
600
|
+
}
|