vision-agent 0.1.5__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vision_agent/__init__.py CHANGED
@@ -1,5 +1,3 @@
1
1
  from .agent import Agent
2
- from .data import DataStore, build_data_store
3
- from .emb import Embedder, OpenAIEmb, SentenceTransformerEmb, get_embedder
4
2
  from .llm import LLM, OpenAILLM
5
3
  from .lmm import LMM, LLaVALMM, OpenAILMM, get_lmm
@@ -8,7 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
8
8
  from PIL import Image
9
9
  from tabulate import tabulate
10
10
 
11
- from vision_agent.image_utils import overlay_bboxes, overlay_masks
11
+ from vision_agent.image_utils import overlay_bboxes, overlay_masks, overlay_heat_map
12
12
  from vision_agent.llm import LLM, OpenAILLM
13
13
  from vision_agent.lmm import LMM, OpenAILMM
14
14
  from vision_agent.tools import TOOLS
@@ -336,7 +336,9 @@ def _handle_viz_tools(
336
336
 
337
337
  for param, call_result in zip(parameters, tool_result["call_results"]):
338
338
  # calls can fail, so we need to check if the call was successful
339
- if not isinstance(call_result, dict) or "bboxes" not in call_result:
339
+ if not isinstance(call_result, dict) or (
340
+ "bboxes" not in call_result and "masks" not in call_result
341
+ ):
340
342
  return image_to_data
341
343
 
342
344
  # if the call was successful, then we can add the image data
@@ -349,11 +351,12 @@ def _handle_viz_tools(
349
351
  "scores": [],
350
352
  }
351
353
 
352
- image_to_data[image]["bboxes"].extend(call_result["bboxes"])
353
- image_to_data[image]["labels"].extend(call_result["labels"])
354
- image_to_data[image]["scores"].extend(call_result["scores"])
355
- if "masks" in call_result:
356
- image_to_data[image]["masks"].extend(call_result["masks"])
354
+ image_to_data[image]["bboxes"].extend(call_result.get("bboxes", []))
355
+ image_to_data[image]["labels"].extend(call_result.get("labels", []))
356
+ image_to_data[image]["scores"].extend(call_result.get("scores", []))
357
+ image_to_data[image]["masks"].extend(call_result.get("masks", []))
358
+ if "mask_shape" in call_result:
359
+ image_to_data[image]["mask_shape"] = call_result["mask_shape"]
357
360
 
358
361
  return image_to_data
359
362
 
@@ -367,6 +370,8 @@ def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]
367
370
  "grounding_dino_",
368
371
  "extract_frames_",
369
372
  "dinov_",
373
+ "zero_shot_counting_",
374
+ "visual_prompt_counting_",
370
375
  ]:
371
376
  continue
372
377
 
@@ -379,8 +384,11 @@ def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]
379
384
  for image_str in image_to_data:
380
385
  image_path = Path(image_str)
381
386
  image_data = image_to_data[image_str]
382
- image = overlay_masks(image_path, image_data)
383
- image = overlay_bboxes(image, image_data)
387
+ if "_counting_" in tool_result["tool_name"]:
388
+ image = overlay_heat_map(image_path, image_data)
389
+ else:
390
+ image = overlay_masks(image_path, image_data)
391
+ image = overlay_bboxes(image, image_data)
384
392
  with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
385
393
  image.save(f.name)
386
394
  visualized_images.append(f.name)
@@ -484,11 +492,21 @@ class VisionAgent(Agent):
484
492
  if image:
485
493
  question += f" Image name: {image}"
486
494
  if reference_data:
487
- if not ("image" in reference_data and "mask" in reference_data):
495
+ if not (
496
+ "image" in reference_data
497
+ and ("mask" in reference_data or "bbox" in reference_data)
498
+ ):
488
499
  raise ValueError(
489
- f"Reference data must contain 'image' and 'mask'. but got {reference_data}"
500
+ f"Reference data must contain 'image' and a visual prompt which can be 'mask' or 'bbox'. but got {reference_data}"
490
501
  )
491
- question += f" Reference image: {reference_data['image']}, Reference mask: {reference_data['mask']}"
502
+ visual_prompt_data = (
503
+ f"Reference mask: {reference_data['mask']}"
504
+ if "mask" in reference_data
505
+ else f"Reference bbox: {reference_data['bbox']}"
506
+ )
507
+ question += (
508
+ f" Reference image: {reference_data['image']}, {visual_prompt_data}"
509
+ )
492
510
 
493
511
  reflections = ""
494
512
  final_answer = ""
@@ -531,7 +549,6 @@ class VisionAgent(Agent):
531
549
  final_answer = answer_summarize(
532
550
  self.answer_model, question, answers, reflections
533
551
  )
534
-
535
552
  visualized_output = visualize_result(all_tool_results)
536
553
  all_tool_results.append({"visualized_output": visualized_output})
537
554
  if len(visualized_output) > 0:
@@ -4,7 +4,7 @@ import base64
4
4
  from importlib import resources
5
5
  from io import BytesIO
6
6
  from pathlib import Path
7
- from typing import Dict, Tuple, Union
7
+ from typing import Dict, Tuple, Union, List
8
8
 
9
9
  import numpy as np
10
10
  from PIL import Image, ImageDraw, ImageFont
@@ -34,6 +34,35 @@ COLORS = [
34
34
  ]
35
35
 
36
36
 
37
+ def normalize_bbox(
38
+ bbox: List[Union[int, float]], image_size: Tuple[int, ...]
39
+ ) -> List[float]:
40
+ r"""Normalize the bounding box coordinates to be between 0 and 1."""
41
+ x1, y1, x2, y2 = bbox
42
+ x1 = round(x1 / image_size[1], 2)
43
+ y1 = round(y1 / image_size[0], 2)
44
+ x2 = round(x2 / image_size[1], 2)
45
+ y2 = round(y2 / image_size[0], 2)
46
+ return [x1, y1, x2, y2]
47
+
48
+
49
+ def rle_decode(mask_rle: str, shape: Tuple[int, int]) -> np.ndarray:
50
+ r"""Decode a run-length encoded mask. Returns numpy array, 1 - mask, 0 - background.
51
+
52
+ Parameters:
53
+ mask_rle: Run-length as string formated (start length)
54
+ shape: The (height, width) of array to return
55
+ """
56
+ s = mask_rle.split()
57
+ starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
58
+ starts -= 1
59
+ ends = starts + lengths
60
+ img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
61
+ for lo, hi in zip(starts, ends):
62
+ img[lo:hi] = 1
63
+ return img.reshape(shape)
64
+
65
+
37
66
  def b64_to_pil(b64_str: str) -> ImageType:
38
67
  r"""Convert a base64 string to a PIL Image.
39
68
 
@@ -86,6 +115,26 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
86
115
  return base64.b64encode(arr_bytes).decode("utf-8")
87
116
 
88
117
 
118
+ def denormalize_bbox(
119
+ bbox: List[Union[int, float]], image_size: Tuple[int, ...]
120
+ ) -> List[float]:
121
+ r"""DeNormalize the bounding box coordinates so that they are in absolute values."""
122
+
123
+ if len(bbox) != 4:
124
+ raise ValueError("Bounding box must be of length 4.")
125
+
126
+ arr = np.array(bbox)
127
+ if np.all((arr >= 0) & (arr <= 1)):
128
+ x1, y1, x2, y2 = bbox
129
+ x1 = round(x1 * image_size[1])
130
+ y1 = round(y1 * image_size[0])
131
+ x2 = round(x2 * image_size[1])
132
+ y2 = round(y2 * image_size[0])
133
+ return [x1, y1, x2, y2]
134
+ else:
135
+ return bbox
136
+
137
+
89
138
  def overlay_bboxes(
90
139
  image: Union[str, Path, np.ndarray, ImageType], bboxes: Dict
91
140
  ) -> ImageType:
@@ -103,6 +152,9 @@ def overlay_bboxes(
103
152
  elif isinstance(image, np.ndarray):
104
153
  image = Image.fromarray(image)
105
154
 
155
+ if "bboxes" not in bboxes:
156
+ return image.convert("RGB")
157
+
106
158
  color = {
107
159
  label: COLORS[i % len(COLORS)] for i, label in enumerate(set(bboxes["labels"]))
108
160
  }
@@ -114,8 +166,6 @@ def overlay_bboxes(
114
166
  str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")),
115
167
  fontsize,
116
168
  )
117
- if "bboxes" not in bboxes:
118
- return image.convert("RGB")
119
169
 
120
170
  for label, box, scores in zip(bboxes["labels"], bboxes["bboxes"], bboxes["scores"]):
121
171
  box = [
@@ -150,11 +200,15 @@ def overlay_masks(
150
200
  elif isinstance(image, np.ndarray):
151
201
  image = Image.fromarray(image)
152
202
 
203
+ if "masks" not in masks:
204
+ return image.convert("RGB")
205
+
206
+ if "labels" not in masks:
207
+ masks["labels"] = [""] * len(masks["masks"])
208
+
153
209
  color = {
154
210
  label: COLORS[i % len(COLORS)] for i, label in enumerate(set(masks["labels"]))
155
211
  }
156
- if "masks" not in masks:
157
- return image.convert("RGB")
158
212
 
159
213
  for label, mask in zip(masks["labels"], masks["masks"]):
160
214
  if isinstance(mask, str):
@@ -164,3 +218,40 @@ def overlay_masks(
164
218
  mask_img = Image.fromarray(np_mask.astype(np.uint8))
165
219
  image = Image.alpha_composite(image.convert("RGBA"), mask_img)
166
220
  return image.convert("RGB")
221
+
222
+
223
+ def overlay_heat_map(
224
+ image: Union[str, Path, np.ndarray, ImageType], masks: Dict, alpha: float = 0.8
225
+ ) -> ImageType:
226
+ r"""Plots heat map on to an image.
227
+
228
+ Parameters:
229
+ image: the input image
230
+ masks: the heatmap to overlay
231
+ alpha: the transparency of the overlay
232
+
233
+ Returns:
234
+ The image with the heatmap overlayed
235
+ """
236
+ if isinstance(image, (str, Path)):
237
+ image = Image.open(image)
238
+ elif isinstance(image, np.ndarray):
239
+ image = Image.fromarray(image)
240
+
241
+ if "masks" not in masks:
242
+ return image.convert("RGB")
243
+
244
+ # Only one heat map per image, so no need to loop through masks
245
+ image = image.convert("L")
246
+
247
+ if isinstance(masks["masks"][0], str):
248
+ mask = b64_to_pil(masks["masks"][0])
249
+
250
+ overlay = Image.new("RGBA", mask.size)
251
+ odraw = ImageDraw.Draw(overlay)
252
+ odraw.bitmap(
253
+ (0, 0), mask, fill=(255, 0, 0, round(alpha * 255))
254
+ ) # fill=(R, G, B, Alpha)
255
+ combined = Image.alpha_composite(image.convert("RGBA"), overlay.resize(image.size))
256
+
257
+ return combined.convert("RGB")
vision_agent/llm/llm.py CHANGED
@@ -11,6 +11,7 @@ from vision_agent.tools import (
11
11
  SYSTEM_PROMPT,
12
12
  GroundingDINO,
13
13
  GroundingSAM,
14
+ ZeroShotCounting,
14
15
  )
15
16
 
16
17
 
@@ -127,6 +128,9 @@ class OpenAILLM(LLM):
127
128
 
128
129
  return lambda x: GroundingSAM()(**{"prompt": params["prompt"], "image": x})
129
130
 
131
+ def generate_zero_shot_counter(self, question: str) -> Callable:
132
+ return lambda x: ZeroShotCounting()(**{"image": x})
133
+
130
134
 
131
135
  class AzureOpenAILLM(OpenAILLM):
132
136
  def __init__(
vision_agent/lmm/lmm.py CHANGED
@@ -15,6 +15,7 @@ from vision_agent.tools import (
15
15
  SYSTEM_PROMPT,
16
16
  GroundingDINO,
17
17
  GroundingSAM,
18
+ ZeroShotCounting,
18
19
  )
19
20
 
20
21
  _LOGGER = logging.getLogger(__name__)
@@ -272,6 +273,9 @@ class OpenAILMM(LMM):
272
273
 
273
274
  return lambda x: GroundingSAM()(**{"prompt": params["prompt"], "image": x})
274
275
 
276
+ def generate_zero_shot_counter(self, question: str) -> Callable:
277
+ return lambda x: ZeroShotCounting()(**{"image": x})
278
+
275
279
 
276
280
  class AzureOpenAILMM(OpenAILMM):
277
281
  def __init__(
@@ -11,6 +11,8 @@ from .tools import ( # Counter,
11
11
  GroundingDINO,
12
12
  GroundingSAM,
13
13
  ImageCaption,
14
+ ZeroShotCounting,
15
+ VisualPromptCounting,
14
16
  SegArea,
15
17
  SegIoU,
16
18
  Tool,
@@ -9,7 +9,13 @@ import requests
9
9
  from PIL import Image
10
10
  from PIL.Image import Image as ImageType
11
11
 
12
- from vision_agent.image_utils import convert_to_b64, get_image_size
12
+ from vision_agent.image_utils import (
13
+ convert_to_b64,
14
+ get_image_size,
15
+ rle_decode,
16
+ normalize_bbox,
17
+ denormalize_bbox,
18
+ )
13
19
  from vision_agent.tools.video import extract_frames_from_video
14
20
  from vision_agent.type_defs import LandingaiAPIKey
15
21
 
@@ -18,35 +24,6 @@ _LND_API_KEY = LandingaiAPIKey().api_key
18
24
  _LND_API_URL = "https://api.dev.landing.ai/v1/agent"
19
25
 
20
26
 
21
- def normalize_bbox(
22
- bbox: List[Union[int, float]], image_size: Tuple[int, ...]
23
- ) -> List[float]:
24
- r"""Normalize the bounding box coordinates to be between 0 and 1."""
25
- x1, y1, x2, y2 = bbox
26
- x1 = round(x1 / image_size[1], 2)
27
- y1 = round(y1 / image_size[0], 2)
28
- x2 = round(x2 / image_size[1], 2)
29
- y2 = round(y2 / image_size[0], 2)
30
- return [x1, y1, x2, y2]
31
-
32
-
33
- def rle_decode(mask_rle: str, shape: Tuple[int, int]) -> np.ndarray:
34
- r"""Decode a run-length encoded mask. Returns numpy array, 1 - mask, 0 - background.
35
-
36
- Parameters:
37
- mask_rle: Run-length as string formated (start length)
38
- shape: The (height, width) of array to return
39
- """
40
- s = mask_rle.split()
41
- starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
42
- starts -= 1
43
- ends = starts + lengths
44
- img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
45
- for lo, hi in zip(starts, ends):
46
- img[lo:hi] = 1
47
- return img.reshape(shape)
48
-
49
-
50
27
  class Tool(ABC):
51
28
  name: str
52
29
  description: str
@@ -489,6 +466,130 @@ class AgentGroundingSAM(GroundingSAM):
489
466
  return rets
490
467
 
491
468
 
469
+ class ZeroShotCounting(Tool):
470
+ r"""ZeroShotCounting is a tool that can count total number of instances of an object
471
+ present in an image belonging to same class without a text or visual prompt.
472
+
473
+ Example
474
+ -------
475
+ >>> import vision_agent as va
476
+ >>> zshot_count = va.tools.ZeroShotCounting()
477
+ >>> zshot_count("image1.jpg")
478
+ {'count': 45}
479
+ """
480
+
481
+ name = "zero_shot_counting_"
482
+ description = "'zero_shot_counting_' is a tool that counts and returns the total number of instances of an object present in an image belonging to the same class without a text or visual prompt."
483
+
484
+ usage = {
485
+ "required_parameters": [
486
+ {"name": "image", "type": "str"},
487
+ ],
488
+ "examples": [
489
+ {
490
+ "scenario": "Can you count the lids in the image ? Image name: lids.jpg",
491
+ "parameters": {"image": "lids.jpg"},
492
+ },
493
+ {
494
+ "scenario": "Can you count the total number of objects in this image ? Image name: tray.jpg",
495
+ "parameters": {"image": "tray.jpg"},
496
+ },
497
+ {
498
+ "scenario": "Can you build me an object counting tool ? Image name: shirts.jpg",
499
+ "parameters": {
500
+ "image": "shirts.jpg",
501
+ },
502
+ },
503
+ ],
504
+ }
505
+
506
+ # TODO: Add support for input multiple images, which aligns with the output type.
507
+ def __call__(self, image: Union[str, ImageType]) -> Dict:
508
+ """Invoke the Image captioning model.
509
+
510
+ Parameters:
511
+ image: the input image.
512
+
513
+ Returns:
514
+ A dictionary containing the key 'count' and the count as value. E.g. {count: 12}
515
+ """
516
+ image_b64 = convert_to_b64(image)
517
+ data = {
518
+ "image": image_b64,
519
+ "tool": "zero_shot_counting",
520
+ }
521
+ return _send_inference_request(data, "tools")
522
+
523
+
524
+ class VisualPromptCounting(Tool):
525
+ r"""VisualPromptCounting is a tool that can count total number of instances of an object
526
+ present in an image belonging to same class with help of an visual prompt which is a bounding box.
527
+
528
+ Example
529
+ -------
530
+ >>> import vision_agent as va
531
+ >>> prompt_count = va.tools.VisualPromptCounting()
532
+ >>> prompt_count(image="image1.jpg", prompt="0.1, 0.1, 0.4, 0.42")
533
+ {'count': 23}
534
+ """
535
+
536
+ name = "visual_prompt_counting_"
537
+ description = "'visual_prompt_counting_' is a tool that can count and return total number of instances of an object present in an image belonging to the same class given an example bounding box."
538
+
539
+ usage = {
540
+ "required_parameters": [
541
+ {"name": "image", "type": "str"},
542
+ {"name": "prompt", "type": "str"},
543
+ ],
544
+ "examples": [
545
+ {
546
+ "scenario": "Here is an example of a lid '0.1, 0.1, 0.14, 0.2', Can you count the lids in the image ? Image name: lids.jpg",
547
+ "parameters": {"image": "lids.jpg", "prompt": "0.1, 0.1, 0.14, 0.2"},
548
+ },
549
+ {
550
+ "scenario": "Can you count the total number of objects in this image ? Image name: tray.jpg",
551
+ "parameters": {"image": "tray.jpg", "prompt": "0.1, 0.1, 0.2, 0.25"},
552
+ },
553
+ {
554
+ "scenario": "Can you build me a few shot object counting tool ? Image name: shirts.jpg",
555
+ "parameters": {
556
+ "image": "shirts.jpg",
557
+ "prompt": "0.1, 0.15, 0.2, 0.2",
558
+ },
559
+ },
560
+ {
561
+ "scenario": "Can you build me a counting tool based on an example prompt ? Image name: shoes.jpg",
562
+ "parameters": {
563
+ "image": "shoes.jpg",
564
+ "prompt": "0.1, 0.1, 0.6, 0.65",
565
+ },
566
+ },
567
+ ],
568
+ }
569
+
570
+ # TODO: Add support for input multiple images, which aligns with the output type.
571
+ def __call__(self, image: Union[str, ImageType], prompt: str) -> Dict:
572
+ """Invoke the Image captioning model.
573
+
574
+ Parameters:
575
+ image: the input image.
576
+
577
+ Returns:
578
+ A dictionary containing the key 'count' and the count as value. E.g. {count: 12}
579
+ """
580
+ image_size = get_image_size(image)
581
+ bbox = [float(x) for x in prompt.split(",")]
582
+ prompt = ", ".join(map(str, denormalize_bbox(bbox, image_size)))
583
+ image_b64 = convert_to_b64(image)
584
+
585
+ data = {
586
+ "image": image_b64,
587
+ "prompt": prompt,
588
+ "tool": "few_shot_counting",
589
+ }
590
+ return _send_inference_request(data, "tools")
591
+
592
+
492
593
  class Crop(Tool):
493
594
  r"""Crop crops an image given a bounding box and returns a file name of the cropped image."""
494
595
 
@@ -798,6 +899,8 @@ TOOLS = {
798
899
  ImageCaption,
799
900
  GroundingDINO,
800
901
  AgentGroundingSAM,
902
+ ZeroShotCounting,
903
+ VisualPromptCounting,
801
904
  AgentDINOv,
802
905
  ExtractFrames,
803
906
  Crop,
@@ -1,15 +1,14 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vision-agent
3
- Version: 0.1.5
3
+ Version: 0.2.1
4
4
  Summary: Toolset for Vision Agent
5
5
  Author: Landing AI
6
6
  Author-email: dev@landing.ai
7
- Requires-Python: >=3.9,<3.12
7
+ Requires-Python: >=3.9
8
8
  Classifier: Programming Language :: Python :: 3
9
9
  Classifier: Programming Language :: Python :: 3.9
10
10
  Classifier: Programming Language :: Python :: 3.10
11
11
  Classifier: Programming Language :: Python :: 3.11
12
- Requires-Dist: faiss-cpu (>=1.0.0,<2.0.0)
13
12
  Requires-Dist: moviepy (>=1.0.0,<2.0.0)
14
13
  Requires-Dist: numpy (>=1.21.0,<2.0.0)
15
14
  Requires-Dist: openai (>=1.0.0,<2.0.0)
@@ -18,9 +17,7 @@ Requires-Dist: pandas (>=2.0.0,<3.0.0)
18
17
  Requires-Dist: pillow (>=10.0.0,<11.0.0)
19
18
  Requires-Dist: pydantic-settings (>=2.2.1,<3.0.0)
20
19
  Requires-Dist: requests (>=2.0.0,<3.0.0)
21
- Requires-Dist: sentence-transformers (>=2.0.0,<3.0.0)
22
20
  Requires-Dist: tabulate (>=0.9.0,<0.10.0)
23
- Requires-Dist: torch (>=2.1.0,<2.2.0)
24
21
  Requires-Dist: tqdm (>=4.64.0,<5.0.0)
25
22
  Requires-Dist: typing_extensions (>=4.0.0,<5.0.0)
26
23
  Project-URL: Homepage, https://landing.ai
@@ -41,7 +38,7 @@ Description-Content-Type: text/markdown
41
38
 
42
39
  Vision Agent is a library that helps you utilize agent frameworks for your vision tasks.
43
40
  Many current vision problems can easily take hours or days to solve, you need to find the
44
- right model, figure out how to use it, possibly write programming logic around it to
41
+ right model, figure out how to use it, possibly write programming logic around it to
45
42
  accomplish the task you want or even more expensive, train your own model. Vision Agent
46
43
  aims to provide an in-seconds experience by allowing users to describe their problem in
47
44
  text and utilizing agent frameworks to solve the task for them. Check out our discord
@@ -138,6 +135,9 @@ you. For example:
138
135
  | BboxIoU | BboxIoU returns the intersection over union of two bounding boxes normalized to 2 decimal places. |
139
136
  | SegIoU | SegIoU returns the intersection over union of two segmentation masks normalized to 2 decimal places. |
140
137
  | ExtractFrames | ExtractFrames extracts frames with motion from a video. |
138
+ | ExtractFrames | ExtractFrames extracts frames with motion from a video. |
139
+ | ZeroShotCounting | ZeroShotCounting returns the total number of objects belonging to a single class in a given image |
140
+ | VisualPromptCounting | VisualPromptCounting returns the total number of objects belonging to a single class given an image and visual prompt |
141
141
 
142
142
 
143
143
  It also has a basic set of calculate tools such as add, subtract, multiply and divide.
@@ -1,29 +1,25 @@
1
- vision_agent/__init__.py,sha256=wD1cssVTAJ55uTViNfBGooqJUV0p9fmVAuTMHHrmUBU,229
1
+ vision_agent/__init__.py,sha256=GVLHCeK_R-zgldpbcPmOzJat-BkadvkuRCMxDvTIcXs,108
2
2
  vision_agent/agent/__init__.py,sha256=B4JVrbY4IRVCJfjmrgvcp7h1mTUEk8MZvL0Zmej4Ka0,127
3
3
  vision_agent/agent/agent.py,sha256=X7kON-g9ePUKumCDaYfQNBX_MEFE-ax5PnRp7-Cc5Wo,529
4
4
  vision_agent/agent/easytool.py,sha256=oMHnBg7YBtIPgqQUNcZgq7uMgpPThs99_UnO7ERkMVg,11511
5
5
  vision_agent/agent/easytool_prompts.py,sha256=zdQQw6WpXOmvwOMtlBlNKY5a3WNlr65dbUvMIGiqdeo,4526
6
6
  vision_agent/agent/reflexion.py,sha256=4gz30BuFMeGxSsTzoDV4p91yE0R8LISXp28IaOI6wdM,10506
7
7
  vision_agent/agent/reflexion_prompts.py,sha256=G7UAeNz_g2qCb2yN6OaIC7bQVUkda4m3z42EG8wAyfE,9342
8
- vision_agent/agent/vision_agent.py,sha256=Deuj28hqRq4wHnD08pU_7fok_EicvlGnDoINYh5hw1k,22853
8
+ vision_agent/agent/vision_agent.py,sha256=MTxeV5_Sghqoe2aOW9EbNgiq61sVCcF3ZndJ7BZl6x0,23588
9
9
  vision_agent/agent/vision_agent_prompts.py,sha256=W3Z72FpUt71UIJSkjAcgtQqxeMqkYuATqHAN5fYY26c,7342
10
- vision_agent/data/__init__.py,sha256=YU-5g3LbEQ6a4drz0RLGTagXMVU2Z4Xr3RlfWE-R0jU,46
11
- vision_agent/data/data.py,sha256=Z2l76OrT0GgyuN52OeJqDitUcP0q1rhfdXd1of3GsVo,5128
12
- vision_agent/emb/__init__.py,sha256=YmCkGrJBtXb6X6Z3lnKiFoQYKXMgHMJp8JJyMLVvqcI,75
13
- vision_agent/emb/emb.py,sha256=la9lhEzk7jqUCjYYQ5oRgVNSnC9_EJBJIpE_B9c6PJo,1375
14
10
  vision_agent/fonts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
11
  vision_agent/fonts/default_font_ch_en.ttf,sha256=1YM0Z3XqLDjSNbF7ihQFSAIUdjF9m1rtHiNC_6QosTE,1594400
16
- vision_agent/image_utils.py,sha256=qRN_Y1XXBm9EL6V53OZUq21h0spIa1J6X9YDbe6B87o,4805
12
+ vision_agent/image_utils.py,sha256=Cg4aKO1tQiETT1gdsZ50XzORBtJnBFfMG2cKJyjaY6Q,7555
17
13
  vision_agent/llm/__init__.py,sha256=BoUm_zSAKnLlE8s-gKTSQugXDqVZKPqYlWwlTLdhcz4,48
18
- vision_agent/llm/llm.py,sha256=Jty_RHdqVmIM0Mm31JNk50c882Tx7hHtkmh0WyXeJd8,5016
14
+ vision_agent/llm/llm.py,sha256=gwDQ9-p9wEn24xi1019e5jzTGQg4xWDSqBCsqIqGcU4,5168
19
15
  vision_agent/lmm/__init__.py,sha256=nnNeKD1k7q_4vLb1x51O_EUTYaBgGfeiCx5F433gr3M,67
20
- vision_agent/lmm/lmm.py,sha256=1E7e_S_0fOKnf6mSsEdkXvsIjGmhBGl5XW4By2jvhbY,10045
21
- vision_agent/tools/__init__.py,sha256=dkzk9amNzTEKULMB1xRJspqEGpzNPGuccWeXrv1xI0U,280
16
+ vision_agent/lmm/lmm.py,sha256=FjxCuIk0KXuWnfY4orVmdyhJW2I4C6i5QNNEXk7gybk,10197
17
+ vision_agent/tools/__init__.py,sha256=BlfxqbYkB0oODhnSmQg1UyzQm73AvvjCjrIiOWBIYDs,328
22
18
  vision_agent/tools/prompts.py,sha256=V1z4YJLXZuUl_iZ5rY0M5hHc_2tmMEUKr0WocXKGt4E,1430
23
- vision_agent/tools/tools.py,sha256=WIodfggPkz_2LSWn_Kqm9uvQUtCgKy3jmMoPVTwf1bA,31181
19
+ vision_agent/tools/tools.py,sha256=gCjHs5vJuGNBFsnJWFT7PX3wTyfHgtrgX1Eq9vqknN0,34979
24
20
  vision_agent/tools/video.py,sha256=xTElFSFp1Jw4ulOMnk81Vxsh-9dTxcWUO6P9fzEi3AM,7653
25
21
  vision_agent/type_defs.py,sha256=4LTnTL4HNsfYqCrDn9Ppjg9bSG2ZGcoKSSd9YeQf4Bw,1792
26
- vision_agent-0.1.5.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
27
- vision_agent-0.1.5.dist-info/METADATA,sha256=ubzhbZW7oT9sIaIkuM6QObXINZGz5Zcvgjdp7sUcsJE,6233
28
- vision_agent-0.1.5.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
29
- vision_agent-0.1.5.dist-info/RECORD,,
22
+ vision_agent-0.2.1.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
23
+ vision_agent-0.2.1.dist-info/METADATA,sha256=RAD8NCAo5N12sccgSC5Q0j4hKwU_rVKg5p_eLE-Njdc,6434
24
+ vision_agent-0.2.1.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
25
+ vision_agent-0.2.1.dist-info/RECORD,,
@@ -1 +0,0 @@
1
- from .data import DataStore, build_data_store
vision_agent/data/data.py DELETED
@@ -1,142 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import uuid
4
- from pathlib import Path
5
- from typing import Callable, Dict, List, Optional, Union, cast
6
-
7
- import faiss
8
- import numpy as np
9
- import numpy.typing as npt
10
- import pandas as pd
11
- from faiss import read_index, write_index
12
- from tqdm import tqdm
13
- from typing_extensions import Self
14
-
15
- from vision_agent.emb import Embedder
16
- from vision_agent.lmm import LMM
17
-
18
- tqdm.pandas()
19
-
20
-
21
- class DataStore:
22
- r"""A class to store and manage image data along with its generated metadata from an LMM."""
23
-
24
- def __init__(self, df: pd.DataFrame):
25
- r"""Initializes the DataStore with a DataFrame containing image paths and image
26
- IDs. If the image IDs are not present, they are generated using UUID4. The
27
- DataFrame must contain an 'image_paths' column.
28
-
29
- Args:
30
- df: The DataFrame containing "image_paths" and "image_id" columns.
31
- """
32
- self.df = df
33
- self.lmm: Optional[LMM] = None
34
- self.emb: Optional[Embedder] = None
35
- self.index: Optional[faiss.IndexFlatIP] = None # type: ignore
36
- if "image_paths" not in self.df.columns:
37
- raise ValueError("image_paths column must be present in DataFrame")
38
- if "image_id" not in self.df.columns:
39
- self.df["image_id"] = [str(uuid.uuid4()) for _ in range(len(df))]
40
-
41
- def add_embedder(self, emb: Embedder) -> Self:
42
- self.emb = emb
43
- return self
44
-
45
- def add_lmm(self, lmm: LMM) -> Self:
46
- self.lmm = lmm
47
- return self
48
-
49
- def add_column(
50
- self, name: str, prompt: str, func: Optional[Callable[[str], str]] = None
51
- ) -> Self:
52
- r"""Adds a new column to the DataFrame containing the generated metadata from
53
- the LMM.
54
-
55
- Args:
56
- name: The name of the column to be added.
57
- prompt: The prompt to be used to generate the metadata.
58
- func: A Python function to be applied on the output of `lmm.generate`.
59
- Defaults to None.
60
- """
61
- if self.lmm is None:
62
- raise ValueError("LMM not set yet")
63
-
64
- self.df[name] = self.df["image_paths"].progress_apply( # type: ignore
65
- lambda x: (
66
- func(self.lmm.generate(prompt, images=[x]))
67
- if func
68
- else self.lmm.generate(prompt, images=[x])
69
- )
70
- )
71
- return self
72
-
73
- def build_index(self, target_col: str) -> Self:
74
- r"""This will generate embeddings for the `target_col` and build a searchable
75
- index over them, so next time you run search it will search over this index.
76
-
77
- Args:
78
- target_col: The column name containing the data to be indexed."""
79
- if self.emb is None:
80
- raise ValueError("Embedder not set yet")
81
-
82
- embeddings: pd.Series = self.df[target_col].progress_apply(lambda x: self.emb.embed(x)) # type: ignore
83
- embeddings_np = np.array(embeddings.tolist()).astype(np.float32)
84
- self.index = faiss.IndexFlatIP(embeddings_np.shape[1])
85
- self.index.add(embeddings_np)
86
- return self
87
-
88
- def get_embeddings(self) -> npt.NDArray[np.float32]:
89
- if self.index is None:
90
- raise ValueError("Index not built yet")
91
-
92
- ntotal = self.index.ntotal
93
- d: int = self.index.d
94
- return cast(
95
- npt.NDArray[np.float32],
96
- faiss.rev_swig_ptr(self.index.get_xb(), ntotal * d).reshape(ntotal, d),
97
- )
98
-
99
- def search(self, query: str, top_k: int = 10) -> List[Dict]:
100
- r"""Searches the index for the most similar images to the query and returns
101
- the top_k results.
102
-
103
- Args:
104
- query: The query to search for.
105
- top_k: The number of results to return. Defaults to 10."""
106
- if self.index is None:
107
- raise ValueError("Index not built yet")
108
- if self.emb is None:
109
- raise ValueError("Embedder not set yet")
110
-
111
- query_embedding: npt.NDArray[np.float32] = self.emb.embed(query)
112
- _, idx = self.index.search(query_embedding.reshape(1, -1), top_k)
113
- return cast(List[Dict], self.df.iloc[idx[0]].to_dict(orient="records"))
114
-
115
- def save(self, path: Union[str, Path]) -> None:
116
- path = Path(path)
117
- path.mkdir(parents=True)
118
- self.df.to_csv(path / "data.csv")
119
- if self.index is not None:
120
- write_index(self.index, str(path / "data.index"))
121
-
122
- @classmethod
123
- def load(cls, path: Union[str, Path]) -> DataStore:
124
- path = Path(path)
125
- df = pd.read_csv(path / "data.csv", index_col=0)
126
- ds = DataStore(df)
127
- if Path(path / "data.index").exists():
128
- ds.index = read_index(str(path / "data.index"))
129
- return ds
130
-
131
-
132
- def build_data_store(data: Union[str, Path, list[Union[str, Path]]]) -> DataStore:
133
- if isinstance(data, Path) or isinstance(data, str):
134
- data = Path(data)
135
- data_files = list(Path(data).glob("*"))
136
- elif isinstance(data, list):
137
- data_files = [Path(d) for d in data]
138
-
139
- df = pd.DataFrame()
140
- df["image_paths"] = data_files
141
- df["image_id"] = [uuid.uuid4() for _ in range(len(data_files))]
142
- return DataStore(df)
@@ -1 +0,0 @@
1
- from .emb import Embedder, OpenAIEmb, SentenceTransformerEmb, get_embedder
vision_agent/emb/emb.py DELETED
@@ -1,47 +0,0 @@
1
- from abc import ABC, abstractmethod
2
- from typing import cast
3
-
4
- import numpy as np
5
- import numpy.typing as npt
6
-
7
-
8
- class Embedder(ABC):
9
- @abstractmethod
10
- def embed(self, text: str) -> npt.NDArray[np.float32]:
11
- pass
12
-
13
-
14
- class SentenceTransformerEmb(Embedder):
15
- def __init__(self, model_name: str = "BAAI/bge-small-en-v1.5"):
16
- from sentence_transformers import SentenceTransformer
17
-
18
- self.model = SentenceTransformer(model_name)
19
-
20
- def embed(self, text: str) -> npt.NDArray[np.float32]:
21
- return cast(
22
- npt.NDArray[np.float32],
23
- self.model.encode([text]).flatten().astype(np.float32),
24
- )
25
-
26
-
27
- class OpenAIEmb(Embedder):
28
- def __init__(self, model_name: str = "text-embedding-3-small"):
29
- from openai import OpenAI
30
-
31
- self.client = OpenAI()
32
- self.model_name = model_name
33
-
34
- def embed(self, text: str) -> npt.NDArray[np.float32]:
35
- response = self.client.embeddings.create(input=text, model=self.model_name)
36
- return np.array(response.data[0].embedding).astype(np.float32)
37
-
38
-
39
- def get_embedder(name: str) -> Embedder:
40
- if name == "sentence-transformer":
41
- return SentenceTransformerEmb()
42
- elif name == "openai":
43
- return OpenAIEmb()
44
- else:
45
- raise ValueError(
46
- f"Unknown embedder name: {name}, currently support sentence-transformer, openai."
47
- )