code-loader 1.0.49.dev9__tar.gz → 1.0.49.dev100__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {code_loader-1.0.49.dev9 → code_loader-1.0.49.dev100}/PKG-INFO +1 -1
  2. {code_loader-1.0.49.dev9 → code_loader-1.0.49.dev100}/code_loader/contract/datasetclasses.py +7 -6
  3. {code_loader-1.0.49.dev9 → code_loader-1.0.49.dev100}/code_loader/contract/visualizer_classes.py +0 -320
  4. {code_loader-1.0.49.dev9 → code_loader-1.0.49.dev100}/code_loader/inner_leap_binder/leapbinder.py +23 -4
  5. code_loader-1.0.49.dev100/code_loader/inner_leap_binder/leapbinder_decorators.py +56 -0
  6. {code_loader-1.0.49.dev9 → code_loader-1.0.49.dev100}/code_loader/leaploader.py +2 -2
  7. {code_loader-1.0.49.dev9 → code_loader-1.0.49.dev100}/code_loader/utils.py +1 -1
  8. {code_loader-1.0.49.dev9 → code_loader-1.0.49.dev100}/pyproject.toml +1 -1
  9. {code_loader-1.0.49.dev9 → code_loader-1.0.49.dev100}/LICENSE +0 -0
  10. {code_loader-1.0.49.dev9 → code_loader-1.0.49.dev100}/README.md +0 -0
  11. {code_loader-1.0.49.dev9 → code_loader-1.0.49.dev100}/code_loader/__init__.py +0 -0
  12. {code_loader-1.0.49.dev9 → code_loader-1.0.49.dev100}/code_loader/code_inegration_processes_manager.py +0 -0
  13. {code_loader-1.0.49.dev9 → code_loader-1.0.49.dev100}/code_loader/contract/__init__.py +0 -0
  14. {code_loader-1.0.49.dev9 → code_loader-1.0.49.dev100}/code_loader/contract/enums.py +0 -0
  15. {code_loader-1.0.49.dev9 → code_loader-1.0.49.dev100}/code_loader/contract/exceptions.py +0 -0
  16. {code_loader-1.0.49.dev9 → code_loader-1.0.49.dev100}/code_loader/contract/responsedataclasses.py +0 -0
  17. {code_loader-1.0.49.dev9 → code_loader-1.0.49.dev100}/code_loader/experiment_api/__init__.py +0 -0
  18. {code_loader-1.0.49.dev9 → code_loader-1.0.49.dev100}/code_loader/experiment_api/api.py +0 -0
  19. {code_loader-1.0.49.dev9 → code_loader-1.0.49.dev100}/code_loader/experiment_api/cli_config_utils.py +0 -0
  20. {code_loader-1.0.49.dev9 → code_loader-1.0.49.dev100}/code_loader/experiment_api/client.py +0 -0
  21. {code_loader-1.0.49.dev9 → code_loader-1.0.49.dev100}/code_loader/experiment_api/epoch.py +0 -0
  22. {code_loader-1.0.49.dev9 → code_loader-1.0.49.dev100}/code_loader/experiment_api/experiment.py +0 -0
  23. {code_loader-1.0.49.dev9 → code_loader-1.0.49.dev100}/code_loader/experiment_api/experiment_context.py +0 -0
  24. {code_loader-1.0.49.dev9 → code_loader-1.0.49.dev100}/code_loader/experiment_api/types.py +0 -0
  25. {code_loader-1.0.49.dev9 → code_loader-1.0.49.dev100}/code_loader/experiment_api/utils.py +0 -0
  26. {code_loader-1.0.49.dev9 → code_loader-1.0.49.dev100}/code_loader/experiment_api/workingspace_config_utils.py +0 -0
  27. {code_loader-1.0.49.dev9 → code_loader-1.0.49.dev100}/code_loader/inner_leap_binder/__init__.py +0 -0
  28. {code_loader-1.0.49.dev9 → code_loader-1.0.49.dev100}/code_loader/visualizers/__init__.py +0 -0
  29. {code_loader-1.0.49.dev9 → code_loader-1.0.49.dev100}/code_loader/visualizers/default_visualizers.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.49.dev9
3
+ Version: 1.0.49.dev100
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/code-loader
6
6
  License: MIT
@@ -36,9 +36,9 @@ class PreprocessResponse:
36
36
  data: Any = None
37
37
  sample_ids: Optional[Union[List[str], List[int]]] = None
38
38
  state: Optional[DataStateType] = None
39
- sample_id_type: Optional[Type] = None
39
+ sample_id_type: Optional[Union[Type[str], Type[int]]] = None
40
40
 
41
- def __post_init__(self):
41
+ def __post_init__(self) -> None:
42
42
  if self.length is not None and self.sample_ids is None:
43
43
  self.sample_ids = [i for i in range(self.length)]
44
44
  self.sample_id_type = int
@@ -47,9 +47,10 @@ class PreprocessResponse:
47
47
  if self.sample_id_type is None:
48
48
  self.sample_id_type = str
49
49
  else:
50
- raise Exception("length is deprecated. Please use sample_ids instead.")
50
+ raise Exception("length is deprecated.")
51
51
 
52
- def __len__(self):
52
+ def __len__(self) -> int:
53
+ assert self.sample_ids is not None
53
54
  return len(self.sample_ids)
54
55
 
55
56
 
@@ -90,8 +91,8 @@ VisualizerCallableInterface = Union[
90
91
  Callable[..., LeapImageWithHeatmap]
91
92
  ]
92
93
 
93
- VisualizerCallableReturnType = Union[LeapImage, LeapText, LeapGraph, LeapHorizontalBar,
94
- LeapImageMask, LeapTextMask, LeapImageWithBBox, LeapImageWithHeatmap]
94
+ LeapData = Union[LeapImage, LeapText, LeapGraph, LeapHorizontalBar, LeapImageMask, LeapTextMask, LeapImageWithBBox,
95
+ LeapImageWithHeatmap]
95
96
 
96
97
  CustomCallableInterface = Callable[..., Any]
97
98
 
@@ -4,8 +4,6 @@ import numpy as np
4
4
  import numpy.typing as npt
5
5
  from dataclasses import dataclass
6
6
 
7
- import matplotlib.pyplot as plt # type: ignore
8
-
9
7
  from code_loader.contract.enums import LeapDataType
10
8
  from code_loader.contract.responsedataclasses import BoundingBox
11
9
 
@@ -51,34 +49,6 @@ class LeapImage:
51
49
  validate_type(len(self.data.shape), 3, 'Image must be of shape 3')
52
50
  validate_type(self.data.shape[2], [1, 3], 'Image channel must be either 3(rgb) or 1(gray)')
53
51
 
54
- def plot_visualizer(self) -> None:
55
- """
56
- Display the image contained in the LeapImage object.
57
-
58
- Returns:
59
- None
60
-
61
- Example:
62
- image_data = np.random.rand(100, 100, 3).astype(np.float32)
63
- leap_image = LeapImage(data=image_data)
64
- leap_image.plot_visualizer()
65
- """
66
- image_data = self.data
67
-
68
- # If the image has one channel, convert it to a 3-channel image for display
69
- if image_data.shape[2] == 1:
70
- image_data = np.repeat(image_data, 3, axis=2)
71
-
72
- fig, ax = plt.subplots()
73
- fig.patch.set_facecolor('black')
74
- ax.set_facecolor('black')
75
-
76
- ax.imshow(image_data)
77
-
78
- plt.axis('off')
79
- plt.title('Leap Image Visualization', color='white')
80
- plt.show()
81
-
82
52
 
83
53
  @dataclass
84
54
  class LeapImageWithBBox:
@@ -106,60 +76,6 @@ class LeapImageWithBBox:
106
76
  validate_type(len(self.data.shape), 3, 'Image must be of shape 3')
107
77
  validate_type(self.data.shape[2], [1, 3], 'Image channel must be either 3(rgb) or 1(gray)')
108
78
 
109
- def plot_visualizer(self) -> None:
110
- """
111
- Plot an image with overlaid bounding boxes.
112
-
113
- Returns:
114
- None
115
-
116
- Example:
117
- image_data = np.random.rand(100, 100, 3).astype(np.float32)
118
- bbox = BoundingBox(x=0.5, y=0.5, width=0.2, height=0.2, confidence=0.9, label="object")
119
- leap_image_with_bbox = LeapImageWithBBox(data=image_data, bounding_boxes=[bbox])
120
- leap_image_with_bbox.plot_visualizer()
121
- """
122
-
123
- image = self.data
124
- bounding_boxes = self.bounding_boxes
125
-
126
- # Create figure and axes
127
- fig, ax = plt.subplots(1)
128
- fig.patch.set_facecolor('black')
129
- ax.set_facecolor('black')
130
-
131
- # Display the image
132
- ax.imshow(image)
133
- ax.set_title('Leap Image With BBox Visualization', color='white')
134
-
135
- # Draw bounding boxes on the image
136
- for bbox in bounding_boxes:
137
- x, y, width, height = bbox.x, bbox.y, bbox.width, bbox.height
138
- confidence, label = bbox.confidence, bbox.label
139
-
140
- # Convert relative coordinates to absolute coordinates
141
- abs_x = x * image.shape[1]
142
- abs_y = y * image.shape[0]
143
- abs_width = width * image.shape[1]
144
- abs_height = height * image.shape[0]
145
-
146
- # Create a rectangle patch
147
- rect = plt.Rectangle(
148
- (abs_x - abs_width / 2, abs_y - abs_height / 2),
149
- abs_width, abs_height,
150
- linewidth=3, edgecolor='r', facecolor='none'
151
- )
152
-
153
- # Add the rectangle to the axes
154
- ax.add_patch(rect)
155
-
156
- # Display label and confidence
157
- ax.text(abs_x - abs_width / 2, abs_y - abs_height / 2 - 5,
158
- f"{label} {confidence:.2f}", color='r', fontsize=8,
159
- bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', boxstyle='round,pad=0.3'))
160
-
161
- # Show the image with bounding boxes
162
- plt.show()
163
79
 
164
80
  @dataclass
165
81
  class LeapGraph:
@@ -183,40 +99,6 @@ class LeapGraph:
183
99
  validate_type(self.data.dtype, np.float32)
184
100
  validate_type(len(self.data.shape), 2, 'Graph must be of shape 2')
185
101
 
186
- def plot_visualizer(self) -> None:
187
- """
188
- Display the line chart contained in the LeapGraph object.
189
-
190
- Returns:
191
- None
192
-
193
- Example:
194
- graph_data = np.random.rand(100, 3).astype(np.float32)
195
- leap_graph = LeapGraph(data=graph_data)
196
- leap_graph.plot_visualizer()
197
- """
198
- graph_data = self.data
199
- num_variables = graph_data.shape[1]
200
-
201
- fig, ax = plt.subplots(figsize=(10, 6))
202
-
203
- # Set the background color to black
204
- fig.patch.set_facecolor('black')
205
- ax.set_facecolor('black')
206
-
207
- for i in range(num_variables):
208
- plt.plot(graph_data[:, i], label=f'Variable {i + 1}')
209
-
210
- ax.set_xlabel('Data Points', color='white')
211
- ax.set_ylabel('Values', color='white')
212
- ax.set_title('Leap Graph Visualization', color='white')
213
- ax.legend()
214
- ax.grid(True, color='white')
215
-
216
- # Change the color of the tick labels to white
217
- ax.tick_params(colors='white')
218
-
219
- plt.show()
220
102
 
221
103
  @dataclass
222
104
  class LeapText:
@@ -242,41 +124,6 @@ class LeapText:
242
124
  for value in self.data:
243
125
  validate_type(type(value), str)
244
126
 
245
- def plot_visualizer(self) -> None:
246
- """
247
- Display the text contained in the LeapText object.
248
-
249
- Returns:
250
- None
251
-
252
- Example:
253
- text_data = ['I', 'ate', 'a', 'banana', '', '', '']
254
- leap_text = LeapText(data=text_data)
255
- leap_text.plot_visualizer()
256
- """
257
- text_data = self.data
258
- # Join the text tokens into a single string, ignoring empty strings
259
- display_text = ' '.join([token for token in text_data if token])
260
-
261
- # Create a black image using Matplotlib
262
- fig, ax = plt.subplots(figsize=(10, 5))
263
- fig.patch.set_facecolor('black')
264
- ax.set_facecolor('black')
265
-
266
- # Hide the axes
267
- ax.axis('off')
268
-
269
- # Set the text properties
270
- font_size = 20
271
- font_color = 'white'
272
-
273
- # Add the text to the image
274
- ax.text(0.5, 0.5, display_text, color=font_color, fontsize=font_size, ha='center', va='center')
275
- ax.set_title('Leap Text Visualization', color='white')
276
-
277
- # Display the image
278
- plt.show()
279
-
280
127
 
281
128
  @dataclass
282
129
  class LeapHorizontalBar:
@@ -309,39 +156,6 @@ class LeapHorizontalBar:
309
156
  for label in self.labels:
310
157
  validate_type(type(label), str)
311
158
 
312
- def plot_visualizer(self) -> None:
313
- """
314
- Display the horizontal bar chart contained in the LeapHorizontalBar object.
315
-
316
- Returns:
317
- None
318
-
319
- Example:
320
- body_data = np.random.rand(5).astype(np.float32)
321
- labels = ['Class A', 'Class B', 'Class C', 'Class D', 'Class E']
322
- leap_horizontal_bar = LeapHorizontalBar(body=body_data, labels=labels)
323
- leap_horizontal_bar.plot_visualizer()
324
- """
325
- body_data = self.body
326
- labels = self.labels
327
-
328
- fig, ax = plt.subplots()
329
-
330
- fig.patch.set_facecolor('black')
331
- ax.set_facecolor('black')
332
-
333
- # Plot horizontal bar chart
334
- ax.barh(labels, body_data, color='green')
335
-
336
- # Set the color of the labels and title to white
337
- ax.set_xlabel('Scores', color='white')
338
- ax.set_title('Leap Horizontal Bar Visualization', color='white')
339
-
340
- # Set the color of the ticks to white
341
- ax.tick_params(axis='x', colors='white')
342
- ax.tick_params(axis='y', colors='white')
343
-
344
- plt.show()
345
159
 
346
160
  @dataclass
347
161
  class LeapImageMask:
@@ -379,51 +193,6 @@ class LeapImageMask:
379
193
  for label in self.labels:
380
194
  validate_type(type(label), str)
381
195
 
382
- def plot_visualizer(self) -> None:
383
- """
384
- Plots an image with overlaid masks given a LeapImageMask visualizer object.
385
-
386
- Returns:
387
- None
388
-
389
-
390
- Example:
391
- image_data = np.random.rand(100, 100, 3).astype(np.float32)
392
- mask_data = np.random.randint(0, 2, (100, 100)).astype(np.uint8)
393
- labels = ["background", "object"]
394
- leap_image_mask = LeapImageMask(image=image_data, mask=mask_data, labels=labels)
395
- leap_image_mask.plot_visualizer()
396
- """
397
-
398
- image = self.image
399
- mask = self.mask
400
- labels = self.labels
401
-
402
- # Create a color map for each label
403
- colors = plt.cm.jet(np.linspace(0, 1, len(labels)))
404
-
405
- # Make a copy of the image to draw on
406
- overlayed_image = image.copy()
407
-
408
- # Iterate through the unique values in the mask (excluding 0)
409
- for i, label in enumerate(labels):
410
- # Extract binary mask for the current instance
411
- instance_mask = (mask == (i + 1))
412
-
413
- # fill the instance mask with a translucent color
414
- overlayed_image[instance_mask] = (
415
- overlayed_image[instance_mask] * (1 - 0.5) + np.array(colors[i][:3], dtype=np.uint8) * 0.5)
416
-
417
- # Display the result using matplotlib
418
- fig, ax = plt.subplots(1)
419
- fig.patch.set_facecolor('black') # Set the figure background to black
420
- ax.set_facecolor('black') # Set the axis background to black
421
-
422
- ax.imshow(overlayed_image)
423
- ax.set_title('Leap Image With Mask Visualization', color='white')
424
- plt.axis('off') # Hide the axis
425
- plt.show()
426
-
427
196
 
428
197
  @dataclass
429
198
  class LeapTextMask:
@@ -461,55 +230,6 @@ class LeapTextMask:
461
230
  for label in self.labels:
462
231
  validate_type(type(label), str)
463
232
 
464
- def plot_visualizer(self) -> None:
465
- """
466
- Plots text with overlaid masks given a LeapTextMask visualizer object.
467
-
468
- Returns:
469
- None
470
-
471
- Example:
472
- text_data = ['I', 'ate', 'a', 'banana', '', '', '']
473
- mask_data = np.array([0, 0, 0, 1, 0, 0, 0]).astype(np.uint8)
474
- labels = ["object"]
475
- leap_text_mask = LeapTextMask(text=text_data, mask=mask_data, labels=labels)
476
- """
477
-
478
- text_data = self.text
479
- mask_data = self.mask
480
- labels = self.labels
481
-
482
- # Create a color map for each label
483
- colors = plt.cm.jet(np.linspace(0, 1, len(labels)))
484
-
485
- # Create a figure and axis
486
- fig, ax = plt.subplots()
487
-
488
- # Set background to black
489
- fig.patch.set_facecolor('black')
490
- ax.set_facecolor('black')
491
- ax.set_title('Leap Text Mask Visualization', color='white')
492
- ax.axis('off')
493
-
494
- # Set initial position
495
- x_pos, y_pos = 0.01, 0.5 # Adjusted initial position for better visibility
496
-
497
- # Display the text with colors
498
- for token, mask_value in zip(text_data, mask_data):
499
- if mask_value > 0:
500
- color = colors[mask_value % len(colors)]
501
- bbox = dict(facecolor=color, edgecolor='none',
502
- boxstyle='round,pad=0.3') # Background color for masked tokens
503
- else:
504
- bbox = None
505
-
506
- ax.text(x_pos, y_pos, token, fontsize=12, color='white', ha='left', va='center', bbox=bbox)
507
-
508
- # Update the x position for the next token
509
- x_pos += len(token) * 0.03 + 0.02 # Adjust the spacing between tokens
510
-
511
- plt.show()
512
-
513
233
 
514
234
  @dataclass
515
235
  class LeapImageWithHeatmap:
@@ -547,46 +267,6 @@ class LeapImageWithHeatmap:
547
267
  raise LeapValidationError(
548
268
  'Number of heatmaps and labels must be equal')
549
269
 
550
- def plot_visualizer(self) -> None:
551
- """
552
- Display the image with overlaid heatmaps contained in the LeapImageWithHeatmap object.
553
-
554
- Returns:
555
- None
556
-
557
- Example:
558
- image_data = np.random.rand(100, 100, 3).astype(np.float32)
559
- heatmaps = np.random.rand(3, 100, 100).astype(np.float32)
560
- labels = ["heatmap1", "heatmap2", "heatmap3"]
561
- leap_image_with_heatmap = LeapImageWithHeatmap(image=image_data, heatmaps=heatmaps, labels=labels)
562
- leap_image_with_heatmap.plot_visualizer()
563
- """
564
- image = self.image
565
- heatmaps = self.heatmaps
566
- labels = self.labels
567
-
568
- # Plot the base image
569
- fig, ax = plt.subplots()
570
- fig.patch.set_facecolor('black') # Set the figure background to black
571
- ax.set_facecolor('black') # Set the axis background to black
572
- ax.imshow(image, cmap='gray')
573
-
574
- # Overlay each heatmap with some transparency
575
- for i in range(len(labels)):
576
- heatmap = heatmaps[i]
577
- ax.imshow(heatmap, cmap='jet', alpha=0.5) # Adjust alpha for transparency
578
- ax.set_title(f'Heatmap: {labels[i]}', color='white')
579
-
580
- # Display a colorbar for the heatmap
581
- cbar = plt.colorbar(ax.imshow(heatmap, cmap='jet', alpha=0.5))
582
- cbar.set_label(labels[i], color='white')
583
- cbar.ax.yaxis.set_tick_params(color='white') # Set color for the colorbar ticks
584
- plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), color='white') # Set color for the colorbar labels
585
-
586
- plt.axis('off')
587
- plt.title('Leap Image With Heatmaps Visualization', color='white')
588
- plt.show()
589
-
590
270
 
591
271
  map_leap_data_type_to_visualizer_class = {
592
272
  LeapDataType.Image.value: LeapImage,
@@ -1,4 +1,5 @@
1
1
  import inspect
2
+ import sys
2
3
  from typing import Callable, List, Optional, Dict, Any, Type, Union
3
4
 
4
5
  import numpy as np
@@ -8,7 +9,7 @@ from code_loader.contract.datasetclasses import SectionCallableInterface, InputH
8
9
  GroundTruthHandler, MetadataHandler, DatasetIntegrationSetup, VisualizerHandler, PreprocessResponse, \
9
10
  PreprocessHandler, VisualizerCallableInterface, CustomLossHandler, CustomCallableInterface, PredictionTypeHandler, \
10
11
  MetadataSectionCallableInterface, UnlabeledDataPreprocessHandler, CustomLayerHandler, MetricHandler, \
11
- CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, VisualizerCallableReturnType, \
12
+ CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, LeapData, \
12
13
  CustomMultipleReturnCallableInterfaceMultiArgs, DatasetBaseHandler, custom_latent_space_attribute, RawInputsForHeatmap
13
14
  from code_loader.contract.enums import LeapDataType, DataStateEnum, DataStateType, MetricDirection
14
15
  from code_loader.contract.responsedataclasses import DatasetTestResultPayload
@@ -37,6 +38,8 @@ class LeapBinder:
37
38
  self._encoder_names: List[str] = list()
38
39
  self._extend_with_default_visualizers()
39
40
 
41
+ self.batch_size_to_validate: Optional[int] = None
42
+
40
43
  def _extend_with_default_visualizers(self) -> None:
41
44
  self.set_visualizer(function=default_image_visualizer, name=DefaultVisualizer.Image.value,
42
45
  visualizer_type=LeapDataType.Image)
@@ -103,7 +106,7 @@ class LeapBinder:
103
106
  if visualizer_type.value not in map_leap_data_type_to_visualizer_class:
104
107
  raise Exception(
105
108
  f'The visualizer_type is invalid. current visualizer_type: {visualizer_type}, ' # type: ignore[attr-defined]
106
- f'should be one of : {", ".join([arg.__name__ for arg in VisualizerCallableReturnType.__args__])}')
109
+ f'should be one of : {", ".join([arg.__name__ for arg in LeapData.__args__])}')
107
110
 
108
111
  func_annotations = function.__annotations__
109
112
  if "return" not in func_annotations:
@@ -113,10 +116,10 @@ class LeapBinder:
113
116
  f"https://docs.python.org/3/library/typing.html")
114
117
  else:
115
118
  return_type = func_annotations["return"]
116
- if return_type not in VisualizerCallableReturnType.__args__: # type: ignore[attr-defined]
119
+ if return_type not in LeapData.__args__: # type: ignore[attr-defined]
117
120
  raise Exception(
118
121
  f'The return type of function {function.__name__} is invalid. current return type: {return_type}, ' # type: ignore[attr-defined]
119
- f'should be one of : {", ".join([arg.__name__ for arg in VisualizerCallableReturnType.__args__])}')
122
+ f'should be one of : {", ".join([arg.__name__ for arg in LeapData.__args__])}')
120
123
 
121
124
  expected_return_type = map_leap_data_type_to_visualizer_class[visualizer_type.value]
122
125
  if not issubclass(return_type, expected_return_type):
@@ -430,6 +433,7 @@ class LeapBinder:
430
433
  def check_handler(
431
434
  preprocess_response: PreprocessResponse, test_result: List[DatasetTestResultPayload],
432
435
  dataset_base_handler: Union[DatasetBaseHandler, MetadataHandler]) -> List[DatasetTestResultPayload]:
436
+ assert preprocess_response.sample_ids is not None
433
437
  raw_result = dataset_base_handler.function(preprocess_response.sample_ids[0], preprocess_response)
434
438
  handler_type = 'metadata' if isinstance(dataset_base_handler, MetadataHandler) else None
435
439
  if isinstance(dataset_base_handler, MetadataHandler) and isinstance(raw_result, dict):
@@ -471,4 +475,19 @@ class LeapBinder:
471
475
  self.check_handlers(preprocess_result)
472
476
  print("Successful!")
473
477
 
478
+ def set_batch_size_to_validate(self, batch_size: int):
479
+ self.batch_size_to_validate = batch_size
480
+
481
+ @staticmethod
482
+ def init():
483
+ available_functions = inspect.getmembers(sys.modules[__name__], inspect.isfunction)
484
+ for func_name, func in available_functions:
485
+ if 'tensorleap_custom_metric' in str(func):
486
+ try:
487
+ func()
488
+ except:
489
+ pass
490
+
491
+
492
+
474
493
 
@@ -0,0 +1,56 @@
1
+ from typing import Optional, Union
2
+
3
+ import numpy as np
4
+
5
+ from code_loader.contract.datasetclasses import CustomCallableInterfaceMultiArgs, \
6
+ CustomMultipleReturnCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs
7
+ from code_loader.contract.enums import MetricDirection
8
+ from code_loader import leap_binder
9
+
10
+
11
+ def tensorleap_custom_metric(name: str, direction: Optional[MetricDirection] = MetricDirection.Downward):
12
+ def decorating_function(
13
+ user_function: Union[CustomCallableInterfaceMultiArgs,
14
+ CustomMultipleReturnCallableInterfaceMultiArgs,
15
+ ConfusionMatrixCallableInterfaceMultiArgs]
16
+ ):
17
+
18
+ leap_binder.add_custom_metric(user_function, name, direction)
19
+
20
+ def _validate_custom_metric_input_args(*args, **kwargs):
21
+ for i, arg in enumerate(args):
22
+ assert isinstance(arg, np.ndarray), (f'tensorleap_custom_metric validation failed: '
23
+ f'Argument #{i} should be a numpy array. Got {type(arg)}.')
24
+ if leap_binder.batch_size_to_validate:
25
+ assert arg.shape[0] == leap_binder.batch_size_to_validate, \
26
+ (f'tensorleap_custom_metric validation failed: Argument #{i} '
27
+ f'first dim should be as the batch size. Got {arg.shape[0]} '
28
+ f'instead of {leap_binder.batch_size_to_validate}')
29
+
30
+ for _arg_name, arg in kwargs.items():
31
+ assert isinstance(arg, np.ndarray), (f'tensorleap_custom_metric validation failed: '
32
+ f'Argument {_arg_name} should be a numpy array. Got {type(arg)}.')
33
+ if leap_binder.batch_size_to_validate:
34
+ assert arg.shape[0] == leap_binder.batch_size_to_validate, \
35
+ (f'tensorleap_custom_metric validation failed: Argument {_arg_name} '
36
+ f'first dim should be as the batch size. Got {arg.shape[0]} '
37
+ f'instead of {leap_binder.batch_size_to_validate}')
38
+
39
+ def _validate_custom_metric_result(result):
40
+ assert isinstance(result, np.ndarray), (f'tensorleap_custom_metric validation failed: '
41
+ f'The return type should be a numpy array. Got {type(result)}.')
42
+ assert len(result.shape) == 1, (f'tensorleap_custom_metric validation failed: '
43
+ f'The return shape should be 1D. Got {len(result.shape)}D.')
44
+ if leap_binder.batch_size_to_validate:
45
+ assert result.shape[0] == leap_binder.batch_size_to_validate, \
46
+ f'tensorleap_custom_metric validation failed: The return len should be as the batch size.'
47
+
48
+ def inner(*args, **kwargs):
49
+ _validate_custom_metric_input_args(*args, **kwargs)
50
+ result = user_function(*args, **kwargs)
51
+ _validate_custom_metric_result(result)
52
+ return result
53
+
54
+ return inner
55
+
56
+ return decorating_function
@@ -12,7 +12,7 @@ import numpy as np
12
12
  import numpy.typing as npt
13
13
 
14
14
  from code_loader.contract.datasetclasses import DatasetSample, DatasetBaseHandler, GroundTruthHandler, \
15
- PreprocessResponse, VisualizerHandler, VisualizerCallableReturnType, CustomLossHandler, \
15
+ PreprocessResponse, VisualizerHandler, LeapData, CustomLossHandler, \
16
16
  PredictionTypeHandler, MetadataHandler, CustomLayerHandler, MetricHandler
17
17
  from code_loader.contract.enums import DataStateEnum, TestingSectionEnum, DataStateType, DatasetMetadataType
18
18
  from code_loader.contract.exceptions import DatasetScriptException
@@ -201,7 +201,7 @@ class LeapLoader:
201
201
  return all_dataset_base_handlers
202
202
 
203
203
  def run_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]],
204
- ) -> VisualizerCallableReturnType:
204
+ ) -> LeapData:
205
205
  return self.visualizer_by_name()[visualizer_name].function(**input_tensors_by_arg_name)
206
206
 
207
207
  def run_heatmap_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]
@@ -10,7 +10,7 @@ from code_loader.contract.datasetclasses import SectionCallableInterface, Prepro
10
10
 
11
11
 
12
12
  def to_numpy_return_wrapper(encoder_function: SectionCallableInterface) -> SectionCallableInterface:
13
- def numpy_encoder_function(idx: int, samples: PreprocessResponse) -> npt.NDArray[np.float32]:
13
+ def numpy_encoder_function(idx: Union[int, str], samples: PreprocessResponse) -> npt.NDArray[np.float32]:
14
14
  result = encoder_function(idx, samples)
15
15
  numpy_result: npt.NDArray[np.float32] = np.array(result)
16
16
  return numpy_result
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "code-loader"
3
- version = "1.0.49.dev9"
3
+ version = "1.0.49.dev100"
4
4
  description = ""
5
5
  authors = ["dorhar <doron.harnoy@tensorleap.ai>"]
6
6
  license = "MIT"