vision-agent 0.2.229__py3-none-any.whl → 0.2.231__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,160 @@
1
+ from typing import Type
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+ from vision_agent.lmm import LMM, OpenAILMM
6
+
7
+
8
+ class Config(BaseModel):
9
+ # for vision_agent_v2
10
+ agent: Type[LMM] = Field(default=OpenAILMM)
11
+ agent_kwargs: dict = Field(
12
+ default_factory=lambda: {
13
+ "model_name": "gpt-4o-2024-08-06",
14
+ "temperature": 0.0,
15
+ "image_size": 768,
16
+ "image_detail": "low",
17
+ }
18
+ )
19
+
20
+ # for vision_agent_planner_v2
21
+ planner: Type[LMM] = Field(default=OpenAILMM)
22
+ planner_kwargs: dict = Field(
23
+ default_factory=lambda: {
24
+ "model_name": "gpt-4o-2024-08-06",
25
+ "temperature": 0.0,
26
+ "image_size": 768,
27
+ "image_detail": "low",
28
+ }
29
+ )
30
+
31
+ # for vision_agent_planner_v2
32
+ summarizer: Type[LMM] = Field(default=OpenAILMM)
33
+ summarizer_kwargs: dict = Field(
34
+ default_factory=lambda: {
35
+ "model_name": "o1",
36
+ "temperature": 1.0,
37
+ "image_size": 768,
38
+ }
39
+ )
40
+
41
+ # for vision_agent_planner_v2
42
+ critic: Type[LMM] = Field(default=OpenAILMM)
43
+ critic_kwargs: dict = Field(
44
+ default_factory=lambda: {
45
+ "model_name": "gpt-4o-2024-08-06",
46
+ "temperature": 0.0,
47
+ "image_size": 768,
48
+ "image_detail": "low",
49
+ }
50
+ )
51
+
52
+ # for vision_agent_coder_v2
53
+ coder: Type[LMM] = Field(default=OpenAILMM)
54
+ coder_kwargs: dict = Field(
55
+ default_factory=lambda: {
56
+ "model_name": "gpt-4o-2024-08-06",
57
+ "temperature": 0.0,
58
+ "image_size": 768,
59
+ "image_detail": "low",
60
+ }
61
+ )
62
+
63
+ # for vision_agent_coder_v2
64
+ tester: Type[LMM] = Field(default=OpenAILMM)
65
+ tester_kwargs: dict = Field(
66
+ default_factory=lambda: {
67
+ "model_name": "gpt-4o-2024-08-06",
68
+ "temperature": 0.0,
69
+ "image_size": 768,
70
+ "image_detail": "low",
71
+ }
72
+ )
73
+
74
+ # for vision_agent_coder_v2
75
+ debugger: Type[LMM] = Field(default=OpenAILMM)
76
+ debugger_kwargs: dict = Field(
77
+ default_factory=lambda: {
78
+ "model_name": "gpt-4o-2024-08-06",
79
+ "temperature": 0.0,
80
+ "image_size": 768,
81
+ "image_detail": "low",
82
+ }
83
+ )
84
+
85
+ # for get_tool_for_task
86
+ tool_tester: Type[LMM] = Field(default=OpenAILMM)
87
+ tool_tester_kwargs: dict = Field(
88
+ default_factory=lambda: {
89
+ "model_name": "gpt-4o-2024-08-06",
90
+ "temperature": 0.0,
91
+ "image_size": 768,
92
+ "image_detail": "low",
93
+ }
94
+ )
95
+
96
+ # for get_tool_for_task
97
+ tool_chooser: Type[LMM] = Field(default=OpenAILMM)
98
+ tool_chooser_kwargs: dict = Field(
99
+ default_factory=lambda: {
100
+ "model_name": "gpt-4o-2024-08-06",
101
+ "temperature": 1.0,
102
+ "image_size": 768,
103
+ "image_detail": "low",
104
+ }
105
+ )
106
+
107
+ # for suggestions module
108
+ suggester: Type[LMM] = Field(default=OpenAILMM)
109
+ suggester_kwargs: dict = Field(
110
+ default_factory=lambda: {
111
+ "model_name": "gpt-4o-2024-08-06",
112
+ "temperature": 1.0,
113
+ "image_size": 768,
114
+ "image_detail": "low",
115
+ }
116
+ )
117
+
118
+ # for vqa module
119
+ vqa: Type[LMM] = Field(default=OpenAILMM)
120
+ vqa_kwargs: dict = Field(
121
+ default_factory=lambda: {
122
+ "model_name": "gpt-4o-2024-08-06",
123
+ "temperature": 0.0,
124
+ "image_size": 768,
125
+ "image_detail": "low",
126
+ }
127
+ )
128
+
129
+ def create_agent(self) -> LMM:
130
+ return self.agent(**self.agent_kwargs)
131
+
132
+ def create_planner(self) -> LMM:
133
+ return self.planner(**self.planner_kwargs)
134
+
135
+ def create_summarizer(self) -> LMM:
136
+ return self.summarizer(**self.summarizer_kwargs)
137
+
138
+ def create_critic(self) -> LMM:
139
+ return self.critic(**self.critic_kwargs)
140
+
141
+ def create_coder(self) -> LMM:
142
+ return self.coder(**self.coder_kwargs)
143
+
144
+ def create_tester(self) -> LMM:
145
+ return self.tester(**self.tester_kwargs)
146
+
147
+ def create_debugger(self) -> LMM:
148
+ return self.debugger(**self.debugger_kwargs)
149
+
150
+ def create_tool_tester(self) -> LMM:
151
+ return self.tool_tester(**self.tool_tester_kwargs)
152
+
153
+ def create_tool_chooser(self) -> LMM:
154
+ return self.tool_chooser(**self.tool_chooser_kwargs)
155
+
156
+ def create_suggester(self) -> LMM:
157
+ return self.suggester(**self.suggester_kwargs)
158
+
159
+ def create_vqa(self) -> LMM:
160
+ return self.vqa(**self.vqa_kwargs)
@@ -1,2 +1,2 @@
1
- from .lmm import LMM, AnthropicLMM, AzureOpenAILMM, OllamaLMM, OpenAILMM
1
+ from .lmm import LMM, AnthropicLMM, AzureOpenAILMM, GoogleLMM, OllamaLMM, OpenAILMM
2
2
  from .types import Message
vision_agent/lmm/lmm.py CHANGED
@@ -50,6 +50,8 @@ class OpenAILMM(LMM):
50
50
  api_key: Optional[str] = None,
51
51
  max_tokens: int = 4096,
52
52
  json_mode: bool = False,
53
+ image_size: int = 768,
54
+ image_detail: str = "low",
53
55
  **kwargs: Any,
54
56
  ):
55
57
  if not api_key:
@@ -59,7 +61,10 @@ class OpenAILMM(LMM):
59
61
 
60
62
  self.client = OpenAI(api_key=api_key)
61
63
  self.model_name = model_name
62
- if "max_tokens" not in kwargs:
64
+ self.image_size = image_size
65
+ self.image_detail = image_detail
66
+ # o1 does not use max_tokens
67
+ if "max_tokens" not in kwargs and not model_name.startswith("o1"):
63
68
  kwargs["max_tokens"] = max_tokens
64
69
  if json_mode:
65
70
  kwargs["response_format"] = {"type": "json_object"}
@@ -94,7 +99,13 @@ class OpenAILMM(LMM):
94
99
  fixed_c["content"] = [{"type": "text", "text": c["content"]}] # type: ignore
95
100
  if "media" in c:
96
101
  for media in c["media"]:
97
- encoded_media = encode_media(cast(str, media))
102
+ resize = kwargs["resize"] if "resize" in kwargs else self.image_size
103
+ image_detail = (
104
+ kwargs["image_detail"]
105
+ if "image_detail" in kwargs
106
+ else self.image_detail
107
+ )
108
+ encoded_media = encode_media(cast(str, media), resize=resize)
98
109
 
99
110
  fixed_c["content"].append( # type: ignore
100
111
  {
@@ -106,7 +117,7 @@ class OpenAILMM(LMM):
106
117
  or encoded_media.startswith("data:image/")
107
118
  else f"data:image/png;base64,{encoded_media}"
108
119
  ),
109
- "detail": "low",
120
+ "detail": image_detail,
110
121
  },
111
122
  },
112
123
  )
@@ -144,7 +155,13 @@ class OpenAILMM(LMM):
144
155
  ]
145
156
  if media and len(media) > 0:
146
157
  for m in media:
147
- encoded_media = encode_media(m)
158
+ resize = kwargs["resize"] if "resize" in kwargs else None
159
+ image_detail = (
160
+ kwargs["image_detail"]
161
+ if "image_detail" in kwargs
162
+ else self.image_detail
163
+ )
164
+ encoded_media = encode_media(m, resize=resize)
148
165
  message[0]["content"].append(
149
166
  {
150
167
  "type": "image_url",
@@ -155,7 +172,7 @@ class OpenAILMM(LMM):
155
172
  or encoded_media.startswith("data:image/")
156
173
  else f"data:image/png;base64,{encoded_media}"
157
174
  ),
158
- "detail": "low",
175
+ "detail": image_detail,
159
176
  },
160
177
  },
161
178
  )
@@ -186,6 +203,7 @@ class AzureOpenAILMM(OpenAILMM):
186
203
  azure_endpoint: Optional[str] = None,
187
204
  max_tokens: int = 4096,
188
205
  json_mode: bool = False,
206
+ image_detail: str = "low",
189
207
  **kwargs: Any,
190
208
  ):
191
209
  if not api_key:
@@ -208,6 +226,7 @@ class AzureOpenAILMM(OpenAILMM):
208
226
  azure_endpoint=azure_endpoint,
209
227
  )
210
228
  self.model_name = model_name
229
+ self.image_detail = image_detail
211
230
 
212
231
  if "max_tokens" not in kwargs:
213
232
  kwargs["max_tokens"] = max_tokens
@@ -225,6 +244,7 @@ class OllamaLMM(LMM):
225
244
  base_url: Optional[str] = "http://localhost:11434/api",
226
245
  json_mode: bool = False,
227
246
  num_ctx: int = 128_000,
247
+ image_size: int = 768,
228
248
  **kwargs: Any,
229
249
  ):
230
250
  """Initializes the Ollama LMM. kwargs are passed as 'options' to the model.
@@ -241,6 +261,7 @@ class OllamaLMM(LMM):
241
261
 
242
262
  self.url = base_url
243
263
  self.model_name = model_name
264
+ self.image_size = image_size
244
265
  self.kwargs = {"options": kwargs}
245
266
 
246
267
  if json_mode:
@@ -273,8 +294,9 @@ class OllamaLMM(LMM):
273
294
  fixed_chat = []
274
295
  for message in chat:
275
296
  if "media" in message:
297
+ resize = kwargs["resize"] if "resize" in kwargs else self.image_size
276
298
  message["images"] = [
277
- encode_media(cast(str, m)) for m in message["media"]
299
+ encode_media(cast(str, m), resize=resize) for m in message["media"]
278
300
  ]
279
301
  del message["media"]
280
302
  fixed_chat.append(message)
@@ -328,7 +350,8 @@ class OllamaLMM(LMM):
328
350
 
329
351
  if media and len(media) > 0:
330
352
  for m in media:
331
- data["images"].append(encode_media(m))
353
+ resize = kwargs["resize"] if "resize" in kwargs else self.image_size
354
+ data["images"].append(encode_media(m, resize=resize))
332
355
 
333
356
  tmp_kwargs = self.kwargs | kwargs
334
357
  data.update(tmp_kwargs)
@@ -370,9 +393,11 @@ class AnthropicLMM(LMM):
370
393
  api_key: Optional[str] = None,
371
394
  model_name: str = "claude-3-5-sonnet-20240620",
372
395
  max_tokens: int = 4096,
396
+ image_size: int = 768,
373
397
  **kwargs: Any,
374
398
  ):
375
399
  self.client = anthropic.Anthropic(api_key=api_key)
400
+ self.image_size = image_size
376
401
  self.model_name = model_name
377
402
  if "max_tokens" not in kwargs:
378
403
  kwargs["max_tokens"] = max_tokens
@@ -399,7 +424,8 @@ class AnthropicLMM(LMM):
399
424
  ]
400
425
  if "media" in msg:
401
426
  for media_path in msg["media"]:
402
- encoded_media = encode_media(media_path, resize=768)
427
+ resize = kwargs["resize"] if "resize" in kwargs else self.image_size
428
+ encoded_media = encode_media(media_path, resize=resize)
403
429
  if encoded_media.startswith("data:image/png;base64,"):
404
430
  encoded_media = encoded_media[len("data:image/png;base64,") :]
405
431
  content.append(
@@ -448,7 +474,8 @@ class AnthropicLMM(LMM):
448
474
  ]
449
475
  if media:
450
476
  for m in media:
451
- encoded_media = encode_media(m, resize=768)
477
+ resize = kwargs["resize"] if "resize" in kwargs else self.image_size
478
+ encoded_media = encode_media(m, resize=resize)
452
479
  if encoded_media.startswith("data:image/png;base64,"):
453
480
  encoded_media = encoded_media[len("data:image/png;base64,") :]
454
481
  content.append(
@@ -486,3 +513,30 @@ class AnthropicLMM(LMM):
486
513
  return f()
487
514
  else:
488
515
  return cast(str, response.content[0].text)
516
+
517
+
518
+ class GoogleLMM(OpenAILMM):
519
+ r"""An LMM class for the Google LMMs."""
520
+
521
+ def __init__(
522
+ self,
523
+ api_key: Optional[str] = None,
524
+ model_name: str = "gemini-2.0-flash-exp",
525
+ max_tokens: int = 4096,
526
+ image_detail: str = "low",
527
+ image_size: int = 768,
528
+ **kwargs: Any,
529
+ ):
530
+ base_url = "https://generativelanguage.googleapis.com/v1beta/openai/"
531
+ if not api_key:
532
+ api_key = os.environ.get("GEMINI_API_KEY")
533
+
534
+ self.client = OpenAI(api_key=api_key, base_url=base_url)
535
+
536
+ self.model_name = model_name
537
+ self.image_size = image_size
538
+ self.image_detail = image_detail
539
+
540
+ if "max_tokens" not in kwargs:
541
+ kwargs["max_tokens"] = max_tokens
542
+ self.kwargs = kwargs
@@ -23,6 +23,9 @@ from .tools import (
23
23
  TOOLS_INFO,
24
24
  UTIL_TOOLS,
25
25
  UTILITIES_DOCSTRING,
26
+ agentic_object_detection,
27
+ agentic_sam2_instance_segmentation,
28
+ agentic_sam2_video_tracking,
26
29
  claude35_text_extraction,
27
30
  closest_box_distance,
28
31
  closest_mask_distance,
@@ -30,6 +33,7 @@ from .tools import (
30
33
  countgd_sam2_instance_segmentation,
31
34
  countgd_sam2_video_tracking,
32
35
  countgd_visual_prompt_object_detection,
36
+ custom_object_detection,
33
37
  depth_anything_v2,
34
38
  detr_segmentation,
35
39
  document_extraction,
@@ -63,10 +67,6 @@ from .tools import (
63
67
  video_temporal_localization,
64
68
  vit_image_classification,
65
69
  vit_nsfw_classification,
66
- custom_object_detection,
67
- agentic_object_detection,
68
- agentic_sam2_instance_segmentation,
69
- agentic_sam2_video_tracking,
70
70
  )
71
71
 
72
72
  __new_tools__ = [
@@ -1,7 +1,7 @@
1
1
  import inspect
2
2
  import logging
3
- import shutil
4
3
  import tempfile
4
+ from concurrent.futures import ThreadPoolExecutor, as_completed
5
5
  from typing import Any, Callable, Dict, List, Optional, Tuple, cast
6
6
 
7
7
  import libcst as cst
@@ -10,12 +10,7 @@ from IPython.display import display
10
10
  from PIL import Image
11
11
 
12
12
  import vision_agent.tools as T
13
- from vision_agent.agent.agent_utils import (
14
- DefaultImports,
15
- extract_code,
16
- extract_json,
17
- extract_tag,
18
- )
13
+ from vision_agent.agent.agent_utils import DefaultImports, extract_json, extract_tag
19
14
  from vision_agent.agent.vision_agent_planner_prompts_v2 import (
20
15
  CATEGORIZE_TOOL_REQUEST,
21
16
  FINALIZE_PLAN,
@@ -24,6 +19,7 @@ from vision_agent.agent.vision_agent_planner_prompts_v2 import (
24
19
  TEST_TOOLS_EXAMPLE1,
25
20
  TEST_TOOLS_EXAMPLE2,
26
21
  )
22
+ from vision_agent.configs import Config
27
23
  from vision_agent.lmm import LMM, AnthropicLMM
28
24
  from vision_agent.utils.execute import (
29
25
  CodeInterpreter,
@@ -35,7 +31,11 @@ from vision_agent.utils.image_utils import convert_to_b64
35
31
  from vision_agent.utils.sim import get_tool_recommender
36
32
 
37
33
  TOOL_FUNCTIONS = {tool.__name__: tool for tool in T.TOOLS}
34
+ LOAD_TOOLS_DOCSTRING = T.get_tool_documentation(
35
+ [T.load_image, T.extract_frames_and_timestamps]
36
+ )
38
37
 
38
+ CONFIG = Config()
39
39
  _LOGGER = logging.getLogger(__name__)
40
40
  EXAMPLES = f"\n{TEST_TOOLS_EXAMPLE1}\n{TEST_TOOLS_EXAMPLE2}\n"
41
41
 
@@ -50,6 +50,54 @@ def format_tool_output(tool_thoughts: str, tool_docstring: str) -> str:
50
50
  return return_str
51
51
 
52
52
 
53
+ def run_multi_judge(
54
+ tool_chooser: LMM,
55
+ tool_docs_str: str,
56
+ task: str,
57
+ code: str,
58
+ tool_output_str: str,
59
+ image_paths: List[str],
60
+ ) -> Tuple[Optional[Callable], str, str]:
61
+ error_message = ""
62
+ prompt = PICK_TOOL.format(
63
+ tool_docs=tool_docs_str,
64
+ user_request=task,
65
+ context=f"<code>\n{code}\n</code>\n<tool_output>\n{tool_output_str}\n</tool_output>",
66
+ previous_attempts=error_message,
67
+ )
68
+
69
+ def run_judge() -> Tuple[Optional[Callable], str, str]:
70
+ response = tool_chooser.generate(prompt, media=image_paths, temperature=1.0)
71
+ tool_choice_context = extract_tag(response, "json") # type: ignore
72
+ tool_choice_context_dict = extract_json(tool_choice_context) # type: ignore
73
+ tool, tool_thoughts, tool_docstring, _ = extract_tool_info(
74
+ tool_choice_context_dict
75
+ )
76
+ return tool, tool_thoughts, tool_docstring
77
+
78
+ responses = []
79
+ with ThreadPoolExecutor() as executor:
80
+ futures = [executor.submit(run_judge) for _ in range(3)]
81
+ for future in as_completed(futures):
82
+ responses.append(future.result())
83
+
84
+ responses = [r for r in responses if r[0] is not None]
85
+ counts: Dict[str, int] = {}
86
+ for tool, tool_thoughts, tool_docstring in responses:
87
+ if tool is not None:
88
+ counts[tool.__name__] = counts.get(tool.__name__, 0) + 1
89
+ if counts[tool.__name__] >= 2:
90
+ return tool, tool_thoughts, tool_docstring
91
+
92
+ if len(responses) == 0:
93
+ return (
94
+ None,
95
+ "No tool could be found, please try again with a different prompt or image",
96
+ "",
97
+ )
98
+ return responses[0]
99
+
100
+
53
101
  def extract_tool_info(
54
102
  tool_choice_context: Dict[str, Any],
55
103
  ) -> Tuple[Optional[Callable], str, str, str]:
@@ -129,6 +177,7 @@ def run_tool_testing(
129
177
  cleaned_tool_docs.append(tool_doc)
130
178
  tool_docs = cleaned_tool_docs
131
179
  tool_docs_str = "\n".join([e["doc"] for e in tool_docs])
180
+ tool_docs_str += "\n" + LOAD_TOOLS_DOCSTRING
132
181
 
133
182
  prompt = TEST_TOOLS.format(
134
183
  tool_docs=tool_docs_str,
@@ -167,8 +216,15 @@ def run_tool_testing(
167
216
  examples=EXAMPLES,
168
217
  media=str(image_paths),
169
218
  )
170
- code = extract_code(lmm.generate(prompt, media=image_paths)) # type: ignore
171
- code = process_code(code)
219
+ response = cast(str, lmm.generate(prompt, media=image_paths))
220
+ code = extract_tag(response, "code")
221
+ if code is None:
222
+ code = response
223
+
224
+ try:
225
+ code = process_code(code)
226
+ except Exception as e:
227
+ _LOGGER.error(f"Error processing code: {e}")
172
228
  tool_output = code_interpreter.exec_isolation(
173
229
  DefaultImports.prepend_imports(code)
174
230
  )
@@ -212,7 +268,8 @@ def get_tool_for_task(
212
268
  --------
213
269
  >>> get_tool_for_task("Give me an OCR model that can find 'hot chocolate' in the image", [image])
214
270
  """
215
- lmm = AnthropicLMM()
271
+ tool_tester = CONFIG.create_tool_tester()
272
+ tool_chooser = CONFIG.create_tool_chooser()
216
273
 
217
274
  with (
218
275
  tempfile.TemporaryDirectory() as tmpdirname,
@@ -225,45 +282,14 @@ def get_tool_for_task(
225
282
  image_paths.append(image_path)
226
283
 
227
284
  code, tool_docs_str, tool_output = run_tool_testing(
228
- task, image_paths, lmm, exclude_tools, code_interpreter
285
+ task, image_paths, tool_tester, exclude_tools, code_interpreter
229
286
  )
230
287
  tool_output_str = tool_output.text(include_results=False).strip()
231
288
 
232
- error_message = ""
233
- prompt = PICK_TOOL.format(
234
- tool_docs=tool_docs_str,
235
- user_request=task,
236
- context=f"<code>\n{code}\n</code>\n<tool_output>\n{tool_output_str}\n</tool_output>",
237
- previous_attempts=error_message,
238
- )
239
-
240
- response = lmm.generate(prompt, media=image_paths)
241
- tool_choice_context = extract_tag(response, "json") # type: ignore
242
- tool_choice_context_dict = extract_json(tool_choice_context) # type: ignore
243
-
244
- tool, tool_thoughts, tool_docstring, error_message = extract_tool_info(
245
- tool_choice_context_dict
289
+ _, tool_thoughts, tool_docstring = run_multi_judge(
290
+ tool_chooser, tool_docs_str, task, code, tool_output_str, image_paths
246
291
  )
247
292
 
248
- count = 1
249
- while tool is None and count <= 3:
250
- prompt = PICK_TOOL.format(
251
- tool_docs=tool_docs_str,
252
- user_request=task,
253
- context=f"<code>\n{code}\n</code>\n<tool_output>\n{tool_output_str}\n</tool_output>",
254
- previous_attempts=error_message,
255
- )
256
- tool_choice_context_dict = extract_json(
257
- lmm.generate(prompt, media=image_paths) # type: ignore
258
- )
259
- tool, tool_thoughts, tool_docstring, error_message = extract_tool_info(
260
- tool_choice_context_dict
261
- )
262
- try:
263
- shutil.rmtree(tmpdirname)
264
- except Exception as e:
265
- _LOGGER.error(f"Error removing temp directory: {e}")
266
-
267
293
  print(format_tool_output(tool_thoughts, tool_docstring))
268
294
 
269
295
 
@@ -277,7 +303,7 @@ def get_tool_for_task_human_reviewer(
277
303
  task: str, images: List[np.ndarray], exclude_tools: Optional[List[str]] = None
278
304
  ) -> None:
279
305
  # NOTE: this will have the same documentation as get_tool_for_task
280
- lmm = AnthropicLMM()
306
+ tool_tester = CONFIG.create_tool_tester()
281
307
 
282
308
  with (
283
309
  tempfile.TemporaryDirectory() as tmpdirname,
@@ -298,7 +324,7 @@ def get_tool_for_task_human_reviewer(
298
324
  _, _, tool_output = run_tool_testing(
299
325
  task,
300
326
  image_paths,
301
- lmm,
327
+ tool_tester,
302
328
  exclude_tools,
303
329
  code_interpreter,
304
330
  process_code=lambda x: replace_box_threshold(x, tools, 0.05),
@@ -349,7 +375,7 @@ def claude35_vqa(prompt: str, medias: List[np.ndarray]) -> None:
349
375
  medias: List[np.ndarray]: The images to ask the question about, it could also
350
376
  be frames from a video. You can send up to 5 frames from a video.
351
377
  """
352
- lmm = AnthropicLMM()
378
+ vqa = CONFIG.create_vqa()
353
379
  if isinstance(medias, np.ndarray):
354
380
  medias = [medias]
355
381
  if isinstance(medias, list) and len(medias) > 5:
@@ -358,7 +384,7 @@ def claude35_vqa(prompt: str, medias: List[np.ndarray]) -> None:
358
384
  "data:image/png;base64," + convert_to_b64(media) for media in medias
359
385
  ]
360
386
 
361
- response = cast(str, lmm.generate(prompt, media=all_media_b64))
387
+ response = cast(str, vqa.generate(prompt, media=all_media_b64))
362
388
  print(f"[claude35_vqa output]\n{response}\n[end of claude35_vqa output]")
363
389
 
364
390
 
@@ -318,6 +318,9 @@ def single_nms(
318
318
  def nms(
319
319
  all_preds: List[List[Dict[str, Any]]], iou_threshold: float
320
320
  ) -> List[List[Dict[str, Any]]]:
321
+ if not isinstance(all_preds[0], List):
322
+ all_preds = [all_preds]
323
+
321
324
  return_preds = []
322
325
  for frame_preds in all_preds:
323
326
  frame_preds = single_nms(frame_preds, iou_threshold)