vision-agent 0.2.119__tar.gz → 0.2.121__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (33) hide show
  1. {vision_agent-0.2.119 → vision_agent-0.2.121}/PKG-INFO +1 -1
  2. {vision_agent-0.2.119 → vision_agent-0.2.121}/pyproject.toml +2 -1
  3. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/agent/vision_agent_coder_prompts.py +4 -5
  4. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/lmm/lmm.py +0 -3
  5. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/tools/__init__.py +3 -0
  6. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/tools/tool_utils.py +95 -51
  7. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/tools/tools.py +182 -8
  8. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/tools/tools_types.py +15 -2
  9. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/utils/image_utils.py +1 -1
  10. {vision_agent-0.2.119 → vision_agent-0.2.121}/LICENSE +0 -0
  11. {vision_agent-0.2.119 → vision_agent-0.2.121}/README.md +0 -0
  12. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/__init__.py +0 -0
  13. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/agent/__init__.py +0 -0
  14. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/agent/agent.py +0 -0
  15. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/agent/agent_utils.py +0 -0
  16. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/agent/vision_agent.py +0 -0
  17. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/agent/vision_agent_coder.py +0 -0
  18. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/agent/vision_agent_prompts.py +0 -0
  19. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/clients/__init__.py +0 -0
  20. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/clients/http.py +0 -0
  21. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/clients/landing_public_api.py +0 -0
  22. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/fonts/__init__.py +0 -0
  23. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/fonts/default_font_ch_en.ttf +0 -0
  24. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/lmm/__init__.py +0 -0
  25. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/lmm/types.py +0 -0
  26. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/tools/meta_tools.py +0 -0
  27. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/tools/prompts.py +0 -0
  28. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/utils/__init__.py +0 -0
  29. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/utils/exceptions.py +0 -0
  30. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/utils/execute.py +0 -0
  31. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/utils/sim.py +0 -0
  32. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/utils/type_defs.py +0 -0
  33. {vision_agent-0.2.119 → vision_agent-0.2.121}/vision_agent/utils/video.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vision-agent
3
- Version: 0.2.119
3
+ Version: 0.2.121
4
4
  Summary: Toolset for Vision Agent
5
5
  Author: Landing AI
6
6
  Author-email: dev@landing.ai
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "vision-agent"
7
- version = "0.2.119"
7
+ version = "0.2.121"
8
8
  description = "Toolset for Vision Agent"
9
9
  authors = ["Landing AI <dev@landing.ai>"]
10
10
  readme = "README.md"
@@ -56,6 +56,7 @@ types-pillow = "^9.5.0.4"
56
56
  data-science-types = "^0.2.23"
57
57
  types-tqdm = "^4.65.0.1"
58
58
  setuptools = "^68.0.0"
59
+ griffe = "^0.45.3"
59
60
  mkdocs = "^1.5.3"
60
61
  mkdocstrings = {extras = ["python"], version = "^0.23.0"}
61
62
  mkdocs-material = "^9.4.2"
@@ -81,20 +81,19 @@ plan2:
81
81
  - Count the number of detected objects labeled as 'person'.
82
82
  plan3:
83
83
  - Load the image from the provided file path 'image.jpg'.
84
- - Use the 'loca_zero_shot_counting' tool to count the dominant foreground object, which in this case is people.
84
+ - Use the 'countgd_counting' tool to count the dominant foreground object, which in this case is people.
85
85
 
86
86
  ```python
87
- from vision_agent.tools import load_image, owl_v2, grounding_sam, loca_zero_shot_counting
87
+ from vision_agent.tools import load_image, owl_v2, grounding_sam, countgd_counting
88
88
  image = load_image("image.jpg")
89
89
  owl_v2_out = owl_v2("person", image)
90
90
 
91
91
  gsam_out = grounding_sam("person", image)
92
92
  gsam_out = [{{k: v for k, v in o.items() if k != "mask"}} for o in gsam_out]
93
93
 
94
- loca_out = loca_zero_shot_counting(image)
95
- loca_out = loca_out["count"]
94
+ cgd_out = countgd_counting(image)
96
95
 
97
- final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "loca_zero_shot_counting": loca_out}}
96
+ final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "countgd_counting": cgd_out}}
98
97
  print(final_out)
99
98
  ```
100
99
  """
@@ -286,9 +286,6 @@ class OpenAILMM(LMM):
286
286
 
287
287
  return lambda x: T.grounding_sam(params["prompt"], x)
288
288
 
289
- def generate_zero_shot_counter(self, question: str) -> Callable:
290
- return T.loca_zero_shot_counting
291
-
292
289
  def generate_image_qa_tool(self, question: str) -> Callable:
293
290
  return lambda x: T.git_vqa_v2(question, x)
294
291
 
@@ -37,10 +37,13 @@ from .tools import (
37
37
  load_image,
38
38
  loca_visual_prompt_counting,
39
39
  loca_zero_shot_counting,
40
+ countgd_counting,
41
+ countgd_example_based_counting,
40
42
  ocr,
41
43
  overlay_bounding_boxes,
42
44
  overlay_heat_map,
43
45
  overlay_segmentation_masks,
46
+ overlay_counting_results,
44
47
  owl_v2,
45
48
  save_image,
46
49
  save_json,
@@ -1,6 +1,6 @@
1
+ import os
1
2
  import inspect
2
3
  import logging
3
- import os
4
4
  from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple
5
5
 
6
6
  import pandas as pd
@@ -13,6 +13,7 @@ from urllib3.util.retry import Retry
13
13
  from vision_agent.utils.exceptions import RemoteToolCallFailed
14
14
  from vision_agent.utils.execute import Error, MimeType
15
15
  from vision_agent.utils.type_defs import LandingaiAPIKey
16
+ from vision_agent.tools.tools_types import BoundingBoxes
16
17
 
17
18
  _LOGGER = logging.getLogger(__name__)
18
19
  _LND_API_KEY = os.environ.get("LANDINGAI_API_KEY", LandingaiAPIKey().api_key)
@@ -34,61 +35,58 @@ def send_inference_request(
34
35
  files: Optional[List[Tuple[Any, ...]]] = None,
35
36
  v2: bool = False,
36
37
  metadata_payload: Optional[Dict[str, Any]] = None,
37
- ) -> Dict[str, Any]:
38
+ ) -> Any:
38
39
  # TODO: runtime_tag and function_name should be metadata_payload and now included
39
40
  # in the service payload
40
- try:
41
- if runtime_tag := os.environ.get("RUNTIME_TAG", ""):
42
- payload["runtime_tag"] = runtime_tag
41
+ if runtime_tag := os.environ.get("RUNTIME_TAG", ""):
42
+ payload["runtime_tag"] = runtime_tag
43
+
44
+ url = f"{_LND_API_URL_v2 if v2 else _LND_API_URL}/{endpoint_name}"
45
+ if "TOOL_ENDPOINT_URL" in os.environ:
46
+ url = os.environ["TOOL_ENDPOINT_URL"]
47
+
48
+ headers = {"apikey": _LND_API_KEY}
49
+ if "TOOL_ENDPOINT_AUTH" in os.environ:
50
+ headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"]
51
+ headers.pop("apikey")
52
+
53
+ session = _create_requests_session(
54
+ url=url,
55
+ num_retry=3,
56
+ headers=headers,
57
+ )
43
58
 
44
- url = f"{_LND_API_URL_v2 if v2 else _LND_API_URL}/{endpoint_name}"
45
- if "TOOL_ENDPOINT_URL" in os.environ:
46
- url = os.environ["TOOL_ENDPOINT_URL"]
59
+ function_name = "unknown"
60
+ if "function_name" in payload:
61
+ function_name = payload["function_name"]
62
+ elif metadata_payload is not None and "function_name" in metadata_payload:
63
+ function_name = metadata_payload["function_name"]
47
64
 
48
- tool_call_trace = ToolCallTrace(
49
- endpoint_url=url,
50
- request=payload,
51
- response={},
52
- error=None,
53
- )
54
- headers = {"apikey": _LND_API_KEY}
55
- if "TOOL_ENDPOINT_AUTH" in os.environ:
56
- headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"]
57
- headers.pop("apikey")
58
-
59
- session = _create_requests_session(
60
- url=url,
61
- num_retry=3,
62
- headers=headers,
63
- )
65
+ response = _call_post(url, payload, session, files, function_name)
64
66
 
65
- if files is not None:
66
- res = session.post(url, data=payload, files=files)
67
- else:
68
- res = session.post(url, json=payload)
69
- if res.status_code != 200:
70
- tool_call_trace.error = Error(
71
- name="RemoteToolCallFailed",
72
- value=f"{res.status_code} - {res.text}",
73
- traceback_raw=[],
74
- )
75
- _LOGGER.error(f"Request failed: {res.status_code} {res.text}")
76
- # TODO: function_name should be in metadata_payload
77
- function_name = "unknown"
78
- if "function_name" in payload:
79
- function_name = payload["function_name"]
80
- elif metadata_payload is not None and "function_name" in metadata_payload:
81
- function_name = metadata_payload["function_name"]
82
- raise RemoteToolCallFailed(function_name, res.status_code, res.text)
83
-
84
- resp = res.json()
85
- tool_call_trace.response = resp
86
- # TODO: consider making the response schema the same between below two sources
87
- return resp if "TOOL_ENDPOINT_AUTH" in os.environ else resp["data"] # type: ignore
88
- finally:
89
- trace = tool_call_trace.model_dump()
90
- trace["type"] = "tool_call"
91
- display({MimeType.APPLICATION_JSON: trace}, raw=True)
67
+ # TODO: consider making the response schema the same between below two sources
68
+ return response if "TOOL_ENDPOINT_AUTH" in os.environ else response["data"]
69
+
70
+
71
+ def send_task_inference_request(
72
+ payload: Dict[str, Any],
73
+ task_name: str,
74
+ files: Optional[List[Tuple[Any, ...]]] = None,
75
+ metadata: Optional[Dict[str, Any]] = None,
76
+ ) -> Any:
77
+ url = f"{_LND_API_URL_v2}/{task_name}"
78
+ headers = {"apikey": _LND_API_KEY}
79
+ session = _create_requests_session(
80
+ url=url,
81
+ num_retry=3,
82
+ headers=headers,
83
+ )
84
+
85
+ function_name = "unknown"
86
+ if metadata is not None and "function_name" in metadata:
87
+ function_name = metadata["function_name"]
88
+ response = _call_post(url, payload, session, files, function_name)
89
+ return response["data"]
92
90
 
93
91
 
94
92
  def _create_requests_session(
@@ -195,3 +193,49 @@ def get_tools_info(funcs: List[Callable[..., Any]]) -> Dict[str, str]:
195
193
  data[func.__name__] = f"{func.__name__}{inspect.signature(func)}:\n{desc}"
196
194
 
197
195
  return data
196
+
197
+
198
+ def _call_post(
199
+ url: str,
200
+ payload: dict[str, Any],
201
+ session: Session,
202
+ files: Optional[List[Tuple[Any, ...]]] = None,
203
+ function_name: str = "unknown",
204
+ ) -> Any:
205
+ try:
206
+ tool_call_trace = ToolCallTrace(
207
+ endpoint_url=url,
208
+ request=payload,
209
+ response={},
210
+ error=None,
211
+ )
212
+
213
+ if files is not None:
214
+ response = session.post(url, data=payload, files=files)
215
+ else:
216
+ response = session.post(url, json=payload)
217
+
218
+ if response.status_code != 200:
219
+ tool_call_trace.error = Error(
220
+ name="RemoteToolCallFailed",
221
+ value=f"{response.status_code} - {response.text}",
222
+ traceback_raw=[],
223
+ )
224
+ _LOGGER.error(f"Request failed: {response.status_code} {response.text}")
225
+ raise RemoteToolCallFailed(
226
+ function_name, response.status_code, response.text
227
+ )
228
+
229
+ result = response.json()
230
+ tool_call_trace.response = result
231
+ return result
232
+ finally:
233
+ trace = tool_call_trace.model_dump()
234
+ trace["type"] = "tool_call"
235
+ display({MimeType.APPLICATION_JSON: trace}, raw=True)
236
+
237
+
238
+ def filter_bboxes_by_threshold(
239
+ bboxes: BoundingBoxes, threshold: float
240
+ ) -> BoundingBoxes:
241
+ return list(filter(lambda bbox: bbox.score >= threshold, bboxes))
@@ -13,7 +13,7 @@ import cv2
13
13
  import numpy as np
14
14
  import requests
15
15
  from moviepy.editor import ImageSequenceClip
16
- from PIL import Image, ImageDraw, ImageFont
16
+ from PIL import Image, ImageDraw, ImageFont, ImageEnhance
17
17
  from pillow_heif import register_heif_opener # type: ignore
18
18
  from pytube import YouTube # type: ignore
19
19
 
@@ -24,6 +24,8 @@ from vision_agent.tools.tool_utils import (
24
24
  get_tools_df,
25
25
  get_tools_info,
26
26
  send_inference_request,
27
+ send_task_inference_request,
28
+ filter_bboxes_by_threshold,
27
29
  )
28
30
  from vision_agent.tools.tools_types import (
29
31
  BboxInput,
@@ -32,6 +34,7 @@ from vision_agent.tools.tools_types import (
32
34
  Florencev2FtRequest,
33
35
  JobStatus,
34
36
  PromptTask,
37
+ ODResponseData,
35
38
  )
36
39
  from vision_agent.utils import extract_frames_from_video
37
40
  from vision_agent.utils.exceptions import FineTuneModelIsNotReady
@@ -455,7 +458,7 @@ def loca_zero_shot_counting(image: np.ndarray) -> Dict[str, Any]:
455
458
  "image": image_b64,
456
459
  "function_name": "loca_zero_shot_counting",
457
460
  }
458
- resp_data = send_inference_request(data, "loca", v2=True)
461
+ resp_data: dict[str, Any] = send_inference_request(data, "loca", v2=True)
459
462
  resp_data["heat_map"] = np.array(resp_data["heat_map"][0]).astype(np.uint8)
460
463
  return resp_data
461
464
 
@@ -469,6 +472,8 @@ def loca_visual_prompt_counting(
469
472
 
470
473
  Parameters:
471
474
  image (np.ndarray): The image that contains lot of instances of a single object
475
+ visual_prompt (Dict[str, List[float]]): Bounding box of the object in format
476
+ [xmin, ymin, xmax, ymax]. Only 1 bounding box can be provided.
472
477
 
473
478
  Returns:
474
479
  Dict[str, Any]: A dictionary containing the key 'count' and the count as a
@@ -496,11 +501,109 @@ def loca_visual_prompt_counting(
496
501
  "bbox": list(map(int, denormalize_bbox(bbox, image_size))),
497
502
  "function_name": "loca_visual_prompt_counting",
498
503
  }
499
- resp_data = send_inference_request(data, "loca", v2=True)
504
+ resp_data: dict[str, Any] = send_inference_request(data, "loca", v2=True)
500
505
  resp_data["heat_map"] = np.array(resp_data["heat_map"][0]).astype(np.uint8)
501
506
  return resp_data
502
507
 
503
508
 
509
+ def countgd_counting(
510
+ prompt: str,
511
+ image: np.ndarray,
512
+ box_threshold: float = 0.23,
513
+ ) -> List[Dict[str, Any]]:
514
+ """'countgd_counting' is a tool that can precisely count multiple instances of an
515
+ object given a text prompt. It returns a list of bounding boxes with normalized
516
+ coordinates, label names and associated confidence scores.
517
+
518
+ Parameters:
519
+ prompt (str): The object that needs to be counted.
520
+ image (np.ndarray): The image that contains multiple instances of the object.
521
+ box_threshold (float, optional): The threshold for detection. Defaults
522
+ to 0.23.
523
+
524
+ Returns:
525
+ List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
526
+ bounding box of the detected objects with normalized coordinates between 0
527
+ and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
528
+ top-left and xmax and ymax are the coordinates of the bottom-right of the
529
+ bounding box.
530
+
531
+ Example
532
+ -------
533
+ >>> countgd_counting("flower", image)
534
+ [
535
+ {'score': 0.49, 'label': 'flower', 'bbox': [0.1, 0.11, 0.35, 0.4]},
536
+ {'score': 0.68, 'label': 'flower', 'bbox': [0.2, 0.21, 0.45, 0.5},
537
+ {'score': 0.78, 'label': 'flower', 'bbox': [0.3, 0.35, 0.48, 0.52},
538
+ {'score': 0.98, 'label': 'flower', 'bbox': [0.44, 0.24, 0.49, 0.58},
539
+ ]
540
+ """
541
+ buffer_bytes = numpy_to_bytes(image)
542
+ files = [("image", buffer_bytes)]
543
+ prompt = prompt.replace(", ", " .")
544
+ payload = {"prompts": [prompt], "model": "countgd"}
545
+ metadata = {"function_name": "countgd_counting"}
546
+ resp_data = send_task_inference_request(
547
+ payload, "text-to-object-detection", files=files, metadata=metadata
548
+ )
549
+ bboxes_per_frame = resp_data[0]
550
+ bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame]
551
+ filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold)
552
+ return [bbox.model_dump() for bbox in filtered_bboxes]
553
+
554
+
555
+ def countgd_example_based_counting(
556
+ visual_prompts: List[List[float]],
557
+ image: np.ndarray,
558
+ box_threshold: float = 0.23,
559
+ ) -> List[Dict[str, Any]]:
560
+ """'countgd_example_based_counting' is a tool that can precisely count multiple
561
+ instances of an object given few visual example prompts. It returns a list of bounding
562
+ boxes with normalized coordinates, label names and associated confidence scores.
563
+
564
+ Parameters:
565
+ visual_prompts (List[List[float]]): Bounding boxes of the object in format
566
+ [xmin, ymin, xmax, ymax]. Upto 3 bounding boxes can be provided.
567
+ image (np.ndarray): The image that contains multiple instances of the object.
568
+ box_threshold (float, optional): The threshold for detection. Defaults
569
+ to 0.23.
570
+
571
+ Returns:
572
+ List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
573
+ bounding box of the detected objects with normalized coordinates between 0
574
+ and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
575
+ top-left and xmax and ymax are the coordinates of the bottom-right of the
576
+ bounding box.
577
+
578
+ Example
579
+ -------
580
+ >>> countgd_example_based_counting(
581
+ visual_prompts=[[0.1, 0.1, 0.4, 0.42], [0.2, 0.3, 0.25, 0.35]],
582
+ image=image
583
+ )
584
+ [
585
+ {'score': 0.49, 'label': 'object', 'bounding_box': [0.1, 0.11, 0.35, 0.4]},
586
+ {'score': 0.68, 'label': 'object', 'bounding_box': [0.2, 0.21, 0.45, 0.5},
587
+ {'score': 0.78, 'label': 'object', 'bounding_box': [0.3, 0.35, 0.48, 0.52},
588
+ {'score': 0.98, 'label': 'object', 'bounding_box': [0.44, 0.24, 0.49, 0.58},
589
+ ]
590
+ """
591
+ buffer_bytes = numpy_to_bytes(image)
592
+ files = [("image", buffer_bytes)]
593
+ visual_prompts = [
594
+ denormalize_bbox(bbox, image.shape[:2]) for bbox in visual_prompts
595
+ ]
596
+ payload = {"visual_prompts": json.dumps(visual_prompts), "model": "countgd"}
597
+ metadata = {"function_name": "countgd_example_based_counting"}
598
+ resp_data = send_task_inference_request(
599
+ payload, "visual-prompts-to-object-detection", files=files, metadata=metadata
600
+ )
601
+ bboxes_per_frame = resp_data[0]
602
+ bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame]
603
+ filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold)
604
+ return [bbox.model_dump() for bbox in filtered_bboxes]
605
+
606
+
504
607
  def florence2_roberta_vqa(prompt: str, image: np.ndarray) -> str:
505
608
  """'florence2_roberta_vqa' is a tool that takes an image and analyzes
506
609
  its contents, generates detailed captions and then tries to answer the given
@@ -646,7 +749,7 @@ def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]:
646
749
  "tool": "closed_set_image_classification",
647
750
  "function_name": "clip",
648
751
  }
649
- resp_data = send_inference_request(data, "tools")
752
+ resp_data: dict[str, Any] = send_inference_request(data, "tools")
650
753
  resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]]
651
754
  return resp_data
652
755
 
@@ -674,7 +777,7 @@ def vit_image_classification(image: np.ndarray) -> Dict[str, Any]:
674
777
  "tool": "image_classification",
675
778
  "function_name": "vit_image_classification",
676
779
  }
677
- resp_data = send_inference_request(data, "tools")
780
+ resp_data: dict[str, Any] = send_inference_request(data, "tools")
678
781
  resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]]
679
782
  return resp_data
680
783
 
@@ -701,7 +804,9 @@ def vit_nsfw_classification(image: np.ndarray) -> Dict[str, Any]:
701
804
  "image": image_b64,
702
805
  "function_name": "vit_nsfw_classification",
703
806
  }
704
- resp_data = send_inference_request(data, "nsfw-classification", v2=True)
807
+ resp_data: dict[str, Any] = send_inference_request(
808
+ data, "nsfw-classification", v2=True
809
+ )
705
810
  resp_data["score"] = round(resp_data["score"], 4)
706
811
  return resp_data
707
812
 
@@ -1559,6 +1664,74 @@ def overlay_heat_map(
1559
1664
  return np.array(combined)
1560
1665
 
1561
1666
 
1667
+ def overlay_counting_results(
1668
+ image: np.ndarray, instances: List[Dict[str, Any]]
1669
+ ) -> np.ndarray:
1670
+ """'overlay_counting_results' is a utility function that displays counting results on
1671
+ an image.
1672
+
1673
+ Parameters:
1674
+ image (np.ndarray): The image to display the bounding boxes on.
1675
+ instances (List[Dict[str, Any]]): A list of dictionaries containing the bounding
1676
+ box information of each instance
1677
+
1678
+ Returns:
1679
+ np.ndarray: The image with the instance_id dislpayed
1680
+
1681
+ Example
1682
+ -------
1683
+ >>> image_with_bboxes = overlay_counting_results(
1684
+ image, [{'score': 0.99, 'label': 'object', 'bbox': [0.1, 0.11, 0.35, 0.4]}],
1685
+ )
1686
+ """
1687
+ pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB")
1688
+ color = (158, 218, 229)
1689
+
1690
+ width, height = pil_image.size
1691
+ fontsize = max(10, int(min(width, height) / 80))
1692
+ pil_image = ImageEnhance.Brightness(pil_image).enhance(0.5)
1693
+ draw = ImageDraw.Draw(pil_image)
1694
+ font = ImageFont.truetype(
1695
+ str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")),
1696
+ fontsize,
1697
+ )
1698
+
1699
+ for i, elt in enumerate(instances):
1700
+ label = f"{i}"
1701
+ box = elt["bbox"]
1702
+
1703
+ # denormalize the box if it is normalized
1704
+ box = denormalize_bbox(box, (height, width))
1705
+ x0, y0, x1, y1 = box
1706
+ cx, cy = (x0 + x1) / 2, (y0 + y1) / 2
1707
+
1708
+ text_box = draw.textbbox(
1709
+ (cx, cy), text=label, font=font, align="center", anchor="mm"
1710
+ )
1711
+
1712
+ # Calculate the offset to center the text within the bounding box
1713
+ text_width = text_box[2] - text_box[0]
1714
+ text_height = text_box[3] - text_box[1]
1715
+ text_x0 = cx - text_width / 2
1716
+ text_y0 = cy - text_height / 2
1717
+ text_x1 = cx + text_width / 2
1718
+ text_y1 = cy + text_height / 2
1719
+
1720
+ # Draw the rectangle encapsulating the text
1721
+ draw.rectangle((text_x0, text_y0, text_x1, text_y1), fill=color)
1722
+
1723
+ # Draw the text at the center of the bounding box
1724
+ draw.text(
1725
+ (text_x0, text_y0),
1726
+ label,
1727
+ fill="black",
1728
+ font=font,
1729
+ anchor="lt",
1730
+ )
1731
+
1732
+ return np.array(pil_image)
1733
+
1734
+
1562
1735
  # TODO: add this function to the imports so that is picked in the agent
1563
1736
  def florencev2_fine_tuning(bboxes: List[Dict[str, Any]], task: str) -> UUID:
1564
1737
  """'florencev2_fine_tuning' is a tool that fine-tune florencev2 to be able
@@ -1679,8 +1852,7 @@ FUNCTION_TOOLS = [
1679
1852
  clip,
1680
1853
  vit_image_classification,
1681
1854
  vit_nsfw_classification,
1682
- loca_zero_shot_counting,
1683
- loca_visual_prompt_counting,
1855
+ countgd_counting,
1684
1856
  florence2_image_caption,
1685
1857
  florence2_ocr,
1686
1858
  florence2_sam2_image,
@@ -1703,6 +1875,7 @@ UTIL_TOOLS = [
1703
1875
  overlay_bounding_boxes,
1704
1876
  overlay_segmentation_masks,
1705
1877
  overlay_heat_map,
1878
+ overlay_counting_results,
1706
1879
  ]
1707
1880
 
1708
1881
  TOOLS = FUNCTION_TOOLS + UTIL_TOOLS
@@ -1720,5 +1893,6 @@ UTILITIES_DOCSTRING = get_tool_documentation(
1720
1893
  overlay_bounding_boxes,
1721
1894
  overlay_segmentation_masks,
1722
1895
  overlay_heat_map,
1896
+ overlay_counting_results,
1723
1897
  ]
1724
1898
  )
@@ -1,8 +1,8 @@
1
1
  from enum import Enum
2
- from typing import List, Optional, Tuple
3
2
  from uuid import UUID
3
+ from typing import List, Tuple, Optional, Union
4
4
 
5
- from pydantic import BaseModel, ConfigDict, Field, SerializationInfo, field_serializer
5
+ from pydantic import BaseModel, ConfigDict, Field, field_serializer, SerializationInfo
6
6
 
7
7
 
8
8
  class BboxInput(BaseModel):
@@ -82,3 +82,16 @@ class JobStatus(str, Enum):
82
82
  SUCCEEDED = "SUCCEEDED"
83
83
  FAILED = "FAILED"
84
84
  STOPPED = "STOPPED"
85
+
86
+
87
+ class ODResponseData(BaseModel):
88
+ label: str
89
+ score: float
90
+ bbox: Union[list[int], list[float]] = Field(alias="bounding_box")
91
+
92
+ model_config = ConfigDict(
93
+ populate_by_name=True,
94
+ )
95
+
96
+
97
+ BoundingBoxes = list[ODResponseData]
@@ -181,7 +181,7 @@ def denormalize_bbox(
181
181
  raise ValueError("Bounding box must be of length 4.")
182
182
 
183
183
  arr = np.array(bbox)
184
- if np.all((arr >= 0) & (arr <= 1)):
184
+ if np.all((arr[:2] >= 0) & (arr[:2] <= 1)):
185
185
  x1, y1, x2, y2 = bbox
186
186
  x1 = round(x1 * image_size[1])
187
187
  y1 = round(y1 * image_size[0])
File without changes
File without changes