vision-agent 0.2.224__py3-none-any.whl → 0.2.226__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vision_agent/.sim_tools/df.csv +49 -91
- vision_agent/.sim_tools/embs.npy +0 -0
- vision_agent/agent/agent_utils.py +13 -0
- vision_agent/agent/vision_agent_coder_prompts_v2.py +1 -1
- vision_agent/agent/vision_agent_coder_v2.py +6 -1
- vision_agent/agent/vision_agent_planner_prompts_v2.py +42 -33
- vision_agent/agent/vision_agent_v2.py +30 -22
- vision_agent/tools/planner_tools.py +4 -2
- vision_agent/tools/tools.py +119 -123
- vision_agent/utils/sim.py +6 -0
- vision_agent/utils/video_tracking.py +305 -0
- {vision_agent-0.2.224.dist-info → vision_agent-0.2.226.dist-info}/METADATA +1 -1
- {vision_agent-0.2.224.dist-info → vision_agent-0.2.226.dist-info}/RECORD +15 -14
- {vision_agent-0.2.224.dist-info → vision_agent-0.2.226.dist-info}/LICENSE +0 -0
- {vision_agent-0.2.224.dist-info → vision_agent-0.2.226.dist-info}/WHEEL +0 -0
vision_agent/.sim_tools/df.csv
CHANGED
@@ -65,25 +65,30 @@ desc,doc,name
|
|
65
65
|
},
|
66
66
|
]
|
67
67
|
",owlv2_sam2_instance_segmentation
|
68
|
-
"'owlv2_sam2_video_tracking' is a tool that can segment multiple objects given a text prompt such as category names or referring expressions. The categories in the text prompt are separated by commas. It returns a list of bounding boxes, label names,
|
69
|
-
'owlv2_sam2_video_tracking' is a tool that can segment multiple
|
70
|
-
prompt such as category names or referring
|
71
|
-
prompt are separated by commas. It returns
|
72
|
-
|
68
|
+
"'owlv2_sam2_video_tracking' is a tool that can track and segment multiple objects in a video given a text prompt such as category names or referring expressions. The categories in the text prompt are separated by commas. It returns a list of bounding boxes, label names, masks and associated probability scores and is useful for tracking and counting without duplicating counts.","owlv2_sam2_video_tracking(prompt: str, frames: List[numpy.ndarray], chunk_length: Optional[int] = 10, fine_tune_id: Optional[str] = None) -> List[List[Dict[str, Any]]]:
|
69
|
+
'owlv2_sam2_video_tracking' is a tool that can track and segment multiple
|
70
|
+
objects in a video given a text prompt such as category names or referring
|
71
|
+
expressions. The categories in the text prompt are separated by commas. It returns
|
72
|
+
a list of bounding boxes, label names, masks and associated probability scores and
|
73
|
+
is useful for tracking and counting without duplicating counts.
|
73
74
|
|
74
75
|
Parameters:
|
75
76
|
prompt (str): The prompt to ground to the image.
|
76
|
-
|
77
|
+
frames (List[np.ndarray]): The list of frames to ground the prompt to.
|
78
|
+
chunk_length (Optional[int]): The number of frames to re-run owlv2 to find
|
79
|
+
new objects.
|
77
80
|
fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the
|
78
81
|
fine-tuned model ID here to use it.
|
79
82
|
|
80
83
|
Returns:
|
81
|
-
List[Dict[str, Any]]: A list of dictionaries containing the
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
the
|
84
|
+
List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
|
85
|
+
label, segmentation mask and bounding boxes. The outer list represents each
|
86
|
+
frame and the inner list is the entities per frame. The detected objects
|
87
|
+
have normalized coordinates between 0 and 1 (xmin, ymin, xmax, ymax). xmin
|
88
|
+
and ymin are the coordinates of the top-left and xmax and ymax are the
|
89
|
+
coordinates of the bottom-right of the bounding box. The mask is binary 2D
|
90
|
+
numpy array where 1 indicates the object and 0 indicates the background.
|
91
|
+
The label names are prefixed with their ID represent the total count.
|
87
92
|
|
88
93
|
Example
|
89
94
|
-------
|
@@ -170,25 +175,28 @@ desc,doc,name
|
|
170
175
|
},
|
171
176
|
]
|
172
177
|
",countgd_sam2_instance_segmentation
|
173
|
-
"'countgd_sam2_video_tracking' is a tool that can segment multiple objects given a text prompt such as category names or referring expressions. The categories in the text prompt are separated by commas. It returns a list of bounding boxes, label names,
|
174
|
-
'countgd_sam2_video_tracking' is a tool that can segment multiple
|
175
|
-
prompt such as category names or referring
|
176
|
-
prompt are separated by commas. It returns
|
177
|
-
|
178
|
+
"'countgd_sam2_video_tracking' is a tool that can track and segment multiple objects in a video given a text prompt such as category names or referring expressions. The categories in the text prompt are separated by commas. It returns a list of bounding boxes, label names, masks and associated probability scores and is useful for tracking and counting without duplicating counts.","countgd_sam2_video_tracking(prompt: str, frames: List[numpy.ndarray], chunk_length: Optional[int] = 10) -> List[List[Dict[str, Any]]]:
|
179
|
+
'countgd_sam2_video_tracking' is a tool that can track and segment multiple
|
180
|
+
objects in a video given a text prompt such as category names or referring
|
181
|
+
expressions. The categories in the text prompt are separated by commas. It returns
|
182
|
+
a list of bounding boxes, label names, masks and associated probability scores and
|
183
|
+
is useful for tracking and counting without duplicating counts.
|
178
184
|
|
179
185
|
Parameters:
|
180
186
|
prompt (str): The prompt to ground to the image.
|
181
|
-
|
182
|
-
chunk_length (Optional[int]): The number of frames to re-run
|
187
|
+
frames (List[np.ndarray]): The list of frames to ground the prompt to.
|
188
|
+
chunk_length (Optional[int]): The number of frames to re-run countgd to find
|
183
189
|
new objects.
|
184
190
|
|
185
191
|
Returns:
|
186
|
-
List[Dict[str, Any]]: A list of dictionaries containing the
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
the
|
192
|
+
List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
|
193
|
+
label, segmentation mask and bounding boxes. The outer list represents each
|
194
|
+
frame and the inner list is the entities per frame. The detected objects
|
195
|
+
have normalized coordinates between 0 and 1 (xmin, ymin, xmax, ymax). xmin
|
196
|
+
and ymin are the coordinates of the top-left and xmax and ymax are the
|
197
|
+
coordinates of the bottom-right of the bounding box. The mask is binary 2D
|
198
|
+
numpy array where 1 indicates the object and 0 indicates the background.
|
199
|
+
The label names are prefixed with their ID represent the total count.
|
192
200
|
|
193
201
|
Example
|
194
202
|
-------
|
@@ -265,12 +273,12 @@ desc,doc,name
|
|
265
273
|
},
|
266
274
|
]
|
267
275
|
",florence2_sam2_instance_segmentation
|
268
|
-
'florence2_sam2_video_tracking' is a tool that can
|
269
|
-
'florence2_sam2_video_tracking' is a tool that can
|
270
|
-
|
271
|
-
expressions.
|
272
|
-
|
273
|
-
|
276
|
+
"'florence2_sam2_video_tracking' is a tool that can track and segment multiple objects in a video given a text prompt such as category names or referring expressions. The categories in the text prompt are separated by commas. It returns a list of bounding boxes, label names, masks and associated probability scores and is useful for tracking and counting without duplicating counts.","florence2_sam2_video_tracking(prompt: str, frames: List[numpy.ndarray], chunk_length: Optional[int] = 10, fine_tune_id: Optional[str] = None) -> List[List[Dict[str, Any]]]:
|
277
|
+
'florence2_sam2_video_tracking' is a tool that can track and segment multiple
|
278
|
+
objects in a video given a text prompt such as category names or referring
|
279
|
+
expressions. The categories in the text prompt are separated by commas. It returns
|
280
|
+
a list of bounding boxes, label names, masks and associated probability scores and
|
281
|
+
is useful for tracking and counting without duplicating counts.
|
274
282
|
|
275
283
|
Parameters:
|
276
284
|
prompt (str): The prompt to ground to the video.
|
@@ -282,10 +290,13 @@ desc,doc,name
|
|
282
290
|
|
283
291
|
Returns:
|
284
292
|
List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
293
|
+
label, segmentation mask and bounding boxes. The outer list represents each
|
294
|
+
frame and the inner list is the entities per frame. The detected objects
|
295
|
+
have normalized coordinates between 0 and 1 (xmin, ymin, xmax, ymax). xmin
|
296
|
+
and ymin are the coordinates of the top-left and xmax and ymax are the
|
297
|
+
coordinates of the bottom-right of the bounding box. The mask is binary 2D
|
298
|
+
numpy array where 1 indicates the object and 0 indicates the background.
|
299
|
+
The label names are prefixed with their ID represent the total count.
|
289
300
|
|
290
301
|
Example
|
291
302
|
-------
|
@@ -445,43 +456,6 @@ desc,doc,name
|
|
445
456
|
>>> qwen2_vl_video_vqa('Which football player made the goal?', frames)
|
446
457
|
'Lionel Messi'
|
447
458
|
",qwen2_vl_video_vqa
|
448
|
-
"'detr_segmentation' is a tool that can segment common objects in an image without any text prompt. It returns a list of detected objects as labels, their regions as masks and their scores.","detr_segmentation(image: numpy.ndarray) -> List[Dict[str, Any]]:
|
449
|
-
'detr_segmentation' is a tool that can segment common objects in an
|
450
|
-
image without any text prompt. It returns a list of detected objects
|
451
|
-
as labels, their regions as masks and their scores.
|
452
|
-
|
453
|
-
Parameters:
|
454
|
-
image (np.ndarray): The image used to segment things and objects
|
455
|
-
|
456
|
-
Returns:
|
457
|
-
List[Dict[str, Any]]: A list of dictionaries containing the score, label
|
458
|
-
and mask of the detected objects. The mask is binary 2D numpy array where 1
|
459
|
-
indicates the object and 0 indicates the background.
|
460
|
-
|
461
|
-
Example
|
462
|
-
-------
|
463
|
-
>>> detr_segmentation(image)
|
464
|
-
[
|
465
|
-
{
|
466
|
-
'score': 0.45,
|
467
|
-
'label': 'window',
|
468
|
-
'mask': array([[0, 0, 0, ..., 0, 0, 0],
|
469
|
-
[0, 0, 0, ..., 0, 0, 0],
|
470
|
-
...,
|
471
|
-
[0, 0, 0, ..., 0, 0, 0],
|
472
|
-
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
|
473
|
-
},
|
474
|
-
{
|
475
|
-
'score': 0.70,
|
476
|
-
'label': 'bird',
|
477
|
-
'mask': array([[0, 0, 0, ..., 0, 0, 0],
|
478
|
-
[0, 0, 0, ..., 0, 0, 0],
|
479
|
-
...,
|
480
|
-
[0, 0, 0, ..., 0, 0, 0],
|
481
|
-
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
|
482
|
-
},
|
483
|
-
]
|
484
|
-
",detr_segmentation
|
485
459
|
'depth_anything_v2' is a tool that runs depth_anythingv2 model to generate a depth image from a given RGB image. The returned depth image is monochrome and represents depth values as pixel intesities with pixel values ranging from 0 to 255.,"depth_anything_v2(image: numpy.ndarray) -> numpy.ndarray:
|
486
460
|
'depth_anything_v2' is a tool that runs depth_anythingv2 model to generate a
|
487
461
|
depth image from a given RGB image. The returned depth image is monochrome and
|
@@ -522,22 +496,6 @@ desc,doc,name
|
|
522
496
|
[10, 11, 15, ..., 202, 202, 205],
|
523
497
|
[10, 10, 10, ..., 200, 200, 200]], dtype=uint8),
|
524
498
|
",generate_pose_image
|
525
|
-
'vit_image_classification' is a tool that can classify an image. It returns a list of classes and their probability scores based on image content.,"vit_image_classification(image: numpy.ndarray) -> Dict[str, Any]:
|
526
|
-
'vit_image_classification' is a tool that can classify an image. It returns a
|
527
|
-
list of classes and their probability scores based on image content.
|
528
|
-
|
529
|
-
Parameters:
|
530
|
-
image (np.ndarray): The image to classify or tag
|
531
|
-
|
532
|
-
Returns:
|
533
|
-
Dict[str, Any]: A dictionary containing the labels and scores. One dictionary
|
534
|
-
contains a list of labels and other a list of scores.
|
535
|
-
|
536
|
-
Example
|
537
|
-
-------
|
538
|
-
>>> vit_image_classification(image)
|
539
|
-
{""labels"": [""leopard"", ""lemur, otter"", ""bird""], ""scores"": [0.68, 0.30, 0.02]},
|
540
|
-
",vit_image_classification
|
541
499
|
'vit_nsfw_classification' is a tool that can classify an image as 'nsfw' or 'normal'. It returns the predicted label and their probability scores based on image content.,"vit_nsfw_classification(image: numpy.ndarray) -> Dict[str, Any]:
|
542
500
|
'vit_nsfw_classification' is a tool that can classify an image as 'nsfw' or 'normal'.
|
543
501
|
It returns the predicted label and their probability scores based on image content.
|
@@ -566,7 +524,7 @@ desc,doc,name
|
|
566
524
|
prompt (str): The question about the video
|
567
525
|
frames (List[np.ndarray]): The reference frames used for the question
|
568
526
|
model (str): The model to use for the inference. Valid values are
|
569
|
-
'qwen2vl', 'gpt4o'
|
527
|
+
'qwen2vl', 'gpt4o'.
|
570
528
|
chunk_length_frames (Optional[int]): length of each chunk in frames
|
571
529
|
|
572
530
|
Returns:
|
@@ -641,7 +599,7 @@ desc,doc,name
|
|
641
599
|
>>> closest_distance(det1, det2, image_size)
|
642
600
|
141.42
|
643
601
|
",minimum_distance
|
644
|
-
"'extract_frames_and_timestamps' extracts frames and timestamps from a video which can be a file path, url or youtube link, returns a list of dictionaries with keys ""frame"" and ""timestamp"" where ""frame"" is a numpy array and ""timestamp"" is the relative time in seconds where the frame was captured. The frame is a numpy array.","extract_frames_and_timestamps(video_uri: Union[str, pathlib.Path], fps: float =
|
602
|
+
"'extract_frames_and_timestamps' extracts frames and timestamps from a video which can be a file path, url or youtube link, returns a list of dictionaries with keys ""frame"" and ""timestamp"" where ""frame"" is a numpy array and ""timestamp"" is the relative time in seconds where the frame was captured. The frame is a numpy array.","extract_frames_and_timestamps(video_uri: Union[str, pathlib.Path], fps: float = 5) -> List[Dict[str, Union[numpy.ndarray, float]]]:
|
645
603
|
'extract_frames_and_timestamps' extracts frames and timestamps from a video
|
646
604
|
which can be a file path, url or youtube link, returns a list of dictionaries
|
647
605
|
with keys ""frame"" and ""timestamp"" where ""frame"" is a numpy array and ""timestamp"" is
|
@@ -651,7 +609,7 @@ desc,doc,name
|
|
651
609
|
Parameters:
|
652
610
|
video_uri (Union[str, Path]): The path to the video file, url or youtube link
|
653
611
|
fps (float, optional): The frame rate per second to extract the frames. Defaults
|
654
|
-
to
|
612
|
+
to 5.
|
655
613
|
|
656
614
|
Returns:
|
657
615
|
List[Dict[str, Union[np.ndarray, float]]]: A list of dictionaries containing the
|
vision_agent/.sim_tools/embs.npy
CHANGED
Binary file
|
@@ -153,6 +153,19 @@ def format_plan_v2(plan: PlanContext) -> str:
|
|
153
153
|
return plan_str
|
154
154
|
|
155
155
|
|
156
|
+
def format_conversation(chat: List[AgentMessage]) -> str:
|
157
|
+
chat = copy.deepcopy(chat)
|
158
|
+
prompt = ""
|
159
|
+
for chat_i in chat:
|
160
|
+
if chat_i.role == "user":
|
161
|
+
prompt += f"USER: {chat_i.content}\n\n"
|
162
|
+
elif chat_i.role == "observation" or chat_i.role == "coder":
|
163
|
+
prompt += f"OBSERVATION: {chat_i.content}\n\n"
|
164
|
+
elif chat_i.role == "conversation":
|
165
|
+
prompt += f"AGENT: {chat_i.content}\n\n"
|
166
|
+
return prompt
|
167
|
+
|
168
|
+
|
156
169
|
def format_plans(plans: Dict[str, Any]) -> str:
|
157
170
|
plan_str = ""
|
158
171
|
for k, v in plans.items():
|
@@ -65,7 +65,7 @@ This is the documentation for the functions you have access to. You may call any
|
|
65
65
|
7. DO NOT assert the output value, run the code and assert only the output format or data structure.
|
66
66
|
8. DO NOT use try except block to handle the error, let the error be raised if the code is incorrect.
|
67
67
|
9. DO NOT import the testing function as it will available in the testing environment.
|
68
|
-
10. Print the output of the function that is being tested.
|
68
|
+
10. Print the output of the function that is being tested and ensure it is not empty.
|
69
69
|
11. Use the output of the function that is being tested as the return value of the testing function.
|
70
70
|
12. Run the testing function in the end and don't assign a variable to its output.
|
71
71
|
13. Output your test code using <code> tags:
|
@@ -202,7 +202,12 @@ def write_and_test_code(
|
|
202
202
|
tool_docs=tool_docs,
|
203
203
|
plan=plan,
|
204
204
|
)
|
205
|
-
|
205
|
+
try:
|
206
|
+
code = strip_function_calls(code)
|
207
|
+
except Exception:
|
208
|
+
# the code may be malformatted, this will fail in the exec call and the agent
|
209
|
+
# will attempt to debug it
|
210
|
+
pass
|
206
211
|
test = write_test(
|
207
212
|
tester=tester,
|
208
213
|
chat=chat,
|
@@ -136,8 +136,9 @@ Tool Documentation:
|
|
136
136
|
countgd_object_detection(prompt: str, image: numpy.ndarray, box_threshold: float = 0.23) -> List[Dict[str, Any]]:
|
137
137
|
'countgd_object_detection' is a tool that can detect multiple instances of an
|
138
138
|
object given a text prompt. It is particularly useful when trying to detect and
|
139
|
-
count a large number of objects.
|
140
|
-
|
139
|
+
count a large number of objects. You can optionally separate object names in the
|
140
|
+
prompt with commas. It returns a list of bounding boxes with normalized
|
141
|
+
coordinates, label names and associated confidence scores.
|
141
142
|
|
142
143
|
Parameters:
|
143
144
|
prompt (str): The object that needs to be counted.
|
@@ -272,40 +273,47 @@ OBSERVATION:
|
|
272
273
|
[get_tool_for_task output]
|
273
274
|
For tracking boxes moving on a conveyor belt, we need a tool that can consistently track the same box across frames without losing it or double counting. Looking at the outputs: florence2_sam2_video_tracking successfully tracks the single box across all 5 frames, maintaining consistent tracking IDs and showing the box's movement along the conveyor.
|
274
275
|
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
276
|
+
Tool Documentation:
|
277
|
+
def florence2_sam2_video_tracking(prompt: str, frames: List[np.ndarray], chunk_length: Optional[int] = 10) -> List[List[Dict[str, Any]]]:
|
278
|
+
'florence2_sam2_video_tracking' is a tool that can track and segment multiple
|
279
|
+
objects in a video given a text prompt such as category names or referring
|
280
|
+
expressions. The categories in the text prompt are separated by commas. It returns
|
281
|
+
a list of bounding boxes, label names, masks and associated probability scores and
|
282
|
+
is useful for tracking and counting without duplicating counts.
|
280
283
|
|
281
|
-
Parameters:
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
284
|
+
Parameters:
|
285
|
+
prompt (str): The prompt to ground to the video.
|
286
|
+
frames (List[np.ndarray]): The list of frames to ground the prompt to.
|
287
|
+
chunk_length (Optional[int]): The number of frames to re-run florence2 to find
|
288
|
+
new objects.
|
289
|
+
fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the
|
290
|
+
fine-tuned model ID here to use it.
|
286
291
|
|
287
|
-
Returns:
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
292
|
+
Returns:
|
293
|
+
List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
|
294
|
+
label, segmentation mask and bounding boxes. The outer list represents each
|
295
|
+
frame and the inner list is the entities per frame. The detected objects
|
296
|
+
have normalized coordinates between 0 and 1 (xmin, ymin, xmax, ymax). xmin
|
297
|
+
and ymin are the coordinates of the top-left and xmax and ymax are the
|
298
|
+
coordinates of the bottom-right of the bounding box. The mask is binary 2D
|
299
|
+
numpy array where 1 indicates the object and 0 indicates the background.
|
300
|
+
The label names are prefixed with their ID represent the total count.
|
293
301
|
|
294
|
-
Example
|
295
|
-
-------
|
296
|
-
|
297
|
-
[
|
302
|
+
Example
|
303
|
+
-------
|
304
|
+
>>> florence2_sam2_video_tracking("car, dinosaur", frames)
|
298
305
|
[
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
...,
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
306
|
+
[
|
307
|
+
{
|
308
|
+
'label': '0: dinosaur',
|
309
|
+
'bbox': [0.1, 0.11, 0.35, 0.4],
|
310
|
+
'mask': array([[0, 0, 0, ..., 0, 0, 0],
|
311
|
+
...,
|
312
|
+
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
|
313
|
+
},
|
314
|
+
],
|
315
|
+
...
|
316
|
+
]
|
309
317
|
[end of get_tool_for_task output]
|
310
318
|
<count>8</count>
|
311
319
|
|
@@ -691,7 +699,8 @@ FINALIZE_PLAN = """
|
|
691
699
|
4. Specifically call out the tools used and the order in which they were used. Only include tools obtained from calling `get_tool_for_task`.
|
692
700
|
5. Do not include {excluded_tools} tools in your instructions.
|
693
701
|
6. Add final instructions for visualizing the output with `overlay_bounding_boxes` or `overlay_segmentation_masks` and saving it to a file with `save_file` or `save_video`.
|
694
|
-
|
702
|
+
7. Use the default FPS for extracting frames from videos unless otherwise specified by the user.
|
703
|
+
8. Respond in the following format with JSON surrounded by <json> tags and code surrounded by <code> tags:
|
695
704
|
|
696
705
|
<json>
|
697
706
|
{{
|
@@ -1,13 +1,14 @@
|
|
1
1
|
import copy
|
2
2
|
import json
|
3
3
|
from pathlib import Path
|
4
|
-
from typing import Any, Callable, Dict, List, Optional, Union, cast
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
|
5
5
|
|
6
6
|
from vision_agent.agent import Agent, AgentCoder, VisionAgentCoderV2
|
7
7
|
from vision_agent.agent.agent_utils import (
|
8
8
|
add_media_to_chat,
|
9
9
|
convert_message_to_agentmessage,
|
10
10
|
extract_tag,
|
11
|
+
format_conversation,
|
11
12
|
)
|
12
13
|
from vision_agent.agent.types import (
|
13
14
|
AgentMessage,
|
@@ -22,19 +23,6 @@ from vision_agent.lmm.types import Message
|
|
22
23
|
from vision_agent.utils.execute import CodeInterpreter, CodeInterpreterFactory
|
23
24
|
|
24
25
|
|
25
|
-
def format_conversation(chat: List[AgentMessage]) -> str:
|
26
|
-
chat = copy.deepcopy(chat)
|
27
|
-
prompt = ""
|
28
|
-
for chat_i in chat:
|
29
|
-
if chat_i.role == "user":
|
30
|
-
prompt += f"USER: {chat_i.content}\n\n"
|
31
|
-
elif chat_i.role == "observation" or chat_i.role == "coder":
|
32
|
-
prompt += f"OBSERVATION: {chat_i.content}\n\n"
|
33
|
-
elif chat_i.role == "conversation":
|
34
|
-
prompt += f"AGENT: {chat_i.content}\n\n"
|
35
|
-
return prompt
|
36
|
-
|
37
|
-
|
38
26
|
def run_conversation(agent: LMM, chat: List[AgentMessage]) -> str:
|
39
27
|
# only keep last 10 messages
|
40
28
|
conv = format_conversation(chat[-10:])
|
@@ -55,23 +43,39 @@ def check_for_interaction(chat: List[AgentMessage]) -> bool:
|
|
55
43
|
|
56
44
|
def extract_conversation_for_generate_code(
|
57
45
|
chat: List[AgentMessage],
|
58
|
-
) -> List[AgentMessage]:
|
46
|
+
) -> Tuple[List[AgentMessage], Optional[str]]:
|
59
47
|
chat = copy.deepcopy(chat)
|
60
48
|
|
61
49
|
# if we are in the middle of an interaction, return all the intermediate planning
|
62
50
|
# steps
|
63
51
|
if check_for_interaction(chat):
|
64
|
-
return chat
|
52
|
+
return chat, None
|
65
53
|
|
66
54
|
extracted_chat = []
|
67
55
|
for chat_i in chat:
|
68
56
|
if chat_i.role == "user":
|
69
57
|
extracted_chat.append(chat_i)
|
70
58
|
elif chat_i.role == "coder":
|
71
|
-
if "<final_code>" in chat_i.content
|
59
|
+
if "<final_code>" in chat_i.content:
|
72
60
|
extracted_chat.append(chat_i)
|
73
61
|
|
74
|
-
|
62
|
+
# only keep the last <final_code> and <final_test>
|
63
|
+
final_code = None
|
64
|
+
extracted_chat_strip_code: List[AgentMessage] = []
|
65
|
+
for chat_i in reversed(extracted_chat):
|
66
|
+
if "<final_code>" in chat_i.content and final_code is None:
|
67
|
+
extracted_chat_strip_code = [chat_i] + extracted_chat_strip_code
|
68
|
+
final_code = extract_tag(chat_i.content, "final_code")
|
69
|
+
if final_code is not None:
|
70
|
+
test_code = extract_tag(chat_i.content, "final_test")
|
71
|
+
final_code += "\n" + test_code if test_code is not None else ""
|
72
|
+
|
73
|
+
if "<final_code>" in chat_i.content and final_code is not None:
|
74
|
+
continue
|
75
|
+
|
76
|
+
extracted_chat_strip_code = [chat_i] + extracted_chat_strip_code
|
77
|
+
|
78
|
+
return extracted_chat_strip_code[-5:], final_code
|
75
79
|
|
76
80
|
|
77
81
|
def maybe_run_action(
|
@@ -81,7 +85,7 @@ def maybe_run_action(
|
|
81
85
|
code_interpreter: Optional[CodeInterpreter] = None,
|
82
86
|
) -> Optional[List[AgentMessage]]:
|
83
87
|
if action == "generate_or_edit_vision_code":
|
84
|
-
extracted_chat = extract_conversation_for_generate_code(chat)
|
88
|
+
extracted_chat, _ = extract_conversation_for_generate_code(chat)
|
85
89
|
# there's an issue here because coder.generate_code will send it's code_context
|
86
90
|
# to the outside user via it's update_callback, but we don't necessarily have
|
87
91
|
# access to that update_callback here, so we re-create the message using
|
@@ -101,11 +105,15 @@ def maybe_run_action(
|
|
101
105
|
)
|
102
106
|
]
|
103
107
|
elif action == "edit_code":
|
104
|
-
extracted_chat = extract_conversation_for_generate_code(chat)
|
108
|
+
extracted_chat, final_code = extract_conversation_for_generate_code(chat)
|
105
109
|
plan_context = PlanContext(
|
106
110
|
plan="Edit the latest code observed in the fewest steps possible according to the user's feedback.",
|
107
|
-
instructions=[
|
108
|
-
|
111
|
+
instructions=[
|
112
|
+
chat_i.content
|
113
|
+
for chat_i in extracted_chat
|
114
|
+
if chat_i.role == "user" and "<final_code>" not in chat_i.content
|
115
|
+
],
|
116
|
+
code=final_code if final_code is not None else "",
|
109
117
|
)
|
110
118
|
context = coder.generate_code_from_plan(
|
111
119
|
extracted_chat, plan_context, code_interpreter=code_interpreter
|
@@ -193,8 +193,10 @@ def get_tool_for_task(
|
|
193
193
|
- Depth and pose estimation
|
194
194
|
- Video object tracking
|
195
195
|
|
196
|
-
|
197
|
-
|
196
|
+
Only ask for one type of task at a time, for example a task needing to identify
|
197
|
+
text is one OCR task while needing to identify non-text objects is an OD task. Wait
|
198
|
+
until the documentation is printed to use the function so you know what the input
|
199
|
+
and output signatures are.
|
198
200
|
|
199
201
|
Parameters:
|
200
202
|
task: str: The task to accomplish.
|
vision_agent/tools/tools.py
CHANGED
@@ -6,7 +6,6 @@ import tempfile
|
|
6
6
|
import urllib.request
|
7
7
|
from base64 import b64encode
|
8
8
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
9
|
-
from enum import Enum
|
10
9
|
from importlib import resources
|
11
10
|
from pathlib import Path
|
12
11
|
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
@@ -54,6 +53,13 @@ from vision_agent.utils.video import (
|
|
54
53
|
frames_to_bytes,
|
55
54
|
video_writer,
|
56
55
|
)
|
56
|
+
from vision_agent.utils.video_tracking import (
|
57
|
+
ODModels,
|
58
|
+
merge_segments,
|
59
|
+
post_process,
|
60
|
+
process_segment,
|
61
|
+
split_frames_into_segments,
|
62
|
+
)
|
57
63
|
|
58
64
|
register_heif_opener()
|
59
65
|
|
@@ -224,12 +230,6 @@ def sam2(
|
|
224
230
|
return ret["return_data"] # type: ignore
|
225
231
|
|
226
232
|
|
227
|
-
class ODModels(str, Enum):
|
228
|
-
COUNTGD = "countgd"
|
229
|
-
FLORENCE2 = "florence2"
|
230
|
-
OWLV2 = "owlv2"
|
231
|
-
|
232
|
-
|
233
233
|
def od_sam2_video_tracking(
|
234
234
|
od_model: ODModels,
|
235
235
|
prompt: str,
|
@@ -237,105 +237,92 @@ def od_sam2_video_tracking(
|
|
237
237
|
chunk_length: Optional[int] = 10,
|
238
238
|
fine_tune_id: Optional[str] = None,
|
239
239
|
) -> Dict[str, Any]:
|
240
|
-
|
240
|
+
SEGMENT_SIZE = 50
|
241
|
+
OVERLAP = 1 # Number of overlapping frames between segments
|
241
242
|
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
243
|
+
image_size = frames[0].shape[:2]
|
244
|
+
|
245
|
+
# Split frames into segments with overlap
|
246
|
+
segments = split_frames_into_segments(frames, SEGMENT_SIZE, OVERLAP)
|
247
|
+
|
248
|
+
def _apply_object_detection( # inner method to avoid circular importing issues.
|
249
|
+
od_model: ODModels,
|
250
|
+
prompt: str,
|
251
|
+
segment_index: int,
|
252
|
+
frame_number: int,
|
253
|
+
fine_tune_id: str,
|
254
|
+
segment_frames: list,
|
255
|
+
) -> tuple:
|
256
|
+
"""
|
257
|
+
Applies the specified object detection model to the given image.
|
258
|
+
|
259
|
+
Args:
|
260
|
+
od_model: The object detection model to use.
|
261
|
+
prompt: The prompt for the object detection model.
|
262
|
+
segment_index: The index of the current segment.
|
263
|
+
frame_number: The number of the current frame.
|
264
|
+
fine_tune_id: Optional fine-tune ID for the model.
|
265
|
+
segment_frames: List of frames for the current segment.
|
266
|
+
|
267
|
+
Returns:
|
268
|
+
A tuple containing the object detection results and the name of the function used.
|
269
|
+
"""
|
248
270
|
|
249
|
-
for idx in range(0, len(frames), step):
|
250
271
|
if od_model == ODModels.COUNTGD:
|
251
|
-
|
272
|
+
segment_results = countgd_object_detection(
|
273
|
+
prompt=prompt, image=segment_frames[frame_number]
|
274
|
+
)
|
252
275
|
function_name = "countgd_object_detection"
|
276
|
+
|
253
277
|
elif od_model == ODModels.OWLV2:
|
254
|
-
|
255
|
-
prompt=prompt,
|
278
|
+
segment_results = owlv2_object_detection(
|
279
|
+
prompt=prompt,
|
280
|
+
image=segment_frames[frame_number],
|
281
|
+
fine_tune_id=fine_tune_id,
|
256
282
|
)
|
257
283
|
function_name = "owlv2_object_detection"
|
284
|
+
|
258
285
|
elif od_model == ODModels.FLORENCE2:
|
259
|
-
|
260
|
-
prompt=prompt,
|
286
|
+
segment_results = florence2_object_detection(
|
287
|
+
prompt=prompt,
|
288
|
+
image=segment_frames[frame_number],
|
289
|
+
fine_tune_id=fine_tune_id,
|
261
290
|
)
|
262
291
|
function_name = "florence2_object_detection"
|
292
|
+
|
263
293
|
else:
|
264
294
|
raise NotImplementedError(
|
265
295
|
f"Object detection model '{od_model}' is not implemented."
|
266
296
|
)
|
267
297
|
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
{
|
285
|
-
"labels": labels,
|
286
|
-
"bboxes": bboxes,
|
287
|
-
}
|
288
|
-
)
|
289
|
-
else:
|
290
|
-
output_list.append(None)
|
291
|
-
|
292
|
-
return output_list
|
298
|
+
return segment_results, function_name
|
299
|
+
|
300
|
+
# Process each segment and collect detections
|
301
|
+
detections_per_segment: List[Any] = []
|
302
|
+
for segment_index, segment in enumerate(segments):
|
303
|
+
segment_detections = process_segment(
|
304
|
+
segment_frames=segment,
|
305
|
+
od_model=od_model,
|
306
|
+
prompt=prompt,
|
307
|
+
fine_tune_id=fine_tune_id,
|
308
|
+
chunk_length=chunk_length,
|
309
|
+
image_size=image_size,
|
310
|
+
segment_index=segment_index,
|
311
|
+
object_detection_tool=_apply_object_detection,
|
312
|
+
)
|
313
|
+
detections_per_segment.append(segment_detections)
|
293
314
|
|
294
|
-
|
315
|
+
merged_detections = merge_segments(detections_per_segment)
|
316
|
+
post_processed = post_process(merged_detections, image_size)
|
295
317
|
|
296
318
|
buffer_bytes = frames_to_bytes(frames)
|
297
319
|
files = [("video", buffer_bytes)]
|
298
|
-
payload = {"bboxes": json.dumps(output), "chunk_length_frames": chunk_length}
|
299
|
-
metadata = {"function_name": function_name}
|
300
|
-
|
301
|
-
detections = send_task_inference_request(
|
302
|
-
payload,
|
303
|
-
"sam2",
|
304
|
-
files=files,
|
305
|
-
metadata=metadata,
|
306
|
-
)
|
307
|
-
|
308
|
-
return_data = []
|
309
|
-
for frame in detections:
|
310
|
-
return_frame_data = []
|
311
|
-
for detection in frame:
|
312
|
-
mask = rle_decode_array(detection["mask"])
|
313
|
-
label = str(detection["id"]) + ": " + detection["label"]
|
314
|
-
return_frame_data.append(
|
315
|
-
{"label": label, "mask": mask, "score": 1.0, "rle": detection["mask"]}
|
316
|
-
)
|
317
|
-
return_data.append(return_frame_data)
|
318
|
-
return_data = add_bboxes_from_masks(return_data)
|
319
|
-
return_data = nms(return_data, iou_threshold=0.95)
|
320
320
|
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
for obj in frame:
|
327
|
-
display_frame_data.append(
|
328
|
-
{
|
329
|
-
"label": obj["label"],
|
330
|
-
"score": obj["score"],
|
331
|
-
"bbox": denormalize_bbox(obj["bbox"], image_size),
|
332
|
-
"mask": obj["rle"],
|
333
|
-
}
|
334
|
-
)
|
335
|
-
del obj["rle"]
|
336
|
-
display_data.append(display_frame_data)
|
337
|
-
|
338
|
-
return {"files": files, "return_data": return_data, "display_data": detections}
|
321
|
+
return {
|
322
|
+
"files": files,
|
323
|
+
"return_data": post_processed["return_data"],
|
324
|
+
"display_data": post_processed["display_data"],
|
325
|
+
}
|
339
326
|
|
340
327
|
|
341
328
|
# Owl V2 Tools
|
@@ -528,24 +515,29 @@ def owlv2_sam2_video_tracking(
|
|
528
515
|
chunk_length: Optional[int] = 10,
|
529
516
|
fine_tune_id: Optional[str] = None,
|
530
517
|
) -> List[List[Dict[str, Any]]]:
|
531
|
-
"""'owlv2_sam2_video_tracking' is a tool that can segment multiple
|
532
|
-
prompt such as category names or referring
|
533
|
-
prompt are separated by commas. It returns
|
534
|
-
|
518
|
+
"""'owlv2_sam2_video_tracking' is a tool that can track and segment multiple
|
519
|
+
objects in a video given a text prompt such as category names or referring
|
520
|
+
expressions. The categories in the text prompt are separated by commas. It returns
|
521
|
+
a list of bounding boxes, label names, masks and associated probability scores and
|
522
|
+
is useful for tracking and counting without duplicating counts.
|
535
523
|
|
536
524
|
Parameters:
|
537
525
|
prompt (str): The prompt to ground to the image.
|
538
|
-
|
526
|
+
frames (List[np.ndarray]): The list of frames to ground the prompt to.
|
527
|
+
chunk_length (Optional[int]): The number of frames to re-run owlv2 to find
|
528
|
+
new objects.
|
539
529
|
fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the
|
540
530
|
fine-tuned model ID here to use it.
|
541
531
|
|
542
532
|
Returns:
|
543
|
-
List[Dict[str, Any]]: A list of dictionaries containing the
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
the
|
533
|
+
List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
|
534
|
+
label, segmentation mask and bounding boxes. The outer list represents each
|
535
|
+
frame and the inner list is the entities per frame. The detected objects
|
536
|
+
have normalized coordinates between 0 and 1 (xmin, ymin, xmax, ymax). xmin
|
537
|
+
and ymin are the coordinates of the top-left and xmax and ymax are the
|
538
|
+
coordinates of the bottom-right of the bounding box. The mask is binary 2D
|
539
|
+
numpy array where 1 indicates the object and 0 indicates the background.
|
540
|
+
The label names are prefixed with their ID represent the total count.
|
549
541
|
|
550
542
|
Example
|
551
543
|
-------
|
@@ -755,11 +747,11 @@ def florence2_sam2_video_tracking(
|
|
755
747
|
chunk_length: Optional[int] = 10,
|
756
748
|
fine_tune_id: Optional[str] = None,
|
757
749
|
) -> List[List[Dict[str, Any]]]:
|
758
|
-
"""'florence2_sam2_video_tracking' is a tool that can
|
759
|
-
|
760
|
-
expressions.
|
761
|
-
|
762
|
-
|
750
|
+
"""'florence2_sam2_video_tracking' is a tool that can track and segment multiple
|
751
|
+
objects in a video given a text prompt such as category names or referring
|
752
|
+
expressions. The categories in the text prompt are separated by commas. It returns
|
753
|
+
a list of bounding boxes, label names, masks and associated probability scores and
|
754
|
+
is useful for tracking and counting without duplicating counts.
|
763
755
|
|
764
756
|
Parameters:
|
765
757
|
prompt (str): The prompt to ground to the video.
|
@@ -771,10 +763,13 @@ def florence2_sam2_video_tracking(
|
|
771
763
|
|
772
764
|
Returns:
|
773
765
|
List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
766
|
+
label, segmentation mask and bounding boxes. The outer list represents each
|
767
|
+
frame and the inner list is the entities per frame. The detected objects
|
768
|
+
have normalized coordinates between 0 and 1 (xmin, ymin, xmax, ymax). xmin
|
769
|
+
and ymin are the coordinates of the top-left and xmax and ymax are the
|
770
|
+
coordinates of the bottom-right of the bounding box. The mask is binary 2D
|
771
|
+
numpy array where 1 indicates the object and 0 indicates the background.
|
772
|
+
The label names are prefixed with their ID represent the total count.
|
778
773
|
|
779
774
|
Example
|
780
775
|
-------
|
@@ -1089,24 +1084,27 @@ def countgd_sam2_video_tracking(
|
|
1089
1084
|
frames: List[np.ndarray],
|
1090
1085
|
chunk_length: Optional[int] = 10,
|
1091
1086
|
) -> List[List[Dict[str, Any]]]:
|
1092
|
-
"""'countgd_sam2_video_tracking' is a tool that can segment multiple
|
1093
|
-
prompt such as category names or referring
|
1094
|
-
prompt are separated by commas. It returns
|
1095
|
-
|
1087
|
+
"""'countgd_sam2_video_tracking' is a tool that can track and segment multiple
|
1088
|
+
objects in a video given a text prompt such as category names or referring
|
1089
|
+
expressions. The categories in the text prompt are separated by commas. It returns
|
1090
|
+
a list of bounding boxes, label names, masks and associated probability scores and
|
1091
|
+
is useful for tracking and counting without duplicating counts.
|
1096
1092
|
|
1097
1093
|
Parameters:
|
1098
1094
|
prompt (str): The prompt to ground to the image.
|
1099
|
-
|
1100
|
-
chunk_length (Optional[int]): The number of frames to re-run
|
1095
|
+
frames (List[np.ndarray]): The list of frames to ground the prompt to.
|
1096
|
+
chunk_length (Optional[int]): The number of frames to re-run countgd to find
|
1101
1097
|
new objects.
|
1102
1098
|
|
1103
1099
|
Returns:
|
1104
|
-
List[Dict[str, Any]]: A list of dictionaries containing the
|
1105
|
-
|
1106
|
-
|
1107
|
-
|
1108
|
-
|
1109
|
-
the
|
1100
|
+
List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
|
1101
|
+
label, segmentation mask and bounding boxes. The outer list represents each
|
1102
|
+
frame and the inner list is the entities per frame. The detected objects
|
1103
|
+
have normalized coordinates between 0 and 1 (xmin, ymin, xmax, ymax). xmin
|
1104
|
+
and ymin are the coordinates of the top-left and xmax and ymax are the
|
1105
|
+
coordinates of the bottom-right of the bounding box. The mask is binary 2D
|
1106
|
+
numpy array where 1 indicates the object and 0 indicates the background.
|
1107
|
+
The label names are prefixed with their ID represent the total count.
|
1110
1108
|
|
1111
1109
|
Example
|
1112
1110
|
-------
|
@@ -1546,7 +1544,7 @@ def video_temporal_localization(
|
|
1546
1544
|
prompt (str): The question about the video
|
1547
1545
|
frames (List[np.ndarray]): The reference frames used for the question
|
1548
1546
|
model (str): The model to use for the inference. Valid values are
|
1549
|
-
'qwen2vl', 'gpt4o'
|
1547
|
+
'qwen2vl', 'gpt4o'.
|
1550
1548
|
chunk_length_frames (Optional[int]): length of each chunk in frames
|
1551
1549
|
|
1552
1550
|
Returns:
|
@@ -2115,7 +2113,7 @@ def closest_box_distance(
|
|
2115
2113
|
|
2116
2114
|
|
2117
2115
|
def extract_frames_and_timestamps(
|
2118
|
-
video_uri: Union[str, Path], fps: float =
|
2116
|
+
video_uri: Union[str, Path], fps: float = 5
|
2119
2117
|
) -> List[Dict[str, Union[np.ndarray, float]]]:
|
2120
2118
|
"""'extract_frames_and_timestamps' extracts frames and timestamps from a video
|
2121
2119
|
which can be a file path, url or youtube link, returns a list of dictionaries
|
@@ -2126,7 +2124,7 @@ def extract_frames_and_timestamps(
|
|
2126
2124
|
Parameters:
|
2127
2125
|
video_uri (Union[str, Path]): The path to the video file, url or youtube link
|
2128
2126
|
fps (float, optional): The frame rate per second to extract the frames. Defaults
|
2129
|
-
to
|
2127
|
+
to 5.
|
2130
2128
|
|
2131
2129
|
Returns:
|
2132
2130
|
List[Dict[str, Union[np.ndarray, float]]]: A list of dictionaries containing the
|
@@ -2649,10 +2647,8 @@ FUNCTION_TOOLS = [
|
|
2649
2647
|
ocr,
|
2650
2648
|
qwen2_vl_images_vqa,
|
2651
2649
|
qwen2_vl_video_vqa,
|
2652
|
-
detr_segmentation,
|
2653
2650
|
depth_anything_v2,
|
2654
2651
|
generate_pose_image,
|
2655
|
-
vit_image_classification,
|
2656
2652
|
vit_nsfw_classification,
|
2657
2653
|
video_temporal_localization,
|
2658
2654
|
flux_image_inpainting,
|
vision_agent/utils/sim.py
CHANGED
@@ -133,6 +133,12 @@ class Sim:
|
|
133
133
|
df: pd.DataFrame,
|
134
134
|
) -> bool:
|
135
135
|
load_dir = Path(load_dir)
|
136
|
+
if (
|
137
|
+
not Path(load_dir / "df.csv").exists()
|
138
|
+
or not Path(load_dir / "embs.npy").exists()
|
139
|
+
):
|
140
|
+
return False
|
141
|
+
|
136
142
|
df_load = pd.read_csv(load_dir / "df.csv")
|
137
143
|
if platform.system() == "Windows":
|
138
144
|
df_load["doc"] = df_load["doc"].apply(lambda x: x.replace("\r", ""))
|
@@ -0,0 +1,305 @@
|
|
1
|
+
import json
|
2
|
+
from enum import Enum
|
3
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from vision_agent.tools.tool_utils import (
|
8
|
+
add_bboxes_from_masks,
|
9
|
+
nms,
|
10
|
+
send_task_inference_request,
|
11
|
+
)
|
12
|
+
from vision_agent.utils.image_utils import denormalize_bbox, rle_decode_array
|
13
|
+
from vision_agent.utils.video import frames_to_bytes
|
14
|
+
|
15
|
+
|
16
|
+
class ODModels(str, Enum):
|
17
|
+
COUNTGD = "countgd"
|
18
|
+
FLORENCE2 = "florence2"
|
19
|
+
OWLV2 = "owlv2"
|
20
|
+
|
21
|
+
|
22
|
+
def split_frames_into_segments(
|
23
|
+
frames: List[np.ndarray], segment_size: int = 50, overlap: int = 1
|
24
|
+
) -> List[List[np.ndarray]]:
|
25
|
+
"""
|
26
|
+
Splits the list of frames into segments with a specified size and overlap.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
frames (List[np.ndarray]): List of video frames.
|
30
|
+
segment_size (int, optional): Number of frames per segment. Defaults to 50.
|
31
|
+
overlap (int, optional): Number of overlapping frames between segments. Defaults to 1.
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
List[List[np.ndarray]]: List of frame segments.
|
35
|
+
"""
|
36
|
+
segments = []
|
37
|
+
start = 0
|
38
|
+
segment_count = 0
|
39
|
+
while start < len(frames):
|
40
|
+
end = start + segment_size
|
41
|
+
if end > len(frames):
|
42
|
+
end = len(frames)
|
43
|
+
if start != 0:
|
44
|
+
# Include the last frame of the previous segment
|
45
|
+
segment = frames[start - overlap : end]
|
46
|
+
else:
|
47
|
+
segment = frames[start:end]
|
48
|
+
segments.append(segment)
|
49
|
+
start += segment_size
|
50
|
+
segment_count += 1
|
51
|
+
return segments
|
52
|
+
|
53
|
+
|
54
|
+
def process_segment(
|
55
|
+
segment_frames: List[np.ndarray],
|
56
|
+
od_model: ODModels,
|
57
|
+
prompt: str,
|
58
|
+
fine_tune_id: Optional[str],
|
59
|
+
chunk_length: Optional[int],
|
60
|
+
image_size: Tuple[int, ...],
|
61
|
+
segment_index: int,
|
62
|
+
object_detection_tool: Callable,
|
63
|
+
) -> Any:
|
64
|
+
"""
|
65
|
+
Processes a segment of frames with the specified object detection model.
|
66
|
+
|
67
|
+
Args:
|
68
|
+
segment_frames (List[np.ndarray]): Frames in the segment.
|
69
|
+
od_model (ODModels): Object detection model to use.
|
70
|
+
prompt (str): Prompt for the model.
|
71
|
+
fine_tune_id (Optional[str]): Fine-tune model ID.
|
72
|
+
chunk_length (Optional[int]): Chunk length for processing.
|
73
|
+
image_size (Tuple[int, int]): Size of the images.
|
74
|
+
segment_index (int): Index of the segment.
|
75
|
+
object_detection_tool (Callable): Object detection tool to use.
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
Any: Detections for the segment.
|
79
|
+
"""
|
80
|
+
segment_results: List[Optional[List[Dict[str, Any]]]] = [None] * len(segment_frames)
|
81
|
+
|
82
|
+
if chunk_length is None:
|
83
|
+
step = 1
|
84
|
+
elif chunk_length <= 0:
|
85
|
+
raise ValueError("chunk_length must be a positive integer or None.")
|
86
|
+
else:
|
87
|
+
step = chunk_length
|
88
|
+
|
89
|
+
function_name = ""
|
90
|
+
|
91
|
+
for idx in range(0, len(segment_frames), step):
|
92
|
+
frame_number = idx
|
93
|
+
segment_results[idx], function_name = object_detection_tool(
|
94
|
+
od_model, prompt, segment_index, frame_number, fine_tune_id, segment_frames
|
95
|
+
)
|
96
|
+
|
97
|
+
transformed_detections = transform_detections(
|
98
|
+
segment_results, image_size, segment_index
|
99
|
+
)
|
100
|
+
|
101
|
+
buffer_bytes = frames_to_bytes(segment_frames)
|
102
|
+
files = [("video", buffer_bytes)]
|
103
|
+
payload = {
|
104
|
+
"bboxes": json.dumps(transformed_detections),
|
105
|
+
"chunk_length_frames": chunk_length,
|
106
|
+
}
|
107
|
+
metadata = {"function_name": function_name}
|
108
|
+
|
109
|
+
segment_detections = send_task_inference_request(
|
110
|
+
payload,
|
111
|
+
"sam2",
|
112
|
+
files=files,
|
113
|
+
metadata=metadata,
|
114
|
+
)
|
115
|
+
|
116
|
+
return segment_detections
|
117
|
+
|
118
|
+
|
119
|
+
def transform_detections(
|
120
|
+
input_list: List[Optional[List[Dict[str, Any]]]],
|
121
|
+
image_size: Tuple[int, ...],
|
122
|
+
segment_index: int,
|
123
|
+
) -> List[Optional[Dict[str, Any]]]:
|
124
|
+
"""
|
125
|
+
Transforms raw detections into a standardized format.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
input_list (List[Optional[List[Dict[str, Any]]]]): Raw detections.
|
129
|
+
image_size (Tuple[int, int]): Size of the images.
|
130
|
+
segment_index (int): Index of the segment.
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
List[Optional[Dict[str, Any]]]: Transformed detections.
|
134
|
+
"""
|
135
|
+
output_list: List[Optional[Dict[str, Any]]] = []
|
136
|
+
for frame_idx, frame in enumerate(input_list):
|
137
|
+
if frame is not None:
|
138
|
+
labels = [detection["label"] for detection in frame]
|
139
|
+
bboxes = [
|
140
|
+
denormalize_bbox(detection["bbox"], image_size) for detection in frame
|
141
|
+
]
|
142
|
+
|
143
|
+
output_list.append(
|
144
|
+
{
|
145
|
+
"labels": labels,
|
146
|
+
"bboxes": bboxes,
|
147
|
+
}
|
148
|
+
)
|
149
|
+
else:
|
150
|
+
output_list.append(None)
|
151
|
+
return output_list
|
152
|
+
|
153
|
+
|
154
|
+
def _calculate_mask_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
|
155
|
+
mask1 = mask1.astype(bool)
|
156
|
+
mask2 = mask2.astype(bool)
|
157
|
+
|
158
|
+
intersection = np.sum(np.logical_and(mask1, mask2))
|
159
|
+
union = np.sum(np.logical_or(mask1, mask2))
|
160
|
+
|
161
|
+
if union == 0:
|
162
|
+
iou = 0.0
|
163
|
+
else:
|
164
|
+
iou = intersection / union
|
165
|
+
|
166
|
+
return iou
|
167
|
+
|
168
|
+
|
169
|
+
def _match_by_iou(
|
170
|
+
first_param: List[Dict],
|
171
|
+
second_param: List[Dict],
|
172
|
+
iou_threshold: float = 0.8,
|
173
|
+
) -> Tuple[List[Dict], Dict[int, int]]:
|
174
|
+
max_id = max((item["id"] for item in first_param), default=0)
|
175
|
+
|
176
|
+
matched_new_item_indices = set()
|
177
|
+
id_mapping = {}
|
178
|
+
|
179
|
+
for new_index, new_item in enumerate(second_param):
|
180
|
+
matched_id = None
|
181
|
+
|
182
|
+
for existing_item in first_param:
|
183
|
+
iou = _calculate_mask_iou(
|
184
|
+
existing_item["decoded_mask"], new_item["decoded_mask"]
|
185
|
+
)
|
186
|
+
if iou > iou_threshold:
|
187
|
+
matched_id = existing_item["id"]
|
188
|
+
matched_new_item_indices.add(new_index)
|
189
|
+
id_mapping[new_item["id"]] = matched_id
|
190
|
+
break
|
191
|
+
|
192
|
+
if matched_id:
|
193
|
+
new_item["id"] = matched_id
|
194
|
+
else:
|
195
|
+
max_id += 1
|
196
|
+
id_mapping[new_item["id"]] = max_id
|
197
|
+
new_item["id"] = max_id
|
198
|
+
|
199
|
+
unmatched_items = [
|
200
|
+
item for i, item in enumerate(second_param) if i not in matched_new_item_indices
|
201
|
+
]
|
202
|
+
combined_list = first_param + unmatched_items
|
203
|
+
|
204
|
+
return combined_list, id_mapping
|
205
|
+
|
206
|
+
|
207
|
+
def _update_ids(detections: List[Dict], id_mapping: Dict[int, int]) -> None:
|
208
|
+
for inner_list in detections:
|
209
|
+
for detection in inner_list:
|
210
|
+
if detection["id"] in id_mapping:
|
211
|
+
detection["id"] = id_mapping[detection["id"]]
|
212
|
+
else:
|
213
|
+
max_new_id = max(id_mapping.values(), default=0)
|
214
|
+
detection["id"] = max_new_id + 1
|
215
|
+
id_mapping[detection["id"]] = detection["id"]
|
216
|
+
|
217
|
+
|
218
|
+
def _convert_to_2d(detections_per_segment: List[Any]) -> List[Any]:
|
219
|
+
result = []
|
220
|
+
for i, segment in enumerate(detections_per_segment):
|
221
|
+
if i == 0:
|
222
|
+
result.extend(segment)
|
223
|
+
else:
|
224
|
+
result.extend(segment[1:])
|
225
|
+
return result
|
226
|
+
|
227
|
+
|
228
|
+
def merge_segments(detections_per_segment: List[Any]) -> List[Any]:
|
229
|
+
"""
|
230
|
+
Merges detections from all segments into a unified result.
|
231
|
+
|
232
|
+
Args:
|
233
|
+
detections_per_segment (List[Any]): List of detections per segment.
|
234
|
+
|
235
|
+
Returns:
|
236
|
+
List[Any]: Merged detections.
|
237
|
+
"""
|
238
|
+
for segment in detections_per_segment:
|
239
|
+
for detection in segment:
|
240
|
+
for item in detection:
|
241
|
+
item["decoded_mask"] = rle_decode_array(item["mask"])
|
242
|
+
|
243
|
+
for segment_idx in range(len(detections_per_segment) - 1):
|
244
|
+
combined_detection, id_mapping = _match_by_iou(
|
245
|
+
detections_per_segment[segment_idx][-1],
|
246
|
+
detections_per_segment[segment_idx + 1][0],
|
247
|
+
)
|
248
|
+
_update_ids(detections_per_segment[segment_idx + 1], id_mapping)
|
249
|
+
|
250
|
+
merged_result = _convert_to_2d(detections_per_segment)
|
251
|
+
|
252
|
+
return merged_result
|
253
|
+
|
254
|
+
|
255
|
+
def post_process(
|
256
|
+
merged_detections: List[Any],
|
257
|
+
image_size: Tuple[int, ...],
|
258
|
+
) -> Dict[str, Any]:
|
259
|
+
"""
|
260
|
+
Performs post-processing on merged detections, including NMS and preparing display data.
|
261
|
+
|
262
|
+
Args:
|
263
|
+
merged_detections (List[Any]): Merged detections from all segments.
|
264
|
+
image_size (Tuple[int, int]): Size of the images.
|
265
|
+
|
266
|
+
Returns:
|
267
|
+
Dict[str, Any]: Post-processed data including return_data and display_data.
|
268
|
+
"""
|
269
|
+
return_data = []
|
270
|
+
for frame_idx, frame in enumerate(merged_detections):
|
271
|
+
return_frame_data = []
|
272
|
+
for detection in frame:
|
273
|
+
label = f"{detection['id']}: {detection['label']}"
|
274
|
+
return_frame_data.append(
|
275
|
+
{
|
276
|
+
"label": label,
|
277
|
+
"mask": detection["decoded_mask"],
|
278
|
+
"rle": detection["mask"],
|
279
|
+
"score": 1.0,
|
280
|
+
}
|
281
|
+
)
|
282
|
+
del detection["decoded_mask"]
|
283
|
+
return_data.append(return_frame_data)
|
284
|
+
|
285
|
+
return_data = add_bboxes_from_masks(return_data)
|
286
|
+
return_data = nms(return_data, iou_threshold=0.95)
|
287
|
+
|
288
|
+
# We save the RLE for display purposes, re-calculting RLE can get very expensive.
|
289
|
+
# Deleted here because we are returning the numpy masks instead
|
290
|
+
display_data = []
|
291
|
+
for frame in return_data:
|
292
|
+
display_frame_data = []
|
293
|
+
for obj in frame:
|
294
|
+
display_frame_data.append(
|
295
|
+
{
|
296
|
+
"label": obj["label"],
|
297
|
+
"bbox": denormalize_bbox(obj["bbox"], image_size),
|
298
|
+
"mask": obj["rle"],
|
299
|
+
"score": obj["score"],
|
300
|
+
}
|
301
|
+
)
|
302
|
+
del obj["rle"]
|
303
|
+
display_data.append(display_frame_data)
|
304
|
+
|
305
|
+
return {"return_data": return_data, "display_data": display_data}
|
@@ -1,23 +1,23 @@
|
|
1
|
-
vision_agent/.sim_tools/df.csv,sha256=
|
2
|
-
vision_agent/.sim_tools/embs.npy,sha256=
|
1
|
+
vision_agent/.sim_tools/df.csv,sha256=Vamicw8MiSGildK1r3-HXY4cKiq17GZxsgBsHbk7jpM,42158
|
2
|
+
vision_agent/.sim_tools/embs.npy,sha256=YJe8EcKVNmeX_75CS2T1sbY-sUS_1HQAMT-34zc18a0,254080
|
3
3
|
vision_agent/__init__.py,sha256=EAb4-f9iyuEYkBrX4ag1syM8Syx8118_t0R6_C34M9w,57
|
4
4
|
vision_agent/agent/README.md,sha256=Q4w7FWw38qaWosQYAZ7NqWx8Q5XzuWrlv7nLhjUd1-8,5527
|
5
5
|
vision_agent/agent/__init__.py,sha256=M8CffavdIh8Zh-skznLHIaQkYGCGK7vk4dq1FaVkbs4,617
|
6
6
|
vision_agent/agent/agent.py,sha256=_1tHWAs7Jm5tqDzEcPfCRvJV3uRRveyh4n9_9pd6I1w,1565
|
7
|
-
vision_agent/agent/agent_utils.py,sha256=
|
7
|
+
vision_agent/agent/agent_utils.py,sha256=pP4u5tiami7C3ChgjgYLqJITnmkTI1_GsUj6g5czSRk,13994
|
8
8
|
vision_agent/agent/types.py,sha256=DkFm3VMMrKlhYyfxEmZx4keppD72Ov3wmLCbM2J2o10,2437
|
9
9
|
vision_agent/agent/vision_agent.py,sha256=I75bEU-os9Lf9OSICKfvQ_H_ftg-zOwgTwWnu41oIdo,23555
|
10
10
|
vision_agent/agent/vision_agent_coder.py,sha256=flUxOibyGZK19BCSK5mhaD3HjCxHw6c6FtKom6N2q1E,27359
|
11
11
|
vision_agent/agent/vision_agent_coder_prompts.py,sha256=gPLVXQMNSzYnQYpNm0wlH_5FPkOTaFDV24bqzK3jQ40,12221
|
12
|
-
vision_agent/agent/vision_agent_coder_prompts_v2.py,sha256=
|
13
|
-
vision_agent/agent/vision_agent_coder_v2.py,sha256=
|
12
|
+
vision_agent/agent/vision_agent_coder_prompts_v2.py,sha256=idmSMfxebPULqqvllz3gqRzGDchEvS5dkGngvBs4PGo,4872
|
13
|
+
vision_agent/agent/vision_agent_coder_v2.py,sha256=i1qgXp5YsWVRoA_qO429Ef-aKZBakveCl1F_2ZbSzk8,16287
|
14
14
|
vision_agent/agent/vision_agent_planner.py,sha256=fFzjNkZBKkh8Y_oS06ATI4qz31xmIJvixb_tV1kX8KA,18590
|
15
15
|
vision_agent/agent/vision_agent_planner_prompts.py,sha256=mn9NlZpRkW4XAvlNuMZwIs1ieHCFds5aYZJ55WXupZY,6733
|
16
|
-
vision_agent/agent/vision_agent_planner_prompts_v2.py,sha256=
|
16
|
+
vision_agent/agent/vision_agent_planner_prompts_v2.py,sha256=YgemW2PRPYd8o8XpmwSJBUOJSxMUXMNr2DZNQnS4jEI,34988
|
17
17
|
vision_agent/agent/vision_agent_planner_v2.py,sha256=vvxfmGydBIKB8CtNSAJyPvdEXkG7nIO5-Hs2SjNc48Y,20465
|
18
18
|
vision_agent/agent/vision_agent_prompts.py,sha256=NtGdCfzzilCRtscKALC9FK55d1h4CBpMnbhLzg0PYlc,13772
|
19
19
|
vision_agent/agent/vision_agent_prompts_v2.py,sha256=-vCWat-ARlCOOOeIDIFhg-kcwRRwjTXYEwsvvqPeaCs,1972
|
20
|
-
vision_agent/agent/vision_agent_v2.py,sha256=
|
20
|
+
vision_agent/agent/vision_agent_v2.py,sha256=1wu_vH_onic2kLYPKW2nAF2e6Zz5vmUt5Acv4Seq3sQ,10796
|
21
21
|
vision_agent/clients/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
22
22
|
vision_agent/clients/http.py,sha256=k883i6M_4nl7zwwHSI-yP5sAgQZIDPM1nrKD6YFJ3Xs,2009
|
23
23
|
vision_agent/clients/landing_public_api.py,sha256=lU2ev6E8NICmR8DMUljuGcVFy5VNJQ4WQkWC8WnnJEc,1503
|
@@ -28,19 +28,20 @@ vision_agent/lmm/lmm.py,sha256=x_nIyDNDZwq4-pfjnJTmcyyJZ2_B7TjkA5jZp88YVO8,17103
|
|
28
28
|
vision_agent/lmm/types.py,sha256=ZEXR_ptBL0ZwDMTDYkgxUCmSZFmBYPQd2jreNzr_8UY,221
|
29
29
|
vision_agent/tools/__init__.py,sha256=15O7eQVn0bitmzUO5OxKdA618PoiLt6Z02gmKsSNMFM,2765
|
30
30
|
vision_agent/tools/meta_tools.py,sha256=TPeS7QWnc_PmmU_ndiDT03dXbQ5yDSP33E7U8cSj7Ls,28660
|
31
|
-
vision_agent/tools/planner_tools.py,sha256=
|
31
|
+
vision_agent/tools/planner_tools.py,sha256=qQvPuCif-KbFi7KsXKkTCfpgEQEJJ6oq6WB3gOuG2Xg,13686
|
32
32
|
vision_agent/tools/prompts.py,sha256=V1z4YJLXZuUl_iZ5rY0M5hHc_2tmMEUKr0WocXKGt4E,1430
|
33
33
|
vision_agent/tools/tool_utils.py,sha256=q9cqXO2AvigUdO1krjnOy8o0goYhgS6eILl6-F5Kxyk,10211
|
34
|
-
vision_agent/tools/tools.py,sha256=
|
34
|
+
vision_agent/tools/tools.py,sha256=zqoo4ml9ZS99kOeOIN6Zplq7pxOwBrVZKKFUVIzsjfw,91712
|
35
35
|
vision_agent/tools/tools_types.py,sha256=8hYf2OZhI58gvf65KGaeGkt4EQ56nwLFqIQDPHioOBc,2339
|
36
36
|
vision_agent/utils/__init__.py,sha256=QKk4zVjMwGxQI0MQ-aZZA50N-qItxRY4EB9CwQkZ2HY,185
|
37
37
|
vision_agent/utils/exceptions.py,sha256=booSPSuoULF7OXRr_YbC4dtKt6gM_HyiFQHBuaW86C4,2052
|
38
38
|
vision_agent/utils/execute.py,sha256=vOEP5Ys7S2lc0_7pOJbgk7OaWi85hrCNu9_8Bo3zk6I,29356
|
39
39
|
vision_agent/utils/image_utils.py,sha256=z_ONgcza125B10NkoGwPOzXnL470bpTWZbkB16NeeH0,12188
|
40
|
-
vision_agent/utils/sim.py,sha256=
|
40
|
+
vision_agent/utils/sim.py,sha256=qr-6UWAxxGwtwIAKZjZCY_pu9VwBI_TTB8bfrGsaABg,9282
|
41
41
|
vision_agent/utils/type_defs.py,sha256=BE12s3JNQy36QvauXHjwyeffVh5enfcvd4vTzSwvEZI,1384
|
42
42
|
vision_agent/utils/video.py,sha256=e1VwKhXzzlC5LcFMyrcQYrPnpnX4wxDpnQ-76sB4jgM,6001
|
43
|
-
vision_agent
|
44
|
-
vision_agent-0.2.
|
45
|
-
vision_agent-0.2.
|
46
|
-
vision_agent-0.2.
|
43
|
+
vision_agent/utils/video_tracking.py,sha256=EeOiSY8gjvvneuAnv-BO7yOyMBF_-1Irk_lLLOt3bDM,9452
|
44
|
+
vision_agent-0.2.226.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
45
|
+
vision_agent-0.2.226.dist-info/METADATA,sha256=_7jZokNbQLK6Ups2psyRKbPDjUIzU3daxCpfrHZ6gSU,20039
|
46
|
+
vision_agent-0.2.226.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
|
47
|
+
vision_agent-0.2.226.dist-info/RECORD,,
|
File without changes
|
File without changes
|