code-loader 1.0.101.dev0__tar.gz → 1.0.102__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of code-loader might be problematic. Click here for more details.

Files changed (35) hide show
  1. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/PKG-INFO +1 -1
  2. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/code_loader/contract/datasetclasses.py +5 -0
  3. code_loader-1.0.102/code_loader/helpers/plot_functions.py +425 -0
  4. code_loader-1.0.102/code_loader/helpers/visualize.py +19 -0
  5. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/code_loader/leaploader.py +11 -7
  6. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/code_loader/utils.py +1 -26
  7. code_loader-1.0.102/code_loader/visualizers/__init__.py +0 -0
  8. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/pyproject.toml +1 -1
  9. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/LICENSE +0 -0
  10. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/README.md +0 -0
  11. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/code_loader/__init__.py +0 -0
  12. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/code_loader/contract/__init__.py +0 -0
  13. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/code_loader/contract/enums.py +0 -0
  14. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/code_loader/contract/exceptions.py +0 -0
  15. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/code_loader/contract/mapping.py +0 -0
  16. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/code_loader/contract/responsedataclasses.py +0 -0
  17. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/code_loader/contract/visualizer_classes.py +0 -0
  18. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/code_loader/default_losses.py +0 -0
  19. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/code_loader/default_metrics.py +0 -0
  20. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/code_loader/experiment_api/__init__.py +0 -0
  21. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/code_loader/experiment_api/api.py +0 -0
  22. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/code_loader/experiment_api/cli_config_utils.py +0 -0
  23. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/code_loader/experiment_api/client.py +0 -0
  24. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/code_loader/experiment_api/epoch.py +0 -0
  25. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/code_loader/experiment_api/experiment.py +0 -0
  26. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/code_loader/experiment_api/experiment_context.py +0 -0
  27. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/code_loader/experiment_api/types.py +0 -0
  28. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/code_loader/experiment_api/utils.py +0 -0
  29. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/code_loader/experiment_api/workingspace_config_utils.py +0 -0
  30. {code_loader-1.0.101.dev0/code_loader/visualizers → code_loader-1.0.102/code_loader/helpers}/__init__.py +0 -0
  31. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/code_loader/inner_leap_binder/__init__.py +0 -0
  32. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/code_loader/inner_leap_binder/leapbinder.py +0 -0
  33. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/code_loader/inner_leap_binder/leapbinder_decorators.py +0 -0
  34. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/code_loader/leaploaderbase.py +0 -0
  35. {code_loader-1.0.101.dev0 → code_loader-1.0.102}/code_loader/visualizers/default_visualizers.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.101.dev0
3
+ Version: 1.0.102
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/code-loader
6
6
  License: MIT
@@ -43,6 +43,9 @@ class PreprocessResponse:
43
43
  instance_ids_to_names: Optional[Dict[str, str]] = None # in use only for element instance
44
44
 
45
45
  def __post_init__(self) -> None:
46
+ def is_valid_string(s: str) -> bool:
47
+ return bool(re.match(r'^[A-Za-z0-9_]+$', s))
48
+
46
49
  assert self.sample_ids_to_instance_mappings is None, f"Keep sample_ids_to_instance_mappings None when initializing PreprocessResponse"
47
50
  assert self.instance_to_sample_ids_mappings is None, f"Keep instance_to_sample_ids_mappings None when initializing PreprocessResponse"
48
51
  assert self.instance_ids_to_names is None, f"Keep instance_ids_to_names None when initializing PreprocessResponse"
@@ -57,6 +60,8 @@ class PreprocessResponse:
57
60
  if self.sample_id_type == str:
58
61
  for sample_id in self.sample_ids:
59
62
  assert isinstance(sample_id, str), f"Sample id should be of type str. Got: {type(sample_id)}"
63
+ if not is_valid_string(sample_id):
64
+ raise Exception(f"Sample id should contain only letters (A-Z, a-z), numbers or '_'. Got: {sample_id}")
60
65
  else:
61
66
  raise Exception("length is deprecated.")
62
67
 
@@ -0,0 +1,425 @@
1
+
2
+ import os
3
+
4
+ from code_loader.inner_leap_binder.leapbinder import mapping_runtime_mode_env_var_mame
5
+
6
+ if not os.environ.get(mapping_runtime_mode_env_var_mame):
7
+ try:
8
+ import matplotlib.pyplot as plt # type: ignore
9
+ except ImportError:
10
+ raise ImportError(
11
+ "Matplotlib is not installed. Please install it using 'pip install matplotlib' to visualize Leap data."
12
+ )
13
+
14
+
15
+ import numpy as np
16
+ from code_loader.contract.enums import LeapDataType
17
+ from textwrap import wrap
18
+ import math
19
+
20
+ from code_loader.contract.visualizer_classes import LeapImage, LeapImageWithBBox, LeapGraph, LeapText, \
21
+ LeapHorizontalBar, LeapImageMask, LeapTextMask, LeapImageWithHeatmap
22
+
23
+
24
+ def plot_image_with_b_box(leap_data: LeapImageWithBBox, title: str) -> None:
25
+ """
26
+ Plot an image with overlaid bounding boxes.
27
+
28
+ Returns:
29
+ None
30
+
31
+ Example:
32
+ image_data = np.random.rand(100, 100, 3).astype(np.float32)
33
+ bbox = BoundingBox(x=0.5, y=0.5, width=0.2, height=0.2, confidence=0.9, label="object")
34
+ leap_image_with_bbox = LeapImageWithBBox(data=image_data, bounding_boxes=[bbox])
35
+ title = "Image With bbox"
36
+ visualize(leap_image_with_bbox, title)
37
+ """
38
+
39
+ image = leap_data.data
40
+ bounding_boxes = leap_data.bounding_boxes
41
+
42
+ # Create figure and axes
43
+ fig, ax = plt.subplots(1)
44
+ fig.patch.set_facecolor('black')
45
+ ax.set_facecolor('black')
46
+
47
+ # Display the image
48
+ ax.imshow(image)
49
+ ax.set_title(title, color='white')
50
+
51
+ # Draw bounding boxes on the image
52
+ for bbox in bounding_boxes:
53
+ x, y, width, height = bbox.x, bbox.y, bbox.width, bbox.height
54
+ confidence, label = bbox.confidence, bbox.label
55
+
56
+ # Convert relative coordinates to absolute coordinates
57
+ abs_x = x * image.shape[1]
58
+ abs_y = y * image.shape[0]
59
+ abs_width = width * image.shape[1]
60
+ abs_height = height * image.shape[0]
61
+
62
+ # Create a rectangle patch
63
+ rect = plt.Rectangle(
64
+ (abs_x - abs_width / 2, abs_y - abs_height / 2),
65
+ abs_width, abs_height,
66
+ linewidth=3, edgecolor='r', facecolor='none'
67
+ )
68
+
69
+ # Add the rectangle to the axes
70
+ ax.add_patch(rect)
71
+
72
+ # Display label and confidence
73
+ ax.text(abs_x - abs_width / 2, abs_y - abs_height / 2 - 5,
74
+ f"{label} {confidence:.2f}", color='r', fontsize=8,
75
+ bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', boxstyle='round,pad=0.3'))
76
+
77
+ # Show the image with bounding boxes
78
+ plt.show()
79
+
80
+
81
+ def plot_image(leap_data: LeapImage, title: str) -> None:
82
+ """
83
+ Display the image contained in the LeapImage object.
84
+
85
+ Returns:
86
+ None
87
+
88
+ Example:
89
+ image_data = np.random.rand(100, 100, 3).astype(np.float32)
90
+ leap_image = LeapImage(data=image_data)
91
+ title = "Image"
92
+ visualize(leap_image, title)
93
+ """
94
+ image_data = leap_data.data
95
+
96
+ # If the image has one channel, convert it to a 3-channel image for display
97
+ if image_data.shape[2] == 1:
98
+ image_data = np.repeat(image_data, 3, axis=2)
99
+
100
+ fig, ax = plt.subplots()
101
+ fig.patch.set_facecolor('black')
102
+ ax.set_facecolor('black')
103
+
104
+ ax.imshow(image_data)
105
+
106
+ plt.axis('off')
107
+ plt.title(title, color='white')
108
+ plt.show()
109
+
110
+
111
+ def plot_graph(leap_data: LeapGraph, title: str) -> None:
112
+ """
113
+ Display the line chart contained in the LeapGraph object.
114
+
115
+ Returns:
116
+ None
117
+
118
+ Example:
119
+ graph_data = np.random.rand(100, 3).astype(np.float32)
120
+ leap_graph = LeapGraph(data=graph_data)
121
+ title = "Graph"
122
+ visualize(leap_graph, title)
123
+ """
124
+ graph_data = leap_data.data
125
+ num_variables = graph_data.shape[1]
126
+
127
+ fig, ax = plt.subplots(figsize=(10, 6))
128
+
129
+ # Set the background color to black
130
+ fig.patch.set_facecolor('black')
131
+ ax.set_facecolor('black')
132
+
133
+ for i in range(num_variables):
134
+ plt.plot(graph_data[:, i], label=f'Variable {i + 1}')
135
+
136
+ ax.set_xlabel('Data Points', color='white')
137
+ ax.set_ylabel('Values', color='white')
138
+ ax.set_title(title, color='white')
139
+ ax.legend()
140
+ ax.grid(True, color='white')
141
+
142
+ # Change the color of the tick labels to white
143
+ ax.tick_params(colors='white')
144
+
145
+ plt.show()
146
+
147
+
148
+ def plot_text_with_heatmap(leap_data: LeapText, title: str) -> None:
149
+ """
150
+ Display the text contained in the LeapText object with a heatmap overlay.
151
+
152
+ Args:
153
+ leap_data (LeapData): The LeapText object containing text tokens and an optional heatmap.
154
+ title (str): The title of the visualization.
155
+
156
+ Returns:
157
+ None
158
+
159
+ Example:
160
+ text_data = ['I', 'ate', 'a', 'banana', '', '', '']
161
+ heatmap = [0.1, 0.3, 0.2, 0.9, 0.0, 0.0, 0.0]
162
+ leap_text = LeapText(data=text_data, heatmap=heatmap) # Create LeapText object
163
+ title = "Text with Heatmap"
164
+ visualize(leap_text, title)
165
+ """
166
+ text_data = leap_data.data
167
+ heatmap = leap_data.heatmap
168
+
169
+ text_data = [s for s in text_data if s != "[PAD]"]
170
+
171
+ fig, ax = plt.subplots(figsize=(12, 5))
172
+ fig.patch.set_facecolor('black')
173
+ ax.set_facecolor('black')
174
+ ax.axis('off') # Hide axes
175
+
176
+ font_size = 20
177
+
178
+ if heatmap is not None:
179
+ heatmap = heatmap[:len(text_data)]
180
+ if len(heatmap) != len(text_data):
181
+ raise ValueError(
182
+ f"Heatmap length ({len(heatmap)}) must match the number of tokens in `data` ({len(text_data)}).")
183
+
184
+ max_tokens_per_row = 10
185
+ num_rows = math.ceil(len(text_data) / max_tokens_per_row)
186
+
187
+ fig.set_size_inches(12, num_rows * 1.2)
188
+ for idx, (token, value) in enumerate(zip(text_data, heatmap)):
189
+ if token:
190
+ row = idx // max_tokens_per_row
191
+ col = idx % max_tokens_per_row
192
+
193
+ x_pos = col / max_tokens_per_row + 0.03
194
+ y_pos = 1 - (row + 0.5) / num_rows
195
+ color = plt.cm.jet(value)
196
+ ax.text(
197
+ x_pos,
198
+ y_pos,
199
+ token,
200
+ fontsize=font_size,
201
+ color=color,
202
+ ha="left",
203
+ va="center"
204
+ )
205
+ else:
206
+ display_text = ' '.join([token for token in text_data if token])
207
+ wrapped_text = "\n".join(wrap(display_text, width=80))
208
+ font_color = 'white'
209
+ ax.text(0.5, 0.5, wrapped_text, color=font_color, fontsize=font_size, ha='center', va='center')
210
+
211
+ ax.set_title(title, color='white', fontsize=16)
212
+
213
+ plt.tight_layout()
214
+ plt.show()
215
+
216
+
217
+ def plot_hbar(leap_data: LeapHorizontalBar, title: str) -> None:
218
+ """
219
+ Display the horizontal bar chart contained in the LeapHorizontalBar object.
220
+
221
+ Returns:
222
+ None
223
+
224
+ Example:
225
+ body_data = np.random.rand(5).astype(np.float32)
226
+ gt_data = np.random.rand(5).astype(np.float32)
227
+ labels = ['Class A', 'Class B', 'Class C', 'Class D', 'Class E']
228
+ leap_horizontal_bar = LeapHorizontalBar(body=body_data, gt=gt_data, labels=labels)
229
+ title = "Horizontal Bar"
230
+ visualize(leap_horizontal_bar, title)
231
+ """
232
+ body_data = leap_data.body
233
+ labels = leap_data.labels
234
+
235
+ # Check if 'gt' attribute exists and is not None
236
+ gt_data = getattr(leap_data, 'gt', None)
237
+
238
+ fig, ax = plt.subplots()
239
+
240
+ fig.patch.set_facecolor('black')
241
+ ax.set_facecolor('black')
242
+
243
+ # Adjust positions for side-by-side bars
244
+ y_positions = range(len(labels))
245
+ bar_width = 0.4
246
+
247
+ # Plot horizontal bar chart
248
+ if gt_data is not None:
249
+ ax.barh([y - bar_width / 2 for y in y_positions], body_data, color='green', height=bar_width, label='Prediction')
250
+ ax.barh([y + bar_width / 2 for y in y_positions], gt_data, color='orange', height=bar_width, label='GT')
251
+ else:
252
+ ax.barh(y_positions, body_data, color='green', label='Body Data')
253
+
254
+ # Set the y-ticks to align with the center of the bars
255
+ ax.set_yticks(y_positions)
256
+ ax.set_yticklabels(labels, color='white')
257
+
258
+ # Set the color of the labels and title to white
259
+ ax.set_xlabel('Scores', color='white')
260
+ ax.set_title(title, color='white')
261
+
262
+ # Set the color of the ticks to white
263
+ ax.tick_params(axis='x', colors='white')
264
+ ax.tick_params(axis='y', colors='white')
265
+
266
+ # Add legend if gt is present
267
+ if gt_data is not None:
268
+ ax.legend(loc='best', facecolor='black', edgecolor='white', labelcolor='white')
269
+
270
+ plt.show()
271
+
272
+
273
+ def plot_image_mask(leap_data: LeapImageMask, title: str) -> None:
274
+ """
275
+ Plots an image with overlaid masks given a LeapImageMask visualizer object.
276
+
277
+ Returns:
278
+ None
279
+
280
+
281
+ Example:
282
+ image_data = np.random.rand(100, 100, 3).astype(np.float32)
283
+ mask_data = np.random.randint(0, 2, (100, 100)).astype(np.uint8)
284
+ labels = ["background", "object"]
285
+ leap_image_mask = LeapImageMask(image=image_data, mask=mask_data, labels=labels)
286
+ title = "Image Mask"
287
+ visualize(leap_image_mask, title)
288
+ """
289
+
290
+ image = leap_data.image
291
+ mask = leap_data.mask
292
+ labels = leap_data.labels
293
+
294
+ # Create a color map for each label
295
+ colors = plt.cm.jet(np.linspace(0, 1, len(labels)))
296
+ if image.dtype == np.uint8:
297
+ colors = colors * 255
298
+
299
+ # Make a copy of the image to draw on
300
+ overlayed_image = image.copy()
301
+
302
+ # Iterate through the unique values in the mask (excluding 0)
303
+ for i, label in enumerate(labels):
304
+ # Extract binary mask for the current instance
305
+ instance_mask = (mask == (i + 1))
306
+
307
+ # fill the instance mask with a translucent color
308
+ overlayed_image[instance_mask] = (
309
+ overlayed_image[instance_mask] * (1 - 0.5) + np.array(colors[i][:image.shape[-1]], dtype=np.uint8) * 0.5)
310
+
311
+ # Display the result using matplotlib
312
+ fig, ax = plt.subplots(1)
313
+ fig.patch.set_facecolor('black') # Set the figure background to black
314
+ ax.set_facecolor('black') # Set the axis background to black
315
+
316
+ ax.imshow(overlayed_image)
317
+ ax.set_title(title, color='white')
318
+ plt.axis('off') # Hide the axis
319
+ plt.show()
320
+
321
+
322
+ def plot_text_mask(leap_data: LeapTextMask, title: str) -> None:
323
+ """
324
+ Plots text with overlaid masks given a LeapTextMask visualizer object.
325
+
326
+ Returns:
327
+ None
328
+
329
+ Example:
330
+ text_data = ['I', 'ate', 'a', 'banana', '', '', '']
331
+ mask_data = np.array([0, 0, 0, 1, 0, 0, 0]).astype(np.uint8)
332
+ labels = ["object"]
333
+ leap_text_mask = LeapTextMask(text=text_data, mask=mask_data, labels=labels)
334
+ title = "Text Mask"
335
+ visualize(leap_text_mask, title)
336
+ """
337
+
338
+ text_data = leap_data.text
339
+ mask_data = leap_data.mask
340
+ labels = leap_data.labels
341
+
342
+ # Create a color map for each label
343
+ colors = plt.cm.jet(np.linspace(0, 1, len(labels)))
344
+
345
+ # Create a figure and axis
346
+ fig, ax = plt.subplots()
347
+
348
+ # Set background to black
349
+ fig.patch.set_facecolor('black')
350
+ ax.set_facecolor('black')
351
+ ax.set_title(title, color='white')
352
+ ax.axis('off')
353
+
354
+ # Set initial position
355
+ x_pos, y_pos = 0.01, 0.5 # Adjusted initial position for better visibility
356
+
357
+ # Display the text with colors
358
+ for token, mask_value in zip(text_data, mask_data):
359
+ if mask_value > 0:
360
+ color = colors[mask_value % len(colors)]
361
+ bbox = dict(facecolor=color, edgecolor='none',
362
+ boxstyle='round,pad=0.3') # Background color for masked tokens
363
+ else:
364
+ bbox = None
365
+
366
+ ax.text(x_pos, y_pos, token, fontsize=12, color='white', ha='left', va='center', bbox=bbox)
367
+
368
+ # Update the x position for the next token
369
+ x_pos += len(token) * 0.03 + 0.02 # Adjust the spacing between tokens
370
+
371
+ plt.show()
372
+
373
+
374
+ def plot_image_with_heatmap(leap_data: LeapImageWithHeatmap, title: str) -> None:
375
+ """
376
+ Display the image with overlaid heatmaps contained in the LeapImageWithHeatmap object.
377
+
378
+ Returns:
379
+ None
380
+
381
+ Example:
382
+ image_data = np.random.rand(100, 100, 3).astype(np.float32)
383
+ heatmaps = np.random.rand(3, 100, 100).astype(np.float32)
384
+ labels = ["heatmap1", "heatmap2", "heatmap3"]
385
+ leap_image_with_heatmap = LeapImageWithHeatmap(image=image_data, heatmaps=heatmaps, labels=labels)
386
+ title = "Image With Heatmap"
387
+ visualize(leap_image_with_heatmap, title)
388
+ """
389
+ image = leap_data.image
390
+ heatmaps = leap_data.heatmaps
391
+ labels = leap_data.labels
392
+
393
+ # Plot the base image
394
+ fig, ax = plt.subplots()
395
+ fig.patch.set_facecolor('black') # Set the figure background to black
396
+ ax.set_facecolor('black') # Set the axis background to black
397
+ ax.imshow(image, cmap='gray')
398
+
399
+ # Overlay each heatmap with some transparency
400
+ for i in range(len(labels)):
401
+ heatmap = heatmaps[i]
402
+ ax.imshow(heatmap, cmap='jet', alpha=0.5) # Adjust alpha for transparency
403
+ ax.set_title(f'Heatmap: {labels[i]}', color='white')
404
+
405
+ # Display a colorbar for the heatmap
406
+ cbar = plt.colorbar(ax.imshow(heatmap, cmap='jet', alpha=0.5))
407
+ cbar.set_label(labels[i], color='white')
408
+ cbar.ax.yaxis.set_tick_params(color='white') # Set color for the colorbar ticks
409
+ plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), color='white') # Set color for the colorbar labels
410
+
411
+ plt.axis('off')
412
+ plt.title(title, color='white')
413
+ plt.show()
414
+
415
+
416
+ plot_switch = {
417
+ LeapDataType.Image: plot_image,
418
+ LeapDataType.Text: plot_text_with_heatmap,
419
+ LeapDataType.Graph: plot_graph,
420
+ LeapDataType.HorizontalBar: plot_hbar,
421
+ LeapDataType.ImageMask: plot_image_mask,
422
+ LeapDataType.TextMask: plot_text_mask,
423
+ LeapDataType.ImageWithHeatmap: plot_image_with_heatmap,
424
+ LeapDataType.ImageWithBBox: plot_image_with_b_box,
425
+ }
@@ -0,0 +1,19 @@
1
+ import sys
2
+
3
+ from code_loader.contract.datasetclasses import LeapData
4
+
5
+ from typing import Optional
6
+
7
+ from code_loader.helpers.plot_functions import plot_switch
8
+
9
+
10
+ def visualize(leap_data: LeapData, title: Optional[str] = None) -> None:
11
+ vis_function = plot_switch.get(leap_data.type)
12
+ if vis_function is None:
13
+ print(f"Error: leap data type is not supported, leap data type: {leap_data.type}")
14
+ sys.exit(1)
15
+
16
+ if not title:
17
+ title = f"Leap {leap_data.type.name} Visualization"
18
+ vis_function(leap_data, title) # type: ignore[operator]
19
+
@@ -26,7 +26,7 @@ from code_loader.contract.responsedataclasses import DatasetIntegParseResult, Da
26
26
  from code_loader.inner_leap_binder import global_leap_binder
27
27
  from code_loader.inner_leap_binder.leapbinder import mapping_runtime_mode_env_var_mame
28
28
  from code_loader.leaploaderbase import LeapLoaderBase
29
- from code_loader.utils import get_root_exception_file_and_line_number, flatten
29
+ from code_loader.utils import get_root_exception_file_and_line_number
30
30
 
31
31
 
32
32
  class LeapLoader(LeapLoaderBase):
@@ -514,18 +514,22 @@ class LeapLoader(LeapLoaderBase):
514
514
 
515
515
  return converted_value, is_none
516
516
 
517
- def _get_metadata(self, state: DataStateEnum, sample_id: Union[int, str]) -> Tuple[
518
- Dict[str, Union[str, int, bool, float]], Dict[str, bool]]:
517
+ def _get_metadata(self, state: DataStateEnum, sample_id: Union[int, str]) -> Tuple[Dict[str, Union[str, int, bool, float]], Dict[str, bool]]:
519
518
  result_agg = {}
520
519
  is_none = {}
521
520
  preprocess_result = self._preprocess_result()
522
521
  preprocess_state = preprocess_result[state]
523
522
  for handler in global_leap_binder.setup_container.metadata:
524
523
  handler_result = handler.function(sample_id, preprocess_state)
525
-
526
- for flat_name, flat_result in flatten(handler_result, prefix=handler.name):
527
- result_agg[flat_name], is_none[flat_name] = self._convert_metadata_to_correct_type(
528
- flat_name, flat_result)
524
+ if isinstance(handler_result, dict):
525
+ for single_metadata_name, single_metadata_result in handler_result.items():
526
+ handler_name = f'{handler.name}_{single_metadata_name}'
527
+ result_agg[handler_name], is_none[handler_name] = self._convert_metadata_to_correct_type(
528
+ handler_name, single_metadata_result)
529
+ else:
530
+ handler_name = handler.name
531
+ result_agg[handler_name], is_none[handler_name] = self._convert_metadata_to_correct_type(
532
+ handler_name, handler_result)
529
533
 
530
534
  return result_agg, is_none
531
535
 
@@ -1,7 +1,7 @@
1
1
  import sys
2
2
  from pathlib import Path
3
3
  from types import TracebackType
4
- from typing import List, Union, Tuple, Any, Iterator, Callable
4
+ from typing import List, Union, Tuple, Any, Callable
5
5
  import traceback
6
6
  import numpy as np
7
7
  import numpy.typing as npt
@@ -76,28 +76,3 @@ def rescale_min_max(image: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
76
76
  return image
77
77
 
78
78
 
79
- def flatten(
80
- value: Any,
81
- *,
82
- prefix: str = "",
83
- list_token: str = "e",
84
- ) -> Iterator[Tuple[str, Any]]:
85
- """
86
- Recursively walk `value` and yield (flat_key, leaf_value) pairs.
87
-
88
- • Dicts → descend with new_prefix = f"{prefix}_{key}" (or just key if top level)
89
- • Sequences → descend with new_prefix = f"{prefix}_{list_token}{idx}"
90
- • Leaf scalars → yield the accumulated flat key and the scalar itself
91
- """
92
- if isinstance(value, dict):
93
- for k, v in value.items():
94
- new_prefix = f"{prefix}_{k}" if prefix else k
95
- yield from flatten(v, prefix=new_prefix, list_token=list_token)
96
-
97
- elif isinstance(value, (list, tuple)):
98
- for idx, v in enumerate(value):
99
- new_prefix = f"{prefix}_{list_token}{idx}"
100
- yield from flatten(v, prefix=new_prefix, list_token=list_token)
101
-
102
- else: # primitive leaf (str, int, float, bool, None…)
103
- yield prefix, value
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "code-loader"
3
- version = "1.0.101.dev0"
3
+ version = "1.0.102"
4
4
  description = ""
5
5
  authors = ["dorhar <doron.harnoy@tensorleap.ai>"]
6
6
  license = "MIT"