code-loader 1.0.42a0__py3-none-any.whl → 1.0.43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,7 +4,7 @@ from typing import Any, Callable, List, Optional, Dict, Union, Type
4
4
  import numpy as np
5
5
  import numpy.typing as npt
6
6
 
7
- from code_loader.contract.enums import DataStateType, DataStateEnum, LeapDataType, ConfusionMatrixValue, MetricDirection, InstanceAnalysisType
7
+ from code_loader.contract.enums import DataStateType, DataStateEnum, LeapDataType, ConfusionMatrixValue, MetricDirection
8
8
  from code_loader.contract.visualizer_classes import LeapImage, LeapText, LeapGraph, LeapHorizontalBar, \
9
9
  LeapTextMask, LeapImageMask, LeapImageWithBBox, LeapImageWithHeatmap
10
10
 
@@ -36,16 +36,8 @@ class PreprocessResponse:
36
36
  data: Any
37
37
 
38
38
 
39
- @dataclass
40
- class ElementInstance:
41
- name: str
42
- mask: npt.NDArray[np.float32]
43
-
44
-
45
39
  SectionCallableInterface = Callable[[int, PreprocessResponse], npt.NDArray[np.float32]]
46
40
 
47
- InstanceCallableInterface = Callable[[int, PreprocessResponse], List[ElementInstance]]
48
-
49
41
  MetadataSectionCallableInterface = Union[
50
42
  Callable[[int, PreprocessResponse], int],
51
43
  Callable[[int, PreprocessResponse], Dict[str, int]],
@@ -117,7 +109,6 @@ class MetricHandler:
117
109
  arg_names: List[str]
118
110
  direction: Optional[MetricDirection] = MetricDirection.Downward
119
111
 
120
-
121
112
  @dataclass
122
113
  class RawInputsForHeatmap:
123
114
  raw_input_by_vizualizer_arg_name: Dict[str, npt.NDArray[np.float32]]
@@ -143,14 +134,6 @@ class InputHandler(DatasetBaseHandler):
143
134
  shape: Optional[List[int]] = None
144
135
 
145
136
 
146
- @dataclass
147
- class ElementInstanceHandler:
148
- input_name: str
149
- instance_function: InstanceCallableInterface
150
- analysis_type: InstanceAnalysisType
151
-
152
-
153
-
154
137
  @dataclass
155
138
  class GroundTruthHandler(DatasetBaseHandler):
156
139
  shape: Optional[List[int]] = None
@@ -184,7 +167,6 @@ class DatasetIntegrationSetup:
184
167
  unlabeled_data_preprocess: Optional[UnlabeledDataPreprocessHandler] = None
185
168
  visualizers: List[VisualizerHandler] = field(default_factory=list)
186
169
  inputs: List[InputHandler] = field(default_factory=list)
187
- element_instances: List[ElementInstanceHandler] = field(default_factory=list)
188
170
  ground_truths: List[GroundTruthHandler] = field(default_factory=list)
189
171
  metadata: List[MetadataHandler] = field(default_factory=list)
190
172
  prediction_types: List[PredictionTypeHandler] = field(default_factory=list)
@@ -64,9 +64,3 @@ class ConfusionMatrixValue(Enum):
64
64
  class TestingSectionEnum(Enum):
65
65
  Warnings = "Warnings"
66
66
  Errors = "Errors"
67
-
68
-
69
-
70
- class InstanceAnalysisType(Enum):
71
- MaskInput = "MaskInput"
72
- MaskLatentSpace = "MaskLatentSpace"
@@ -4,6 +4,8 @@ import numpy as np
4
4
  import numpy.typing as npt
5
5
  from dataclasses import dataclass
6
6
 
7
+ import matplotlib.pyplot as plt # type: ignore
8
+
7
9
  from code_loader.contract.enums import LeapDataType
8
10
  from code_loader.contract.responsedataclasses import BoundingBox
9
11
 
@@ -34,6 +36,10 @@ class LeapImage:
34
36
  Attributes:
35
37
  data (npt.NDArray[np.float32] | npt.NDArray[np.uint8]): The image data.
36
38
  type (LeapDataType): The data type, default is LeapDataType.Image.
39
+
40
+ Example:
41
+ image_data = np.random.rand(100, 100, 3).astype(np.float32)
42
+ leap_image = LeapImage(data=image_data)
37
43
  """
38
44
  data: Union[npt.NDArray[np.float32], npt.NDArray[np.uint8]]
39
45
  type: LeapDataType = LeapDataType.Image
@@ -45,6 +51,34 @@ class LeapImage:
45
51
  validate_type(len(self.data.shape), 3, 'Image must be of shape 3')
46
52
  validate_type(self.data.shape[2], [1, 3], 'Image channel must be either 3(rgb) or 1(gray)')
47
53
 
54
+ def plot_visualizer(self) -> None:
55
+ """
56
+ Display the image contained in the LeapImage object.
57
+
58
+ Returns:
59
+ None
60
+
61
+ Example:
62
+ image_data = np.random.rand(100, 100, 3).astype(np.float32)
63
+ leap_image = LeapImage(data=image_data)
64
+ leap_image.plot_visualizer()
65
+ """
66
+ image_data = self.data
67
+
68
+ # If the image has one channel, convert it to a 3-channel image for display
69
+ if image_data.shape[2] == 1:
70
+ image_data = np.repeat(image_data, 3, axis=2)
71
+
72
+ fig, ax = plt.subplots()
73
+ fig.patch.set_facecolor('black')
74
+ ax.set_facecolor('black')
75
+
76
+ ax.imshow(image_data)
77
+
78
+ plt.axis('off')
79
+ plt.title('Leap Image Visualization', color='white')
80
+ plt.show()
81
+
48
82
 
49
83
  @dataclass
50
84
  class LeapImageWithBBox:
@@ -55,6 +89,11 @@ class LeapImageWithBBox:
55
89
  data (npt.NDArray[np.float32] | npt.NDArray[np.uint8]): The image data, shaped [H, W, 3] or [H, W, 1].
56
90
  bounding_boxes (List[BoundingBox]): List of Tensorleap bounding boxes objects in relative size to image size.
57
91
  type (LeapDataType): The data type, default is LeapDataType.ImageWithBBox.
92
+
93
+ Example:
94
+ image_data = np.random.rand(100, 100, 3).astype(np.float32)
95
+ bbox = BoundingBox(x=0.5, y=0.5, width=0.2, height=0.2, confidence=0.9, label="object")
96
+ leap_image_with_bbox = LeapImageWithBBox(data=image_data, bounding_boxes=[bbox])
58
97
  """
59
98
  data: Union[npt.NDArray[np.float32], npt.NDArray[np.uint8]]
60
99
  bounding_boxes: List[BoundingBox]
@@ -67,6 +106,60 @@ class LeapImageWithBBox:
67
106
  validate_type(len(self.data.shape), 3, 'Image must be of shape 3')
68
107
  validate_type(self.data.shape[2], [1, 3], 'Image channel must be either 3(rgb) or 1(gray)')
69
108
 
109
+ def plot_visualizer(self) -> None:
110
+ """
111
+ Plot an image with overlaid bounding boxes.
112
+
113
+ Returns:
114
+ None
115
+
116
+ Example:
117
+ image_data = np.random.rand(100, 100, 3).astype(np.float32)
118
+ bbox = BoundingBox(x=0.5, y=0.5, width=0.2, height=0.2, confidence=0.9, label="object")
119
+ leap_image_with_bbox = LeapImageWithBBox(data=image_data, bounding_boxes=[bbox])
120
+ leap_image_with_bbox.plot_visualizer()
121
+ """
122
+
123
+ image = self.data
124
+ bounding_boxes = self.bounding_boxes
125
+
126
+ # Create figure and axes
127
+ fig, ax = plt.subplots(1)
128
+ fig.patch.set_facecolor('black')
129
+ ax.set_facecolor('black')
130
+
131
+ # Display the image
132
+ ax.imshow(image)
133
+ ax.set_title('Leap Image With BBox Visualization', color='white')
134
+
135
+ # Draw bounding boxes on the image
136
+ for bbox in bounding_boxes:
137
+ x, y, width, height = bbox.x, bbox.y, bbox.width, bbox.height
138
+ confidence, label = bbox.confidence, bbox.label
139
+
140
+ # Convert relative coordinates to absolute coordinates
141
+ abs_x = x * image.shape[1]
142
+ abs_y = y * image.shape[0]
143
+ abs_width = width * image.shape[1]
144
+ abs_height = height * image.shape[0]
145
+
146
+ # Create a rectangle patch
147
+ rect = plt.Rectangle(
148
+ (abs_x - abs_width / 2, abs_y - abs_height / 2),
149
+ abs_width, abs_height,
150
+ linewidth=3, edgecolor='r', facecolor='none'
151
+ )
152
+
153
+ # Add the rectangle to the axes
154
+ ax.add_patch(rect)
155
+
156
+ # Display label and confidence
157
+ ax.text(abs_x - abs_width / 2, abs_y - abs_height / 2 - 5,
158
+ f"{label} {confidence:.2f}", color='r', fontsize=8,
159
+ bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', boxstyle='round,pad=0.3'))
160
+
161
+ # Show the image with bounding boxes
162
+ plt.show()
70
163
 
71
164
  @dataclass
72
165
  class LeapGraph:
@@ -76,6 +169,10 @@ class LeapGraph:
76
169
  Attributes:
77
170
  data (npt.NDArray[np.float32]): The array data, shaped [M, N] where M is the number of data points and N is the number of variables.
78
171
  type (LeapDataType): The data type, default is LeapDataType.Graph.
172
+
173
+ Example:
174
+ graph_data = np.random.rand(100, 3).astype(np.float32)
175
+ leap_graph = LeapGraph(data=graph_data)
79
176
  """
80
177
  data: npt.NDArray[np.float32]
81
178
  type: LeapDataType = LeapDataType.Graph
@@ -86,6 +183,40 @@ class LeapGraph:
86
183
  validate_type(self.data.dtype, np.float32)
87
184
  validate_type(len(self.data.shape), 2, 'Graph must be of shape 2')
88
185
 
186
+ def plot_visualizer(self) -> None:
187
+ """
188
+ Display the line chart contained in the LeapGraph object.
189
+
190
+ Returns:
191
+ None
192
+
193
+ Example:
194
+ graph_data = np.random.rand(100, 3).astype(np.float32)
195
+ leap_graph = LeapGraph(data=graph_data)
196
+ leap_graph.plot_visualizer()
197
+ """
198
+ graph_data = self.data
199
+ num_variables = graph_data.shape[1]
200
+
201
+ fig, ax = plt.subplots(figsize=(10, 6))
202
+
203
+ # Set the background color to black
204
+ fig.patch.set_facecolor('black')
205
+ ax.set_facecolor('black')
206
+
207
+ for i in range(num_variables):
208
+ plt.plot(graph_data[:, i], label=f'Variable {i + 1}')
209
+
210
+ ax.set_xlabel('Data Points', color='white')
211
+ ax.set_ylabel('Values', color='white')
212
+ ax.set_title('Leap Graph Visualization', color='white')
213
+ ax.legend()
214
+ ax.grid(True, color='white')
215
+
216
+ # Change the color of the tick labels to white
217
+ ax.tick_params(colors='white')
218
+
219
+ plt.show()
89
220
 
90
221
  @dataclass
91
222
  class LeapText:
@@ -96,6 +227,11 @@ class LeapText:
96
227
  data (List[str]): The text data, consisting of a list of text tokens. If the model requires fixed-length inputs,
97
228
  it is recommended to maintain the fixed length, using empty strings ('') instead of padding tokens ('PAD') e.g., ['I', 'ate', 'a', 'banana', '', '', '', ...]
98
229
  type (LeapDataType): The data type, default is LeapDataType.Text.
230
+
231
+ Example:
232
+ text_data = ['I', 'ate', 'a', 'banana', '', '', '']
233
+ leap_text = LeapText(data=text_data) # Create LeapText object
234
+ LeapText(leap_text)
99
235
  """
100
236
  data: List[str]
101
237
  type: LeapDataType = LeapDataType.Text
@@ -106,6 +242,41 @@ class LeapText:
106
242
  for value in self.data:
107
243
  validate_type(type(value), str)
108
244
 
245
+ def plot_visualizer(self) -> None:
246
+ """
247
+ Display the text contained in the LeapText object.
248
+
249
+ Returns:
250
+ None
251
+
252
+ Example:
253
+ text_data = ['I', 'ate', 'a', 'banana', '', '', '']
254
+ leap_text = LeapText(data=text_data)
255
+ leap_text.plot_visualizer()
256
+ """
257
+ text_data = self.data
258
+ # Join the text tokens into a single string, ignoring empty strings
259
+ display_text = ' '.join([token for token in text_data if token])
260
+
261
+ # Create a black image using Matplotlib
262
+ fig, ax = plt.subplots(figsize=(10, 5))
263
+ fig.patch.set_facecolor('black')
264
+ ax.set_facecolor('black')
265
+
266
+ # Hide the axes
267
+ ax.axis('off')
268
+
269
+ # Set the text properties
270
+ font_size = 20
271
+ font_color = 'white'
272
+
273
+ # Add the text to the image
274
+ ax.text(0.5, 0.5, display_text, color=font_color, fontsize=font_size, ha='center', va='center')
275
+ ax.set_title('Leap Text Visualization', color='white')
276
+
277
+ # Display the image
278
+ plt.show()
279
+
109
280
 
110
281
  @dataclass
111
282
  class LeapHorizontalBar:
@@ -118,6 +289,11 @@ class LeapHorizontalBar:
118
289
  labels (List[str]): Labels for the horizontal bar; e.g., when visualizing the model's classification output, labels are the class names.
119
290
  Length of `body` should match the length of `labels`, C.
120
291
  type (LeapDataType): The data type, default is LeapDataType.HorizontalBar.
292
+
293
+ Example:
294
+ body_data = np.random.rand(5).astype(np.float32)
295
+ labels = ['Class A', 'Class B', 'Class C', 'Class D', 'Class E']
296
+ leap_horizontal_bar = LeapHorizontalBar(body=body_data, labels=labels)
121
297
  """
122
298
  body: npt.NDArray[np.float32]
123
299
  labels: List[str]
@@ -133,6 +309,39 @@ class LeapHorizontalBar:
133
309
  for label in self.labels:
134
310
  validate_type(type(label), str)
135
311
 
312
+ def plot_visualizer(self) -> None:
313
+ """
314
+ Display the horizontal bar chart contained in the LeapHorizontalBar object.
315
+
316
+ Returns:
317
+ None
318
+
319
+ Example:
320
+ body_data = np.random.rand(5).astype(np.float32)
321
+ labels = ['Class A', 'Class B', 'Class C', 'Class D', 'Class E']
322
+ leap_horizontal_bar = LeapHorizontalBar(body=body_data, labels=labels)
323
+ leap_horizontal_bar.plot_visualizer()
324
+ """
325
+ body_data = self.body
326
+ labels = self.labels
327
+
328
+ fig, ax = plt.subplots()
329
+
330
+ fig.patch.set_facecolor('black')
331
+ ax.set_facecolor('black')
332
+
333
+ # Plot horizontal bar chart
334
+ ax.barh(labels, body_data, color='green')
335
+
336
+ # Set the color of the labels and title to white
337
+ ax.set_xlabel('Scores', color='white')
338
+ ax.set_title('Leap Horizontal Bar Visualization', color='white')
339
+
340
+ # Set the color of the ticks to white
341
+ ax.tick_params(axis='x', colors='white')
342
+ ax.tick_params(axis='y', colors='white')
343
+
344
+ plt.show()
136
345
 
137
346
  @dataclass
138
347
  class LeapImageMask:
@@ -145,6 +354,12 @@ class LeapImageMask:
145
354
  image (npt.NDArray[np.float32] | npt.NDArray[np.uint8]): The image data, shaped [H, W, 3] or shaped [H, W, 1].
146
355
  labels (List[str]): Labels associated with the mask regions; e.g., class names for segmented objects. The length of `labels` should match the number of unique values in `mask`.
147
356
  type (LeapDataType): The data type, default is LeapDataType.ImageMask.
357
+
358
+ Example:
359
+ image_data = np.random.rand(100, 100, 3).astype(np.float32)
360
+ mask_data = np.random.randint(0, 2, (100, 100)).astype(np.uint8)
361
+ labels = ["background", "object"]
362
+ leap_image_mask = LeapImageMask(image=image_data, mask=mask_data, labels=labels)
148
363
  """
149
364
  mask: npt.NDArray[np.uint8]
150
365
  image: Union[npt.NDArray[np.float32], npt.NDArray[np.uint8]]
@@ -164,6 +379,51 @@ class LeapImageMask:
164
379
  for label in self.labels:
165
380
  validate_type(type(label), str)
166
381
 
382
+ def plot_visualizer(self) -> None:
383
+ """
384
+ Plots an image with overlaid masks given a LeapImageMask visualizer object.
385
+
386
+ Returns:
387
+ None
388
+
389
+
390
+ Example:
391
+ image_data = np.random.rand(100, 100, 3).astype(np.float32)
392
+ mask_data = np.random.randint(0, 2, (100, 100)).astype(np.uint8)
393
+ labels = ["background", "object"]
394
+ leap_image_mask = LeapImageMask(image=image_data, mask=mask_data, labels=labels)
395
+ leap_image_mask.plot_visualizer()
396
+ """
397
+
398
+ image = self.image
399
+ mask = self.mask
400
+ labels = self.labels
401
+
402
+ # Create a color map for each label
403
+ colors = plt.cm.jet(np.linspace(0, 1, len(labels)))
404
+
405
+ # Make a copy of the image to draw on
406
+ overlayed_image = image.copy()
407
+
408
+ # Iterate through the unique values in the mask (excluding 0)
409
+ for i, label in enumerate(labels):
410
+ # Extract binary mask for the current instance
411
+ instance_mask = (mask == (i + 1))
412
+
413
+ # fill the instance mask with a translucent color
414
+ overlayed_image[instance_mask] = (
415
+ overlayed_image[instance_mask] * (1 - 0.5) + np.array(colors[i][:3], dtype=np.uint8) * 0.5)
416
+
417
+ # Display the result using matplotlib
418
+ fig, ax = plt.subplots(1)
419
+ fig.patch.set_facecolor('black') # Set the figure background to black
420
+ ax.set_facecolor('black') # Set the axis background to black
421
+
422
+ ax.imshow(overlayed_image)
423
+ ax.set_title('Leap Image With Mask Visualization', color='white')
424
+ plt.axis('off') # Hide the axis
425
+ plt.show()
426
+
167
427
 
168
428
  @dataclass
169
429
  class LeapTextMask:
@@ -176,6 +436,13 @@ class LeapTextMask:
176
436
  text (List[str]): The text data, consisting of a list of text tokens, length of L.
177
437
  labels (List[str]): Labels associated with the masked tokens; e.g., named entities or sentiment categories. The length of `labels` should match the number of unique values in `mask`.
178
438
  type (LeapDataType): The data type, default is LeapDataType.TextMask.
439
+
440
+ Example:
441
+ text_data = ['I', 'ate', 'a', 'banana', '', '', '']
442
+ mask_data = np.array([0, 0, 0, 1, 0, 0, 0]).astype(np.uint8)
443
+ labels = ["object"]
444
+ leap_text_mask = LeapTextMask(text=text_data, mask=mask_data, labels=labels)
445
+ leap_text_mask.plot_visualizer()
179
446
  """
180
447
  mask: npt.NDArray[np.uint8]
181
448
  text: List[str]
@@ -194,6 +461,55 @@ class LeapTextMask:
194
461
  for label in self.labels:
195
462
  validate_type(type(label), str)
196
463
 
464
+ def plot_visualizer(self) -> None:
465
+ """
466
+ Plots text with overlaid masks given a LeapTextMask visualizer object.
467
+
468
+ Returns:
469
+ None
470
+
471
+ Example:
472
+ text_data = ['I', 'ate', 'a', 'banana', '', '', '']
473
+ mask_data = np.array([0, 0, 0, 1, 0, 0, 0]).astype(np.uint8)
474
+ labels = ["object"]
475
+ leap_text_mask = LeapTextMask(text=text_data, mask=mask_data, labels=labels)
476
+ """
477
+
478
+ text_data = self.text
479
+ mask_data = self.mask
480
+ labels = self.labels
481
+
482
+ # Create a color map for each label
483
+ colors = plt.cm.jet(np.linspace(0, 1, len(labels)))
484
+
485
+ # Create a figure and axis
486
+ fig, ax = plt.subplots()
487
+
488
+ # Set background to black
489
+ fig.patch.set_facecolor('black')
490
+ ax.set_facecolor('black')
491
+ ax.set_title('Leap Text Mask Visualization', color='white')
492
+ ax.axis('off')
493
+
494
+ # Set initial position
495
+ x_pos, y_pos = 0.01, 0.5 # Adjusted initial position for better visibility
496
+
497
+ # Display the text with colors
498
+ for token, mask_value in zip(text_data, mask_data):
499
+ if mask_value > 0:
500
+ color = colors[mask_value % len(colors)]
501
+ bbox = dict(facecolor=color, edgecolor='none',
502
+ boxstyle='round,pad=0.3') # Background color for masked tokens
503
+ else:
504
+ bbox = None
505
+
506
+ ax.text(x_pos, y_pos, token, fontsize=12, color='white', ha='left', va='center', bbox=bbox)
507
+
508
+ # Update the x position for the next token
509
+ x_pos += len(token) * 0.03 + 0.02 # Adjust the spacing between tokens
510
+
511
+ plt.show()
512
+
197
513
 
198
514
  @dataclass
199
515
  class LeapImageWithHeatmap:
@@ -206,6 +522,12 @@ class LeapImageWithHeatmap:
206
522
  heatmaps (npt.NDArray[np.float32]): The heatmap data, shaped [N, H, W], where N is the number of heatmaps.
207
523
  labels (List[str]): Labels associated with the heatmaps; e.g., feature names or attention regions. The length of `labels` should match the number of heatmaps, N.
208
524
  type (LeapDataType): The data type, default is LeapDataType.ImageWithHeatmap.
525
+
526
+ Example:
527
+ image_data = np.random.rand(100, 100, 3).astype(np.float32)
528
+ heatmaps = np.random.rand(3, 100, 100).astype(np.float32)
529
+ labels = ["heatmap1", "heatmap2", "heatmap3"]
530
+ leap_image_with_heatmap = LeapImageWithHeatmap(image=image_data, heatmaps=heatmaps, labels=labels)
209
531
  """
210
532
  image: npt.NDArray[np.float32]
211
533
  heatmaps: npt.NDArray[np.float32]
@@ -225,6 +547,46 @@ class LeapImageWithHeatmap:
225
547
  raise LeapValidationError(
226
548
  'Number of heatmaps and labels must be equal')
227
549
 
550
+ def plot_visualizer(self) -> None:
551
+ """
552
+ Display the image with overlaid heatmaps contained in the LeapImageWithHeatmap object.
553
+
554
+ Returns:
555
+ None
556
+
557
+ Example:
558
+ image_data = np.random.rand(100, 100, 3).astype(np.float32)
559
+ heatmaps = np.random.rand(3, 100, 100).astype(np.float32)
560
+ labels = ["heatmap1", "heatmap2", "heatmap3"]
561
+ leap_image_with_heatmap = LeapImageWithHeatmap(image=image_data, heatmaps=heatmaps, labels=labels)
562
+ leap_image_with_heatmap.plot_visualizer()
563
+ """
564
+ image = self.image
565
+ heatmaps = self.heatmaps
566
+ labels = self.labels
567
+
568
+ # Plot the base image
569
+ fig, ax = plt.subplots()
570
+ fig.patch.set_facecolor('black') # Set the figure background to black
571
+ ax.set_facecolor('black') # Set the axis background to black
572
+ ax.imshow(image, cmap='gray')
573
+
574
+ # Overlay each heatmap with some transparency
575
+ for i in range(len(labels)):
576
+ heatmap = heatmaps[i]
577
+ ax.imshow(heatmap, cmap='jet', alpha=0.5) # Adjust alpha for transparency
578
+ ax.set_title(f'Heatmap: {labels[i]}', color='white')
579
+
580
+ # Display a colorbar for the heatmap
581
+ cbar = plt.colorbar(ax.imshow(heatmap, cmap='jet', alpha=0.5))
582
+ cbar.set_label(labels[i], color='white')
583
+ cbar.ax.yaxis.set_tick_params(color='white') # Set color for the colorbar ticks
584
+ plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), color='white') # Set color for the colorbar labels
585
+
586
+ plt.axis('off')
587
+ plt.title('Leap Image With Heatmaps Visualization', color='white')
588
+ plt.show()
589
+
228
590
 
229
591
  map_leap_data_type_to_visualizer_class = {
230
592
  LeapDataType.Image.value: LeapImage,
@@ -9,9 +9,8 @@ from code_loader.contract.datasetclasses import SectionCallableInterface, InputH
9
9
  PreprocessHandler, VisualizerCallableInterface, CustomLossHandler, CustomCallableInterface, PredictionTypeHandler, \
10
10
  MetadataSectionCallableInterface, UnlabeledDataPreprocessHandler, CustomLayerHandler, MetricHandler, \
11
11
  CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, VisualizerCallableReturnType, \
12
- CustomMultipleReturnCallableInterfaceMultiArgs, DatasetBaseHandler, custom_latent_space_attribute, \
13
- RawInputsForHeatmap, InstanceCallableInterface, ElementInstanceHandler
14
- from code_loader.contract.enums import LeapDataType, DataStateEnum, DataStateType, MetricDirection, InstanceAnalysisType
12
+ CustomMultipleReturnCallableInterfaceMultiArgs, DatasetBaseHandler, custom_latent_space_attribute, RawInputsForHeatmap
13
+ from code_loader.contract.enums import LeapDataType, DataStateEnum, DataStateType, MetricDirection
15
14
  from code_loader.contract.responsedataclasses import DatasetTestResultPayload
16
15
  from code_loader.contract.visualizer_classes import map_leap_data_type_to_visualizer_class
17
16
  from code_loader.utils import to_numpy_return_wrapper, get_shape
@@ -203,10 +202,6 @@ class LeapBinder:
203
202
 
204
203
  self._encoder_names.append(name)
205
204
 
206
- def set_instance_element(self, input_name: str, instance_function: Optional[InstanceCallableInterface] = None,
207
- analysis_type: InstanceAnalysisType = InstanceAnalysisType.MaskInput) -> None:
208
- self.setup_container.element_instances.append(ElementInstanceHandler(input_name, instance_function, analysis_type))
209
-
210
205
  def add_custom_loss(self, function: CustomCallableInterface, name: str) -> None:
211
206
  """
212
207
  Add a custom loss function to the setup.
@@ -366,8 +361,7 @@ class LeapBinder:
366
361
  custom_layer.kernel_index = kernel_index
367
362
 
368
363
  if use_custom_latent_space and not hasattr(custom_layer, custom_latent_space_attribute):
369
- raise Exception(
370
- f"{custom_latent_space_attribute} function has not been set for custom layer: {custom_layer.__name__}")
364
+ raise Exception(f"{custom_latent_space_attribute} function has not been set for custom layer: {custom_layer.__name__}")
371
365
 
372
366
  init_args = inspect.getfullargspec(custom_layer.__init__)[0][1:]
373
367
  call_args = inspect.getfullargspec(custom_layer.call)[0][1:]
@@ -458,3 +452,5 @@ class LeapBinder:
458
452
  self.check_preprocess(preprocess_result)
459
453
  self.check_handlers(preprocess_result)
460
454
  print("Successful!")
455
+
456
+
code_loader/leaploader.py CHANGED
@@ -5,16 +5,15 @@ import sys
5
5
  from contextlib import redirect_stdout
6
6
  from functools import lru_cache
7
7
  from pathlib import Path
8
- from typing import Dict, List, Iterable, Union, Any, Optional, Tuple
8
+ from typing import Dict, List, Iterable, Union, Any
9
9
 
10
10
  import numpy as np
11
11
  import numpy.typing as npt
12
12
 
13
13
  from code_loader.contract.datasetclasses import DatasetSample, DatasetBaseHandler, GroundTruthHandler, \
14
14
  PreprocessResponse, VisualizerHandler, VisualizerCallableReturnType, CustomLossHandler, \
15
- PredictionTypeHandler, MetadataHandler, CustomLayerHandler, MetricHandler, ElementInstance
16
- from code_loader.contract.enums import DataStateEnum, TestingSectionEnum, DataStateType, DatasetMetadataType, \
17
- InstanceAnalysisType
15
+ PredictionTypeHandler, MetadataHandler, CustomLayerHandler, MetricHandler
16
+ from code_loader.contract.enums import DataStateEnum, TestingSectionEnum, DataStateType, DatasetMetadataType
18
17
  from code_loader.contract.exceptions import DatasetScriptException
19
18
  from code_loader.contract.responsedataclasses import DatasetIntegParseResult, DatasetTestResultPayload, \
20
19
  DatasetPreprocess, DatasetSetup, DatasetInputInstance, DatasetOutputInstance, DatasetMetadataInstance, \
@@ -299,17 +298,6 @@ class LeapLoader:
299
298
  def _get_inputs(self, state: DataStateEnum, idx: int) -> Dict[str, npt.NDArray[np.float32]]:
300
299
  return self._get_dataset_handlers(global_leap_binder.setup_container.inputs, state, idx)
301
300
 
302
- def get_instance_elements(self, state: DataStateEnum, idx: int, input_name: str) \
303
- -> Tuple[Optional[List[ElementInstance]], Optional[InstanceAnalysisType]]:
304
- preprocess_result = self._preprocess_result()
305
- preprocess_state = preprocess_result[state]
306
- for element_instance in global_leap_binder.setup_container.element_instances:
307
- if element_instance.input_name == input_name:
308
- return element_instance.instance_function(idx, preprocess_state), element_instance.analysis_type
309
-
310
- return None, None
311
-
312
-
313
301
  def _get_gt(self, state: DataStateEnum, idx: int) -> Dict[str, npt.NDArray[np.float32]]:
314
302
  return self._get_dataset_handlers(global_leap_binder.setup_container.ground_truths, state, idx)
315
303
 
@@ -1,5 +1,4 @@
1
1
  from enum import Enum
2
- from typing import List
3
2
 
4
3
  import numpy as np
5
4
  import numpy.typing as npt
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.42a0
3
+ Version: 1.0.43
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/code-loader
6
6
  License: MIT
@@ -13,6 +13,7 @@ Classifier: Programming Language :: Python :: 3.8
13
13
  Classifier: Programming Language :: Python :: 3.9
14
14
  Classifier: Programming Language :: Python :: 3.10
15
15
  Classifier: Programming Language :: Python :: 3.11
16
+ Requires-Dist: matplotlib (>=3.3,<3.4)
16
17
  Requires-Dist: numpy (>=1.22.3,<2.0.0)
17
18
  Requires-Dist: psutil (>=5.9.5,<6.0.0)
18
19
  Project-URL: Repository, https://github.com/tensorleap/code-loader
@@ -0,0 +1,18 @@
1
+ LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
2
+ code_loader/__init__.py,sha256=V3DEXSN6Ie6PlGeSAbzjp9ufRj0XPJLpD7pDLLYxk6M,122
3
+ code_loader/contract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ code_loader/contract/datasetclasses.py,sha256=81TmCcVol7768lzKUp70MatLLipR3ftcR9jgE1r8Yqo,5698
5
+ code_loader/contract/enums.py,sha256=6Lo7p5CUog68Fd31bCozIuOgIp_IhSiPqWWph2k3OGU,1602
6
+ code_loader/contract/exceptions.py,sha256=jWqu5i7t-0IG0jGRsKF4DjJdrsdpJjIYpUkN1F4RiyQ,51
7
+ code_loader/contract/responsedataclasses.py,sha256=w7xVOv2S8Hyb5lqyomMGiKAWXDTSOG-FX1YW39bXD3A,3969
8
+ code_loader/contract/visualizer_classes.py,sha256=Ka8fJSVKrOeZ12Eg8-dOBGMW0UswYCIkuhnNMd-7z9s,22948
9
+ code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
10
+ code_loader/inner_leap_binder/leapbinder.py,sha256=fHND8ayIXDlKHufHHq36u5rKNnZq64MmMuI42GctHvQ,24022
11
+ code_loader/leaploader.py,sha256=pUySweZetJ6SsubCcZlDCJpvWmUrm5YlPlkZWQxY1hQ,17289
12
+ code_loader/utils.py,sha256=61I4PgSl-ZBIe4DifLxMNlBELE-HQR2pB9efVYPceIU,2230
13
+ code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
+ code_loader/visualizers/default_visualizers.py,sha256=VoqO9FN84yXyMjRjHjUTOt2GdTkJRMbHbXJ1cJkREkk,2230
15
+ code_loader-1.0.43.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
16
+ code_loader-1.0.43.dist-info/METADATA,sha256=FO-vqcxPvPCvmVbDXmgc5ne0o2pSnKF9NsP2w0j_Ub8,807
17
+ code_loader-1.0.43.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
18
+ code_loader-1.0.43.dist-info/RECORD,,
@@ -1,18 +0,0 @@
1
- LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
2
- code_loader/__init__.py,sha256=V3DEXSN6Ie6PlGeSAbzjp9ufRj0XPJLpD7pDLLYxk6M,122
3
- code_loader/contract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- code_loader/contract/datasetclasses.py,sha256=eoDxCGZzvogp440dhWZBBm8oTs6XiqcHxE-JXD4txeA,6128
5
- code_loader/contract/enums.py,sha256=2tKXamabooQjCe3EBkeBcZMA8vBbeGjDAbACvUTGqLE,1707
6
- code_loader/contract/exceptions.py,sha256=jWqu5i7t-0IG0jGRsKF4DjJdrsdpJjIYpUkN1F4RiyQ,51
7
- code_loader/contract/responsedataclasses.py,sha256=w7xVOv2S8Hyb5lqyomMGiKAWXDTSOG-FX1YW39bXD3A,3969
8
- code_loader/contract/visualizer_classes.py,sha256=t0EKxTGFJoFBnKuMmjtFpj-L32xEFF3NV-E5XutVXa0,10118
9
- code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
10
- code_loader/inner_leap_binder/leapbinder.py,sha256=7_ffUiGW54Y27f7eKnvpJ2XLOqfGfTCW5cMNaH21iIk,24464
11
- code_loader/leaploader.py,sha256=gQaEpHyQx50FP-d0LS2PjQdm9aOUf3jok8ZeLbWkGZc,17917
12
- code_loader/utils.py,sha256=61I4PgSl-ZBIe4DifLxMNlBELE-HQR2pB9efVYPceIU,2230
13
- code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- code_loader/visualizers/default_visualizers.py,sha256=F2qRT6rNy7SOmr4QfqVxIkLlYEa00CwDkLJuA45lfSI,2254
15
- code_loader-1.0.42a0.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
16
- code_loader-1.0.42a0.dist-info/METADATA,sha256=HrRgFfW68Nr0KIBzK-C9Fwnxcw9SYotWTzWfwW2PF4I,770
17
- code_loader-1.0.42a0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
18
- code_loader-1.0.42a0.dist-info/RECORD,,