vision-agent 0.2.223__tar.gz → 0.2.225__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (47) hide show
  1. {vision_agent-0.2.223 → vision_agent-0.2.225}/PKG-INFO +1 -1
  2. {vision_agent-0.2.223 → vision_agent-0.2.225}/pyproject.toml +1 -1
  3. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/tools/tool_utils.py +5 -1
  4. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/tools/tools.py +85 -89
  5. vision_agent-0.2.225/vision_agent/utils/video_tracking.py +305 -0
  6. {vision_agent-0.2.223 → vision_agent-0.2.225}/LICENSE +0 -0
  7. {vision_agent-0.2.223 → vision_agent-0.2.225}/README.md +0 -0
  8. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/.sim_tools/df.csv +0 -0
  9. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/.sim_tools/embs.npy +0 -0
  10. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/__init__.py +0 -0
  11. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/README.md +0 -0
  12. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/__init__.py +0 -0
  13. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/agent.py +0 -0
  14. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/agent_utils.py +0 -0
  15. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/types.py +0 -0
  16. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/vision_agent.py +0 -0
  17. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/vision_agent_coder.py +0 -0
  18. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/vision_agent_coder_prompts.py +0 -0
  19. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/vision_agent_coder_prompts_v2.py +0 -0
  20. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/vision_agent_coder_v2.py +0 -0
  21. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/vision_agent_planner.py +0 -0
  22. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/vision_agent_planner_prompts.py +0 -0
  23. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/vision_agent_planner_prompts_v2.py +0 -0
  24. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/vision_agent_planner_v2.py +0 -0
  25. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/vision_agent_prompts.py +0 -0
  26. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/vision_agent_prompts_v2.py +0 -0
  27. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/vision_agent_v2.py +0 -0
  28. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/clients/__init__.py +0 -0
  29. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/clients/http.py +0 -0
  30. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/clients/landing_public_api.py +0 -0
  31. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/fonts/__init__.py +0 -0
  32. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/fonts/default_font_ch_en.ttf +0 -0
  33. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/lmm/__init__.py +0 -0
  34. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/lmm/lmm.py +0 -0
  35. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/lmm/types.py +0 -0
  36. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/tools/__init__.py +0 -0
  37. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/tools/meta_tools.py +0 -0
  38. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/tools/planner_tools.py +0 -0
  39. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/tools/prompts.py +0 -0
  40. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/tools/tools_types.py +0 -0
  41. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/utils/__init__.py +0 -0
  42. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/utils/exceptions.py +0 -0
  43. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/utils/execute.py +0 -0
  44. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/utils/image_utils.py +0 -0
  45. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/utils/sim.py +0 -0
  46. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/utils/type_defs.py +0 -0
  47. {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/utils/video.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vision-agent
3
- Version: 0.2.223
3
+ Version: 0.2.225
4
4
  Summary: Toolset for Vision Agent
5
5
  Author: Landing AI
6
6
  Author-email: dev@landing.ai
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "vision-agent"
7
- version = "0.2.223"
7
+ version = "0.2.225"
8
8
  description = "Toolset for Vision Agent"
9
9
  authors = ["Landing AI <dev@landing.ai>"]
10
10
  readme = "README.md"
@@ -25,6 +25,10 @@ _LND_API_URL = f"{_LND_BASE_URL}/v1/agent/model"
25
25
  _LND_API_URL_v2 = f"{_LND_BASE_URL}/v1/tools"
26
26
 
27
27
 
28
+ def should_report_tool_traces() -> bool:
29
+ return bool(os.environ.get("REPORT_TOOL_TRACES", False))
30
+
31
+
28
32
  class ToolCallTrace(BaseModel):
29
33
  endpoint_url: str
30
34
  type: str
@@ -251,7 +255,7 @@ def _call_post(
251
255
  tool_call_trace.response = result
252
256
  return result
253
257
  finally:
254
- if tool_call_trace is not None:
258
+ if tool_call_trace is not None and should_report_tool_traces():
255
259
  trace = tool_call_trace.model_dump()
256
260
  display({MimeType.APPLICATION_JSON: trace}, raw=True)
257
261
 
@@ -6,7 +6,6 @@ import tempfile
6
6
  import urllib.request
7
7
  from base64 import b64encode
8
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
- from enum import Enum
10
9
  from importlib import resources
11
10
  from pathlib import Path
12
11
  from typing import Any, Dict, List, Optional, Tuple, Union, cast
@@ -32,6 +31,7 @@ from vision_agent.tools.tool_utils import (
32
31
  nms,
33
32
  send_inference_request,
34
33
  send_task_inference_request,
34
+ should_report_tool_traces,
35
35
  single_nms,
36
36
  )
37
37
  from vision_agent.tools.tools_types import JobStatus
@@ -53,6 +53,13 @@ from vision_agent.utils.video import (
53
53
  frames_to_bytes,
54
54
  video_writer,
55
55
  )
56
+ from vision_agent.utils.video_tracking import (
57
+ ODModels,
58
+ merge_segments,
59
+ post_process,
60
+ process_segment,
61
+ split_frames_into_segments,
62
+ )
56
63
 
57
64
  register_heif_opener()
58
65
 
@@ -94,6 +101,9 @@ def _display_tool_trace(
94
101
  # such as video bytes, which can be slow. Since this is calculated inside the
95
102
  # function we can't capture it with a decarator without adding it as a return value
96
103
  # which would change the function signature and affect the agent.
104
+ if not should_report_tool_traces():
105
+ return
106
+
97
107
  files_in_b64: List[Tuple[str, str]]
98
108
  if isinstance(files, str):
99
109
  files_in_b64 = [("images", files)]
@@ -220,12 +230,6 @@ def sam2(
220
230
  return ret["return_data"] # type: ignore
221
231
 
222
232
 
223
- class ODModels(str, Enum):
224
- COUNTGD = "countgd"
225
- FLORENCE2 = "florence2"
226
- OWLV2 = "owlv2"
227
-
228
-
229
233
  def od_sam2_video_tracking(
230
234
  od_model: ODModels,
231
235
  prompt: str,
@@ -233,105 +237,92 @@ def od_sam2_video_tracking(
233
237
  chunk_length: Optional[int] = 10,
234
238
  fine_tune_id: Optional[str] = None,
235
239
  ) -> Dict[str, Any]:
236
- results: List[Optional[List[Dict[str, Any]]]] = [None] * len(frames)
240
+ SEGMENT_SIZE = 50
241
+ OVERLAP = 1 # Number of overlapping frames between segments
237
242
 
238
- if chunk_length is None:
239
- step = 1 # Process every frame
240
- elif chunk_length <= 0:
241
- raise ValueError("chunk_length must be a positive integer or None.")
242
- else:
243
- step = chunk_length # Process frames with the specified step size
243
+ image_size = frames[0].shape[:2]
244
+
245
+ # Split frames into segments with overlap
246
+ segments = split_frames_into_segments(frames, SEGMENT_SIZE, OVERLAP)
247
+
248
+ def _apply_object_detection( # inner method to avoid circular importing issues.
249
+ od_model: ODModels,
250
+ prompt: str,
251
+ segment_index: int,
252
+ frame_number: int,
253
+ fine_tune_id: str,
254
+ segment_frames: list,
255
+ ) -> tuple:
256
+ """
257
+ Applies the specified object detection model to the given image.
258
+
259
+ Args:
260
+ od_model: The object detection model to use.
261
+ prompt: The prompt for the object detection model.
262
+ segment_index: The index of the current segment.
263
+ frame_number: The number of the current frame.
264
+ fine_tune_id: Optional fine-tune ID for the model.
265
+ segment_frames: List of frames for the current segment.
266
+
267
+ Returns:
268
+ A tuple containing the object detection results and the name of the function used.
269
+ """
244
270
 
245
- for idx in range(0, len(frames), step):
246
271
  if od_model == ODModels.COUNTGD:
247
- results[idx] = countgd_object_detection(prompt=prompt, image=frames[idx])
272
+ segment_results = countgd_object_detection(
273
+ prompt=prompt, image=segment_frames[frame_number]
274
+ )
248
275
  function_name = "countgd_object_detection"
276
+
249
277
  elif od_model == ODModels.OWLV2:
250
- results[idx] = owlv2_object_detection(
251
- prompt=prompt, image=frames[idx], fine_tune_id=fine_tune_id
278
+ segment_results = owlv2_object_detection(
279
+ prompt=prompt,
280
+ image=segment_frames[frame_number],
281
+ fine_tune_id=fine_tune_id,
252
282
  )
253
283
  function_name = "owlv2_object_detection"
284
+
254
285
  elif od_model == ODModels.FLORENCE2:
255
- results[idx] = florence2_object_detection(
256
- prompt=prompt, image=frames[idx], fine_tune_id=fine_tune_id
286
+ segment_results = florence2_object_detection(
287
+ prompt=prompt,
288
+ image=segment_frames[frame_number],
289
+ fine_tune_id=fine_tune_id,
257
290
  )
258
291
  function_name = "florence2_object_detection"
292
+
259
293
  else:
260
294
  raise NotImplementedError(
261
295
  f"Object detection model '{od_model}' is not implemented."
262
296
  )
263
297
 
264
- image_size = frames[0].shape[:2]
265
-
266
- def _transform_detections(
267
- input_list: List[Optional[List[Dict[str, Any]]]]
268
- ) -> List[Optional[Dict[str, Any]]]:
269
- output_list: List[Optional[Dict[str, Any]]] = []
270
-
271
- for _, frame in enumerate(input_list):
272
- if frame is not None:
273
- labels = [detection["label"] for detection in frame]
274
- bboxes = [
275
- denormalize_bbox(detection["bbox"], image_size)
276
- for detection in frame
277
- ]
278
-
279
- output_list.append(
280
- {
281
- "labels": labels,
282
- "bboxes": bboxes,
283
- }
284
- )
285
- else:
286
- output_list.append(None)
287
-
288
- return output_list
298
+ return segment_results, function_name
299
+
300
+ # Process each segment and collect detections
301
+ detections_per_segment: List[Any] = []
302
+ for segment_index, segment in enumerate(segments):
303
+ segment_detections = process_segment(
304
+ segment_frames=segment,
305
+ od_model=od_model,
306
+ prompt=prompt,
307
+ fine_tune_id=fine_tune_id,
308
+ chunk_length=chunk_length,
309
+ image_size=image_size,
310
+ segment_index=segment_index,
311
+ object_detection_tool=_apply_object_detection,
312
+ )
313
+ detections_per_segment.append(segment_detections)
289
314
 
290
- output = _transform_detections(results)
315
+ merged_detections = merge_segments(detections_per_segment)
316
+ post_processed = post_process(merged_detections, image_size)
291
317
 
292
318
  buffer_bytes = frames_to_bytes(frames)
293
319
  files = [("video", buffer_bytes)]
294
- payload = {"bboxes": json.dumps(output), "chunk_length_frames": chunk_length}
295
- metadata = {"function_name": function_name}
296
320
 
297
- detections = send_task_inference_request(
298
- payload,
299
- "sam2",
300
- files=files,
301
- metadata=metadata,
302
- )
303
-
304
- return_data = []
305
- for frame in detections:
306
- return_frame_data = []
307
- for detection in frame:
308
- mask = rle_decode_array(detection["mask"])
309
- label = str(detection["id"]) + ": " + detection["label"]
310
- return_frame_data.append(
311
- {"label": label, "mask": mask, "score": 1.0, "rle": detection["mask"]}
312
- )
313
- return_data.append(return_frame_data)
314
- return_data = add_bboxes_from_masks(return_data)
315
- return_data = nms(return_data, iou_threshold=0.95)
316
-
317
- # We save the RLE for display purposes, re-calculting RLE can get very expensive.
318
- # Deleted here because we are returning the numpy masks instead
319
- display_data = []
320
- for frame in return_data:
321
- display_frame_data = []
322
- for obj in frame:
323
- display_frame_data.append(
324
- {
325
- "label": obj["label"],
326
- "score": obj["score"],
327
- "bbox": denormalize_bbox(obj["bbox"], image_size),
328
- "mask": obj["rle"],
329
- }
330
- )
331
- del obj["rle"]
332
- display_data.append(display_frame_data)
333
-
334
- return {"files": files, "return_data": return_data, "display_data": detections}
321
+ return {
322
+ "files": files,
323
+ "return_data": post_processed["return_data"],
324
+ "display_data": post_processed["display_data"],
325
+ }
335
326
 
336
327
 
337
328
  # Owl V2 Tools
@@ -2243,15 +2234,17 @@ def save_image(image: np.ndarray, file_path: str) -> None:
2243
2234
  >>> save_image(image)
2244
2235
  """
2245
2236
  Path(file_path).parent.mkdir(parents=True, exist_ok=True)
2246
- from IPython.display import display
2247
-
2248
2237
  if not isinstance(image, np.ndarray) or (
2249
2238
  image.shape[0] == 0 and image.shape[1] == 0
2250
2239
  ):
2251
2240
  raise ValueError("The image is not a valid NumPy array with shape (H, W, C)")
2252
2241
 
2253
2242
  pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB")
2254
- display(pil_image)
2243
+ if should_report_tool_traces():
2244
+ from IPython.display import display
2245
+
2246
+ display(pil_image)
2247
+
2255
2248
  pil_image.save(file_path)
2256
2249
 
2257
2250
 
@@ -2302,6 +2295,9 @@ def save_video(
2302
2295
 
2303
2296
  def _save_video_to_result(video_uri: str) -> None:
2304
2297
  """Saves a video into the result of the code execution (as an intermediate output)."""
2298
+ if not should_report_tool_traces():
2299
+ return
2300
+
2305
2301
  from IPython.display import display
2306
2302
 
2307
2303
  serializer = FileSerializer(video_uri)
@@ -0,0 +1,305 @@
1
+ import json
2
+ from enum import Enum
3
+ from typing import Any, Callable, Dict, List, Optional, Tuple
4
+
5
+ import numpy as np
6
+
7
+ from vision_agent.tools.tool_utils import (
8
+ add_bboxes_from_masks,
9
+ nms,
10
+ send_task_inference_request,
11
+ )
12
+ from vision_agent.utils.image_utils import denormalize_bbox, rle_decode_array
13
+ from vision_agent.utils.video import frames_to_bytes
14
+
15
+
16
+ class ODModels(str, Enum):
17
+ COUNTGD = "countgd"
18
+ FLORENCE2 = "florence2"
19
+ OWLV2 = "owlv2"
20
+
21
+
22
+ def split_frames_into_segments(
23
+ frames: List[np.ndarray], segment_size: int = 50, overlap: int = 1
24
+ ) -> List[List[np.ndarray]]:
25
+ """
26
+ Splits the list of frames into segments with a specified size and overlap.
27
+
28
+ Args:
29
+ frames (List[np.ndarray]): List of video frames.
30
+ segment_size (int, optional): Number of frames per segment. Defaults to 50.
31
+ overlap (int, optional): Number of overlapping frames between segments. Defaults to 1.
32
+
33
+ Returns:
34
+ List[List[np.ndarray]]: List of frame segments.
35
+ """
36
+ segments = []
37
+ start = 0
38
+ segment_count = 0
39
+ while start < len(frames):
40
+ end = start + segment_size
41
+ if end > len(frames):
42
+ end = len(frames)
43
+ if start != 0:
44
+ # Include the last frame of the previous segment
45
+ segment = frames[start - overlap : end]
46
+ else:
47
+ segment = frames[start:end]
48
+ segments.append(segment)
49
+ start += segment_size
50
+ segment_count += 1
51
+ return segments
52
+
53
+
54
+ def process_segment(
55
+ segment_frames: List[np.ndarray],
56
+ od_model: ODModels,
57
+ prompt: str,
58
+ fine_tune_id: Optional[str],
59
+ chunk_length: Optional[int],
60
+ image_size: Tuple[int, ...],
61
+ segment_index: int,
62
+ object_detection_tool: Callable,
63
+ ) -> Any:
64
+ """
65
+ Processes a segment of frames with the specified object detection model.
66
+
67
+ Args:
68
+ segment_frames (List[np.ndarray]): Frames in the segment.
69
+ od_model (ODModels): Object detection model to use.
70
+ prompt (str): Prompt for the model.
71
+ fine_tune_id (Optional[str]): Fine-tune model ID.
72
+ chunk_length (Optional[int]): Chunk length for processing.
73
+ image_size (Tuple[int, int]): Size of the images.
74
+ segment_index (int): Index of the segment.
75
+ object_detection_tool (Callable): Object detection tool to use.
76
+
77
+ Returns:
78
+ Any: Detections for the segment.
79
+ """
80
+ segment_results: List[Optional[List[Dict[str, Any]]]] = [None] * len(segment_frames)
81
+
82
+ if chunk_length is None:
83
+ step = 1
84
+ elif chunk_length <= 0:
85
+ raise ValueError("chunk_length must be a positive integer or None.")
86
+ else:
87
+ step = chunk_length
88
+
89
+ function_name = ""
90
+
91
+ for idx in range(0, len(segment_frames), step):
92
+ frame_number = idx
93
+ segment_results[idx], function_name = object_detection_tool(
94
+ od_model, prompt, segment_index, frame_number, fine_tune_id, segment_frames
95
+ )
96
+
97
+ transformed_detections = transform_detections(
98
+ segment_results, image_size, segment_index
99
+ )
100
+
101
+ buffer_bytes = frames_to_bytes(segment_frames)
102
+ files = [("video", buffer_bytes)]
103
+ payload = {
104
+ "bboxes": json.dumps(transformed_detections),
105
+ "chunk_length_frames": chunk_length,
106
+ }
107
+ metadata = {"function_name": function_name}
108
+
109
+ segment_detections = send_task_inference_request(
110
+ payload,
111
+ "sam2",
112
+ files=files,
113
+ metadata=metadata,
114
+ )
115
+
116
+ return segment_detections
117
+
118
+
119
+ def transform_detections(
120
+ input_list: List[Optional[List[Dict[str, Any]]]],
121
+ image_size: Tuple[int, ...],
122
+ segment_index: int,
123
+ ) -> List[Optional[Dict[str, Any]]]:
124
+ """
125
+ Transforms raw detections into a standardized format.
126
+
127
+ Args:
128
+ input_list (List[Optional[List[Dict[str, Any]]]]): Raw detections.
129
+ image_size (Tuple[int, int]): Size of the images.
130
+ segment_index (int): Index of the segment.
131
+
132
+ Returns:
133
+ List[Optional[Dict[str, Any]]]: Transformed detections.
134
+ """
135
+ output_list: List[Optional[Dict[str, Any]]] = []
136
+ for frame_idx, frame in enumerate(input_list):
137
+ if frame is not None:
138
+ labels = [detection["label"] for detection in frame]
139
+ bboxes = [
140
+ denormalize_bbox(detection["bbox"], image_size) for detection in frame
141
+ ]
142
+
143
+ output_list.append(
144
+ {
145
+ "labels": labels,
146
+ "bboxes": bboxes,
147
+ }
148
+ )
149
+ else:
150
+ output_list.append(None)
151
+ return output_list
152
+
153
+
154
+ def _calculate_mask_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
155
+ mask1 = mask1.astype(bool)
156
+ mask2 = mask2.astype(bool)
157
+
158
+ intersection = np.sum(np.logical_and(mask1, mask2))
159
+ union = np.sum(np.logical_or(mask1, mask2))
160
+
161
+ if union == 0:
162
+ iou = 0.0
163
+ else:
164
+ iou = intersection / union
165
+
166
+ return iou
167
+
168
+
169
+ def _match_by_iou(
170
+ first_param: List[Dict],
171
+ second_param: List[Dict],
172
+ iou_threshold: float = 0.8,
173
+ ) -> Tuple[List[Dict], Dict[int, int]]:
174
+ max_id = max((item["id"] for item in first_param), default=0)
175
+
176
+ matched_new_item_indices = set()
177
+ id_mapping = {}
178
+
179
+ for new_index, new_item in enumerate(second_param):
180
+ matched_id = None
181
+
182
+ for existing_item in first_param:
183
+ iou = _calculate_mask_iou(
184
+ existing_item["decoded_mask"], new_item["decoded_mask"]
185
+ )
186
+ if iou > iou_threshold:
187
+ matched_id = existing_item["id"]
188
+ matched_new_item_indices.add(new_index)
189
+ id_mapping[new_item["id"]] = matched_id
190
+ break
191
+
192
+ if matched_id:
193
+ new_item["id"] = matched_id
194
+ else:
195
+ max_id += 1
196
+ id_mapping[new_item["id"]] = max_id
197
+ new_item["id"] = max_id
198
+
199
+ unmatched_items = [
200
+ item for i, item in enumerate(second_param) if i not in matched_new_item_indices
201
+ ]
202
+ combined_list = first_param + unmatched_items
203
+
204
+ return combined_list, id_mapping
205
+
206
+
207
+ def _update_ids(detections: List[Dict], id_mapping: Dict[int, int]) -> None:
208
+ for inner_list in detections:
209
+ for detection in inner_list:
210
+ if detection["id"] in id_mapping:
211
+ detection["id"] = id_mapping[detection["id"]]
212
+ else:
213
+ max_new_id = max(id_mapping.values(), default=0)
214
+ detection["id"] = max_new_id + 1
215
+ id_mapping[detection["id"]] = detection["id"]
216
+
217
+
218
+ def _convert_to_2d(detections_per_segment: List[Any]) -> List[Any]:
219
+ result = []
220
+ for i, segment in enumerate(detections_per_segment):
221
+ if i == 0:
222
+ result.extend(segment)
223
+ else:
224
+ result.extend(segment[1:])
225
+ return result
226
+
227
+
228
+ def merge_segments(detections_per_segment: List[Any]) -> List[Any]:
229
+ """
230
+ Merges detections from all segments into a unified result.
231
+
232
+ Args:
233
+ detections_per_segment (List[Any]): List of detections per segment.
234
+
235
+ Returns:
236
+ List[Any]: Merged detections.
237
+ """
238
+ for segment in detections_per_segment:
239
+ for detection in segment:
240
+ for item in detection:
241
+ item["decoded_mask"] = rle_decode_array(item["mask"])
242
+
243
+ for segment_idx in range(len(detections_per_segment) - 1):
244
+ combined_detection, id_mapping = _match_by_iou(
245
+ detections_per_segment[segment_idx][-1],
246
+ detections_per_segment[segment_idx + 1][0],
247
+ )
248
+ _update_ids(detections_per_segment[segment_idx + 1], id_mapping)
249
+
250
+ merged_result = _convert_to_2d(detections_per_segment)
251
+
252
+ return merged_result
253
+
254
+
255
+ def post_process(
256
+ merged_detections: List[Any],
257
+ image_size: Tuple[int, ...],
258
+ ) -> Dict[str, Any]:
259
+ """
260
+ Performs post-processing on merged detections, including NMS and preparing display data.
261
+
262
+ Args:
263
+ merged_detections (List[Any]): Merged detections from all segments.
264
+ image_size (Tuple[int, int]): Size of the images.
265
+
266
+ Returns:
267
+ Dict[str, Any]: Post-processed data including return_data and display_data.
268
+ """
269
+ return_data = []
270
+ for frame_idx, frame in enumerate(merged_detections):
271
+ return_frame_data = []
272
+ for detection in frame:
273
+ label = f"{detection['id']}: {detection['label']}"
274
+ return_frame_data.append(
275
+ {
276
+ "label": label,
277
+ "mask": detection["decoded_mask"],
278
+ "rle": detection["mask"],
279
+ "score": 1.0,
280
+ }
281
+ )
282
+ del detection["decoded_mask"]
283
+ return_data.append(return_frame_data)
284
+
285
+ return_data = add_bboxes_from_masks(return_data)
286
+ return_data = nms(return_data, iou_threshold=0.95)
287
+
288
+ # We save the RLE for display purposes, re-calculting RLE can get very expensive.
289
+ # Deleted here because we are returning the numpy masks instead
290
+ display_data = []
291
+ for frame in return_data:
292
+ display_frame_data = []
293
+ for obj in frame:
294
+ display_frame_data.append(
295
+ {
296
+ "label": obj["label"],
297
+ "bbox": denormalize_bbox(obj["bbox"], image_size),
298
+ "mask": obj["rle"],
299
+ "score": obj["score"],
300
+ }
301
+ )
302
+ del obj["rle"]
303
+ display_data.append(display_frame_data)
304
+
305
+ return {"return_data": return_data, "display_data": display_data}
File without changes
File without changes