vision-agent 0.2.173__tar.gz → 0.2.175__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (35) hide show
  1. {vision_agent-0.2.173 → vision_agent-0.2.175}/PKG-INFO +2 -2
  2. {vision_agent-0.2.173 → vision_agent-0.2.175}/pyproject.toml +2 -2
  3. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/agent/agent_utils.py +3 -1
  4. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/agent/vision_agent.py +63 -68
  5. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/agent/vision_agent_coder.py +101 -31
  6. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/agent/vision_agent_prompts.py +75 -14
  7. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/tools/meta_tools.py +141 -111
  8. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/utils/execute.py +2 -2
  9. {vision_agent-0.2.173 → vision_agent-0.2.175}/LICENSE +0 -0
  10. {vision_agent-0.2.173 → vision_agent-0.2.175}/README.md +0 -0
  11. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/__init__.py +0 -0
  12. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/agent/__init__.py +0 -0
  13. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/agent/agent.py +0 -0
  14. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/agent/vision_agent_coder_prompts.py +0 -0
  15. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/agent/vision_agent_planner.py +0 -0
  16. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/agent/vision_agent_planner_prompts.py +0 -0
  17. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/clients/__init__.py +0 -0
  18. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/clients/http.py +0 -0
  19. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/clients/landing_public_api.py +0 -0
  20. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/fonts/__init__.py +0 -0
  21. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/fonts/default_font_ch_en.ttf +0 -0
  22. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/lmm/__init__.py +0 -0
  23. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/lmm/lmm.py +0 -0
  24. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/lmm/types.py +0 -0
  25. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/tools/__init__.py +0 -0
  26. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/tools/prompts.py +0 -0
  27. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/tools/tool_utils.py +0 -0
  28. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/tools/tools.py +0 -0
  29. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/tools/tools_types.py +0 -0
  30. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/utils/__init__.py +0 -0
  31. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/utils/exceptions.py +0 -0
  32. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/utils/image_utils.py +0 -0
  33. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/utils/sim.py +0 -0
  34. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/utils/type_defs.py +0 -0
  35. {vision_agent-0.2.173 → vision_agent-0.2.175}/vision_agent/utils/video.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vision-agent
3
- Version: 0.2.173
3
+ Version: 0.2.175
4
4
  Summary: Toolset for Vision Agent
5
5
  Author: Landing AI
6
6
  Author-email: dev@landing.ai
@@ -16,6 +16,7 @@ Requires-Dist: e2b-code-interpreter (==0.0.11a37)
16
16
  Requires-Dist: flake8 (>=7.0.0,<8.0.0)
17
17
  Requires-Dist: ipykernel (>=6.29.4,<7.0.0)
18
18
  Requires-Dist: langsmith (>=0.1.58,<0.2.0)
19
+ Requires-Dist: libcst (>=1.5.0,<2.0.0)
19
20
  Requires-Dist: nbclient (>=0.10.0,<0.11.0)
20
21
  Requires-Dist: nbformat (>=5.10.4,<6.0.0)
21
22
  Requires-Dist: numpy (>=1.21.0,<2.0.0)
@@ -27,7 +28,6 @@ Requires-Dist: pillow-heif (>=0.16.0,<0.17.0)
27
28
  Requires-Dist: pydantic (==2.7.4)
28
29
  Requires-Dist: pydantic-settings (>=2.2.1,<3.0.0)
29
30
  Requires-Dist: pytube (==15.0.0)
30
- Requires-Dist: redbaron (>=0.9.2,<0.10.0)
31
31
  Requires-Dist: requests (>=2.0.0,<3.0.0)
32
32
  Requires-Dist: rich (>=13.7.1,<14.0.0)
33
33
  Requires-Dist: scipy (>=1.13.0,<1.14.0)
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "vision-agent"
7
- version = "0.2.173"
7
+ version = "0.2.175"
8
8
  description = "Toolset for Vision Agent"
9
9
  authors = ["Landing AI <dev@landing.ai>"]
10
10
  readme = "README.md"
@@ -43,7 +43,7 @@ pytube = "15.0.0"
43
43
  anthropic = "^0.31.0"
44
44
  pydantic = "2.7.4"
45
45
  av = "^11.0.0"
46
- redbaron = "^0.9.2"
46
+ libcst = "^1.5.0"
47
47
 
48
48
  [tool.poetry.group.dev.dependencies]
49
49
  autoflake = "1.*"
@@ -72,7 +72,9 @@ def extract_json(json_str: str) -> Dict[str, Any]:
72
72
  if json_dict is None:
73
73
  error_msg = f"Could not extract JSON from the given str: {json_orig}"
74
74
  _LOGGER.exception(error_msg)
75
- raise ValueError(error_msg)
75
+ raise json.JSONDecodeError(
76
+ msg="Could not extract JSON", doc=json_orig, pos=0
77
+ )
76
78
 
77
79
  return json_dict
78
80
 
@@ -2,7 +2,6 @@ import copy
2
2
  import json
3
3
  import logging
4
4
  import os
5
- import tempfile
6
5
  from pathlib import Path
7
6
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
8
7
 
@@ -12,6 +11,7 @@ from vision_agent.agent.vision_agent_prompts import (
12
11
  EXAMPLES_CODE1,
13
12
  EXAMPLES_CODE2,
14
13
  EXAMPLES_CODE3,
14
+ EXAMPLES_CODE3_EXTRA2,
15
15
  VA_CODE,
16
16
  )
17
17
  from vision_agent.lmm import LMM, AnthropicLMM, Message, OpenAILMM
@@ -19,7 +19,6 @@ from vision_agent.tools.meta_tools import (
19
19
  META_TOOL_DOCSTRING,
20
20
  Artifacts,
21
21
  check_and_load_image,
22
- extract_and_save_files_to_artifacts,
23
22
  use_extra_vision_agent_args,
24
23
  )
25
24
  from vision_agent.utils import CodeInterpreterFactory
@@ -37,11 +36,12 @@ class BoilerplateCode:
37
36
  pre_code = [
38
37
  "from typing import *",
39
38
  "from vision_agent.utils.execute import CodeInterpreter",
40
- "from vision_agent.tools.meta_tools import Artifacts, open_code_artifact, create_code_artifact, edit_code_artifact, get_tool_descriptions, generate_vision_code, edit_vision_code, view_media_artifact, object_detection_fine_tuning, use_object_detection_fine_tuning",
41
- "artifacts = Artifacts('{remote_path}')",
39
+ "from vision_agent.tools.meta_tools import Artifacts, open_code_artifact, create_code_artifact, edit_code_artifact, get_tool_descriptions, generate_vision_code, edit_vision_code, view_media_artifact, object_detection_fine_tuning, use_object_detection_fine_tuning, list_artifacts, capture_files_into_artifacts",
40
+ "artifacts = Artifacts('{remote_path}', '{remote_path}')",
42
41
  "artifacts.load('{remote_path}')",
43
42
  ]
44
43
  post_code = [
44
+ "capture_files_into_artifacts(artifacts)",
45
45
  "artifacts.save()",
46
46
  ]
47
47
 
@@ -97,8 +97,9 @@ def _clean_response(response: str) -> str:
97
97
  def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]:
98
98
  chat = copy.deepcopy(chat)
99
99
 
100
+ # only add 10 most recent messages in the chat to not go over token limit
100
101
  conversation = ""
101
- for chat_i in chat:
102
+ for chat_i in chat[-10:]:
102
103
  if chat_i["role"] == "user":
103
104
  conversation += f"USER: {chat_i['content']}\n\n"
104
105
  elif chat_i["role"] == "observation":
@@ -110,7 +111,7 @@ def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]:
110
111
 
111
112
  prompt = VA_CODE.format(
112
113
  documentation=META_TOOL_DOCSTRING,
113
- examples=f"{EXAMPLES_CODE1}\n{EXAMPLES_CODE2}\n{EXAMPLES_CODE3}",
114
+ examples=f"{EXAMPLES_CODE1}\n{EXAMPLES_CODE2}\n{EXAMPLES_CODE3}\n{EXAMPLES_CODE3_EXTRA2}",
114
115
  conversation=conversation,
115
116
  )
116
117
  message: Message = {"role": "user", "content": prompt}
@@ -120,7 +121,9 @@ def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]:
120
121
  and "media" in chat[-1]
121
122
  and len(chat[-1]["media"]) > 0 # type: ignore
122
123
  ):
123
- message["media"] = chat[-1]["media"]
124
+ media_obs = [media for media in chat[-1]["media"] if Path(media).exists()] # type: ignore
125
+ if len(media_obs) > 0:
126
+ message["media"] = media_obs # type: ignore
124
127
  conv_resp = cast(str, orch([message], stream=False))
125
128
 
126
129
  # clean the response first, if we are executing code, do not resond or end
@@ -144,16 +147,16 @@ def execute_code_action(
144
147
  artifacts: Artifacts,
145
148
  code: str,
146
149
  code_interpreter: CodeInterpreter,
147
- artifact_remote_path: str,
148
150
  ) -> Tuple[Execution, str]:
149
151
  result = code_interpreter.exec_isolation(
150
- BoilerplateCode.add_boilerplate(code, remote_path=artifact_remote_path)
152
+ BoilerplateCode.add_boilerplate(
153
+ code, remote_path=str(artifacts.remote_save_path)
154
+ )
151
155
  )
152
156
 
153
157
  obs = str(result.logs)
154
158
  if result.error:
155
159
  obs += f"\n{result.error}"
156
- extract_and_save_files_to_artifacts(artifacts, code, obs)
157
160
  return result, obs
158
161
 
159
162
 
@@ -161,7 +164,6 @@ def execute_user_code_action(
161
164
  artifacts: Artifacts,
162
165
  last_user_message: Message,
163
166
  code_interpreter: CodeInterpreter,
164
- artifact_remote_path: str,
165
167
  ) -> Tuple[Optional[Execution], Optional[str]]:
166
168
  user_result = None
167
169
  user_obs = None
@@ -178,11 +180,10 @@ def execute_user_code_action(
178
180
  if user_code_action is not None:
179
181
  user_code_action = use_extra_vision_agent_args(user_code_action, False)
180
182
  user_result, user_obs = execute_code_action(
181
- artifacts, user_code_action, code_interpreter, artifact_remote_path
183
+ artifacts, user_code_action, code_interpreter
182
184
  )
183
185
  if user_result.error:
184
186
  user_obs += f"\n{user_result.error}"
185
- extract_and_save_files_to_artifacts(artifacts, user_code_action, user_obs)
186
187
  return user_result, user_obs
187
188
 
188
189
 
@@ -231,9 +232,18 @@ def old_format_to_new_format(old_format_str: str) -> str:
231
232
  except json.JSONDecodeError:
232
233
  return old_format_str
233
234
 
234
- thinking = old_format["thoughts"] if old_format["thoughts"].strip() != "" else None
235
- let_user_respond = old_format["let_user_respond"]
236
- if "<execute_python>" in old_format["response"]:
235
+ if "thoughts" in old_format:
236
+ thinking = (
237
+ old_format["thoughts"] if old_format["thoughts"].strip() != "" else None
238
+ )
239
+ else:
240
+ thinking = None
241
+
242
+ let_user_respond = (
243
+ old_format["let_user_respond"] if "let_user_respond" in old_format else True
244
+ )
245
+
246
+ if "response" in old_format and "<execute_python>" in old_format["response"]:
237
247
  execute_python = extract_tag(old_format["response"], "execute_python")
238
248
  response = (
239
249
  old_format["response"]
@@ -244,7 +254,7 @@ def old_format_to_new_format(old_format_str: str) -> str:
244
254
  )
245
255
  else:
246
256
  execute_python = None
247
- response = old_format["response"]
257
+ response = old_format["response"] if "response" in old_format else None
248
258
 
249
259
  return json.dumps(
250
260
  {
@@ -275,7 +285,6 @@ class VisionAgent(Agent):
275
285
  self,
276
286
  agent: Optional[LMM] = None,
277
287
  verbosity: int = 0,
278
- local_artifacts_path: Optional[Union[str, Path]] = None,
279
288
  callback_message: Optional[Callable[[Dict[str, Any]], None]] = None,
280
289
  code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
281
290
  ) -> None:
@@ -285,8 +294,6 @@ class VisionAgent(Agent):
285
294
  agent (Optional[LMM]): The agent to use for conversation and orchestration
286
295
  of other agents.
287
296
  verbosity (int): The verbosity level of the agent.
288
- local_artifacts_path (Optional[Union[str, Path]]): The path to the local
289
- artifacts file.
290
297
  callback_message (Optional[Callable[[Dict[str, Any]], None]]): Callback
291
298
  function to send intermediate update messages.
292
299
  code_interpreter (Optional[Union[str, CodeInterpreter]]): For string values
@@ -302,14 +309,6 @@ class VisionAgent(Agent):
302
309
  self.callback_message = callback_message
303
310
  if self.verbosity >= 1:
304
311
  _LOGGER.setLevel(logging.INFO)
305
- self.local_artifacts_path = cast(
306
- str,
307
- (
308
- Path(local_artifacts_path)
309
- if local_artifacts_path is not None
310
- else Path(tempfile.NamedTemporaryFile(delete=False).name)
311
- ),
312
- )
313
312
 
314
313
  def __call__(
315
314
  self,
@@ -386,7 +385,7 @@ class VisionAgent(Agent):
386
385
 
387
386
  if not artifacts:
388
387
  # this is setting remote artifacts path
389
- artifacts = Artifacts(WORKSPACE / "artifacts.pkl")
388
+ artifacts = Artifacts("", "")
390
389
 
391
390
  # NOTE: each chat should have a dedicated code interpreter instance to avoid concurrency issues
392
391
  code_interpreter = (
@@ -395,8 +394,15 @@ class VisionAgent(Agent):
395
394
  and not isinstance(self.code_interpreter, str)
396
395
  else CodeInterpreterFactory.new_instance(
397
396
  code_sandbox_runtime=self.code_interpreter,
397
+ remote_path=artifacts.remote_save_path.parent,
398
398
  )
399
399
  )
400
+
401
+ if code_interpreter.remote_path != artifacts.remote_save_path.parent:
402
+ raise ValueError(
403
+ f"Code interpreter remote path {code_interpreter.remote_path} does not match {artifacts.remote_save_path.parent}"
404
+ )
405
+
400
406
  with code_interpreter:
401
407
  orig_chat = copy.deepcopy(chat)
402
408
  int_chat = copy.deepcopy(chat)
@@ -436,15 +442,13 @@ class VisionAgent(Agent):
436
442
 
437
443
  # Save the current state of artifacts, will include any images the user
438
444
  # passed in.
439
- artifacts.save(self.local_artifacts_path)
445
+ artifacts.save()
440
446
 
441
447
  # Upload artifacts to remote location and show where they are going
442
448
  # to be loaded to. The actual loading happens in BoilerplateCode as
443
449
  # part of the pre_code.
444
- remote_artifacts_path = code_interpreter.upload_file(
445
- self.local_artifacts_path
446
- )
447
- artifacts_loaded = artifacts.show(code_interpreter.remote_path)
450
+ code_interpreter.upload_file(artifacts.local_save_path)
451
+ artifacts_loaded = artifacts.show(artifacts.remote_save_path.parent)
448
452
  int_chat.append({"role": "observation", "content": artifacts_loaded})
449
453
  orig_chat.append({"role": "observation", "content": artifacts_loaded})
450
454
  self.streaming_message({"role": "observation", "content": artifacts_loaded})
@@ -453,7 +457,6 @@ class VisionAgent(Agent):
453
457
  artifacts,
454
458
  last_user_message,
455
459
  code_interpreter,
456
- str(remote_artifacts_path),
457
460
  )
458
461
  finished = user_result is not None and user_obs is not None
459
462
  if user_result is not None and user_obs is not None:
@@ -472,7 +475,16 @@ class VisionAgent(Agent):
472
475
  )
473
476
 
474
477
  while not finished and iterations < self.max_iterations:
478
+ # ensure we upload the artifacts before each turn, so any local
479
+ # modifications we made to it will be reflected in the remote
480
+ code_interpreter.upload_file(artifacts.local_save_path)
481
+
475
482
  response = run_conversation(self.agent, int_chat)
483
+ code_action = use_extra_vision_agent_args(
484
+ response.get("execute_python", None),
485
+ test_multi_plan,
486
+ custom_tool_names,
487
+ )
476
488
  if self.verbosity >= 1:
477
489
  _LOGGER.info(response)
478
490
  int_chat.append(
@@ -532,31 +544,20 @@ class VisionAgent(Agent):
532
544
  artifacts,
533
545
  code_action,
534
546
  code_interpreter,
535
- str(remote_artifacts_path),
536
547
  )
537
-
538
- media_obs = check_and_load_image(code_action)
539
-
540
- if self.verbosity >= 1:
541
- _LOGGER.info(obs)
542
-
543
548
  obs_chat_elt: Message = {"role": "observation", "content": obs}
549
+ media_obs = check_and_load_image(code_action)
544
550
  if media_obs and result.success:
545
- # for view_media_artifact, we need to ensure the media is loaded
546
- # locally so the conversation agent can actually see it
547
- code_interpreter.download_file(
548
- str(remote_artifacts_path.name),
549
- str(self.local_artifacts_path),
550
- )
551
- artifacts.load(
552
- self.local_artifacts_path,
553
- Path(self.local_artifacts_path).parent,
554
- )
551
+ # media paths will be under the local_save_path when we download
552
+ # them after each turn
555
553
  obs_chat_elt["media"] = [
556
- Path(self.local_artifacts_path).parent / media_ob
554
+ artifacts.local_save_path.parent / media_ob
557
555
  for media_ob in media_obs
558
556
  ]
559
557
 
558
+ if self.verbosity >= 1:
559
+ _LOGGER.info(obs)
560
+
560
561
  # don't add execution results to internal chat
561
562
  int_chat.append(obs_chat_elt)
562
563
  obs_chat_elt["execution"] = result
@@ -573,13 +574,15 @@ class VisionAgent(Agent):
573
574
  iterations += 1
574
575
  last_response = response
575
576
 
576
- # after running the agent, download the artifacts locally
577
- code_interpreter.download_file(
578
- str(remote_artifacts_path.name), str(self.local_artifacts_path)
579
- )
580
- artifacts.load(
581
- self.local_artifacts_path, Path(self.local_artifacts_path).parent
582
- )
577
+ # after each turn, download the artifacts locally
578
+ code_interpreter.download_file(
579
+ str(artifacts.remote_save_path.name),
580
+ str(artifacts.local_save_path),
581
+ )
582
+ artifacts.load(
583
+ artifacts.local_save_path, artifacts.local_save_path.parent
584
+ )
585
+
583
586
  return orig_chat, artifacts
584
587
 
585
588
  def streaming_message(self, message: Dict[str, Any]) -> None:
@@ -595,7 +598,6 @@ class OpenAIVisionAgent(VisionAgent):
595
598
  self,
596
599
  agent: Optional[LMM] = None,
597
600
  verbosity: int = 0,
598
- local_artifacts_path: Optional[Union[str, Path]] = None,
599
601
  callback_message: Optional[Callable[[Dict[str, Any]], None]] = None,
600
602
  code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
601
603
  ) -> None:
@@ -605,8 +607,6 @@ class OpenAIVisionAgent(VisionAgent):
605
607
  agent (Optional[LMM]): The agent to use for conversation and orchestration
606
608
  of other agents.
607
609
  verbosity (int): The verbosity level of the agent.
608
- local_artifacts_path (Optional[Union[str, Path]]): The path to the local
609
- artifacts file.
610
610
  callback_message (Optional[Callable[[Dict[str, Any]], None]]): Callback
611
611
  function to send intermediate update messages.
612
612
  code_interpreter (Optional[Union[str, CodeInterpreter]]): For string values
@@ -619,7 +619,6 @@ class OpenAIVisionAgent(VisionAgent):
619
619
  super().__init__(
620
620
  agent,
621
621
  verbosity,
622
- local_artifacts_path,
623
622
  callback_message,
624
623
  code_interpreter,
625
624
  )
@@ -630,7 +629,6 @@ class AnthropicVisionAgent(VisionAgent):
630
629
  self,
631
630
  agent: Optional[LMM] = None,
632
631
  verbosity: int = 0,
633
- local_artifacts_path: Optional[Union[str, Path]] = None,
634
632
  callback_message: Optional[Callable[[Dict[str, Any]], None]] = None,
635
633
  code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
636
634
  ) -> None:
@@ -640,8 +638,6 @@ class AnthropicVisionAgent(VisionAgent):
640
638
  agent (Optional[LMM]): The agent to use for conversation and orchestration
641
639
  of other agents.
642
640
  verbosity (int): The verbosity level of the agent.
643
- local_artifacts_path (Optional[Union[str, Path]]): The path to the local
644
- artifacts file.
645
641
  callback_message (Optional[Callable[[Dict[str, Any]], None]]): Callback
646
642
  function to send intermediate update messages.
647
643
  code_interpreter (Optional[Union[str, CodeInterpreter]]): For string values
@@ -654,7 +650,6 @@ class AnthropicVisionAgent(VisionAgent):
654
650
  super().__init__(
655
651
  agent,
656
652
  verbosity,
657
- local_artifacts_path,
658
653
  callback_message,
659
654
  code_interpreter,
660
655
  )
@@ -5,7 +5,7 @@ import sys
5
5
  from pathlib import Path
6
6
  from typing import Any, Callable, Dict, List, Optional, Sequence, Union, cast
7
7
 
8
- from redbaron import RedBaron # type: ignore
8
+ import libcst as cst
9
9
  from tabulate import tabulate
10
10
 
11
11
  import vision_agent.tools as T
@@ -49,42 +49,112 @@ WORKSPACE = Path(os.getenv("WORKSPACE", ""))
49
49
  _LOGGER = logging.getLogger(__name__)
50
50
 
51
51
 
52
- def strip_function_calls(code: str, exclusions: Optional[List[str]] = None) -> str:
52
+ def strip_function_calls( # noqa: C901
53
+ code: str, exclusions: Optional[List[str]] = None
54
+ ) -> str:
53
55
  """This will strip out all code that calls functions except for functions included
54
56
  in exclusions.
55
57
  """
56
58
  if exclusions is None:
57
59
  exclusions = []
58
60
 
59
- red = RedBaron(code)
60
- nodes_to_remove = []
61
- for node in red:
62
- if node.type == "def":
63
- continue
64
- elif node.type == "import" or node.type == "from_import":
65
- continue
66
- elif node.type == "call":
67
- if node.value and node.value[0].value in exclusions:
68
- continue
69
- nodes_to_remove.append(node)
70
- elif node.type == "atomtrailers":
71
- if node[0].value in exclusions:
72
- continue
73
- nodes_to_remove.append(node)
74
- elif node.type == "assignment":
75
- if node.value.type == "call" or node.value.type == "atomtrailers":
76
- func_name = node.value[0].value
77
- if func_name in exclusions:
78
- continue
79
- nodes_to_remove.append(node)
80
- elif node.type == "endl":
81
- continue
82
- else:
83
- nodes_to_remove.append(node)
84
- for node in nodes_to_remove:
85
- node.parent.remove(node)
86
- cleaned_code = red.dumps().strip()
87
- return cleaned_code if isinstance(cleaned_code, str) else code
61
+ def check_and_remove_node(node: cst.CSTNode, exclusions: List[str]) -> cst.CSTNode:
62
+ if hasattr(node, "value") and isinstance(node.value, cst.Call):
63
+ if (
64
+ isinstance(node.value.func, cst.Name)
65
+ and node.value.func.value in exclusions
66
+ ):
67
+ return node
68
+ return cst.RemoveFromParent() # type: ignore
69
+ return node
70
+
71
+ class StripFunctionCallsTransformer(cst.CSTTransformer):
72
+ def __init__(self, exclusions: List[str]):
73
+ # Store exclusions to skip removing certain function calls
74
+ self.exclusions = exclusions
75
+ self.in_function_or_class = False
76
+
77
+ def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
78
+ self.in_function_or_class = True
79
+ return True
80
+
81
+ def leave_FunctionDef(
82
+ self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
83
+ ) -> cst.BaseStatement:
84
+ self.in_function_or_class = False
85
+ return updated_node
86
+
87
+ def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
88
+ self.in_function_or_class = True
89
+ return True
90
+
91
+ def leave_ClassDef(
92
+ self, node: cst.ClassDef, updated_node: cst.ClassDef
93
+ ) -> cst.BaseStatement:
94
+ self.in_function_or_class = False
95
+ return updated_node
96
+
97
+ def leave_Expr(
98
+ self, original_node: cst.Expr, updated_node: cst.Expr
99
+ ) -> cst.Expr:
100
+ if not self.in_function_or_class:
101
+ return cast(
102
+ cst.Expr, check_and_remove_node(updated_node, self.exclusions)
103
+ )
104
+ return updated_node
105
+
106
+ def leave_Assign(
107
+ self, original_node: cst.Assign, updated_node: cst.Assign
108
+ ) -> cst.Assign:
109
+ if not self.in_function_or_class:
110
+ return cast(
111
+ cst.Assign, check_and_remove_node(updated_node, self.exclusions)
112
+ )
113
+ return updated_node
114
+
115
+ def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.If:
116
+ if not self.in_function_or_class:
117
+ return cast(
118
+ cst.If, check_and_remove_node(updated_node, self.exclusions)
119
+ )
120
+ return updated_node
121
+
122
+ def leave_For(self, original_node: cst.For, updated_node: cst.For) -> cst.For:
123
+ if not self.in_function_or_class:
124
+ return cast(
125
+ cst.For, check_and_remove_node(updated_node, self.exclusions)
126
+ )
127
+ return updated_node
128
+
129
+ def leave_While(
130
+ self, original_node: cst.While, updated_node: cst.While
131
+ ) -> cst.While:
132
+ if not self.in_function_or_class:
133
+ return cast(
134
+ cst.While, check_and_remove_node(updated_node, self.exclusions)
135
+ )
136
+ return updated_node
137
+
138
+ def leave_With(
139
+ self, original_node: cst.With, updated_node: cst.With
140
+ ) -> cst.With:
141
+ if not self.in_function_or_class:
142
+ return cast(
143
+ cst.With, check_and_remove_node(updated_node, self.exclusions)
144
+ )
145
+ return updated_node
146
+
147
+ def leave_Try(self, original_node: cst.Try, updated_node: cst.Try) -> cst.Try:
148
+ if not self.in_function_or_class:
149
+ return cast(
150
+ cst.Try, check_and_remove_node(updated_node, self.exclusions)
151
+ )
152
+ return updated_node
153
+
154
+ tree = cst.parse_module(code)
155
+ transformer = StripFunctionCallsTransformer(exclusions)
156
+ modified_tree = tree.visit(transformer)
157
+ return modified_tree.code
88
158
 
89
159
 
90
160
  def write_code(
@@ -1,7 +1,7 @@
1
1
  VA_CODE = """
2
2
  **Role**: You are a helpful agent that assists users with writing code.
3
3
 
4
- **Taks**: As a conversational agent, you are required to understand the user's request and provide a helpful response. Use a Chain-of-Thought approach to break down the problem, create a plan, and then provide a response. Ensure that your response is clear, concise, and helpful. You can use an interactive Python (Jupyter Notebook) environment, executing code with <execution_python>. You are given access to an `artifacts` object which contains files shared between you and the user. `artifacts` will be automatically saved everytime you execute python code.
4
+ **Taks**: As a conversational agent, you are required to understand the user's request and provide a helpful response. Use a Chain-of-Thought approach to break down the problem, create a plan, and then provide a response. Ensure that your response is clear, concise, and helpful. You can use an interactive Python (Jupyter Notebook) environment, executing code with <execute_python>. You are given access to an `artifacts` object which contains files shared between you and the user. `artifacts` will be saved only AFTER you execute python code with `save_imgae` or `save_video`. The user can see all `artifacts`.
5
5
 
6
6
  <execute_python>
7
7
  print("Hello World!")
@@ -26,10 +26,11 @@ Here is the current conversation so far:
26
26
 
27
27
  **Instructions**:
28
28
  1. **Understand and Clarify**: Make sure you understand the task, ask clarifying questions if the task is not clear.
29
- 2. **Code Generation**: Only use code provided in the Documentation in your <execute_python> tags. Only use `edit_vision_code` to modify code written by `generate_vision_code`.
29
+ 2. **Code Generation**: Only use code provided in the Documentation in your <execute_python> tags. Only use `edit_vision_code` to modify code written by `generate_vision_code`. DO NOT run `edit_vision_code` or `edit_code_artifact` more than 2 times in a row and instead ask the user for help.
30
30
  3. **Execute**: Do only what the user asked you to do and no more. If you need to ask the user a question or show your results to the user, set <let_user_respond> to `true`.
31
- 4. **Response**: Keep your responses short and concise. Provide the user only with the information they need to continue the conversation.
32
- 5. **Output**: You can only respond with <thinking>, <response>, <execute_python>, and <let_user_respond> tags.
31
+ 4. **Artifacts**: Files are only saved in `artifacts` after <execute_python>, do not try to access artifacts until you observe that they are loaded.
32
+ 5. **Response**: Keep your responses short and concise. Provide the user only with the information they need to continue the conversation.
33
+ 6. **Output**: You can only respond with <thinking>, <response>, <execute_python>, and <let_user_respond> tags.
33
34
 
34
35
  <thinking>Your thoughts here...</thinking>
35
36
  <response>Your response to the user here...</response>
@@ -53,7 +54,7 @@ generate_vision_code(artifacts, 'dog_detector.py', 'Can you write code to detect
53
54
  <let_user_respond>false</let_user_respond>
54
55
 
55
56
  OBSERVATION:
56
- [Artifact dog_detector.py]
57
+ [Artifact dog_detector.py (5 lines total)]
57
58
  0|from vision_agent.tools import load_image, owl_v2
58
59
  1|def detect_dogs(image_path: str):
59
60
  2| image = load_image(image_path)
@@ -94,7 +95,7 @@ edit_vision_code(artifacts, 'dog_detector.py', ['Can you write code to detect do
94
95
  <let_user_respond>false</let_user_respond>
95
96
 
96
97
  OBSERVATION:
97
- [Artifact dog_detector.py]
98
+ [Artifact dog_detector.py (5 lines total)]
98
99
  0|from vision_agent.tools import load_image, owl_v2
99
100
  1|def detect_dogs(image_path: str):
100
101
  2| image = load_image(image_path)
@@ -142,7 +143,7 @@ AGENT: <thinking>The image shows a cat and a dog sitting on the couch, I will te
142
143
 
143
144
 
144
145
  EXAMPLES_CODE3 = """
145
- USER: Can you create a function to count workers with helmets, return the count and save a visaulization of the bounding boxes?
146
+ USER: Can you create a function to count workers with helmets, return the count and a visaulization of the bounding boxes?
146
147
 
147
148
  OBSERVATION:
148
149
  [Artifacts loaded]
@@ -166,9 +167,9 @@ generate_vision_code(artifacts, 'code.py', 'Can you create a function to count w
166
167
  <let_user_respond>false</let_user_respond>
167
168
 
168
169
  OBSERVATION:
169
- [Artifact code.py]
170
+ [Artifact code.py (17 lines total)]
170
171
  0|from vision_agent.tools import load_image, florence2_phrase_grounding, closest_box_distance, overlay_bounding_boxes, save_image
171
- 1|def count_workers_with_helmets(image_path: str, output_path: str):
172
+ 1|def count_workers_with_helmets(image_path: str):
172
173
  2| image = load_image(image_path)
173
174
  3| detections = florence2_phrase_grounding("worker, helmet", image)
174
175
  4| workers = [d for d in detections if d['label'] == 'worker']
@@ -182,15 +183,16 @@ OBSERVATION:
182
183
  11| break
183
184
  11| if person_has_helmet:
184
185
  13| count += 1
185
- 14| overlay_bounding_boxes(image, detections)
186
- 15| save_image(output_path, image)
187
- 16| return count
186
+ 14| viz = overlay_bounding_boxes(image, detections)
187
+ 16| return count, viz
188
188
  [End of artifact]
189
189
 
190
- AGENT: <thinking>I have generated the code to count the workers with helmets in the image, I must now run the code and print the output.</thinking>
190
+ AGENT: <thinking>I have generated the code to count the workers with helmets in the image, I must now run the code, print the counts and save the visualization.</thinking>
191
191
  <execute_python>
192
192
  from code import count_workers_with_helmets
193
- print(count_workers_with_helmets('/path/to/images/workers.png', 'workers_viz.png'))
193
+ count, viz = count_workers_with_helmets('/path/to/images/workers.png')
194
+ save_image(viz, 'workers_viz.png')
195
+ print(count)
194
196
  </execute_python>
195
197
  <let_user_respond>false</let_user_respond>
196
198
 
@@ -201,7 +203,9 @@ OBSERVATION:
201
203
  AGENT: <thinking>Two workers with helmets are detected, I will show this to the user and ask them if the result looks good.</thinking>
202
204
  <response>The code to detect workers with helmets is saved in code.py and the visualization under 'workers_viz.png'.</response>
203
205
  <let_user_respond>true</let_user_respond>
206
+ """
204
207
 
208
+ EXAMPLES_CODE3_EXTRA1 = """
205
209
  USER: The detections are slightly off. Can you fine tune florence2 using these labels? "[{'image_path': 'image1.png': 'labels': ['worker', 'helmet'], 'bboxes': [[235, 118, 294, 241], [232, 118, 294, 128]]}, {'image_path': 'image2.png': 'labels': ['worker', 'helmet'], 'bboxes': [[235, 118, 294, 241], [232, 118, 294, 128]]}]"
206
210
 
207
211
  AGENT: <thinking>Because the user has supplied me with labels I can call object_detection_fine_tuning on their behalf to fine tune the model</thinking>
@@ -249,3 +253,60 @@ AGENT: <thinking>Three workers with helmets are detected, I will show this to th
249
253
  <response>I have updated the code using the fine tuned florence2_phrase_grounding model to count the workers wearing helmets in code.py and saved the visualization under 'workers_viz.png'.</response>
250
254
  <let_user_respond>true</let_user_respond>
251
255
  """
256
+
257
+ EXAMPLES_CODE3_EXTRA2 = """
258
+ USER: Can you write a streamlit app that allows me to upload an image and count the workers with helmets?
259
+
260
+ AGENT: <thinking>I will use edit_code_artifact to generate a streamlit app that allows the user to upload an image and count the workers with helmets.</thinking>
261
+ <execute_python>
262
+ CODE = '''
263
+ import streamlit as st
264
+ from code imoprt count_workers_with_helmets
265
+ import os
266
+
267
+ def main():
268
+ st.title("Worker with Helmet Counter")
269
+ uploaded_file = st.file_uploader("Choose image...", type=["jpeg", "jpg", "png"])
270
+ if uploaded_file is not None:
271
+ with open("temp_image.jpg", "wb") as f:
272
+ f.write(uploaded_file.read())
273
+
274
+ count = count_workers_with_helmets("temp_image.jpg", "temp_image_viz.jpg")
275
+ st.image("temp_image_viz.jpg")
276
+ st.write(f"Number of workers with helmets: {count}")
277
+ os.remove("temp_image.jpg")
278
+ os.remove("temp_image_viz.jpg")
279
+
280
+ if __name__ == "__main__":
281
+ main()
282
+ '''
283
+ edit_code_artifact(artifacts, 'streamlit_app.py', CODE, 0, 0)
284
+ </execute_python>
285
+
286
+ OBSERVATION:
287
+ [Artifact streamlit_app.py (19 lines total)]
288
+ 0|import streamlit as st
289
+ 1|from code imoprt count_workers_with_helmets
290
+ 2|import os
291
+ 3|
292
+ 4|def main():
293
+ 5| st.title("Worker with Helmet Counter")
294
+ 6| uploaded_file = st.file_uploader("Choose image...", type=["jpeg", "jpg", "png"])
295
+ 7| if uploaded_file is not None:
296
+ 8| with open("temp_image.jpg", "wb") as f:
297
+ 9| f.write(uploaded_file.read())
298
+ 10|
299
+ 11| count = count_workers_with_helmets("temp_image.jpg", "temp_image_viz.jpg")
300
+ 12| st.image("temp_image_viz.jpg")
301
+ 13| st.write(f"Number of workers with helmets: {count}")
302
+ 14| os.remove("temp_image.jpg")
303
+ 15| os.remove("temp_image_viz.jpg")
304
+ 16|
305
+ 17|if __name__ == "__main__":
306
+ 18| main()
307
+ [End of artifact]
308
+
309
+ AGENT: <thinking>I have generated the streamlit app that allows the user to upload an image and count the workers</thinking>
310
+ <response>The streamlit app is saved in streamlit_app.py, you can run the app by running `streamlit run streamlit_app.py`.</response>
311
+ <let_user_respond>true</let_user_respond>
312
+ """
@@ -6,15 +6,13 @@ import re
6
6
  import subprocess
7
7
  import tempfile
8
8
  from pathlib import Path
9
- from typing import Any, Dict, List, Optional, Union
9
+ from typing import Any, Dict, List, Optional, Tuple, Union
10
10
 
11
+ import libcst as cst
11
12
  from IPython.display import display
12
- from redbaron import RedBaron # type: ignore
13
13
 
14
14
  import vision_agent as va
15
- from vision_agent.agent.agent_utils import extract_json
16
15
  from vision_agent.clients.landing_public_api import LandingPublicAPI
17
- from vision_agent.lmm import AnthropicLMM
18
16
  from vision_agent.lmm.types import Message
19
17
  from vision_agent.tools.tool_utils import get_tool_documentation
20
18
  from vision_agent.tools.tools import TOOL_DESCRIPTIONS
@@ -26,7 +24,6 @@ CURRENT_FILE = None
26
24
  CURRENT_LINE = 0
27
25
  DEFAULT_WINDOW_SIZE = 100
28
26
  ZMQ_PORT = os.environ.get("ZMQ_PORT", None)
29
- VERBOSITY = os.environ.get("VERBOSITY", 0)
30
27
 
31
28
 
32
29
  def report_progress_callback(port: int, inp: Dict[str, Any]) -> None:
@@ -38,16 +35,6 @@ def report_progress_callback(port: int, inp: Dict[str, Any]) -> None:
38
35
  socket.send_json(inp)
39
36
 
40
37
 
41
- def filter_file(file_name: Union[str, Path]) -> bool:
42
- file_name_p = Path(file_name)
43
- return (
44
- file_name_p.is_file()
45
- and "__pycache__" not in str(file_name_p)
46
- and file_name_p.suffix in [".py", ".txt"]
47
- and not file_name_p.name.startswith(".")
48
- )
49
-
50
-
51
38
  def redisplay_results(execution: Execution) -> None:
52
39
  """This function is used to add previous execution results to the current output.
53
40
  This is handy if you are inside a notebook environment, call it notebook1, and you
@@ -86,8 +73,11 @@ class Artifacts:
86
73
  need to be in sync with the remote environment the VisionAgent is running in.
87
74
  """
88
75
 
89
- def __init__(self, remote_save_path: Union[str, Path]) -> None:
76
+ def __init__(
77
+ self, remote_save_path: Union[str, Path], local_save_path: Union[str, Path]
78
+ ) -> None:
90
79
  self.remote_save_path = Path(remote_save_path)
80
+ self.local_save_path = Path(local_save_path)
91
81
  self.artifacts: Dict[str, Any] = {}
92
82
 
93
83
  self.code_sandbox_runtime = None
@@ -131,9 +121,7 @@ class Artifacts:
131
121
  return output_str
132
122
 
133
123
  def save(self, local_path: Optional[Union[str, Path]] = None) -> None:
134
- save_path = (
135
- Path(local_path) if local_path is not None else self.remote_save_path
136
- )
124
+ save_path = Path(local_path) if local_path is not None else self.local_save_path
137
125
  with open(save_path, "wb") as f:
138
126
  pkl.dump(self.artifacts, f)
139
127
 
@@ -150,6 +138,38 @@ class Artifacts:
150
138
  return name in self.artifacts
151
139
 
152
140
 
141
+ def filter_file(file_name: Union[str, Path]) -> Tuple[bool, bool]:
142
+ file_name_p = Path(file_name)
143
+ return (
144
+ file_name_p.is_file()
145
+ and "__pycache__" not in str(file_name_p)
146
+ and not file_name_p.name.startswith(".")
147
+ and file_name_p.suffix
148
+ in [".png", ".jpeg", ".jpg", ".mp4", ".txt", ".json", ".csv"]
149
+ ), file_name_p.suffix in [".png", ".jpeg", ".jpg", ".mp4"]
150
+
151
+
152
+ def capture_files_into_artifacts(artifacts: Artifacts) -> None:
153
+ """This function is used to capture all files in the current directory into an
154
+ artifact object. This is useful if you want to capture all files in the current
155
+ directory and use them in a different environment where you don't have access to
156
+ the file system.
157
+
158
+ Parameters:
159
+ artifact (Artifacts): The artifact object to save the files to.
160
+ """
161
+ for file in Path(".").glob("**/*"):
162
+ usable_file, is_media = filter_file(file)
163
+ mode = "rb" if is_media else "r"
164
+ if usable_file:
165
+ file_name = file.name
166
+ if file_name.startswith(str(Path(artifacts.remote_save_path).parents)):
167
+ idx = len(Path(artifacts.remote_save_path).parents)
168
+ file_name = file_name[idx:]
169
+ with open(file, mode) as f:
170
+ artifacts[file_name] = f.read()
171
+
172
+
153
173
  # These tools are adapted from SWE-Agent https://github.com/princeton-nlp/SWE-agent
154
174
 
155
175
 
@@ -174,9 +194,9 @@ def view_lines(
174
194
  f"[Artifact: {name} ({total_lines} lines total)]\n"
175
195
  + format_lines(lines[start:end], start)
176
196
  + (
177
- "[End of artifact]"
197
+ "\n[End of artifact]"
178
198
  if end == len(lines)
179
- else f"[{len(lines) - end} more lines]"
199
+ else f"\n[{len(lines) - end} more lines]"
180
200
  )
181
201
  )
182
202
 
@@ -256,8 +276,10 @@ def edit_code_artifact(
256
276
  Parameters:
257
277
  artifacts (Artifacts): The artifacts object to edit the artifact from.
258
278
  name (str): The name of the artifact to edit.
259
- start (int): The line number to start the edit.
260
- end (int): The line number to end the edit.
279
+ start (int): The line number to start the edit, can be in [-1, total_lines]
280
+ where -1 represents the end of the file.
281
+ end (int): The line number to end the edit, can be in [-1, total_lines] where
282
+ -1 represents the end of the file.
261
283
  content (str): The content to insert.
262
284
  """
263
285
  # just make the artifact if it doesn't exist instead of forcing agent to call
@@ -266,17 +288,21 @@ def edit_code_artifact(
266
288
  artifacts[name] = ""
267
289
 
268
290
  total_lines = len(artifacts[name].splitlines())
291
+ if start == -1:
292
+ start = total_lines
293
+ if end == -1:
294
+ end = total_lines
295
+
269
296
  if start < 0 or end < 0 or start > end or end > total_lines:
270
297
  print("[Invalid line range]")
271
298
  return "[Invalid line range]"
272
- if start == end:
273
- end += 1
274
299
 
275
300
  new_content_lines = content.splitlines(keepends=True)
276
301
  new_content_lines = [
277
302
  line if line.endswith("\n") else line + "\n" for line in new_content_lines
278
303
  ]
279
304
  lines = artifacts[name].splitlines(keepends=True)
305
+ lines = [line if line.endswith("\n") else line + "\n" for line in lines]
280
306
  edited_lines = lines[:start] + new_content_lines + lines[end:]
281
307
 
282
308
  cur_line = start + len(content.split("\n")) // 2
@@ -371,14 +397,16 @@ def generate_vision_plan(
371
397
  [End Plan Context]
372
398
  """
373
399
 
400
+ # verbosity is set to 0 to avoid adding extra content to the VisionAgent conversation
374
401
  if ZMQ_PORT is not None:
375
402
  agent = va.agent.VisionAgentPlanner(
376
403
  report_progress_callback=lambda inp: report_progress_callback(
377
404
  int(ZMQ_PORT), inp
378
- )
405
+ ),
406
+ verbosity=0,
379
407
  )
380
408
  else:
381
- agent = va.agent.VisionAgentPlanner()
409
+ agent = va.agent.VisionAgentPlanner(verbosity=0)
382
410
 
383
411
  fixed_chat: List[Message] = [{"role": "user", "content": chat, "media": media}]
384
412
  response = agent.generate_plan(
@@ -435,14 +463,16 @@ def generate_vision_code(
435
463
  dogs = owl_v2("dog", image)
436
464
  return dogs
437
465
  """
466
+ # verbosity is set to 0 to avoid adding extra content to the VisionAgent conversation
438
467
  if ZMQ_PORT is not None:
439
468
  agent = va.agent.VisionAgentCoder(
440
469
  report_progress_callback=lambda inp: report_progress_callback(
441
470
  int(ZMQ_PORT), inp
442
- )
471
+ ),
472
+ verbosity=0,
443
473
  )
444
474
  else:
445
- agent = va.agent.VisionAgentCoder(verbosity=int(VERBOSITY))
475
+ agent = va.agent.VisionAgentCoder(verbosity=0)
446
476
 
447
477
  fixed_chat: List[Message] = [{"role": "user", "content": chat, "media": media}]
448
478
  response = agent.generate_code(
@@ -506,7 +536,8 @@ def edit_vision_code(
506
536
  return dogs
507
537
  """
508
538
 
509
- agent = va.agent.VisionAgentCoder(verbosity=int(VERBOSITY))
539
+ # verbosity is set to 0 to avoid adding extra content to the VisionAgent conversation
540
+ agent = va.agent.VisionAgentCoder(verbosity=0)
510
541
  if name not in artifacts:
511
542
  print(f"[Artifact {name} does not exist]")
512
543
  return f"[Artifact {name} does not exist]"
@@ -570,8 +601,9 @@ def check_and_load_image(code: str) -> List[str]:
570
601
 
571
602
 
572
603
  def view_media_artifact(artifacts: Artifacts, name: str) -> str:
573
- """Allows you to view the media artifact with the given name. This does not show
574
- the media to the user, the user can already see all media saved in the artifacts.
604
+ """Allows only the agent to view the media artifact with the given name. DO NOT use
605
+ this to show media to the user, the user can already see all media saved in the
606
+ artifacts.
575
607
 
576
608
  Parameters:
577
609
  artifacts (Artifacts): The artifacts object to show the image from.
@@ -648,10 +680,10 @@ def get_diff_with_prompts(name: str, before: str, after: str) -> str:
648
680
 
649
681
 
650
682
  def use_extra_vision_agent_args(
651
- code: str,
683
+ code: Optional[str],
652
684
  test_multi_plan: bool = True,
653
685
  custom_tool_names: Optional[List[str]] = None,
654
- ) -> str:
686
+ ) -> Optional[str]:
655
687
  """This is for forcing arguments passed by the user to VisionAgent into the
656
688
  VisionAgentCoder call.
657
689
 
@@ -663,21 +695,81 @@ def use_extra_vision_agent_args(
663
695
  Returns:
664
696
  str: The edited code.
665
697
  """
666
- red = RedBaron(code)
667
- for node in red:
668
- # seems to always be atomtrailers not call type
669
- if node.type == "atomtrailers":
670
- if node.name.value == "generate_vision_code":
671
- node.value[1].value.append(f"test_multi_plan={test_multi_plan}")
672
-
673
- if (
674
- node.name.value == "generate_vision_code"
675
- or node.name.value == "edit_vision_code"
676
- ):
677
- if custom_tool_names is not None:
678
- node.value[1].value.append(f"custom_tool_names={custom_tool_names}")
679
- cleaned_code = red.dumps().strip()
680
- return cleaned_code if isinstance(cleaned_code, str) else code
698
+ if code is None:
699
+ return None
700
+
701
+ class VisionAgentTransformer(cst.CSTTransformer):
702
+ def __init__(
703
+ self, test_multi_plan: bool, custom_tool_names: Optional[List[str]]
704
+ ):
705
+ self.test_multi_plan = test_multi_plan
706
+ self.custom_tool_names = custom_tool_names
707
+
708
+ def leave_Call(
709
+ self, original_node: cst.Call, updated_node: cst.Call
710
+ ) -> cst.Call:
711
+ # Check if the function being called is generate_vision_code or edit_vision_code
712
+ if isinstance(updated_node.func, cst.Name) and updated_node.func.value in [
713
+ "generate_vision_code",
714
+ "edit_vision_code",
715
+ ]:
716
+ # Add test_multi_plan argument to generate_vision_code calls
717
+ if updated_node.func.value == "generate_vision_code":
718
+ new_arg = cst.Arg(
719
+ keyword=cst.Name("test_multi_plan"),
720
+ value=cst.Name(str(self.test_multi_plan)),
721
+ equal=cst.AssignEqual(
722
+ whitespace_before=cst.SimpleWhitespace(""),
723
+ whitespace_after=cst.SimpleWhitespace(""),
724
+ ),
725
+ )
726
+ updated_node = updated_node.with_changes(
727
+ args=[*updated_node.args, new_arg]
728
+ )
729
+
730
+ # Add custom_tool_names if provided
731
+ if self.custom_tool_names is not None:
732
+ list_arg = []
733
+ for i, tool_name in enumerate(self.custom_tool_names):
734
+ if i < len(self.custom_tool_names) - 1:
735
+ list_arg.append(
736
+ cst._nodes.expression.Element(
737
+ value=cst.SimpleString(value=f'"{tool_name}"'),
738
+ comma=cst.Comma(
739
+ whitespace_before=cst.SimpleWhitespace(""),
740
+ whitespace_after=cst.SimpleWhitespace(" "),
741
+ ),
742
+ )
743
+ )
744
+ else:
745
+ list_arg.append(
746
+ cst._nodes.expression.Element(
747
+ value=cst.SimpleString(value=f'"{tool_name}"'),
748
+ )
749
+ )
750
+ new_arg = cst.Arg(
751
+ keyword=cst.Name("custom_tool_names"),
752
+ value=cst.List(list_arg),
753
+ equal=cst.AssignEqual(
754
+ whitespace_before=cst.SimpleWhitespace(""),
755
+ whitespace_after=cst.SimpleWhitespace(""),
756
+ ),
757
+ )
758
+ updated_node = updated_node.with_changes(
759
+ args=[*updated_node.args, new_arg]
760
+ )
761
+
762
+ return updated_node
763
+
764
+ # Parse the input code into a CST node
765
+ tree = cst.parse_module(code)
766
+
767
+ # Apply the transformer to modify the CST
768
+ transformer = VisionAgentTransformer(test_multi_plan, custom_tool_names)
769
+ modified_tree = tree.visit(transformer)
770
+
771
+ # Return the modified code as a string
772
+ return modified_tree.code
681
773
 
682
774
 
683
775
  def use_object_detection_fine_tuning(
@@ -757,74 +849,12 @@ def use_object_detection_fine_tuning(
757
849
  return diff
758
850
 
759
851
 
760
- def extract_and_save_files_to_artifacts(
761
- artifacts: Artifacts, code: str, obs: str
762
- ) -> None:
763
- """Extracts and saves files used in the code to the artifacts object.
764
-
765
- Parameters:
766
- artifacts (Artifacts): The artifacts object to save the files to.
767
- code (str): The code to extract the files from.
768
- """
769
- try:
770
- response = extract_json(
771
- AnthropicLMM()( # type: ignore
772
- f"""You are a helpful AI assistant. Your job is to look at a snippet of code and the output of running that code and return the file paths that are being saved in the file. Below is the code snippet:
773
-
774
- ```python
775
- {code}
776
- ```
777
-
778
- ```output
779
- {obs}
780
- ```
781
-
782
- Return the file paths in the following JSON format:
783
- {{"file_paths": ["/path/to/image1.jpg", "/other/path/to/data.json"]}}"""
784
- )
785
- )
786
- except json.JSONDecodeError:
787
- return
788
-
789
- text_file_ext = [
790
- ".txt",
791
- ".md",
792
- "rtf",
793
- ".html",
794
- ".htm",
795
- "xml",
796
- ".json",
797
- ".csv",
798
- ".tsv",
799
- ".yaml",
800
- ".yml",
801
- ".toml",
802
- ".conf",
803
- ".env" ".ini",
804
- ".log",
805
- ".py",
806
- ".java",
807
- ".js",
808
- ".cpp",
809
- ".c" ".sql",
810
- ".sh",
811
- ]
812
-
813
- if "file_paths" in response and isinstance(response["file_paths"], list):
814
- for file_path in response["file_paths"]:
815
- read_mode = "r" if Path(file_path).suffix in text_file_ext else "rb"
816
- if Path(file_path).is_file():
817
- with open(file_path, read_mode) as f:
818
- artifacts[Path(file_path).name] = f.read()
819
-
820
-
821
852
  META_TOOL_DOCSTRING = get_tool_documentation(
822
853
  [
823
854
  get_tool_descriptions,
824
855
  open_code_artifact,
825
856
  create_code_artifact,
826
857
  edit_code_artifact,
827
- generate_vision_plan,
828
858
  generate_vision_code,
829
859
  edit_vision_code,
830
860
  view_media_artifact,
@@ -575,6 +575,7 @@ class LocalCodeInterpreter(CodeInterpreter):
575
575
  super().__init__(timeout=timeout)
576
576
  self.nb = nbformat.v4.new_notebook()
577
577
  # Set the notebook execution path to the remote path
578
+ self.remote_path = Path(remote_path if remote_path is not None else WORKSPACE)
578
579
  self.resources = {"metadata": {"path": str(self.remote_path)}}
579
580
  self.nb_client = NotebookClient(
580
581
  self.nb,
@@ -591,7 +592,6 @@ Timeout: {self.timeout}"""
591
592
  )
592
593
  sleep(1)
593
594
  self._new_kernel()
594
- self.remote_path = Path(remote_path if remote_path is not None else WORKSPACE)
595
595
 
596
596
  def _new_kernel(self) -> None:
597
597
  if self.nb_client.kc is None or not run_sync(self.nb_client.kc.is_alive)(): # type: ignore
@@ -659,7 +659,7 @@ Timeout: {self.timeout}"""
659
659
  def download_file(
660
660
  self, remote_file_path: Union[str, Path], local_file_path: Union[str, Path]
661
661
  ) -> Path:
662
- with open(self.remote_path / remote_file_path, "rb") as f:
662
+ with open(self.remote_path / Path(remote_file_path).name, "rb") as f:
663
663
  contents = f.read()
664
664
  with open(local_file_path, "wb") as f:
665
665
  f.write(contents)
File without changes
File without changes