code-loader 1.0.101__py3-none-any.whl → 1.0.101.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -43,9 +43,6 @@ class PreprocessResponse:
43
43
  instance_ids_to_names: Optional[Dict[str, str]] = None # in use only for element instance
44
44
 
45
45
  def __post_init__(self) -> None:
46
- def is_valid_string(s: str) -> bool:
47
- return bool(re.match(r'^[A-Za-z0-9_]+$', s))
48
-
49
46
  assert self.sample_ids_to_instance_mappings is None, f"Keep sample_ids_to_instance_mappings None when initializing PreprocessResponse"
50
47
  assert self.instance_to_sample_ids_mappings is None, f"Keep instance_to_sample_ids_mappings None when initializing PreprocessResponse"
51
48
  assert self.instance_ids_to_names is None, f"Keep instance_ids_to_names None when initializing PreprocessResponse"
@@ -60,8 +57,6 @@ class PreprocessResponse:
60
57
  if self.sample_id_type == str:
61
58
  for sample_id in self.sample_ids:
62
59
  assert isinstance(sample_id, str), f"Sample id should be of type str. Got: {type(sample_id)}"
63
- if not is_valid_string(sample_id):
64
- raise Exception(f"Sample id should contain only letters (A-Z, a-z), numbers or '_'. Got: {sample_id}")
65
60
  else:
66
61
  raise Exception("length is deprecated.")
67
62
 
code_loader/leaploader.py CHANGED
@@ -26,7 +26,7 @@ from code_loader.contract.responsedataclasses import DatasetIntegParseResult, Da
26
26
  from code_loader.inner_leap_binder import global_leap_binder
27
27
  from code_loader.inner_leap_binder.leapbinder import mapping_runtime_mode_env_var_mame
28
28
  from code_loader.leaploaderbase import LeapLoaderBase
29
- from code_loader.utils import get_root_exception_file_and_line_number
29
+ from code_loader.utils import get_root_exception_file_and_line_number, flatten
30
30
 
31
31
 
32
32
  class LeapLoader(LeapLoaderBase):
@@ -514,22 +514,18 @@ class LeapLoader(LeapLoaderBase):
514
514
 
515
515
  return converted_value, is_none
516
516
 
517
- def _get_metadata(self, state: DataStateEnum, sample_id: Union[int, str]) -> Tuple[Dict[str, Union[str, int, bool, float]], Dict[str, bool]]:
517
+ def _get_metadata(self, state: DataStateEnum, sample_id: Union[int, str]) -> Tuple[
518
+ Dict[str, Union[str, int, bool, float]], Dict[str, bool]]:
518
519
  result_agg = {}
519
520
  is_none = {}
520
521
  preprocess_result = self._preprocess_result()
521
522
  preprocess_state = preprocess_result[state]
522
523
  for handler in global_leap_binder.setup_container.metadata:
523
524
  handler_result = handler.function(sample_id, preprocess_state)
524
- if isinstance(handler_result, dict):
525
- for single_metadata_name, single_metadata_result in handler_result.items():
526
- handler_name = f'{handler.name}_{single_metadata_name}'
527
- result_agg[handler_name], is_none[handler_name] = self._convert_metadata_to_correct_type(
528
- handler_name, single_metadata_result)
529
- else:
530
- handler_name = handler.name
531
- result_agg[handler_name], is_none[handler_name] = self._convert_metadata_to_correct_type(
532
- handler_name, handler_result)
525
+
526
+ for flat_name, flat_result in flatten(handler_result, prefix=handler.name):
527
+ result_agg[flat_name], is_none[flat_name] = self._convert_metadata_to_correct_type(
528
+ flat_name, flat_result)
533
529
 
534
530
  return result_agg, is_none
535
531
 
code_loader/utils.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import sys
2
2
  from pathlib import Path
3
3
  from types import TracebackType
4
- from typing import List, Union, Tuple, Any, Callable
4
+ from typing import List, Union, Tuple, Any, Iterator, Callable
5
5
  import traceback
6
6
  import numpy as np
7
7
  import numpy.typing as npt
@@ -76,3 +76,28 @@ def rescale_min_max(image: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
76
76
  return image
77
77
 
78
78
 
79
+ def flatten(
80
+ value: Any,
81
+ *,
82
+ prefix: str = "",
83
+ list_token: str = "e",
84
+ ) -> Iterator[Tuple[str, Any]]:
85
+ """
86
+ Recursively walk `value` and yield (flat_key, leaf_value) pairs.
87
+
88
+ • Dicts → descend with new_prefix = f"{prefix}_{key}" (or just key if top level)
89
+ • Sequences → descend with new_prefix = f"{prefix}_{list_token}{idx}"
90
+ • Leaf scalars → yield the accumulated flat key and the scalar itself
91
+ """
92
+ if isinstance(value, dict):
93
+ for k, v in value.items():
94
+ new_prefix = f"{prefix}_{k}" if prefix else k
95
+ yield from flatten(v, prefix=new_prefix, list_token=list_token)
96
+
97
+ elif isinstance(value, (list, tuple)):
98
+ for idx, v in enumerate(value):
99
+ new_prefix = f"{prefix}_{list_token}{idx}"
100
+ yield from flatten(v, prefix=new_prefix, list_token=list_token)
101
+
102
+ else: # primitive leaf (str, int, float, bool, None…)
103
+ yield prefix, value
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.101
3
+ Version: 1.0.101.dev0
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/code-loader
6
6
  License: MIT
@@ -13,7 +13,6 @@ Classifier: Programming Language :: Python :: 3.8
13
13
  Classifier: Programming Language :: Python :: 3.9
14
14
  Classifier: Programming Language :: Python :: 3.10
15
15
  Classifier: Programming Language :: Python :: 3.11
16
- Requires-Dist: matplotlib (>=3.3.4)
17
16
  Requires-Dist: numpy (>=1.22.3,<2.0.0)
18
17
  Requires-Dist: psutil (>=5.9.5,<6.0.0)
19
18
  Requires-Dist: pyyaml (>=6.0.2,<7.0.0)
@@ -1,7 +1,7 @@
1
1
  LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
2
2
  code_loader/__init__.py,sha256=6MMWr0ObOU7hkqQKgOqp4Zp3I28L7joGC9iCbQYtAJg,241
3
3
  code_loader/contract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- code_loader/contract/datasetclasses.py,sha256=6hwzFOtVlAHcDYgz4n-yIQrpqjda-ZoBGeDzH1kogXc,9123
4
+ code_loader/contract/datasetclasses.py,sha256=gJsXu4zVAaiBlq6GJwPxfTD2e0gICTtI_6Ir61MRL48,8838
5
5
  code_loader/contract/enums.py,sha256=GEFkvUMXnCNt-GOoz7NJ9ecQZ2PPDettJNOsxsiM0wk,1622
6
6
  code_loader/contract/exceptions.py,sha256=jWqu5i7t-0IG0jGRsKF4DjJdrsdpJjIYpUkN1F4RiyQ,51
7
7
  code_loader/contract/mapping.py,sha256=e11h_sprwOyE32PcqgRq9JvyahQrPzwqgkhmbQLKLQY,1165
@@ -19,18 +19,15 @@ code_loader/experiment_api/experiment_context.py,sha256=kdzUbuzXo1pMVslOC3TKeJwW
19
19
  code_loader/experiment_api/types.py,sha256=MY8xFARHwdVA7p4dxyhD60ShmttgTvb4qdp1oEB_NPg,485
20
20
  code_loader/experiment_api/utils.py,sha256=XZHtxge12TS4H4-8PjV3sKuhp8Ud6ojAiIzTZJEqBqc,3304
21
21
  code_loader/experiment_api/workingspace_config_utils.py,sha256=DLzXQCg4dgTV_YgaSbeTVzq-2ja_SQw4zi7LXwKL9cY,990
22
- code_loader/helpers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
- code_loader/helpers/plot_functions.py,sha256=QyCDxWrXmqAz9jbwOcDrRZ2xXAMmzwfr1CaY5X9wcTY,13426
24
- code_loader/helpers/visualize.py,sha256=AUd5OSJs4YhL_8LJ44MQECJhsBiYfGrLk269GWhnp58,555
25
22
  code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
26
23
  code_loader/inner_leap_binder/leapbinder.py,sha256=mi9wp98iywHedCe2GwrbiqE14zbGo1rR0huodG96ZXY,32508
27
24
  code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=j38nYWfc6yll1SMggV8gABEvSyQwEBVf5RdFnmQ1aD0,38523
28
- code_loader/leaploader.py,sha256=vfN92-uoLeo8pojhwzPh4iu3gaoIQNqQklYwOy0kbtM,29225
25
+ code_loader/leaploader.py,sha256=kCNiLdbmGZBo1a6hE1gDRZyOeJLWH2THweO9AtepO3s,28869
29
26
  code_loader/leaploaderbase.py,sha256=lKdw2pd6H9hFsxVmc7jJMoZd_vlG5He1ooqT-cR_yq8,4496
30
- code_loader/utils.py,sha256=_j8b60pimoNAvWMRj7hEkkT6C76qES6cZoBFHpXHMxA,2698
27
+ code_loader/utils.py,sha256=lzisPgCxMo10dn_VFIlkM1fJaYjwaKXgiMB8zZo7oYw,3664
31
28
  code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
32
29
  code_loader/visualizers/default_visualizers.py,sha256=onRnLE_TXfgLN4o52hQIOOhUcFexGlqJ3xSpQDVLuZM,2604
33
- code_loader-1.0.101.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
34
- code_loader-1.0.101.dist-info/METADATA,sha256=-mfEuwf_Ltx_i9hrG3HhFwCibS1KYSvLcD3O6cwssAE,886
35
- code_loader-1.0.101.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
36
- code_loader-1.0.101.dist-info/RECORD,,
30
+ code_loader-1.0.101.dev0.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
31
+ code_loader-1.0.101.dev0.dist-info/METADATA,sha256=XNQDZDS2yCb-da4qiLyq6WM3ymDBp4e9MbNdZJltUBc,855
32
+ code_loader-1.0.101.dev0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
33
+ code_loader-1.0.101.dev0.dist-info/RECORD,,
File without changes
@@ -1,410 +0,0 @@
1
- import matplotlib.pyplot as plt # type: ignore
2
- import numpy as np
3
- from code_loader.contract.enums import LeapDataType # type: ignore
4
- from code_loader.contract.datasetclasses import LeapData # type: ignore
5
- from textwrap import wrap
6
- import math
7
-
8
-
9
- def plot_image_with_b_box(leap_data: LeapData, title: str) -> None:
10
- """
11
- Plot an image with overlaid bounding boxes.
12
-
13
- Returns:
14
- None
15
-
16
- Example:
17
- image_data = np.random.rand(100, 100, 3).astype(np.float32)
18
- bbox = BoundingBox(x=0.5, y=0.5, width=0.2, height=0.2, confidence=0.9, label="object")
19
- leap_image_with_bbox = LeapImageWithBBox(data=image_data, bounding_boxes=[bbox])
20
- title = "Image With bbox"
21
- visualize(leap_image_with_bbox, title)
22
- """
23
-
24
- image = leap_data.data
25
- bounding_boxes = leap_data.bounding_boxes
26
-
27
- # Create figure and axes
28
- fig, ax = plt.subplots(1)
29
- fig.patch.set_facecolor('black')
30
- ax.set_facecolor('black')
31
-
32
- # Display the image
33
- ax.imshow(image)
34
- ax.set_title(title, color='white')
35
-
36
- # Draw bounding boxes on the image
37
- for bbox in bounding_boxes:
38
- x, y, width, height = bbox.x, bbox.y, bbox.width, bbox.height
39
- confidence, label = bbox.confidence, bbox.label
40
-
41
- # Convert relative coordinates to absolute coordinates
42
- abs_x = x * image.shape[1]
43
- abs_y = y * image.shape[0]
44
- abs_width = width * image.shape[1]
45
- abs_height = height * image.shape[0]
46
-
47
- # Create a rectangle patch
48
- rect = plt.Rectangle(
49
- (abs_x - abs_width / 2, abs_y - abs_height / 2),
50
- abs_width, abs_height,
51
- linewidth=3, edgecolor='r', facecolor='none'
52
- )
53
-
54
- # Add the rectangle to the axes
55
- ax.add_patch(rect)
56
-
57
- # Display label and confidence
58
- ax.text(abs_x - abs_width / 2, abs_y - abs_height / 2 - 5,
59
- f"{label} {confidence:.2f}", color='r', fontsize=8,
60
- bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', boxstyle='round,pad=0.3'))
61
-
62
- # Show the image with bounding boxes
63
- plt.show()
64
-
65
-
66
- def plot_image(leap_data: LeapData, title: str) -> None:
67
- """
68
- Display the image contained in the LeapImage object.
69
-
70
- Returns:
71
- None
72
-
73
- Example:
74
- image_data = np.random.rand(100, 100, 3).astype(np.float32)
75
- leap_image = LeapImage(data=image_data)
76
- title = "Image"
77
- visualize(leap_image, title)
78
- """
79
- image_data = leap_data.data
80
-
81
- # If the image has one channel, convert it to a 3-channel image for display
82
- if image_data.shape[2] == 1:
83
- image_data = np.repeat(image_data, 3, axis=2)
84
-
85
- fig, ax = plt.subplots()
86
- fig.patch.set_facecolor('black')
87
- ax.set_facecolor('black')
88
-
89
- ax.imshow(image_data)
90
-
91
- plt.axis('off')
92
- plt.title(title, color='white')
93
- plt.show()
94
-
95
-
96
- def plot_graph(leap_data: LeapData, title: str) -> None:
97
- """
98
- Display the line chart contained in the LeapGraph object.
99
-
100
- Returns:
101
- None
102
-
103
- Example:
104
- graph_data = np.random.rand(100, 3).astype(np.float32)
105
- leap_graph = LeapGraph(data=graph_data)
106
- title = "Graph"
107
- visualize(leap_graph, title)
108
- """
109
- graph_data = leap_data.data
110
- num_variables = graph_data.shape[1]
111
-
112
- fig, ax = plt.subplots(figsize=(10, 6))
113
-
114
- # Set the background color to black
115
- fig.patch.set_facecolor('black')
116
- ax.set_facecolor('black')
117
-
118
- for i in range(num_variables):
119
- plt.plot(graph_data[:, i], label=f'Variable {i + 1}')
120
-
121
- ax.set_xlabel('Data Points', color='white')
122
- ax.set_ylabel('Values', color='white')
123
- ax.set_title(title, color='white')
124
- ax.legend()
125
- ax.grid(True, color='white')
126
-
127
- # Change the color of the tick labels to white
128
- ax.tick_params(colors='white')
129
-
130
- plt.show()
131
-
132
-
133
- def plot_text_with_heatmap(leap_data: LeapData, title: str) -> None:
134
- """
135
- Display the text contained in the LeapText object with a heatmap overlay.
136
-
137
- Args:
138
- leap_data (LeapData): The LeapText object containing text tokens and an optional heatmap.
139
- title (str): The title of the visualization.
140
-
141
- Returns:
142
- None
143
-
144
- Example:
145
- text_data = ['I', 'ate', 'a', 'banana', '', '', '']
146
- heatmap = [0.1, 0.3, 0.2, 0.9, 0.0, 0.0, 0.0]
147
- leap_text = LeapText(data=text_data, heatmap=heatmap) # Create LeapText object
148
- title = "Text with Heatmap"
149
- visualize(leap_text, title)
150
- """
151
- text_data = leap_data.data
152
- heatmap = leap_data.heatmap
153
-
154
- text_data = [s for s in text_data if s != "[PAD]"]
155
-
156
- fig, ax = plt.subplots(figsize=(12, 5))
157
- fig.patch.set_facecolor('black')
158
- ax.set_facecolor('black')
159
- ax.axis('off') # Hide axes
160
-
161
- font_size = 20
162
-
163
- if heatmap is not None:
164
- heatmap = heatmap[:len(text_data)]
165
- if len(heatmap) != len(text_data):
166
- raise ValueError(
167
- f"Heatmap length ({len(heatmap)}) must match the number of tokens in `data` ({len(text_data)}).")
168
-
169
- max_tokens_per_row = 10
170
- num_rows = math.ceil(len(text_data) / max_tokens_per_row)
171
-
172
- fig.set_size_inches(12, num_rows * 1.2)
173
- for idx, (token, value) in enumerate(zip(text_data, heatmap)):
174
- if token:
175
- row = idx // max_tokens_per_row
176
- col = idx % max_tokens_per_row
177
-
178
- x_pos = col / max_tokens_per_row + 0.03
179
- y_pos = 1 - (row + 0.5) / num_rows
180
- color = plt.cm.jet(value)
181
- ax.text(
182
- x_pos,
183
- y_pos,
184
- token,
185
- fontsize=font_size,
186
- color=color,
187
- ha="left",
188
- va="center"
189
- )
190
- else:
191
- display_text = ' '.join([token for token in text_data if token])
192
- wrapped_text = "\n".join(wrap(display_text, width=80))
193
- font_color = 'white'
194
- ax.text(0.5, 0.5, wrapped_text, color=font_color, fontsize=font_size, ha='center', va='center')
195
-
196
- ax.set_title(title, color='white', fontsize=16)
197
-
198
- plt.tight_layout()
199
- plt.show()
200
-
201
-
202
- def plot_hbar(leap_data: LeapData, title: str) -> None:
203
- """
204
- Display the horizontal bar chart contained in the LeapHorizontalBar object.
205
-
206
- Returns:
207
- None
208
-
209
- Example:
210
- body_data = np.random.rand(5).astype(np.float32)
211
- gt_data = np.random.rand(5).astype(np.float32)
212
- labels = ['Class A', 'Class B', 'Class C', 'Class D', 'Class E']
213
- leap_horizontal_bar = LeapHorizontalBar(body=body_data, gt=gt_data, labels=labels)
214
- title = "Horizontal Bar"
215
- visualize(leap_horizontal_bar, title)
216
- """
217
- body_data = leap_data.body
218
- labels = leap_data.labels
219
-
220
- # Check if 'gt' attribute exists and is not None
221
- gt_data = getattr(leap_data, 'gt', None)
222
-
223
- fig, ax = plt.subplots()
224
-
225
- fig.patch.set_facecolor('black')
226
- ax.set_facecolor('black')
227
-
228
- # Adjust positions for side-by-side bars
229
- y_positions = range(len(labels))
230
- bar_width = 0.4
231
-
232
- # Plot horizontal bar chart
233
- if gt_data is not None:
234
- ax.barh([y - bar_width / 2 for y in y_positions], body_data, color='green', height=bar_width, label='Prediction')
235
- ax.barh([y + bar_width / 2 for y in y_positions], gt_data, color='orange', height=bar_width, label='GT')
236
- else:
237
- ax.barh(y_positions, body_data, color='green', label='Body Data')
238
-
239
- # Set the y-ticks to align with the center of the bars
240
- ax.set_yticks(y_positions)
241
- ax.set_yticklabels(labels, color='white')
242
-
243
- # Set the color of the labels and title to white
244
- ax.set_xlabel('Scores', color='white')
245
- ax.set_title(title, color='white')
246
-
247
- # Set the color of the ticks to white
248
- ax.tick_params(axis='x', colors='white')
249
- ax.tick_params(axis='y', colors='white')
250
-
251
- # Add legend if gt is present
252
- if gt_data is not None:
253
- ax.legend(loc='best', facecolor='black', edgecolor='white', labelcolor='white')
254
-
255
- plt.show()
256
-
257
-
258
- def plot_image_mask(leap_data: LeapData, title: str) -> None:
259
- """
260
- Plots an image with overlaid masks given a LeapImageMask visualizer object.
261
-
262
- Returns:
263
- None
264
-
265
-
266
- Example:
267
- image_data = np.random.rand(100, 100, 3).astype(np.float32)
268
- mask_data = np.random.randint(0, 2, (100, 100)).astype(np.uint8)
269
- labels = ["background", "object"]
270
- leap_image_mask = LeapImageMask(image=image_data, mask=mask_data, labels=labels)
271
- title = "Image Mask"
272
- visualize(leap_image_mask, title)
273
- """
274
-
275
- image = leap_data.image
276
- mask = leap_data.mask
277
- labels = leap_data.labels
278
-
279
- # Create a color map for each label
280
- colors = plt.cm.jet(np.linspace(0, 1, len(labels)))
281
- if image.dtype == np.uint8:
282
- colors = colors * 255
283
-
284
- # Make a copy of the image to draw on
285
- overlayed_image = image.copy()
286
-
287
- # Iterate through the unique values in the mask (excluding 0)
288
- for i, label in enumerate(labels):
289
- # Extract binary mask for the current instance
290
- instance_mask = (mask == (i + 1))
291
-
292
- # fill the instance mask with a translucent color
293
- overlayed_image[instance_mask] = (
294
- overlayed_image[instance_mask] * (1 - 0.5) + np.array(colors[i][:image.shape[-1]], dtype=np.uint8) * 0.5)
295
-
296
- # Display the result using matplotlib
297
- fig, ax = plt.subplots(1)
298
- fig.patch.set_facecolor('black') # Set the figure background to black
299
- ax.set_facecolor('black') # Set the axis background to black
300
-
301
- ax.imshow(overlayed_image)
302
- ax.set_title(title, color='white')
303
- plt.axis('off') # Hide the axis
304
- plt.show()
305
-
306
-
307
- def plot_text_mask(leap_data: LeapData, title: str) -> None:
308
- """
309
- Plots text with overlaid masks given a LeapTextMask visualizer object.
310
-
311
- Returns:
312
- None
313
-
314
- Example:
315
- text_data = ['I', 'ate', 'a', 'banana', '', '', '']
316
- mask_data = np.array([0, 0, 0, 1, 0, 0, 0]).astype(np.uint8)
317
- labels = ["object"]
318
- leap_text_mask = LeapTextMask(text=text_data, mask=mask_data, labels=labels)
319
- title = "Text Mask"
320
- visualize(leap_text_mask, title)
321
- """
322
-
323
- text_data = leap_data.text
324
- mask_data = leap_data.mask
325
- labels = leap_data.labels
326
-
327
- # Create a color map for each label
328
- colors = plt.cm.jet(np.linspace(0, 1, len(labels)))
329
-
330
- # Create a figure and axis
331
- fig, ax = plt.subplots()
332
-
333
- # Set background to black
334
- fig.patch.set_facecolor('black')
335
- ax.set_facecolor('black')
336
- ax.set_title(title, color='white')
337
- ax.axis('off')
338
-
339
- # Set initial position
340
- x_pos, y_pos = 0.01, 0.5 # Adjusted initial position for better visibility
341
-
342
- # Display the text with colors
343
- for token, mask_value in zip(text_data, mask_data):
344
- if mask_value > 0:
345
- color = colors[mask_value % len(colors)]
346
- bbox = dict(facecolor=color, edgecolor='none',
347
- boxstyle='round,pad=0.3') # Background color for masked tokens
348
- else:
349
- bbox = None
350
-
351
- ax.text(x_pos, y_pos, token, fontsize=12, color='white', ha='left', va='center', bbox=bbox)
352
-
353
- # Update the x position for the next token
354
- x_pos += len(token) * 0.03 + 0.02 # Adjust the spacing between tokens
355
-
356
- plt.show()
357
-
358
-
359
- def plot_image_with_heatmap(leap_data: LeapData, title: str) -> None:
360
- """
361
- Display the image with overlaid heatmaps contained in the LeapImageWithHeatmap object.
362
-
363
- Returns:
364
- None
365
-
366
- Example:
367
- image_data = np.random.rand(100, 100, 3).astype(np.float32)
368
- heatmaps = np.random.rand(3, 100, 100).astype(np.float32)
369
- labels = ["heatmap1", "heatmap2", "heatmap3"]
370
- leap_image_with_heatmap = LeapImageWithHeatmap(image=image_data, heatmaps=heatmaps, labels=labels)
371
- title = "Image With Heatmap"
372
- visualize(leap_image_with_heatmap, title)
373
- """
374
- image = leap_data.image
375
- heatmaps = leap_data.heatmaps
376
- labels = leap_data.labels
377
-
378
- # Plot the base image
379
- fig, ax = plt.subplots()
380
- fig.patch.set_facecolor('black') # Set the figure background to black
381
- ax.set_facecolor('black') # Set the axis background to black
382
- ax.imshow(image, cmap='gray')
383
-
384
- # Overlay each heatmap with some transparency
385
- for i in range(len(labels)):
386
- heatmap = heatmaps[i]
387
- ax.imshow(heatmap, cmap='jet', alpha=0.5) # Adjust alpha for transparency
388
- ax.set_title(f'Heatmap: {labels[i]}', color='white')
389
-
390
- # Display a colorbar for the heatmap
391
- cbar = plt.colorbar(ax.imshow(heatmap, cmap='jet', alpha=0.5))
392
- cbar.set_label(labels[i], color='white')
393
- cbar.ax.yaxis.set_tick_params(color='white') # Set color for the colorbar ticks
394
- plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), color='white') # Set color for the colorbar labels
395
-
396
- plt.axis('off')
397
- plt.title(title, color='white')
398
- plt.show()
399
-
400
-
401
- plot_switch = {
402
- LeapDataType.Image: plot_image,
403
- LeapDataType.Text: plot_text_with_heatmap,
404
- LeapDataType.Graph: plot_graph,
405
- LeapDataType.HorizontalBar: plot_hbar,
406
- LeapDataType.ImageMask: plot_image_mask,
407
- LeapDataType.TextMask: plot_text_mask,
408
- LeapDataType.ImageWithHeatmap: plot_image_with_heatmap,
409
- LeapDataType.ImageWithBBox: plot_image_with_b_box,
410
- }
@@ -1,19 +0,0 @@
1
- import sys
2
-
3
- from code_loader.contract.datasetclasses import LeapData # type: ignore
4
-
5
- from typing import Optional
6
-
7
- from code_loader.helpers.plot_functions import plot_switch
8
-
9
-
10
- def visualize(leap_data: LeapData, title: Optional[str] = None) -> None:
11
- vis_function = plot_switch.get(leap_data.type)
12
- if vis_function is None:
13
- print(f"Error: leap data type is not supported, leap data type: {leap_data.type}")
14
- sys.exit(1)
15
-
16
- if not title:
17
- title = f"Leap {leap_data.type.name} Visualization"
18
- vis_function(leap_data, title)
19
-