vision-agent 0.2.131__tar.gz → 0.2.133__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {vision_agent-0.2.131 → vision_agent-0.2.133}/PKG-INFO +1 -2
- {vision_agent-0.2.131 → vision_agent-0.2.133}/pyproject.toml +1 -2
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/agent/__init__.py +1 -0
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/agent/agent_utils.py +30 -18
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/agent/vision_agent.py +26 -3
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/agent/vision_agent_coder.py +86 -26
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/agent/vision_agent_coder_prompts.py +34 -8
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/lmm/lmm.py +1 -1
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/tools/__init__.py +1 -0
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/tools/tools.py +42 -2
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/utils/execute.py +12 -10
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/utils/video.py +22 -11
- {vision_agent-0.2.131 → vision_agent-0.2.133}/LICENSE +0 -0
- {vision_agent-0.2.131 → vision_agent-0.2.133}/README.md +0 -0
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/__init__.py +0 -0
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/agent/agent.py +0 -0
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/agent/vision_agent_prompts.py +0 -0
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/clients/__init__.py +0 -0
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/clients/http.py +0 -0
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/clients/landing_public_api.py +0 -0
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/fonts/__init__.py +0 -0
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/fonts/default_font_ch_en.ttf +0 -0
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/lmm/__init__.py +0 -0
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/lmm/types.py +0 -0
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/tools/meta_tools.py +0 -0
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/tools/prompts.py +0 -0
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/tools/tool_utils.py +1 -1
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/tools/tools_types.py +0 -0
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/utils/__init__.py +0 -0
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/utils/exceptions.py +0 -0
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/utils/image_utils.py +0 -0
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/utils/sim.py +0 -0
- {vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/utils/type_defs.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: vision-agent
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.133
|
4
4
|
Summary: Toolset for Vision Agent
|
5
5
|
Author: Landing AI
|
6
6
|
Author-email: dev@landing.ai
|
@@ -13,7 +13,6 @@ Requires-Dist: anthropic (>=0.31.0,<0.32.0)
|
|
13
13
|
Requires-Dist: av (>=11.0.0,<12.0.0)
|
14
14
|
Requires-Dist: e2b (>=0.17.2a50,<0.18.0)
|
15
15
|
Requires-Dist: e2b-code-interpreter (==0.0.11a37)
|
16
|
-
Requires-Dist: eva-decord (>=0.6.1,<0.7.0)
|
17
16
|
Requires-Dist: ipykernel (>=6.29.4,<7.0.0)
|
18
17
|
Requires-Dist: langsmith (>=0.1.58,<0.2.0)
|
19
18
|
Requires-Dist: nbclient (>=0.10.0,<0.11.0)
|
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
|
|
4
4
|
|
5
5
|
[tool.poetry]
|
6
6
|
name = "vision-agent"
|
7
|
-
version = "0.2.
|
7
|
+
version = "0.2.133"
|
8
8
|
description = "Toolset for Vision Agent"
|
9
9
|
authors = ["Landing AI <dev@landing.ai>"]
|
10
10
|
readme = "README.md"
|
@@ -41,7 +41,6 @@ pillow-heif = "^0.16.0"
|
|
41
41
|
pytube = "15.0.0"
|
42
42
|
anthropic = "^0.31.0"
|
43
43
|
pydantic = "2.7.4"
|
44
|
-
eva-decord = "^0.6.1"
|
45
44
|
av = "^11.0.0"
|
46
45
|
|
47
46
|
[tool.poetry.group.dev.dependencies]
|
@@ -14,6 +14,10 @@ def _extract_sub_json(json_str: str) -> Optional[Dict[str, Any]]:
|
|
14
14
|
if match:
|
15
15
|
json_str = match.group()
|
16
16
|
try:
|
17
|
+
# remove trailing comma
|
18
|
+
trailing_bracket_pattern = r",\s+\}"
|
19
|
+
json_str = re.sub(trailing_bracket_pattern, "}", json_str, flags=re.DOTALL)
|
20
|
+
|
17
21
|
json_dict = json.loads(json_str)
|
18
22
|
return json_dict # type: ignore
|
19
23
|
except json.JSONDecodeError:
|
@@ -21,29 +25,37 @@ def _extract_sub_json(json_str: str) -> Optional[Dict[str, Any]]:
|
|
21
25
|
return None
|
22
26
|
|
23
27
|
|
28
|
+
def _find_markdown_json(json_str: str) -> str:
|
29
|
+
pattern = r"```json(.*?)```"
|
30
|
+
match = re.search(pattern, json_str, re.DOTALL)
|
31
|
+
if match:
|
32
|
+
return match.group(1).strip()
|
33
|
+
return json_str
|
34
|
+
|
35
|
+
|
36
|
+
def _strip_markdown_code(inp_str: str) -> str:
|
37
|
+
pattern = r"```python.*?```"
|
38
|
+
cleaned_str = re.sub(pattern, "", inp_str, flags=re.DOTALL)
|
39
|
+
return cleaned_str
|
40
|
+
|
41
|
+
|
24
42
|
def extract_json(json_str: str) -> Dict[str, Any]:
|
43
|
+
json_str = json_str.replace("\n", " ").strip()
|
44
|
+
|
25
45
|
try:
|
26
|
-
|
27
|
-
json_dict = json.loads(json_str)
|
46
|
+
return json.loads(json_str) # type: ignore
|
28
47
|
except json.JSONDecodeError:
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
try:
|
37
|
-
json_dict = json.loads(json_str)
|
38
|
-
except json.JSONDecodeError as e:
|
39
|
-
json_dict = _extract_sub_json(json_str)
|
40
|
-
if json_dict is not None:
|
41
|
-
return json_dict # type: ignore
|
42
|
-
error_msg = f"Could not extract JSON from the given str: {json_str}"
|
48
|
+
json_orig = json_str
|
49
|
+
json_str = _strip_markdown_code(json_str)
|
50
|
+
json_str = _find_markdown_json(json_str)
|
51
|
+
json_dict = _extract_sub_json(json_str)
|
52
|
+
|
53
|
+
if json_dict is None:
|
54
|
+
error_msg = f"Could not extract JSON from the given str: {json_orig}"
|
43
55
|
_LOGGER.exception(error_msg)
|
44
|
-
raise ValueError(error_msg)
|
56
|
+
raise ValueError(error_msg)
|
45
57
|
|
46
|
-
|
58
|
+
return json_dict
|
47
59
|
|
48
60
|
|
49
61
|
def extract_code(code: str) -> str:
|
@@ -3,7 +3,7 @@ import logging
|
|
3
3
|
import os
|
4
4
|
import tempfile
|
5
5
|
from pathlib import Path
|
6
|
-
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
6
|
+
from typing import Any, Dict, List, Optional, Tuple, Union, cast, Callable
|
7
7
|
|
8
8
|
from vision_agent.agent import Agent
|
9
9
|
from vision_agent.agent.agent_utils import extract_json
|
@@ -13,7 +13,7 @@ from vision_agent.agent.vision_agent_prompts import (
|
|
13
13
|
VA_CODE,
|
14
14
|
)
|
15
15
|
from vision_agent.lmm import LMM, Message, OpenAILMM
|
16
|
-
from vision_agent.tools import META_TOOL_DOCSTRING
|
16
|
+
from vision_agent.tools import META_TOOL_DOCSTRING, save_image, load_image
|
17
17
|
from vision_agent.tools.meta_tools import Artifacts, use_extra_vision_agent_args
|
18
18
|
from vision_agent.utils import CodeInterpreterFactory
|
19
19
|
from vision_agent.utils.execute import CodeInterpreter, Execution
|
@@ -123,6 +123,7 @@ class VisionAgent(Agent):
|
|
123
123
|
verbosity: int = 0,
|
124
124
|
local_artifacts_path: Optional[Union[str, Path]] = None,
|
125
125
|
code_sandbox_runtime: Optional[str] = None,
|
126
|
+
callback_message: Optional[Callable[[Dict[str, Any]], None]] = None,
|
126
127
|
) -> None:
|
127
128
|
"""Initialize the VisionAgent.
|
128
129
|
|
@@ -141,6 +142,7 @@ class VisionAgent(Agent):
|
|
141
142
|
self.max_iterations = 100
|
142
143
|
self.verbosity = verbosity
|
143
144
|
self.code_sandbox_runtime = code_sandbox_runtime
|
145
|
+
self.callback_message = callback_message
|
144
146
|
if self.verbosity >= 1:
|
145
147
|
_LOGGER.setLevel(logging.INFO)
|
146
148
|
self.local_artifacts_path = cast(
|
@@ -220,7 +222,14 @@ class VisionAgent(Agent):
|
|
220
222
|
for chat_i in int_chat:
|
221
223
|
if "media" in chat_i:
|
222
224
|
for media in chat_i["media"]:
|
223
|
-
media
|
225
|
+
if type(media) is str and media.startswith(("http", "https")):
|
226
|
+
# TODO: Ideally we should not call VA.tools here, we should come to revisit how to better support remote image later
|
227
|
+
file_path = Path(media).name
|
228
|
+
ndarray = load_image(media)
|
229
|
+
save_image(ndarray, file_path)
|
230
|
+
media = file_path
|
231
|
+
else:
|
232
|
+
media = cast(str, media)
|
224
233
|
artifacts.artifacts[Path(media).name] = open(media, "rb").read()
|
225
234
|
|
226
235
|
media_remote_path = (
|
@@ -262,6 +271,7 @@ class VisionAgent(Agent):
|
|
262
271
|
artifacts_loaded = artifacts.show()
|
263
272
|
int_chat.append({"role": "observation", "content": artifacts_loaded})
|
264
273
|
orig_chat.append({"role": "observation", "content": artifacts_loaded})
|
274
|
+
self.streaming_message({"role": "observation", "content": artifacts_loaded})
|
265
275
|
|
266
276
|
while not finished and iterations < self.max_iterations:
|
267
277
|
response = run_conversation(self.agent, int_chat)
|
@@ -274,6 +284,8 @@ class VisionAgent(Agent):
|
|
274
284
|
if last_response == response:
|
275
285
|
response["let_user_respond"] = True
|
276
286
|
|
287
|
+
self.streaming_message({"role": "assistant", "content": response})
|
288
|
+
|
277
289
|
if response["let_user_respond"]:
|
278
290
|
break
|
279
291
|
|
@@ -293,6 +305,13 @@ class VisionAgent(Agent):
|
|
293
305
|
orig_chat.append(
|
294
306
|
{"role": "observation", "content": obs, "execution": result}
|
295
307
|
)
|
308
|
+
self.streaming_message(
|
309
|
+
{
|
310
|
+
"role": "observation",
|
311
|
+
"content": obs,
|
312
|
+
"execution": result,
|
313
|
+
}
|
314
|
+
)
|
296
315
|
|
297
316
|
iterations += 1
|
298
317
|
last_response = response
|
@@ -305,5 +324,9 @@ class VisionAgent(Agent):
|
|
305
324
|
artifacts.save()
|
306
325
|
return orig_chat, artifacts
|
307
326
|
|
327
|
+
def streaming_message(self, message: Dict[str, Any]) -> None:
|
328
|
+
if self.callback_message:
|
329
|
+
self.callback_message(message)
|
330
|
+
|
308
331
|
def log_progress(self, data: Dict[str, Any]) -> None:
|
309
332
|
pass
|
@@ -27,7 +27,14 @@ from vision_agent.agent.vision_agent_coder_prompts import (
|
|
27
27
|
TEST_PLANS,
|
28
28
|
USER_REQ,
|
29
29
|
)
|
30
|
-
from vision_agent.lmm import
|
30
|
+
from vision_agent.lmm import (
|
31
|
+
LMM,
|
32
|
+
AzureOpenAILMM,
|
33
|
+
ClaudeSonnetLMM,
|
34
|
+
Message,
|
35
|
+
OllamaLMM,
|
36
|
+
OpenAILMM,
|
37
|
+
)
|
31
38
|
from vision_agent.tools.meta_tools import get_diff
|
32
39
|
from vision_agent.utils import CodeInterpreterFactory, Execution
|
33
40
|
from vision_agent.utils.execute import CodeInterpreter
|
@@ -167,9 +174,10 @@ def pick_plan(
|
|
167
174
|
}
|
168
175
|
)
|
169
176
|
tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code))
|
170
|
-
|
171
|
-
|
172
|
-
|
177
|
+
# Because of the way we trace function calls the trace information ends up in the
|
178
|
+
# results. We don't want to show this info to the LLM so we don't include it in the
|
179
|
+
# tool_output_str.
|
180
|
+
tool_output_str = tool_output.text(include_results=False).strip()
|
173
181
|
|
174
182
|
if verbosity == 2:
|
175
183
|
_print_code("Initial code and tests:", code)
|
@@ -196,7 +204,7 @@ def pick_plan(
|
|
196
204
|
docstring=tool_info,
|
197
205
|
plans=plan_str,
|
198
206
|
previous_attempts=PREVIOUS_FAILED.format(
|
199
|
-
code=code, error=
|
207
|
+
code=code, error="\n".join(tool_output_str.splitlines()[-50:])
|
200
208
|
),
|
201
209
|
media=media,
|
202
210
|
)
|
@@ -225,11 +233,11 @@ def pick_plan(
|
|
225
233
|
"status": "completed" if tool_output.success else "failed",
|
226
234
|
}
|
227
235
|
)
|
228
|
-
tool_output_str = tool_output.text().strip()
|
236
|
+
tool_output_str = tool_output.text(include_results=False).strip()
|
229
237
|
|
230
238
|
if verbosity == 2:
|
231
239
|
_print_code("Code and test after attempted fix:", code)
|
232
|
-
_LOGGER.info(f"Code execution result after attempt {count}")
|
240
|
+
_LOGGER.info(f"Code execution result after attempt {count + 1}")
|
233
241
|
|
234
242
|
count += 1
|
235
243
|
|
@@ -387,7 +395,6 @@ def write_and_test_code(
|
|
387
395
|
"code": DefaultImports.prepend_imports(code),
|
388
396
|
"payload": {
|
389
397
|
"test": test,
|
390
|
-
# "result": result.to_json(),
|
391
398
|
},
|
392
399
|
}
|
393
400
|
)
|
@@ -406,6 +413,7 @@ def write_and_test_code(
|
|
406
413
|
working_memory,
|
407
414
|
debugger,
|
408
415
|
code_interpreter,
|
416
|
+
tool_info,
|
409
417
|
code,
|
410
418
|
test,
|
411
419
|
result,
|
@@ -431,6 +439,7 @@ def debug_code(
|
|
431
439
|
working_memory: List[Dict[str, str]],
|
432
440
|
debugger: LMM,
|
433
441
|
code_interpreter: CodeInterpreter,
|
442
|
+
tool_info: str,
|
434
443
|
code: str,
|
435
444
|
test: str,
|
436
445
|
result: Execution,
|
@@ -451,17 +460,38 @@ def debug_code(
|
|
451
460
|
count = 0
|
452
461
|
while not success and count < 3:
|
453
462
|
try:
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
463
|
+
# LLMs write worse code when it's in JSON, so we have it write JSON
|
464
|
+
# followed by code each wrapped in markdown blocks.
|
465
|
+
fixed_code_and_test_str = debugger(
|
466
|
+
FIX_BUG.format(
|
467
|
+
docstring=tool_info,
|
468
|
+
code=code,
|
469
|
+
tests=test,
|
470
|
+
# Because of the way we trace function calls the trace information
|
471
|
+
# ends up in the results. We don't want to show this info to the
|
472
|
+
# LLM so we don't include it in the tool_output_str.
|
473
|
+
result="\n".join(
|
474
|
+
result.text(include_results=False).splitlines()[-50:]
|
461
475
|
),
|
462
|
-
|
463
|
-
)
|
476
|
+
feedback=format_memory(working_memory + new_working_memory),
|
477
|
+
),
|
478
|
+
stream=False,
|
464
479
|
)
|
480
|
+
fixed_code_and_test_str = cast(str, fixed_code_and_test_str)
|
481
|
+
fixed_code_and_test = extract_json(fixed_code_and_test_str)
|
482
|
+
code = extract_code(fixed_code_and_test_str)
|
483
|
+
if (
|
484
|
+
"which_code" in fixed_code_and_test
|
485
|
+
and fixed_code_and_test["which_code"] == "test"
|
486
|
+
):
|
487
|
+
fixed_code_and_test["code"] = ""
|
488
|
+
fixed_code_and_test["test"] = code
|
489
|
+
else: # for everything else always assume it's updating code
|
490
|
+
fixed_code_and_test["code"] = code
|
491
|
+
fixed_code_and_test["test"] = ""
|
492
|
+
if "which_code" in fixed_code_and_test:
|
493
|
+
del fixed_code_and_test["which_code"]
|
494
|
+
|
465
495
|
success = True
|
466
496
|
except Exception as e:
|
467
497
|
_LOGGER.exception(f"Error while extracting JSON: {e}")
|
@@ -472,9 +502,9 @@ def debug_code(
|
|
472
502
|
old_test = test
|
473
503
|
|
474
504
|
if fixed_code_and_test["code"].strip() != "":
|
475
|
-
code =
|
505
|
+
code = fixed_code_and_test["code"]
|
476
506
|
if fixed_code_and_test["test"].strip() != "":
|
477
|
-
test =
|
507
|
+
test = fixed_code_and_test["test"]
|
478
508
|
|
479
509
|
new_working_memory.append(
|
480
510
|
{
|
@@ -628,9 +658,7 @@ class VisionAgentCoder(Agent):
|
|
628
658
|
)
|
629
659
|
self.coder = OpenAILMM(temperature=0.0) if coder is None else coder
|
630
660
|
self.tester = OpenAILMM(temperature=0.0) if tester is None else tester
|
631
|
-
self.debugger = (
|
632
|
-
OpenAILMM(temperature=0.0, json_mode=True) if debugger is None else debugger
|
633
|
-
)
|
661
|
+
self.debugger = OpenAILMM(temperature=0.0) if debugger is None else debugger
|
634
662
|
self.verbosity = verbosity
|
635
663
|
if self.verbosity > 0:
|
636
664
|
_LOGGER.setLevel(logging.INFO)
|
@@ -876,6 +904,40 @@ class VisionAgentCoder(Agent):
|
|
876
904
|
)
|
877
905
|
|
878
906
|
|
907
|
+
class ClaudeVisionAgentCoder(VisionAgentCoder):
|
908
|
+
def __init__(
|
909
|
+
self,
|
910
|
+
planner: Optional[LMM] = None,
|
911
|
+
coder: Optional[LMM] = None,
|
912
|
+
tester: Optional[LMM] = None,
|
913
|
+
debugger: Optional[LMM] = None,
|
914
|
+
tool_recommender: Optional[Sim] = None,
|
915
|
+
verbosity: int = 0,
|
916
|
+
report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
|
917
|
+
code_sandbox_runtime: Optional[str] = None,
|
918
|
+
) -> None:
|
919
|
+
# NOTE: Claude doesn't have an official JSON mode
|
920
|
+
self.planner = ClaudeSonnetLMM(temperature=0.0) if planner is None else planner
|
921
|
+
self.coder = ClaudeSonnetLMM(temperature=0.0) if coder is None else coder
|
922
|
+
self.tester = ClaudeSonnetLMM(temperature=0.0) if tester is None else tester
|
923
|
+
self.debugger = (
|
924
|
+
ClaudeSonnetLMM(temperature=0.0) if debugger is None else debugger
|
925
|
+
)
|
926
|
+
self.verbosity = verbosity
|
927
|
+
if self.verbosity > 0:
|
928
|
+
_LOGGER.setLevel(logging.INFO)
|
929
|
+
|
930
|
+
# Anthropic does not offer any embedding models and instead recomends Voyage,
|
931
|
+
# we're using OpenAI's embedder for now.
|
932
|
+
self.tool_recommender = (
|
933
|
+
Sim(T.TOOLS_DF, sim_key="desc")
|
934
|
+
if tool_recommender is None
|
935
|
+
else tool_recommender
|
936
|
+
)
|
937
|
+
self.report_progress_callback = report_progress_callback
|
938
|
+
self.code_sandbox_runtime = code_sandbox_runtime
|
939
|
+
|
940
|
+
|
879
941
|
class OllamaVisionAgentCoder(VisionAgentCoder):
|
880
942
|
"""VisionAgentCoder that uses Ollama models for planning, coding, testing.
|
881
943
|
|
@@ -920,7 +982,7 @@ class OllamaVisionAgentCoder(VisionAgentCoder):
|
|
920
982
|
else tester
|
921
983
|
),
|
922
984
|
debugger=(
|
923
|
-
OllamaLMM(model_name="llama3.1", temperature=0.0
|
985
|
+
OllamaLMM(model_name="llama3.1", temperature=0.0)
|
924
986
|
if debugger is None
|
925
987
|
else debugger
|
926
988
|
),
|
@@ -983,9 +1045,7 @@ class AzureVisionAgentCoder(VisionAgentCoder):
|
|
983
1045
|
coder=AzureOpenAILMM(temperature=0.0) if coder is None else coder,
|
984
1046
|
tester=AzureOpenAILMM(temperature=0.0) if tester is None else tester,
|
985
1047
|
debugger=(
|
986
|
-
AzureOpenAILMM(temperature=0.0
|
987
|
-
if debugger is None
|
988
|
-
else debugger
|
1048
|
+
AzureOpenAILMM(temperature=0.0) if debugger is None else debugger
|
989
1049
|
),
|
990
1050
|
tool_recommender=(
|
991
1051
|
AzureSim(T.TOOLS_DF, sim_key="desc")
|
{vision_agent-0.2.131 → vision_agent-0.2.133}/vision_agent/agent/vision_agent_coder_prompts.py
RENAMED
@@ -63,6 +63,7 @@ This is the documentation for the functions you have access to. You may call any
|
|
63
63
|
**Plans**:
|
64
64
|
{plans}
|
65
65
|
|
66
|
+
**Previous Attempts**:
|
66
67
|
{previous_attempts}
|
67
68
|
|
68
69
|
**Instructions**:
|
@@ -108,16 +109,27 @@ plan2:
|
|
108
109
|
- Use the 'florence2_phrase_grounding' tool with the prompt 'person' to detect where the people are in the video.
|
109
110
|
plan3:
|
110
111
|
- Extract frames from 'video.mp4' at 10 FPS using the 'extract_frames' tool.
|
111
|
-
- Use the '
|
112
|
+
- Use the 'florence2_sam2_video_tracking' tool with the prompt 'person' to detect where the people are in the video.
|
112
113
|
|
113
114
|
|
114
115
|
```python
|
115
|
-
|
116
|
+
import numpy as np
|
117
|
+
from vision_agent.tools import extract_frames, owl_v2_image, florence2_phrase_grounding, florence2_sam2_video_tracking
|
116
118
|
|
117
119
|
# sample at 1 FPS and use the first 10 frames to reduce processing time
|
118
120
|
frames = extract_frames("video.mp4", 1)
|
119
121
|
frames = [f[0] for f in frames][:10]
|
120
122
|
|
123
|
+
def remove_arrays(o):
|
124
|
+
if isinstance(o, list):
|
125
|
+
return [remove_arrays(e) for e in o]
|
126
|
+
elif isinstance(o, dict):
|
127
|
+
return {{k: remove_arrays(v) for k, v in o.items()}}
|
128
|
+
elif isinstance(o, np.ndarray):
|
129
|
+
return "array: " + str(o.shape)
|
130
|
+
else:
|
131
|
+
return o
|
132
|
+
|
121
133
|
# plan1
|
122
134
|
owl_v2_out = [owl_v2_image("person", f) for f in frames]
|
123
135
|
|
@@ -125,9 +137,10 @@ owl_v2_out = [owl_v2_image("person", f) for f in frames]
|
|
125
137
|
florence2_out = [florence2_phrase_grounding("person", f) for f in frames]
|
126
138
|
|
127
139
|
# plan3
|
128
|
-
|
140
|
+
f2s2_tracking_out = florence2_sam2_video_tracking("person", frames)
|
141
|
+
remove_arrays(f2s2_tracking_out)
|
129
142
|
|
130
|
-
final_out = {{"owl_v2_image": owl_v2_out, "
|
143
|
+
final_out = {{"owl_v2_image": owl_v2_out, "florence2_phrase_grounding": florence2_out, "florence2_sam2_video_tracking": f2s2_tracking_out}}
|
131
144
|
print(final_out)
|
132
145
|
```
|
133
146
|
"""
|
@@ -161,9 +174,10 @@ PICK_PLAN = """
|
|
161
174
|
|
162
175
|
**Instructions**:
|
163
176
|
1. Given the plans, image, and tool outputs, decide which plan is the best to achieve the user request.
|
164
|
-
2.
|
177
|
+
2. Solve the problem yourself given the image and pick the plan that matches your solution the best.
|
165
178
|
3. Output a JSON object with the following format:
|
166
179
|
{{
|
180
|
+
"predicted_answer": str # the answer you would expect from the best plan
|
167
181
|
"thoughts": str # your thought process for choosing the best plan
|
168
182
|
"best_plan": str # the best plan you have chosen
|
169
183
|
}}
|
@@ -311,6 +325,11 @@ This is the documentation for the functions you have access to. You may call any
|
|
311
325
|
FIX_BUG = """
|
312
326
|
**Role** As a coder, your job is to find the error in the code and fix it. You are running in a notebook setting so you can run !pip install to install missing packages.
|
313
327
|
|
328
|
+
**Documentation**:
|
329
|
+
This is the documentation for the functions you have access to. You may call any of these functions to help you complete the task. They are available through importing `from vision_agent.tools import *`.
|
330
|
+
|
331
|
+
{docstring}
|
332
|
+
|
314
333
|
**Instructions**:
|
315
334
|
Please re-complete the code to fix the error message. Here is the previous version:
|
316
335
|
```python
|
@@ -323,17 +342,24 @@ When we run this test code:
|
|
323
342
|
```
|
324
343
|
|
325
344
|
It raises this error:
|
345
|
+
```
|
326
346
|
{result}
|
347
|
+
```
|
327
348
|
|
328
349
|
This is previous feedback provided on the code:
|
329
350
|
{feedback}
|
330
351
|
|
331
|
-
Please fix the bug by
|
352
|
+
Please fix the bug by correcting the error. Return the following JSON object followed by the fixed code in the below format:
|
353
|
+
```json
|
332
354
|
{{
|
333
355
|
"reflections": str # any thoughts you have about the bug and how you fixed it
|
334
|
-
"
|
335
|
-
"test": str # the fixed test code if any, else an empty string
|
356
|
+
"which_code": str # the code that was fixed, can only be 'code' or 'test'
|
336
357
|
}}
|
358
|
+
```
|
359
|
+
|
360
|
+
```python
|
361
|
+
# Your fixed code here
|
362
|
+
```
|
337
363
|
"""
|
338
364
|
|
339
365
|
|
@@ -468,7 +468,7 @@ def ocr(image: np.ndarray) -> List[Dict[str, Any]]:
|
|
468
468
|
|
469
469
|
pil_image = Image.fromarray(image).convert("RGB")
|
470
470
|
image_size = pil_image.size[::-1]
|
471
|
-
if image_size[0] < 1
|
471
|
+
if image_size[0] < 1 or image_size[1] < 1:
|
472
472
|
return []
|
473
473
|
image_buffer = io.BytesIO()
|
474
474
|
pil_image.save(image_buffer, format="PNG")
|
@@ -781,6 +781,44 @@ def ixc25_video_vqa(prompt: str, frames: List[np.ndarray]) -> str:
|
|
781
781
|
return cast(str, data["answer"])
|
782
782
|
|
783
783
|
|
784
|
+
def ixc25_temporal_localization(prompt: str, frames: List[np.ndarray]) -> List[bool]:
|
785
|
+
"""'ixc25_temporal_localization' uses ixc25_video_vqa to temporally segment a video
|
786
|
+
given a prompt that can be other an object or a phrase. It returns a list of
|
787
|
+
boolean values indicating whether the object or phrase is present in the
|
788
|
+
corresponding frame.
|
789
|
+
|
790
|
+
Parameters:
|
791
|
+
prompt (str): The question about the video
|
792
|
+
frames (List[np.ndarray]): The reference frames used for the question
|
793
|
+
|
794
|
+
Returns:
|
795
|
+
List[bool]: A list of boolean values indicating whether the object or phrase is
|
796
|
+
present in the corresponding frame.
|
797
|
+
|
798
|
+
Example
|
799
|
+
-------
|
800
|
+
>>> output = ixc25_temporal_localization('soccer goal', frames)
|
801
|
+
>>> print(output)
|
802
|
+
[False, False, False, True, True, True, False, False, False, False]
|
803
|
+
>>> save_video([f for i, f in enumerate(frames) if output[i]], 'output.mp4')
|
804
|
+
"""
|
805
|
+
|
806
|
+
buffer_bytes = frames_to_bytes(frames)
|
807
|
+
files = [("video", buffer_bytes)]
|
808
|
+
payload = {
|
809
|
+
"prompt": prompt,
|
810
|
+
"chunk_length": 2,
|
811
|
+
"function_name": "ixc25_temporal_localization",
|
812
|
+
}
|
813
|
+
data: List[int] = send_inference_request(
|
814
|
+
payload, "video-temporal-localization", files=files, v2=True
|
815
|
+
)
|
816
|
+
chunk_size = round(len(frames) / len(data))
|
817
|
+
data_explode = [[elt] * chunk_size for elt in data]
|
818
|
+
data_bool = [bool(elt) for sublist in data_explode for elt in sublist]
|
819
|
+
return data_bool[: len(frames)]
|
820
|
+
|
821
|
+
|
784
822
|
def gpt4o_image_vqa(prompt: str, image: np.ndarray) -> str:
|
785
823
|
"""'gpt4o_image_vqa' is a tool that can answer any questions about arbitrary images
|
786
824
|
including regular images or images of documents or presentations. It returns text
|
@@ -1112,6 +1150,8 @@ def florence2_ocr(image: np.ndarray) -> List[Dict[str, Any]]:
|
|
1112
1150
|
"""
|
1113
1151
|
|
1114
1152
|
image_size = image.shape[:2]
|
1153
|
+
if image_size[0] < 1 or image_size[1] < 1:
|
1154
|
+
return []
|
1115
1155
|
image_b64 = convert_to_b64(image)
|
1116
1156
|
data = {
|
1117
1157
|
"image": image_b64,
|
@@ -1467,7 +1507,7 @@ def extract_frames(
|
|
1467
1507
|
Parameters:
|
1468
1508
|
video_uri (Union[str, Path]): The path to the video file, url or youtube link
|
1469
1509
|
fps (float, optional): The frame rate per second to extract the frames. Defaults
|
1470
|
-
to
|
1510
|
+
to 1.
|
1471
1511
|
|
1472
1512
|
Returns:
|
1473
1513
|
List[Tuple[np.ndarray, float]]: A list of tuples containing the extracted frame
|
@@ -292,7 +292,7 @@ class Execution(BaseModel):
|
|
292
292
|
error: Optional[Error] = None
|
293
293
|
"Error object if an error occurred, None otherwise."
|
294
294
|
|
295
|
-
def text(self, include_logs: bool = True) -> str:
|
295
|
+
def text(self, include_logs: bool = True, include_results: bool = True) -> str:
|
296
296
|
"""Returns the text representation of this object, i.e. including the main
|
297
297
|
result or the error traceback, optionally along with the logs (stdout, stderr).
|
298
298
|
"""
|
@@ -300,15 +300,17 @@ class Execution(BaseModel):
|
|
300
300
|
if self.error:
|
301
301
|
return prefix + "\n----- Error -----\n" + self.error.traceback
|
302
302
|
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
303
|
+
if include_results:
|
304
|
+
result_str = [
|
305
|
+
(
|
306
|
+
f"----- Final output -----\n{res.text}"
|
307
|
+
if res.is_main_result
|
308
|
+
else f"----- Intermediate output-----\n{res.text}"
|
309
|
+
)
|
310
|
+
for res in self.results
|
311
|
+
]
|
312
|
+
return prefix + "\n" + "\n".join(result_str)
|
313
|
+
return prefix
|
312
314
|
|
313
315
|
@property
|
314
316
|
def success(self) -> bool:
|
@@ -7,7 +7,6 @@ from typing import List, Optional, Tuple
|
|
7
7
|
import av # type: ignore
|
8
8
|
import cv2
|
9
9
|
import numpy as np
|
10
|
-
from decord import VideoReader # type: ignore
|
11
10
|
|
12
11
|
_LOGGER = logging.getLogger(__name__)
|
13
12
|
# The maximum length of the clip to extract frames from, in seconds
|
@@ -103,7 +102,7 @@ def frames_to_bytes(
|
|
103
102
|
def extract_frames_from_video(
|
104
103
|
video_uri: str, fps: float = 1.0
|
105
104
|
) -> List[Tuple[np.ndarray, float]]:
|
106
|
-
"""Extract frames from a video
|
105
|
+
"""Extract frames from a video along with the timestamp in seconds.
|
107
106
|
|
108
107
|
Parameters:
|
109
108
|
video_uri (str): the path to the video file or a video file url
|
@@ -115,12 +114,24 @@ def extract_frames_from_video(
|
|
115
114
|
from the start of the video. E.g. 12.125 means 12.125 seconds from the start of
|
116
115
|
the video. The frames are sorted by the timestamp in ascending order.
|
117
116
|
"""
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
117
|
+
|
118
|
+
cap = cv2.VideoCapture(video_uri)
|
119
|
+
orig_fps = cap.get(cv2.CAP_PROP_FPS)
|
120
|
+
orig_frame_time = 1 / orig_fps
|
121
|
+
targ_frame_time = 1 / fps
|
122
|
+
frames: List[Tuple[np.ndarray, float]] = []
|
123
|
+
i = 0
|
124
|
+
elapsed_time = 0.0
|
125
|
+
while cap.isOpened():
|
126
|
+
ret, frame = cap.read()
|
127
|
+
if not ret:
|
128
|
+
break
|
129
|
+
|
130
|
+
elapsed_time += orig_frame_time
|
131
|
+
if elapsed_time >= targ_frame_time:
|
132
|
+
frames.append((cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), i / orig_fps))
|
133
|
+
elapsed_time -= targ_frame_time
|
134
|
+
|
135
|
+
i += 1
|
136
|
+
cap.release()
|
137
|
+
return frames
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|