vision-agent 0.2.228__py3-none-any.whl → 0.2.230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,160 @@
1
+ from typing import Type
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+ from vision_agent.lmm import LMM, OpenAILMM
6
+
7
+
8
+ class Config(BaseModel):
9
+ # for vision_agent_v2
10
+ agent: Type[LMM] = Field(default=OpenAILMM)
11
+ agent_kwargs: dict = Field(
12
+ default_factory=lambda: {
13
+ "model_name": "gpt-4o-2024-08-06",
14
+ "temperature": 0.0,
15
+ "image_size": 768,
16
+ "image_detail": "low",
17
+ }
18
+ )
19
+
20
+ # for vision_agent_planner_v2
21
+ planner: Type[LMM] = Field(default=OpenAILMM)
22
+ planner_kwargs: dict = Field(
23
+ default_factory=lambda: {
24
+ "model_name": "gpt-4o-2024-08-06",
25
+ "temperature": 0.0,
26
+ "image_size": 768,
27
+ "image_detail": "low",
28
+ }
29
+ )
30
+
31
+ # for vision_agent_planner_v2
32
+ summarizer: Type[LMM] = Field(default=OpenAILMM)
33
+ summarizer_kwargs: dict = Field(
34
+ default_factory=lambda: {
35
+ "model_name": "o1",
36
+ "temperature": 1.0,
37
+ "image_size": 768,
38
+ }
39
+ )
40
+
41
+ # for vision_agent_planner_v2
42
+ critic: Type[LMM] = Field(default=OpenAILMM)
43
+ critic_kwargs: dict = Field(
44
+ default_factory=lambda: {
45
+ "model_name": "gpt-4o-2024-08-06",
46
+ "temperature": 0.0,
47
+ "image_size": 768,
48
+ "image_detail": "low",
49
+ }
50
+ )
51
+
52
+ # for vision_agent_coder_v2
53
+ coder: Type[LMM] = Field(default=OpenAILMM)
54
+ coder_kwargs: dict = Field(
55
+ default_factory=lambda: {
56
+ "model_name": "gpt-4o-2024-08-06",
57
+ "temperature": 0.0,
58
+ "image_size": 768,
59
+ "image_detail": "low",
60
+ }
61
+ )
62
+
63
+ # for vision_agent_coder_v2
64
+ tester: Type[LMM] = Field(default=OpenAILMM)
65
+ tester_kwargs: dict = Field(
66
+ default_factory=lambda: {
67
+ "model_name": "gpt-4o-2024-08-06",
68
+ "temperature": 0.0,
69
+ "image_size": 768,
70
+ "image_detail": "low",
71
+ }
72
+ )
73
+
74
+ # for vision_agent_coder_v2
75
+ debugger: Type[LMM] = Field(default=OpenAILMM)
76
+ debugger_kwargs: dict = Field(
77
+ default_factory=lambda: {
78
+ "model_name": "gpt-4o-2024-08-06",
79
+ "temperature": 0.0,
80
+ "image_size": 768,
81
+ "image_detail": "low",
82
+ }
83
+ )
84
+
85
+ # for get_tool_for_task
86
+ tool_tester: Type[LMM] = Field(default=OpenAILMM)
87
+ tool_tester_kwargs: dict = Field(
88
+ default_factory=lambda: {
89
+ "model_name": "gpt-4o-2024-08-06",
90
+ "temperature": 0.0,
91
+ "image_size": 768,
92
+ "image_detail": "low",
93
+ }
94
+ )
95
+
96
+ # for get_tool_for_task
97
+ tool_chooser: Type[LMM] = Field(default=OpenAILMM)
98
+ tool_chooser_kwargs: dict = Field(
99
+ default_factory=lambda: {
100
+ "model_name": "gpt-4o-2024-08-06",
101
+ "temperature": 0.0,
102
+ "image_size": 768,
103
+ "image_detail": "low",
104
+ }
105
+ )
106
+
107
+ # for suggestions module
108
+ suggester: Type[LMM] = Field(default=OpenAILMM)
109
+ suggester_kwargs: dict = Field(
110
+ default_factory=lambda: {
111
+ "model_name": "gpt-4o-2024-08-06",
112
+ "temperature": 0.0,
113
+ "image_size": 768,
114
+ "image_detail": "low",
115
+ }
116
+ )
117
+
118
+ # for vqa module
119
+ vqa: Type[LMM] = Field(default=OpenAILMM)
120
+ vqa_kwargs: dict = Field(
121
+ default_factory=lambda: {
122
+ "model_name": "gpt-4o-2024-08-06",
123
+ "temperature": 0.0,
124
+ "image_size": 768,
125
+ "image_detail": "low",
126
+ }
127
+ )
128
+
129
+ def create_agent(self) -> LMM:
130
+ return self.agent(**self.agent_kwargs)
131
+
132
+ def create_planner(self) -> LMM:
133
+ return self.planner(**self.planner_kwargs)
134
+
135
+ def create_summarizer(self) -> LMM:
136
+ return self.summarizer(**self.summarizer_kwargs)
137
+
138
+ def create_critic(self) -> LMM:
139
+ return self.critic(**self.critic_kwargs)
140
+
141
+ def create_coder(self) -> LMM:
142
+ return self.coder(**self.coder_kwargs)
143
+
144
+ def create_tester(self) -> LMM:
145
+ return self.tester(**self.tester_kwargs)
146
+
147
+ def create_debugger(self) -> LMM:
148
+ return self.debugger(**self.debugger_kwargs)
149
+
150
+ def create_tool_tester(self) -> LMM:
151
+ return self.tool_tester(**self.tool_tester_kwargs)
152
+
153
+ def create_tool_chooser(self) -> LMM:
154
+ return self.tool_chooser(**self.tool_chooser_kwargs)
155
+
156
+ def create_suggester(self) -> LMM:
157
+ return self.suggester(**self.suggester_kwargs)
158
+
159
+ def create_vqa(self) -> LMM:
160
+ return self.vqa(**self.vqa_kwargs)
@@ -1,2 +1,2 @@
1
- from .lmm import LMM, AnthropicLMM, AzureOpenAILMM, OllamaLMM, OpenAILMM
1
+ from .lmm import LMM, AnthropicLMM, AzureOpenAILMM, GoogleLMM, OllamaLMM, OpenAILMM
2
2
  from .types import Message
vision_agent/lmm/lmm.py CHANGED
@@ -50,6 +50,8 @@ class OpenAILMM(LMM):
50
50
  api_key: Optional[str] = None,
51
51
  max_tokens: int = 4096,
52
52
  json_mode: bool = False,
53
+ image_size: int = 768,
54
+ image_detail: str = "low",
53
55
  **kwargs: Any,
54
56
  ):
55
57
  if not api_key:
@@ -59,7 +61,10 @@ class OpenAILMM(LMM):
59
61
 
60
62
  self.client = OpenAI(api_key=api_key)
61
63
  self.model_name = model_name
62
- if "max_tokens" not in kwargs:
64
+ self.image_size = image_size
65
+ self.image_detail = image_detail
66
+ # o1 does not use max_tokens
67
+ if "max_tokens" not in kwargs and not model_name.startswith("o1"):
63
68
  kwargs["max_tokens"] = max_tokens
64
69
  if json_mode:
65
70
  kwargs["response_format"] = {"type": "json_object"}
@@ -94,7 +99,13 @@ class OpenAILMM(LMM):
94
99
  fixed_c["content"] = [{"type": "text", "text": c["content"]}] # type: ignore
95
100
  if "media" in c:
96
101
  for media in c["media"]:
97
- encoded_media = encode_media(cast(str, media))
102
+ resize = kwargs["resize"] if "resize" in kwargs else self.image_size
103
+ image_detail = (
104
+ kwargs["image_detail"]
105
+ if "image_detail" in kwargs
106
+ else self.image_detail
107
+ )
108
+ encoded_media = encode_media(cast(str, media), resize=resize)
98
109
 
99
110
  fixed_c["content"].append( # type: ignore
100
111
  {
@@ -106,7 +117,7 @@ class OpenAILMM(LMM):
106
117
  or encoded_media.startswith("data:image/")
107
118
  else f"data:image/png;base64,{encoded_media}"
108
119
  ),
109
- "detail": "low",
120
+ "detail": image_detail,
110
121
  },
111
122
  },
112
123
  )
@@ -144,7 +155,13 @@ class OpenAILMM(LMM):
144
155
  ]
145
156
  if media and len(media) > 0:
146
157
  for m in media:
147
- encoded_media = encode_media(m)
158
+ resize = kwargs["resize"] if "resize" in kwargs else None
159
+ image_detail = (
160
+ kwargs["image_detail"]
161
+ if "image_detail" in kwargs
162
+ else self.image_detail
163
+ )
164
+ encoded_media = encode_media(m, resize=resize)
148
165
  message[0]["content"].append(
149
166
  {
150
167
  "type": "image_url",
@@ -155,7 +172,7 @@ class OpenAILMM(LMM):
155
172
  or encoded_media.startswith("data:image/")
156
173
  else f"data:image/png;base64,{encoded_media}"
157
174
  ),
158
- "detail": "low",
175
+ "detail": image_detail,
159
176
  },
160
177
  },
161
178
  )
@@ -186,6 +203,7 @@ class AzureOpenAILMM(OpenAILMM):
186
203
  azure_endpoint: Optional[str] = None,
187
204
  max_tokens: int = 4096,
188
205
  json_mode: bool = False,
206
+ image_detail: str = "low",
189
207
  **kwargs: Any,
190
208
  ):
191
209
  if not api_key:
@@ -208,6 +226,7 @@ class AzureOpenAILMM(OpenAILMM):
208
226
  azure_endpoint=azure_endpoint,
209
227
  )
210
228
  self.model_name = model_name
229
+ self.image_detail = image_detail
211
230
 
212
231
  if "max_tokens" not in kwargs:
213
232
  kwargs["max_tokens"] = max_tokens
@@ -225,6 +244,7 @@ class OllamaLMM(LMM):
225
244
  base_url: Optional[str] = "http://localhost:11434/api",
226
245
  json_mode: bool = False,
227
246
  num_ctx: int = 128_000,
247
+ image_size: int = 768,
228
248
  **kwargs: Any,
229
249
  ):
230
250
  """Initializes the Ollama LMM. kwargs are passed as 'options' to the model.
@@ -241,6 +261,7 @@ class OllamaLMM(LMM):
241
261
 
242
262
  self.url = base_url
243
263
  self.model_name = model_name
264
+ self.image_size = image_size
244
265
  self.kwargs = {"options": kwargs}
245
266
 
246
267
  if json_mode:
@@ -273,8 +294,9 @@ class OllamaLMM(LMM):
273
294
  fixed_chat = []
274
295
  for message in chat:
275
296
  if "media" in message:
297
+ resize = kwargs["resize"] if "resize" in kwargs else self.image_size
276
298
  message["images"] = [
277
- encode_media(cast(str, m)) for m in message["media"]
299
+ encode_media(cast(str, m), resize=resize) for m in message["media"]
278
300
  ]
279
301
  del message["media"]
280
302
  fixed_chat.append(message)
@@ -328,7 +350,8 @@ class OllamaLMM(LMM):
328
350
 
329
351
  if media and len(media) > 0:
330
352
  for m in media:
331
- data["images"].append(encode_media(m))
353
+ resize = kwargs["resize"] if "resize" in kwargs else self.image_size
354
+ data["images"].append(encode_media(m, resize=resize))
332
355
 
333
356
  tmp_kwargs = self.kwargs | kwargs
334
357
  data.update(tmp_kwargs)
@@ -370,9 +393,11 @@ class AnthropicLMM(LMM):
370
393
  api_key: Optional[str] = None,
371
394
  model_name: str = "claude-3-5-sonnet-20240620",
372
395
  max_tokens: int = 4096,
396
+ image_size: int = 768,
373
397
  **kwargs: Any,
374
398
  ):
375
399
  self.client = anthropic.Anthropic(api_key=api_key)
400
+ self.image_size = image_size
376
401
  self.model_name = model_name
377
402
  if "max_tokens" not in kwargs:
378
403
  kwargs["max_tokens"] = max_tokens
@@ -399,7 +424,8 @@ class AnthropicLMM(LMM):
399
424
  ]
400
425
  if "media" in msg:
401
426
  for media_path in msg["media"]:
402
- encoded_media = encode_media(media_path, resize=768)
427
+ resize = kwargs["resize"] if "resize" in kwargs else self.image_size
428
+ encoded_media = encode_media(media_path, resize=resize)
403
429
  if encoded_media.startswith("data:image/png;base64,"):
404
430
  encoded_media = encoded_media[len("data:image/png;base64,") :]
405
431
  content.append(
@@ -448,7 +474,8 @@ class AnthropicLMM(LMM):
448
474
  ]
449
475
  if media:
450
476
  for m in media:
451
- encoded_media = encode_media(m, resize=768)
477
+ resize = kwargs["resize"] if "resize" in kwargs else self.image_size
478
+ encoded_media = encode_media(m, resize=resize)
452
479
  if encoded_media.startswith("data:image/png;base64,"):
453
480
  encoded_media = encoded_media[len("data:image/png;base64,") :]
454
481
  content.append(
@@ -486,3 +513,30 @@ class AnthropicLMM(LMM):
486
513
  return f()
487
514
  else:
488
515
  return cast(str, response.content[0].text)
516
+
517
+
518
+ class GoogleLMM(OpenAILMM):
519
+ r"""An LMM class for the Google LMMs."""
520
+
521
+ def __init__(
522
+ self,
523
+ api_key: Optional[str] = None,
524
+ model_name: str = "gemini-2.0-flash-exp",
525
+ max_tokens: int = 4096,
526
+ image_detail: str = "low",
527
+ image_size: int = 768,
528
+ **kwargs: Any,
529
+ ):
530
+ base_url = "https://generativelanguage.googleapis.com/v1beta/openai/"
531
+ if not api_key:
532
+ api_key = os.environ.get("GEMINI_API_KEY")
533
+
534
+ self.client = OpenAI(api_key=api_key, base_url=base_url)
535
+
536
+ self.model_name = model_name
537
+ self.image_size = image_size
538
+ self.image_detail = image_detail
539
+
540
+ if "max_tokens" not in kwargs:
541
+ kwargs["max_tokens"] = max_tokens
542
+ self.kwargs = kwargs
@@ -1,7 +1,7 @@
1
1
  import inspect
2
2
  import logging
3
- import shutil
4
3
  import tempfile
4
+ from concurrent.futures import ThreadPoolExecutor, as_completed
5
5
  from typing import Any, Callable, Dict, List, Optional, Tuple, cast
6
6
 
7
7
  import libcst as cst
@@ -24,6 +24,7 @@ from vision_agent.agent.vision_agent_planner_prompts_v2 import (
24
24
  TEST_TOOLS_EXAMPLE1,
25
25
  TEST_TOOLS_EXAMPLE2,
26
26
  )
27
+ from vision_agent.configs import Config
27
28
  from vision_agent.lmm import LMM, AnthropicLMM
28
29
  from vision_agent.utils.execute import (
29
30
  CodeInterpreter,
@@ -36,6 +37,7 @@ from vision_agent.utils.sim import get_tool_recommender
36
37
 
37
38
  TOOL_FUNCTIONS = {tool.__name__: tool for tool in T.TOOLS}
38
39
 
40
+ CONFIG = Config()
39
41
  _LOGGER = logging.getLogger(__name__)
40
42
  EXAMPLES = f"\n{TEST_TOOLS_EXAMPLE1}\n{TEST_TOOLS_EXAMPLE2}\n"
41
43
 
@@ -50,6 +52,54 @@ def format_tool_output(tool_thoughts: str, tool_docstring: str) -> str:
50
52
  return return_str
51
53
 
52
54
 
55
+ def run_multi_judge(
56
+ tool_chooser: LMM,
57
+ tool_docs_str: str,
58
+ task: str,
59
+ code: str,
60
+ tool_output_str: str,
61
+ image_paths: List[str],
62
+ ) -> Tuple[Optional[Callable], str, str]:
63
+ error_message = ""
64
+ prompt = PICK_TOOL.format(
65
+ tool_docs=tool_docs_str,
66
+ user_request=task,
67
+ context=f"<code>\n{code}\n</code>\n<tool_output>\n{tool_output_str}\n</tool_output>",
68
+ previous_attempts=error_message,
69
+ )
70
+
71
+ def run_judge() -> Tuple[Optional[Callable], str, str]:
72
+ response = tool_chooser.generate(prompt, media=image_paths, temperature=1.0)
73
+ tool_choice_context = extract_tag(response, "json") # type: ignore
74
+ tool_choice_context_dict = extract_json(tool_choice_context) # type: ignore
75
+ tool, tool_thoughts, tool_docstring, _ = extract_tool_info(
76
+ tool_choice_context_dict
77
+ )
78
+ return tool, tool_thoughts, tool_docstring
79
+
80
+ responses = []
81
+ with ThreadPoolExecutor() as executor:
82
+ futures = [executor.submit(run_judge) for _ in range(3)]
83
+ for future in as_completed(futures):
84
+ responses.append(future.result())
85
+
86
+ responses = [r for r in responses if r[0] is not None]
87
+ counts: Dict[str, int] = {}
88
+ for tool, tool_thoughts, tool_docstring in responses:
89
+ if tool is not None:
90
+ counts[tool.__name__] = counts.get(tool.__name__, 0) + 1
91
+ if counts[tool.__name__] >= 2:
92
+ return tool, tool_thoughts, tool_docstring
93
+
94
+ if len(responses) == 0:
95
+ return (
96
+ None,
97
+ "No tool could be found, please try again with a different prompt or image",
98
+ "",
99
+ )
100
+ return responses[0]
101
+
102
+
53
103
  def extract_tool_info(
54
104
  tool_choice_context: Dict[str, Any],
55
105
  ) -> Tuple[Optional[Callable], str, str, str]:
@@ -212,7 +262,8 @@ def get_tool_for_task(
212
262
  --------
213
263
  >>> get_tool_for_task("Give me an OCR model that can find 'hot chocolate' in the image", [image])
214
264
  """
215
- lmm = AnthropicLMM()
265
+ tool_tester = CONFIG.create_tool_tester()
266
+ tool_chooser = CONFIG.create_tool_chooser()
216
267
 
217
268
  with (
218
269
  tempfile.TemporaryDirectory() as tmpdirname,
@@ -225,45 +276,14 @@ def get_tool_for_task(
225
276
  image_paths.append(image_path)
226
277
 
227
278
  code, tool_docs_str, tool_output = run_tool_testing(
228
- task, image_paths, lmm, exclude_tools, code_interpreter
279
+ task, image_paths, tool_tester, exclude_tools, code_interpreter
229
280
  )
230
281
  tool_output_str = tool_output.text(include_results=False).strip()
231
282
 
232
- error_message = ""
233
- prompt = PICK_TOOL.format(
234
- tool_docs=tool_docs_str,
235
- user_request=task,
236
- context=f"<code>\n{code}\n</code>\n<tool_output>\n{tool_output_str}\n</tool_output>",
237
- previous_attempts=error_message,
238
- )
239
-
240
- response = lmm.generate(prompt, media=image_paths)
241
- tool_choice_context = extract_tag(response, "json") # type: ignore
242
- tool_choice_context_dict = extract_json(tool_choice_context) # type: ignore
243
-
244
- tool, tool_thoughts, tool_docstring, error_message = extract_tool_info(
245
- tool_choice_context_dict
283
+ _, tool_thoughts, tool_docstring = run_multi_judge(
284
+ tool_chooser, tool_docs_str, task, code, tool_output_str, image_paths
246
285
  )
247
286
 
248
- count = 1
249
- while tool is None and count <= 3:
250
- prompt = PICK_TOOL.format(
251
- tool_docs=tool_docs_str,
252
- user_request=task,
253
- context=f"<code>\n{code}\n</code>\n<tool_output>\n{tool_output_str}\n</tool_output>",
254
- previous_attempts=error_message,
255
- )
256
- tool_choice_context_dict = extract_json(
257
- lmm.generate(prompt, media=image_paths) # type: ignore
258
- )
259
- tool, tool_thoughts, tool_docstring, error_message = extract_tool_info(
260
- tool_choice_context_dict
261
- )
262
- try:
263
- shutil.rmtree(tmpdirname)
264
- except Exception as e:
265
- _LOGGER.error(f"Error removing temp directory: {e}")
266
-
267
287
  print(format_tool_output(tool_thoughts, tool_docstring))
268
288
 
269
289
 
@@ -277,7 +297,7 @@ def get_tool_for_task_human_reviewer(
277
297
  task: str, images: List[np.ndarray], exclude_tools: Optional[List[str]] = None
278
298
  ) -> None:
279
299
  # NOTE: this will have the same documentation as get_tool_for_task
280
- lmm = AnthropicLMM()
300
+ tool_tester = CONFIG.create_tool_tester()
281
301
 
282
302
  with (
283
303
  tempfile.TemporaryDirectory() as tmpdirname,
@@ -298,7 +318,7 @@ def get_tool_for_task_human_reviewer(
298
318
  _, _, tool_output = run_tool_testing(
299
319
  task,
300
320
  image_paths,
301
- lmm,
321
+ tool_tester,
302
322
  exclude_tools,
303
323
  code_interpreter,
304
324
  process_code=lambda x: replace_box_threshold(x, tools, 0.05),
@@ -349,7 +369,7 @@ def claude35_vqa(prompt: str, medias: List[np.ndarray]) -> None:
349
369
  medias: List[np.ndarray]: The images to ask the question about, it could also
350
370
  be frames from a video. You can send up to 5 frames from a video.
351
371
  """
352
- lmm = AnthropicLMM()
372
+ vqa = CONFIG.create_vqa()
353
373
  if isinstance(medias, np.ndarray):
354
374
  medias = [medias]
355
375
  if isinstance(medias, list) and len(medias) > 5:
@@ -358,7 +378,7 @@ def claude35_vqa(prompt: str, medias: List[np.ndarray]) -> None:
358
378
  "data:image/png;base64," + convert_to_b64(media) for media in medias
359
379
  ]
360
380
 
361
- response = cast(str, lmm.generate(prompt, media=all_media_b64))
381
+ response = cast(str, vqa.generate(prompt, media=all_media_b64))
362
382
  print(f"[claude35_vqa output]\n{response}\n[end of claude35_vqa output]")
363
383
 
364
384
 
@@ -72,8 +72,7 @@ def send_inference_request(
72
72
 
73
73
  response = _call_post(url, payload, session, files, function_name, is_form)
74
74
 
75
- # TODO: consider making the response schema the same between below two sources
76
- return response if "TOOL_ENDPOINT_AUTH" in os.environ else response["data"]
75
+ return response["data"]
77
76
 
78
77
 
79
78
  def send_task_inference_request(
@@ -595,14 +595,14 @@ def owlv2_sam2_video_tracking(
595
595
  def florence2_object_detection(
596
596
  prompt: str, image: np.ndarray, fine_tune_id: Optional[str] = None
597
597
  ) -> List[Dict[str, Any]]:
598
- """'florence2_object_detection' is a tool that can detect multiple
599
- objects given a text prompt which can be object names or caption. You
600
- can optionally separate the object names in the text with commas. It returns a list
601
- of bounding boxes with normalized coordinates, label names and associated
602
- confidence scores of 1.0.
598
+ """'florence2_object_detection' is a tool that can detect multiple objects given a
599
+ text prompt which can be object names or caption. You can optionally separate the
600
+ object names in the text with commas. It returns a list of bounding boxes with
601
+ normalized coordinates, label names and associated confidence scores of 1.0.
603
602
 
604
603
  Parameters:
605
- prompt (str): The prompt to ground to the image.
604
+ prompt (str): The prompt to ground to the image. Use exclusive categories that
605
+ do not overlap such as 'person, car' and NOT 'person, athlete'.
606
606
  image (np.ndarray): The image to used to detect objects
607
607
  fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the
608
608
  fine-tuned model ID here to use it.
@@ -681,7 +681,8 @@ def florence2_sam2_instance_segmentation(
681
681
  1.0.
682
682
 
683
683
  Parameters:
684
- prompt (str): The prompt to ground to the image.
684
+ prompt (str): The prompt to ground to the image. Use exclusive categories that
685
+ do not overlap such as 'person, car' and NOT 'person, athlete'.
685
686
  image (np.ndarray): The image to ground the prompt to.
686
687
  fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the
687
688
  fine-tuned model ID here to use it.
@@ -769,7 +770,8 @@ def florence2_sam2_video_tracking(
769
770
  is useful for tracking and counting without duplicating counts.
770
771
 
771
772
  Parameters:
772
- prompt (str): The prompt to ground to the video.
773
+ prompt (str): The prompt to ground to the image. Use exclusive categories that
774
+ do not overlap such as 'person, car' and NOT 'person, athlete'.
773
775
  frames (List[np.ndarray]): The list of frames to ground the prompt to.
774
776
  chunk_length (Optional[int]): The number of frames to re-run florence2 to find
775
777
  new objects.