vision-agent 1.0.4__py3-none-any.whl → 1.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,17 +1,11 @@
1
1
  import difflib
2
- import json
3
2
  import os
4
3
  import re
5
- import subprocess
6
- import tempfile
7
4
  from pathlib import Path
8
- from typing import Any, Dict, List, Optional, Tuple, Union
5
+ from typing import Any, Dict, List, Tuple, Union
9
6
 
10
- import libcst as cst
11
7
  from IPython.display import display
12
8
 
13
- import vision_agent as va
14
- from vision_agent.models import Message
15
9
  from vision_agent.tools.tools import get_tools_descriptions as _get_tool_descriptions
16
10
  from vision_agent.utils.execute import Execution, MimeType
17
11
  from vision_agent.utils.tools_doc import get_tool_documentation
@@ -152,392 +146,6 @@ def view_lines(
152
146
  return return_str
153
147
 
154
148
 
155
- def open_code_artifact(
156
- artifacts: Artifacts, name: str, line_num: int = 0, window_size: int = 100
157
- ) -> str:
158
- """Opens the provided code artifact. If `line_num` is provided, the window will be
159
- moved to include that line. It only shows the first 100 lines by default! Max
160
- `window_size` supported is 2000.
161
-
162
- Parameters:
163
- artifacts (Artifacts): The artifacts object to open the artifact from.
164
- name (str): The name of the artifact to open.
165
- line_num (int): The line number to move the window to.
166
- window_size (int): The number of lines to show above and below the line.
167
- """
168
- if name not in artifacts:
169
- return f"[Artifact {name} does not exist]"
170
-
171
- total_lines = len(artifacts[name].splitlines())
172
- window_size = min(window_size, 2000)
173
- window_size = window_size // 2
174
- if line_num - window_size < 0:
175
- line_num = window_size
176
- elif line_num >= total_lines:
177
- line_num = total_lines - 1 - window_size
178
-
179
- lines = artifacts[name].splitlines(keepends=True)
180
-
181
- return view_lines(lines, line_num, window_size, name, total_lines)
182
-
183
-
184
- def create_code_artifact(artifacts: Artifacts, name: str) -> str:
185
- """Creates a new code artifiact with the given name.
186
-
187
- Parameters:
188
- artifacts (Artifacts): The artifacts object to add the new artifact to.
189
- name (str): The name of the new artifact.
190
- """
191
- if name in artifacts:
192
- return_str = f"[Artifact {name} already exists]"
193
- else:
194
- artifacts[name] = ""
195
- return_str = f"[Artifact {name} created]"
196
- print(return_str)
197
-
198
- display(
199
- {
200
- MimeType.APPLICATION_ARTIFACT: json.dumps(
201
- {
202
- "name": name,
203
- "content": artifacts[name],
204
- "action": "create",
205
- }
206
- )
207
- },
208
- raw=True,
209
- )
210
- return return_str
211
-
212
-
213
- def edit_code_artifact(
214
- artifacts: Artifacts, name: str, start: int, end: int, content: str
215
- ) -> str:
216
- """Edits the given code artifact with the provided content. The content will be
217
- inserted between the `start` and `end` line numbers. If the `start` and `end` are
218
- the same, the content will be inserted at the `start` line number. If the `end` is
219
- greater than the total number of lines in the file, the content will be inserted at
220
- the end of the file. If the `start` or `end` are negative, the function will return
221
- an error message.
222
-
223
- Parameters:
224
- artifacts (Artifacts): The artifacts object to edit the artifact from.
225
- name (str): The name of the artifact to edit.
226
- start (int): The line number to start the edit, can be in [-1, total_lines]
227
- where -1 represents the end of the file.
228
- end (int): The line number to end the edit, can be in [-1, total_lines] where
229
- -1 represents the end of the file.
230
- content (str): The content to insert.
231
- """
232
- # just make the artifact if it doesn't exist instead of forcing agent to call
233
- # create_artifact
234
- if name not in artifacts:
235
- artifacts[name] = ""
236
-
237
- total_lines = len(artifacts[name].splitlines())
238
- if start == -1:
239
- start = total_lines
240
- if end == -1:
241
- end = total_lines
242
-
243
- if start < 0 or end < 0 or start > end or end > total_lines:
244
- print("[Invalid line range]")
245
- return "[Invalid line range]"
246
-
247
- new_content_lines = content.splitlines(keepends=True)
248
- new_content_lines = [
249
- line if line.endswith("\n") else line + "\n" for line in new_content_lines
250
- ]
251
- lines = artifacts[name].splitlines(keepends=True)
252
- lines = [line if line.endswith("\n") else line + "\n" for line in lines]
253
- edited_lines = lines[:start] + new_content_lines + lines[end:]
254
-
255
- cur_line = start + len(content.split("\n")) // 2
256
- with tempfile.NamedTemporaryFile(delete=True) as f:
257
- with open(f.name, "w") as f: # type: ignore
258
- f.writelines(edited_lines)
259
-
260
- process = subprocess.Popen(
261
- [
262
- "flake8",
263
- "--isolated",
264
- "--select=F821,F822,F831,E111,E112,E113,E999,E902",
265
- f.name,
266
- ],
267
- stdout=subprocess.PIPE,
268
- stderr=subprocess.PIPE,
269
- text=True,
270
- )
271
- stdout, _ = process.communicate()
272
-
273
- if stdout != "":
274
- stdout = stdout.replace(f.name, name)
275
- error_msg = "[Edit failed with the following status]\n" + stdout
276
- original_view = view_lines(
277
- lines,
278
- start + ((end - start) // 2),
279
- DEFAULT_WINDOW_SIZE,
280
- name,
281
- total_lines,
282
- print_output=False,
283
- )
284
- total_lines_edit = sum(1 for _ in edited_lines)
285
- edited_view = view_lines(
286
- edited_lines,
287
- cur_line,
288
- DEFAULT_WINDOW_SIZE,
289
- name,
290
- total_lines_edit,
291
- print_output=False,
292
- )
293
-
294
- error_msg += f"\n[This is how your edit would have looked like if applied]\n{edited_view}\n\n[This is the original code before your edit]\n{original_view}"
295
- print(error_msg)
296
- return error_msg
297
-
298
- artifacts[name] = "".join(edited_lines)
299
-
300
- display(
301
- {
302
- MimeType.APPLICATION_ARTIFACT: json.dumps(
303
- {
304
- "name": name,
305
- "content": artifacts[name],
306
- "action": "edit",
307
- }
308
- )
309
- },
310
- raw=True,
311
- )
312
- return open_code_artifact(artifacts, name, cur_line)
313
-
314
-
315
- def generate_vision_plan(
316
- artifacts: Artifacts,
317
- name: str,
318
- chat: str,
319
- media: List[str],
320
- test_multi_plan: bool = True,
321
- custom_tool_names: Optional[List[str]] = None,
322
- ) -> str:
323
- """Generates a plan to solve vision based tasks.
324
-
325
- Parameters:
326
- artifacts (Artifacts): The artifacts object to save the plan to.
327
- name (str): The name of the artifact to save the plan context to.
328
- chat (str): The chat message from the user.
329
- media (List[str]): The media files to use.
330
- test_multi_plan (bool): Do not change this parameter.
331
- custom_tool_names (Optional[List[str]]): Do not change this parameter.
332
-
333
- Returns:
334
- str: The generated plan.
335
-
336
- Examples
337
- --------
338
- >>> generate_vision_plan(artifacts, "plan.json", "Can you detect the dogs in this image?", ["image.jpg"])
339
- [Start Plan Context]
340
- plan1: This is a plan to detect dogs in an image
341
- -load image
342
- -detect dogs
343
- -return detections
344
- [End Plan Context]
345
- """
346
-
347
- # verbosity is set to 0 to avoid adding extra content to the VisionAgent conversation
348
- if ZMQ_PORT is not None:
349
- agent = va.agent.VisionAgentPlanner(
350
- report_progress_callback=lambda inp: report_progress_callback(
351
- int(ZMQ_PORT), inp
352
- ),
353
- verbosity=0,
354
- )
355
- else:
356
- agent = va.agent.VisionAgentPlanner(verbosity=0)
357
-
358
- fixed_chat: List[Message] = [{"role": "user", "content": chat, "media": media}]
359
- response = agent.generate_plan(
360
- fixed_chat,
361
- test_multi_plan=test_multi_plan,
362
- custom_tool_names=custom_tool_names,
363
- )
364
- if response.test_results is not None:
365
- redisplay_results(response.test_results)
366
- response.test_results = None
367
- artifacts[name] = response.model_dump_json()
368
-
369
- output_str = f"[Start Plan Context, saved at {name}]"
370
- for plan in response.plans.keys():
371
- output_str += f"\n{plan}: {response.plans[plan]['thoughts'].strip()}\n" # type: ignore
372
- output_str += " -" + "\n -".join(
373
- e.strip() for e in response.plans[plan]["instructions"]
374
- )
375
-
376
- output_str += f"\nbest plan: {response.best_plan}\n"
377
- output_str += "thoughts: " + response.plan_thoughts.strip() + "\n"
378
- output_str += "[End Plan Context]"
379
- print(output_str)
380
- return output_str
381
-
382
-
383
- def generate_vision_code(
384
- artifacts: Artifacts,
385
- name: str,
386
- chat: str,
387
- media: List[str],
388
- test_multi_plan: bool = True,
389
- custom_tool_names: Optional[List[str]] = None,
390
- ) -> str:
391
- """Generates python code to solve vision based tasks.
392
-
393
- Parameters:
394
- artifacts (Artifacts): The artifacts object to save the code to.
395
- name (str): The name of the artifact to save the code to.
396
- chat (str): The chat message from the user.
397
- media (List[str]): The media files to use.
398
- test_multi_plan (bool): Do not change this parameter.
399
- custom_tool_names (Optional[List[str]]): Do not change this parameter.
400
-
401
- Returns:
402
- str: The generated code.
403
-
404
- Examples
405
- --------
406
- >>> generate_vision_code(artifacts, "code.py", "Can you detect the dogs in this image?", ["image.jpg"])
407
- from vision_agent.tools import load_image, owl_v2
408
- def detect_dogs(image_path: str):
409
- image = load_image(image_path)
410
- dogs = owl_v2("dog", image)
411
- return dogs
412
- """
413
- # verbosity is set to 0 to avoid adding extra content to the VisionAgent conversation
414
- if ZMQ_PORT is not None:
415
- agent = va.agent.VisionAgentCoder(
416
- report_progress_callback=lambda inp: report_progress_callback(
417
- int(ZMQ_PORT), inp
418
- ),
419
- verbosity=0,
420
- )
421
- else:
422
- agent = va.agent.VisionAgentCoder(verbosity=0)
423
-
424
- fixed_chat: List[Message] = [{"role": "user", "content": chat, "media": media}]
425
- response = agent.generate_code(
426
- fixed_chat,
427
- test_multi_plan=test_multi_plan,
428
- custom_tool_names=custom_tool_names,
429
- )
430
-
431
- redisplay_results(response["test_result"])
432
- code = response["code"]
433
- artifacts[name] = code
434
- code_lines = code.splitlines(keepends=True)
435
- total_lines = len(code_lines)
436
-
437
- display(
438
- {
439
- MimeType.APPLICATION_ARTIFACT: json.dumps(
440
- {
441
- "name": name,
442
- "content": code,
443
- "contentType": "vision_code",
444
- "action": "create",
445
- }
446
- )
447
- },
448
- raw=True,
449
- )
450
- return view_lines(code_lines, 0, total_lines, name, total_lines)
451
-
452
-
453
- def edit_vision_code(
454
- artifacts: Artifacts,
455
- name: str,
456
- chat_history: List[str],
457
- media: List[str],
458
- custom_tool_names: Optional[List[str]] = None,
459
- ) -> str:
460
- """Edits python code to solve a vision based task.
461
-
462
- Parameters:
463
- artifacts (Artifacts): The artifacts object to save the code to.
464
- name (str): The file path to the code.
465
- chat_history (List[str]): The chat history to used to generate the code.
466
- custom_tool_names (Optional[List[str]]): Do not change this parameter.
467
-
468
- Returns:
469
- str: The edited code.
470
-
471
- Examples
472
- --------
473
- >>> edit_vision_code(
474
- >>> artifacts,
475
- >>> "code.py",
476
- >>> ["Can you detect the dogs in this image?", "Can you use a higher threshold?"],
477
- >>> ["dog.jpg"],
478
- >>> )
479
- from vision_agent.tools import load_image, owl_v2
480
- def detect_dogs(image_path: str):
481
- image = load_image(image_path)
482
- dogs = owl_v2("dog", image, threshold=0.8)
483
- return dogs
484
- """
485
-
486
- # verbosity is set to 0 to avoid adding extra content to the VisionAgent conversation
487
- agent = va.agent.VisionAgentCoder(verbosity=0)
488
- if name not in artifacts:
489
- print(f"[Artifact {name} does not exist]")
490
- return f"[Artifact {name} does not exist]"
491
-
492
- code = artifacts[name]
493
-
494
- # Append latest code to second to last message from assistant
495
- fixed_chat_history: List[Message] = []
496
- user_message = "Previous user requests:"
497
- for i, chat in enumerate(chat_history):
498
- if i < len(chat_history) - 1:
499
- user_message += " " + chat
500
- else:
501
- fixed_chat_history.append(
502
- {"role": "user", "content": user_message, "media": media}
503
- )
504
- fixed_chat_history.append({"role": "assistant", "content": code})
505
- fixed_chat_history.append({"role": "user", "content": chat})
506
-
507
- response = agent.generate_code(
508
- fixed_chat_history,
509
- test_multi_plan=False,
510
- custom_tool_names=custom_tool_names,
511
- )
512
-
513
- redisplay_results(response["test_result"])
514
- code = response["code"]
515
- artifacts[name] = code
516
- code_lines = code.splitlines(keepends=True)
517
- total_lines = len(code_lines)
518
-
519
- display(
520
- {
521
- MimeType.APPLICATION_ARTIFACT: json.dumps(
522
- {
523
- "name": name,
524
- "content": code,
525
- "action": "edit",
526
- }
527
- )
528
- },
529
- raw=True,
530
- )
531
- return view_lines(code_lines, 0, total_lines, name, total_lines)
532
-
533
-
534
- def list_artifacts(artifacts: Artifacts) -> str:
535
- """Lists all the artifacts that have been loaded into the artifacts object."""
536
- output_str = artifacts.show()
537
- print(output_str)
538
- return output_str
539
-
540
-
541
149
  def check_and_load_image(code: str) -> List[str]:
542
150
  if not code.strip():
543
151
  return []
@@ -584,108 +192,9 @@ def get_diff_with_prompts(name: str, before: str, after: str) -> str:
584
192
  return f"[Artifact {name} edits]\n{diff}\n[End of edits]"
585
193
 
586
194
 
587
- def use_extra_vision_agent_args(
588
- code: Optional[str],
589
- test_multi_plan: bool = True,
590
- custom_tool_names: Optional[List[str]] = None,
591
- ) -> Optional[str]:
592
- """This is for forcing arguments passed by the user to VisionAgent into the
593
- VisionAgentCoder call.
594
-
595
- Parameters:
596
- code (str): The code to edit.
597
- test_multi_plan (bool): Do not change this parameter.
598
- custom_tool_names (Optional[List[str]]): Do not change this parameter.
599
-
600
- Returns:
601
- str: The edited code.
602
- """
603
- if code is None:
604
- return None
605
-
606
- class VisionAgentTransformer(cst.CSTTransformer):
607
- def __init__(
608
- self, test_multi_plan: bool, custom_tool_names: Optional[List[str]]
609
- ):
610
- self.test_multi_plan = test_multi_plan
611
- self.custom_tool_names = custom_tool_names
612
-
613
- def leave_Call(
614
- self, original_node: cst.Call, updated_node: cst.Call
615
- ) -> cst.Call:
616
- # Check if the function being called is generate_vision_code or edit_vision_code
617
- if isinstance(updated_node.func, cst.Name) and updated_node.func.value in [
618
- "generate_vision_code",
619
- "edit_vision_code",
620
- ]:
621
- # Add test_multi_plan argument to generate_vision_code calls
622
- if updated_node.func.value == "generate_vision_code":
623
- new_arg = cst.Arg(
624
- keyword=cst.Name("test_multi_plan"),
625
- value=cst.Name(str(self.test_multi_plan)),
626
- equal=cst.AssignEqual(
627
- whitespace_before=cst.SimpleWhitespace(""),
628
- whitespace_after=cst.SimpleWhitespace(""),
629
- ),
630
- )
631
- updated_node = updated_node.with_changes(
632
- args=[*updated_node.args, new_arg]
633
- )
634
-
635
- # Add custom_tool_names if provided
636
- if self.custom_tool_names is not None:
637
- list_arg = []
638
- for i, tool_name in enumerate(self.custom_tool_names):
639
- if i < len(self.custom_tool_names) - 1:
640
- list_arg.append(
641
- cst._nodes.expression.Element(
642
- value=cst.SimpleString(value=f'"{tool_name}"'),
643
- comma=cst.Comma(
644
- whitespace_before=cst.SimpleWhitespace(""),
645
- whitespace_after=cst.SimpleWhitespace(" "),
646
- ),
647
- )
648
- )
649
- else:
650
- list_arg.append(
651
- cst._nodes.expression.Element(
652
- value=cst.SimpleString(value=f'"{tool_name}"'),
653
- )
654
- )
655
- new_arg = cst.Arg(
656
- keyword=cst.Name("custom_tool_names"),
657
- value=cst.List(list_arg),
658
- equal=cst.AssignEqual(
659
- whitespace_before=cst.SimpleWhitespace(""),
660
- whitespace_after=cst.SimpleWhitespace(""),
661
- ),
662
- )
663
- updated_node = updated_node.with_changes(
664
- args=[*updated_node.args, new_arg]
665
- )
666
-
667
- return updated_node
668
-
669
- # Parse the input code into a CST node
670
- tree = cst.parse_module(code)
671
-
672
- # Apply the transformer to modify the CST
673
- transformer = VisionAgentTransformer(test_multi_plan, custom_tool_names)
674
- modified_tree = tree.visit(transformer)
675
-
676
- # Return the modified code as a string
677
- return modified_tree.code
678
-
679
-
680
195
  META_TOOL_DOCSTRING = get_tool_documentation(
681
196
  [
682
197
  get_tool_descriptions,
683
- open_code_artifact,
684
- create_code_artifact,
685
- edit_code_artifact,
686
- generate_vision_code,
687
- edit_vision_code,
688
198
  view_media_artifact,
689
- list_artifacts,
690
199
  ]
691
200
  )
@@ -236,7 +236,7 @@ def retrieve_tool_docs(lmm: LMM, task: str, exclude_tools: Optional[List[str]])
236
236
  all_tool_docs = []
237
237
  all_tool_doc_names = set()
238
238
  exclude_tools = [] if exclude_tools is None else exclude_tools
239
- for category in categories:
239
+ for category in categories + [task]:
240
240
  tool_docs = sim.top_k(category, k=3, thresh=0.3)
241
241
 
242
242
  for tool_doc in tool_docs:
@@ -248,9 +248,7 @@ def retrieve_tool_docs(lmm: LMM, task: str, exclude_tools: Optional[List[str]])
248
248
  all_tool_doc_names.add(tool_doc["name"])
249
249
 
250
250
  tool_docs_str = explanation + "\n\n" + "\n".join([e["doc"] for e in all_tool_docs])
251
- tool_docs_str += (
252
- "\n" + get_load_tools_docstring() + get_tool_documentation([judge_od_results])
253
- )
251
+ tool_docs_str += get_load_tools_docstring()
254
252
  return tool_docs_str
255
253
 
256
254
 
@@ -346,22 +344,22 @@ def get_tool_for_task(
346
344
  and output signatures are.
347
345
 
348
346
  Parameters:
349
- task: str: The task to accomplish.
350
- images: Union[Dict[str, List[np.ndarray]], List[np.ndarray]]: The images to use
347
+ task (str): The task to accomplish.
348
+ images (Union[Dict[str, List[np.ndarray]], List[np.ndarray]]): The images to use
351
349
  for the task. If a key is provided, it is used as the file name.
352
- exclude_tools: Optional[List[str]]: A list of tool names to exclude from the
350
+ exclude_tools (Optional[List[str]]): A list of tool names to exclude from the
353
351
  recommendations. This is helpful if you are calling get_tool_for_task twice
354
352
  and do not want the same tool recommended.
355
353
 
356
354
  Returns:
357
- The tool to use for the task is printed to stdout
355
+ None: The function does not return the tool but prints it to stdout.
358
356
 
359
357
  Examples
360
358
  --------
361
359
  >>> get_tool_for_task(
362
360
  >>> "Give me an OCR model that can find 'hot chocolate' in the image",
363
361
  >>> {"image": [image]})
364
- >>> get_tool_for_taks(
362
+ >>> get_tool_for_task(
365
363
  >>> "I need a tool that can paint a background for this image and maks",
366
364
  >>> {"image": [image], "mask": [mask]})
367
365
  """
@@ -497,8 +495,8 @@ def finalize_plan(user_request: str, chain_of_thoughts: str) -> str:
497
495
  return finalized_plan
498
496
 
499
497
 
500
- def claude35_vqa(prompt: str, medias: List[np.ndarray]) -> None:
501
- """Asks the Claude-3.5 model a question about the given media and returns an answer.
498
+ def vqa(prompt: str, medias: List[np.ndarray]) -> None:
499
+ """Asks the VQA model a question about the given media and returns an answer.
502
500
 
503
501
  Parameters:
504
502
  prompt: str: The question to ask the model.
@@ -515,13 +513,14 @@ def claude35_vqa(prompt: str, medias: List[np.ndarray]) -> None:
515
513
  ]
516
514
 
517
515
  response = cast(str, vqa.generate(prompt, media=all_media_b64))
518
- print(f"[claude35_vqa output]\n{response}\n[end of claude35_vqa output]")
516
+ print(f"[vqa output]\n{response}\n[end of vqa output]")
519
517
 
520
518
 
521
519
  def suggestion(prompt: str, medias: List[np.ndarray]) -> None:
522
520
  """Given your problem statement and the images, this will provide you with a
523
521
  suggested plan on how to proceed. Always call suggestion when starting to solve
524
- a problem.
522
+ a problem. 'suggestion' will only print pseudo code for you to execute, it will not
523
+ execute the code for you.
525
524
 
526
525
  Parameters:
527
526
  prompt: str: The problem statement, provide a detailed description of the
@@ -538,7 +537,7 @@ def suggestion(prompt: str, medias: List[np.ndarray]) -> None:
538
537
 
539
538
 
540
539
  PLANNER_TOOLS = [
541
- claude35_vqa,
540
+ vqa,
542
541
  suggestion,
543
542
  get_tool_for_task,
544
543
  ]