code-loader 1.0.49.dev8__py3-none-any.whl → 1.0.50__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -32,38 +32,21 @@ class PreprocessResponse:
32
32
  }
33
33
  response = PreprocessResponse(length=len(preprocessed_data), data=preprocessed_data)
34
34
  """
35
- length: Optional[int] = None # Deprecated. Please use sample_ids instead
36
- data: Any = None
37
- sample_ids: Optional[Union[List[str], List[int]]] = None
38
- state: Optional[DataStateType] = None
39
- sample_id_type: Optional[Type] = None
35
+ length: int
36
+ data: Any
40
37
 
41
- def __post_init__(self):
42
- if self.length is not None and self.sample_ids is None:
43
- self.sample_ids = [i for i in range(self.length)]
44
- self.sample_id_type = int
45
- elif self.length is None and self.sample_ids is not None:
46
- self.length = len(self.sample_ids)
47
- if self.sample_id_type is None:
48
- self.sample_id_type = str
49
- else:
50
- raise Exception("length is deprecated. Please use sample_ids instead.")
51
38
 
52
- def __len__(self):
53
- return len(self.sample_ids)
54
-
55
-
56
- SectionCallableInterface = Callable[[Union[int, str], PreprocessResponse], npt.NDArray[np.float32]]
39
+ SectionCallableInterface = Callable[[int, PreprocessResponse], npt.NDArray[np.float32]]
57
40
 
58
41
  MetadataSectionCallableInterface = Union[
59
- Callable[[Union[int, str], PreprocessResponse], int],
60
- Callable[[Union[int, str], PreprocessResponse], Dict[str, int]],
61
- Callable[[Union[int, str], PreprocessResponse], str],
62
- Callable[[Union[int, str], PreprocessResponse], Dict[str, str]],
63
- Callable[[Union[int, str], PreprocessResponse], bool],
64
- Callable[[Union[int, str], PreprocessResponse], Dict[str, bool]],
65
- Callable[[Union[int, str], PreprocessResponse], float],
66
- Callable[[Union[int, str], PreprocessResponse], Dict[str, float]]
42
+ Callable[[int, PreprocessResponse], int],
43
+ Callable[[int, PreprocessResponse], Dict[str, int]],
44
+ Callable[[int, PreprocessResponse], str],
45
+ Callable[[int, PreprocessResponse], Dict[str, str]],
46
+ Callable[[int, PreprocessResponse], bool],
47
+ Callable[[int, PreprocessResponse], Dict[str, bool]],
48
+ Callable[[int, PreprocessResponse], float],
49
+ Callable[[int, PreprocessResponse], Dict[str, float]]
67
50
  ]
68
51
 
69
52
 
@@ -90,8 +73,8 @@ VisualizerCallableInterface = Union[
90
73
  Callable[..., LeapImageWithHeatmap]
91
74
  ]
92
75
 
93
- VisualizerCallableReturnType = Union[LeapImage, LeapText, LeapGraph, LeapHorizontalBar,
94
- LeapImageMask, LeapTextMask, LeapImageWithBBox, LeapImageWithHeatmap]
76
+ LeapData = Union[LeapImage, LeapText, LeapGraph, LeapHorizontalBar, LeapImageMask, LeapTextMask, LeapImageWithBBox,
77
+ LeapImageWithHeatmap]
95
78
 
96
79
  CustomCallableInterface = Callable[..., Any]
97
80
 
@@ -198,5 +181,5 @@ class DatasetSample:
198
181
  inputs: Dict[str, npt.NDArray[np.float32]]
199
182
  gt: Optional[Dict[str, npt.NDArray[np.float32]]]
200
183
  metadata: Dict[str, Union[str, int, bool, float]]
201
- index: Union[int, str]
184
+ index: int
202
185
  state: DataStateEnum
@@ -4,8 +4,6 @@ import numpy as np
4
4
  import numpy.typing as npt
5
5
  from dataclasses import dataclass
6
6
 
7
- import matplotlib.pyplot as plt # type: ignore
8
-
9
7
  from code_loader.contract.enums import LeapDataType
10
8
  from code_loader.contract.responsedataclasses import BoundingBox
11
9
 
@@ -51,34 +49,6 @@ class LeapImage:
51
49
  validate_type(len(self.data.shape), 3, 'Image must be of shape 3')
52
50
  validate_type(self.data.shape[2], [1, 3], 'Image channel must be either 3(rgb) or 1(gray)')
53
51
 
54
- def plot_visualizer(self) -> None:
55
- """
56
- Display the image contained in the LeapImage object.
57
-
58
- Returns:
59
- None
60
-
61
- Example:
62
- image_data = np.random.rand(100, 100, 3).astype(np.float32)
63
- leap_image = LeapImage(data=image_data)
64
- leap_image.plot_visualizer()
65
- """
66
- image_data = self.data
67
-
68
- # If the image has one channel, convert it to a 3-channel image for display
69
- if image_data.shape[2] == 1:
70
- image_data = np.repeat(image_data, 3, axis=2)
71
-
72
- fig, ax = plt.subplots()
73
- fig.patch.set_facecolor('black')
74
- ax.set_facecolor('black')
75
-
76
- ax.imshow(image_data)
77
-
78
- plt.axis('off')
79
- plt.title('Leap Image Visualization', color='white')
80
- plt.show()
81
-
82
52
 
83
53
  @dataclass
84
54
  class LeapImageWithBBox:
@@ -106,60 +76,6 @@ class LeapImageWithBBox:
106
76
  validate_type(len(self.data.shape), 3, 'Image must be of shape 3')
107
77
  validate_type(self.data.shape[2], [1, 3], 'Image channel must be either 3(rgb) or 1(gray)')
108
78
 
109
- def plot_visualizer(self) -> None:
110
- """
111
- Plot an image with overlaid bounding boxes.
112
-
113
- Returns:
114
- None
115
-
116
- Example:
117
- image_data = np.random.rand(100, 100, 3).astype(np.float32)
118
- bbox = BoundingBox(x=0.5, y=0.5, width=0.2, height=0.2, confidence=0.9, label="object")
119
- leap_image_with_bbox = LeapImageWithBBox(data=image_data, bounding_boxes=[bbox])
120
- leap_image_with_bbox.plot_visualizer()
121
- """
122
-
123
- image = self.data
124
- bounding_boxes = self.bounding_boxes
125
-
126
- # Create figure and axes
127
- fig, ax = plt.subplots(1)
128
- fig.patch.set_facecolor('black')
129
- ax.set_facecolor('black')
130
-
131
- # Display the image
132
- ax.imshow(image)
133
- ax.set_title('Leap Image With BBox Visualization', color='white')
134
-
135
- # Draw bounding boxes on the image
136
- for bbox in bounding_boxes:
137
- x, y, width, height = bbox.x, bbox.y, bbox.width, bbox.height
138
- confidence, label = bbox.confidence, bbox.label
139
-
140
- # Convert relative coordinates to absolute coordinates
141
- abs_x = x * image.shape[1]
142
- abs_y = y * image.shape[0]
143
- abs_width = width * image.shape[1]
144
- abs_height = height * image.shape[0]
145
-
146
- # Create a rectangle patch
147
- rect = plt.Rectangle(
148
- (abs_x - abs_width / 2, abs_y - abs_height / 2),
149
- abs_width, abs_height,
150
- linewidth=3, edgecolor='r', facecolor='none'
151
- )
152
-
153
- # Add the rectangle to the axes
154
- ax.add_patch(rect)
155
-
156
- # Display label and confidence
157
- ax.text(abs_x - abs_width / 2, abs_y - abs_height / 2 - 5,
158
- f"{label} {confidence:.2f}", color='r', fontsize=8,
159
- bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', boxstyle='round,pad=0.3'))
160
-
161
- # Show the image with bounding boxes
162
- plt.show()
163
79
 
164
80
  @dataclass
165
81
  class LeapGraph:
@@ -183,40 +99,6 @@ class LeapGraph:
183
99
  validate_type(self.data.dtype, np.float32)
184
100
  validate_type(len(self.data.shape), 2, 'Graph must be of shape 2')
185
101
 
186
- def plot_visualizer(self) -> None:
187
- """
188
- Display the line chart contained in the LeapGraph object.
189
-
190
- Returns:
191
- None
192
-
193
- Example:
194
- graph_data = np.random.rand(100, 3).astype(np.float32)
195
- leap_graph = LeapGraph(data=graph_data)
196
- leap_graph.plot_visualizer()
197
- """
198
- graph_data = self.data
199
- num_variables = graph_data.shape[1]
200
-
201
- fig, ax = plt.subplots(figsize=(10, 6))
202
-
203
- # Set the background color to black
204
- fig.patch.set_facecolor('black')
205
- ax.set_facecolor('black')
206
-
207
- for i in range(num_variables):
208
- plt.plot(graph_data[:, i], label=f'Variable {i + 1}')
209
-
210
- ax.set_xlabel('Data Points', color='white')
211
- ax.set_ylabel('Values', color='white')
212
- ax.set_title('Leap Graph Visualization', color='white')
213
- ax.legend()
214
- ax.grid(True, color='white')
215
-
216
- # Change the color of the tick labels to white
217
- ax.tick_params(colors='white')
218
-
219
- plt.show()
220
102
 
221
103
  @dataclass
222
104
  class LeapText:
@@ -242,41 +124,6 @@ class LeapText:
242
124
  for value in self.data:
243
125
  validate_type(type(value), str)
244
126
 
245
- def plot_visualizer(self) -> None:
246
- """
247
- Display the text contained in the LeapText object.
248
-
249
- Returns:
250
- None
251
-
252
- Example:
253
- text_data = ['I', 'ate', 'a', 'banana', '', '', '']
254
- leap_text = LeapText(data=text_data)
255
- leap_text.plot_visualizer()
256
- """
257
- text_data = self.data
258
- # Join the text tokens into a single string, ignoring empty strings
259
- display_text = ' '.join([token for token in text_data if token])
260
-
261
- # Create a black image using Matplotlib
262
- fig, ax = plt.subplots(figsize=(10, 5))
263
- fig.patch.set_facecolor('black')
264
- ax.set_facecolor('black')
265
-
266
- # Hide the axes
267
- ax.axis('off')
268
-
269
- # Set the text properties
270
- font_size = 20
271
- font_color = 'white'
272
-
273
- # Add the text to the image
274
- ax.text(0.5, 0.5, display_text, color=font_color, fontsize=font_size, ha='center', va='center')
275
- ax.set_title('Leap Text Visualization', color='white')
276
-
277
- # Display the image
278
- plt.show()
279
-
280
127
 
281
128
  @dataclass
282
129
  class LeapHorizontalBar:
@@ -309,39 +156,6 @@ class LeapHorizontalBar:
309
156
  for label in self.labels:
310
157
  validate_type(type(label), str)
311
158
 
312
- def plot_visualizer(self) -> None:
313
- """
314
- Display the horizontal bar chart contained in the LeapHorizontalBar object.
315
-
316
- Returns:
317
- None
318
-
319
- Example:
320
- body_data = np.random.rand(5).astype(np.float32)
321
- labels = ['Class A', 'Class B', 'Class C', 'Class D', 'Class E']
322
- leap_horizontal_bar = LeapHorizontalBar(body=body_data, labels=labels)
323
- leap_horizontal_bar.plot_visualizer()
324
- """
325
- body_data = self.body
326
- labels = self.labels
327
-
328
- fig, ax = plt.subplots()
329
-
330
- fig.patch.set_facecolor('black')
331
- ax.set_facecolor('black')
332
-
333
- # Plot horizontal bar chart
334
- ax.barh(labels, body_data, color='green')
335
-
336
- # Set the color of the labels and title to white
337
- ax.set_xlabel('Scores', color='white')
338
- ax.set_title('Leap Horizontal Bar Visualization', color='white')
339
-
340
- # Set the color of the ticks to white
341
- ax.tick_params(axis='x', colors='white')
342
- ax.tick_params(axis='y', colors='white')
343
-
344
- plt.show()
345
159
 
346
160
  @dataclass
347
161
  class LeapImageMask:
@@ -379,51 +193,6 @@ class LeapImageMask:
379
193
  for label in self.labels:
380
194
  validate_type(type(label), str)
381
195
 
382
- def plot_visualizer(self) -> None:
383
- """
384
- Plots an image with overlaid masks given a LeapImageMask visualizer object.
385
-
386
- Returns:
387
- None
388
-
389
-
390
- Example:
391
- image_data = np.random.rand(100, 100, 3).astype(np.float32)
392
- mask_data = np.random.randint(0, 2, (100, 100)).astype(np.uint8)
393
- labels = ["background", "object"]
394
- leap_image_mask = LeapImageMask(image=image_data, mask=mask_data, labels=labels)
395
- leap_image_mask.plot_visualizer()
396
- """
397
-
398
- image = self.image
399
- mask = self.mask
400
- labels = self.labels
401
-
402
- # Create a color map for each label
403
- colors = plt.cm.jet(np.linspace(0, 1, len(labels)))
404
-
405
- # Make a copy of the image to draw on
406
- overlayed_image = image.copy()
407
-
408
- # Iterate through the unique values in the mask (excluding 0)
409
- for i, label in enumerate(labels):
410
- # Extract binary mask for the current instance
411
- instance_mask = (mask == (i + 1))
412
-
413
- # fill the instance mask with a translucent color
414
- overlayed_image[instance_mask] = (
415
- overlayed_image[instance_mask] * (1 - 0.5) + np.array(colors[i][:3], dtype=np.uint8) * 0.5)
416
-
417
- # Display the result using matplotlib
418
- fig, ax = plt.subplots(1)
419
- fig.patch.set_facecolor('black') # Set the figure background to black
420
- ax.set_facecolor('black') # Set the axis background to black
421
-
422
- ax.imshow(overlayed_image)
423
- ax.set_title('Leap Image With Mask Visualization', color='white')
424
- plt.axis('off') # Hide the axis
425
- plt.show()
426
-
427
196
 
428
197
  @dataclass
429
198
  class LeapTextMask:
@@ -461,55 +230,6 @@ class LeapTextMask:
461
230
  for label in self.labels:
462
231
  validate_type(type(label), str)
463
232
 
464
- def plot_visualizer(self) -> None:
465
- """
466
- Plots text with overlaid masks given a LeapTextMask visualizer object.
467
-
468
- Returns:
469
- None
470
-
471
- Example:
472
- text_data = ['I', 'ate', 'a', 'banana', '', '', '']
473
- mask_data = np.array([0, 0, 0, 1, 0, 0, 0]).astype(np.uint8)
474
- labels = ["object"]
475
- leap_text_mask = LeapTextMask(text=text_data, mask=mask_data, labels=labels)
476
- """
477
-
478
- text_data = self.text
479
- mask_data = self.mask
480
- labels = self.labels
481
-
482
- # Create a color map for each label
483
- colors = plt.cm.jet(np.linspace(0, 1, len(labels)))
484
-
485
- # Create a figure and axis
486
- fig, ax = plt.subplots()
487
-
488
- # Set background to black
489
- fig.patch.set_facecolor('black')
490
- ax.set_facecolor('black')
491
- ax.set_title('Leap Text Mask Visualization', color='white')
492
- ax.axis('off')
493
-
494
- # Set initial position
495
- x_pos, y_pos = 0.01, 0.5 # Adjusted initial position for better visibility
496
-
497
- # Display the text with colors
498
- for token, mask_value in zip(text_data, mask_data):
499
- if mask_value > 0:
500
- color = colors[mask_value % len(colors)]
501
- bbox = dict(facecolor=color, edgecolor='none',
502
- boxstyle='round,pad=0.3') # Background color for masked tokens
503
- else:
504
- bbox = None
505
-
506
- ax.text(x_pos, y_pos, token, fontsize=12, color='white', ha='left', va='center', bbox=bbox)
507
-
508
- # Update the x position for the next token
509
- x_pos += len(token) * 0.03 + 0.02 # Adjust the spacing between tokens
510
-
511
- plt.show()
512
-
513
233
 
514
234
  @dataclass
515
235
  class LeapImageWithHeatmap:
@@ -547,46 +267,6 @@ class LeapImageWithHeatmap:
547
267
  raise LeapValidationError(
548
268
  'Number of heatmaps and labels must be equal')
549
269
 
550
- def plot_visualizer(self) -> None:
551
- """
552
- Display the image with overlaid heatmaps contained in the LeapImageWithHeatmap object.
553
-
554
- Returns:
555
- None
556
-
557
- Example:
558
- image_data = np.random.rand(100, 100, 3).astype(np.float32)
559
- heatmaps = np.random.rand(3, 100, 100).astype(np.float32)
560
- labels = ["heatmap1", "heatmap2", "heatmap3"]
561
- leap_image_with_heatmap = LeapImageWithHeatmap(image=image_data, heatmaps=heatmaps, labels=labels)
562
- leap_image_with_heatmap.plot_visualizer()
563
- """
564
- image = self.image
565
- heatmaps = self.heatmaps
566
- labels = self.labels
567
-
568
- # Plot the base image
569
- fig, ax = plt.subplots()
570
- fig.patch.set_facecolor('black') # Set the figure background to black
571
- ax.set_facecolor('black') # Set the axis background to black
572
- ax.imshow(image, cmap='gray')
573
-
574
- # Overlay each heatmap with some transparency
575
- for i in range(len(labels)):
576
- heatmap = heatmaps[i]
577
- ax.imshow(heatmap, cmap='jet', alpha=0.5) # Adjust alpha for transparency
578
- ax.set_title(f'Heatmap: {labels[i]}', color='white')
579
-
580
- # Display a colorbar for the heatmap
581
- cbar = plt.colorbar(ax.imshow(heatmap, cmap='jet', alpha=0.5))
582
- cbar.set_label(labels[i], color='white')
583
- cbar.ax.yaxis.set_tick_params(color='white') # Set color for the colorbar ticks
584
- plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), color='white') # Set color for the colorbar labels
585
-
586
- plt.axis('off')
587
- plt.title('Leap Image With Heatmaps Visualization', color='white')
588
- plt.show()
589
-
590
270
 
591
271
  map_leap_data_type_to_visualizer_class = {
592
272
  LeapDataType.Image.value: LeapImage,
@@ -8,7 +8,7 @@ from code_loader.contract.datasetclasses import SectionCallableInterface, InputH
8
8
  GroundTruthHandler, MetadataHandler, DatasetIntegrationSetup, VisualizerHandler, PreprocessResponse, \
9
9
  PreprocessHandler, VisualizerCallableInterface, CustomLossHandler, CustomCallableInterface, PredictionTypeHandler, \
10
10
  MetadataSectionCallableInterface, UnlabeledDataPreprocessHandler, CustomLayerHandler, MetricHandler, \
11
- CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, VisualizerCallableReturnType, \
11
+ CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, LeapData, \
12
12
  CustomMultipleReturnCallableInterfaceMultiArgs, DatasetBaseHandler, custom_latent_space_attribute, RawInputsForHeatmap
13
13
  from code_loader.contract.enums import LeapDataType, DataStateEnum, DataStateType, MetricDirection
14
14
  from code_loader.contract.responsedataclasses import DatasetTestResultPayload
@@ -103,7 +103,7 @@ class LeapBinder:
103
103
  if visualizer_type.value not in map_leap_data_type_to_visualizer_class:
104
104
  raise Exception(
105
105
  f'The visualizer_type is invalid. current visualizer_type: {visualizer_type}, ' # type: ignore[attr-defined]
106
- f'should be one of : {", ".join([arg.__name__ for arg in VisualizerCallableReturnType.__args__])}')
106
+ f'should be one of : {", ".join([arg.__name__ for arg in LeapData.__args__])}')
107
107
 
108
108
  func_annotations = function.__annotations__
109
109
  if "return" not in func_annotations:
@@ -113,10 +113,10 @@ class LeapBinder:
113
113
  f"https://docs.python.org/3/library/typing.html")
114
114
  else:
115
115
  return_type = func_annotations["return"]
116
- if return_type not in VisualizerCallableReturnType.__args__: # type: ignore[attr-defined]
116
+ if return_type not in LeapData.__args__: # type: ignore[attr-defined]
117
117
  raise Exception(
118
118
  f'The return type of function {function.__name__} is invalid. current return type: {return_type}, ' # type: ignore[attr-defined]
119
- f'should be one of : {", ".join([arg.__name__ for arg in VisualizerCallableReturnType.__args__])}')
119
+ f'should be one of : {", ".join([arg.__name__ for arg in LeapData.__args__])}')
120
120
 
121
121
  expected_return_type = map_leap_data_type_to_visualizer_class[visualizer_type.value]
122
122
  if not issubclass(return_type, expected_return_type):
@@ -389,35 +389,16 @@ class LeapBinder:
389
389
  if preprocess is None:
390
390
  raise Exception("Please make sure you call the leap_binder.set_preprocess method")
391
391
  preprocess_results = preprocess.function()
392
- preprocess_result_dict = {}
393
- for i, preprocess_result in enumerate(preprocess_results):
394
- if preprocess_result.state is None:
395
- state_enum = DataStateEnum(i)
396
- preprocess_result.state = DataStateType(state_enum.name)
397
- else:
398
- state_enum = DataStateEnum[preprocess_result.state.name]
399
-
400
- if state_enum in preprocess_result_dict:
401
- raise Exception(f"Duplicate state {state_enum.name} in preprocess results")
402
- preprocess_result_dict[state_enum] = preprocess_result
403
-
404
- if DataStateEnum.unlabeled not in preprocess_result_dict:
405
- preprocess_unlabeled_result = self.get_preprocess_unlabeled_result()
406
- if preprocess_unlabeled_result is not None:
407
- preprocess_result_dict[DataStateEnum.unlabeled] = preprocess_unlabeled_result
408
-
409
- if DataStateEnum.training not in preprocess_result_dict:
410
- raise Exception("Training data is required")
411
- if DataStateEnum.validation not in preprocess_result_dict:
412
- raise Exception("Validation data is required")
392
+ preprocess_result_dict = {
393
+ DataStateEnum(i): preprocess_result
394
+ for i, preprocess_result in enumerate(preprocess_results)
395
+ }
413
396
 
414
- return preprocess_result_dict
415
-
416
- def get_preprocess_unlabeled_result(self) -> Optional[PreprocessResponse]:
417
397
  unlabeled_preprocess = self.setup_container.unlabeled_data_preprocess
418
398
  if unlabeled_preprocess is not None:
419
- return unlabeled_preprocess.function()
420
- return None
399
+ preprocess_result_dict[DataStateEnum.unlabeled] = unlabeled_preprocess.function()
400
+
401
+ return preprocess_result_dict
421
402
 
422
403
  def _get_all_dataset_base_handlers(self) -> List[Union[DatasetBaseHandler, MetadataHandler]]:
423
404
  all_dataset_base_handlers: List[Union[DatasetBaseHandler, MetadataHandler]] = []
@@ -430,7 +411,7 @@ class LeapBinder:
430
411
  def check_handler(
431
412
  preprocess_response: PreprocessResponse, test_result: List[DatasetTestResultPayload],
432
413
  dataset_base_handler: Union[DatasetBaseHandler, MetadataHandler]) -> List[DatasetTestResultPayload]:
433
- raw_result = dataset_base_handler.function(preprocess_response.sample_ids[0], preprocess_response)
414
+ raw_result = dataset_base_handler.function(0, preprocess_response)
434
415
  handler_type = 'metadata' if isinstance(dataset_base_handler, MetadataHandler) else None
435
416
  if isinstance(dataset_base_handler, MetadataHandler) and isinstance(raw_result, dict):
436
417
  metadata_test_result_payloads = [
code_loader/leaploader.py CHANGED
@@ -2,17 +2,16 @@
2
2
  import importlib.util
3
3
  import io
4
4
  import sys
5
- import time
6
5
  from contextlib import redirect_stdout
7
6
  from functools import lru_cache
8
7
  from pathlib import Path
9
- from typing import Dict, List, Iterable, Union, Any, Type
8
+ from typing import Dict, List, Iterable, Union, Any
10
9
 
11
10
  import numpy as np
12
11
  import numpy.typing as npt
13
12
 
14
13
  from code_loader.contract.datasetclasses import DatasetSample, DatasetBaseHandler, GroundTruthHandler, \
15
- PreprocessResponse, VisualizerHandler, VisualizerCallableReturnType, CustomLossHandler, \
14
+ PreprocessResponse, VisualizerHandler, LeapData, CustomLossHandler, \
16
15
  PredictionTypeHandler, MetadataHandler, CustomLayerHandler, MetricHandler
17
16
  from code_loader.contract.enums import DataStateEnum, TestingSectionEnum, DataStateType, DatasetMetadataType
18
17
  from code_loader.contract.exceptions import DatasetScriptException
@@ -28,8 +27,6 @@ class LeapLoader:
28
27
  self.code_entry_name = code_entry_name
29
28
  self.code_path = code_path
30
29
 
31
- self._preprocess_result_cached = None
32
-
33
30
  @lru_cache()
34
31
  def exec_script(self) -> None:
35
32
  try:
@@ -106,16 +103,12 @@ class LeapLoader:
106
103
  for prediction_type in setup.prediction_types
107
104
  }
108
105
 
109
- def get_sample(self, state: DataStateEnum, sample_id: Union[int, str]) -> DatasetSample:
106
+ def get_sample(self, state: DataStateEnum, idx: int) -> DatasetSample:
110
107
  self.exec_script()
111
- preprocess_result = self._preprocess_result()
112
- if state == DataStateEnum.unlabeled and sample_id not in preprocess_result[state].sample_ids:
113
- self._preprocess_result(update_unlabeled_preprocess=True)
114
-
115
- sample = DatasetSample(inputs=self._get_inputs(state, sample_id),
116
- gt=None if state == DataStateEnum.unlabeled else self._get_gt(state, sample_id),
117
- metadata=self._get_metadata(state, sample_id),
118
- index=sample_id,
108
+ sample = DatasetSample(inputs=self._get_inputs(state, idx),
109
+ gt=None if state == DataStateEnum.unlabeled else self._get_gt(state, idx),
110
+ metadata=self._get_metadata(state, idx),
111
+ index=idx,
119
112
  state=state)
120
113
  return sample
121
114
 
@@ -194,7 +187,7 @@ class LeapLoader:
194
187
  return all_dataset_base_handlers
195
188
 
196
189
  def run_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]],
197
- ) -> VisualizerCallableReturnType:
190
+ ) -> LeapData:
198
191
  return self.visualizer_by_name()[visualizer_name].function(**input_tensors_by_arg_name)
199
192
 
200
193
  def run_heatmap_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]
@@ -286,42 +279,27 @@ class LeapLoader:
286
279
  ]
287
280
  return ModelSetup(custom_layer_instances)
288
281
 
289
- def _preprocess_result(self, update_unlabeled_preprocess=False) -> Dict[DataStateEnum, PreprocessResponse]:
282
+ @lru_cache()
283
+ def _preprocess_result(self) -> Dict[DataStateEnum, PreprocessResponse]:
290
284
  self.exec_script()
291
-
292
- if self._preprocess_result_cached is None:
293
- self._preprocess_result_cached = global_leap_binder.get_preprocess_result()
294
-
295
- if update_unlabeled_preprocess:
296
- self._preprocess_result_cached[
297
- DataStateEnum.unlabeled] = global_leap_binder.get_preprocess_unlabeled_result()
298
-
299
- return self._preprocess_result_cached
300
-
301
- def get_preprocess_sample_ids(self, update_unlabeled_preprocess=False) -> Dict[DataStateEnum, Union[List[int], List[str]]]:
302
- preprocess_result = self._preprocess_result(update_unlabeled_preprocess)
303
- sample_ids = {}
304
- for state, preprocess_response in preprocess_result.items():
305
- sample_ids[state] = preprocess_response.sample_ids
306
-
307
- return sample_ids
285
+ return global_leap_binder.get_preprocess_result()
308
286
 
309
287
  def _get_dataset_handlers(self, handlers: Iterable[DatasetBaseHandler],
310
- state: DataStateEnum, sample_id: Union[int, str]) -> Dict[str, npt.NDArray[np.float32]]:
288
+ state: DataStateEnum, idx: int) -> Dict[str, npt.NDArray[np.float32]]:
311
289
  result_agg = {}
312
290
  preprocess_result = self._preprocess_result()
313
291
  preprocess_state = preprocess_result[state]
314
292
  for handler in handlers:
315
- handler_result = handler.function(sample_id, preprocess_state)
293
+ handler_result = handler.function(idx, preprocess_state)
316
294
  handler_name = handler.name
317
295
  result_agg[handler_name] = handler_result
318
296
  return result_agg
319
297
 
320
- def _get_inputs(self, state: DataStateEnum, sample_id: Union[int, str]) -> Dict[str, npt.NDArray[np.float32]]:
321
- return self._get_dataset_handlers(global_leap_binder.setup_container.inputs, state, sample_id)
298
+ def _get_inputs(self, state: DataStateEnum, idx: int) -> Dict[str, npt.NDArray[np.float32]]:
299
+ return self._get_dataset_handlers(global_leap_binder.setup_container.inputs, state, idx)
322
300
 
323
- def _get_gt(self, state: DataStateEnum, sample_id: Union[int, str]) -> Dict[str, npt.NDArray[np.float32]]:
324
- return self._get_dataset_handlers(global_leap_binder.setup_container.ground_truths, state, sample_id)
301
+ def _get_gt(self, state: DataStateEnum, idx: int) -> Dict[str, npt.NDArray[np.float32]]:
302
+ return self._get_dataset_handlers(global_leap_binder.setup_container.ground_truths, state, idx)
325
303
 
326
304
  @lru_cache()
327
305
  def _metadata_name_to_type(self) -> Dict[str, DatasetMetadataType]:
@@ -356,12 +334,12 @@ class LeapLoader:
356
334
 
357
335
  return converted_value
358
336
 
359
- def _get_metadata(self, state: DataStateEnum, sample_id: Union[int, str]) -> Dict[str, Union[str, int, bool, float]]:
337
+ def _get_metadata(self, state: DataStateEnum, idx: int) -> Dict[str, Union[str, int, bool, float]]:
360
338
  result_agg = {}
361
339
  preprocess_result = self._preprocess_result()
362
340
  preprocess_state = preprocess_result[state]
363
341
  for handler in global_leap_binder.setup_container.metadata:
364
- handler_result = handler.function(sample_id, preprocess_state)
342
+ handler_result = handler.function(idx, preprocess_state)
365
343
  if isinstance(handler_result, dict):
366
344
  for single_metadata_name, single_metadata_result in handler_result.items():
367
345
  handler_name = f'{handler.name}_{single_metadata_name}'
@@ -371,14 +349,3 @@ class LeapLoader:
371
349
  result_agg[handler_name] = self._convert_metadata_to_correct_type(handler_name, handler_result)
372
350
 
373
351
  return result_agg
374
-
375
- @lru_cache()
376
- def get_sample_id_type(self) -> Type:
377
- preprocess_results = list(self._preprocess_result().values())
378
- id_type = preprocess_results[0].sample_id_type
379
- for preprocess_result in preprocess_results:
380
- if preprocess_result.sample_id_type != id_type:
381
- raise Exception("Different id types in preprocess results")
382
-
383
- return id_type
384
-
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.49.dev8
3
+ Version: 1.0.50
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/code-loader
6
6
  License: MIT
@@ -13,7 +13,6 @@ Classifier: Programming Language :: Python :: 3.8
13
13
  Classifier: Programming Language :: Python :: 3.9
14
14
  Classifier: Programming Language :: Python :: 3.10
15
15
  Classifier: Programming Language :: Python :: 3.11
16
- Requires-Dist: matplotlib (>=3.3,<3.4)
17
16
  Requires-Dist: numpy (>=1.22.3,<2.0.0)
18
17
  Requires-Dist: psutil (>=5.9.5,<6.0.0)
19
18
  Requires-Dist: pyyaml (>=6.0.2,<7.0.0)
@@ -1,12 +1,11 @@
1
1
  LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
2
2
  code_loader/__init__.py,sha256=6MMWr0ObOU7hkqQKgOqp4Zp3I28L7joGC9iCbQYtAJg,241
3
- code_loader/code_inegration_processes_manager.py,sha256=XslWOPeNQk4RAFJ_f3tP5Oe3EgcIR7BE7Y8r9Ty73-o,3261
4
3
  code_loader/contract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- code_loader/contract/datasetclasses.py,sha256=U-9bKvVkojKTqBLFBVUxKzhPu-pvP5eYXP8IM509-Zw,6584
4
+ code_loader/contract/datasetclasses.py,sha256=HPm-z82EbkIk_C_vkCpD8oBs5pgUpStzciMRV0auMlI,5679
6
5
  code_loader/contract/enums.py,sha256=6Lo7p5CUog68Fd31bCozIuOgIp_IhSiPqWWph2k3OGU,1602
7
6
  code_loader/contract/exceptions.py,sha256=jWqu5i7t-0IG0jGRsKF4DjJdrsdpJjIYpUkN1F4RiyQ,51
8
7
  code_loader/contract/responsedataclasses.py,sha256=w7xVOv2S8Hyb5lqyomMGiKAWXDTSOG-FX1YW39bXD3A,3969
9
- code_loader/contract/visualizer_classes.py,sha256=Ka8fJSVKrOeZ12Eg8-dOBGMW0UswYCIkuhnNMd-7z9s,22948
8
+ code_loader/contract/visualizer_classes.py,sha256=iIa_O2rKvPTwN5ILCTZvRpsGYiiFABKdwQwfIXGigDo,11928
10
9
  code_loader/experiment_api/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
10
  code_loader/experiment_api/api.py,sha256=a7wh6Hhe7IaVxu46eV2soSz-yxnmXG3ipU1BBtsEAaQ,2493
12
11
  code_loader/experiment_api/cli_config_utils.py,sha256=n6JMyNrquxql3KKxHhAP8jAzezlRT-PV2KWI95kKsm0,1140
@@ -18,12 +17,12 @@ code_loader/experiment_api/types.py,sha256=MY8xFARHwdVA7p4dxyhD60ShmttgTvb4qdp1o
18
17
  code_loader/experiment_api/utils.py,sha256=XZHtxge12TS4H4-8PjV3sKuhp8Ud6ojAiIzTZJEqBqc,3304
19
18
  code_loader/experiment_api/workingspace_config_utils.py,sha256=DLzXQCg4dgTV_YgaSbeTVzq-2ja_SQw4zi7LXwKL9cY,990
20
19
  code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
21
- code_loader/inner_leap_binder/leapbinder.py,sha256=URm7W1W0KnmDqFMp1yPQrZQwm2B9DQ9oHVvAwRFkJuc,24934
22
- code_loader/leaploader.py,sha256=rXuYkkk0aJR1mczQ7s4U2swnAWeaW7dg8XF9U0nQHXE,19071
20
+ code_loader/inner_leap_binder/leapbinder.py,sha256=ALUtiRYBxxP1xjza8WWZvVt3jNmfevRnxPYIQ4wy3g4,23808
21
+ code_loader/leaploader.py,sha256=_iB23STM_6PuedtRsI_tod3dUoe1j5YoNuuoASBLLPc,17481
23
22
  code_loader/utils.py,sha256=TZAoUbA2pE8eK3Le3s5Xr4eRaYdeDMQtxotx6rh-5oE,2185
24
23
  code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
24
  code_loader/visualizers/default_visualizers.py,sha256=VoqO9FN84yXyMjRjHjUTOt2GdTkJRMbHbXJ1cJkREkk,2230
26
- code_loader-1.0.49.dev8.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
27
- code_loader-1.0.49.dev8.dist-info/METADATA,sha256=uaktiu0VsFQfasyvzFeZaHVgcU851Kh7z6g_p-__1Eo,893
28
- code_loader-1.0.49.dev8.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
29
- code_loader-1.0.49.dev8.dist-info/RECORD,,
25
+ code_loader-1.0.50.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
26
+ code_loader-1.0.50.dist-info/METADATA,sha256=ThRDw4Frh9tilH0mroPmMI2bYeKQ-JXjheuLpRI0Bn8,849
27
+ code_loader-1.0.50.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
28
+ code_loader-1.0.50.dist-info/RECORD,,
@@ -1,83 +0,0 @@
1
- # mypy: ignore-errors
2
- import traceback
3
- from dataclasses import dataclass
4
-
5
- from typing import List, Tuple, Optional
6
-
7
- from multiprocessing import Process, Queue
8
-
9
- from code_loader.leap_loader_parallelized_base import LeapLoaderParallelizedBase
10
- from code_loader.leaploader import LeapLoader
11
- from code_loader.contract.enums import DataStateEnum
12
- from code_loader.metric_calculator_parallelized import MetricCalculatorParallelized
13
- from code_loader.samples_generator_parallelized import SamplesGeneratorParallelized
14
-
15
-
16
- @dataclass
17
- class SampleSerializableError:
18
- state: DataStateEnum
19
- index: int
20
- leap_script_trace: str
21
- exception_as_str: str
22
-
23
-
24
- class CodeIntegrationProcessesManager:
25
- def __init__(self, code_path: str, code_entry_name: str, n_workers: Optional[int] = 2,
26
- max_samples_in_queue: int = 128) -> None:
27
- self.metric_calculator_parallelized = MetricCalculatorParallelized(code_path, code_entry_name)
28
- self.samples_generator_parallelized = SamplesGeneratorParallelized(code_path, code_entry_name)
29
-
30
- def _create_and_start_process(self) -> Process:
31
- process = self.multiprocessing_context.Process(
32
- target=CodeIntegrationProcessesManager._process_func,
33
- args=(self.code_path, self.code_entry_name, self._inputs_waiting_to_be_process,
34
- self._ready_processed_results))
35
- process.daemon = True
36
- process.start()
37
- return process
38
-
39
- def _run_and_warm_first_process(self):
40
- process = self._create_and_start_process()
41
- self.processes = [process]
42
-
43
- # needed in order to make sure the preprocess func runs once in nonparallel
44
- self._start_process_inputs([(DataStateEnum.training, 0)])
45
- self._get_next_ready_processed_result()
46
-
47
- def _operation_decider(self):
48
- if self.metric_calculator_parallelized._ready_processed_results.empty() and not \
49
- self.metric_calculator_parallelized._inputs_waiting_to_be_process.empty():
50
- return 'metric'
51
-
52
- if self.samples_generator_parallelized._ready_processed_results.empty() and not \
53
- self.samples_generator_parallelized._inputs_waiting_to_be_process.empty():
54
- return 'dataset'
55
-
56
-
57
-
58
-
59
- @staticmethod
60
- def _process_func(code_path: str, code_entry_name: str,
61
- samples_to_process: Queue, ready_samples: Queue,
62
- metrics_to_process: Queue, ready_metrics: Queue) -> None:
63
- import os
64
- os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
65
-
66
- leap_loader = LeapLoader(code_path, code_entry_name)
67
- while True:
68
-
69
- # decide on sample or metric to process
70
- state, idx = samples_to_process.get(block=True)
71
- leap_loader._preprocess_result()
72
- try:
73
- sample = leap_loader.get_sample(state, idx)
74
- except Exception as e:
75
- leap_script_trace = traceback.format_exc().split('File "<string>"')[-1]
76
- ready_samples.put(SampleSerializableError(state, idx, leap_script_trace, str(e)))
77
- continue
78
-
79
- ready_samples.put(sample)
80
-
81
- def generate_samples(self, sample_identities: List[Tuple[DataStateEnum, int]]):
82
- return self.start_process_inputs(sample_identities)
83
-