vision-agent 0.2.120__tar.gz → 0.2.121__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {vision_agent-0.2.120 → vision_agent-0.2.121}/PKG-INFO +1 -1
- {vision_agent-0.2.120 → vision_agent-0.2.121}/pyproject.toml +1 -1
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/agent/vision_agent_coder_prompts.py +4 -5
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/lmm/lmm.py +0 -3
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/tools/__init__.py +3 -0
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/tools/tool_utils.py +95 -51
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/tools/tools.py +182 -8
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/tools/tools_types.py +15 -2
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/utils/image_utils.py +1 -1
- {vision_agent-0.2.120 → vision_agent-0.2.121}/LICENSE +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.121}/README.md +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/__init__.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/agent/__init__.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/agent/agent.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/agent/agent_utils.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/agent/vision_agent.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/agent/vision_agent_coder.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/agent/vision_agent_prompts.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/clients/__init__.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/clients/http.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/clients/landing_public_api.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/fonts/__init__.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/fonts/default_font_ch_en.ttf +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/lmm/__init__.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/lmm/types.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/tools/meta_tools.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/tools/prompts.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/utils/__init__.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/utils/exceptions.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/utils/execute.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/utils/sim.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/utils/type_defs.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/utils/video.py +0 -0
{vision_agent-0.2.120 → vision_agent-0.2.121}/vision_agent/agent/vision_agent_coder_prompts.py
RENAMED
@@ -81,20 +81,19 @@ plan2:
|
|
81
81
|
- Count the number of detected objects labeled as 'person'.
|
82
82
|
plan3:
|
83
83
|
- Load the image from the provided file path 'image.jpg'.
|
84
|
-
- Use the '
|
84
|
+
- Use the 'countgd_counting' tool to count the dominant foreground object, which in this case is people.
|
85
85
|
|
86
86
|
```python
|
87
|
-
from vision_agent.tools import load_image, owl_v2, grounding_sam,
|
87
|
+
from vision_agent.tools import load_image, owl_v2, grounding_sam, countgd_counting
|
88
88
|
image = load_image("image.jpg")
|
89
89
|
owl_v2_out = owl_v2("person", image)
|
90
90
|
|
91
91
|
gsam_out = grounding_sam("person", image)
|
92
92
|
gsam_out = [{{k: v for k, v in o.items() if k != "mask"}} for o in gsam_out]
|
93
93
|
|
94
|
-
|
95
|
-
loca_out = loca_out["count"]
|
94
|
+
cgd_out = countgd_counting(image)
|
96
95
|
|
97
|
-
final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "
|
96
|
+
final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "countgd_counting": cgd_out}}
|
98
97
|
print(final_out)
|
99
98
|
```
|
100
99
|
"""
|
@@ -286,9 +286,6 @@ class OpenAILMM(LMM):
|
|
286
286
|
|
287
287
|
return lambda x: T.grounding_sam(params["prompt"], x)
|
288
288
|
|
289
|
-
def generate_zero_shot_counter(self, question: str) -> Callable:
|
290
|
-
return T.loca_zero_shot_counting
|
291
|
-
|
292
289
|
def generate_image_qa_tool(self, question: str) -> Callable:
|
293
290
|
return lambda x: T.git_vqa_v2(question, x)
|
294
291
|
|
@@ -37,10 +37,13 @@ from .tools import (
|
|
37
37
|
load_image,
|
38
38
|
loca_visual_prompt_counting,
|
39
39
|
loca_zero_shot_counting,
|
40
|
+
countgd_counting,
|
41
|
+
countgd_example_based_counting,
|
40
42
|
ocr,
|
41
43
|
overlay_bounding_boxes,
|
42
44
|
overlay_heat_map,
|
43
45
|
overlay_segmentation_masks,
|
46
|
+
overlay_counting_results,
|
44
47
|
owl_v2,
|
45
48
|
save_image,
|
46
49
|
save_json,
|
@@ -1,6 +1,6 @@
|
|
1
|
+
import os
|
1
2
|
import inspect
|
2
3
|
import logging
|
3
|
-
import os
|
4
4
|
from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple
|
5
5
|
|
6
6
|
import pandas as pd
|
@@ -13,6 +13,7 @@ from urllib3.util.retry import Retry
|
|
13
13
|
from vision_agent.utils.exceptions import RemoteToolCallFailed
|
14
14
|
from vision_agent.utils.execute import Error, MimeType
|
15
15
|
from vision_agent.utils.type_defs import LandingaiAPIKey
|
16
|
+
from vision_agent.tools.tools_types import BoundingBoxes
|
16
17
|
|
17
18
|
_LOGGER = logging.getLogger(__name__)
|
18
19
|
_LND_API_KEY = os.environ.get("LANDINGAI_API_KEY", LandingaiAPIKey().api_key)
|
@@ -34,61 +35,58 @@ def send_inference_request(
|
|
34
35
|
files: Optional[List[Tuple[Any, ...]]] = None,
|
35
36
|
v2: bool = False,
|
36
37
|
metadata_payload: Optional[Dict[str, Any]] = None,
|
37
|
-
) ->
|
38
|
+
) -> Any:
|
38
39
|
# TODO: runtime_tag and function_name should be metadata_payload and now included
|
39
40
|
# in the service payload
|
40
|
-
|
41
|
-
|
42
|
-
|
41
|
+
if runtime_tag := os.environ.get("RUNTIME_TAG", ""):
|
42
|
+
payload["runtime_tag"] = runtime_tag
|
43
|
+
|
44
|
+
url = f"{_LND_API_URL_v2 if v2 else _LND_API_URL}/{endpoint_name}"
|
45
|
+
if "TOOL_ENDPOINT_URL" in os.environ:
|
46
|
+
url = os.environ["TOOL_ENDPOINT_URL"]
|
47
|
+
|
48
|
+
headers = {"apikey": _LND_API_KEY}
|
49
|
+
if "TOOL_ENDPOINT_AUTH" in os.environ:
|
50
|
+
headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"]
|
51
|
+
headers.pop("apikey")
|
52
|
+
|
53
|
+
session = _create_requests_session(
|
54
|
+
url=url,
|
55
|
+
num_retry=3,
|
56
|
+
headers=headers,
|
57
|
+
)
|
43
58
|
|
44
|
-
|
45
|
-
|
46
|
-
|
59
|
+
function_name = "unknown"
|
60
|
+
if "function_name" in payload:
|
61
|
+
function_name = payload["function_name"]
|
62
|
+
elif metadata_payload is not None and "function_name" in metadata_payload:
|
63
|
+
function_name = metadata_payload["function_name"]
|
47
64
|
|
48
|
-
|
49
|
-
endpoint_url=url,
|
50
|
-
request=payload,
|
51
|
-
response={},
|
52
|
-
error=None,
|
53
|
-
)
|
54
|
-
headers = {"apikey": _LND_API_KEY}
|
55
|
-
if "TOOL_ENDPOINT_AUTH" in os.environ:
|
56
|
-
headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"]
|
57
|
-
headers.pop("apikey")
|
58
|
-
|
59
|
-
session = _create_requests_session(
|
60
|
-
url=url,
|
61
|
-
num_retry=3,
|
62
|
-
headers=headers,
|
63
|
-
)
|
65
|
+
response = _call_post(url, payload, session, files, function_name)
|
64
66
|
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
finally:
|
89
|
-
trace = tool_call_trace.model_dump()
|
90
|
-
trace["type"] = "tool_call"
|
91
|
-
display({MimeType.APPLICATION_JSON: trace}, raw=True)
|
67
|
+
# TODO: consider making the response schema the same between below two sources
|
68
|
+
return response if "TOOL_ENDPOINT_AUTH" in os.environ else response["data"]
|
69
|
+
|
70
|
+
|
71
|
+
def send_task_inference_request(
|
72
|
+
payload: Dict[str, Any],
|
73
|
+
task_name: str,
|
74
|
+
files: Optional[List[Tuple[Any, ...]]] = None,
|
75
|
+
metadata: Optional[Dict[str, Any]] = None,
|
76
|
+
) -> Any:
|
77
|
+
url = f"{_LND_API_URL_v2}/{task_name}"
|
78
|
+
headers = {"apikey": _LND_API_KEY}
|
79
|
+
session = _create_requests_session(
|
80
|
+
url=url,
|
81
|
+
num_retry=3,
|
82
|
+
headers=headers,
|
83
|
+
)
|
84
|
+
|
85
|
+
function_name = "unknown"
|
86
|
+
if metadata is not None and "function_name" in metadata:
|
87
|
+
function_name = metadata["function_name"]
|
88
|
+
response = _call_post(url, payload, session, files, function_name)
|
89
|
+
return response["data"]
|
92
90
|
|
93
91
|
|
94
92
|
def _create_requests_session(
|
@@ -195,3 +193,49 @@ def get_tools_info(funcs: List[Callable[..., Any]]) -> Dict[str, str]:
|
|
195
193
|
data[func.__name__] = f"{func.__name__}{inspect.signature(func)}:\n{desc}"
|
196
194
|
|
197
195
|
return data
|
196
|
+
|
197
|
+
|
198
|
+
def _call_post(
|
199
|
+
url: str,
|
200
|
+
payload: dict[str, Any],
|
201
|
+
session: Session,
|
202
|
+
files: Optional[List[Tuple[Any, ...]]] = None,
|
203
|
+
function_name: str = "unknown",
|
204
|
+
) -> Any:
|
205
|
+
try:
|
206
|
+
tool_call_trace = ToolCallTrace(
|
207
|
+
endpoint_url=url,
|
208
|
+
request=payload,
|
209
|
+
response={},
|
210
|
+
error=None,
|
211
|
+
)
|
212
|
+
|
213
|
+
if files is not None:
|
214
|
+
response = session.post(url, data=payload, files=files)
|
215
|
+
else:
|
216
|
+
response = session.post(url, json=payload)
|
217
|
+
|
218
|
+
if response.status_code != 200:
|
219
|
+
tool_call_trace.error = Error(
|
220
|
+
name="RemoteToolCallFailed",
|
221
|
+
value=f"{response.status_code} - {response.text}",
|
222
|
+
traceback_raw=[],
|
223
|
+
)
|
224
|
+
_LOGGER.error(f"Request failed: {response.status_code} {response.text}")
|
225
|
+
raise RemoteToolCallFailed(
|
226
|
+
function_name, response.status_code, response.text
|
227
|
+
)
|
228
|
+
|
229
|
+
result = response.json()
|
230
|
+
tool_call_trace.response = result
|
231
|
+
return result
|
232
|
+
finally:
|
233
|
+
trace = tool_call_trace.model_dump()
|
234
|
+
trace["type"] = "tool_call"
|
235
|
+
display({MimeType.APPLICATION_JSON: trace}, raw=True)
|
236
|
+
|
237
|
+
|
238
|
+
def filter_bboxes_by_threshold(
|
239
|
+
bboxes: BoundingBoxes, threshold: float
|
240
|
+
) -> BoundingBoxes:
|
241
|
+
return list(filter(lambda bbox: bbox.score >= threshold, bboxes))
|
@@ -13,7 +13,7 @@ import cv2
|
|
13
13
|
import numpy as np
|
14
14
|
import requests
|
15
15
|
from moviepy.editor import ImageSequenceClip
|
16
|
-
from PIL import Image, ImageDraw, ImageFont
|
16
|
+
from PIL import Image, ImageDraw, ImageFont, ImageEnhance
|
17
17
|
from pillow_heif import register_heif_opener # type: ignore
|
18
18
|
from pytube import YouTube # type: ignore
|
19
19
|
|
@@ -24,6 +24,8 @@ from vision_agent.tools.tool_utils import (
|
|
24
24
|
get_tools_df,
|
25
25
|
get_tools_info,
|
26
26
|
send_inference_request,
|
27
|
+
send_task_inference_request,
|
28
|
+
filter_bboxes_by_threshold,
|
27
29
|
)
|
28
30
|
from vision_agent.tools.tools_types import (
|
29
31
|
BboxInput,
|
@@ -32,6 +34,7 @@ from vision_agent.tools.tools_types import (
|
|
32
34
|
Florencev2FtRequest,
|
33
35
|
JobStatus,
|
34
36
|
PromptTask,
|
37
|
+
ODResponseData,
|
35
38
|
)
|
36
39
|
from vision_agent.utils import extract_frames_from_video
|
37
40
|
from vision_agent.utils.exceptions import FineTuneModelIsNotReady
|
@@ -455,7 +458,7 @@ def loca_zero_shot_counting(image: np.ndarray) -> Dict[str, Any]:
|
|
455
458
|
"image": image_b64,
|
456
459
|
"function_name": "loca_zero_shot_counting",
|
457
460
|
}
|
458
|
-
resp_data = send_inference_request(data, "loca", v2=True)
|
461
|
+
resp_data: dict[str, Any] = send_inference_request(data, "loca", v2=True)
|
459
462
|
resp_data["heat_map"] = np.array(resp_data["heat_map"][0]).astype(np.uint8)
|
460
463
|
return resp_data
|
461
464
|
|
@@ -469,6 +472,8 @@ def loca_visual_prompt_counting(
|
|
469
472
|
|
470
473
|
Parameters:
|
471
474
|
image (np.ndarray): The image that contains lot of instances of a single object
|
475
|
+
visual_prompt (Dict[str, List[float]]): Bounding box of the object in format
|
476
|
+
[xmin, ymin, xmax, ymax]. Only 1 bounding box can be provided.
|
472
477
|
|
473
478
|
Returns:
|
474
479
|
Dict[str, Any]: A dictionary containing the key 'count' and the count as a
|
@@ -496,11 +501,109 @@ def loca_visual_prompt_counting(
|
|
496
501
|
"bbox": list(map(int, denormalize_bbox(bbox, image_size))),
|
497
502
|
"function_name": "loca_visual_prompt_counting",
|
498
503
|
}
|
499
|
-
resp_data = send_inference_request(data, "loca", v2=True)
|
504
|
+
resp_data: dict[str, Any] = send_inference_request(data, "loca", v2=True)
|
500
505
|
resp_data["heat_map"] = np.array(resp_data["heat_map"][0]).astype(np.uint8)
|
501
506
|
return resp_data
|
502
507
|
|
503
508
|
|
509
|
+
def countgd_counting(
|
510
|
+
prompt: str,
|
511
|
+
image: np.ndarray,
|
512
|
+
box_threshold: float = 0.23,
|
513
|
+
) -> List[Dict[str, Any]]:
|
514
|
+
"""'countgd_counting' is a tool that can precisely count multiple instances of an
|
515
|
+
object given a text prompt. It returns a list of bounding boxes with normalized
|
516
|
+
coordinates, label names and associated confidence scores.
|
517
|
+
|
518
|
+
Parameters:
|
519
|
+
prompt (str): The object that needs to be counted.
|
520
|
+
image (np.ndarray): The image that contains multiple instances of the object.
|
521
|
+
box_threshold (float, optional): The threshold for detection. Defaults
|
522
|
+
to 0.23.
|
523
|
+
|
524
|
+
Returns:
|
525
|
+
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
|
526
|
+
bounding box of the detected objects with normalized coordinates between 0
|
527
|
+
and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
|
528
|
+
top-left and xmax and ymax are the coordinates of the bottom-right of the
|
529
|
+
bounding box.
|
530
|
+
|
531
|
+
Example
|
532
|
+
-------
|
533
|
+
>>> countgd_counting("flower", image)
|
534
|
+
[
|
535
|
+
{'score': 0.49, 'label': 'flower', 'bbox': [0.1, 0.11, 0.35, 0.4]},
|
536
|
+
{'score': 0.68, 'label': 'flower', 'bbox': [0.2, 0.21, 0.45, 0.5},
|
537
|
+
{'score': 0.78, 'label': 'flower', 'bbox': [0.3, 0.35, 0.48, 0.52},
|
538
|
+
{'score': 0.98, 'label': 'flower', 'bbox': [0.44, 0.24, 0.49, 0.58},
|
539
|
+
]
|
540
|
+
"""
|
541
|
+
buffer_bytes = numpy_to_bytes(image)
|
542
|
+
files = [("image", buffer_bytes)]
|
543
|
+
prompt = prompt.replace(", ", " .")
|
544
|
+
payload = {"prompts": [prompt], "model": "countgd"}
|
545
|
+
metadata = {"function_name": "countgd_counting"}
|
546
|
+
resp_data = send_task_inference_request(
|
547
|
+
payload, "text-to-object-detection", files=files, metadata=metadata
|
548
|
+
)
|
549
|
+
bboxes_per_frame = resp_data[0]
|
550
|
+
bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame]
|
551
|
+
filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold)
|
552
|
+
return [bbox.model_dump() for bbox in filtered_bboxes]
|
553
|
+
|
554
|
+
|
555
|
+
def countgd_example_based_counting(
|
556
|
+
visual_prompts: List[List[float]],
|
557
|
+
image: np.ndarray,
|
558
|
+
box_threshold: float = 0.23,
|
559
|
+
) -> List[Dict[str, Any]]:
|
560
|
+
"""'countgd_example_based_counting' is a tool that can precisely count multiple
|
561
|
+
instances of an object given few visual example prompts. It returns a list of bounding
|
562
|
+
boxes with normalized coordinates, label names and associated confidence scores.
|
563
|
+
|
564
|
+
Parameters:
|
565
|
+
visual_prompts (List[List[float]]): Bounding boxes of the object in format
|
566
|
+
[xmin, ymin, xmax, ymax]. Upto 3 bounding boxes can be provided.
|
567
|
+
image (np.ndarray): The image that contains multiple instances of the object.
|
568
|
+
box_threshold (float, optional): The threshold for detection. Defaults
|
569
|
+
to 0.23.
|
570
|
+
|
571
|
+
Returns:
|
572
|
+
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
|
573
|
+
bounding box of the detected objects with normalized coordinates between 0
|
574
|
+
and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
|
575
|
+
top-left and xmax and ymax are the coordinates of the bottom-right of the
|
576
|
+
bounding box.
|
577
|
+
|
578
|
+
Example
|
579
|
+
-------
|
580
|
+
>>> countgd_example_based_counting(
|
581
|
+
visual_prompts=[[0.1, 0.1, 0.4, 0.42], [0.2, 0.3, 0.25, 0.35]],
|
582
|
+
image=image
|
583
|
+
)
|
584
|
+
[
|
585
|
+
{'score': 0.49, 'label': 'object', 'bounding_box': [0.1, 0.11, 0.35, 0.4]},
|
586
|
+
{'score': 0.68, 'label': 'object', 'bounding_box': [0.2, 0.21, 0.45, 0.5},
|
587
|
+
{'score': 0.78, 'label': 'object', 'bounding_box': [0.3, 0.35, 0.48, 0.52},
|
588
|
+
{'score': 0.98, 'label': 'object', 'bounding_box': [0.44, 0.24, 0.49, 0.58},
|
589
|
+
]
|
590
|
+
"""
|
591
|
+
buffer_bytes = numpy_to_bytes(image)
|
592
|
+
files = [("image", buffer_bytes)]
|
593
|
+
visual_prompts = [
|
594
|
+
denormalize_bbox(bbox, image.shape[:2]) for bbox in visual_prompts
|
595
|
+
]
|
596
|
+
payload = {"visual_prompts": json.dumps(visual_prompts), "model": "countgd"}
|
597
|
+
metadata = {"function_name": "countgd_example_based_counting"}
|
598
|
+
resp_data = send_task_inference_request(
|
599
|
+
payload, "visual-prompts-to-object-detection", files=files, metadata=metadata
|
600
|
+
)
|
601
|
+
bboxes_per_frame = resp_data[0]
|
602
|
+
bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame]
|
603
|
+
filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold)
|
604
|
+
return [bbox.model_dump() for bbox in filtered_bboxes]
|
605
|
+
|
606
|
+
|
504
607
|
def florence2_roberta_vqa(prompt: str, image: np.ndarray) -> str:
|
505
608
|
"""'florence2_roberta_vqa' is a tool that takes an image and analyzes
|
506
609
|
its contents, generates detailed captions and then tries to answer the given
|
@@ -646,7 +749,7 @@ def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]:
|
|
646
749
|
"tool": "closed_set_image_classification",
|
647
750
|
"function_name": "clip",
|
648
751
|
}
|
649
|
-
resp_data = send_inference_request(data, "tools")
|
752
|
+
resp_data: dict[str, Any] = send_inference_request(data, "tools")
|
650
753
|
resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]]
|
651
754
|
return resp_data
|
652
755
|
|
@@ -674,7 +777,7 @@ def vit_image_classification(image: np.ndarray) -> Dict[str, Any]:
|
|
674
777
|
"tool": "image_classification",
|
675
778
|
"function_name": "vit_image_classification",
|
676
779
|
}
|
677
|
-
resp_data = send_inference_request(data, "tools")
|
780
|
+
resp_data: dict[str, Any] = send_inference_request(data, "tools")
|
678
781
|
resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]]
|
679
782
|
return resp_data
|
680
783
|
|
@@ -701,7 +804,9 @@ def vit_nsfw_classification(image: np.ndarray) -> Dict[str, Any]:
|
|
701
804
|
"image": image_b64,
|
702
805
|
"function_name": "vit_nsfw_classification",
|
703
806
|
}
|
704
|
-
resp_data = send_inference_request(
|
807
|
+
resp_data: dict[str, Any] = send_inference_request(
|
808
|
+
data, "nsfw-classification", v2=True
|
809
|
+
)
|
705
810
|
resp_data["score"] = round(resp_data["score"], 4)
|
706
811
|
return resp_data
|
707
812
|
|
@@ -1559,6 +1664,74 @@ def overlay_heat_map(
|
|
1559
1664
|
return np.array(combined)
|
1560
1665
|
|
1561
1666
|
|
1667
|
+
def overlay_counting_results(
|
1668
|
+
image: np.ndarray, instances: List[Dict[str, Any]]
|
1669
|
+
) -> np.ndarray:
|
1670
|
+
"""'overlay_counting_results' is a utility function that displays counting results on
|
1671
|
+
an image.
|
1672
|
+
|
1673
|
+
Parameters:
|
1674
|
+
image (np.ndarray): The image to display the bounding boxes on.
|
1675
|
+
instances (List[Dict[str, Any]]): A list of dictionaries containing the bounding
|
1676
|
+
box information of each instance
|
1677
|
+
|
1678
|
+
Returns:
|
1679
|
+
np.ndarray: The image with the instance_id dislpayed
|
1680
|
+
|
1681
|
+
Example
|
1682
|
+
-------
|
1683
|
+
>>> image_with_bboxes = overlay_counting_results(
|
1684
|
+
image, [{'score': 0.99, 'label': 'object', 'bbox': [0.1, 0.11, 0.35, 0.4]}],
|
1685
|
+
)
|
1686
|
+
"""
|
1687
|
+
pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB")
|
1688
|
+
color = (158, 218, 229)
|
1689
|
+
|
1690
|
+
width, height = pil_image.size
|
1691
|
+
fontsize = max(10, int(min(width, height) / 80))
|
1692
|
+
pil_image = ImageEnhance.Brightness(pil_image).enhance(0.5)
|
1693
|
+
draw = ImageDraw.Draw(pil_image)
|
1694
|
+
font = ImageFont.truetype(
|
1695
|
+
str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")),
|
1696
|
+
fontsize,
|
1697
|
+
)
|
1698
|
+
|
1699
|
+
for i, elt in enumerate(instances):
|
1700
|
+
label = f"{i}"
|
1701
|
+
box = elt["bbox"]
|
1702
|
+
|
1703
|
+
# denormalize the box if it is normalized
|
1704
|
+
box = denormalize_bbox(box, (height, width))
|
1705
|
+
x0, y0, x1, y1 = box
|
1706
|
+
cx, cy = (x0 + x1) / 2, (y0 + y1) / 2
|
1707
|
+
|
1708
|
+
text_box = draw.textbbox(
|
1709
|
+
(cx, cy), text=label, font=font, align="center", anchor="mm"
|
1710
|
+
)
|
1711
|
+
|
1712
|
+
# Calculate the offset to center the text within the bounding box
|
1713
|
+
text_width = text_box[2] - text_box[0]
|
1714
|
+
text_height = text_box[3] - text_box[1]
|
1715
|
+
text_x0 = cx - text_width / 2
|
1716
|
+
text_y0 = cy - text_height / 2
|
1717
|
+
text_x1 = cx + text_width / 2
|
1718
|
+
text_y1 = cy + text_height / 2
|
1719
|
+
|
1720
|
+
# Draw the rectangle encapsulating the text
|
1721
|
+
draw.rectangle((text_x0, text_y0, text_x1, text_y1), fill=color)
|
1722
|
+
|
1723
|
+
# Draw the text at the center of the bounding box
|
1724
|
+
draw.text(
|
1725
|
+
(text_x0, text_y0),
|
1726
|
+
label,
|
1727
|
+
fill="black",
|
1728
|
+
font=font,
|
1729
|
+
anchor="lt",
|
1730
|
+
)
|
1731
|
+
|
1732
|
+
return np.array(pil_image)
|
1733
|
+
|
1734
|
+
|
1562
1735
|
# TODO: add this function to the imports so that is picked in the agent
|
1563
1736
|
def florencev2_fine_tuning(bboxes: List[Dict[str, Any]], task: str) -> UUID:
|
1564
1737
|
"""'florencev2_fine_tuning' is a tool that fine-tune florencev2 to be able
|
@@ -1679,8 +1852,7 @@ FUNCTION_TOOLS = [
|
|
1679
1852
|
clip,
|
1680
1853
|
vit_image_classification,
|
1681
1854
|
vit_nsfw_classification,
|
1682
|
-
|
1683
|
-
loca_visual_prompt_counting,
|
1855
|
+
countgd_counting,
|
1684
1856
|
florence2_image_caption,
|
1685
1857
|
florence2_ocr,
|
1686
1858
|
florence2_sam2_image,
|
@@ -1703,6 +1875,7 @@ UTIL_TOOLS = [
|
|
1703
1875
|
overlay_bounding_boxes,
|
1704
1876
|
overlay_segmentation_masks,
|
1705
1877
|
overlay_heat_map,
|
1878
|
+
overlay_counting_results,
|
1706
1879
|
]
|
1707
1880
|
|
1708
1881
|
TOOLS = FUNCTION_TOOLS + UTIL_TOOLS
|
@@ -1720,5 +1893,6 @@ UTILITIES_DOCSTRING = get_tool_documentation(
|
|
1720
1893
|
overlay_bounding_boxes,
|
1721
1894
|
overlay_segmentation_masks,
|
1722
1895
|
overlay_heat_map,
|
1896
|
+
overlay_counting_results,
|
1723
1897
|
]
|
1724
1898
|
)
|
@@ -1,8 +1,8 @@
|
|
1
1
|
from enum import Enum
|
2
|
-
from typing import List, Optional, Tuple
|
3
2
|
from uuid import UUID
|
3
|
+
from typing import List, Tuple, Optional, Union
|
4
4
|
|
5
|
-
from pydantic import BaseModel, ConfigDict, Field,
|
5
|
+
from pydantic import BaseModel, ConfigDict, Field, field_serializer, SerializationInfo
|
6
6
|
|
7
7
|
|
8
8
|
class BboxInput(BaseModel):
|
@@ -82,3 +82,16 @@ class JobStatus(str, Enum):
|
|
82
82
|
SUCCEEDED = "SUCCEEDED"
|
83
83
|
FAILED = "FAILED"
|
84
84
|
STOPPED = "STOPPED"
|
85
|
+
|
86
|
+
|
87
|
+
class ODResponseData(BaseModel):
|
88
|
+
label: str
|
89
|
+
score: float
|
90
|
+
bbox: Union[list[int], list[float]] = Field(alias="bounding_box")
|
91
|
+
|
92
|
+
model_config = ConfigDict(
|
93
|
+
populate_by_name=True,
|
94
|
+
)
|
95
|
+
|
96
|
+
|
97
|
+
BoundingBoxes = list[ODResponseData]
|
@@ -181,7 +181,7 @@ def denormalize_bbox(
|
|
181
181
|
raise ValueError("Bounding box must be of length 4.")
|
182
182
|
|
183
183
|
arr = np.array(bbox)
|
184
|
-
if np.all((arr >= 0) & (arr <= 1)):
|
184
|
+
if np.all((arr[:2] >= 0) & (arr[:2] <= 1)):
|
185
185
|
x1, y1, x2, y2 = bbox
|
186
186
|
x1 = round(x1 * image_size[1])
|
187
187
|
y1 = round(y1 * image_size[0])
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|