vision-agent 0.2.223__tar.gz → 0.2.225__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {vision_agent-0.2.223 → vision_agent-0.2.225}/PKG-INFO +1 -1
- {vision_agent-0.2.223 → vision_agent-0.2.225}/pyproject.toml +1 -1
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/tools/tool_utils.py +5 -1
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/tools/tools.py +85 -89
- vision_agent-0.2.225/vision_agent/utils/video_tracking.py +305 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/LICENSE +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/README.md +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/.sim_tools/df.csv +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/.sim_tools/embs.npy +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/__init__.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/README.md +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/__init__.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/agent.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/agent_utils.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/types.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/vision_agent.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/vision_agent_coder.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/vision_agent_coder_prompts.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/vision_agent_coder_prompts_v2.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/vision_agent_coder_v2.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/vision_agent_planner.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/vision_agent_planner_prompts.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/vision_agent_planner_prompts_v2.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/vision_agent_planner_v2.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/vision_agent_prompts.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/vision_agent_prompts_v2.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/vision_agent_v2.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/clients/__init__.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/clients/http.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/clients/landing_public_api.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/fonts/__init__.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/fonts/default_font_ch_en.ttf +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/lmm/__init__.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/lmm/lmm.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/lmm/types.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/tools/__init__.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/tools/meta_tools.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/tools/planner_tools.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/tools/prompts.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/tools/tools_types.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/utils/__init__.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/utils/exceptions.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/utils/execute.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/utils/image_utils.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/utils/sim.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/utils/type_defs.py +0 -0
- {vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/utils/video.py +0 -0
@@ -25,6 +25,10 @@ _LND_API_URL = f"{_LND_BASE_URL}/v1/agent/model"
|
|
25
25
|
_LND_API_URL_v2 = f"{_LND_BASE_URL}/v1/tools"
|
26
26
|
|
27
27
|
|
28
|
+
def should_report_tool_traces() -> bool:
|
29
|
+
return bool(os.environ.get("REPORT_TOOL_TRACES", False))
|
30
|
+
|
31
|
+
|
28
32
|
class ToolCallTrace(BaseModel):
|
29
33
|
endpoint_url: str
|
30
34
|
type: str
|
@@ -251,7 +255,7 @@ def _call_post(
|
|
251
255
|
tool_call_trace.response = result
|
252
256
|
return result
|
253
257
|
finally:
|
254
|
-
if tool_call_trace is not None:
|
258
|
+
if tool_call_trace is not None and should_report_tool_traces():
|
255
259
|
trace = tool_call_trace.model_dump()
|
256
260
|
display({MimeType.APPLICATION_JSON: trace}, raw=True)
|
257
261
|
|
@@ -6,7 +6,6 @@ import tempfile
|
|
6
6
|
import urllib.request
|
7
7
|
from base64 import b64encode
|
8
8
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
9
|
-
from enum import Enum
|
10
9
|
from importlib import resources
|
11
10
|
from pathlib import Path
|
12
11
|
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
@@ -32,6 +31,7 @@ from vision_agent.tools.tool_utils import (
|
|
32
31
|
nms,
|
33
32
|
send_inference_request,
|
34
33
|
send_task_inference_request,
|
34
|
+
should_report_tool_traces,
|
35
35
|
single_nms,
|
36
36
|
)
|
37
37
|
from vision_agent.tools.tools_types import JobStatus
|
@@ -53,6 +53,13 @@ from vision_agent.utils.video import (
|
|
53
53
|
frames_to_bytes,
|
54
54
|
video_writer,
|
55
55
|
)
|
56
|
+
from vision_agent.utils.video_tracking import (
|
57
|
+
ODModels,
|
58
|
+
merge_segments,
|
59
|
+
post_process,
|
60
|
+
process_segment,
|
61
|
+
split_frames_into_segments,
|
62
|
+
)
|
56
63
|
|
57
64
|
register_heif_opener()
|
58
65
|
|
@@ -94,6 +101,9 @@ def _display_tool_trace(
|
|
94
101
|
# such as video bytes, which can be slow. Since this is calculated inside the
|
95
102
|
# function we can't capture it with a decarator without adding it as a return value
|
96
103
|
# which would change the function signature and affect the agent.
|
104
|
+
if not should_report_tool_traces():
|
105
|
+
return
|
106
|
+
|
97
107
|
files_in_b64: List[Tuple[str, str]]
|
98
108
|
if isinstance(files, str):
|
99
109
|
files_in_b64 = [("images", files)]
|
@@ -220,12 +230,6 @@ def sam2(
|
|
220
230
|
return ret["return_data"] # type: ignore
|
221
231
|
|
222
232
|
|
223
|
-
class ODModels(str, Enum):
|
224
|
-
COUNTGD = "countgd"
|
225
|
-
FLORENCE2 = "florence2"
|
226
|
-
OWLV2 = "owlv2"
|
227
|
-
|
228
|
-
|
229
233
|
def od_sam2_video_tracking(
|
230
234
|
od_model: ODModels,
|
231
235
|
prompt: str,
|
@@ -233,105 +237,92 @@ def od_sam2_video_tracking(
|
|
233
237
|
chunk_length: Optional[int] = 10,
|
234
238
|
fine_tune_id: Optional[str] = None,
|
235
239
|
) -> Dict[str, Any]:
|
236
|
-
|
240
|
+
SEGMENT_SIZE = 50
|
241
|
+
OVERLAP = 1 # Number of overlapping frames between segments
|
237
242
|
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
243
|
+
image_size = frames[0].shape[:2]
|
244
|
+
|
245
|
+
# Split frames into segments with overlap
|
246
|
+
segments = split_frames_into_segments(frames, SEGMENT_SIZE, OVERLAP)
|
247
|
+
|
248
|
+
def _apply_object_detection( # inner method to avoid circular importing issues.
|
249
|
+
od_model: ODModels,
|
250
|
+
prompt: str,
|
251
|
+
segment_index: int,
|
252
|
+
frame_number: int,
|
253
|
+
fine_tune_id: str,
|
254
|
+
segment_frames: list,
|
255
|
+
) -> tuple:
|
256
|
+
"""
|
257
|
+
Applies the specified object detection model to the given image.
|
258
|
+
|
259
|
+
Args:
|
260
|
+
od_model: The object detection model to use.
|
261
|
+
prompt: The prompt for the object detection model.
|
262
|
+
segment_index: The index of the current segment.
|
263
|
+
frame_number: The number of the current frame.
|
264
|
+
fine_tune_id: Optional fine-tune ID for the model.
|
265
|
+
segment_frames: List of frames for the current segment.
|
266
|
+
|
267
|
+
Returns:
|
268
|
+
A tuple containing the object detection results and the name of the function used.
|
269
|
+
"""
|
244
270
|
|
245
|
-
for idx in range(0, len(frames), step):
|
246
271
|
if od_model == ODModels.COUNTGD:
|
247
|
-
|
272
|
+
segment_results = countgd_object_detection(
|
273
|
+
prompt=prompt, image=segment_frames[frame_number]
|
274
|
+
)
|
248
275
|
function_name = "countgd_object_detection"
|
276
|
+
|
249
277
|
elif od_model == ODModels.OWLV2:
|
250
|
-
|
251
|
-
prompt=prompt,
|
278
|
+
segment_results = owlv2_object_detection(
|
279
|
+
prompt=prompt,
|
280
|
+
image=segment_frames[frame_number],
|
281
|
+
fine_tune_id=fine_tune_id,
|
252
282
|
)
|
253
283
|
function_name = "owlv2_object_detection"
|
284
|
+
|
254
285
|
elif od_model == ODModels.FLORENCE2:
|
255
|
-
|
256
|
-
prompt=prompt,
|
286
|
+
segment_results = florence2_object_detection(
|
287
|
+
prompt=prompt,
|
288
|
+
image=segment_frames[frame_number],
|
289
|
+
fine_tune_id=fine_tune_id,
|
257
290
|
)
|
258
291
|
function_name = "florence2_object_detection"
|
292
|
+
|
259
293
|
else:
|
260
294
|
raise NotImplementedError(
|
261
295
|
f"Object detection model '{od_model}' is not implemented."
|
262
296
|
)
|
263
297
|
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
{
|
281
|
-
"labels": labels,
|
282
|
-
"bboxes": bboxes,
|
283
|
-
}
|
284
|
-
)
|
285
|
-
else:
|
286
|
-
output_list.append(None)
|
287
|
-
|
288
|
-
return output_list
|
298
|
+
return segment_results, function_name
|
299
|
+
|
300
|
+
# Process each segment and collect detections
|
301
|
+
detections_per_segment: List[Any] = []
|
302
|
+
for segment_index, segment in enumerate(segments):
|
303
|
+
segment_detections = process_segment(
|
304
|
+
segment_frames=segment,
|
305
|
+
od_model=od_model,
|
306
|
+
prompt=prompt,
|
307
|
+
fine_tune_id=fine_tune_id,
|
308
|
+
chunk_length=chunk_length,
|
309
|
+
image_size=image_size,
|
310
|
+
segment_index=segment_index,
|
311
|
+
object_detection_tool=_apply_object_detection,
|
312
|
+
)
|
313
|
+
detections_per_segment.append(segment_detections)
|
289
314
|
|
290
|
-
|
315
|
+
merged_detections = merge_segments(detections_per_segment)
|
316
|
+
post_processed = post_process(merged_detections, image_size)
|
291
317
|
|
292
318
|
buffer_bytes = frames_to_bytes(frames)
|
293
319
|
files = [("video", buffer_bytes)]
|
294
|
-
payload = {"bboxes": json.dumps(output), "chunk_length_frames": chunk_length}
|
295
|
-
metadata = {"function_name": function_name}
|
296
320
|
|
297
|
-
|
298
|
-
|
299
|
-
"
|
300
|
-
|
301
|
-
|
302
|
-
)
|
303
|
-
|
304
|
-
return_data = []
|
305
|
-
for frame in detections:
|
306
|
-
return_frame_data = []
|
307
|
-
for detection in frame:
|
308
|
-
mask = rle_decode_array(detection["mask"])
|
309
|
-
label = str(detection["id"]) + ": " + detection["label"]
|
310
|
-
return_frame_data.append(
|
311
|
-
{"label": label, "mask": mask, "score": 1.0, "rle": detection["mask"]}
|
312
|
-
)
|
313
|
-
return_data.append(return_frame_data)
|
314
|
-
return_data = add_bboxes_from_masks(return_data)
|
315
|
-
return_data = nms(return_data, iou_threshold=0.95)
|
316
|
-
|
317
|
-
# We save the RLE for display purposes, re-calculting RLE can get very expensive.
|
318
|
-
# Deleted here because we are returning the numpy masks instead
|
319
|
-
display_data = []
|
320
|
-
for frame in return_data:
|
321
|
-
display_frame_data = []
|
322
|
-
for obj in frame:
|
323
|
-
display_frame_data.append(
|
324
|
-
{
|
325
|
-
"label": obj["label"],
|
326
|
-
"score": obj["score"],
|
327
|
-
"bbox": denormalize_bbox(obj["bbox"], image_size),
|
328
|
-
"mask": obj["rle"],
|
329
|
-
}
|
330
|
-
)
|
331
|
-
del obj["rle"]
|
332
|
-
display_data.append(display_frame_data)
|
333
|
-
|
334
|
-
return {"files": files, "return_data": return_data, "display_data": detections}
|
321
|
+
return {
|
322
|
+
"files": files,
|
323
|
+
"return_data": post_processed["return_data"],
|
324
|
+
"display_data": post_processed["display_data"],
|
325
|
+
}
|
335
326
|
|
336
327
|
|
337
328
|
# Owl V2 Tools
|
@@ -2243,15 +2234,17 @@ def save_image(image: np.ndarray, file_path: str) -> None:
|
|
2243
2234
|
>>> save_image(image)
|
2244
2235
|
"""
|
2245
2236
|
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
|
2246
|
-
from IPython.display import display
|
2247
|
-
|
2248
2237
|
if not isinstance(image, np.ndarray) or (
|
2249
2238
|
image.shape[0] == 0 and image.shape[1] == 0
|
2250
2239
|
):
|
2251
2240
|
raise ValueError("The image is not a valid NumPy array with shape (H, W, C)")
|
2252
2241
|
|
2253
2242
|
pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB")
|
2254
|
-
|
2243
|
+
if should_report_tool_traces():
|
2244
|
+
from IPython.display import display
|
2245
|
+
|
2246
|
+
display(pil_image)
|
2247
|
+
|
2255
2248
|
pil_image.save(file_path)
|
2256
2249
|
|
2257
2250
|
|
@@ -2302,6 +2295,9 @@ def save_video(
|
|
2302
2295
|
|
2303
2296
|
def _save_video_to_result(video_uri: str) -> None:
|
2304
2297
|
"""Saves a video into the result of the code execution (as an intermediate output)."""
|
2298
|
+
if not should_report_tool_traces():
|
2299
|
+
return
|
2300
|
+
|
2305
2301
|
from IPython.display import display
|
2306
2302
|
|
2307
2303
|
serializer = FileSerializer(video_uri)
|
@@ -0,0 +1,305 @@
|
|
1
|
+
import json
|
2
|
+
from enum import Enum
|
3
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from vision_agent.tools.tool_utils import (
|
8
|
+
add_bboxes_from_masks,
|
9
|
+
nms,
|
10
|
+
send_task_inference_request,
|
11
|
+
)
|
12
|
+
from vision_agent.utils.image_utils import denormalize_bbox, rle_decode_array
|
13
|
+
from vision_agent.utils.video import frames_to_bytes
|
14
|
+
|
15
|
+
|
16
|
+
class ODModels(str, Enum):
|
17
|
+
COUNTGD = "countgd"
|
18
|
+
FLORENCE2 = "florence2"
|
19
|
+
OWLV2 = "owlv2"
|
20
|
+
|
21
|
+
|
22
|
+
def split_frames_into_segments(
|
23
|
+
frames: List[np.ndarray], segment_size: int = 50, overlap: int = 1
|
24
|
+
) -> List[List[np.ndarray]]:
|
25
|
+
"""
|
26
|
+
Splits the list of frames into segments with a specified size and overlap.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
frames (List[np.ndarray]): List of video frames.
|
30
|
+
segment_size (int, optional): Number of frames per segment. Defaults to 50.
|
31
|
+
overlap (int, optional): Number of overlapping frames between segments. Defaults to 1.
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
List[List[np.ndarray]]: List of frame segments.
|
35
|
+
"""
|
36
|
+
segments = []
|
37
|
+
start = 0
|
38
|
+
segment_count = 0
|
39
|
+
while start < len(frames):
|
40
|
+
end = start + segment_size
|
41
|
+
if end > len(frames):
|
42
|
+
end = len(frames)
|
43
|
+
if start != 0:
|
44
|
+
# Include the last frame of the previous segment
|
45
|
+
segment = frames[start - overlap : end]
|
46
|
+
else:
|
47
|
+
segment = frames[start:end]
|
48
|
+
segments.append(segment)
|
49
|
+
start += segment_size
|
50
|
+
segment_count += 1
|
51
|
+
return segments
|
52
|
+
|
53
|
+
|
54
|
+
def process_segment(
|
55
|
+
segment_frames: List[np.ndarray],
|
56
|
+
od_model: ODModels,
|
57
|
+
prompt: str,
|
58
|
+
fine_tune_id: Optional[str],
|
59
|
+
chunk_length: Optional[int],
|
60
|
+
image_size: Tuple[int, ...],
|
61
|
+
segment_index: int,
|
62
|
+
object_detection_tool: Callable,
|
63
|
+
) -> Any:
|
64
|
+
"""
|
65
|
+
Processes a segment of frames with the specified object detection model.
|
66
|
+
|
67
|
+
Args:
|
68
|
+
segment_frames (List[np.ndarray]): Frames in the segment.
|
69
|
+
od_model (ODModels): Object detection model to use.
|
70
|
+
prompt (str): Prompt for the model.
|
71
|
+
fine_tune_id (Optional[str]): Fine-tune model ID.
|
72
|
+
chunk_length (Optional[int]): Chunk length for processing.
|
73
|
+
image_size (Tuple[int, int]): Size of the images.
|
74
|
+
segment_index (int): Index of the segment.
|
75
|
+
object_detection_tool (Callable): Object detection tool to use.
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
Any: Detections for the segment.
|
79
|
+
"""
|
80
|
+
segment_results: List[Optional[List[Dict[str, Any]]]] = [None] * len(segment_frames)
|
81
|
+
|
82
|
+
if chunk_length is None:
|
83
|
+
step = 1
|
84
|
+
elif chunk_length <= 0:
|
85
|
+
raise ValueError("chunk_length must be a positive integer or None.")
|
86
|
+
else:
|
87
|
+
step = chunk_length
|
88
|
+
|
89
|
+
function_name = ""
|
90
|
+
|
91
|
+
for idx in range(0, len(segment_frames), step):
|
92
|
+
frame_number = idx
|
93
|
+
segment_results[idx], function_name = object_detection_tool(
|
94
|
+
od_model, prompt, segment_index, frame_number, fine_tune_id, segment_frames
|
95
|
+
)
|
96
|
+
|
97
|
+
transformed_detections = transform_detections(
|
98
|
+
segment_results, image_size, segment_index
|
99
|
+
)
|
100
|
+
|
101
|
+
buffer_bytes = frames_to_bytes(segment_frames)
|
102
|
+
files = [("video", buffer_bytes)]
|
103
|
+
payload = {
|
104
|
+
"bboxes": json.dumps(transformed_detections),
|
105
|
+
"chunk_length_frames": chunk_length,
|
106
|
+
}
|
107
|
+
metadata = {"function_name": function_name}
|
108
|
+
|
109
|
+
segment_detections = send_task_inference_request(
|
110
|
+
payload,
|
111
|
+
"sam2",
|
112
|
+
files=files,
|
113
|
+
metadata=metadata,
|
114
|
+
)
|
115
|
+
|
116
|
+
return segment_detections
|
117
|
+
|
118
|
+
|
119
|
+
def transform_detections(
|
120
|
+
input_list: List[Optional[List[Dict[str, Any]]]],
|
121
|
+
image_size: Tuple[int, ...],
|
122
|
+
segment_index: int,
|
123
|
+
) -> List[Optional[Dict[str, Any]]]:
|
124
|
+
"""
|
125
|
+
Transforms raw detections into a standardized format.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
input_list (List[Optional[List[Dict[str, Any]]]]): Raw detections.
|
129
|
+
image_size (Tuple[int, int]): Size of the images.
|
130
|
+
segment_index (int): Index of the segment.
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
List[Optional[Dict[str, Any]]]: Transformed detections.
|
134
|
+
"""
|
135
|
+
output_list: List[Optional[Dict[str, Any]]] = []
|
136
|
+
for frame_idx, frame in enumerate(input_list):
|
137
|
+
if frame is not None:
|
138
|
+
labels = [detection["label"] for detection in frame]
|
139
|
+
bboxes = [
|
140
|
+
denormalize_bbox(detection["bbox"], image_size) for detection in frame
|
141
|
+
]
|
142
|
+
|
143
|
+
output_list.append(
|
144
|
+
{
|
145
|
+
"labels": labels,
|
146
|
+
"bboxes": bboxes,
|
147
|
+
}
|
148
|
+
)
|
149
|
+
else:
|
150
|
+
output_list.append(None)
|
151
|
+
return output_list
|
152
|
+
|
153
|
+
|
154
|
+
def _calculate_mask_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
|
155
|
+
mask1 = mask1.astype(bool)
|
156
|
+
mask2 = mask2.astype(bool)
|
157
|
+
|
158
|
+
intersection = np.sum(np.logical_and(mask1, mask2))
|
159
|
+
union = np.sum(np.logical_or(mask1, mask2))
|
160
|
+
|
161
|
+
if union == 0:
|
162
|
+
iou = 0.0
|
163
|
+
else:
|
164
|
+
iou = intersection / union
|
165
|
+
|
166
|
+
return iou
|
167
|
+
|
168
|
+
|
169
|
+
def _match_by_iou(
|
170
|
+
first_param: List[Dict],
|
171
|
+
second_param: List[Dict],
|
172
|
+
iou_threshold: float = 0.8,
|
173
|
+
) -> Tuple[List[Dict], Dict[int, int]]:
|
174
|
+
max_id = max((item["id"] for item in first_param), default=0)
|
175
|
+
|
176
|
+
matched_new_item_indices = set()
|
177
|
+
id_mapping = {}
|
178
|
+
|
179
|
+
for new_index, new_item in enumerate(second_param):
|
180
|
+
matched_id = None
|
181
|
+
|
182
|
+
for existing_item in first_param:
|
183
|
+
iou = _calculate_mask_iou(
|
184
|
+
existing_item["decoded_mask"], new_item["decoded_mask"]
|
185
|
+
)
|
186
|
+
if iou > iou_threshold:
|
187
|
+
matched_id = existing_item["id"]
|
188
|
+
matched_new_item_indices.add(new_index)
|
189
|
+
id_mapping[new_item["id"]] = matched_id
|
190
|
+
break
|
191
|
+
|
192
|
+
if matched_id:
|
193
|
+
new_item["id"] = matched_id
|
194
|
+
else:
|
195
|
+
max_id += 1
|
196
|
+
id_mapping[new_item["id"]] = max_id
|
197
|
+
new_item["id"] = max_id
|
198
|
+
|
199
|
+
unmatched_items = [
|
200
|
+
item for i, item in enumerate(second_param) if i not in matched_new_item_indices
|
201
|
+
]
|
202
|
+
combined_list = first_param + unmatched_items
|
203
|
+
|
204
|
+
return combined_list, id_mapping
|
205
|
+
|
206
|
+
|
207
|
+
def _update_ids(detections: List[Dict], id_mapping: Dict[int, int]) -> None:
|
208
|
+
for inner_list in detections:
|
209
|
+
for detection in inner_list:
|
210
|
+
if detection["id"] in id_mapping:
|
211
|
+
detection["id"] = id_mapping[detection["id"]]
|
212
|
+
else:
|
213
|
+
max_new_id = max(id_mapping.values(), default=0)
|
214
|
+
detection["id"] = max_new_id + 1
|
215
|
+
id_mapping[detection["id"]] = detection["id"]
|
216
|
+
|
217
|
+
|
218
|
+
def _convert_to_2d(detections_per_segment: List[Any]) -> List[Any]:
|
219
|
+
result = []
|
220
|
+
for i, segment in enumerate(detections_per_segment):
|
221
|
+
if i == 0:
|
222
|
+
result.extend(segment)
|
223
|
+
else:
|
224
|
+
result.extend(segment[1:])
|
225
|
+
return result
|
226
|
+
|
227
|
+
|
228
|
+
def merge_segments(detections_per_segment: List[Any]) -> List[Any]:
|
229
|
+
"""
|
230
|
+
Merges detections from all segments into a unified result.
|
231
|
+
|
232
|
+
Args:
|
233
|
+
detections_per_segment (List[Any]): List of detections per segment.
|
234
|
+
|
235
|
+
Returns:
|
236
|
+
List[Any]: Merged detections.
|
237
|
+
"""
|
238
|
+
for segment in detections_per_segment:
|
239
|
+
for detection in segment:
|
240
|
+
for item in detection:
|
241
|
+
item["decoded_mask"] = rle_decode_array(item["mask"])
|
242
|
+
|
243
|
+
for segment_idx in range(len(detections_per_segment) - 1):
|
244
|
+
combined_detection, id_mapping = _match_by_iou(
|
245
|
+
detections_per_segment[segment_idx][-1],
|
246
|
+
detections_per_segment[segment_idx + 1][0],
|
247
|
+
)
|
248
|
+
_update_ids(detections_per_segment[segment_idx + 1], id_mapping)
|
249
|
+
|
250
|
+
merged_result = _convert_to_2d(detections_per_segment)
|
251
|
+
|
252
|
+
return merged_result
|
253
|
+
|
254
|
+
|
255
|
+
def post_process(
|
256
|
+
merged_detections: List[Any],
|
257
|
+
image_size: Tuple[int, ...],
|
258
|
+
) -> Dict[str, Any]:
|
259
|
+
"""
|
260
|
+
Performs post-processing on merged detections, including NMS and preparing display data.
|
261
|
+
|
262
|
+
Args:
|
263
|
+
merged_detections (List[Any]): Merged detections from all segments.
|
264
|
+
image_size (Tuple[int, int]): Size of the images.
|
265
|
+
|
266
|
+
Returns:
|
267
|
+
Dict[str, Any]: Post-processed data including return_data and display_data.
|
268
|
+
"""
|
269
|
+
return_data = []
|
270
|
+
for frame_idx, frame in enumerate(merged_detections):
|
271
|
+
return_frame_data = []
|
272
|
+
for detection in frame:
|
273
|
+
label = f"{detection['id']}: {detection['label']}"
|
274
|
+
return_frame_data.append(
|
275
|
+
{
|
276
|
+
"label": label,
|
277
|
+
"mask": detection["decoded_mask"],
|
278
|
+
"rle": detection["mask"],
|
279
|
+
"score": 1.0,
|
280
|
+
}
|
281
|
+
)
|
282
|
+
del detection["decoded_mask"]
|
283
|
+
return_data.append(return_frame_data)
|
284
|
+
|
285
|
+
return_data = add_bboxes_from_masks(return_data)
|
286
|
+
return_data = nms(return_data, iou_threshold=0.95)
|
287
|
+
|
288
|
+
# We save the RLE for display purposes, re-calculting RLE can get very expensive.
|
289
|
+
# Deleted here because we are returning the numpy masks instead
|
290
|
+
display_data = []
|
291
|
+
for frame in return_data:
|
292
|
+
display_frame_data = []
|
293
|
+
for obj in frame:
|
294
|
+
display_frame_data.append(
|
295
|
+
{
|
296
|
+
"label": obj["label"],
|
297
|
+
"bbox": denormalize_bbox(obj["bbox"], image_size),
|
298
|
+
"mask": obj["rle"],
|
299
|
+
"score": obj["score"],
|
300
|
+
}
|
301
|
+
)
|
302
|
+
del obj["rle"]
|
303
|
+
display_data.append(display_frame_data)
|
304
|
+
|
305
|
+
return {"return_data": return_data, "display_data": display_data}
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/vision_agent_coder_prompts.py
RENAMED
File without changes
|
{vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/vision_agent_coder_prompts_v2.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
{vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/vision_agent_planner_prompts.py
RENAMED
File without changes
|
{vision_agent-0.2.223 → vision_agent-0.2.225}/vision_agent/agent/vision_agent_planner_prompts_v2.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|