vision-agent 0.0.51__tar.gz → 0.0.53__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {vision_agent-0.0.51 → vision_agent-0.0.53}/PKG-INFO +3 -2
  2. {vision_agent-0.0.51 → vision_agent-0.0.53}/README.md +2 -1
  3. {vision_agent-0.0.51 → vision_agent-0.0.53}/pyproject.toml +1 -1
  4. {vision_agent-0.0.51 → vision_agent-0.0.53}/vision_agent/agent/vision_agent.py +35 -14
  5. {vision_agent-0.0.51 → vision_agent-0.0.53}/vision_agent/agent/vision_agent_prompts.py +1 -3
  6. vision_agent-0.0.53/vision_agent/fonts/__init__.py +0 -0
  7. vision_agent-0.0.53/vision_agent/fonts/default_font_ch_en.ttf +0 -0
  8. {vision_agent-0.0.51 → vision_agent-0.0.53}/vision_agent/image_utils.py +22 -10
  9. {vision_agent-0.0.51 → vision_agent-0.0.53}/vision_agent/tools/__init__.py +1 -0
  10. {vision_agent-0.0.51 → vision_agent-0.0.53}/vision_agent/tools/tools.py +109 -90
  11. {vision_agent-0.0.51 → vision_agent-0.0.53}/LICENSE +0 -0
  12. {vision_agent-0.0.51 → vision_agent-0.0.53}/vision_agent/__init__.py +0 -0
  13. {vision_agent-0.0.51 → vision_agent-0.0.53}/vision_agent/agent/__init__.py +0 -0
  14. {vision_agent-0.0.51 → vision_agent-0.0.53}/vision_agent/agent/agent.py +0 -0
  15. {vision_agent-0.0.51 → vision_agent-0.0.53}/vision_agent/agent/easytool.py +0 -0
  16. {vision_agent-0.0.51 → vision_agent-0.0.53}/vision_agent/agent/easytool_prompts.py +0 -0
  17. {vision_agent-0.0.51 → vision_agent-0.0.53}/vision_agent/agent/reflexion.py +0 -0
  18. {vision_agent-0.0.51 → vision_agent-0.0.53}/vision_agent/agent/reflexion_prompts.py +0 -0
  19. {vision_agent-0.0.51 → vision_agent-0.0.53}/vision_agent/data/__init__.py +0 -0
  20. {vision_agent-0.0.51 → vision_agent-0.0.53}/vision_agent/data/data.py +0 -0
  21. {vision_agent-0.0.51 → vision_agent-0.0.53}/vision_agent/emb/__init__.py +0 -0
  22. {vision_agent-0.0.51 → vision_agent-0.0.53}/vision_agent/emb/emb.py +0 -0
  23. {vision_agent-0.0.51 → vision_agent-0.0.53}/vision_agent/llm/__init__.py +0 -0
  24. {vision_agent-0.0.51 → vision_agent-0.0.53}/vision_agent/llm/llm.py +0 -0
  25. {vision_agent-0.0.51 → vision_agent-0.0.53}/vision_agent/lmm/__init__.py +0 -0
  26. {vision_agent-0.0.51 → vision_agent-0.0.53}/vision_agent/lmm/lmm.py +0 -0
  27. {vision_agent-0.0.51 → vision_agent-0.0.53}/vision_agent/tools/prompts.py +0 -0
  28. {vision_agent-0.0.51 → vision_agent-0.0.53}/vision_agent/tools/video.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vision-agent
3
- Version: 0.0.51
3
+ Version: 0.0.53
4
4
  Summary: Toolset for Vision Agent
5
5
  Author: Landing AI
6
6
  Author-email: dev@landing.ai
@@ -103,7 +103,8 @@ the individual steps and tools to get the answer:
103
103
  }
104
104
  ]],
105
105
  "answer": "The jar is located at [0.58, 0.2, 0.72, 0.45].",
106
- }]
106
+ },
107
+ {"visualize_output": "final_output.png"}]
107
108
  ```
108
109
 
109
110
  ### Tools
@@ -74,7 +74,8 @@ the individual steps and tools to get the answer:
74
74
  }
75
75
  ]],
76
76
  "answer": "The jar is located at [0.58, 0.2, 0.72, 0.45].",
77
- }]
77
+ },
78
+ {"visualize_output": "final_output.png"}]
78
79
  ```
79
80
 
80
81
  ### Tools
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "vision-agent"
7
- version = "0.0.51"
7
+ version = "0.0.53"
8
8
  description = "Toolset for Vision Agent"
9
9
  authors = ["Landing AI <dev@landing.ai>"]
10
10
  readme = "README.md"
@@ -5,6 +5,7 @@ import tempfile
5
5
  from pathlib import Path
6
6
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
7
7
 
8
+ from PIL import Image
8
9
  from tabulate import tabulate
9
10
 
10
11
  from vision_agent.image_utils import overlay_bboxes, overlay_masks
@@ -288,9 +289,8 @@ def visualize_result(all_tool_results: List[Dict]) -> List[str]:
288
289
  continue
289
290
  parameters = [parameters]
290
291
  elif isinstance(tool_result["parameters"], list):
291
- if (
292
- len(tool_result["parameters"]) < 1
293
- and "image" not in tool_result["parameters"][0]
292
+ if len(tool_result["parameters"]) < 1 or (
293
+ "image" not in tool_result["parameters"][0]
294
294
  ):
295
295
  continue
296
296
 
@@ -304,10 +304,16 @@ def visualize_result(all_tool_results: List[Dict]) -> List[str]:
304
304
  # if the call was successful, then we can add the image data
305
305
  image = param["image"]
306
306
  if image not in image_to_data:
307
- image_to_data[image] = {"bboxes": [], "masks": [], "labels": []}
307
+ image_to_data[image] = {
308
+ "bboxes": [],
309
+ "masks": [],
310
+ "labels": [],
311
+ "scores": [],
312
+ }
308
313
 
309
314
  image_to_data[image]["bboxes"].extend(call_result["bboxes"])
310
315
  image_to_data[image]["labels"].extend(call_result["labels"])
316
+ image_to_data[image]["scores"].extend(call_result["scores"])
311
317
  if "masks" in call_result:
312
318
  image_to_data[image]["masks"].extend(call_result["masks"])
313
319
 
@@ -380,6 +386,7 @@ class VisionAgent(Agent):
380
386
  self,
381
387
  input: Union[List[Dict[str, str]], str],
382
388
  image: Optional[Union[str, Path]] = None,
389
+ visualize_output: Optional[bool] = False,
383
390
  ) -> str:
384
391
  """Invoke the vision agent.
385
392
 
@@ -393,7 +400,7 @@ class VisionAgent(Agent):
393
400
  """
394
401
  if isinstance(input, str):
395
402
  input = [{"role": "user", "content": input}]
396
- return self.chat(input, image=image)
403
+ return self.chat(input, image=image, visualize_output=visualize_output)
397
404
 
398
405
  def log_progress(self, description: str) -> None:
399
406
  _LOGGER.info(description)
@@ -401,7 +408,10 @@ class VisionAgent(Agent):
401
408
  self.report_progress_callback(description)
402
409
 
403
410
  def chat_with_workflow(
404
- self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None
411
+ self,
412
+ chat: List[Dict[str, str]],
413
+ image: Optional[Union[str, Path]] = None,
414
+ visualize_output: Optional[bool] = False,
405
415
  ) -> Tuple[str, List[Dict]]:
406
416
  question = chat[0]["content"]
407
417
  if image:
@@ -449,31 +459,42 @@ class VisionAgent(Agent):
449
459
  self.answer_model, question, answers, reflections
450
460
  )
451
461
 
452
- visualized_images = visualize_result(all_tool_results)
453
- all_tool_results.append({"visualized_images": visualized_images})
462
+ visualized_output = visualize_result(all_tool_results)
463
+ all_tool_results.append({"visualized_output": visualized_output})
454
464
  reflection = self_reflect(
455
465
  self.reflect_model,
456
466
  question,
457
467
  self.tools,
458
468
  all_tool_results,
459
469
  final_answer,
460
- visualized_images[0] if len(visualized_images) > 0 else image,
470
+ visualized_output[0] if len(visualized_output) > 0 else image,
461
471
  )
462
472
  self.log_progress(f"Reflection: {reflection}")
463
473
  if parse_reflect(reflection):
464
474
  break
465
475
  else:
466
- reflections += reflection
467
- # '<ANSWER>' is a symbol to indicate the end of the chat, which is useful for streaming logs.
476
+ reflections += "\n" + reflection
477
+ # '<END>' is a symbol to indicate the end of the chat, which is useful for streaming logs.
468
478
  self.log_progress(
469
- f"The Vision Agent has concluded this chat. <ANSWER>{final_answer}</ANSWER>"
479
+ f"The Vision Agent has concluded this chat. <ANSWER>{final_answer}</<ANSWER>"
470
480
  )
481
+
482
+ if visualize_output:
483
+ visualized_output = all_tool_results[-1]["visualized_output"]
484
+ for image in visualized_output:
485
+ Image.open(image).show()
486
+
471
487
  return final_answer, all_tool_results
472
488
 
473
489
  def chat(
474
- self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None
490
+ self,
491
+ chat: List[Dict[str, str]],
492
+ image: Optional[Union[str, Path]] = None,
493
+ visualize_output: Optional[bool] = False,
475
494
  ) -> str:
476
- answer, _ = self.chat_with_workflow(chat, image=image)
495
+ answer, _ = self.chat_with_workflow(
496
+ chat, image=image, visualize_output=visualize_output
497
+ )
477
498
  return answer
478
499
 
479
500
  def retrieval(
@@ -1,4 +1,4 @@
1
- VISION_AGENT_REFLECTION = """You are an advanced reasoning agent that can improve based on self-refection. You will be given a previous reasoning trial in which you were given the user's question, the available tools that the agent has, the decomposed tasks and tools that the agent used to answer the question and the final answer the agent provided. You must determine if the agent's answer was correct or incorrect. If the agent's answer was correct, respond with Finish. If the agent's answer was incorrect, you must diagnose a possible reason for failure or phrasing discrepancy and devise a new, concise, high level plan that aims to mitigate the same failure with the tools available. Use complete sentences.
1
+ VISION_AGENT_REFLECTION = """You are an advanced reasoning agent that can improve based on self-refection. You will be given a previous reasoning trial in which you were given the user's question, the available tools that the agent has, the decomposed tasks and tools that the agent used to answer the question and the final answer the agent provided. You may also receive an image with the visualized bounding boxes or masks with their associated labels and scores from the tools used. You must determine if the agent's answer was correct or incorrect. If the agent's answer was correct, respond with Finish. If the agent's answer was incorrect, you must diagnose a possible reason for failure or phrasing discrepancy and devise a new, concise, concrete plan that aims to mitigate the same failure with the tools available. Do not make vague steps like re-evaluate the threshold, instead make concrete steps like use a threshold of 0.5 or whatever threshold you think would fix this issue. If the task cannot be completed with the existing tools, respond with Finish. Use complete sentences.
2
2
 
3
3
  User's question: {question}
4
4
 
@@ -49,7 +49,6 @@ Output: """
49
49
 
50
50
  CHOOSE_TOOL = """This is the user's question: {question}
51
51
  These are the tools you can select to solve the question:
52
-
53
52
  {tools}
54
53
 
55
54
  Please note that:
@@ -63,7 +62,6 @@ Output: """
63
62
 
64
63
  CHOOSE_TOOL_DEPENDS = """This is the user's question: {question}
65
64
  These are the tools you can select to solve the question:
66
-
67
65
  {tools}
68
66
 
69
67
  This is a reflection from a previous failed attempt:
File without changes
@@ -1,6 +1,7 @@
1
1
  """Utility functions for image processing."""
2
2
 
3
3
  import base64
4
+ from importlib import resources
4
5
  from io import BytesIO
5
6
  from pathlib import Path
6
7
  from typing import Dict, Tuple, Union
@@ -104,19 +105,28 @@ def overlay_bboxes(
104
105
 
105
106
  color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(bboxes["labels"])}
106
107
 
107
- draw = ImageDraw.Draw(image)
108
- font = ImageFont.load_default()
109
108
  width, height = image.size
109
+ fontsize = max(12, int(min(width, height) / 40))
110
+ draw = ImageDraw.Draw(image)
111
+ font = ImageFont.truetype(
112
+ str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")),
113
+ fontsize,
114
+ )
110
115
  if "bboxes" not in bboxes:
111
116
  return image.convert("RGB")
112
117
 
113
- for label, box in zip(bboxes["labels"], bboxes["bboxes"]):
114
- box = [box[0] * width, box[1] * height, box[2] * width, box[3] * height]
115
- draw.rectangle(box, outline=color[label], width=3)
116
- label = f"{label}"
117
- text_box = draw.textbbox((box[0], box[1]), text=label, font=font)
118
- draw.rectangle(text_box, fill=color[label])
119
- draw.text((text_box[0], text_box[1]), label, fill="black", font=font)
118
+ for label, box, scores in zip(bboxes["labels"], bboxes["bboxes"], bboxes["scores"]):
119
+ box = [
120
+ int(box[0] * width),
121
+ int(box[1] * height),
122
+ int(box[2] * width),
123
+ int(box[3] * height),
124
+ ]
125
+ draw.rectangle(box, outline=color[label], width=4)
126
+ text = f"{label}: {scores:.2f}"
127
+ text_box = draw.textbbox((box[0], box[1]), text=text, font=font)
128
+ draw.rectangle((box[0], box[1], text_box[2], text_box[3]), fill=color[label])
129
+ draw.text((box[0], box[1]), text, fill="black", font=font)
120
130
  return image.convert("RGB")
121
131
 
122
132
 
@@ -138,7 +148,9 @@ def overlay_masks(
138
148
  elif isinstance(image, np.ndarray):
139
149
  image = Image.fromarray(image)
140
150
 
141
- color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(masks["labels"])}
151
+ color = {
152
+ label: COLORS[i % len(COLORS)] for i, label in enumerate(set(masks["labels"]))
153
+ }
142
154
  if "masks" not in masks:
143
155
  return image.convert("RGB")
144
156
 
@@ -9,6 +9,7 @@ from .tools import (
9
9
  ExtractFrames,
10
10
  GroundingDINO,
11
11
  GroundingSAM,
12
+ ImageCaption,
12
13
  SegArea,
13
14
  SegIoU,
14
15
  Tool,
@@ -53,9 +53,7 @@ class Tool(ABC):
53
53
 
54
54
  class NoOp(Tool):
55
55
  name = "noop_"
56
- description = (
57
- "'noop_' is a no-op tool that does nothing if you do not need to use a tool."
58
- )
56
+ description = "'noop_' is a no-op tool that does nothing if you do not want answer the question directly and not use a tool."
59
57
  usage = {
60
58
  "required_parameters": [],
61
59
  "examples": [
@@ -85,7 +83,7 @@ class CLIP(Tool):
85
83
  _ENDPOINT = "https://soi4ewr6fjqqdf5vuss6rrilee0kumxq.lambda-url.us-east-2.on.aws"
86
84
 
87
85
  name = "clip_"
88
- description = "'clip_' is a tool that can classify or tag any image given a set of input classes or tags."
86
+ description = "'clip_' is a tool that can classify any image given a set of input names or tags. It returns a list of the input names along with their probability scores."
89
87
  usage = {
90
88
  "required_parameters": [
91
89
  {"name": "prompt", "type": "str"},
@@ -146,6 +144,74 @@ class CLIP(Tool):
146
144
  return resp_json["data"] # type: ignore
147
145
 
148
146
 
147
+ class ImageCaption(Tool):
148
+ r"""ImageCaption is a tool that can caption an image based on its contents
149
+ or tags.
150
+
151
+ Example
152
+ -------
153
+ >>> import vision_agent as va
154
+ >>> caption = va.tools.ImageCaption()
155
+ >>> caption("image1.jpg")
156
+ {'text': ['a box of orange and white socks']}
157
+ """
158
+
159
+ _ENDPOINT = "https://soi4ewr6fjqqdf5vuss6rrilee0kumxq.lambda-url.us-east-2.on.aws"
160
+
161
+ name = "image_caption_"
162
+ description = "'image_caption_' is a tool that can caption an image based on its contents or tags. It returns a text describing the image"
163
+ usage = {
164
+ "required_parameters": [
165
+ {"name": "image", "type": "str"},
166
+ ],
167
+ "examples": [
168
+ {
169
+ "scenario": "Can you describe this image ? Image name: cat.jpg",
170
+ "parameters": {"image": "cat.jpg"},
171
+ },
172
+ {
173
+ "scenario": "Can you caption this image with their main contents ? Image name: cat_dog.jpg",
174
+ "parameters": {"image": "cat_dog.jpg"},
175
+ },
176
+ {
177
+ "scenario": "Can you build me a image captioning tool ? Image name: shirts.jpg",
178
+ "parameters": {
179
+ "image": "shirts.jpg",
180
+ },
181
+ },
182
+ ],
183
+ }
184
+
185
+ # TODO: Add support for input multiple images, which aligns with the output type.
186
+ def __call__(self, image: Union[str, ImageType]) -> Dict:
187
+ """Invoke the Image captioning model.
188
+
189
+ Parameters:
190
+ image: the input image to caption.
191
+
192
+ Returns:
193
+ A list of dictionaries containing the labels and scores. Each dictionary contains the classification result for an image. E.g. [{"labels": ["red line", "yellow dot"], "scores": [0.98, 0.02]}]
194
+ """
195
+ image_b64 = convert_to_b64(image)
196
+ data = {
197
+ "image": image_b64,
198
+ "tool": "image_captioning",
199
+ }
200
+ res = requests.post(
201
+ self._ENDPOINT,
202
+ headers={"Content-Type": "application/json"},
203
+ json=data,
204
+ )
205
+ resp_json: Dict[str, Any] = res.json()
206
+ if (
207
+ "statusCode" in resp_json and resp_json["statusCode"] != 200
208
+ ) or "statusCode" not in resp_json:
209
+ _LOGGER.error(f"Request failed: {resp_json}")
210
+ raise ValueError(f"Request failed: {resp_json}")
211
+
212
+ return resp_json["data"] # type: ignore
213
+
214
+
149
215
  class GroundingDINO(Tool):
150
216
  r"""Grounding DINO is a tool that can detect arbitrary objects with inputs such as
151
217
  category names or referring expressions.
@@ -163,7 +229,7 @@ class GroundingDINO(Tool):
163
229
  _ENDPOINT = "https://soi4ewr6fjqqdf5vuss6rrilee0kumxq.lambda-url.us-east-2.on.aws"
164
230
 
165
231
  name = "grounding_dino_"
166
- description = "'grounding_dino_' is a tool that can detect arbitrary objects with inputs such as category names or referring expressions."
232
+ description = "'grounding_dino_' is a tool that can detect arbitrary objects with inputs such as category names or referring expressions. It returns a list of bounding boxes, label names and associated probability scores."
167
233
  usage = {
168
234
  "required_parameters": [
169
235
  {"name": "prompt", "type": "str"},
@@ -179,8 +245,11 @@ class GroundingDINO(Tool):
179
245
  "parameters": {"prompt": "car", "image": ""},
180
246
  },
181
247
  {
182
- "scenario": "Can you detect the person on the left? Image name: person.jpg",
183
- "parameters": {"prompt": "person on the left", "image": "person.jpg"},
248
+ "scenario": "Can you detect the person on the left and right? Image name: person.jpg",
249
+ "parameters": {
250
+ "prompt": "left person. right person",
251
+ "image": "person.jpg",
252
+ },
184
253
  },
185
254
  {
186
255
  "scenario": "Detect the red shirts and green shirst. Image name: shirts.jpg",
@@ -269,7 +338,7 @@ class GroundingSAM(Tool):
269
338
  _ENDPOINT = "https://soi4ewr6fjqqdf5vuss6rrilee0kumxq.lambda-url.us-east-2.on.aws"
270
339
 
271
340
  name = "grounding_sam_"
272
- description = "'grounding_sam_' is a tool that can detect and segment arbitrary objects with inputs such as category names or referring expressions."
341
+ description = "'grounding_sam_' is a tool that can detect arbitrary objects with inputs such as category names or referring expressions. It returns a list of bounding boxes, label names and masks file names and associated probability scores."
273
342
  usage = {
274
343
  "required_parameters": [
275
344
  {"name": "prompt", "type": "str"},
@@ -285,8 +354,11 @@ class GroundingSAM(Tool):
285
354
  "parameters": {"prompt": "car", "image": ""},
286
355
  },
287
356
  {
288
- "scenario": "Can you segment the person on the left? Image name: person.jpg",
289
- "parameters": {"prompt": "person on the left", "image": "person.jpg"},
357
+ "scenario": "Can you segment the person on the left and right? Image name: person.jpg",
358
+ "parameters": {
359
+ "prompt": "left person. right person",
360
+ "image": "person.jpg",
361
+ },
290
362
  },
291
363
  {
292
364
  "scenario": "Can you build me a tool that segments red shirts and green shirts? Image name: shirts.jpg",
@@ -370,8 +442,9 @@ class AgentGroundingSAM(GroundingSAM):
370
442
  mask_files = []
371
443
  for mask in rets["masks"]:
372
444
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
373
- Image.fromarray(mask * 255).save(tmp)
374
- mask_files.append(tmp.name)
445
+ file_name = Path(tmp.name).with_suffix(".mask.png")
446
+ Image.fromarray(mask * 255).save(file_name)
447
+ mask_files.append(str(file_name))
375
448
  rets["masks"] = mask_files
376
449
  return rets
377
450
 
@@ -380,7 +453,7 @@ class Counter(Tool):
380
453
  r"""Counter detects and counts the number of objects in an image given an input such as a category name or referring expression."""
381
454
 
382
455
  name = "counter_"
383
- description = "'counter_' detects and counts the number of objects in an image given an input such as a category name or referring expression."
456
+ description = "'counter_' detects and counts the number of objects in an image given an input such as a category name or referring expression. It returns a dictionary containing the labels and their counts."
384
457
  usage = {
385
458
  "required_parameters": [
386
459
  {"name": "prompt", "type": "str"},
@@ -400,14 +473,14 @@ class Counter(Tool):
400
473
 
401
474
  def __call__(self, prompt: str, image: Union[str, ImageType]) -> Dict:
402
475
  resp = GroundingDINO()(prompt, image)
403
- return dict(CounterClass(resp[0]["labels"]))
476
+ return dict(CounterClass(resp["labels"]))
404
477
 
405
478
 
406
479
  class Crop(Tool):
407
480
  r"""Crop crops an image given a bounding box and returns a file name of the cropped image."""
408
481
 
409
482
  name = "crop_"
410
- description = "'crop_' crops an image given a bounding box and returns a file name of the cropped image."
483
+ description = "'crop_' crops an image given a bounding box and returns a file name of the cropped image. It returns a file with the cropped image."
411
484
  usage = {
412
485
  "required_parameters": [
413
486
  {"name": "bbox", "type": "List[float]"},
@@ -495,9 +568,7 @@ class SegArea(Tool):
495
568
 
496
569
  class BboxIoU(Tool):
497
570
  name = "bbox_iou_"
498
- description = (
499
- "'bbox_iou_' returns the intersection over union of two bounding boxes."
500
- )
571
+ description = "'bbox_iou_' returns the intersection over union of two bounding boxes. This is a good tool for determining if two objects are overlapping."
501
572
  usage = {
502
573
  "required_parameters": [
503
574
  {"name": "bbox1", "type": "List[int]"},
@@ -591,85 +662,35 @@ class ExtractFrames(Tool):
591
662
  )
592
663
  for frame, ts in frames:
593
664
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
594
- Image.fromarray(frame).save(tmp)
595
- result.append((tmp.name, ts))
665
+ file_name = Path(tmp.name).with_suffix(".frame.png")
666
+ Image.fromarray(frame).save(file_name)
667
+ result.append((str(file_name), ts))
596
668
  return result
597
669
 
598
670
 
599
- class Add(Tool):
600
- r"""Add returns the sum of all the arguments passed to it, normalized to 2 decimal places."""
601
-
602
- name = "add_"
603
- description = "'add_' returns the sum of all the arguments passed to it, normalized to 2 decimal places."
604
- usage = {
605
- "required_parameters": [{"name": "input", "type": "List[int]"}],
606
- "examples": [
607
- {
608
- "scenario": "If you want to calculate 2 + 4",
609
- "parameters": {"input": [2, 4]},
610
- }
611
- ],
612
- }
613
-
614
- def __call__(self, input: List[int]) -> float:
615
- return round(sum(input), 2)
616
-
617
-
618
- class Subtract(Tool):
619
- r"""Subtract returns the difference of all the arguments passed to it, normalized to 2 decimal places."""
671
+ class Calculator(Tool):
672
+ r"""Calculator is a tool that can perform basic arithmetic operations."""
620
673
 
621
- name = "subtract_"
622
- description = "'subtract_' returns the difference of all the arguments passed to it, normalized to 2 decimal places."
623
- usage = {
624
- "required_parameters": [{"name": "input", "type": "List[int]"}],
625
- "examples": [
626
- {
627
- "scenario": "If you want to calculate 4 - 2",
628
- "parameters": {"input": [4, 2]},
629
- }
630
- ],
631
- }
632
-
633
- def __call__(self, input: List[int]) -> float:
634
- return round(input[0] - input[1], 2)
635
-
636
-
637
- class Multiply(Tool):
638
- r"""Multiply returns the product of all the arguments passed to it, normalized to 2 decimal places."""
639
-
640
- name = "multiply_"
641
- description = "'multiply_' returns the product of all the arguments passed to it, normalized to 2 decimal places."
674
+ name = "calculator_"
675
+ description = (
676
+ "'calculator_' is a tool that can perform basic arithmetic operations."
677
+ )
642
678
  usage = {
643
- "required_parameters": [{"name": "input", "type": "List[int]"}],
679
+ "required_parameters": [{"name": "equation", "type": "str"}],
644
680
  "examples": [
645
681
  {
646
- "scenario": "If you want to calculate 2 * 4",
647
- "parameters": {"input": [2, 4]},
648
- }
649
- ],
650
- }
651
-
652
- def __call__(self, input: List[int]) -> float:
653
- return round(input[0] * input[1], 2)
654
-
655
-
656
- class Divide(Tool):
657
- r"""Divide returns the division of all the arguments passed to it, normalized to 2 decimal places."""
658
-
659
- name = "divide_"
660
- description = "'divide_' returns the division of all the arguments passed to it, normalized to 2 decimal places."
661
- usage = {
662
- "required_parameters": [{"name": "input", "type": "List[int]"}],
663
- "examples": [
682
+ "scenario": "If you want to calculate (2 * 3) + 4",
683
+ "parameters": {"equation": "2 + 4"},
684
+ },
664
685
  {
665
- "scenario": "If you want to calculate 4 / 2",
666
- "parameters": {"input": [4, 2]},
667
- }
686
+ "scenario": "If you want to calculate (4 + 2.5) / 2.1",
687
+ "parameters": {"equation": "(4 + 2.5) / 2.1"},
688
+ },
668
689
  ],
669
690
  }
670
691
 
671
- def __call__(self, input: List[int]) -> float:
672
- return round(input[0] / input[1], 2)
692
+ def __call__(self, equation: str) -> float:
693
+ return cast(float, round(eval(equation), 2))
673
694
 
674
695
 
675
696
  TOOLS = {
@@ -678,6 +699,7 @@ TOOLS = {
678
699
  [
679
700
  NoOp,
680
701
  CLIP,
702
+ ImageCaption,
681
703
  GroundingDINO,
682
704
  AgentGroundingSAM,
683
705
  ExtractFrames,
@@ -687,10 +709,7 @@ TOOLS = {
687
709
  SegArea,
688
710
  BboxIoU,
689
711
  SegIoU,
690
- Add,
691
- Subtract,
692
- Multiply,
693
- Divide,
712
+ Calculator,
694
713
  ]
695
714
  )
696
715
  if (hasattr(c, "name") and hasattr(c, "description") and hasattr(c, "usage"))
File without changes