code-loader 1.0.49.dev9__py3-none-any.whl → 1.0.49.dev100__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- code_loader/contract/datasetclasses.py +7 -6
- code_loader/contract/visualizer_classes.py +0 -320
- code_loader/inner_leap_binder/leapbinder.py +23 -4
- code_loader/inner_leap_binder/leapbinder_decorators.py +56 -0
- code_loader/leaploader.py +2 -2
- code_loader/utils.py +1 -1
- {code_loader-1.0.49.dev9.dist-info → code_loader-1.0.49.dev100.dist-info}/METADATA +1 -1
- {code_loader-1.0.49.dev9.dist-info → code_loader-1.0.49.dev100.dist-info}/RECORD +10 -9
- {code_loader-1.0.49.dev9.dist-info → code_loader-1.0.49.dev100.dist-info}/LICENSE +0 -0
- {code_loader-1.0.49.dev9.dist-info → code_loader-1.0.49.dev100.dist-info}/WHEEL +0 -0
@@ -36,9 +36,9 @@ class PreprocessResponse:
|
|
36
36
|
data: Any = None
|
37
37
|
sample_ids: Optional[Union[List[str], List[int]]] = None
|
38
38
|
state: Optional[DataStateType] = None
|
39
|
-
sample_id_type: Optional[Type] = None
|
39
|
+
sample_id_type: Optional[Union[Type[str], Type[int]]] = None
|
40
40
|
|
41
|
-
def __post_init__(self):
|
41
|
+
def __post_init__(self) -> None:
|
42
42
|
if self.length is not None and self.sample_ids is None:
|
43
43
|
self.sample_ids = [i for i in range(self.length)]
|
44
44
|
self.sample_id_type = int
|
@@ -47,9 +47,10 @@ class PreprocessResponse:
|
|
47
47
|
if self.sample_id_type is None:
|
48
48
|
self.sample_id_type = str
|
49
49
|
else:
|
50
|
-
raise Exception("length is deprecated.
|
50
|
+
raise Exception("length is deprecated.")
|
51
51
|
|
52
|
-
def __len__(self):
|
52
|
+
def __len__(self) -> int:
|
53
|
+
assert self.sample_ids is not None
|
53
54
|
return len(self.sample_ids)
|
54
55
|
|
55
56
|
|
@@ -90,8 +91,8 @@ VisualizerCallableInterface = Union[
|
|
90
91
|
Callable[..., LeapImageWithHeatmap]
|
91
92
|
]
|
92
93
|
|
93
|
-
|
94
|
-
|
94
|
+
LeapData = Union[LeapImage, LeapText, LeapGraph, LeapHorizontalBar, LeapImageMask, LeapTextMask, LeapImageWithBBox,
|
95
|
+
LeapImageWithHeatmap]
|
95
96
|
|
96
97
|
CustomCallableInterface = Callable[..., Any]
|
97
98
|
|
@@ -4,8 +4,6 @@ import numpy as np
|
|
4
4
|
import numpy.typing as npt
|
5
5
|
from dataclasses import dataclass
|
6
6
|
|
7
|
-
import matplotlib.pyplot as plt # type: ignore
|
8
|
-
|
9
7
|
from code_loader.contract.enums import LeapDataType
|
10
8
|
from code_loader.contract.responsedataclasses import BoundingBox
|
11
9
|
|
@@ -51,34 +49,6 @@ class LeapImage:
|
|
51
49
|
validate_type(len(self.data.shape), 3, 'Image must be of shape 3')
|
52
50
|
validate_type(self.data.shape[2], [1, 3], 'Image channel must be either 3(rgb) or 1(gray)')
|
53
51
|
|
54
|
-
def plot_visualizer(self) -> None:
|
55
|
-
"""
|
56
|
-
Display the image contained in the LeapImage object.
|
57
|
-
|
58
|
-
Returns:
|
59
|
-
None
|
60
|
-
|
61
|
-
Example:
|
62
|
-
image_data = np.random.rand(100, 100, 3).astype(np.float32)
|
63
|
-
leap_image = LeapImage(data=image_data)
|
64
|
-
leap_image.plot_visualizer()
|
65
|
-
"""
|
66
|
-
image_data = self.data
|
67
|
-
|
68
|
-
# If the image has one channel, convert it to a 3-channel image for display
|
69
|
-
if image_data.shape[2] == 1:
|
70
|
-
image_data = np.repeat(image_data, 3, axis=2)
|
71
|
-
|
72
|
-
fig, ax = plt.subplots()
|
73
|
-
fig.patch.set_facecolor('black')
|
74
|
-
ax.set_facecolor('black')
|
75
|
-
|
76
|
-
ax.imshow(image_data)
|
77
|
-
|
78
|
-
plt.axis('off')
|
79
|
-
plt.title('Leap Image Visualization', color='white')
|
80
|
-
plt.show()
|
81
|
-
|
82
52
|
|
83
53
|
@dataclass
|
84
54
|
class LeapImageWithBBox:
|
@@ -106,60 +76,6 @@ class LeapImageWithBBox:
|
|
106
76
|
validate_type(len(self.data.shape), 3, 'Image must be of shape 3')
|
107
77
|
validate_type(self.data.shape[2], [1, 3], 'Image channel must be either 3(rgb) or 1(gray)')
|
108
78
|
|
109
|
-
def plot_visualizer(self) -> None:
|
110
|
-
"""
|
111
|
-
Plot an image with overlaid bounding boxes.
|
112
|
-
|
113
|
-
Returns:
|
114
|
-
None
|
115
|
-
|
116
|
-
Example:
|
117
|
-
image_data = np.random.rand(100, 100, 3).astype(np.float32)
|
118
|
-
bbox = BoundingBox(x=0.5, y=0.5, width=0.2, height=0.2, confidence=0.9, label="object")
|
119
|
-
leap_image_with_bbox = LeapImageWithBBox(data=image_data, bounding_boxes=[bbox])
|
120
|
-
leap_image_with_bbox.plot_visualizer()
|
121
|
-
"""
|
122
|
-
|
123
|
-
image = self.data
|
124
|
-
bounding_boxes = self.bounding_boxes
|
125
|
-
|
126
|
-
# Create figure and axes
|
127
|
-
fig, ax = plt.subplots(1)
|
128
|
-
fig.patch.set_facecolor('black')
|
129
|
-
ax.set_facecolor('black')
|
130
|
-
|
131
|
-
# Display the image
|
132
|
-
ax.imshow(image)
|
133
|
-
ax.set_title('Leap Image With BBox Visualization', color='white')
|
134
|
-
|
135
|
-
# Draw bounding boxes on the image
|
136
|
-
for bbox in bounding_boxes:
|
137
|
-
x, y, width, height = bbox.x, bbox.y, bbox.width, bbox.height
|
138
|
-
confidence, label = bbox.confidence, bbox.label
|
139
|
-
|
140
|
-
# Convert relative coordinates to absolute coordinates
|
141
|
-
abs_x = x * image.shape[1]
|
142
|
-
abs_y = y * image.shape[0]
|
143
|
-
abs_width = width * image.shape[1]
|
144
|
-
abs_height = height * image.shape[0]
|
145
|
-
|
146
|
-
# Create a rectangle patch
|
147
|
-
rect = plt.Rectangle(
|
148
|
-
(abs_x - abs_width / 2, abs_y - abs_height / 2),
|
149
|
-
abs_width, abs_height,
|
150
|
-
linewidth=3, edgecolor='r', facecolor='none'
|
151
|
-
)
|
152
|
-
|
153
|
-
# Add the rectangle to the axes
|
154
|
-
ax.add_patch(rect)
|
155
|
-
|
156
|
-
# Display label and confidence
|
157
|
-
ax.text(abs_x - abs_width / 2, abs_y - abs_height / 2 - 5,
|
158
|
-
f"{label} {confidence:.2f}", color='r', fontsize=8,
|
159
|
-
bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', boxstyle='round,pad=0.3'))
|
160
|
-
|
161
|
-
# Show the image with bounding boxes
|
162
|
-
plt.show()
|
163
79
|
|
164
80
|
@dataclass
|
165
81
|
class LeapGraph:
|
@@ -183,40 +99,6 @@ class LeapGraph:
|
|
183
99
|
validate_type(self.data.dtype, np.float32)
|
184
100
|
validate_type(len(self.data.shape), 2, 'Graph must be of shape 2')
|
185
101
|
|
186
|
-
def plot_visualizer(self) -> None:
|
187
|
-
"""
|
188
|
-
Display the line chart contained in the LeapGraph object.
|
189
|
-
|
190
|
-
Returns:
|
191
|
-
None
|
192
|
-
|
193
|
-
Example:
|
194
|
-
graph_data = np.random.rand(100, 3).astype(np.float32)
|
195
|
-
leap_graph = LeapGraph(data=graph_data)
|
196
|
-
leap_graph.plot_visualizer()
|
197
|
-
"""
|
198
|
-
graph_data = self.data
|
199
|
-
num_variables = graph_data.shape[1]
|
200
|
-
|
201
|
-
fig, ax = plt.subplots(figsize=(10, 6))
|
202
|
-
|
203
|
-
# Set the background color to black
|
204
|
-
fig.patch.set_facecolor('black')
|
205
|
-
ax.set_facecolor('black')
|
206
|
-
|
207
|
-
for i in range(num_variables):
|
208
|
-
plt.plot(graph_data[:, i], label=f'Variable {i + 1}')
|
209
|
-
|
210
|
-
ax.set_xlabel('Data Points', color='white')
|
211
|
-
ax.set_ylabel('Values', color='white')
|
212
|
-
ax.set_title('Leap Graph Visualization', color='white')
|
213
|
-
ax.legend()
|
214
|
-
ax.grid(True, color='white')
|
215
|
-
|
216
|
-
# Change the color of the tick labels to white
|
217
|
-
ax.tick_params(colors='white')
|
218
|
-
|
219
|
-
plt.show()
|
220
102
|
|
221
103
|
@dataclass
|
222
104
|
class LeapText:
|
@@ -242,41 +124,6 @@ class LeapText:
|
|
242
124
|
for value in self.data:
|
243
125
|
validate_type(type(value), str)
|
244
126
|
|
245
|
-
def plot_visualizer(self) -> None:
|
246
|
-
"""
|
247
|
-
Display the text contained in the LeapText object.
|
248
|
-
|
249
|
-
Returns:
|
250
|
-
None
|
251
|
-
|
252
|
-
Example:
|
253
|
-
text_data = ['I', 'ate', 'a', 'banana', '', '', '']
|
254
|
-
leap_text = LeapText(data=text_data)
|
255
|
-
leap_text.plot_visualizer()
|
256
|
-
"""
|
257
|
-
text_data = self.data
|
258
|
-
# Join the text tokens into a single string, ignoring empty strings
|
259
|
-
display_text = ' '.join([token for token in text_data if token])
|
260
|
-
|
261
|
-
# Create a black image using Matplotlib
|
262
|
-
fig, ax = plt.subplots(figsize=(10, 5))
|
263
|
-
fig.patch.set_facecolor('black')
|
264
|
-
ax.set_facecolor('black')
|
265
|
-
|
266
|
-
# Hide the axes
|
267
|
-
ax.axis('off')
|
268
|
-
|
269
|
-
# Set the text properties
|
270
|
-
font_size = 20
|
271
|
-
font_color = 'white'
|
272
|
-
|
273
|
-
# Add the text to the image
|
274
|
-
ax.text(0.5, 0.5, display_text, color=font_color, fontsize=font_size, ha='center', va='center')
|
275
|
-
ax.set_title('Leap Text Visualization', color='white')
|
276
|
-
|
277
|
-
# Display the image
|
278
|
-
plt.show()
|
279
|
-
|
280
127
|
|
281
128
|
@dataclass
|
282
129
|
class LeapHorizontalBar:
|
@@ -309,39 +156,6 @@ class LeapHorizontalBar:
|
|
309
156
|
for label in self.labels:
|
310
157
|
validate_type(type(label), str)
|
311
158
|
|
312
|
-
def plot_visualizer(self) -> None:
|
313
|
-
"""
|
314
|
-
Display the horizontal bar chart contained in the LeapHorizontalBar object.
|
315
|
-
|
316
|
-
Returns:
|
317
|
-
None
|
318
|
-
|
319
|
-
Example:
|
320
|
-
body_data = np.random.rand(5).astype(np.float32)
|
321
|
-
labels = ['Class A', 'Class B', 'Class C', 'Class D', 'Class E']
|
322
|
-
leap_horizontal_bar = LeapHorizontalBar(body=body_data, labels=labels)
|
323
|
-
leap_horizontal_bar.plot_visualizer()
|
324
|
-
"""
|
325
|
-
body_data = self.body
|
326
|
-
labels = self.labels
|
327
|
-
|
328
|
-
fig, ax = plt.subplots()
|
329
|
-
|
330
|
-
fig.patch.set_facecolor('black')
|
331
|
-
ax.set_facecolor('black')
|
332
|
-
|
333
|
-
# Plot horizontal bar chart
|
334
|
-
ax.barh(labels, body_data, color='green')
|
335
|
-
|
336
|
-
# Set the color of the labels and title to white
|
337
|
-
ax.set_xlabel('Scores', color='white')
|
338
|
-
ax.set_title('Leap Horizontal Bar Visualization', color='white')
|
339
|
-
|
340
|
-
# Set the color of the ticks to white
|
341
|
-
ax.tick_params(axis='x', colors='white')
|
342
|
-
ax.tick_params(axis='y', colors='white')
|
343
|
-
|
344
|
-
plt.show()
|
345
159
|
|
346
160
|
@dataclass
|
347
161
|
class LeapImageMask:
|
@@ -379,51 +193,6 @@ class LeapImageMask:
|
|
379
193
|
for label in self.labels:
|
380
194
|
validate_type(type(label), str)
|
381
195
|
|
382
|
-
def plot_visualizer(self) -> None:
|
383
|
-
"""
|
384
|
-
Plots an image with overlaid masks given a LeapImageMask visualizer object.
|
385
|
-
|
386
|
-
Returns:
|
387
|
-
None
|
388
|
-
|
389
|
-
|
390
|
-
Example:
|
391
|
-
image_data = np.random.rand(100, 100, 3).astype(np.float32)
|
392
|
-
mask_data = np.random.randint(0, 2, (100, 100)).astype(np.uint8)
|
393
|
-
labels = ["background", "object"]
|
394
|
-
leap_image_mask = LeapImageMask(image=image_data, mask=mask_data, labels=labels)
|
395
|
-
leap_image_mask.plot_visualizer()
|
396
|
-
"""
|
397
|
-
|
398
|
-
image = self.image
|
399
|
-
mask = self.mask
|
400
|
-
labels = self.labels
|
401
|
-
|
402
|
-
# Create a color map for each label
|
403
|
-
colors = plt.cm.jet(np.linspace(0, 1, len(labels)))
|
404
|
-
|
405
|
-
# Make a copy of the image to draw on
|
406
|
-
overlayed_image = image.copy()
|
407
|
-
|
408
|
-
# Iterate through the unique values in the mask (excluding 0)
|
409
|
-
for i, label in enumerate(labels):
|
410
|
-
# Extract binary mask for the current instance
|
411
|
-
instance_mask = (mask == (i + 1))
|
412
|
-
|
413
|
-
# fill the instance mask with a translucent color
|
414
|
-
overlayed_image[instance_mask] = (
|
415
|
-
overlayed_image[instance_mask] * (1 - 0.5) + np.array(colors[i][:3], dtype=np.uint8) * 0.5)
|
416
|
-
|
417
|
-
# Display the result using matplotlib
|
418
|
-
fig, ax = plt.subplots(1)
|
419
|
-
fig.patch.set_facecolor('black') # Set the figure background to black
|
420
|
-
ax.set_facecolor('black') # Set the axis background to black
|
421
|
-
|
422
|
-
ax.imshow(overlayed_image)
|
423
|
-
ax.set_title('Leap Image With Mask Visualization', color='white')
|
424
|
-
plt.axis('off') # Hide the axis
|
425
|
-
plt.show()
|
426
|
-
|
427
196
|
|
428
197
|
@dataclass
|
429
198
|
class LeapTextMask:
|
@@ -461,55 +230,6 @@ class LeapTextMask:
|
|
461
230
|
for label in self.labels:
|
462
231
|
validate_type(type(label), str)
|
463
232
|
|
464
|
-
def plot_visualizer(self) -> None:
|
465
|
-
"""
|
466
|
-
Plots text with overlaid masks given a LeapTextMask visualizer object.
|
467
|
-
|
468
|
-
Returns:
|
469
|
-
None
|
470
|
-
|
471
|
-
Example:
|
472
|
-
text_data = ['I', 'ate', 'a', 'banana', '', '', '']
|
473
|
-
mask_data = np.array([0, 0, 0, 1, 0, 0, 0]).astype(np.uint8)
|
474
|
-
labels = ["object"]
|
475
|
-
leap_text_mask = LeapTextMask(text=text_data, mask=mask_data, labels=labels)
|
476
|
-
"""
|
477
|
-
|
478
|
-
text_data = self.text
|
479
|
-
mask_data = self.mask
|
480
|
-
labels = self.labels
|
481
|
-
|
482
|
-
# Create a color map for each label
|
483
|
-
colors = plt.cm.jet(np.linspace(0, 1, len(labels)))
|
484
|
-
|
485
|
-
# Create a figure and axis
|
486
|
-
fig, ax = plt.subplots()
|
487
|
-
|
488
|
-
# Set background to black
|
489
|
-
fig.patch.set_facecolor('black')
|
490
|
-
ax.set_facecolor('black')
|
491
|
-
ax.set_title('Leap Text Mask Visualization', color='white')
|
492
|
-
ax.axis('off')
|
493
|
-
|
494
|
-
# Set initial position
|
495
|
-
x_pos, y_pos = 0.01, 0.5 # Adjusted initial position for better visibility
|
496
|
-
|
497
|
-
# Display the text with colors
|
498
|
-
for token, mask_value in zip(text_data, mask_data):
|
499
|
-
if mask_value > 0:
|
500
|
-
color = colors[mask_value % len(colors)]
|
501
|
-
bbox = dict(facecolor=color, edgecolor='none',
|
502
|
-
boxstyle='round,pad=0.3') # Background color for masked tokens
|
503
|
-
else:
|
504
|
-
bbox = None
|
505
|
-
|
506
|
-
ax.text(x_pos, y_pos, token, fontsize=12, color='white', ha='left', va='center', bbox=bbox)
|
507
|
-
|
508
|
-
# Update the x position for the next token
|
509
|
-
x_pos += len(token) * 0.03 + 0.02 # Adjust the spacing between tokens
|
510
|
-
|
511
|
-
plt.show()
|
512
|
-
|
513
233
|
|
514
234
|
@dataclass
|
515
235
|
class LeapImageWithHeatmap:
|
@@ -547,46 +267,6 @@ class LeapImageWithHeatmap:
|
|
547
267
|
raise LeapValidationError(
|
548
268
|
'Number of heatmaps and labels must be equal')
|
549
269
|
|
550
|
-
def plot_visualizer(self) -> None:
|
551
|
-
"""
|
552
|
-
Display the image with overlaid heatmaps contained in the LeapImageWithHeatmap object.
|
553
|
-
|
554
|
-
Returns:
|
555
|
-
None
|
556
|
-
|
557
|
-
Example:
|
558
|
-
image_data = np.random.rand(100, 100, 3).astype(np.float32)
|
559
|
-
heatmaps = np.random.rand(3, 100, 100).astype(np.float32)
|
560
|
-
labels = ["heatmap1", "heatmap2", "heatmap3"]
|
561
|
-
leap_image_with_heatmap = LeapImageWithHeatmap(image=image_data, heatmaps=heatmaps, labels=labels)
|
562
|
-
leap_image_with_heatmap.plot_visualizer()
|
563
|
-
"""
|
564
|
-
image = self.image
|
565
|
-
heatmaps = self.heatmaps
|
566
|
-
labels = self.labels
|
567
|
-
|
568
|
-
# Plot the base image
|
569
|
-
fig, ax = plt.subplots()
|
570
|
-
fig.patch.set_facecolor('black') # Set the figure background to black
|
571
|
-
ax.set_facecolor('black') # Set the axis background to black
|
572
|
-
ax.imshow(image, cmap='gray')
|
573
|
-
|
574
|
-
# Overlay each heatmap with some transparency
|
575
|
-
for i in range(len(labels)):
|
576
|
-
heatmap = heatmaps[i]
|
577
|
-
ax.imshow(heatmap, cmap='jet', alpha=0.5) # Adjust alpha for transparency
|
578
|
-
ax.set_title(f'Heatmap: {labels[i]}', color='white')
|
579
|
-
|
580
|
-
# Display a colorbar for the heatmap
|
581
|
-
cbar = plt.colorbar(ax.imshow(heatmap, cmap='jet', alpha=0.5))
|
582
|
-
cbar.set_label(labels[i], color='white')
|
583
|
-
cbar.ax.yaxis.set_tick_params(color='white') # Set color for the colorbar ticks
|
584
|
-
plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), color='white') # Set color for the colorbar labels
|
585
|
-
|
586
|
-
plt.axis('off')
|
587
|
-
plt.title('Leap Image With Heatmaps Visualization', color='white')
|
588
|
-
plt.show()
|
589
|
-
|
590
270
|
|
591
271
|
map_leap_data_type_to_visualizer_class = {
|
592
272
|
LeapDataType.Image.value: LeapImage,
|
@@ -1,4 +1,5 @@
|
|
1
1
|
import inspect
|
2
|
+
import sys
|
2
3
|
from typing import Callable, List, Optional, Dict, Any, Type, Union
|
3
4
|
|
4
5
|
import numpy as np
|
@@ -8,7 +9,7 @@ from code_loader.contract.datasetclasses import SectionCallableInterface, InputH
|
|
8
9
|
GroundTruthHandler, MetadataHandler, DatasetIntegrationSetup, VisualizerHandler, PreprocessResponse, \
|
9
10
|
PreprocessHandler, VisualizerCallableInterface, CustomLossHandler, CustomCallableInterface, PredictionTypeHandler, \
|
10
11
|
MetadataSectionCallableInterface, UnlabeledDataPreprocessHandler, CustomLayerHandler, MetricHandler, \
|
11
|
-
CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs,
|
12
|
+
CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, LeapData, \
|
12
13
|
CustomMultipleReturnCallableInterfaceMultiArgs, DatasetBaseHandler, custom_latent_space_attribute, RawInputsForHeatmap
|
13
14
|
from code_loader.contract.enums import LeapDataType, DataStateEnum, DataStateType, MetricDirection
|
14
15
|
from code_loader.contract.responsedataclasses import DatasetTestResultPayload
|
@@ -37,6 +38,8 @@ class LeapBinder:
|
|
37
38
|
self._encoder_names: List[str] = list()
|
38
39
|
self._extend_with_default_visualizers()
|
39
40
|
|
41
|
+
self.batch_size_to_validate: Optional[int] = None
|
42
|
+
|
40
43
|
def _extend_with_default_visualizers(self) -> None:
|
41
44
|
self.set_visualizer(function=default_image_visualizer, name=DefaultVisualizer.Image.value,
|
42
45
|
visualizer_type=LeapDataType.Image)
|
@@ -103,7 +106,7 @@ class LeapBinder:
|
|
103
106
|
if visualizer_type.value not in map_leap_data_type_to_visualizer_class:
|
104
107
|
raise Exception(
|
105
108
|
f'The visualizer_type is invalid. current visualizer_type: {visualizer_type}, ' # type: ignore[attr-defined]
|
106
|
-
f'should be one of : {", ".join([arg.__name__ for arg in
|
109
|
+
f'should be one of : {", ".join([arg.__name__ for arg in LeapData.__args__])}')
|
107
110
|
|
108
111
|
func_annotations = function.__annotations__
|
109
112
|
if "return" not in func_annotations:
|
@@ -113,10 +116,10 @@ class LeapBinder:
|
|
113
116
|
f"https://docs.python.org/3/library/typing.html")
|
114
117
|
else:
|
115
118
|
return_type = func_annotations["return"]
|
116
|
-
if return_type not in
|
119
|
+
if return_type not in LeapData.__args__: # type: ignore[attr-defined]
|
117
120
|
raise Exception(
|
118
121
|
f'The return type of function {function.__name__} is invalid. current return type: {return_type}, ' # type: ignore[attr-defined]
|
119
|
-
f'should be one of : {", ".join([arg.__name__ for arg in
|
122
|
+
f'should be one of : {", ".join([arg.__name__ for arg in LeapData.__args__])}')
|
120
123
|
|
121
124
|
expected_return_type = map_leap_data_type_to_visualizer_class[visualizer_type.value]
|
122
125
|
if not issubclass(return_type, expected_return_type):
|
@@ -430,6 +433,7 @@ class LeapBinder:
|
|
430
433
|
def check_handler(
|
431
434
|
preprocess_response: PreprocessResponse, test_result: List[DatasetTestResultPayload],
|
432
435
|
dataset_base_handler: Union[DatasetBaseHandler, MetadataHandler]) -> List[DatasetTestResultPayload]:
|
436
|
+
assert preprocess_response.sample_ids is not None
|
433
437
|
raw_result = dataset_base_handler.function(preprocess_response.sample_ids[0], preprocess_response)
|
434
438
|
handler_type = 'metadata' if isinstance(dataset_base_handler, MetadataHandler) else None
|
435
439
|
if isinstance(dataset_base_handler, MetadataHandler) and isinstance(raw_result, dict):
|
@@ -471,4 +475,19 @@ class LeapBinder:
|
|
471
475
|
self.check_handlers(preprocess_result)
|
472
476
|
print("Successful!")
|
473
477
|
|
478
|
+
def set_batch_size_to_validate(self, batch_size: int):
|
479
|
+
self.batch_size_to_validate = batch_size
|
480
|
+
|
481
|
+
@staticmethod
|
482
|
+
def init():
|
483
|
+
available_functions = inspect.getmembers(sys.modules[__name__], inspect.isfunction)
|
484
|
+
for func_name, func in available_functions:
|
485
|
+
if 'tensorleap_custom_metric' in str(func):
|
486
|
+
try:
|
487
|
+
func()
|
488
|
+
except:
|
489
|
+
pass
|
490
|
+
|
491
|
+
|
492
|
+
|
474
493
|
|
@@ -0,0 +1,56 @@
|
|
1
|
+
from typing import Optional, Union
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
|
5
|
+
from code_loader.contract.datasetclasses import CustomCallableInterfaceMultiArgs, \
|
6
|
+
CustomMultipleReturnCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs
|
7
|
+
from code_loader.contract.enums import MetricDirection
|
8
|
+
from code_loader import leap_binder
|
9
|
+
|
10
|
+
|
11
|
+
def tensorleap_custom_metric(name: str, direction: Optional[MetricDirection] = MetricDirection.Downward):
|
12
|
+
def decorating_function(
|
13
|
+
user_function: Union[CustomCallableInterfaceMultiArgs,
|
14
|
+
CustomMultipleReturnCallableInterfaceMultiArgs,
|
15
|
+
ConfusionMatrixCallableInterfaceMultiArgs]
|
16
|
+
):
|
17
|
+
|
18
|
+
leap_binder.add_custom_metric(user_function, name, direction)
|
19
|
+
|
20
|
+
def _validate_custom_metric_input_args(*args, **kwargs):
|
21
|
+
for i, arg in enumerate(args):
|
22
|
+
assert isinstance(arg, np.ndarray), (f'tensorleap_custom_metric validation failed: '
|
23
|
+
f'Argument #{i} should be a numpy array. Got {type(arg)}.')
|
24
|
+
if leap_binder.batch_size_to_validate:
|
25
|
+
assert arg.shape[0] == leap_binder.batch_size_to_validate, \
|
26
|
+
(f'tensorleap_custom_metric validation failed: Argument #{i} '
|
27
|
+
f'first dim should be as the batch size. Got {arg.shape[0]} '
|
28
|
+
f'instead of {leap_binder.batch_size_to_validate}')
|
29
|
+
|
30
|
+
for _arg_name, arg in kwargs.items():
|
31
|
+
assert isinstance(arg, np.ndarray), (f'tensorleap_custom_metric validation failed: '
|
32
|
+
f'Argument {_arg_name} should be a numpy array. Got {type(arg)}.')
|
33
|
+
if leap_binder.batch_size_to_validate:
|
34
|
+
assert arg.shape[0] == leap_binder.batch_size_to_validate, \
|
35
|
+
(f'tensorleap_custom_metric validation failed: Argument {_arg_name} '
|
36
|
+
f'first dim should be as the batch size. Got {arg.shape[0]} '
|
37
|
+
f'instead of {leap_binder.batch_size_to_validate}')
|
38
|
+
|
39
|
+
def _validate_custom_metric_result(result):
|
40
|
+
assert isinstance(result, np.ndarray), (f'tensorleap_custom_metric validation failed: '
|
41
|
+
f'The return type should be a numpy array. Got {type(result)}.')
|
42
|
+
assert len(result.shape) == 1, (f'tensorleap_custom_metric validation failed: '
|
43
|
+
f'The return shape should be 1D. Got {len(result.shape)}D.')
|
44
|
+
if leap_binder.batch_size_to_validate:
|
45
|
+
assert result.shape[0] == leap_binder.batch_size_to_validate, \
|
46
|
+
f'tensorleap_custom_metric validation failed: The return len should be as the batch size.'
|
47
|
+
|
48
|
+
def inner(*args, **kwargs):
|
49
|
+
_validate_custom_metric_input_args(*args, **kwargs)
|
50
|
+
result = user_function(*args, **kwargs)
|
51
|
+
_validate_custom_metric_result(result)
|
52
|
+
return result
|
53
|
+
|
54
|
+
return inner
|
55
|
+
|
56
|
+
return decorating_function
|
code_loader/leaploader.py
CHANGED
@@ -12,7 +12,7 @@ import numpy as np
|
|
12
12
|
import numpy.typing as npt
|
13
13
|
|
14
14
|
from code_loader.contract.datasetclasses import DatasetSample, DatasetBaseHandler, GroundTruthHandler, \
|
15
|
-
PreprocessResponse, VisualizerHandler,
|
15
|
+
PreprocessResponse, VisualizerHandler, LeapData, CustomLossHandler, \
|
16
16
|
PredictionTypeHandler, MetadataHandler, CustomLayerHandler, MetricHandler
|
17
17
|
from code_loader.contract.enums import DataStateEnum, TestingSectionEnum, DataStateType, DatasetMetadataType
|
18
18
|
from code_loader.contract.exceptions import DatasetScriptException
|
@@ -201,7 +201,7 @@ class LeapLoader:
|
|
201
201
|
return all_dataset_base_handlers
|
202
202
|
|
203
203
|
def run_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]],
|
204
|
-
) ->
|
204
|
+
) -> LeapData:
|
205
205
|
return self.visualizer_by_name()[visualizer_name].function(**input_tensors_by_arg_name)
|
206
206
|
|
207
207
|
def run_heatmap_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]
|
code_loader/utils.py
CHANGED
@@ -10,7 +10,7 @@ from code_loader.contract.datasetclasses import SectionCallableInterface, Prepro
|
|
10
10
|
|
11
11
|
|
12
12
|
def to_numpy_return_wrapper(encoder_function: SectionCallableInterface) -> SectionCallableInterface:
|
13
|
-
def numpy_encoder_function(idx: int, samples: PreprocessResponse) -> npt.NDArray[np.float32]:
|
13
|
+
def numpy_encoder_function(idx: Union[int, str], samples: PreprocessResponse) -> npt.NDArray[np.float32]:
|
14
14
|
result = encoder_function(idx, samples)
|
15
15
|
numpy_result: npt.NDArray[np.float32] = np.array(result)
|
16
16
|
return numpy_result
|
@@ -2,11 +2,11 @@ LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
|
|
2
2
|
code_loader/__init__.py,sha256=6MMWr0ObOU7hkqQKgOqp4Zp3I28L7joGC9iCbQYtAJg,241
|
3
3
|
code_loader/code_inegration_processes_manager.py,sha256=XslWOPeNQk4RAFJ_f3tP5Oe3EgcIR7BE7Y8r9Ty73-o,3261
|
4
4
|
code_loader/contract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
|
-
code_loader/contract/datasetclasses.py,sha256=
|
5
|
+
code_loader/contract/datasetclasses.py,sha256=lOIY-h9t4k9NxNsC9GrJhltmhpqRju3AuLA3WVQcCMs,6614
|
6
6
|
code_loader/contract/enums.py,sha256=6Lo7p5CUog68Fd31bCozIuOgIp_IhSiPqWWph2k3OGU,1602
|
7
7
|
code_loader/contract/exceptions.py,sha256=jWqu5i7t-0IG0jGRsKF4DjJdrsdpJjIYpUkN1F4RiyQ,51
|
8
8
|
code_loader/contract/responsedataclasses.py,sha256=w7xVOv2S8Hyb5lqyomMGiKAWXDTSOG-FX1YW39bXD3A,3969
|
9
|
-
code_loader/contract/visualizer_classes.py,sha256=
|
9
|
+
code_loader/contract/visualizer_classes.py,sha256=iIa_O2rKvPTwN5ILCTZvRpsGYiiFABKdwQwfIXGigDo,11928
|
10
10
|
code_loader/experiment_api/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
11
|
code_loader/experiment_api/api.py,sha256=a7wh6Hhe7IaVxu46eV2soSz-yxnmXG3ipU1BBtsEAaQ,2493
|
12
12
|
code_loader/experiment_api/cli_config_utils.py,sha256=n6JMyNrquxql3KKxHhAP8jAzezlRT-PV2KWI95kKsm0,1140
|
@@ -18,12 +18,13 @@ code_loader/experiment_api/types.py,sha256=MY8xFARHwdVA7p4dxyhD60ShmttgTvb4qdp1o
|
|
18
18
|
code_loader/experiment_api/utils.py,sha256=XZHtxge12TS4H4-8PjV3sKuhp8Ud6ojAiIzTZJEqBqc,3304
|
19
19
|
code_loader/experiment_api/workingspace_config_utils.py,sha256=DLzXQCg4dgTV_YgaSbeTVzq-2ja_SQw4zi7LXwKL9cY,990
|
20
20
|
code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
|
21
|
-
code_loader/inner_leap_binder/leapbinder.py,sha256=
|
22
|
-
code_loader/
|
23
|
-
code_loader/
|
21
|
+
code_loader/inner_leap_binder/leapbinder.py,sha256=QXHXEXV5jBCqggDBD7hDHVcgveNb1jeL382iPTa9K-o,25425
|
22
|
+
code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=pZjIVP-zqdOPk785r6G4ycTTvlNiRB-UqQz9_gcPPKY,3133
|
23
|
+
code_loader/leaploader.py,sha256=POUgD6x1GH_iF_eDGz-VLX4DsIl2kddufKVDdrA_K-U,19491
|
24
|
+
code_loader/utils.py,sha256=aw2i_fqW_ADjLB66FWZd9DfpCQ7mPdMyauROC5Nd51I,2197
|
24
25
|
code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
25
26
|
code_loader/visualizers/default_visualizers.py,sha256=VoqO9FN84yXyMjRjHjUTOt2GdTkJRMbHbXJ1cJkREkk,2230
|
26
|
-
code_loader-1.0.49.
|
27
|
-
code_loader-1.0.49.
|
28
|
-
code_loader-1.0.49.
|
29
|
-
code_loader-1.0.49.
|
27
|
+
code_loader-1.0.49.dev100.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
|
28
|
+
code_loader-1.0.49.dev100.dist-info/METADATA,sha256=zgwAvViqFWeiGeHRUr64Ry-umEsndhv8pma3jKkCaoc,895
|
29
|
+
code_loader-1.0.49.dev100.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
30
|
+
code_loader-1.0.49.dev100.dist-info/RECORD,,
|
File without changes
|
File without changes
|