vision-agent 0.2.131__tar.gz → 0.2.133__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (33) hide show
  1. {vision_agent-0.2.131 → vision_agent-0.2.133}/PKG-INFO +1 -2
  2. {vision_agent-0.2.131 → vision_agent-0.2.133}/pyproject.toml +1 -2
  3. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/agent/__init__.py +1 -0
  4. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/agent/agent_utils.py +30 -18
  5. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/agent/vision_agent.py +26 -3
  6. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/agent/vision_agent_coder.py +86 -26
  7. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/agent/vision_agent_coder_prompts.py +34 -8
  8. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/lmm/lmm.py +1 -1
  9. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/tools/__init__.py +1 -0
  10. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/tools/tools.py +42 -2
  11. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/utils/execute.py +12 -10
  12. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/utils/video.py +22 -11
  13. {vision_agent-0.2.131 → vision_agent-0.2.133}/LICENSE +0 -0
  14. {vision_agent-0.2.131 → vision_agent-0.2.133}/README.md +0 -0
  15. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/__init__.py +0 -0
  16. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/agent/agent.py +0 -0
  17. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/agent/vision_agent_prompts.py +0 -0
  18. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/clients/__init__.py +0 -0
  19. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/clients/http.py +0 -0
  20. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/clients/landing_public_api.py +0 -0
  21. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/fonts/__init__.py +0 -0
  22. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/fonts/default_font_ch_en.ttf +0 -0
  23. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/lmm/__init__.py +0 -0
  24. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/lmm/types.py +0 -0
  25. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/tools/meta_tools.py +0 -0
  26. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/tools/prompts.py +0 -0
  27. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/tools/tool_utils.py +1 -1
  28. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/tools/tools_types.py +0 -0
  29. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/utils/__init__.py +0 -0
  30. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/utils/exceptions.py +0 -0
  31. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/utils/image_utils.py +0 -0
  32. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/utils/sim.py +0 -0
  33. {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/utils/type_defs.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vision-agent
3
- Version: 0.2.131
3
+ Version: 0.2.133
4
4
  Summary: Toolset for Vision Agent
5
5
  Author: Landing AI
6
6
  Author-email: dev@landing.ai
@@ -13,7 +13,6 @@ Requires-Dist: anthropic (>=0.31.0,<0.32.0)
13
13
  Requires-Dist: av (>=11.0.0,<12.0.0)
14
14
  Requires-Dist: e2b (>=0.17.2a50,<0.18.0)
15
15
  Requires-Dist: e2b-code-interpreter (==0.0.11a37)
16
- Requires-Dist: eva-decord (>=0.6.1,<0.7.0)
17
16
  Requires-Dist: ipykernel (>=6.29.4,<7.0.0)
18
17
  Requires-Dist: langsmith (>=0.1.58,<0.2.0)
19
18
  Requires-Dist: nbclient (>=0.10.0,<0.11.0)
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "vision-agent"
7
- version = "0.2.131"
7
+ version = "0.2.133"
8
8
  description = "Toolset for Vision Agent"
9
9
  authors = ["Landing AI <dev@landing.ai>"]
10
10
  readme = "README.md"
@@ -41,7 +41,6 @@ pillow-heif = "^0.16.0"
41
41
  pytube = "15.0.0"
42
42
  anthropic = "^0.31.0"
43
43
  pydantic = "2.7.4"
44
- eva-decord = "^0.6.1"
45
44
  av = "^11.0.0"
46
45
 
47
46
  [tool.poetry.group.dev.dependencies]
@@ -2,6 +2,7 @@ from .agent import Agent
2
2
  from .vision_agent import VisionAgent
3
3
  from .vision_agent_coder import (
4
4
  AzureVisionAgentCoder,
5
+ ClaudeVisionAgentCoder,
5
6
  OllamaVisionAgentCoder,
6
7
  VisionAgentCoder,
7
8
  )
@@ -14,6 +14,10 @@ def _extract_sub_json(json_str: str) -> Optional[Dict[str, Any]]:
14
14
  if match:
15
15
  json_str = match.group()
16
16
  try:
17
+ # remove trailing comma
18
+ trailing_bracket_pattern = r",\s+\}"
19
+ json_str = re.sub(trailing_bracket_pattern, "}", json_str, flags=re.DOTALL)
20
+
17
21
  json_dict = json.loads(json_str)
18
22
  return json_dict # type: ignore
19
23
  except json.JSONDecodeError:
@@ -21,29 +25,37 @@ def _extract_sub_json(json_str: str) -> Optional[Dict[str, Any]]:
21
25
  return None
22
26
 
23
27
 
28
+ def _find_markdown_json(json_str: str) -> str:
29
+ pattern = r"```json(.*?)```"
30
+ match = re.search(pattern, json_str, re.DOTALL)
31
+ if match:
32
+ return match.group(1).strip()
33
+ return json_str
34
+
35
+
36
+ def _strip_markdown_code(inp_str: str) -> str:
37
+ pattern = r"```python.*?```"
38
+ cleaned_str = re.sub(pattern, "", inp_str, flags=re.DOTALL)
39
+ return cleaned_str
40
+
41
+
24
42
  def extract_json(json_str: str) -> Dict[str, Any]:
43
+ json_str = json_str.replace("\n", " ").strip()
44
+
25
45
  try:
26
- json_str = json_str.replace("\n", " ")
27
- json_dict = json.loads(json_str)
46
+ return json.loads(json_str) # type: ignore
28
47
  except json.JSONDecodeError:
29
- if "```json" in json_str:
30
- json_str = json_str[json_str.find("```json") + len("```json") :]
31
- json_str = json_str[: json_str.find("```")]
32
- elif "```" in json_str:
33
- json_str = json_str[json_str.find("```") + len("```") :]
34
- # get the last ``` not one from an intermediate string
35
- json_str = json_str[: json_str.find("}```")]
36
- try:
37
- json_dict = json.loads(json_str)
38
- except json.JSONDecodeError as e:
39
- json_dict = _extract_sub_json(json_str)
40
- if json_dict is not None:
41
- return json_dict # type: ignore
42
- error_msg = f"Could not extract JSON from the given str: {json_str}"
48
+ json_orig = json_str
49
+ json_str = _strip_markdown_code(json_str)
50
+ json_str = _find_markdown_json(json_str)
51
+ json_dict = _extract_sub_json(json_str)
52
+
53
+ if json_dict is None:
54
+ error_msg = f"Could not extract JSON from the given str: {json_orig}"
43
55
  _LOGGER.exception(error_msg)
44
- raise ValueError(error_msg) from e
56
+ raise ValueError(error_msg)
45
57
 
46
- return json_dict # type: ignore
58
+ return json_dict
47
59
 
48
60
 
49
61
  def extract_code(code: str) -> str:
@@ -3,7 +3,7 @@ import logging
3
3
  import os
4
4
  import tempfile
5
5
  from pathlib import Path
6
- from typing import Any, Dict, List, Optional, Tuple, Union, cast
6
+ from typing import Any, Dict, List, Optional, Tuple, Union, cast, Callable
7
7
 
8
8
  from vision_agent.agent import Agent
9
9
  from vision_agent.agent.agent_utils import extract_json
@@ -13,7 +13,7 @@ from vision_agent.agent.vision_agent_prompts import (
13
13
  VA_CODE,
14
14
  )
15
15
  from vision_agent.lmm import LMM, Message, OpenAILMM
16
- from vision_agent.tools import META_TOOL_DOCSTRING
16
+ from vision_agent.tools import META_TOOL_DOCSTRING, save_image, load_image
17
17
  from vision_agent.tools.meta_tools import Artifacts, use_extra_vision_agent_args
18
18
  from vision_agent.utils import CodeInterpreterFactory
19
19
  from vision_agent.utils.execute import CodeInterpreter, Execution
@@ -123,6 +123,7 @@ class VisionAgent(Agent):
123
123
  verbosity: int = 0,
124
124
  local_artifacts_path: Optional[Union[str, Path]] = None,
125
125
  code_sandbox_runtime: Optional[str] = None,
126
+ callback_message: Optional[Callable[[Dict[str, Any]], None]] = None,
126
127
  ) -> None:
127
128
  """Initialize the VisionAgent.
128
129
 
@@ -141,6 +142,7 @@ class VisionAgent(Agent):
141
142
  self.max_iterations = 100
142
143
  self.verbosity = verbosity
143
144
  self.code_sandbox_runtime = code_sandbox_runtime
145
+ self.callback_message = callback_message
144
146
  if self.verbosity >= 1:
145
147
  _LOGGER.setLevel(logging.INFO)
146
148
  self.local_artifacts_path = cast(
@@ -220,7 +222,14 @@ class VisionAgent(Agent):
220
222
  for chat_i in int_chat:
221
223
  if "media" in chat_i:
222
224
  for media in chat_i["media"]:
223
- media = cast(str, media)
225
+ if type(media) is str and media.startswith(("http", "https")):
226
+ # TODO: Ideally we should not call VA.tools here, we should come to revisit how to better support remote image later
227
+ file_path = Path(media).name
228
+ ndarray = load_image(media)
229
+ save_image(ndarray, file_path)
230
+ media = file_path
231
+ else:
232
+ media = cast(str, media)
224
233
  artifacts.artifacts[Path(media).name] = open(media, "rb").read()
225
234
 
226
235
  media_remote_path = (
@@ -262,6 +271,7 @@ class VisionAgent(Agent):
262
271
  artifacts_loaded = artifacts.show()
263
272
  int_chat.append({"role": "observation", "content": artifacts_loaded})
264
273
  orig_chat.append({"role": "observation", "content": artifacts_loaded})
274
+ self.streaming_message({"role": "observation", "content": artifacts_loaded})
265
275
 
266
276
  while not finished and iterations < self.max_iterations:
267
277
  response = run_conversation(self.agent, int_chat)
@@ -274,6 +284,8 @@ class VisionAgent(Agent):
274
284
  if last_response == response:
275
285
  response["let_user_respond"] = True
276
286
 
287
+ self.streaming_message({"role": "assistant", "content": response})
288
+
277
289
  if response["let_user_respond"]:
278
290
  break
279
291
 
@@ -293,6 +305,13 @@ class VisionAgent(Agent):
293
305
  orig_chat.append(
294
306
  {"role": "observation", "content": obs, "execution": result}
295
307
  )
308
+ self.streaming_message(
309
+ {
310
+ "role": "observation",
311
+ "content": obs,
312
+ "execution": result,
313
+ }
314
+ )
296
315
 
297
316
  iterations += 1
298
317
  last_response = response
@@ -305,5 +324,9 @@ class VisionAgent(Agent):
305
324
  artifacts.save()
306
325
  return orig_chat, artifacts
307
326
 
327
+ def streaming_message(self, message: Dict[str, Any]) -> None:
328
+ if self.callback_message:
329
+ self.callback_message(message)
330
+
308
331
  def log_progress(self, data: Dict[str, Any]) -> None:
309
332
  pass
@@ -27,7 +27,14 @@ from vision_agent.agent.vision_agent_coder_prompts import (
27
27
  TEST_PLANS,
28
28
  USER_REQ,
29
29
  )
30
- from vision_agent.lmm import LMM, AzureOpenAILMM, Message, OllamaLMM, OpenAILMM
30
+ from vision_agent.lmm import (
31
+ LMM,
32
+ AzureOpenAILMM,
33
+ ClaudeSonnetLMM,
34
+ Message,
35
+ OllamaLMM,
36
+ OpenAILMM,
37
+ )
31
38
  from vision_agent.tools.meta_tools import get_diff
32
39
  from vision_agent.utils import CodeInterpreterFactory, Execution
33
40
  from vision_agent.utils.execute import CodeInterpreter
@@ -167,9 +174,10 @@ def pick_plan(
167
174
  }
168
175
  )
169
176
  tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code))
170
- tool_output_str = ""
171
- if len(tool_output.logs.stdout) > 0:
172
- tool_output_str = tool_output.logs.stdout[0]
177
+ # Because of the way we trace function calls the trace information ends up in the
178
+ # results. We don't want to show this info to the LLM so we don't include it in the
179
+ # tool_output_str.
180
+ tool_output_str = tool_output.text(include_results=False).strip()
173
181
 
174
182
  if verbosity == 2:
175
183
  _print_code("Initial code and tests:", code)
@@ -196,7 +204,7 @@ def pick_plan(
196
204
  docstring=tool_info,
197
205
  plans=plan_str,
198
206
  previous_attempts=PREVIOUS_FAILED.format(
199
- code=code, error=tool_output.text()
207
+ code=code, error="\n".join(tool_output_str.splitlines()[-50:])
200
208
  ),
201
209
  media=media,
202
210
  )
@@ -225,11 +233,11 @@ def pick_plan(
225
233
  "status": "completed" if tool_output.success else "failed",
226
234
  }
227
235
  )
228
- tool_output_str = tool_output.text().strip()
236
+ tool_output_str = tool_output.text(include_results=False).strip()
229
237
 
230
238
  if verbosity == 2:
231
239
  _print_code("Code and test after attempted fix:", code)
232
- _LOGGER.info(f"Code execution result after attempt {count}")
240
+ _LOGGER.info(f"Code execution result after attempt {count + 1}")
233
241
 
234
242
  count += 1
235
243
 
@@ -387,7 +395,6 @@ def write_and_test_code(
387
395
  "code": DefaultImports.prepend_imports(code),
388
396
  "payload": {
389
397
  "test": test,
390
- # "result": result.to_json(),
391
398
  },
392
399
  }
393
400
  )
@@ -406,6 +413,7 @@ def write_and_test_code(
406
413
  working_memory,
407
414
  debugger,
408
415
  code_interpreter,
416
+ tool_info,
409
417
  code,
410
418
  test,
411
419
  result,
@@ -431,6 +439,7 @@ def debug_code(
431
439
  working_memory: List[Dict[str, str]],
432
440
  debugger: LMM,
433
441
  code_interpreter: CodeInterpreter,
442
+ tool_info: str,
434
443
  code: str,
435
444
  test: str,
436
445
  result: Execution,
@@ -451,17 +460,38 @@ def debug_code(
451
460
  count = 0
452
461
  while not success and count < 3:
453
462
  try:
454
- fixed_code_and_test = extract_json(
455
- debugger( # type: ignore
456
- FIX_BUG.format(
457
- code=code,
458
- tests=test,
459
- result="\n".join(result.text().splitlines()[-50:]),
460
- feedback=format_memory(working_memory + new_working_memory),
463
+ # LLMs write worse code when it's in JSON, so we have it write JSON
464
+ # followed by code each wrapped in markdown blocks.
465
+ fixed_code_and_test_str = debugger(
466
+ FIX_BUG.format(
467
+ docstring=tool_info,
468
+ code=code,
469
+ tests=test,
470
+ # Because of the way we trace function calls the trace information
471
+ # ends up in the results. We don't want to show this info to the
472
+ # LLM so we don't include it in the tool_output_str.
473
+ result="\n".join(
474
+ result.text(include_results=False).splitlines()[-50:]
461
475
  ),
462
- stream=False,
463
- )
476
+ feedback=format_memory(working_memory + new_working_memory),
477
+ ),
478
+ stream=False,
464
479
  )
480
+ fixed_code_and_test_str = cast(str, fixed_code_and_test_str)
481
+ fixed_code_and_test = extract_json(fixed_code_and_test_str)
482
+ code = extract_code(fixed_code_and_test_str)
483
+ if (
484
+ "which_code" in fixed_code_and_test
485
+ and fixed_code_and_test["which_code"] == "test"
486
+ ):
487
+ fixed_code_and_test["code"] = ""
488
+ fixed_code_and_test["test"] = code
489
+ else: # for everything else always assume it's updating code
490
+ fixed_code_and_test["code"] = code
491
+ fixed_code_and_test["test"] = ""
492
+ if "which_code" in fixed_code_and_test:
493
+ del fixed_code_and_test["which_code"]
494
+
465
495
  success = True
466
496
  except Exception as e:
467
497
  _LOGGER.exception(f"Error while extracting JSON: {e}")
@@ -472,9 +502,9 @@ def debug_code(
472
502
  old_test = test
473
503
 
474
504
  if fixed_code_and_test["code"].strip() != "":
475
- code = extract_code(fixed_code_and_test["code"])
505
+ code = fixed_code_and_test["code"]
476
506
  if fixed_code_and_test["test"].strip() != "":
477
- test = extract_code(fixed_code_and_test["test"])
507
+ test = fixed_code_and_test["test"]
478
508
 
479
509
  new_working_memory.append(
480
510
  {
@@ -628,9 +658,7 @@ class VisionAgentCoder(Agent):
628
658
  )
629
659
  self.coder = OpenAILMM(temperature=0.0) if coder is None else coder
630
660
  self.tester = OpenAILMM(temperature=0.0) if tester is None else tester
631
- self.debugger = (
632
- OpenAILMM(temperature=0.0, json_mode=True) if debugger is None else debugger
633
- )
661
+ self.debugger = OpenAILMM(temperature=0.0) if debugger is None else debugger
634
662
  self.verbosity = verbosity
635
663
  if self.verbosity > 0:
636
664
  _LOGGER.setLevel(logging.INFO)
@@ -876,6 +904,40 @@ class VisionAgentCoder(Agent):
876
904
  )
877
905
 
878
906
 
907
+ class ClaudeVisionAgentCoder(VisionAgentCoder):
908
+ def __init__(
909
+ self,
910
+ planner: Optional[LMM] = None,
911
+ coder: Optional[LMM] = None,
912
+ tester: Optional[LMM] = None,
913
+ debugger: Optional[LMM] = None,
914
+ tool_recommender: Optional[Sim] = None,
915
+ verbosity: int = 0,
916
+ report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
917
+ code_sandbox_runtime: Optional[str] = None,
918
+ ) -> None:
919
+ # NOTE: Claude doesn't have an official JSON mode
920
+ self.planner = ClaudeSonnetLMM(temperature=0.0) if planner is None else planner
921
+ self.coder = ClaudeSonnetLMM(temperature=0.0) if coder is None else coder
922
+ self.tester = ClaudeSonnetLMM(temperature=0.0) if tester is None else tester
923
+ self.debugger = (
924
+ ClaudeSonnetLMM(temperature=0.0) if debugger is None else debugger
925
+ )
926
+ self.verbosity = verbosity
927
+ if self.verbosity > 0:
928
+ _LOGGER.setLevel(logging.INFO)
929
+
930
+ # Anthropic does not offer any embedding models and instead recomends Voyage,
931
+ # we're using OpenAI's embedder for now.
932
+ self.tool_recommender = (
933
+ Sim(T.TOOLS_DF, sim_key="desc")
934
+ if tool_recommender is None
935
+ else tool_recommender
936
+ )
937
+ self.report_progress_callback = report_progress_callback
938
+ self.code_sandbox_runtime = code_sandbox_runtime
939
+
940
+
879
941
  class OllamaVisionAgentCoder(VisionAgentCoder):
880
942
  """VisionAgentCoder that uses Ollama models for planning, coding, testing.
881
943
 
@@ -920,7 +982,7 @@ class OllamaVisionAgentCoder(VisionAgentCoder):
920
982
  else tester
921
983
  ),
922
984
  debugger=(
923
- OllamaLMM(model_name="llama3.1", temperature=0.0, json_mode=True)
985
+ OllamaLMM(model_name="llama3.1", temperature=0.0)
924
986
  if debugger is None
925
987
  else debugger
926
988
  ),
@@ -983,9 +1045,7 @@ class AzureVisionAgentCoder(VisionAgentCoder):
983
1045
  coder=AzureOpenAILMM(temperature=0.0) if coder is None else coder,
984
1046
  tester=AzureOpenAILMM(temperature=0.0) if tester is None else tester,
985
1047
  debugger=(
986
- AzureOpenAILMM(temperature=0.0, json_mode=True)
987
- if debugger is None
988
- else debugger
1048
+ AzureOpenAILMM(temperature=0.0) if debugger is None else debugger
989
1049
  ),
990
1050
  tool_recommender=(
991
1051
  AzureSim(T.TOOLS_DF, sim_key="desc")
@@ -63,6 +63,7 @@ This is the documentation for the functions you have access to. You may call any
63
63
  **Plans**:
64
64
  {plans}
65
65
 
66
+ **Previous Attempts**:
66
67
  {previous_attempts}
67
68
 
68
69
  **Instructions**:
@@ -108,16 +109,27 @@ plan2:
108
109
  - Use the 'florence2_phrase_grounding' tool with the prompt 'person' to detect where the people are in the video.
109
110
  plan3:
110
111
  - Extract frames from 'video.mp4' at 10 FPS using the 'extract_frames' tool.
111
- - Use the 'countgd_counting' tool with the prompt 'person' to detect where the people are in the video.
112
+ - Use the 'florence2_sam2_video_tracking' tool with the prompt 'person' to detect where the people are in the video.
112
113
 
113
114
 
114
115
  ```python
115
- from vision_agent.tools import extract_frames, owl_v2_image, florence2_phrase_grounding, countgd_counting
116
+ import numpy as np
117
+ from vision_agent.tools import extract_frames, owl_v2_image, florence2_phrase_grounding, florence2_sam2_video_tracking
116
118
 
117
119
  # sample at 1 FPS and use the first 10 frames to reduce processing time
118
120
  frames = extract_frames("video.mp4", 1)
119
121
  frames = [f[0] for f in frames][:10]
120
122
 
123
+ def remove_arrays(o):
124
+ if isinstance(o, list):
125
+ return [remove_arrays(e) for e in o]
126
+ elif isinstance(o, dict):
127
+ return {{k: remove_arrays(v) for k, v in o.items()}}
128
+ elif isinstance(o, np.ndarray):
129
+ return "array: " + str(o.shape)
130
+ else:
131
+ return o
132
+
121
133
  # plan1
122
134
  owl_v2_out = [owl_v2_image("person", f) for f in frames]
123
135
 
@@ -125,9 +137,10 @@ owl_v2_out = [owl_v2_image("person", f) for f in frames]
125
137
  florence2_out = [florence2_phrase_grounding("person", f) for f in frames]
126
138
 
127
139
  # plan3
128
- countgd_out = [countgd_counting(f) for f in frames]
140
+ f2s2_tracking_out = florence2_sam2_video_tracking("person", frames)
141
+ remove_arrays(f2s2_tracking_out)
129
142
 
130
- final_out = {{"owl_v2_image": owl_v2_out, "florencev2_object_detection": florencev2_out, "countgd_counting": cgd_out}}
143
+ final_out = {{"owl_v2_image": owl_v2_out, "florence2_phrase_grounding": florence2_out, "florence2_sam2_video_tracking": f2s2_tracking_out}}
131
144
  print(final_out)
132
145
  ```
133
146
  """
@@ -161,9 +174,10 @@ PICK_PLAN = """
161
174
 
162
175
  **Instructions**:
163
176
  1. Given the plans, image, and tool outputs, decide which plan is the best to achieve the user request.
164
- 2. Try solving the problem yourself given the image and pick the plan that matches your solution the best.
177
+ 2. Solve the problem yourself given the image and pick the plan that matches your solution the best.
165
178
  3. Output a JSON object with the following format:
166
179
  {{
180
+ "predicted_answer": str # the answer you would expect from the best plan
167
181
  "thoughts": str # your thought process for choosing the best plan
168
182
  "best_plan": str # the best plan you have chosen
169
183
  }}
@@ -311,6 +325,11 @@ This is the documentation for the functions you have access to. You may call any
311
325
  FIX_BUG = """
312
326
  **Role** As a coder, your job is to find the error in the code and fix it. You are running in a notebook setting so you can run !pip install to install missing packages.
313
327
 
328
+ **Documentation**:
329
+ This is the documentation for the functions you have access to. You may call any of these functions to help you complete the task. They are available through importing `from vision_agent.tools import *`.
330
+
331
+ {docstring}
332
+
314
333
  **Instructions**:
315
334
  Please re-complete the code to fix the error message. Here is the previous version:
316
335
  ```python
@@ -323,17 +342,24 @@ When we run this test code:
323
342
  ```
324
343
 
325
344
  It raises this error:
345
+ ```
326
346
  {result}
347
+ ```
327
348
 
328
349
  This is previous feedback provided on the code:
329
350
  {feedback}
330
351
 
331
- Please fix the bug by follow the error information and return a JSON object with the following format:
352
+ Please fix the bug by correcting the error. Return the following JSON object followed by the fixed code in the below format:
353
+ ```json
332
354
  {{
333
355
  "reflections": str # any thoughts you have about the bug and how you fixed it
334
- "code": str # the fixed code if any, else an empty string
335
- "test": str # the fixed test code if any, else an empty string
356
+ "which_code": str # the code that was fixed, can only be 'code' or 'test'
336
357
  }}
358
+ ```
359
+
360
+ ```python
361
+ # Your fixed code here
362
+ ```
337
363
  """
338
364
 
339
365
 
@@ -371,7 +371,7 @@ class ClaudeSonnetLMM(LMM):
371
371
  def __init__(
372
372
  self,
373
373
  api_key: Optional[str] = None,
374
- model_name: str = "claude-3-sonnet-20240229",
374
+ model_name: str = "claude-3-5-sonnet-20240620",
375
375
  max_tokens: int = 4096,
376
376
  **kwargs: Any,
377
377
  ):
@@ -37,6 +37,7 @@ from .tools import (
37
37
  grounding_dino,
38
38
  grounding_sam,
39
39
  ixc25_image_vqa,
40
+ ixc25_temporal_localization,
40
41
  ixc25_video_vqa,
41
42
  load_image,
42
43
  loca_visual_prompt_counting,
@@ -468,7 +468,7 @@ def ocr(image: np.ndarray) -> List[Dict[str, Any]]:
468
468
 
469
469
  pil_image = Image.fromarray(image).convert("RGB")
470
470
  image_size = pil_image.size[::-1]
471
- if image_size[0] < 1 and image_size[1] < 1:
471
+ if image_size[0] < 1 or image_size[1] < 1:
472
472
  return []
473
473
  image_buffer = io.BytesIO()
474
474
  pil_image.save(image_buffer, format="PNG")
@@ -781,6 +781,44 @@ def ixc25_video_vqa(prompt: str, frames: List[np.ndarray]) -> str:
781
781
  return cast(str, data["answer"])
782
782
 
783
783
 
784
+ def ixc25_temporal_localization(prompt: str, frames: List[np.ndarray]) -> List[bool]:
785
+ """'ixc25_temporal_localization' uses ixc25_video_vqa to temporally segment a video
786
+ given a prompt that can be other an object or a phrase. It returns a list of
787
+ boolean values indicating whether the object or phrase is present in the
788
+ corresponding frame.
789
+
790
+ Parameters:
791
+ prompt (str): The question about the video
792
+ frames (List[np.ndarray]): The reference frames used for the question
793
+
794
+ Returns:
795
+ List[bool]: A list of boolean values indicating whether the object or phrase is
796
+ present in the corresponding frame.
797
+
798
+ Example
799
+ -------
800
+ >>> output = ixc25_temporal_localization('soccer goal', frames)
801
+ >>> print(output)
802
+ [False, False, False, True, True, True, False, False, False, False]
803
+ >>> save_video([f for i, f in enumerate(frames) if output[i]], 'output.mp4')
804
+ """
805
+
806
+ buffer_bytes = frames_to_bytes(frames)
807
+ files = [("video", buffer_bytes)]
808
+ payload = {
809
+ "prompt": prompt,
810
+ "chunk_length": 2,
811
+ "function_name": "ixc25_temporal_localization",
812
+ }
813
+ data: List[int] = send_inference_request(
814
+ payload, "video-temporal-localization", files=files, v2=True
815
+ )
816
+ chunk_size = round(len(frames) / len(data))
817
+ data_explode = [[elt] * chunk_size for elt in data]
818
+ data_bool = [bool(elt) for sublist in data_explode for elt in sublist]
819
+ return data_bool[: len(frames)]
820
+
821
+
784
822
  def gpt4o_image_vqa(prompt: str, image: np.ndarray) -> str:
785
823
  """'gpt4o_image_vqa' is a tool that can answer any questions about arbitrary images
786
824
  including regular images or images of documents or presentations. It returns text
@@ -1112,6 +1150,8 @@ def florence2_ocr(image: np.ndarray) -> List[Dict[str, Any]]:
1112
1150
  """
1113
1151
 
1114
1152
  image_size = image.shape[:2]
1153
+ if image_size[0] < 1 or image_size[1] < 1:
1154
+ return []
1115
1155
  image_b64 = convert_to_b64(image)
1116
1156
  data = {
1117
1157
  "image": image_b64,
@@ -1467,7 +1507,7 @@ def extract_frames(
1467
1507
  Parameters:
1468
1508
  video_uri (Union[str, Path]): The path to the video file, url or youtube link
1469
1509
  fps (float, optional): The frame rate per second to extract the frames. Defaults
1470
- to 10.
1510
+ to 1.
1471
1511
 
1472
1512
  Returns:
1473
1513
  List[Tuple[np.ndarray, float]]: A list of tuples containing the extracted frame
@@ -292,7 +292,7 @@ class Execution(BaseModel):
292
292
  error: Optional[Error] = None
293
293
  "Error object if an error occurred, None otherwise."
294
294
 
295
- def text(self, include_logs: bool = True) -> str:
295
+ def text(self, include_logs: bool = True, include_results: bool = True) -> str:
296
296
  """Returns the text representation of this object, i.e. including the main
297
297
  result or the error traceback, optionally along with the logs (stdout, stderr).
298
298
  """
@@ -300,15 +300,17 @@ class Execution(BaseModel):
300
300
  if self.error:
301
301
  return prefix + "\n----- Error -----\n" + self.error.traceback
302
302
 
303
- result_str = [
304
- (
305
- f"----- Final output -----\n{res.text}"
306
- if res.is_main_result
307
- else f"----- Intermediate output-----\n{res.text}"
308
- )
309
- for res in self.results
310
- ]
311
- return prefix + "\n" + "\n".join(result_str)
303
+ if include_results:
304
+ result_str = [
305
+ (
306
+ f"----- Final output -----\n{res.text}"
307
+ if res.is_main_result
308
+ else f"----- Intermediate output-----\n{res.text}"
309
+ )
310
+ for res in self.results
311
+ ]
312
+ return prefix + "\n" + "\n".join(result_str)
313
+ return prefix
312
314
 
313
315
  @property
314
316
  def success(self) -> bool:
@@ -7,7 +7,6 @@ from typing import List, Optional, Tuple
7
7
  import av # type: ignore
8
8
  import cv2
9
9
  import numpy as np
10
- from decord import VideoReader # type: ignore
11
10
 
12
11
  _LOGGER = logging.getLogger(__name__)
13
12
  # The maximum length of the clip to extract frames from, in seconds
@@ -103,7 +102,7 @@ def frames_to_bytes(
103
102
  def extract_frames_from_video(
104
103
  video_uri: str, fps: float = 1.0
105
104
  ) -> List[Tuple[np.ndarray, float]]:
106
- """Extract frames from a video
105
+ """Extract frames from a video along with the timestamp in seconds.
107
106
 
108
107
  Parameters:
109
108
  video_uri (str): the path to the video file or a video file url
@@ -115,12 +114,24 @@ def extract_frames_from_video(
115
114
  from the start of the video. E.g. 12.125 means 12.125 seconds from the start of
116
115
  the video. The frames are sorted by the timestamp in ascending order.
117
116
  """
118
- vr = VideoReader(video_uri)
119
- orig_fps = vr.get_avg_fps()
120
- if fps > orig_fps:
121
- fps = orig_fps
122
-
123
- s = orig_fps / fps
124
- samples = [(int(i * s), int(i * s) / orig_fps) for i in range(int(len(vr) / s))]
125
- frames = vr.get_batch([s[0] for s in samples]).asnumpy()
126
- return [(frames[i, :, :, :], samples[i][1]) for i in range(len(samples))]
117
+
118
+ cap = cv2.VideoCapture(video_uri)
119
+ orig_fps = cap.get(cv2.CAP_PROP_FPS)
120
+ orig_frame_time = 1 / orig_fps
121
+ targ_frame_time = 1 / fps
122
+ frames: List[Tuple[np.ndarray, float]] = []
123
+ i = 0
124
+ elapsed_time = 0.0
125
+ while cap.isOpened():
126
+ ret, frame = cap.read()
127
+ if not ret:
128
+ break
129
+
130
+ elapsed_time += orig_frame_time
131
+ if elapsed_time >= targ_frame_time:
132
+ frames.append((cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), i / orig_fps))
133
+ elapsed_time -= targ_frame_time
134
+
135
+ i += 1
136
+ cap.release()
137
+ return frames
File without changes
File without changes
@@ -1,7 +1,7 @@
1
- from base64 import b64encode
2
1
  import inspect
3
2
  import logging
4
3
  import os
4
+ from base64 import b64encode
5
5
  from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple
6
6
 
7
7
  import pandas as pd