vision-agent 0.2.224__py3-none-any.whl → 0.2.226__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -65,25 +65,30 @@ desc,doc,name
65
65
  },
66
66
  ]
67
67
  ",owlv2_sam2_instance_segmentation
68
- "'owlv2_sam2_video_tracking' is a tool that can segment multiple objects given a text prompt such as category names or referring expressions. The categories in the text prompt are separated by commas. It returns a list of bounding boxes, label names, mask file names and associated probability scores.","owlv2_sam2_video_tracking(prompt: str, frames: List[numpy.ndarray], chunk_length: Optional[int] = 10, fine_tune_id: Optional[str] = None) -> List[List[Dict[str, Any]]]:
69
- 'owlv2_sam2_video_tracking' is a tool that can segment multiple objects given a text
70
- prompt such as category names or referring expressions. The categories in the text
71
- prompt are separated by commas. It returns a list of bounding boxes, label names,
72
- mask file names and associated probability scores.
68
+ "'owlv2_sam2_video_tracking' is a tool that can track and segment multiple objects in a video given a text prompt such as category names or referring expressions. The categories in the text prompt are separated by commas. It returns a list of bounding boxes, label names, masks and associated probability scores and is useful for tracking and counting without duplicating counts.","owlv2_sam2_video_tracking(prompt: str, frames: List[numpy.ndarray], chunk_length: Optional[int] = 10, fine_tune_id: Optional[str] = None) -> List[List[Dict[str, Any]]]:
69
+ 'owlv2_sam2_video_tracking' is a tool that can track and segment multiple
70
+ objects in a video given a text prompt such as category names or referring
71
+ expressions. The categories in the text prompt are separated by commas. It returns
72
+ a list of bounding boxes, label names, masks and associated probability scores and
73
+ is useful for tracking and counting without duplicating counts.
73
74
 
74
75
  Parameters:
75
76
  prompt (str): The prompt to ground to the image.
76
- image (np.ndarray): The image to ground the prompt to.
77
+ frames (List[np.ndarray]): The list of frames to ground the prompt to.
78
+ chunk_length (Optional[int]): The number of frames to re-run owlv2 to find
79
+ new objects.
77
80
  fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the
78
81
  fine-tuned model ID here to use it.
79
82
 
80
83
  Returns:
81
- List[Dict[str, Any]]: A list of dictionaries containing the score, label,
82
- bounding box, and mask of the detected objects with normalized coordinates
83
- (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left
84
- and xmax and ymax are the coordinates of the bottom-right of the bounding box.
85
- The mask is binary 2D numpy array where 1 indicates the object and 0 indicates
86
- the background.
84
+ List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
85
+ label, segmentation mask and bounding boxes. The outer list represents each
86
+ frame and the inner list is the entities per frame. The detected objects
87
+ have normalized coordinates between 0 and 1 (xmin, ymin, xmax, ymax). xmin
88
+ and ymin are the coordinates of the top-left and xmax and ymax are the
89
+ coordinates of the bottom-right of the bounding box. The mask is binary 2D
90
+ numpy array where 1 indicates the object and 0 indicates the background.
91
+ The label names are prefixed with their ID represent the total count.
87
92
 
88
93
  Example
89
94
  -------
@@ -170,25 +175,28 @@ desc,doc,name
170
175
  },
171
176
  ]
172
177
  ",countgd_sam2_instance_segmentation
173
- "'countgd_sam2_video_tracking' is a tool that can segment multiple objects given a text prompt such as category names or referring expressions. The categories in the text prompt are separated by commas. It returns a list of bounding boxes, label names, mask file names and associated probability scores.","countgd_sam2_video_tracking(prompt: str, frames: List[numpy.ndarray], chunk_length: Optional[int] = 10) -> List[List[Dict[str, Any]]]:
174
- 'countgd_sam2_video_tracking' is a tool that can segment multiple objects given a text
175
- prompt such as category names or referring expressions. The categories in the text
176
- prompt are separated by commas. It returns a list of bounding boxes, label names,
177
- mask file names and associated probability scores.
178
+ "'countgd_sam2_video_tracking' is a tool that can track and segment multiple objects in a video given a text prompt such as category names or referring expressions. The categories in the text prompt are separated by commas. It returns a list of bounding boxes, label names, masks and associated probability scores and is useful for tracking and counting without duplicating counts.","countgd_sam2_video_tracking(prompt: str, frames: List[numpy.ndarray], chunk_length: Optional[int] = 10) -> List[List[Dict[str, Any]]]:
179
+ 'countgd_sam2_video_tracking' is a tool that can track and segment multiple
180
+ objects in a video given a text prompt such as category names or referring
181
+ expressions. The categories in the text prompt are separated by commas. It returns
182
+ a list of bounding boxes, label names, masks and associated probability scores and
183
+ is useful for tracking and counting without duplicating counts.
178
184
 
179
185
  Parameters:
180
186
  prompt (str): The prompt to ground to the image.
181
- image (np.ndarray): The image to ground the prompt to.
182
- chunk_length (Optional[int]): The number of frames to re-run florence2 to find
187
+ frames (List[np.ndarray]): The list of frames to ground the prompt to.
188
+ chunk_length (Optional[int]): The number of frames to re-run countgd to find
183
189
  new objects.
184
190
 
185
191
  Returns:
186
- List[Dict[str, Any]]: A list of dictionaries containing the score, label,
187
- bounding box, and mask of the detected objects with normalized coordinates
188
- (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left
189
- and xmax and ymax are the coordinates of the bottom-right of the bounding box.
190
- The mask is binary 2D numpy array where 1 indicates the object and 0 indicates
191
- the background.
192
+ List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
193
+ label, segmentation mask and bounding boxes. The outer list represents each
194
+ frame and the inner list is the entities per frame. The detected objects
195
+ have normalized coordinates between 0 and 1 (xmin, ymin, xmax, ymax). xmin
196
+ and ymin are the coordinates of the top-left and xmax and ymax are the
197
+ coordinates of the bottom-right of the bounding box. The mask is binary 2D
198
+ numpy array where 1 indicates the object and 0 indicates the background.
199
+ The label names are prefixed with their ID represent the total count.
192
200
 
193
201
  Example
194
202
  -------
@@ -265,12 +273,12 @@ desc,doc,name
265
273
  },
266
274
  ]
267
275
  ",florence2_sam2_instance_segmentation
268
- 'florence2_sam2_video_tracking' is a tool that can segment and track multiple entities in a video given a text prompt such as category names or referring expressions. You can optionally separate the categories in the text with commas. It can find new objects every 'chunk_length' frames and is useful for tracking and counting without duplicating counts and always outputs scores of 1.0.,"florence2_sam2_video_tracking(prompt: str, frames: List[numpy.ndarray], chunk_length: Optional[int] = 10, fine_tune_id: Optional[str] = None) -> List[List[Dict[str, Any]]]:
269
- 'florence2_sam2_video_tracking' is a tool that can segment and track multiple
270
- entities in a video given a text prompt such as category names or referring
271
- expressions. You can optionally separate the categories in the text with commas. It
272
- can find new objects every 'chunk_length' frames and is useful for tracking and
273
- counting without duplicating counts and always outputs scores of 1.0.
276
+ "'florence2_sam2_video_tracking' is a tool that can track and segment multiple objects in a video given a text prompt such as category names or referring expressions. The categories in the text prompt are separated by commas. It returns a list of bounding boxes, label names, masks and associated probability scores and is useful for tracking and counting without duplicating counts.","florence2_sam2_video_tracking(prompt: str, frames: List[numpy.ndarray], chunk_length: Optional[int] = 10, fine_tune_id: Optional[str] = None) -> List[List[Dict[str, Any]]]:
277
+ 'florence2_sam2_video_tracking' is a tool that can track and segment multiple
278
+ objects in a video given a text prompt such as category names or referring
279
+ expressions. The categories in the text prompt are separated by commas. It returns
280
+ a list of bounding boxes, label names, masks and associated probability scores and
281
+ is useful for tracking and counting without duplicating counts.
274
282
 
275
283
  Parameters:
276
284
  prompt (str): The prompt to ground to the video.
@@ -282,10 +290,13 @@ desc,doc,name
282
290
 
283
291
  Returns:
284
292
  List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
285
- label, segment mask and bounding boxes. The outer list represents each frame
286
- and the inner list is the entities per frame. The label contains the object ID
287
- followed by the label name. The objects are only identified in the first framed
288
- and tracked throughout the video.
293
+ label, segmentation mask and bounding boxes. The outer list represents each
294
+ frame and the inner list is the entities per frame. The detected objects
295
+ have normalized coordinates between 0 and 1 (xmin, ymin, xmax, ymax). xmin
296
+ and ymin are the coordinates of the top-left and xmax and ymax are the
297
+ coordinates of the bottom-right of the bounding box. The mask is binary 2D
298
+ numpy array where 1 indicates the object and 0 indicates the background.
299
+ The label names are prefixed with their ID represent the total count.
289
300
 
290
301
  Example
291
302
  -------
@@ -445,43 +456,6 @@ desc,doc,name
445
456
  >>> qwen2_vl_video_vqa('Which football player made the goal?', frames)
446
457
  'Lionel Messi'
447
458
  ",qwen2_vl_video_vqa
448
- "'detr_segmentation' is a tool that can segment common objects in an image without any text prompt. It returns a list of detected objects as labels, their regions as masks and their scores.","detr_segmentation(image: numpy.ndarray) -> List[Dict[str, Any]]:
449
- 'detr_segmentation' is a tool that can segment common objects in an
450
- image without any text prompt. It returns a list of detected objects
451
- as labels, their regions as masks and their scores.
452
-
453
- Parameters:
454
- image (np.ndarray): The image used to segment things and objects
455
-
456
- Returns:
457
- List[Dict[str, Any]]: A list of dictionaries containing the score, label
458
- and mask of the detected objects. The mask is binary 2D numpy array where 1
459
- indicates the object and 0 indicates the background.
460
-
461
- Example
462
- -------
463
- >>> detr_segmentation(image)
464
- [
465
- {
466
- 'score': 0.45,
467
- 'label': 'window',
468
- 'mask': array([[0, 0, 0, ..., 0, 0, 0],
469
- [0, 0, 0, ..., 0, 0, 0],
470
- ...,
471
- [0, 0, 0, ..., 0, 0, 0],
472
- [0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
473
- },
474
- {
475
- 'score': 0.70,
476
- 'label': 'bird',
477
- 'mask': array([[0, 0, 0, ..., 0, 0, 0],
478
- [0, 0, 0, ..., 0, 0, 0],
479
- ...,
480
- [0, 0, 0, ..., 0, 0, 0],
481
- [0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
482
- },
483
- ]
484
- ",detr_segmentation
485
459
  'depth_anything_v2' is a tool that runs depth_anythingv2 model to generate a depth image from a given RGB image. The returned depth image is monochrome and represents depth values as pixel intesities with pixel values ranging from 0 to 255.,"depth_anything_v2(image: numpy.ndarray) -> numpy.ndarray:
486
460
  'depth_anything_v2' is a tool that runs depth_anythingv2 model to generate a
487
461
  depth image from a given RGB image. The returned depth image is monochrome and
@@ -522,22 +496,6 @@ desc,doc,name
522
496
  [10, 11, 15, ..., 202, 202, 205],
523
497
  [10, 10, 10, ..., 200, 200, 200]], dtype=uint8),
524
498
  ",generate_pose_image
525
- 'vit_image_classification' is a tool that can classify an image. It returns a list of classes and their probability scores based on image content.,"vit_image_classification(image: numpy.ndarray) -> Dict[str, Any]:
526
- 'vit_image_classification' is a tool that can classify an image. It returns a
527
- list of classes and their probability scores based on image content.
528
-
529
- Parameters:
530
- image (np.ndarray): The image to classify or tag
531
-
532
- Returns:
533
- Dict[str, Any]: A dictionary containing the labels and scores. One dictionary
534
- contains a list of labels and other a list of scores.
535
-
536
- Example
537
- -------
538
- >>> vit_image_classification(image)
539
- {""labels"": [""leopard"", ""lemur, otter"", ""bird""], ""scores"": [0.68, 0.30, 0.02]},
540
- ",vit_image_classification
541
499
  'vit_nsfw_classification' is a tool that can classify an image as 'nsfw' or 'normal'. It returns the predicted label and their probability scores based on image content.,"vit_nsfw_classification(image: numpy.ndarray) -> Dict[str, Any]:
542
500
  'vit_nsfw_classification' is a tool that can classify an image as 'nsfw' or 'normal'.
543
501
  It returns the predicted label and their probability scores based on image content.
@@ -566,7 +524,7 @@ desc,doc,name
566
524
  prompt (str): The question about the video
567
525
  frames (List[np.ndarray]): The reference frames used for the question
568
526
  model (str): The model to use for the inference. Valid values are
569
- 'qwen2vl', 'gpt4o', 'internlm-xcomposer'
527
+ 'qwen2vl', 'gpt4o'.
570
528
  chunk_length_frames (Optional[int]): length of each chunk in frames
571
529
 
572
530
  Returns:
@@ -641,7 +599,7 @@ desc,doc,name
641
599
  >>> closest_distance(det1, det2, image_size)
642
600
  141.42
643
601
  ",minimum_distance
644
- "'extract_frames_and_timestamps' extracts frames and timestamps from a video which can be a file path, url or youtube link, returns a list of dictionaries with keys ""frame"" and ""timestamp"" where ""frame"" is a numpy array and ""timestamp"" is the relative time in seconds where the frame was captured. The frame is a numpy array.","extract_frames_and_timestamps(video_uri: Union[str, pathlib.Path], fps: float = 1) -> List[Dict[str, Union[numpy.ndarray, float]]]:
602
+ "'extract_frames_and_timestamps' extracts frames and timestamps from a video which can be a file path, url or youtube link, returns a list of dictionaries with keys ""frame"" and ""timestamp"" where ""frame"" is a numpy array and ""timestamp"" is the relative time in seconds where the frame was captured. The frame is a numpy array.","extract_frames_and_timestamps(video_uri: Union[str, pathlib.Path], fps: float = 5) -> List[Dict[str, Union[numpy.ndarray, float]]]:
645
603
  'extract_frames_and_timestamps' extracts frames and timestamps from a video
646
604
  which can be a file path, url or youtube link, returns a list of dictionaries
647
605
  with keys ""frame"" and ""timestamp"" where ""frame"" is a numpy array and ""timestamp"" is
@@ -651,7 +609,7 @@ desc,doc,name
651
609
  Parameters:
652
610
  video_uri (Union[str, Path]): The path to the video file, url or youtube link
653
611
  fps (float, optional): The frame rate per second to extract the frames. Defaults
654
- to 1.
612
+ to 5.
655
613
 
656
614
  Returns:
657
615
  List[Dict[str, Union[np.ndarray, float]]]: A list of dictionaries containing the
Binary file
@@ -153,6 +153,19 @@ def format_plan_v2(plan: PlanContext) -> str:
153
153
  return plan_str
154
154
 
155
155
 
156
+ def format_conversation(chat: List[AgentMessage]) -> str:
157
+ chat = copy.deepcopy(chat)
158
+ prompt = ""
159
+ for chat_i in chat:
160
+ if chat_i.role == "user":
161
+ prompt += f"USER: {chat_i.content}\n\n"
162
+ elif chat_i.role == "observation" or chat_i.role == "coder":
163
+ prompt += f"OBSERVATION: {chat_i.content}\n\n"
164
+ elif chat_i.role == "conversation":
165
+ prompt += f"AGENT: {chat_i.content}\n\n"
166
+ return prompt
167
+
168
+
156
169
  def format_plans(plans: Dict[str, Any]) -> str:
157
170
  plan_str = ""
158
171
  for k, v in plans.items():
@@ -65,7 +65,7 @@ This is the documentation for the functions you have access to. You may call any
65
65
  7. DO NOT assert the output value, run the code and assert only the output format or data structure.
66
66
  8. DO NOT use try except block to handle the error, let the error be raised if the code is incorrect.
67
67
  9. DO NOT import the testing function as it will available in the testing environment.
68
- 10. Print the output of the function that is being tested.
68
+ 10. Print the output of the function that is being tested and ensure it is not empty.
69
69
  11. Use the output of the function that is being tested as the return value of the testing function.
70
70
  12. Run the testing function in the end and don't assign a variable to its output.
71
71
  13. Output your test code using <code> tags:
@@ -202,7 +202,12 @@ def write_and_test_code(
202
202
  tool_docs=tool_docs,
203
203
  plan=plan,
204
204
  )
205
- code = strip_function_calls(code)
205
+ try:
206
+ code = strip_function_calls(code)
207
+ except Exception:
208
+ # the code may be malformatted, this will fail in the exec call and the agent
209
+ # will attempt to debug it
210
+ pass
206
211
  test = write_test(
207
212
  tester=tester,
208
213
  chat=chat,
@@ -136,8 +136,9 @@ Tool Documentation:
136
136
  countgd_object_detection(prompt: str, image: numpy.ndarray, box_threshold: float = 0.23) -> List[Dict[str, Any]]:
137
137
  'countgd_object_detection' is a tool that can detect multiple instances of an
138
138
  object given a text prompt. It is particularly useful when trying to detect and
139
- count a large number of objects. It returns a list of bounding boxes with
140
- normalized coordinates, label names and associated confidence scores.
139
+ count a large number of objects. You can optionally separate object names in the
140
+ prompt with commas. It returns a list of bounding boxes with normalized
141
+ coordinates, label names and associated confidence scores.
141
142
 
142
143
  Parameters:
143
144
  prompt (str): The object that needs to be counted.
@@ -272,40 +273,47 @@ OBSERVATION:
272
273
  [get_tool_for_task output]
273
274
  For tracking boxes moving on a conveyor belt, we need a tool that can consistently track the same box across frames without losing it or double counting. Looking at the outputs: florence2_sam2_video_tracking successfully tracks the single box across all 5 frames, maintaining consistent tracking IDs and showing the box's movement along the conveyor.
274
275
 
275
- 'florence2_sam2_video_tracking' is a tool that can segment and track multiple
276
- entities in a video given a text prompt such as category names or referring
277
- expressions. You can optionally separate the categories in the text with commas. It
278
- can find new objects every 'chunk_length' frames and is useful for tracking and
279
- counting without duplicating counts and always outputs scores of 1.0.
276
+ Tool Documentation:
277
+ def florence2_sam2_video_tracking(prompt: str, frames: List[np.ndarray], chunk_length: Optional[int] = 10) -> List[List[Dict[str, Any]]]:
278
+ 'florence2_sam2_video_tracking' is a tool that can track and segment multiple
279
+ objects in a video given a text prompt such as category names or referring
280
+ expressions. The categories in the text prompt are separated by commas. It returns
281
+ a list of bounding boxes, label names, masks and associated probability scores and
282
+ is useful for tracking and counting without duplicating counts.
280
283
 
281
- Parameters:
282
- prompt (str): The prompt to ground to the video.
283
- frames (List[np.ndarray]): The list of frames to ground the prompt to.
284
- chunk_length (Optional[int]): The number of frames to re-run florence2 to find
285
- new objects.
284
+ Parameters:
285
+ prompt (str): The prompt to ground to the video.
286
+ frames (List[np.ndarray]): The list of frames to ground the prompt to.
287
+ chunk_length (Optional[int]): The number of frames to re-run florence2 to find
288
+ new objects.
289
+ fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the
290
+ fine-tuned model ID here to use it.
286
291
 
287
- Returns:
288
- List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
289
- label,segment mask and bounding boxes. The outer list represents each frame and
290
- the inner list is the entities per frame. The label contains the object ID
291
- followed by the label name. The objects are only identified in the first framed
292
- and tracked throughout the video.
292
+ Returns:
293
+ List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
294
+ label, segmentation mask and bounding boxes. The outer list represents each
295
+ frame and the inner list is the entities per frame. The detected objects
296
+ have normalized coordinates between 0 and 1 (xmin, ymin, xmax, ymax). xmin
297
+ and ymin are the coordinates of the top-left and xmax and ymax are the
298
+ coordinates of the bottom-right of the bounding box. The mask is binary 2D
299
+ numpy array where 1 indicates the object and 0 indicates the background.
300
+ The label names are prefixed with their ID represent the total count.
293
301
 
294
- Example
295
- -------
296
- >>> florence2_sam2_video("car, dinosaur", frames)
297
- [
302
+ Example
303
+ -------
304
+ >>> florence2_sam2_video_tracking("car, dinosaur", frames)
298
305
  [
299
- {
300
- 'label': '0: dinosaur',
301
- 'bbox': [0.1, 0.11, 0.35, 0.4],
302
- 'mask': array([[0, 0, 0, ..., 0, 0, 0],
303
- ...,
304
- [0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
305
- },
306
- ],
307
- ...
308
- ]
306
+ [
307
+ {
308
+ 'label': '0: dinosaur',
309
+ 'bbox': [0.1, 0.11, 0.35, 0.4],
310
+ 'mask': array([[0, 0, 0, ..., 0, 0, 0],
311
+ ...,
312
+ [0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
313
+ },
314
+ ],
315
+ ...
316
+ ]
309
317
  [end of get_tool_for_task output]
310
318
  <count>8</count>
311
319
 
@@ -691,7 +699,8 @@ FINALIZE_PLAN = """
691
699
  4. Specifically call out the tools used and the order in which they were used. Only include tools obtained from calling `get_tool_for_task`.
692
700
  5. Do not include {excluded_tools} tools in your instructions.
693
701
  6. Add final instructions for visualizing the output with `overlay_bounding_boxes` or `overlay_segmentation_masks` and saving it to a file with `save_file` or `save_video`.
694
- 6. Respond in the following format with JSON surrounded by <json> tags and code surrounded by <code> tags:
702
+ 7. Use the default FPS for extracting frames from videos unless otherwise specified by the user.
703
+ 8. Respond in the following format with JSON surrounded by <json> tags and code surrounded by <code> tags:
695
704
 
696
705
  <json>
697
706
  {{
@@ -1,13 +1,14 @@
1
1
  import copy
2
2
  import json
3
3
  from pathlib import Path
4
- from typing import Any, Callable, Dict, List, Optional, Union, cast
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
5
5
 
6
6
  from vision_agent.agent import Agent, AgentCoder, VisionAgentCoderV2
7
7
  from vision_agent.agent.agent_utils import (
8
8
  add_media_to_chat,
9
9
  convert_message_to_agentmessage,
10
10
  extract_tag,
11
+ format_conversation,
11
12
  )
12
13
  from vision_agent.agent.types import (
13
14
  AgentMessage,
@@ -22,19 +23,6 @@ from vision_agent.lmm.types import Message
22
23
  from vision_agent.utils.execute import CodeInterpreter, CodeInterpreterFactory
23
24
 
24
25
 
25
- def format_conversation(chat: List[AgentMessage]) -> str:
26
- chat = copy.deepcopy(chat)
27
- prompt = ""
28
- for chat_i in chat:
29
- if chat_i.role == "user":
30
- prompt += f"USER: {chat_i.content}\n\n"
31
- elif chat_i.role == "observation" or chat_i.role == "coder":
32
- prompt += f"OBSERVATION: {chat_i.content}\n\n"
33
- elif chat_i.role == "conversation":
34
- prompt += f"AGENT: {chat_i.content}\n\n"
35
- return prompt
36
-
37
-
38
26
  def run_conversation(agent: LMM, chat: List[AgentMessage]) -> str:
39
27
  # only keep last 10 messages
40
28
  conv = format_conversation(chat[-10:])
@@ -55,23 +43,39 @@ def check_for_interaction(chat: List[AgentMessage]) -> bool:
55
43
 
56
44
  def extract_conversation_for_generate_code(
57
45
  chat: List[AgentMessage],
58
- ) -> List[AgentMessage]:
46
+ ) -> Tuple[List[AgentMessage], Optional[str]]:
59
47
  chat = copy.deepcopy(chat)
60
48
 
61
49
  # if we are in the middle of an interaction, return all the intermediate planning
62
50
  # steps
63
51
  if check_for_interaction(chat):
64
- return chat
52
+ return chat, None
65
53
 
66
54
  extracted_chat = []
67
55
  for chat_i in chat:
68
56
  if chat_i.role == "user":
69
57
  extracted_chat.append(chat_i)
70
58
  elif chat_i.role == "coder":
71
- if "<final_code>" in chat_i.content and "<final_test>" in chat_i.content:
59
+ if "<final_code>" in chat_i.content:
72
60
  extracted_chat.append(chat_i)
73
61
 
74
- return extracted_chat
62
+ # only keep the last <final_code> and <final_test>
63
+ final_code = None
64
+ extracted_chat_strip_code: List[AgentMessage] = []
65
+ for chat_i in reversed(extracted_chat):
66
+ if "<final_code>" in chat_i.content and final_code is None:
67
+ extracted_chat_strip_code = [chat_i] + extracted_chat_strip_code
68
+ final_code = extract_tag(chat_i.content, "final_code")
69
+ if final_code is not None:
70
+ test_code = extract_tag(chat_i.content, "final_test")
71
+ final_code += "\n" + test_code if test_code is not None else ""
72
+
73
+ if "<final_code>" in chat_i.content and final_code is not None:
74
+ continue
75
+
76
+ extracted_chat_strip_code = [chat_i] + extracted_chat_strip_code
77
+
78
+ return extracted_chat_strip_code[-5:], final_code
75
79
 
76
80
 
77
81
  def maybe_run_action(
@@ -81,7 +85,7 @@ def maybe_run_action(
81
85
  code_interpreter: Optional[CodeInterpreter] = None,
82
86
  ) -> Optional[List[AgentMessage]]:
83
87
  if action == "generate_or_edit_vision_code":
84
- extracted_chat = extract_conversation_for_generate_code(chat)
88
+ extracted_chat, _ = extract_conversation_for_generate_code(chat)
85
89
  # there's an issue here because coder.generate_code will send it's code_context
86
90
  # to the outside user via it's update_callback, but we don't necessarily have
87
91
  # access to that update_callback here, so we re-create the message using
@@ -101,11 +105,15 @@ def maybe_run_action(
101
105
  )
102
106
  ]
103
107
  elif action == "edit_code":
104
- extracted_chat = extract_conversation_for_generate_code(chat)
108
+ extracted_chat, final_code = extract_conversation_for_generate_code(chat)
105
109
  plan_context = PlanContext(
106
110
  plan="Edit the latest code observed in the fewest steps possible according to the user's feedback.",
107
- instructions=[],
108
- code="",
111
+ instructions=[
112
+ chat_i.content
113
+ for chat_i in extracted_chat
114
+ if chat_i.role == "user" and "<final_code>" not in chat_i.content
115
+ ],
116
+ code=final_code if final_code is not None else "",
109
117
  )
110
118
  context = coder.generate_code_from_plan(
111
119
  extracted_chat, plan_context, code_interpreter=code_interpreter
@@ -193,8 +193,10 @@ def get_tool_for_task(
193
193
  - Depth and pose estimation
194
194
  - Video object tracking
195
195
 
196
- Wait until the documentation is printed to use the function so you know what the
197
- input and output signatures are.
196
+ Only ask for one type of task at a time, for example a task needing to identify
197
+ text is one OCR task while needing to identify non-text objects is an OD task. Wait
198
+ until the documentation is printed to use the function so you know what the input
199
+ and output signatures are.
198
200
 
199
201
  Parameters:
200
202
  task: str: The task to accomplish.
@@ -6,7 +6,6 @@ import tempfile
6
6
  import urllib.request
7
7
  from base64 import b64encode
8
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
- from enum import Enum
10
9
  from importlib import resources
11
10
  from pathlib import Path
12
11
  from typing import Any, Dict, List, Optional, Tuple, Union, cast
@@ -54,6 +53,13 @@ from vision_agent.utils.video import (
54
53
  frames_to_bytes,
55
54
  video_writer,
56
55
  )
56
+ from vision_agent.utils.video_tracking import (
57
+ ODModels,
58
+ merge_segments,
59
+ post_process,
60
+ process_segment,
61
+ split_frames_into_segments,
62
+ )
57
63
 
58
64
  register_heif_opener()
59
65
 
@@ -224,12 +230,6 @@ def sam2(
224
230
  return ret["return_data"] # type: ignore
225
231
 
226
232
 
227
- class ODModels(str, Enum):
228
- COUNTGD = "countgd"
229
- FLORENCE2 = "florence2"
230
- OWLV2 = "owlv2"
231
-
232
-
233
233
  def od_sam2_video_tracking(
234
234
  od_model: ODModels,
235
235
  prompt: str,
@@ -237,105 +237,92 @@ def od_sam2_video_tracking(
237
237
  chunk_length: Optional[int] = 10,
238
238
  fine_tune_id: Optional[str] = None,
239
239
  ) -> Dict[str, Any]:
240
- results: List[Optional[List[Dict[str, Any]]]] = [None] * len(frames)
240
+ SEGMENT_SIZE = 50
241
+ OVERLAP = 1 # Number of overlapping frames between segments
241
242
 
242
- if chunk_length is None:
243
- step = 1 # Process every frame
244
- elif chunk_length <= 0:
245
- raise ValueError("chunk_length must be a positive integer or None.")
246
- else:
247
- step = chunk_length # Process frames with the specified step size
243
+ image_size = frames[0].shape[:2]
244
+
245
+ # Split frames into segments with overlap
246
+ segments = split_frames_into_segments(frames, SEGMENT_SIZE, OVERLAP)
247
+
248
+ def _apply_object_detection( # inner method to avoid circular importing issues.
249
+ od_model: ODModels,
250
+ prompt: str,
251
+ segment_index: int,
252
+ frame_number: int,
253
+ fine_tune_id: str,
254
+ segment_frames: list,
255
+ ) -> tuple:
256
+ """
257
+ Applies the specified object detection model to the given image.
258
+
259
+ Args:
260
+ od_model: The object detection model to use.
261
+ prompt: The prompt for the object detection model.
262
+ segment_index: The index of the current segment.
263
+ frame_number: The number of the current frame.
264
+ fine_tune_id: Optional fine-tune ID for the model.
265
+ segment_frames: List of frames for the current segment.
266
+
267
+ Returns:
268
+ A tuple containing the object detection results and the name of the function used.
269
+ """
248
270
 
249
- for idx in range(0, len(frames), step):
250
271
  if od_model == ODModels.COUNTGD:
251
- results[idx] = countgd_object_detection(prompt=prompt, image=frames[idx])
272
+ segment_results = countgd_object_detection(
273
+ prompt=prompt, image=segment_frames[frame_number]
274
+ )
252
275
  function_name = "countgd_object_detection"
276
+
253
277
  elif od_model == ODModels.OWLV2:
254
- results[idx] = owlv2_object_detection(
255
- prompt=prompt, image=frames[idx], fine_tune_id=fine_tune_id
278
+ segment_results = owlv2_object_detection(
279
+ prompt=prompt,
280
+ image=segment_frames[frame_number],
281
+ fine_tune_id=fine_tune_id,
256
282
  )
257
283
  function_name = "owlv2_object_detection"
284
+
258
285
  elif od_model == ODModels.FLORENCE2:
259
- results[idx] = florence2_object_detection(
260
- prompt=prompt, image=frames[idx], fine_tune_id=fine_tune_id
286
+ segment_results = florence2_object_detection(
287
+ prompt=prompt,
288
+ image=segment_frames[frame_number],
289
+ fine_tune_id=fine_tune_id,
261
290
  )
262
291
  function_name = "florence2_object_detection"
292
+
263
293
  else:
264
294
  raise NotImplementedError(
265
295
  f"Object detection model '{od_model}' is not implemented."
266
296
  )
267
297
 
268
- image_size = frames[0].shape[:2]
269
-
270
- def _transform_detections(
271
- input_list: List[Optional[List[Dict[str, Any]]]],
272
- ) -> List[Optional[Dict[str, Any]]]:
273
- output_list: List[Optional[Dict[str, Any]]] = []
274
-
275
- for _, frame in enumerate(input_list):
276
- if frame is not None:
277
- labels = [detection["label"] for detection in frame]
278
- bboxes = [
279
- denormalize_bbox(detection["bbox"], image_size)
280
- for detection in frame
281
- ]
282
-
283
- output_list.append(
284
- {
285
- "labels": labels,
286
- "bboxes": bboxes,
287
- }
288
- )
289
- else:
290
- output_list.append(None)
291
-
292
- return output_list
298
+ return segment_results, function_name
299
+
300
+ # Process each segment and collect detections
301
+ detections_per_segment: List[Any] = []
302
+ for segment_index, segment in enumerate(segments):
303
+ segment_detections = process_segment(
304
+ segment_frames=segment,
305
+ od_model=od_model,
306
+ prompt=prompt,
307
+ fine_tune_id=fine_tune_id,
308
+ chunk_length=chunk_length,
309
+ image_size=image_size,
310
+ segment_index=segment_index,
311
+ object_detection_tool=_apply_object_detection,
312
+ )
313
+ detections_per_segment.append(segment_detections)
293
314
 
294
- output = _transform_detections(results)
315
+ merged_detections = merge_segments(detections_per_segment)
316
+ post_processed = post_process(merged_detections, image_size)
295
317
 
296
318
  buffer_bytes = frames_to_bytes(frames)
297
319
  files = [("video", buffer_bytes)]
298
- payload = {"bboxes": json.dumps(output), "chunk_length_frames": chunk_length}
299
- metadata = {"function_name": function_name}
300
-
301
- detections = send_task_inference_request(
302
- payload,
303
- "sam2",
304
- files=files,
305
- metadata=metadata,
306
- )
307
-
308
- return_data = []
309
- for frame in detections:
310
- return_frame_data = []
311
- for detection in frame:
312
- mask = rle_decode_array(detection["mask"])
313
- label = str(detection["id"]) + ": " + detection["label"]
314
- return_frame_data.append(
315
- {"label": label, "mask": mask, "score": 1.0, "rle": detection["mask"]}
316
- )
317
- return_data.append(return_frame_data)
318
- return_data = add_bboxes_from_masks(return_data)
319
- return_data = nms(return_data, iou_threshold=0.95)
320
320
 
321
- # We save the RLE for display purposes, re-calculting RLE can get very expensive.
322
- # Deleted here because we are returning the numpy masks instead
323
- display_data = []
324
- for frame in return_data:
325
- display_frame_data = []
326
- for obj in frame:
327
- display_frame_data.append(
328
- {
329
- "label": obj["label"],
330
- "score": obj["score"],
331
- "bbox": denormalize_bbox(obj["bbox"], image_size),
332
- "mask": obj["rle"],
333
- }
334
- )
335
- del obj["rle"]
336
- display_data.append(display_frame_data)
337
-
338
- return {"files": files, "return_data": return_data, "display_data": detections}
321
+ return {
322
+ "files": files,
323
+ "return_data": post_processed["return_data"],
324
+ "display_data": post_processed["display_data"],
325
+ }
339
326
 
340
327
 
341
328
  # Owl V2 Tools
@@ -528,24 +515,29 @@ def owlv2_sam2_video_tracking(
528
515
  chunk_length: Optional[int] = 10,
529
516
  fine_tune_id: Optional[str] = None,
530
517
  ) -> List[List[Dict[str, Any]]]:
531
- """'owlv2_sam2_video_tracking' is a tool that can segment multiple objects given a text
532
- prompt such as category names or referring expressions. The categories in the text
533
- prompt are separated by commas. It returns a list of bounding boxes, label names,
534
- mask file names and associated probability scores.
518
+ """'owlv2_sam2_video_tracking' is a tool that can track and segment multiple
519
+ objects in a video given a text prompt such as category names or referring
520
+ expressions. The categories in the text prompt are separated by commas. It returns
521
+ a list of bounding boxes, label names, masks and associated probability scores and
522
+ is useful for tracking and counting without duplicating counts.
535
523
 
536
524
  Parameters:
537
525
  prompt (str): The prompt to ground to the image.
538
- image (np.ndarray): The image to ground the prompt to.
526
+ frames (List[np.ndarray]): The list of frames to ground the prompt to.
527
+ chunk_length (Optional[int]): The number of frames to re-run owlv2 to find
528
+ new objects.
539
529
  fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the
540
530
  fine-tuned model ID here to use it.
541
531
 
542
532
  Returns:
543
- List[Dict[str, Any]]: A list of dictionaries containing the score, label,
544
- bounding box, and mask of the detected objects with normalized coordinates
545
- (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left
546
- and xmax and ymax are the coordinates of the bottom-right of the bounding box.
547
- The mask is binary 2D numpy array where 1 indicates the object and 0 indicates
548
- the background.
533
+ List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
534
+ label, segmentation mask and bounding boxes. The outer list represents each
535
+ frame and the inner list is the entities per frame. The detected objects
536
+ have normalized coordinates between 0 and 1 (xmin, ymin, xmax, ymax). xmin
537
+ and ymin are the coordinates of the top-left and xmax and ymax are the
538
+ coordinates of the bottom-right of the bounding box. The mask is binary 2D
539
+ numpy array where 1 indicates the object and 0 indicates the background.
540
+ The label names are prefixed with their ID represent the total count.
549
541
 
550
542
  Example
551
543
  -------
@@ -755,11 +747,11 @@ def florence2_sam2_video_tracking(
755
747
  chunk_length: Optional[int] = 10,
756
748
  fine_tune_id: Optional[str] = None,
757
749
  ) -> List[List[Dict[str, Any]]]:
758
- """'florence2_sam2_video_tracking' is a tool that can segment and track multiple
759
- entities in a video given a text prompt such as category names or referring
760
- expressions. You can optionally separate the categories in the text with commas. It
761
- can find new objects every 'chunk_length' frames and is useful for tracking and
762
- counting without duplicating counts and always outputs scores of 1.0.
750
+ """'florence2_sam2_video_tracking' is a tool that can track and segment multiple
751
+ objects in a video given a text prompt such as category names or referring
752
+ expressions. The categories in the text prompt are separated by commas. It returns
753
+ a list of bounding boxes, label names, masks and associated probability scores and
754
+ is useful for tracking and counting without duplicating counts.
763
755
 
764
756
  Parameters:
765
757
  prompt (str): The prompt to ground to the video.
@@ -771,10 +763,13 @@ def florence2_sam2_video_tracking(
771
763
 
772
764
  Returns:
773
765
  List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
774
- label, segment mask and bounding boxes. The outer list represents each frame
775
- and the inner list is the entities per frame. The label contains the object ID
776
- followed by the label name. The objects are only identified in the first framed
777
- and tracked throughout the video.
766
+ label, segmentation mask and bounding boxes. The outer list represents each
767
+ frame and the inner list is the entities per frame. The detected objects
768
+ have normalized coordinates between 0 and 1 (xmin, ymin, xmax, ymax). xmin
769
+ and ymin are the coordinates of the top-left and xmax and ymax are the
770
+ coordinates of the bottom-right of the bounding box. The mask is binary 2D
771
+ numpy array where 1 indicates the object and 0 indicates the background.
772
+ The label names are prefixed with their ID represent the total count.
778
773
 
779
774
  Example
780
775
  -------
@@ -1089,24 +1084,27 @@ def countgd_sam2_video_tracking(
1089
1084
  frames: List[np.ndarray],
1090
1085
  chunk_length: Optional[int] = 10,
1091
1086
  ) -> List[List[Dict[str, Any]]]:
1092
- """'countgd_sam2_video_tracking' is a tool that can segment multiple objects given a text
1093
- prompt such as category names or referring expressions. The categories in the text
1094
- prompt are separated by commas. It returns a list of bounding boxes, label names,
1095
- mask file names and associated probability scores.
1087
+ """'countgd_sam2_video_tracking' is a tool that can track and segment multiple
1088
+ objects in a video given a text prompt such as category names or referring
1089
+ expressions. The categories in the text prompt are separated by commas. It returns
1090
+ a list of bounding boxes, label names, masks and associated probability scores and
1091
+ is useful for tracking and counting without duplicating counts.
1096
1092
 
1097
1093
  Parameters:
1098
1094
  prompt (str): The prompt to ground to the image.
1099
- image (np.ndarray): The image to ground the prompt to.
1100
- chunk_length (Optional[int]): The number of frames to re-run florence2 to find
1095
+ frames (List[np.ndarray]): The list of frames to ground the prompt to.
1096
+ chunk_length (Optional[int]): The number of frames to re-run countgd to find
1101
1097
  new objects.
1102
1098
 
1103
1099
  Returns:
1104
- List[Dict[str, Any]]: A list of dictionaries containing the score, label,
1105
- bounding box, and mask of the detected objects with normalized coordinates
1106
- (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left
1107
- and xmax and ymax are the coordinates of the bottom-right of the bounding box.
1108
- The mask is binary 2D numpy array where 1 indicates the object and 0 indicates
1109
- the background.
1100
+ List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
1101
+ label, segmentation mask and bounding boxes. The outer list represents each
1102
+ frame and the inner list is the entities per frame. The detected objects
1103
+ have normalized coordinates between 0 and 1 (xmin, ymin, xmax, ymax). xmin
1104
+ and ymin are the coordinates of the top-left and xmax and ymax are the
1105
+ coordinates of the bottom-right of the bounding box. The mask is binary 2D
1106
+ numpy array where 1 indicates the object and 0 indicates the background.
1107
+ The label names are prefixed with their ID represent the total count.
1110
1108
 
1111
1109
  Example
1112
1110
  -------
@@ -1546,7 +1544,7 @@ def video_temporal_localization(
1546
1544
  prompt (str): The question about the video
1547
1545
  frames (List[np.ndarray]): The reference frames used for the question
1548
1546
  model (str): The model to use for the inference. Valid values are
1549
- 'qwen2vl', 'gpt4o', 'internlm-xcomposer'
1547
+ 'qwen2vl', 'gpt4o'.
1550
1548
  chunk_length_frames (Optional[int]): length of each chunk in frames
1551
1549
 
1552
1550
  Returns:
@@ -2115,7 +2113,7 @@ def closest_box_distance(
2115
2113
 
2116
2114
 
2117
2115
  def extract_frames_and_timestamps(
2118
- video_uri: Union[str, Path], fps: float = 1
2116
+ video_uri: Union[str, Path], fps: float = 5
2119
2117
  ) -> List[Dict[str, Union[np.ndarray, float]]]:
2120
2118
  """'extract_frames_and_timestamps' extracts frames and timestamps from a video
2121
2119
  which can be a file path, url or youtube link, returns a list of dictionaries
@@ -2126,7 +2124,7 @@ def extract_frames_and_timestamps(
2126
2124
  Parameters:
2127
2125
  video_uri (Union[str, Path]): The path to the video file, url or youtube link
2128
2126
  fps (float, optional): The frame rate per second to extract the frames. Defaults
2129
- to 1.
2127
+ to 5.
2130
2128
 
2131
2129
  Returns:
2132
2130
  List[Dict[str, Union[np.ndarray, float]]]: A list of dictionaries containing the
@@ -2649,10 +2647,8 @@ FUNCTION_TOOLS = [
2649
2647
  ocr,
2650
2648
  qwen2_vl_images_vqa,
2651
2649
  qwen2_vl_video_vqa,
2652
- detr_segmentation,
2653
2650
  depth_anything_v2,
2654
2651
  generate_pose_image,
2655
- vit_image_classification,
2656
2652
  vit_nsfw_classification,
2657
2653
  video_temporal_localization,
2658
2654
  flux_image_inpainting,
vision_agent/utils/sim.py CHANGED
@@ -133,6 +133,12 @@ class Sim:
133
133
  df: pd.DataFrame,
134
134
  ) -> bool:
135
135
  load_dir = Path(load_dir)
136
+ if (
137
+ not Path(load_dir / "df.csv").exists()
138
+ or not Path(load_dir / "embs.npy").exists()
139
+ ):
140
+ return False
141
+
136
142
  df_load = pd.read_csv(load_dir / "df.csv")
137
143
  if platform.system() == "Windows":
138
144
  df_load["doc"] = df_load["doc"].apply(lambda x: x.replace("\r", ""))
@@ -0,0 +1,305 @@
1
+ import json
2
+ from enum import Enum
3
+ from typing import Any, Callable, Dict, List, Optional, Tuple
4
+
5
+ import numpy as np
6
+
7
+ from vision_agent.tools.tool_utils import (
8
+ add_bboxes_from_masks,
9
+ nms,
10
+ send_task_inference_request,
11
+ )
12
+ from vision_agent.utils.image_utils import denormalize_bbox, rle_decode_array
13
+ from vision_agent.utils.video import frames_to_bytes
14
+
15
+
16
+ class ODModels(str, Enum):
17
+ COUNTGD = "countgd"
18
+ FLORENCE2 = "florence2"
19
+ OWLV2 = "owlv2"
20
+
21
+
22
+ def split_frames_into_segments(
23
+ frames: List[np.ndarray], segment_size: int = 50, overlap: int = 1
24
+ ) -> List[List[np.ndarray]]:
25
+ """
26
+ Splits the list of frames into segments with a specified size and overlap.
27
+
28
+ Args:
29
+ frames (List[np.ndarray]): List of video frames.
30
+ segment_size (int, optional): Number of frames per segment. Defaults to 50.
31
+ overlap (int, optional): Number of overlapping frames between segments. Defaults to 1.
32
+
33
+ Returns:
34
+ List[List[np.ndarray]]: List of frame segments.
35
+ """
36
+ segments = []
37
+ start = 0
38
+ segment_count = 0
39
+ while start < len(frames):
40
+ end = start + segment_size
41
+ if end > len(frames):
42
+ end = len(frames)
43
+ if start != 0:
44
+ # Include the last frame of the previous segment
45
+ segment = frames[start - overlap : end]
46
+ else:
47
+ segment = frames[start:end]
48
+ segments.append(segment)
49
+ start += segment_size
50
+ segment_count += 1
51
+ return segments
52
+
53
+
54
+ def process_segment(
55
+ segment_frames: List[np.ndarray],
56
+ od_model: ODModels,
57
+ prompt: str,
58
+ fine_tune_id: Optional[str],
59
+ chunk_length: Optional[int],
60
+ image_size: Tuple[int, ...],
61
+ segment_index: int,
62
+ object_detection_tool: Callable,
63
+ ) -> Any:
64
+ """
65
+ Processes a segment of frames with the specified object detection model.
66
+
67
+ Args:
68
+ segment_frames (List[np.ndarray]): Frames in the segment.
69
+ od_model (ODModels): Object detection model to use.
70
+ prompt (str): Prompt for the model.
71
+ fine_tune_id (Optional[str]): Fine-tune model ID.
72
+ chunk_length (Optional[int]): Chunk length for processing.
73
+ image_size (Tuple[int, int]): Size of the images.
74
+ segment_index (int): Index of the segment.
75
+ object_detection_tool (Callable): Object detection tool to use.
76
+
77
+ Returns:
78
+ Any: Detections for the segment.
79
+ """
80
+ segment_results: List[Optional[List[Dict[str, Any]]]] = [None] * len(segment_frames)
81
+
82
+ if chunk_length is None:
83
+ step = 1
84
+ elif chunk_length <= 0:
85
+ raise ValueError("chunk_length must be a positive integer or None.")
86
+ else:
87
+ step = chunk_length
88
+
89
+ function_name = ""
90
+
91
+ for idx in range(0, len(segment_frames), step):
92
+ frame_number = idx
93
+ segment_results[idx], function_name = object_detection_tool(
94
+ od_model, prompt, segment_index, frame_number, fine_tune_id, segment_frames
95
+ )
96
+
97
+ transformed_detections = transform_detections(
98
+ segment_results, image_size, segment_index
99
+ )
100
+
101
+ buffer_bytes = frames_to_bytes(segment_frames)
102
+ files = [("video", buffer_bytes)]
103
+ payload = {
104
+ "bboxes": json.dumps(transformed_detections),
105
+ "chunk_length_frames": chunk_length,
106
+ }
107
+ metadata = {"function_name": function_name}
108
+
109
+ segment_detections = send_task_inference_request(
110
+ payload,
111
+ "sam2",
112
+ files=files,
113
+ metadata=metadata,
114
+ )
115
+
116
+ return segment_detections
117
+
118
+
119
+ def transform_detections(
120
+ input_list: List[Optional[List[Dict[str, Any]]]],
121
+ image_size: Tuple[int, ...],
122
+ segment_index: int,
123
+ ) -> List[Optional[Dict[str, Any]]]:
124
+ """
125
+ Transforms raw detections into a standardized format.
126
+
127
+ Args:
128
+ input_list (List[Optional[List[Dict[str, Any]]]]): Raw detections.
129
+ image_size (Tuple[int, int]): Size of the images.
130
+ segment_index (int): Index of the segment.
131
+
132
+ Returns:
133
+ List[Optional[Dict[str, Any]]]: Transformed detections.
134
+ """
135
+ output_list: List[Optional[Dict[str, Any]]] = []
136
+ for frame_idx, frame in enumerate(input_list):
137
+ if frame is not None:
138
+ labels = [detection["label"] for detection in frame]
139
+ bboxes = [
140
+ denormalize_bbox(detection["bbox"], image_size) for detection in frame
141
+ ]
142
+
143
+ output_list.append(
144
+ {
145
+ "labels": labels,
146
+ "bboxes": bboxes,
147
+ }
148
+ )
149
+ else:
150
+ output_list.append(None)
151
+ return output_list
152
+
153
+
154
+ def _calculate_mask_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
155
+ mask1 = mask1.astype(bool)
156
+ mask2 = mask2.astype(bool)
157
+
158
+ intersection = np.sum(np.logical_and(mask1, mask2))
159
+ union = np.sum(np.logical_or(mask1, mask2))
160
+
161
+ if union == 0:
162
+ iou = 0.0
163
+ else:
164
+ iou = intersection / union
165
+
166
+ return iou
167
+
168
+
169
+ def _match_by_iou(
170
+ first_param: List[Dict],
171
+ second_param: List[Dict],
172
+ iou_threshold: float = 0.8,
173
+ ) -> Tuple[List[Dict], Dict[int, int]]:
174
+ max_id = max((item["id"] for item in first_param), default=0)
175
+
176
+ matched_new_item_indices = set()
177
+ id_mapping = {}
178
+
179
+ for new_index, new_item in enumerate(second_param):
180
+ matched_id = None
181
+
182
+ for existing_item in first_param:
183
+ iou = _calculate_mask_iou(
184
+ existing_item["decoded_mask"], new_item["decoded_mask"]
185
+ )
186
+ if iou > iou_threshold:
187
+ matched_id = existing_item["id"]
188
+ matched_new_item_indices.add(new_index)
189
+ id_mapping[new_item["id"]] = matched_id
190
+ break
191
+
192
+ if matched_id:
193
+ new_item["id"] = matched_id
194
+ else:
195
+ max_id += 1
196
+ id_mapping[new_item["id"]] = max_id
197
+ new_item["id"] = max_id
198
+
199
+ unmatched_items = [
200
+ item for i, item in enumerate(second_param) if i not in matched_new_item_indices
201
+ ]
202
+ combined_list = first_param + unmatched_items
203
+
204
+ return combined_list, id_mapping
205
+
206
+
207
+ def _update_ids(detections: List[Dict], id_mapping: Dict[int, int]) -> None:
208
+ for inner_list in detections:
209
+ for detection in inner_list:
210
+ if detection["id"] in id_mapping:
211
+ detection["id"] = id_mapping[detection["id"]]
212
+ else:
213
+ max_new_id = max(id_mapping.values(), default=0)
214
+ detection["id"] = max_new_id + 1
215
+ id_mapping[detection["id"]] = detection["id"]
216
+
217
+
218
+ def _convert_to_2d(detections_per_segment: List[Any]) -> List[Any]:
219
+ result = []
220
+ for i, segment in enumerate(detections_per_segment):
221
+ if i == 0:
222
+ result.extend(segment)
223
+ else:
224
+ result.extend(segment[1:])
225
+ return result
226
+
227
+
228
+ def merge_segments(detections_per_segment: List[Any]) -> List[Any]:
229
+ """
230
+ Merges detections from all segments into a unified result.
231
+
232
+ Args:
233
+ detections_per_segment (List[Any]): List of detections per segment.
234
+
235
+ Returns:
236
+ List[Any]: Merged detections.
237
+ """
238
+ for segment in detections_per_segment:
239
+ for detection in segment:
240
+ for item in detection:
241
+ item["decoded_mask"] = rle_decode_array(item["mask"])
242
+
243
+ for segment_idx in range(len(detections_per_segment) - 1):
244
+ combined_detection, id_mapping = _match_by_iou(
245
+ detections_per_segment[segment_idx][-1],
246
+ detections_per_segment[segment_idx + 1][0],
247
+ )
248
+ _update_ids(detections_per_segment[segment_idx + 1], id_mapping)
249
+
250
+ merged_result = _convert_to_2d(detections_per_segment)
251
+
252
+ return merged_result
253
+
254
+
255
+ def post_process(
256
+ merged_detections: List[Any],
257
+ image_size: Tuple[int, ...],
258
+ ) -> Dict[str, Any]:
259
+ """
260
+ Performs post-processing on merged detections, including NMS and preparing display data.
261
+
262
+ Args:
263
+ merged_detections (List[Any]): Merged detections from all segments.
264
+ image_size (Tuple[int, int]): Size of the images.
265
+
266
+ Returns:
267
+ Dict[str, Any]: Post-processed data including return_data and display_data.
268
+ """
269
+ return_data = []
270
+ for frame_idx, frame in enumerate(merged_detections):
271
+ return_frame_data = []
272
+ for detection in frame:
273
+ label = f"{detection['id']}: {detection['label']}"
274
+ return_frame_data.append(
275
+ {
276
+ "label": label,
277
+ "mask": detection["decoded_mask"],
278
+ "rle": detection["mask"],
279
+ "score": 1.0,
280
+ }
281
+ )
282
+ del detection["decoded_mask"]
283
+ return_data.append(return_frame_data)
284
+
285
+ return_data = add_bboxes_from_masks(return_data)
286
+ return_data = nms(return_data, iou_threshold=0.95)
287
+
288
+ # We save the RLE for display purposes, re-calculting RLE can get very expensive.
289
+ # Deleted here because we are returning the numpy masks instead
290
+ display_data = []
291
+ for frame in return_data:
292
+ display_frame_data = []
293
+ for obj in frame:
294
+ display_frame_data.append(
295
+ {
296
+ "label": obj["label"],
297
+ "bbox": denormalize_bbox(obj["bbox"], image_size),
298
+ "mask": obj["rle"],
299
+ "score": obj["score"],
300
+ }
301
+ )
302
+ del obj["rle"]
303
+ display_data.append(display_frame_data)
304
+
305
+ return {"return_data": return_data, "display_data": display_data}
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vision-agent
3
- Version: 0.2.224
3
+ Version: 0.2.226
4
4
  Summary: Toolset for Vision Agent
5
5
  Author: Landing AI
6
6
  Author-email: dev@landing.ai
@@ -1,23 +1,23 @@
1
- vision_agent/.sim_tools/df.csv,sha256=1cpUFKN48Iuq6HvaG5OhbHs2RghESicb3ouKVTLhm-s,43360
2
- vision_agent/.sim_tools/embs.npy,sha256=Nji50P_8aV0hhyxp-Kfh_YmXAFWTHSwwBGrNJWitHPU,270464
1
+ vision_agent/.sim_tools/df.csv,sha256=Vamicw8MiSGildK1r3-HXY4cKiq17GZxsgBsHbk7jpM,42158
2
+ vision_agent/.sim_tools/embs.npy,sha256=YJe8EcKVNmeX_75CS2T1sbY-sUS_1HQAMT-34zc18a0,254080
3
3
  vision_agent/__init__.py,sha256=EAb4-f9iyuEYkBrX4ag1syM8Syx8118_t0R6_C34M9w,57
4
4
  vision_agent/agent/README.md,sha256=Q4w7FWw38qaWosQYAZ7NqWx8Q5XzuWrlv7nLhjUd1-8,5527
5
5
  vision_agent/agent/__init__.py,sha256=M8CffavdIh8Zh-skznLHIaQkYGCGK7vk4dq1FaVkbs4,617
6
6
  vision_agent/agent/agent.py,sha256=_1tHWAs7Jm5tqDzEcPfCRvJV3uRRveyh4n9_9pd6I1w,1565
7
- vision_agent/agent/agent_utils.py,sha256=NmrqjhSb6fpnrB8XGWtaywZjr9n89otusOZpcbWLf9k,13534
7
+ vision_agent/agent/agent_utils.py,sha256=pP4u5tiami7C3ChgjgYLqJITnmkTI1_GsUj6g5czSRk,13994
8
8
  vision_agent/agent/types.py,sha256=DkFm3VMMrKlhYyfxEmZx4keppD72Ov3wmLCbM2J2o10,2437
9
9
  vision_agent/agent/vision_agent.py,sha256=I75bEU-os9Lf9OSICKfvQ_H_ftg-zOwgTwWnu41oIdo,23555
10
10
  vision_agent/agent/vision_agent_coder.py,sha256=flUxOibyGZK19BCSK5mhaD3HjCxHw6c6FtKom6N2q1E,27359
11
11
  vision_agent/agent/vision_agent_coder_prompts.py,sha256=gPLVXQMNSzYnQYpNm0wlH_5FPkOTaFDV24bqzK3jQ40,12221
12
- vision_agent/agent/vision_agent_coder_prompts_v2.py,sha256=9v5HwbNidSzYUEFl6ZMniWWOmyLITM_moWLtKVaTen8,4845
13
- vision_agent/agent/vision_agent_coder_v2.py,sha256=G3I8O89gzE2VczQGPWV149aYaOjbbfB1lmgGuwFWvo4,16118
12
+ vision_agent/agent/vision_agent_coder_prompts_v2.py,sha256=idmSMfxebPULqqvllz3gqRzGDchEvS5dkGngvBs4PGo,4872
13
+ vision_agent/agent/vision_agent_coder_v2.py,sha256=i1qgXp5YsWVRoA_qO429Ef-aKZBakveCl1F_2ZbSzk8,16287
14
14
  vision_agent/agent/vision_agent_planner.py,sha256=fFzjNkZBKkh8Y_oS06ATI4qz31xmIJvixb_tV1kX8KA,18590
15
15
  vision_agent/agent/vision_agent_planner_prompts.py,sha256=mn9NlZpRkW4XAvlNuMZwIs1ieHCFds5aYZJ55WXupZY,6733
16
- vision_agent/agent/vision_agent_planner_prompts_v2.py,sha256=lzfJFvBYW_-Ue4OevgljI8bAQxgKC4Rdv5SmP6UsAxE,34102
16
+ vision_agent/agent/vision_agent_planner_prompts_v2.py,sha256=YgemW2PRPYd8o8XpmwSJBUOJSxMUXMNr2DZNQnS4jEI,34988
17
17
  vision_agent/agent/vision_agent_planner_v2.py,sha256=vvxfmGydBIKB8CtNSAJyPvdEXkG7nIO5-Hs2SjNc48Y,20465
18
18
  vision_agent/agent/vision_agent_prompts.py,sha256=NtGdCfzzilCRtscKALC9FK55d1h4CBpMnbhLzg0PYlc,13772
19
19
  vision_agent/agent/vision_agent_prompts_v2.py,sha256=-vCWat-ARlCOOOeIDIFhg-kcwRRwjTXYEwsvvqPeaCs,1972
20
- vision_agent/agent/vision_agent_v2.py,sha256=6gGVV3FlL4NLzHRpjMqMz-fEP6f_JhwwOjUKczZ3TPA,10231
20
+ vision_agent/agent/vision_agent_v2.py,sha256=1wu_vH_onic2kLYPKW2nAF2e6Zz5vmUt5Acv4Seq3sQ,10796
21
21
  vision_agent/clients/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
22
  vision_agent/clients/http.py,sha256=k883i6M_4nl7zwwHSI-yP5sAgQZIDPM1nrKD6YFJ3Xs,2009
23
23
  vision_agent/clients/landing_public_api.py,sha256=lU2ev6E8NICmR8DMUljuGcVFy5VNJQ4WQkWC8WnnJEc,1503
@@ -28,19 +28,20 @@ vision_agent/lmm/lmm.py,sha256=x_nIyDNDZwq4-pfjnJTmcyyJZ2_B7TjkA5jZp88YVO8,17103
28
28
  vision_agent/lmm/types.py,sha256=ZEXR_ptBL0ZwDMTDYkgxUCmSZFmBYPQd2jreNzr_8UY,221
29
29
  vision_agent/tools/__init__.py,sha256=15O7eQVn0bitmzUO5OxKdA618PoiLt6Z02gmKsSNMFM,2765
30
30
  vision_agent/tools/meta_tools.py,sha256=TPeS7QWnc_PmmU_ndiDT03dXbQ5yDSP33E7U8cSj7Ls,28660
31
- vision_agent/tools/planner_tools.py,sha256=CvaJ2vGM8O_CYvsoSk1avxAMqpIu3tv4C2bY0p1X-X4,13519
31
+ vision_agent/tools/planner_tools.py,sha256=qQvPuCif-KbFi7KsXKkTCfpgEQEJJ6oq6WB3gOuG2Xg,13686
32
32
  vision_agent/tools/prompts.py,sha256=V1z4YJLXZuUl_iZ5rY0M5hHc_2tmMEUKr0WocXKGt4E,1430
33
33
  vision_agent/tools/tool_utils.py,sha256=q9cqXO2AvigUdO1krjnOy8o0goYhgS6eILl6-F5Kxyk,10211
34
- vision_agent/tools/tools.py,sha256=60S5ItFG9yKzVb8FU8oLFj_aouDg2-4vlieDbSgfPdQ,91306
34
+ vision_agent/tools/tools.py,sha256=zqoo4ml9ZS99kOeOIN6Zplq7pxOwBrVZKKFUVIzsjfw,91712
35
35
  vision_agent/tools/tools_types.py,sha256=8hYf2OZhI58gvf65KGaeGkt4EQ56nwLFqIQDPHioOBc,2339
36
36
  vision_agent/utils/__init__.py,sha256=QKk4zVjMwGxQI0MQ-aZZA50N-qItxRY4EB9CwQkZ2HY,185
37
37
  vision_agent/utils/exceptions.py,sha256=booSPSuoULF7OXRr_YbC4dtKt6gM_HyiFQHBuaW86C4,2052
38
38
  vision_agent/utils/execute.py,sha256=vOEP5Ys7S2lc0_7pOJbgk7OaWi85hrCNu9_8Bo3zk6I,29356
39
39
  vision_agent/utils/image_utils.py,sha256=z_ONgcza125B10NkoGwPOzXnL470bpTWZbkB16NeeH0,12188
40
- vision_agent/utils/sim.py,sha256=znsInUDrsyBi3OlgAlV3rDn5UQQRfJAWXTXm7D7eJA8,9125
40
+ vision_agent/utils/sim.py,sha256=qr-6UWAxxGwtwIAKZjZCY_pu9VwBI_TTB8bfrGsaABg,9282
41
41
  vision_agent/utils/type_defs.py,sha256=BE12s3JNQy36QvauXHjwyeffVh5enfcvd4vTzSwvEZI,1384
42
42
  vision_agent/utils/video.py,sha256=e1VwKhXzzlC5LcFMyrcQYrPnpnX4wxDpnQ-76sB4jgM,6001
43
- vision_agent-0.2.224.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
44
- vision_agent-0.2.224.dist-info/METADATA,sha256=wT49_byW9-Oz6-1eSlP3cW_AFGbWaxtKrYsGB4nT62o,20039
45
- vision_agent-0.2.224.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
46
- vision_agent-0.2.224.dist-info/RECORD,,
43
+ vision_agent/utils/video_tracking.py,sha256=EeOiSY8gjvvneuAnv-BO7yOyMBF_-1Irk_lLLOt3bDM,9452
44
+ vision_agent-0.2.226.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
45
+ vision_agent-0.2.226.dist-info/METADATA,sha256=_7jZokNbQLK6Ups2psyRKbPDjUIzU3daxCpfrHZ6gSU,20039
46
+ vision_agent-0.2.226.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
47
+ vision_agent-0.2.226.dist-info/RECORD,,