vision-agent 0.2.224__py3-none-any.whl → 0.2.226__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- vision_agent/.sim_tools/df.csv +49 -91
- vision_agent/.sim_tools/embs.npy +0 -0
- vision_agent/agent/agent_utils.py +13 -0
- vision_agent/agent/vision_agent_coder_prompts_v2.py +1 -1
- vision_agent/agent/vision_agent_coder_v2.py +6 -1
- vision_agent/agent/vision_agent_planner_prompts_v2.py +42 -33
- vision_agent/agent/vision_agent_v2.py +30 -22
- vision_agent/tools/planner_tools.py +4 -2
- vision_agent/tools/tools.py +119 -123
- vision_agent/utils/sim.py +6 -0
- vision_agent/utils/video_tracking.py +305 -0
- {vision_agent-0.2.224.dist-info → vision_agent-0.2.226.dist-info}/METADATA +1 -1
- {vision_agent-0.2.224.dist-info → vision_agent-0.2.226.dist-info}/RECORD +15 -14
- {vision_agent-0.2.224.dist-info → vision_agent-0.2.226.dist-info}/LICENSE +0 -0
- {vision_agent-0.2.224.dist-info → vision_agent-0.2.226.dist-info}/WHEEL +0 -0
vision_agent/.sim_tools/df.csv
CHANGED
@@ -65,25 +65,30 @@ desc,doc,name
|
|
65
65
|
},
|
66
66
|
]
|
67
67
|
",owlv2_sam2_instance_segmentation
|
68
|
-
"'owlv2_sam2_video_tracking' is a tool that can segment multiple objects given a text prompt such as category names or referring expressions. The categories in the text prompt are separated by commas. It returns a list of bounding boxes, label names,
|
69
|
-
'owlv2_sam2_video_tracking' is a tool that can segment multiple
|
70
|
-
prompt such as category names or referring
|
71
|
-
prompt are separated by commas. It returns
|
72
|
-
|
68
|
+
"'owlv2_sam2_video_tracking' is a tool that can track and segment multiple objects in a video given a text prompt such as category names or referring expressions. The categories in the text prompt are separated by commas. It returns a list of bounding boxes, label names, masks and associated probability scores and is useful for tracking and counting without duplicating counts.","owlv2_sam2_video_tracking(prompt: str, frames: List[numpy.ndarray], chunk_length: Optional[int] = 10, fine_tune_id: Optional[str] = None) -> List[List[Dict[str, Any]]]:
|
69
|
+
'owlv2_sam2_video_tracking' is a tool that can track and segment multiple
|
70
|
+
objects in a video given a text prompt such as category names or referring
|
71
|
+
expressions. The categories in the text prompt are separated by commas. It returns
|
72
|
+
a list of bounding boxes, label names, masks and associated probability scores and
|
73
|
+
is useful for tracking and counting without duplicating counts.
|
73
74
|
|
74
75
|
Parameters:
|
75
76
|
prompt (str): The prompt to ground to the image.
|
76
|
-
|
77
|
+
frames (List[np.ndarray]): The list of frames to ground the prompt to.
|
78
|
+
chunk_length (Optional[int]): The number of frames to re-run owlv2 to find
|
79
|
+
new objects.
|
77
80
|
fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the
|
78
81
|
fine-tuned model ID here to use it.
|
79
82
|
|
80
83
|
Returns:
|
81
|
-
List[Dict[str, Any]]: A list of dictionaries containing the
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
the
|
84
|
+
List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
|
85
|
+
label, segmentation mask and bounding boxes. The outer list represents each
|
86
|
+
frame and the inner list is the entities per frame. The detected objects
|
87
|
+
have normalized coordinates between 0 and 1 (xmin, ymin, xmax, ymax). xmin
|
88
|
+
and ymin are the coordinates of the top-left and xmax and ymax are the
|
89
|
+
coordinates of the bottom-right of the bounding box. The mask is binary 2D
|
90
|
+
numpy array where 1 indicates the object and 0 indicates the background.
|
91
|
+
The label names are prefixed with their ID represent the total count.
|
87
92
|
|
88
93
|
Example
|
89
94
|
-------
|
@@ -170,25 +175,28 @@ desc,doc,name
|
|
170
175
|
},
|
171
176
|
]
|
172
177
|
",countgd_sam2_instance_segmentation
|
173
|
-
"'countgd_sam2_video_tracking' is a tool that can segment multiple objects given a text prompt such as category names or referring expressions. The categories in the text prompt are separated by commas. It returns a list of bounding boxes, label names,
|
174
|
-
'countgd_sam2_video_tracking' is a tool that can segment multiple
|
175
|
-
prompt such as category names or referring
|
176
|
-
prompt are separated by commas. It returns
|
177
|
-
|
178
|
+
"'countgd_sam2_video_tracking' is a tool that can track and segment multiple objects in a video given a text prompt such as category names or referring expressions. The categories in the text prompt are separated by commas. It returns a list of bounding boxes, label names, masks and associated probability scores and is useful for tracking and counting without duplicating counts.","countgd_sam2_video_tracking(prompt: str, frames: List[numpy.ndarray], chunk_length: Optional[int] = 10) -> List[List[Dict[str, Any]]]:
|
179
|
+
'countgd_sam2_video_tracking' is a tool that can track and segment multiple
|
180
|
+
objects in a video given a text prompt such as category names or referring
|
181
|
+
expressions. The categories in the text prompt are separated by commas. It returns
|
182
|
+
a list of bounding boxes, label names, masks and associated probability scores and
|
183
|
+
is useful for tracking and counting without duplicating counts.
|
178
184
|
|
179
185
|
Parameters:
|
180
186
|
prompt (str): The prompt to ground to the image.
|
181
|
-
|
182
|
-
chunk_length (Optional[int]): The number of frames to re-run
|
187
|
+
frames (List[np.ndarray]): The list of frames to ground the prompt to.
|
188
|
+
chunk_length (Optional[int]): The number of frames to re-run countgd to find
|
183
189
|
new objects.
|
184
190
|
|
185
191
|
Returns:
|
186
|
-
List[Dict[str, Any]]: A list of dictionaries containing the
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
the
|
192
|
+
List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
|
193
|
+
label, segmentation mask and bounding boxes. The outer list represents each
|
194
|
+
frame and the inner list is the entities per frame. The detected objects
|
195
|
+
have normalized coordinates between 0 and 1 (xmin, ymin, xmax, ymax). xmin
|
196
|
+
and ymin are the coordinates of the top-left and xmax and ymax are the
|
197
|
+
coordinates of the bottom-right of the bounding box. The mask is binary 2D
|
198
|
+
numpy array where 1 indicates the object and 0 indicates the background.
|
199
|
+
The label names are prefixed with their ID represent the total count.
|
192
200
|
|
193
201
|
Example
|
194
202
|
-------
|
@@ -265,12 +273,12 @@ desc,doc,name
|
|
265
273
|
},
|
266
274
|
]
|
267
275
|
",florence2_sam2_instance_segmentation
|
268
|
-
'florence2_sam2_video_tracking' is a tool that can
|
269
|
-
'florence2_sam2_video_tracking' is a tool that can
|
270
|
-
|
271
|
-
expressions.
|
272
|
-
|
273
|
-
|
276
|
+
"'florence2_sam2_video_tracking' is a tool that can track and segment multiple objects in a video given a text prompt such as category names or referring expressions. The categories in the text prompt are separated by commas. It returns a list of bounding boxes, label names, masks and associated probability scores and is useful for tracking and counting without duplicating counts.","florence2_sam2_video_tracking(prompt: str, frames: List[numpy.ndarray], chunk_length: Optional[int] = 10, fine_tune_id: Optional[str] = None) -> List[List[Dict[str, Any]]]:
|
277
|
+
'florence2_sam2_video_tracking' is a tool that can track and segment multiple
|
278
|
+
objects in a video given a text prompt such as category names or referring
|
279
|
+
expressions. The categories in the text prompt are separated by commas. It returns
|
280
|
+
a list of bounding boxes, label names, masks and associated probability scores and
|
281
|
+
is useful for tracking and counting without duplicating counts.
|
274
282
|
|
275
283
|
Parameters:
|
276
284
|
prompt (str): The prompt to ground to the video.
|
@@ -282,10 +290,13 @@ desc,doc,name
|
|
282
290
|
|
283
291
|
Returns:
|
284
292
|
List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
293
|
+
label, segmentation mask and bounding boxes. The outer list represents each
|
294
|
+
frame and the inner list is the entities per frame. The detected objects
|
295
|
+
have normalized coordinates between 0 and 1 (xmin, ymin, xmax, ymax). xmin
|
296
|
+
and ymin are the coordinates of the top-left and xmax and ymax are the
|
297
|
+
coordinates of the bottom-right of the bounding box. The mask is binary 2D
|
298
|
+
numpy array where 1 indicates the object and 0 indicates the background.
|
299
|
+
The label names are prefixed with their ID represent the total count.
|
289
300
|
|
290
301
|
Example
|
291
302
|
-------
|
@@ -445,43 +456,6 @@ desc,doc,name
|
|
445
456
|
>>> qwen2_vl_video_vqa('Which football player made the goal?', frames)
|
446
457
|
'Lionel Messi'
|
447
458
|
",qwen2_vl_video_vqa
|
448
|
-
"'detr_segmentation' is a tool that can segment common objects in an image without any text prompt. It returns a list of detected objects as labels, their regions as masks and their scores.","detr_segmentation(image: numpy.ndarray) -> List[Dict[str, Any]]:
|
449
|
-
'detr_segmentation' is a tool that can segment common objects in an
|
450
|
-
image without any text prompt. It returns a list of detected objects
|
451
|
-
as labels, their regions as masks and their scores.
|
452
|
-
|
453
|
-
Parameters:
|
454
|
-
image (np.ndarray): The image used to segment things and objects
|
455
|
-
|
456
|
-
Returns:
|
457
|
-
List[Dict[str, Any]]: A list of dictionaries containing the score, label
|
458
|
-
and mask of the detected objects. The mask is binary 2D numpy array where 1
|
459
|
-
indicates the object and 0 indicates the background.
|
460
|
-
|
461
|
-
Example
|
462
|
-
-------
|
463
|
-
>>> detr_segmentation(image)
|
464
|
-
[
|
465
|
-
{
|
466
|
-
'score': 0.45,
|
467
|
-
'label': 'window',
|
468
|
-
'mask': array([[0, 0, 0, ..., 0, 0, 0],
|
469
|
-
[0, 0, 0, ..., 0, 0, 0],
|
470
|
-
...,
|
471
|
-
[0, 0, 0, ..., 0, 0, 0],
|
472
|
-
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
|
473
|
-
},
|
474
|
-
{
|
475
|
-
'score': 0.70,
|
476
|
-
'label': 'bird',
|
477
|
-
'mask': array([[0, 0, 0, ..., 0, 0, 0],
|
478
|
-
[0, 0, 0, ..., 0, 0, 0],
|
479
|
-
...,
|
480
|
-
[0, 0, 0, ..., 0, 0, 0],
|
481
|
-
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
|
482
|
-
},
|
483
|
-
]
|
484
|
-
",detr_segmentation
|
485
459
|
'depth_anything_v2' is a tool that runs depth_anythingv2 model to generate a depth image from a given RGB image. The returned depth image is monochrome and represents depth values as pixel intesities with pixel values ranging from 0 to 255.,"depth_anything_v2(image: numpy.ndarray) -> numpy.ndarray:
|
486
460
|
'depth_anything_v2' is a tool that runs depth_anythingv2 model to generate a
|
487
461
|
depth image from a given RGB image. The returned depth image is monochrome and
|
@@ -522,22 +496,6 @@ desc,doc,name
|
|
522
496
|
[10, 11, 15, ..., 202, 202, 205],
|
523
497
|
[10, 10, 10, ..., 200, 200, 200]], dtype=uint8),
|
524
498
|
",generate_pose_image
|
525
|
-
'vit_image_classification' is a tool that can classify an image. It returns a list of classes and their probability scores based on image content.,"vit_image_classification(image: numpy.ndarray) -> Dict[str, Any]:
|
526
|
-
'vit_image_classification' is a tool that can classify an image. It returns a
|
527
|
-
list of classes and their probability scores based on image content.
|
528
|
-
|
529
|
-
Parameters:
|
530
|
-
image (np.ndarray): The image to classify or tag
|
531
|
-
|
532
|
-
Returns:
|
533
|
-
Dict[str, Any]: A dictionary containing the labels and scores. One dictionary
|
534
|
-
contains a list of labels and other a list of scores.
|
535
|
-
|
536
|
-
Example
|
537
|
-
-------
|
538
|
-
>>> vit_image_classification(image)
|
539
|
-
{""labels"": [""leopard"", ""lemur, otter"", ""bird""], ""scores"": [0.68, 0.30, 0.02]},
|
540
|
-
",vit_image_classification
|
541
499
|
'vit_nsfw_classification' is a tool that can classify an image as 'nsfw' or 'normal'. It returns the predicted label and their probability scores based on image content.,"vit_nsfw_classification(image: numpy.ndarray) -> Dict[str, Any]:
|
542
500
|
'vit_nsfw_classification' is a tool that can classify an image as 'nsfw' or 'normal'.
|
543
501
|
It returns the predicted label and their probability scores based on image content.
|
@@ -566,7 +524,7 @@ desc,doc,name
|
|
566
524
|
prompt (str): The question about the video
|
567
525
|
frames (List[np.ndarray]): The reference frames used for the question
|
568
526
|
model (str): The model to use for the inference. Valid values are
|
569
|
-
'qwen2vl', 'gpt4o'
|
527
|
+
'qwen2vl', 'gpt4o'.
|
570
528
|
chunk_length_frames (Optional[int]): length of each chunk in frames
|
571
529
|
|
572
530
|
Returns:
|
@@ -641,7 +599,7 @@ desc,doc,name
|
|
641
599
|
>>> closest_distance(det1, det2, image_size)
|
642
600
|
141.42
|
643
601
|
",minimum_distance
|
644
|
-
"'extract_frames_and_timestamps' extracts frames and timestamps from a video which can be a file path, url or youtube link, returns a list of dictionaries with keys ""frame"" and ""timestamp"" where ""frame"" is a numpy array and ""timestamp"" is the relative time in seconds where the frame was captured. The frame is a numpy array.","extract_frames_and_timestamps(video_uri: Union[str, pathlib.Path], fps: float =
|
602
|
+
"'extract_frames_and_timestamps' extracts frames and timestamps from a video which can be a file path, url or youtube link, returns a list of dictionaries with keys ""frame"" and ""timestamp"" where ""frame"" is a numpy array and ""timestamp"" is the relative time in seconds where the frame was captured. The frame is a numpy array.","extract_frames_and_timestamps(video_uri: Union[str, pathlib.Path], fps: float = 5) -> List[Dict[str, Union[numpy.ndarray, float]]]:
|
645
603
|
'extract_frames_and_timestamps' extracts frames and timestamps from a video
|
646
604
|
which can be a file path, url or youtube link, returns a list of dictionaries
|
647
605
|
with keys ""frame"" and ""timestamp"" where ""frame"" is a numpy array and ""timestamp"" is
|
@@ -651,7 +609,7 @@ desc,doc,name
|
|
651
609
|
Parameters:
|
652
610
|
video_uri (Union[str, Path]): The path to the video file, url or youtube link
|
653
611
|
fps (float, optional): The frame rate per second to extract the frames. Defaults
|
654
|
-
to
|
612
|
+
to 5.
|
655
613
|
|
656
614
|
Returns:
|
657
615
|
List[Dict[str, Union[np.ndarray, float]]]: A list of dictionaries containing the
|
vision_agent/.sim_tools/embs.npy
CHANGED
Binary file
|
@@ -153,6 +153,19 @@ def format_plan_v2(plan: PlanContext) -> str:
|
|
153
153
|
return plan_str
|
154
154
|
|
155
155
|
|
156
|
+
def format_conversation(chat: List[AgentMessage]) -> str:
|
157
|
+
chat = copy.deepcopy(chat)
|
158
|
+
prompt = ""
|
159
|
+
for chat_i in chat:
|
160
|
+
if chat_i.role == "user":
|
161
|
+
prompt += f"USER: {chat_i.content}\n\n"
|
162
|
+
elif chat_i.role == "observation" or chat_i.role == "coder":
|
163
|
+
prompt += f"OBSERVATION: {chat_i.content}\n\n"
|
164
|
+
elif chat_i.role == "conversation":
|
165
|
+
prompt += f"AGENT: {chat_i.content}\n\n"
|
166
|
+
return prompt
|
167
|
+
|
168
|
+
|
156
169
|
def format_plans(plans: Dict[str, Any]) -> str:
|
157
170
|
plan_str = ""
|
158
171
|
for k, v in plans.items():
|
@@ -65,7 +65,7 @@ This is the documentation for the functions you have access to. You may call any
|
|
65
65
|
7. DO NOT assert the output value, run the code and assert only the output format or data structure.
|
66
66
|
8. DO NOT use try except block to handle the error, let the error be raised if the code is incorrect.
|
67
67
|
9. DO NOT import the testing function as it will available in the testing environment.
|
68
|
-
10. Print the output of the function that is being tested.
|
68
|
+
10. Print the output of the function that is being tested and ensure it is not empty.
|
69
69
|
11. Use the output of the function that is being tested as the return value of the testing function.
|
70
70
|
12. Run the testing function in the end and don't assign a variable to its output.
|
71
71
|
13. Output your test code using <code> tags:
|
@@ -202,7 +202,12 @@ def write_and_test_code(
|
|
202
202
|
tool_docs=tool_docs,
|
203
203
|
plan=plan,
|
204
204
|
)
|
205
|
-
|
205
|
+
try:
|
206
|
+
code = strip_function_calls(code)
|
207
|
+
except Exception:
|
208
|
+
# the code may be malformatted, this will fail in the exec call and the agent
|
209
|
+
# will attempt to debug it
|
210
|
+
pass
|
206
211
|
test = write_test(
|
207
212
|
tester=tester,
|
208
213
|
chat=chat,
|
@@ -136,8 +136,9 @@ Tool Documentation:
|
|
136
136
|
countgd_object_detection(prompt: str, image: numpy.ndarray, box_threshold: float = 0.23) -> List[Dict[str, Any]]:
|
137
137
|
'countgd_object_detection' is a tool that can detect multiple instances of an
|
138
138
|
object given a text prompt. It is particularly useful when trying to detect and
|
139
|
-
count a large number of objects.
|
140
|
-
|
139
|
+
count a large number of objects. You can optionally separate object names in the
|
140
|
+
prompt with commas. It returns a list of bounding boxes with normalized
|
141
|
+
coordinates, label names and associated confidence scores.
|
141
142
|
|
142
143
|
Parameters:
|
143
144
|
prompt (str): The object that needs to be counted.
|
@@ -272,40 +273,47 @@ OBSERVATION:
|
|
272
273
|
[get_tool_for_task output]
|
273
274
|
For tracking boxes moving on a conveyor belt, we need a tool that can consistently track the same box across frames without losing it or double counting. Looking at the outputs: florence2_sam2_video_tracking successfully tracks the single box across all 5 frames, maintaining consistent tracking IDs and showing the box's movement along the conveyor.
|
274
275
|
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
276
|
+
Tool Documentation:
|
277
|
+
def florence2_sam2_video_tracking(prompt: str, frames: List[np.ndarray], chunk_length: Optional[int] = 10) -> List[List[Dict[str, Any]]]:
|
278
|
+
'florence2_sam2_video_tracking' is a tool that can track and segment multiple
|
279
|
+
objects in a video given a text prompt such as category names or referring
|
280
|
+
expressions. The categories in the text prompt are separated by commas. It returns
|
281
|
+
a list of bounding boxes, label names, masks and associated probability scores and
|
282
|
+
is useful for tracking and counting without duplicating counts.
|
280
283
|
|
281
|
-
Parameters:
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
284
|
+
Parameters:
|
285
|
+
prompt (str): The prompt to ground to the video.
|
286
|
+
frames (List[np.ndarray]): The list of frames to ground the prompt to.
|
287
|
+
chunk_length (Optional[int]): The number of frames to re-run florence2 to find
|
288
|
+
new objects.
|
289
|
+
fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the
|
290
|
+
fine-tuned model ID here to use it.
|
286
291
|
|
287
|
-
Returns:
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
292
|
+
Returns:
|
293
|
+
List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
|
294
|
+
label, segmentation mask and bounding boxes. The outer list represents each
|
295
|
+
frame and the inner list is the entities per frame. The detected objects
|
296
|
+
have normalized coordinates between 0 and 1 (xmin, ymin, xmax, ymax). xmin
|
297
|
+
and ymin are the coordinates of the top-left and xmax and ymax are the
|
298
|
+
coordinates of the bottom-right of the bounding box. The mask is binary 2D
|
299
|
+
numpy array where 1 indicates the object and 0 indicates the background.
|
300
|
+
The label names are prefixed with their ID represent the total count.
|
293
301
|
|
294
|
-
Example
|
295
|
-
-------
|
296
|
-
|
297
|
-
[
|
302
|
+
Example
|
303
|
+
-------
|
304
|
+
>>> florence2_sam2_video_tracking("car, dinosaur", frames)
|
298
305
|
[
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
...,
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
306
|
+
[
|
307
|
+
{
|
308
|
+
'label': '0: dinosaur',
|
309
|
+
'bbox': [0.1, 0.11, 0.35, 0.4],
|
310
|
+
'mask': array([[0, 0, 0, ..., 0, 0, 0],
|
311
|
+
...,
|
312
|
+
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
|
313
|
+
},
|
314
|
+
],
|
315
|
+
...
|
316
|
+
]
|
309
317
|
[end of get_tool_for_task output]
|
310
318
|
<count>8</count>
|
311
319
|
|
@@ -691,7 +699,8 @@ FINALIZE_PLAN = """
|
|
691
699
|
4. Specifically call out the tools used and the order in which they were used. Only include tools obtained from calling `get_tool_for_task`.
|
692
700
|
5. Do not include {excluded_tools} tools in your instructions.
|
693
701
|
6. Add final instructions for visualizing the output with `overlay_bounding_boxes` or `overlay_segmentation_masks` and saving it to a file with `save_file` or `save_video`.
|
694
|
-
|
702
|
+
7. Use the default FPS for extracting frames from videos unless otherwise specified by the user.
|
703
|
+
8. Respond in the following format with JSON surrounded by <json> tags and code surrounded by <code> tags:
|
695
704
|
|
696
705
|
<json>
|
697
706
|
{{
|
@@ -1,13 +1,14 @@
|
|
1
1
|
import copy
|
2
2
|
import json
|
3
3
|
from pathlib import Path
|
4
|
-
from typing import Any, Callable, Dict, List, Optional, Union, cast
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
|
5
5
|
|
6
6
|
from vision_agent.agent import Agent, AgentCoder, VisionAgentCoderV2
|
7
7
|
from vision_agent.agent.agent_utils import (
|
8
8
|
add_media_to_chat,
|
9
9
|
convert_message_to_agentmessage,
|
10
10
|
extract_tag,
|
11
|
+
format_conversation,
|
11
12
|
)
|
12
13
|
from vision_agent.agent.types import (
|
13
14
|
AgentMessage,
|
@@ -22,19 +23,6 @@ from vision_agent.lmm.types import Message
|
|
22
23
|
from vision_agent.utils.execute import CodeInterpreter, CodeInterpreterFactory
|
23
24
|
|
24
25
|
|
25
|
-
def format_conversation(chat: List[AgentMessage]) -> str:
|
26
|
-
chat = copy.deepcopy(chat)
|
27
|
-
prompt = ""
|
28
|
-
for chat_i in chat:
|
29
|
-
if chat_i.role == "user":
|
30
|
-
prompt += f"USER: {chat_i.content}\n\n"
|
31
|
-
elif chat_i.role == "observation" or chat_i.role == "coder":
|
32
|
-
prompt += f"OBSERVATION: {chat_i.content}\n\n"
|
33
|
-
elif chat_i.role == "conversation":
|
34
|
-
prompt += f"AGENT: {chat_i.content}\n\n"
|
35
|
-
return prompt
|
36
|
-
|
37
|
-
|
38
26
|
def run_conversation(agent: LMM, chat: List[AgentMessage]) -> str:
|
39
27
|
# only keep last 10 messages
|
40
28
|
conv = format_conversation(chat[-10:])
|
@@ -55,23 +43,39 @@ def check_for_interaction(chat: List[AgentMessage]) -> bool:
|
|
55
43
|
|
56
44
|
def extract_conversation_for_generate_code(
|
57
45
|
chat: List[AgentMessage],
|
58
|
-
) -> List[AgentMessage]:
|
46
|
+
) -> Tuple[List[AgentMessage], Optional[str]]:
|
59
47
|
chat = copy.deepcopy(chat)
|
60
48
|
|
61
49
|
# if we are in the middle of an interaction, return all the intermediate planning
|
62
50
|
# steps
|
63
51
|
if check_for_interaction(chat):
|
64
|
-
return chat
|
52
|
+
return chat, None
|
65
53
|
|
66
54
|
extracted_chat = []
|
67
55
|
for chat_i in chat:
|
68
56
|
if chat_i.role == "user":
|
69
57
|
extracted_chat.append(chat_i)
|
70
58
|
elif chat_i.role == "coder":
|
71
|
-
if "<final_code>" in chat_i.content
|
59
|
+
if "<final_code>" in chat_i.content:
|
72
60
|
extracted_chat.append(chat_i)
|
73
61
|
|
74
|
-
|
62
|
+
# only keep the last <final_code> and <final_test>
|
63
|
+
final_code = None
|
64
|
+
extracted_chat_strip_code: List[AgentMessage] = []
|
65
|
+
for chat_i in reversed(extracted_chat):
|
66
|
+
if "<final_code>" in chat_i.content and final_code is None:
|
67
|
+
extracted_chat_strip_code = [chat_i] + extracted_chat_strip_code
|
68
|
+
final_code = extract_tag(chat_i.content, "final_code")
|
69
|
+
if final_code is not None:
|
70
|
+
test_code = extract_tag(chat_i.content, "final_test")
|
71
|
+
final_code += "\n" + test_code if test_code is not None else ""
|
72
|
+
|
73
|
+
if "<final_code>" in chat_i.content and final_code is not None:
|
74
|
+
continue
|
75
|
+
|
76
|
+
extracted_chat_strip_code = [chat_i] + extracted_chat_strip_code
|
77
|
+
|
78
|
+
return extracted_chat_strip_code[-5:], final_code
|
75
79
|
|
76
80
|
|
77
81
|
def maybe_run_action(
|
@@ -81,7 +85,7 @@ def maybe_run_action(
|
|
81
85
|
code_interpreter: Optional[CodeInterpreter] = None,
|
82
86
|
) -> Optional[List[AgentMessage]]:
|
83
87
|
if action == "generate_or_edit_vision_code":
|
84
|
-
extracted_chat = extract_conversation_for_generate_code(chat)
|
88
|
+
extracted_chat, _ = extract_conversation_for_generate_code(chat)
|
85
89
|
# there's an issue here because coder.generate_code will send it's code_context
|
86
90
|
# to the outside user via it's update_callback, but we don't necessarily have
|
87
91
|
# access to that update_callback here, so we re-create the message using
|
@@ -101,11 +105,15 @@ def maybe_run_action(
|
|
101
105
|
)
|
102
106
|
]
|
103
107
|
elif action == "edit_code":
|
104
|
-
extracted_chat = extract_conversation_for_generate_code(chat)
|
108
|
+
extracted_chat, final_code = extract_conversation_for_generate_code(chat)
|
105
109
|
plan_context = PlanContext(
|
106
110
|
plan="Edit the latest code observed in the fewest steps possible according to the user's feedback.",
|
107
|
-
instructions=[
|
108
|
-
|
111
|
+
instructions=[
|
112
|
+
chat_i.content
|
113
|
+
for chat_i in extracted_chat
|
114
|
+
if chat_i.role == "user" and "<final_code>" not in chat_i.content
|
115
|
+
],
|
116
|
+
code=final_code if final_code is not None else "",
|
109
117
|
)
|
110
118
|
context = coder.generate_code_from_plan(
|
111
119
|
extracted_chat, plan_context, code_interpreter=code_interpreter
|
@@ -193,8 +193,10 @@ def get_tool_for_task(
|
|
193
193
|
- Depth and pose estimation
|
194
194
|
- Video object tracking
|
195
195
|
|
196
|
-
|
197
|
-
|
196
|
+
Only ask for one type of task at a time, for example a task needing to identify
|
197
|
+
text is one OCR task while needing to identify non-text objects is an OD task. Wait
|
198
|
+
until the documentation is printed to use the function so you know what the input
|
199
|
+
and output signatures are.
|
198
200
|
|
199
201
|
Parameters:
|
200
202
|
task: str: The task to accomplish.
|
vision_agent/tools/tools.py
CHANGED
@@ -6,7 +6,6 @@ import tempfile
|
|
6
6
|
import urllib.request
|
7
7
|
from base64 import b64encode
|
8
8
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
9
|
-
from enum import Enum
|
10
9
|
from importlib import resources
|
11
10
|
from pathlib import Path
|
12
11
|
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
@@ -54,6 +53,13 @@ from vision_agent.utils.video import (
|
|
54
53
|
frames_to_bytes,
|
55
54
|
video_writer,
|
56
55
|
)
|
56
|
+
from vision_agent.utils.video_tracking import (
|
57
|
+
ODModels,
|
58
|
+
merge_segments,
|
59
|
+
post_process,
|
60
|
+
process_segment,
|
61
|
+
split_frames_into_segments,
|
62
|
+
)
|
57
63
|
|
58
64
|
register_heif_opener()
|
59
65
|
|
@@ -224,12 +230,6 @@ def sam2(
|
|
224
230
|
return ret["return_data"] # type: ignore
|
225
231
|
|
226
232
|
|
227
|
-
class ODModels(str, Enum):
|
228
|
-
COUNTGD = "countgd"
|
229
|
-
FLORENCE2 = "florence2"
|
230
|
-
OWLV2 = "owlv2"
|
231
|
-
|
232
|
-
|
233
233
|
def od_sam2_video_tracking(
|
234
234
|
od_model: ODModels,
|
235
235
|
prompt: str,
|
@@ -237,105 +237,92 @@ def od_sam2_video_tracking(
|
|
237
237
|
chunk_length: Optional[int] = 10,
|
238
238
|
fine_tune_id: Optional[str] = None,
|
239
239
|
) -> Dict[str, Any]:
|
240
|
-
|
240
|
+
SEGMENT_SIZE = 50
|
241
|
+
OVERLAP = 1 # Number of overlapping frames between segments
|
241
242
|
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
243
|
+
image_size = frames[0].shape[:2]
|
244
|
+
|
245
|
+
# Split frames into segments with overlap
|
246
|
+
segments = split_frames_into_segments(frames, SEGMENT_SIZE, OVERLAP)
|
247
|
+
|
248
|
+
def _apply_object_detection( # inner method to avoid circular importing issues.
|
249
|
+
od_model: ODModels,
|
250
|
+
prompt: str,
|
251
|
+
segment_index: int,
|
252
|
+
frame_number: int,
|
253
|
+
fine_tune_id: str,
|
254
|
+
segment_frames: list,
|
255
|
+
) -> tuple:
|
256
|
+
"""
|
257
|
+
Applies the specified object detection model to the given image.
|
258
|
+
|
259
|
+
Args:
|
260
|
+
od_model: The object detection model to use.
|
261
|
+
prompt: The prompt for the object detection model.
|
262
|
+
segment_index: The index of the current segment.
|
263
|
+
frame_number: The number of the current frame.
|
264
|
+
fine_tune_id: Optional fine-tune ID for the model.
|
265
|
+
segment_frames: List of frames for the current segment.
|
266
|
+
|
267
|
+
Returns:
|
268
|
+
A tuple containing the object detection results and the name of the function used.
|
269
|
+
"""
|
248
270
|
|
249
|
-
for idx in range(0, len(frames), step):
|
250
271
|
if od_model == ODModels.COUNTGD:
|
251
|
-
|
272
|
+
segment_results = countgd_object_detection(
|
273
|
+
prompt=prompt, image=segment_frames[frame_number]
|
274
|
+
)
|
252
275
|
function_name = "countgd_object_detection"
|
276
|
+
|
253
277
|
elif od_model == ODModels.OWLV2:
|
254
|
-
|
255
|
-
prompt=prompt,
|
278
|
+
segment_results = owlv2_object_detection(
|
279
|
+
prompt=prompt,
|
280
|
+
image=segment_frames[frame_number],
|
281
|
+
fine_tune_id=fine_tune_id,
|
256
282
|
)
|
257
283
|
function_name = "owlv2_object_detection"
|
284
|
+
|
258
285
|
elif od_model == ODModels.FLORENCE2:
|
259
|
-
|
260
|
-
prompt=prompt,
|
286
|
+
segment_results = florence2_object_detection(
|
287
|
+
prompt=prompt,
|
288
|
+
image=segment_frames[frame_number],
|
289
|
+
fine_tune_id=fine_tune_id,
|
261
290
|
)
|
262
291
|
function_name = "florence2_object_detection"
|
292
|
+
|
263
293
|
else:
|
264
294
|
raise NotImplementedError(
|
265
295
|
f"Object detection model '{od_model}' is not implemented."
|
266
296
|
)
|
267
297
|
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
{
|
285
|
-
"labels": labels,
|
286
|
-
"bboxes": bboxes,
|
287
|
-
}
|
288
|
-
)
|
289
|
-
else:
|
290
|
-
output_list.append(None)
|
291
|
-
|
292
|
-
return output_list
|
298
|
+
return segment_results, function_name
|
299
|
+
|
300
|
+
# Process each segment and collect detections
|
301
|
+
detections_per_segment: List[Any] = []
|
302
|
+
for segment_index, segment in enumerate(segments):
|
303
|
+
segment_detections = process_segment(
|
304
|
+
segment_frames=segment,
|
305
|
+
od_model=od_model,
|
306
|
+
prompt=prompt,
|
307
|
+
fine_tune_id=fine_tune_id,
|
308
|
+
chunk_length=chunk_length,
|
309
|
+
image_size=image_size,
|
310
|
+
segment_index=segment_index,
|
311
|
+
object_detection_tool=_apply_object_detection,
|
312
|
+
)
|
313
|
+
detections_per_segment.append(segment_detections)
|
293
314
|
|
294
|
-
|
315
|
+
merged_detections = merge_segments(detections_per_segment)
|
316
|
+
post_processed = post_process(merged_detections, image_size)
|
295
317
|
|
296
318
|
buffer_bytes = frames_to_bytes(frames)
|
297
319
|
files = [("video", buffer_bytes)]
|
298
|
-
payload = {"bboxes": json.dumps(output), "chunk_length_frames": chunk_length}
|
299
|
-
metadata = {"function_name": function_name}
|
300
|
-
|
301
|
-
detections = send_task_inference_request(
|
302
|
-
payload,
|
303
|
-
"sam2",
|
304
|
-
files=files,
|
305
|
-
metadata=metadata,
|
306
|
-
)
|
307
|
-
|
308
|
-
return_data = []
|
309
|
-
for frame in detections:
|
310
|
-
return_frame_data = []
|
311
|
-
for detection in frame:
|
312
|
-
mask = rle_decode_array(detection["mask"])
|
313
|
-
label = str(detection["id"]) + ": " + detection["label"]
|
314
|
-
return_frame_data.append(
|
315
|
-
{"label": label, "mask": mask, "score": 1.0, "rle": detection["mask"]}
|
316
|
-
)
|
317
|
-
return_data.append(return_frame_data)
|
318
|
-
return_data = add_bboxes_from_masks(return_data)
|
319
|
-
return_data = nms(return_data, iou_threshold=0.95)
|
320
320
|
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
for obj in frame:
|
327
|
-
display_frame_data.append(
|
328
|
-
{
|
329
|
-
"label": obj["label"],
|
330
|
-
"score": obj["score"],
|
331
|
-
"bbox": denormalize_bbox(obj["bbox"], image_size),
|
332
|
-
"mask": obj["rle"],
|
333
|
-
}
|
334
|
-
)
|
335
|
-
del obj["rle"]
|
336
|
-
display_data.append(display_frame_data)
|
337
|
-
|
338
|
-
return {"files": files, "return_data": return_data, "display_data": detections}
|
321
|
+
return {
|
322
|
+
"files": files,
|
323
|
+
"return_data": post_processed["return_data"],
|
324
|
+
"display_data": post_processed["display_data"],
|
325
|
+
}
|
339
326
|
|
340
327
|
|
341
328
|
# Owl V2 Tools
|
@@ -528,24 +515,29 @@ def owlv2_sam2_video_tracking(
|
|
528
515
|
chunk_length: Optional[int] = 10,
|
529
516
|
fine_tune_id: Optional[str] = None,
|
530
517
|
) -> List[List[Dict[str, Any]]]:
|
531
|
-
"""'owlv2_sam2_video_tracking' is a tool that can segment multiple
|
532
|
-
prompt such as category names or referring
|
533
|
-
prompt are separated by commas. It returns
|
534
|
-
|
518
|
+
"""'owlv2_sam2_video_tracking' is a tool that can track and segment multiple
|
519
|
+
objects in a video given a text prompt such as category names or referring
|
520
|
+
expressions. The categories in the text prompt are separated by commas. It returns
|
521
|
+
a list of bounding boxes, label names, masks and associated probability scores and
|
522
|
+
is useful for tracking and counting without duplicating counts.
|
535
523
|
|
536
524
|
Parameters:
|
537
525
|
prompt (str): The prompt to ground to the image.
|
538
|
-
|
526
|
+
frames (List[np.ndarray]): The list of frames to ground the prompt to.
|
527
|
+
chunk_length (Optional[int]): The number of frames to re-run owlv2 to find
|
528
|
+
new objects.
|
539
529
|
fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the
|
540
530
|
fine-tuned model ID here to use it.
|
541
531
|
|
542
532
|
Returns:
|
543
|
-
List[Dict[str, Any]]: A list of dictionaries containing the
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
the
|
533
|
+
List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
|
534
|
+
label, segmentation mask and bounding boxes. The outer list represents each
|
535
|
+
frame and the inner list is the entities per frame. The detected objects
|
536
|
+
have normalized coordinates between 0 and 1 (xmin, ymin, xmax, ymax). xmin
|
537
|
+
and ymin are the coordinates of the top-left and xmax and ymax are the
|
538
|
+
coordinates of the bottom-right of the bounding box. The mask is binary 2D
|
539
|
+
numpy array where 1 indicates the object and 0 indicates the background.
|
540
|
+
The label names are prefixed with their ID represent the total count.
|
549
541
|
|
550
542
|
Example
|
551
543
|
-------
|
@@ -755,11 +747,11 @@ def florence2_sam2_video_tracking(
|
|
755
747
|
chunk_length: Optional[int] = 10,
|
756
748
|
fine_tune_id: Optional[str] = None,
|
757
749
|
) -> List[List[Dict[str, Any]]]:
|
758
|
-
"""'florence2_sam2_video_tracking' is a tool that can
|
759
|
-
|
760
|
-
expressions.
|
761
|
-
|
762
|
-
|
750
|
+
"""'florence2_sam2_video_tracking' is a tool that can track and segment multiple
|
751
|
+
objects in a video given a text prompt such as category names or referring
|
752
|
+
expressions. The categories in the text prompt are separated by commas. It returns
|
753
|
+
a list of bounding boxes, label names, masks and associated probability scores and
|
754
|
+
is useful for tracking and counting without duplicating counts.
|
763
755
|
|
764
756
|
Parameters:
|
765
757
|
prompt (str): The prompt to ground to the video.
|
@@ -771,10 +763,13 @@ def florence2_sam2_video_tracking(
|
|
771
763
|
|
772
764
|
Returns:
|
773
765
|
List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
766
|
+
label, segmentation mask and bounding boxes. The outer list represents each
|
767
|
+
frame and the inner list is the entities per frame. The detected objects
|
768
|
+
have normalized coordinates between 0 and 1 (xmin, ymin, xmax, ymax). xmin
|
769
|
+
and ymin are the coordinates of the top-left and xmax and ymax are the
|
770
|
+
coordinates of the bottom-right of the bounding box. The mask is binary 2D
|
771
|
+
numpy array where 1 indicates the object and 0 indicates the background.
|
772
|
+
The label names are prefixed with their ID represent the total count.
|
778
773
|
|
779
774
|
Example
|
780
775
|
-------
|
@@ -1089,24 +1084,27 @@ def countgd_sam2_video_tracking(
|
|
1089
1084
|
frames: List[np.ndarray],
|
1090
1085
|
chunk_length: Optional[int] = 10,
|
1091
1086
|
) -> List[List[Dict[str, Any]]]:
|
1092
|
-
"""'countgd_sam2_video_tracking' is a tool that can segment multiple
|
1093
|
-
prompt such as category names or referring
|
1094
|
-
prompt are separated by commas. It returns
|
1095
|
-
|
1087
|
+
"""'countgd_sam2_video_tracking' is a tool that can track and segment multiple
|
1088
|
+
objects in a video given a text prompt such as category names or referring
|
1089
|
+
expressions. The categories in the text prompt are separated by commas. It returns
|
1090
|
+
a list of bounding boxes, label names, masks and associated probability scores and
|
1091
|
+
is useful for tracking and counting without duplicating counts.
|
1096
1092
|
|
1097
1093
|
Parameters:
|
1098
1094
|
prompt (str): The prompt to ground to the image.
|
1099
|
-
|
1100
|
-
chunk_length (Optional[int]): The number of frames to re-run
|
1095
|
+
frames (List[np.ndarray]): The list of frames to ground the prompt to.
|
1096
|
+
chunk_length (Optional[int]): The number of frames to re-run countgd to find
|
1101
1097
|
new objects.
|
1102
1098
|
|
1103
1099
|
Returns:
|
1104
|
-
List[Dict[str, Any]]: A list of dictionaries containing the
|
1105
|
-
|
1106
|
-
|
1107
|
-
|
1108
|
-
|
1109
|
-
the
|
1100
|
+
List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
|
1101
|
+
label, segmentation mask and bounding boxes. The outer list represents each
|
1102
|
+
frame and the inner list is the entities per frame. The detected objects
|
1103
|
+
have normalized coordinates between 0 and 1 (xmin, ymin, xmax, ymax). xmin
|
1104
|
+
and ymin are the coordinates of the top-left and xmax and ymax are the
|
1105
|
+
coordinates of the bottom-right of the bounding box. The mask is binary 2D
|
1106
|
+
numpy array where 1 indicates the object and 0 indicates the background.
|
1107
|
+
The label names are prefixed with their ID represent the total count.
|
1110
1108
|
|
1111
1109
|
Example
|
1112
1110
|
-------
|
@@ -1546,7 +1544,7 @@ def video_temporal_localization(
|
|
1546
1544
|
prompt (str): The question about the video
|
1547
1545
|
frames (List[np.ndarray]): The reference frames used for the question
|
1548
1546
|
model (str): The model to use for the inference. Valid values are
|
1549
|
-
'qwen2vl', 'gpt4o'
|
1547
|
+
'qwen2vl', 'gpt4o'.
|
1550
1548
|
chunk_length_frames (Optional[int]): length of each chunk in frames
|
1551
1549
|
|
1552
1550
|
Returns:
|
@@ -2115,7 +2113,7 @@ def closest_box_distance(
|
|
2115
2113
|
|
2116
2114
|
|
2117
2115
|
def extract_frames_and_timestamps(
|
2118
|
-
video_uri: Union[str, Path], fps: float =
|
2116
|
+
video_uri: Union[str, Path], fps: float = 5
|
2119
2117
|
) -> List[Dict[str, Union[np.ndarray, float]]]:
|
2120
2118
|
"""'extract_frames_and_timestamps' extracts frames and timestamps from a video
|
2121
2119
|
which can be a file path, url or youtube link, returns a list of dictionaries
|
@@ -2126,7 +2124,7 @@ def extract_frames_and_timestamps(
|
|
2126
2124
|
Parameters:
|
2127
2125
|
video_uri (Union[str, Path]): The path to the video file, url or youtube link
|
2128
2126
|
fps (float, optional): The frame rate per second to extract the frames. Defaults
|
2129
|
-
to
|
2127
|
+
to 5.
|
2130
2128
|
|
2131
2129
|
Returns:
|
2132
2130
|
List[Dict[str, Union[np.ndarray, float]]]: A list of dictionaries containing the
|
@@ -2649,10 +2647,8 @@ FUNCTION_TOOLS = [
|
|
2649
2647
|
ocr,
|
2650
2648
|
qwen2_vl_images_vqa,
|
2651
2649
|
qwen2_vl_video_vqa,
|
2652
|
-
detr_segmentation,
|
2653
2650
|
depth_anything_v2,
|
2654
2651
|
generate_pose_image,
|
2655
|
-
vit_image_classification,
|
2656
2652
|
vit_nsfw_classification,
|
2657
2653
|
video_temporal_localization,
|
2658
2654
|
flux_image_inpainting,
|
vision_agent/utils/sim.py
CHANGED
@@ -133,6 +133,12 @@ class Sim:
|
|
133
133
|
df: pd.DataFrame,
|
134
134
|
) -> bool:
|
135
135
|
load_dir = Path(load_dir)
|
136
|
+
if (
|
137
|
+
not Path(load_dir / "df.csv").exists()
|
138
|
+
or not Path(load_dir / "embs.npy").exists()
|
139
|
+
):
|
140
|
+
return False
|
141
|
+
|
136
142
|
df_load = pd.read_csv(load_dir / "df.csv")
|
137
143
|
if platform.system() == "Windows":
|
138
144
|
df_load["doc"] = df_load["doc"].apply(lambda x: x.replace("\r", ""))
|
@@ -0,0 +1,305 @@
|
|
1
|
+
import json
|
2
|
+
from enum import Enum
|
3
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from vision_agent.tools.tool_utils import (
|
8
|
+
add_bboxes_from_masks,
|
9
|
+
nms,
|
10
|
+
send_task_inference_request,
|
11
|
+
)
|
12
|
+
from vision_agent.utils.image_utils import denormalize_bbox, rle_decode_array
|
13
|
+
from vision_agent.utils.video import frames_to_bytes
|
14
|
+
|
15
|
+
|
16
|
+
class ODModels(str, Enum):
|
17
|
+
COUNTGD = "countgd"
|
18
|
+
FLORENCE2 = "florence2"
|
19
|
+
OWLV2 = "owlv2"
|
20
|
+
|
21
|
+
|
22
|
+
def split_frames_into_segments(
|
23
|
+
frames: List[np.ndarray], segment_size: int = 50, overlap: int = 1
|
24
|
+
) -> List[List[np.ndarray]]:
|
25
|
+
"""
|
26
|
+
Splits the list of frames into segments with a specified size and overlap.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
frames (List[np.ndarray]): List of video frames.
|
30
|
+
segment_size (int, optional): Number of frames per segment. Defaults to 50.
|
31
|
+
overlap (int, optional): Number of overlapping frames between segments. Defaults to 1.
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
List[List[np.ndarray]]: List of frame segments.
|
35
|
+
"""
|
36
|
+
segments = []
|
37
|
+
start = 0
|
38
|
+
segment_count = 0
|
39
|
+
while start < len(frames):
|
40
|
+
end = start + segment_size
|
41
|
+
if end > len(frames):
|
42
|
+
end = len(frames)
|
43
|
+
if start != 0:
|
44
|
+
# Include the last frame of the previous segment
|
45
|
+
segment = frames[start - overlap : end]
|
46
|
+
else:
|
47
|
+
segment = frames[start:end]
|
48
|
+
segments.append(segment)
|
49
|
+
start += segment_size
|
50
|
+
segment_count += 1
|
51
|
+
return segments
|
52
|
+
|
53
|
+
|
54
|
+
def process_segment(
|
55
|
+
segment_frames: List[np.ndarray],
|
56
|
+
od_model: ODModels,
|
57
|
+
prompt: str,
|
58
|
+
fine_tune_id: Optional[str],
|
59
|
+
chunk_length: Optional[int],
|
60
|
+
image_size: Tuple[int, ...],
|
61
|
+
segment_index: int,
|
62
|
+
object_detection_tool: Callable,
|
63
|
+
) -> Any:
|
64
|
+
"""
|
65
|
+
Processes a segment of frames with the specified object detection model.
|
66
|
+
|
67
|
+
Args:
|
68
|
+
segment_frames (List[np.ndarray]): Frames in the segment.
|
69
|
+
od_model (ODModels): Object detection model to use.
|
70
|
+
prompt (str): Prompt for the model.
|
71
|
+
fine_tune_id (Optional[str]): Fine-tune model ID.
|
72
|
+
chunk_length (Optional[int]): Chunk length for processing.
|
73
|
+
image_size (Tuple[int, int]): Size of the images.
|
74
|
+
segment_index (int): Index of the segment.
|
75
|
+
object_detection_tool (Callable): Object detection tool to use.
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
Any: Detections for the segment.
|
79
|
+
"""
|
80
|
+
segment_results: List[Optional[List[Dict[str, Any]]]] = [None] * len(segment_frames)
|
81
|
+
|
82
|
+
if chunk_length is None:
|
83
|
+
step = 1
|
84
|
+
elif chunk_length <= 0:
|
85
|
+
raise ValueError("chunk_length must be a positive integer or None.")
|
86
|
+
else:
|
87
|
+
step = chunk_length
|
88
|
+
|
89
|
+
function_name = ""
|
90
|
+
|
91
|
+
for idx in range(0, len(segment_frames), step):
|
92
|
+
frame_number = idx
|
93
|
+
segment_results[idx], function_name = object_detection_tool(
|
94
|
+
od_model, prompt, segment_index, frame_number, fine_tune_id, segment_frames
|
95
|
+
)
|
96
|
+
|
97
|
+
transformed_detections = transform_detections(
|
98
|
+
segment_results, image_size, segment_index
|
99
|
+
)
|
100
|
+
|
101
|
+
buffer_bytes = frames_to_bytes(segment_frames)
|
102
|
+
files = [("video", buffer_bytes)]
|
103
|
+
payload = {
|
104
|
+
"bboxes": json.dumps(transformed_detections),
|
105
|
+
"chunk_length_frames": chunk_length,
|
106
|
+
}
|
107
|
+
metadata = {"function_name": function_name}
|
108
|
+
|
109
|
+
segment_detections = send_task_inference_request(
|
110
|
+
payload,
|
111
|
+
"sam2",
|
112
|
+
files=files,
|
113
|
+
metadata=metadata,
|
114
|
+
)
|
115
|
+
|
116
|
+
return segment_detections
|
117
|
+
|
118
|
+
|
119
|
+
def transform_detections(
|
120
|
+
input_list: List[Optional[List[Dict[str, Any]]]],
|
121
|
+
image_size: Tuple[int, ...],
|
122
|
+
segment_index: int,
|
123
|
+
) -> List[Optional[Dict[str, Any]]]:
|
124
|
+
"""
|
125
|
+
Transforms raw detections into a standardized format.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
input_list (List[Optional[List[Dict[str, Any]]]]): Raw detections.
|
129
|
+
image_size (Tuple[int, int]): Size of the images.
|
130
|
+
segment_index (int): Index of the segment.
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
List[Optional[Dict[str, Any]]]: Transformed detections.
|
134
|
+
"""
|
135
|
+
output_list: List[Optional[Dict[str, Any]]] = []
|
136
|
+
for frame_idx, frame in enumerate(input_list):
|
137
|
+
if frame is not None:
|
138
|
+
labels = [detection["label"] for detection in frame]
|
139
|
+
bboxes = [
|
140
|
+
denormalize_bbox(detection["bbox"], image_size) for detection in frame
|
141
|
+
]
|
142
|
+
|
143
|
+
output_list.append(
|
144
|
+
{
|
145
|
+
"labels": labels,
|
146
|
+
"bboxes": bboxes,
|
147
|
+
}
|
148
|
+
)
|
149
|
+
else:
|
150
|
+
output_list.append(None)
|
151
|
+
return output_list
|
152
|
+
|
153
|
+
|
154
|
+
def _calculate_mask_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
|
155
|
+
mask1 = mask1.astype(bool)
|
156
|
+
mask2 = mask2.astype(bool)
|
157
|
+
|
158
|
+
intersection = np.sum(np.logical_and(mask1, mask2))
|
159
|
+
union = np.sum(np.logical_or(mask1, mask2))
|
160
|
+
|
161
|
+
if union == 0:
|
162
|
+
iou = 0.0
|
163
|
+
else:
|
164
|
+
iou = intersection / union
|
165
|
+
|
166
|
+
return iou
|
167
|
+
|
168
|
+
|
169
|
+
def _match_by_iou(
|
170
|
+
first_param: List[Dict],
|
171
|
+
second_param: List[Dict],
|
172
|
+
iou_threshold: float = 0.8,
|
173
|
+
) -> Tuple[List[Dict], Dict[int, int]]:
|
174
|
+
max_id = max((item["id"] for item in first_param), default=0)
|
175
|
+
|
176
|
+
matched_new_item_indices = set()
|
177
|
+
id_mapping = {}
|
178
|
+
|
179
|
+
for new_index, new_item in enumerate(second_param):
|
180
|
+
matched_id = None
|
181
|
+
|
182
|
+
for existing_item in first_param:
|
183
|
+
iou = _calculate_mask_iou(
|
184
|
+
existing_item["decoded_mask"], new_item["decoded_mask"]
|
185
|
+
)
|
186
|
+
if iou > iou_threshold:
|
187
|
+
matched_id = existing_item["id"]
|
188
|
+
matched_new_item_indices.add(new_index)
|
189
|
+
id_mapping[new_item["id"]] = matched_id
|
190
|
+
break
|
191
|
+
|
192
|
+
if matched_id:
|
193
|
+
new_item["id"] = matched_id
|
194
|
+
else:
|
195
|
+
max_id += 1
|
196
|
+
id_mapping[new_item["id"]] = max_id
|
197
|
+
new_item["id"] = max_id
|
198
|
+
|
199
|
+
unmatched_items = [
|
200
|
+
item for i, item in enumerate(second_param) if i not in matched_new_item_indices
|
201
|
+
]
|
202
|
+
combined_list = first_param + unmatched_items
|
203
|
+
|
204
|
+
return combined_list, id_mapping
|
205
|
+
|
206
|
+
|
207
|
+
def _update_ids(detections: List[Dict], id_mapping: Dict[int, int]) -> None:
|
208
|
+
for inner_list in detections:
|
209
|
+
for detection in inner_list:
|
210
|
+
if detection["id"] in id_mapping:
|
211
|
+
detection["id"] = id_mapping[detection["id"]]
|
212
|
+
else:
|
213
|
+
max_new_id = max(id_mapping.values(), default=0)
|
214
|
+
detection["id"] = max_new_id + 1
|
215
|
+
id_mapping[detection["id"]] = detection["id"]
|
216
|
+
|
217
|
+
|
218
|
+
def _convert_to_2d(detections_per_segment: List[Any]) -> List[Any]:
|
219
|
+
result = []
|
220
|
+
for i, segment in enumerate(detections_per_segment):
|
221
|
+
if i == 0:
|
222
|
+
result.extend(segment)
|
223
|
+
else:
|
224
|
+
result.extend(segment[1:])
|
225
|
+
return result
|
226
|
+
|
227
|
+
|
228
|
+
def merge_segments(detections_per_segment: List[Any]) -> List[Any]:
|
229
|
+
"""
|
230
|
+
Merges detections from all segments into a unified result.
|
231
|
+
|
232
|
+
Args:
|
233
|
+
detections_per_segment (List[Any]): List of detections per segment.
|
234
|
+
|
235
|
+
Returns:
|
236
|
+
List[Any]: Merged detections.
|
237
|
+
"""
|
238
|
+
for segment in detections_per_segment:
|
239
|
+
for detection in segment:
|
240
|
+
for item in detection:
|
241
|
+
item["decoded_mask"] = rle_decode_array(item["mask"])
|
242
|
+
|
243
|
+
for segment_idx in range(len(detections_per_segment) - 1):
|
244
|
+
combined_detection, id_mapping = _match_by_iou(
|
245
|
+
detections_per_segment[segment_idx][-1],
|
246
|
+
detections_per_segment[segment_idx + 1][0],
|
247
|
+
)
|
248
|
+
_update_ids(detections_per_segment[segment_idx + 1], id_mapping)
|
249
|
+
|
250
|
+
merged_result = _convert_to_2d(detections_per_segment)
|
251
|
+
|
252
|
+
return merged_result
|
253
|
+
|
254
|
+
|
255
|
+
def post_process(
|
256
|
+
merged_detections: List[Any],
|
257
|
+
image_size: Tuple[int, ...],
|
258
|
+
) -> Dict[str, Any]:
|
259
|
+
"""
|
260
|
+
Performs post-processing on merged detections, including NMS and preparing display data.
|
261
|
+
|
262
|
+
Args:
|
263
|
+
merged_detections (List[Any]): Merged detections from all segments.
|
264
|
+
image_size (Tuple[int, int]): Size of the images.
|
265
|
+
|
266
|
+
Returns:
|
267
|
+
Dict[str, Any]: Post-processed data including return_data and display_data.
|
268
|
+
"""
|
269
|
+
return_data = []
|
270
|
+
for frame_idx, frame in enumerate(merged_detections):
|
271
|
+
return_frame_data = []
|
272
|
+
for detection in frame:
|
273
|
+
label = f"{detection['id']}: {detection['label']}"
|
274
|
+
return_frame_data.append(
|
275
|
+
{
|
276
|
+
"label": label,
|
277
|
+
"mask": detection["decoded_mask"],
|
278
|
+
"rle": detection["mask"],
|
279
|
+
"score": 1.0,
|
280
|
+
}
|
281
|
+
)
|
282
|
+
del detection["decoded_mask"]
|
283
|
+
return_data.append(return_frame_data)
|
284
|
+
|
285
|
+
return_data = add_bboxes_from_masks(return_data)
|
286
|
+
return_data = nms(return_data, iou_threshold=0.95)
|
287
|
+
|
288
|
+
# We save the RLE for display purposes, re-calculting RLE can get very expensive.
|
289
|
+
# Deleted here because we are returning the numpy masks instead
|
290
|
+
display_data = []
|
291
|
+
for frame in return_data:
|
292
|
+
display_frame_data = []
|
293
|
+
for obj in frame:
|
294
|
+
display_frame_data.append(
|
295
|
+
{
|
296
|
+
"label": obj["label"],
|
297
|
+
"bbox": denormalize_bbox(obj["bbox"], image_size),
|
298
|
+
"mask": obj["rle"],
|
299
|
+
"score": obj["score"],
|
300
|
+
}
|
301
|
+
)
|
302
|
+
del obj["rle"]
|
303
|
+
display_data.append(display_frame_data)
|
304
|
+
|
305
|
+
return {"return_data": return_data, "display_data": display_data}
|
@@ -1,23 +1,23 @@
|
|
1
|
-
vision_agent/.sim_tools/df.csv,sha256=
|
2
|
-
vision_agent/.sim_tools/embs.npy,sha256=
|
1
|
+
vision_agent/.sim_tools/df.csv,sha256=Vamicw8MiSGildK1r3-HXY4cKiq17GZxsgBsHbk7jpM,42158
|
2
|
+
vision_agent/.sim_tools/embs.npy,sha256=YJe8EcKVNmeX_75CS2T1sbY-sUS_1HQAMT-34zc18a0,254080
|
3
3
|
vision_agent/__init__.py,sha256=EAb4-f9iyuEYkBrX4ag1syM8Syx8118_t0R6_C34M9w,57
|
4
4
|
vision_agent/agent/README.md,sha256=Q4w7FWw38qaWosQYAZ7NqWx8Q5XzuWrlv7nLhjUd1-8,5527
|
5
5
|
vision_agent/agent/__init__.py,sha256=M8CffavdIh8Zh-skznLHIaQkYGCGK7vk4dq1FaVkbs4,617
|
6
6
|
vision_agent/agent/agent.py,sha256=_1tHWAs7Jm5tqDzEcPfCRvJV3uRRveyh4n9_9pd6I1w,1565
|
7
|
-
vision_agent/agent/agent_utils.py,sha256=
|
7
|
+
vision_agent/agent/agent_utils.py,sha256=pP4u5tiami7C3ChgjgYLqJITnmkTI1_GsUj6g5czSRk,13994
|
8
8
|
vision_agent/agent/types.py,sha256=DkFm3VMMrKlhYyfxEmZx4keppD72Ov3wmLCbM2J2o10,2437
|
9
9
|
vision_agent/agent/vision_agent.py,sha256=I75bEU-os9Lf9OSICKfvQ_H_ftg-zOwgTwWnu41oIdo,23555
|
10
10
|
vision_agent/agent/vision_agent_coder.py,sha256=flUxOibyGZK19BCSK5mhaD3HjCxHw6c6FtKom6N2q1E,27359
|
11
11
|
vision_agent/agent/vision_agent_coder_prompts.py,sha256=gPLVXQMNSzYnQYpNm0wlH_5FPkOTaFDV24bqzK3jQ40,12221
|
12
|
-
vision_agent/agent/vision_agent_coder_prompts_v2.py,sha256=
|
13
|
-
vision_agent/agent/vision_agent_coder_v2.py,sha256=
|
12
|
+
vision_agent/agent/vision_agent_coder_prompts_v2.py,sha256=idmSMfxebPULqqvllz3gqRzGDchEvS5dkGngvBs4PGo,4872
|
13
|
+
vision_agent/agent/vision_agent_coder_v2.py,sha256=i1qgXp5YsWVRoA_qO429Ef-aKZBakveCl1F_2ZbSzk8,16287
|
14
14
|
vision_agent/agent/vision_agent_planner.py,sha256=fFzjNkZBKkh8Y_oS06ATI4qz31xmIJvixb_tV1kX8KA,18590
|
15
15
|
vision_agent/agent/vision_agent_planner_prompts.py,sha256=mn9NlZpRkW4XAvlNuMZwIs1ieHCFds5aYZJ55WXupZY,6733
|
16
|
-
vision_agent/agent/vision_agent_planner_prompts_v2.py,sha256=
|
16
|
+
vision_agent/agent/vision_agent_planner_prompts_v2.py,sha256=YgemW2PRPYd8o8XpmwSJBUOJSxMUXMNr2DZNQnS4jEI,34988
|
17
17
|
vision_agent/agent/vision_agent_planner_v2.py,sha256=vvxfmGydBIKB8CtNSAJyPvdEXkG7nIO5-Hs2SjNc48Y,20465
|
18
18
|
vision_agent/agent/vision_agent_prompts.py,sha256=NtGdCfzzilCRtscKALC9FK55d1h4CBpMnbhLzg0PYlc,13772
|
19
19
|
vision_agent/agent/vision_agent_prompts_v2.py,sha256=-vCWat-ARlCOOOeIDIFhg-kcwRRwjTXYEwsvvqPeaCs,1972
|
20
|
-
vision_agent/agent/vision_agent_v2.py,sha256=
|
20
|
+
vision_agent/agent/vision_agent_v2.py,sha256=1wu_vH_onic2kLYPKW2nAF2e6Zz5vmUt5Acv4Seq3sQ,10796
|
21
21
|
vision_agent/clients/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
22
22
|
vision_agent/clients/http.py,sha256=k883i6M_4nl7zwwHSI-yP5sAgQZIDPM1nrKD6YFJ3Xs,2009
|
23
23
|
vision_agent/clients/landing_public_api.py,sha256=lU2ev6E8NICmR8DMUljuGcVFy5VNJQ4WQkWC8WnnJEc,1503
|
@@ -28,19 +28,20 @@ vision_agent/lmm/lmm.py,sha256=x_nIyDNDZwq4-pfjnJTmcyyJZ2_B7TjkA5jZp88YVO8,17103
|
|
28
28
|
vision_agent/lmm/types.py,sha256=ZEXR_ptBL0ZwDMTDYkgxUCmSZFmBYPQd2jreNzr_8UY,221
|
29
29
|
vision_agent/tools/__init__.py,sha256=15O7eQVn0bitmzUO5OxKdA618PoiLt6Z02gmKsSNMFM,2765
|
30
30
|
vision_agent/tools/meta_tools.py,sha256=TPeS7QWnc_PmmU_ndiDT03dXbQ5yDSP33E7U8cSj7Ls,28660
|
31
|
-
vision_agent/tools/planner_tools.py,sha256=
|
31
|
+
vision_agent/tools/planner_tools.py,sha256=qQvPuCif-KbFi7KsXKkTCfpgEQEJJ6oq6WB3gOuG2Xg,13686
|
32
32
|
vision_agent/tools/prompts.py,sha256=V1z4YJLXZuUl_iZ5rY0M5hHc_2tmMEUKr0WocXKGt4E,1430
|
33
33
|
vision_agent/tools/tool_utils.py,sha256=q9cqXO2AvigUdO1krjnOy8o0goYhgS6eILl6-F5Kxyk,10211
|
34
|
-
vision_agent/tools/tools.py,sha256=
|
34
|
+
vision_agent/tools/tools.py,sha256=zqoo4ml9ZS99kOeOIN6Zplq7pxOwBrVZKKFUVIzsjfw,91712
|
35
35
|
vision_agent/tools/tools_types.py,sha256=8hYf2OZhI58gvf65KGaeGkt4EQ56nwLFqIQDPHioOBc,2339
|
36
36
|
vision_agent/utils/__init__.py,sha256=QKk4zVjMwGxQI0MQ-aZZA50N-qItxRY4EB9CwQkZ2HY,185
|
37
37
|
vision_agent/utils/exceptions.py,sha256=booSPSuoULF7OXRr_YbC4dtKt6gM_HyiFQHBuaW86C4,2052
|
38
38
|
vision_agent/utils/execute.py,sha256=vOEP5Ys7S2lc0_7pOJbgk7OaWi85hrCNu9_8Bo3zk6I,29356
|
39
39
|
vision_agent/utils/image_utils.py,sha256=z_ONgcza125B10NkoGwPOzXnL470bpTWZbkB16NeeH0,12188
|
40
|
-
vision_agent/utils/sim.py,sha256=
|
40
|
+
vision_agent/utils/sim.py,sha256=qr-6UWAxxGwtwIAKZjZCY_pu9VwBI_TTB8bfrGsaABg,9282
|
41
41
|
vision_agent/utils/type_defs.py,sha256=BE12s3JNQy36QvauXHjwyeffVh5enfcvd4vTzSwvEZI,1384
|
42
42
|
vision_agent/utils/video.py,sha256=e1VwKhXzzlC5LcFMyrcQYrPnpnX4wxDpnQ-76sB4jgM,6001
|
43
|
-
vision_agent
|
44
|
-
vision_agent-0.2.
|
45
|
-
vision_agent-0.2.
|
46
|
-
vision_agent-0.2.
|
43
|
+
vision_agent/utils/video_tracking.py,sha256=EeOiSY8gjvvneuAnv-BO7yOyMBF_-1Irk_lLLOt3bDM,9452
|
44
|
+
vision_agent-0.2.226.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
45
|
+
vision_agent-0.2.226.dist-info/METADATA,sha256=_7jZokNbQLK6Ups2psyRKbPDjUIzU3daxCpfrHZ6gSU,20039
|
46
|
+
vision_agent-0.2.226.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
|
47
|
+
vision_agent-0.2.226.dist-info/RECORD,,
|
File without changes
|
File without changes
|