vision-agent 0.2.119__py3-none-any.whl → 0.2.121__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vision_agent/agent/vision_agent_coder_prompts.py +4 -5
- vision_agent/lmm/lmm.py +0 -3
- vision_agent/tools/__init__.py +3 -0
- vision_agent/tools/tool_utils.py +95 -51
- vision_agent/tools/tools.py +182 -8
- vision_agent/tools/tools_types.py +15 -2
- vision_agent/utils/image_utils.py +1 -1
- {vision_agent-0.2.119.dist-info → vision_agent-0.2.121.dist-info}/METADATA +1 -1
- {vision_agent-0.2.119.dist-info → vision_agent-0.2.121.dist-info}/RECORD +11 -11
- {vision_agent-0.2.119.dist-info → vision_agent-0.2.121.dist-info}/LICENSE +0 -0
- {vision_agent-0.2.119.dist-info → vision_agent-0.2.121.dist-info}/WHEEL +0 -0
@@ -81,20 +81,19 @@ plan2:
|
|
81
81
|
- Count the number of detected objects labeled as 'person'.
|
82
82
|
plan3:
|
83
83
|
- Load the image from the provided file path 'image.jpg'.
|
84
|
-
- Use the '
|
84
|
+
- Use the 'countgd_counting' tool to count the dominant foreground object, which in this case is people.
|
85
85
|
|
86
86
|
```python
|
87
|
-
from vision_agent.tools import load_image, owl_v2, grounding_sam,
|
87
|
+
from vision_agent.tools import load_image, owl_v2, grounding_sam, countgd_counting
|
88
88
|
image = load_image("image.jpg")
|
89
89
|
owl_v2_out = owl_v2("person", image)
|
90
90
|
|
91
91
|
gsam_out = grounding_sam("person", image)
|
92
92
|
gsam_out = [{{k: v for k, v in o.items() if k != "mask"}} for o in gsam_out]
|
93
93
|
|
94
|
-
|
95
|
-
loca_out = loca_out["count"]
|
94
|
+
cgd_out = countgd_counting(image)
|
96
95
|
|
97
|
-
final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "
|
96
|
+
final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "countgd_counting": cgd_out}}
|
98
97
|
print(final_out)
|
99
98
|
```
|
100
99
|
"""
|
vision_agent/lmm/lmm.py
CHANGED
@@ -286,9 +286,6 @@ class OpenAILMM(LMM):
|
|
286
286
|
|
287
287
|
return lambda x: T.grounding_sam(params["prompt"], x)
|
288
288
|
|
289
|
-
def generate_zero_shot_counter(self, question: str) -> Callable:
|
290
|
-
return T.loca_zero_shot_counting
|
291
|
-
|
292
289
|
def generate_image_qa_tool(self, question: str) -> Callable:
|
293
290
|
return lambda x: T.git_vqa_v2(question, x)
|
294
291
|
|
vision_agent/tools/__init__.py
CHANGED
@@ -37,10 +37,13 @@ from .tools import (
|
|
37
37
|
load_image,
|
38
38
|
loca_visual_prompt_counting,
|
39
39
|
loca_zero_shot_counting,
|
40
|
+
countgd_counting,
|
41
|
+
countgd_example_based_counting,
|
40
42
|
ocr,
|
41
43
|
overlay_bounding_boxes,
|
42
44
|
overlay_heat_map,
|
43
45
|
overlay_segmentation_masks,
|
46
|
+
overlay_counting_results,
|
44
47
|
owl_v2,
|
45
48
|
save_image,
|
46
49
|
save_json,
|
vision_agent/tools/tool_utils.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
|
+
import os
|
1
2
|
import inspect
|
2
3
|
import logging
|
3
|
-
import os
|
4
4
|
from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple
|
5
5
|
|
6
6
|
import pandas as pd
|
@@ -13,6 +13,7 @@ from urllib3.util.retry import Retry
|
|
13
13
|
from vision_agent.utils.exceptions import RemoteToolCallFailed
|
14
14
|
from vision_agent.utils.execute import Error, MimeType
|
15
15
|
from vision_agent.utils.type_defs import LandingaiAPIKey
|
16
|
+
from vision_agent.tools.tools_types import BoundingBoxes
|
16
17
|
|
17
18
|
_LOGGER = logging.getLogger(__name__)
|
18
19
|
_LND_API_KEY = os.environ.get("LANDINGAI_API_KEY", LandingaiAPIKey().api_key)
|
@@ -34,61 +35,58 @@ def send_inference_request(
|
|
34
35
|
files: Optional[List[Tuple[Any, ...]]] = None,
|
35
36
|
v2: bool = False,
|
36
37
|
metadata_payload: Optional[Dict[str, Any]] = None,
|
37
|
-
) ->
|
38
|
+
) -> Any:
|
38
39
|
# TODO: runtime_tag and function_name should be metadata_payload and now included
|
39
40
|
# in the service payload
|
40
|
-
|
41
|
-
|
42
|
-
|
41
|
+
if runtime_tag := os.environ.get("RUNTIME_TAG", ""):
|
42
|
+
payload["runtime_tag"] = runtime_tag
|
43
|
+
|
44
|
+
url = f"{_LND_API_URL_v2 if v2 else _LND_API_URL}/{endpoint_name}"
|
45
|
+
if "TOOL_ENDPOINT_URL" in os.environ:
|
46
|
+
url = os.environ["TOOL_ENDPOINT_URL"]
|
47
|
+
|
48
|
+
headers = {"apikey": _LND_API_KEY}
|
49
|
+
if "TOOL_ENDPOINT_AUTH" in os.environ:
|
50
|
+
headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"]
|
51
|
+
headers.pop("apikey")
|
52
|
+
|
53
|
+
session = _create_requests_session(
|
54
|
+
url=url,
|
55
|
+
num_retry=3,
|
56
|
+
headers=headers,
|
57
|
+
)
|
43
58
|
|
44
|
-
|
45
|
-
|
46
|
-
|
59
|
+
function_name = "unknown"
|
60
|
+
if "function_name" in payload:
|
61
|
+
function_name = payload["function_name"]
|
62
|
+
elif metadata_payload is not None and "function_name" in metadata_payload:
|
63
|
+
function_name = metadata_payload["function_name"]
|
47
64
|
|
48
|
-
|
49
|
-
endpoint_url=url,
|
50
|
-
request=payload,
|
51
|
-
response={},
|
52
|
-
error=None,
|
53
|
-
)
|
54
|
-
headers = {"apikey": _LND_API_KEY}
|
55
|
-
if "TOOL_ENDPOINT_AUTH" in os.environ:
|
56
|
-
headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"]
|
57
|
-
headers.pop("apikey")
|
58
|
-
|
59
|
-
session = _create_requests_session(
|
60
|
-
url=url,
|
61
|
-
num_retry=3,
|
62
|
-
headers=headers,
|
63
|
-
)
|
65
|
+
response = _call_post(url, payload, session, files, function_name)
|
64
66
|
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
finally:
|
89
|
-
trace = tool_call_trace.model_dump()
|
90
|
-
trace["type"] = "tool_call"
|
91
|
-
display({MimeType.APPLICATION_JSON: trace}, raw=True)
|
67
|
+
# TODO: consider making the response schema the same between below two sources
|
68
|
+
return response if "TOOL_ENDPOINT_AUTH" in os.environ else response["data"]
|
69
|
+
|
70
|
+
|
71
|
+
def send_task_inference_request(
|
72
|
+
payload: Dict[str, Any],
|
73
|
+
task_name: str,
|
74
|
+
files: Optional[List[Tuple[Any, ...]]] = None,
|
75
|
+
metadata: Optional[Dict[str, Any]] = None,
|
76
|
+
) -> Any:
|
77
|
+
url = f"{_LND_API_URL_v2}/{task_name}"
|
78
|
+
headers = {"apikey": _LND_API_KEY}
|
79
|
+
session = _create_requests_session(
|
80
|
+
url=url,
|
81
|
+
num_retry=3,
|
82
|
+
headers=headers,
|
83
|
+
)
|
84
|
+
|
85
|
+
function_name = "unknown"
|
86
|
+
if metadata is not None and "function_name" in metadata:
|
87
|
+
function_name = metadata["function_name"]
|
88
|
+
response = _call_post(url, payload, session, files, function_name)
|
89
|
+
return response["data"]
|
92
90
|
|
93
91
|
|
94
92
|
def _create_requests_session(
|
@@ -195,3 +193,49 @@ def get_tools_info(funcs: List[Callable[..., Any]]) -> Dict[str, str]:
|
|
195
193
|
data[func.__name__] = f"{func.__name__}{inspect.signature(func)}:\n{desc}"
|
196
194
|
|
197
195
|
return data
|
196
|
+
|
197
|
+
|
198
|
+
def _call_post(
|
199
|
+
url: str,
|
200
|
+
payload: dict[str, Any],
|
201
|
+
session: Session,
|
202
|
+
files: Optional[List[Tuple[Any, ...]]] = None,
|
203
|
+
function_name: str = "unknown",
|
204
|
+
) -> Any:
|
205
|
+
try:
|
206
|
+
tool_call_trace = ToolCallTrace(
|
207
|
+
endpoint_url=url,
|
208
|
+
request=payload,
|
209
|
+
response={},
|
210
|
+
error=None,
|
211
|
+
)
|
212
|
+
|
213
|
+
if files is not None:
|
214
|
+
response = session.post(url, data=payload, files=files)
|
215
|
+
else:
|
216
|
+
response = session.post(url, json=payload)
|
217
|
+
|
218
|
+
if response.status_code != 200:
|
219
|
+
tool_call_trace.error = Error(
|
220
|
+
name="RemoteToolCallFailed",
|
221
|
+
value=f"{response.status_code} - {response.text}",
|
222
|
+
traceback_raw=[],
|
223
|
+
)
|
224
|
+
_LOGGER.error(f"Request failed: {response.status_code} {response.text}")
|
225
|
+
raise RemoteToolCallFailed(
|
226
|
+
function_name, response.status_code, response.text
|
227
|
+
)
|
228
|
+
|
229
|
+
result = response.json()
|
230
|
+
tool_call_trace.response = result
|
231
|
+
return result
|
232
|
+
finally:
|
233
|
+
trace = tool_call_trace.model_dump()
|
234
|
+
trace["type"] = "tool_call"
|
235
|
+
display({MimeType.APPLICATION_JSON: trace}, raw=True)
|
236
|
+
|
237
|
+
|
238
|
+
def filter_bboxes_by_threshold(
|
239
|
+
bboxes: BoundingBoxes, threshold: float
|
240
|
+
) -> BoundingBoxes:
|
241
|
+
return list(filter(lambda bbox: bbox.score >= threshold, bboxes))
|
vision_agent/tools/tools.py
CHANGED
@@ -13,7 +13,7 @@ import cv2
|
|
13
13
|
import numpy as np
|
14
14
|
import requests
|
15
15
|
from moviepy.editor import ImageSequenceClip
|
16
|
-
from PIL import Image, ImageDraw, ImageFont
|
16
|
+
from PIL import Image, ImageDraw, ImageFont, ImageEnhance
|
17
17
|
from pillow_heif import register_heif_opener # type: ignore
|
18
18
|
from pytube import YouTube # type: ignore
|
19
19
|
|
@@ -24,6 +24,8 @@ from vision_agent.tools.tool_utils import (
|
|
24
24
|
get_tools_df,
|
25
25
|
get_tools_info,
|
26
26
|
send_inference_request,
|
27
|
+
send_task_inference_request,
|
28
|
+
filter_bboxes_by_threshold,
|
27
29
|
)
|
28
30
|
from vision_agent.tools.tools_types import (
|
29
31
|
BboxInput,
|
@@ -32,6 +34,7 @@ from vision_agent.tools.tools_types import (
|
|
32
34
|
Florencev2FtRequest,
|
33
35
|
JobStatus,
|
34
36
|
PromptTask,
|
37
|
+
ODResponseData,
|
35
38
|
)
|
36
39
|
from vision_agent.utils import extract_frames_from_video
|
37
40
|
from vision_agent.utils.exceptions import FineTuneModelIsNotReady
|
@@ -455,7 +458,7 @@ def loca_zero_shot_counting(image: np.ndarray) -> Dict[str, Any]:
|
|
455
458
|
"image": image_b64,
|
456
459
|
"function_name": "loca_zero_shot_counting",
|
457
460
|
}
|
458
|
-
resp_data = send_inference_request(data, "loca", v2=True)
|
461
|
+
resp_data: dict[str, Any] = send_inference_request(data, "loca", v2=True)
|
459
462
|
resp_data["heat_map"] = np.array(resp_data["heat_map"][0]).astype(np.uint8)
|
460
463
|
return resp_data
|
461
464
|
|
@@ -469,6 +472,8 @@ def loca_visual_prompt_counting(
|
|
469
472
|
|
470
473
|
Parameters:
|
471
474
|
image (np.ndarray): The image that contains lot of instances of a single object
|
475
|
+
visual_prompt (Dict[str, List[float]]): Bounding box of the object in format
|
476
|
+
[xmin, ymin, xmax, ymax]. Only 1 bounding box can be provided.
|
472
477
|
|
473
478
|
Returns:
|
474
479
|
Dict[str, Any]: A dictionary containing the key 'count' and the count as a
|
@@ -496,11 +501,109 @@ def loca_visual_prompt_counting(
|
|
496
501
|
"bbox": list(map(int, denormalize_bbox(bbox, image_size))),
|
497
502
|
"function_name": "loca_visual_prompt_counting",
|
498
503
|
}
|
499
|
-
resp_data = send_inference_request(data, "loca", v2=True)
|
504
|
+
resp_data: dict[str, Any] = send_inference_request(data, "loca", v2=True)
|
500
505
|
resp_data["heat_map"] = np.array(resp_data["heat_map"][0]).astype(np.uint8)
|
501
506
|
return resp_data
|
502
507
|
|
503
508
|
|
509
|
+
def countgd_counting(
|
510
|
+
prompt: str,
|
511
|
+
image: np.ndarray,
|
512
|
+
box_threshold: float = 0.23,
|
513
|
+
) -> List[Dict[str, Any]]:
|
514
|
+
"""'countgd_counting' is a tool that can precisely count multiple instances of an
|
515
|
+
object given a text prompt. It returns a list of bounding boxes with normalized
|
516
|
+
coordinates, label names and associated confidence scores.
|
517
|
+
|
518
|
+
Parameters:
|
519
|
+
prompt (str): The object that needs to be counted.
|
520
|
+
image (np.ndarray): The image that contains multiple instances of the object.
|
521
|
+
box_threshold (float, optional): The threshold for detection. Defaults
|
522
|
+
to 0.23.
|
523
|
+
|
524
|
+
Returns:
|
525
|
+
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
|
526
|
+
bounding box of the detected objects with normalized coordinates between 0
|
527
|
+
and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
|
528
|
+
top-left and xmax and ymax are the coordinates of the bottom-right of the
|
529
|
+
bounding box.
|
530
|
+
|
531
|
+
Example
|
532
|
+
-------
|
533
|
+
>>> countgd_counting("flower", image)
|
534
|
+
[
|
535
|
+
{'score': 0.49, 'label': 'flower', 'bbox': [0.1, 0.11, 0.35, 0.4]},
|
536
|
+
{'score': 0.68, 'label': 'flower', 'bbox': [0.2, 0.21, 0.45, 0.5},
|
537
|
+
{'score': 0.78, 'label': 'flower', 'bbox': [0.3, 0.35, 0.48, 0.52},
|
538
|
+
{'score': 0.98, 'label': 'flower', 'bbox': [0.44, 0.24, 0.49, 0.58},
|
539
|
+
]
|
540
|
+
"""
|
541
|
+
buffer_bytes = numpy_to_bytes(image)
|
542
|
+
files = [("image", buffer_bytes)]
|
543
|
+
prompt = prompt.replace(", ", " .")
|
544
|
+
payload = {"prompts": [prompt], "model": "countgd"}
|
545
|
+
metadata = {"function_name": "countgd_counting"}
|
546
|
+
resp_data = send_task_inference_request(
|
547
|
+
payload, "text-to-object-detection", files=files, metadata=metadata
|
548
|
+
)
|
549
|
+
bboxes_per_frame = resp_data[0]
|
550
|
+
bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame]
|
551
|
+
filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold)
|
552
|
+
return [bbox.model_dump() for bbox in filtered_bboxes]
|
553
|
+
|
554
|
+
|
555
|
+
def countgd_example_based_counting(
|
556
|
+
visual_prompts: List[List[float]],
|
557
|
+
image: np.ndarray,
|
558
|
+
box_threshold: float = 0.23,
|
559
|
+
) -> List[Dict[str, Any]]:
|
560
|
+
"""'countgd_example_based_counting' is a tool that can precisely count multiple
|
561
|
+
instances of an object given few visual example prompts. It returns a list of bounding
|
562
|
+
boxes with normalized coordinates, label names and associated confidence scores.
|
563
|
+
|
564
|
+
Parameters:
|
565
|
+
visual_prompts (List[List[float]]): Bounding boxes of the object in format
|
566
|
+
[xmin, ymin, xmax, ymax]. Upto 3 bounding boxes can be provided.
|
567
|
+
image (np.ndarray): The image that contains multiple instances of the object.
|
568
|
+
box_threshold (float, optional): The threshold for detection. Defaults
|
569
|
+
to 0.23.
|
570
|
+
|
571
|
+
Returns:
|
572
|
+
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
|
573
|
+
bounding box of the detected objects with normalized coordinates between 0
|
574
|
+
and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
|
575
|
+
top-left and xmax and ymax are the coordinates of the bottom-right of the
|
576
|
+
bounding box.
|
577
|
+
|
578
|
+
Example
|
579
|
+
-------
|
580
|
+
>>> countgd_example_based_counting(
|
581
|
+
visual_prompts=[[0.1, 0.1, 0.4, 0.42], [0.2, 0.3, 0.25, 0.35]],
|
582
|
+
image=image
|
583
|
+
)
|
584
|
+
[
|
585
|
+
{'score': 0.49, 'label': 'object', 'bounding_box': [0.1, 0.11, 0.35, 0.4]},
|
586
|
+
{'score': 0.68, 'label': 'object', 'bounding_box': [0.2, 0.21, 0.45, 0.5},
|
587
|
+
{'score': 0.78, 'label': 'object', 'bounding_box': [0.3, 0.35, 0.48, 0.52},
|
588
|
+
{'score': 0.98, 'label': 'object', 'bounding_box': [0.44, 0.24, 0.49, 0.58},
|
589
|
+
]
|
590
|
+
"""
|
591
|
+
buffer_bytes = numpy_to_bytes(image)
|
592
|
+
files = [("image", buffer_bytes)]
|
593
|
+
visual_prompts = [
|
594
|
+
denormalize_bbox(bbox, image.shape[:2]) for bbox in visual_prompts
|
595
|
+
]
|
596
|
+
payload = {"visual_prompts": json.dumps(visual_prompts), "model": "countgd"}
|
597
|
+
metadata = {"function_name": "countgd_example_based_counting"}
|
598
|
+
resp_data = send_task_inference_request(
|
599
|
+
payload, "visual-prompts-to-object-detection", files=files, metadata=metadata
|
600
|
+
)
|
601
|
+
bboxes_per_frame = resp_data[0]
|
602
|
+
bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame]
|
603
|
+
filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold)
|
604
|
+
return [bbox.model_dump() for bbox in filtered_bboxes]
|
605
|
+
|
606
|
+
|
504
607
|
def florence2_roberta_vqa(prompt: str, image: np.ndarray) -> str:
|
505
608
|
"""'florence2_roberta_vqa' is a tool that takes an image and analyzes
|
506
609
|
its contents, generates detailed captions and then tries to answer the given
|
@@ -646,7 +749,7 @@ def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]:
|
|
646
749
|
"tool": "closed_set_image_classification",
|
647
750
|
"function_name": "clip",
|
648
751
|
}
|
649
|
-
resp_data = send_inference_request(data, "tools")
|
752
|
+
resp_data: dict[str, Any] = send_inference_request(data, "tools")
|
650
753
|
resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]]
|
651
754
|
return resp_data
|
652
755
|
|
@@ -674,7 +777,7 @@ def vit_image_classification(image: np.ndarray) -> Dict[str, Any]:
|
|
674
777
|
"tool": "image_classification",
|
675
778
|
"function_name": "vit_image_classification",
|
676
779
|
}
|
677
|
-
resp_data = send_inference_request(data, "tools")
|
780
|
+
resp_data: dict[str, Any] = send_inference_request(data, "tools")
|
678
781
|
resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]]
|
679
782
|
return resp_data
|
680
783
|
|
@@ -701,7 +804,9 @@ def vit_nsfw_classification(image: np.ndarray) -> Dict[str, Any]:
|
|
701
804
|
"image": image_b64,
|
702
805
|
"function_name": "vit_nsfw_classification",
|
703
806
|
}
|
704
|
-
resp_data = send_inference_request(
|
807
|
+
resp_data: dict[str, Any] = send_inference_request(
|
808
|
+
data, "nsfw-classification", v2=True
|
809
|
+
)
|
705
810
|
resp_data["score"] = round(resp_data["score"], 4)
|
706
811
|
return resp_data
|
707
812
|
|
@@ -1559,6 +1664,74 @@ def overlay_heat_map(
|
|
1559
1664
|
return np.array(combined)
|
1560
1665
|
|
1561
1666
|
|
1667
|
+
def overlay_counting_results(
|
1668
|
+
image: np.ndarray, instances: List[Dict[str, Any]]
|
1669
|
+
) -> np.ndarray:
|
1670
|
+
"""'overlay_counting_results' is a utility function that displays counting results on
|
1671
|
+
an image.
|
1672
|
+
|
1673
|
+
Parameters:
|
1674
|
+
image (np.ndarray): The image to display the bounding boxes on.
|
1675
|
+
instances (List[Dict[str, Any]]): A list of dictionaries containing the bounding
|
1676
|
+
box information of each instance
|
1677
|
+
|
1678
|
+
Returns:
|
1679
|
+
np.ndarray: The image with the instance_id dislpayed
|
1680
|
+
|
1681
|
+
Example
|
1682
|
+
-------
|
1683
|
+
>>> image_with_bboxes = overlay_counting_results(
|
1684
|
+
image, [{'score': 0.99, 'label': 'object', 'bbox': [0.1, 0.11, 0.35, 0.4]}],
|
1685
|
+
)
|
1686
|
+
"""
|
1687
|
+
pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB")
|
1688
|
+
color = (158, 218, 229)
|
1689
|
+
|
1690
|
+
width, height = pil_image.size
|
1691
|
+
fontsize = max(10, int(min(width, height) / 80))
|
1692
|
+
pil_image = ImageEnhance.Brightness(pil_image).enhance(0.5)
|
1693
|
+
draw = ImageDraw.Draw(pil_image)
|
1694
|
+
font = ImageFont.truetype(
|
1695
|
+
str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")),
|
1696
|
+
fontsize,
|
1697
|
+
)
|
1698
|
+
|
1699
|
+
for i, elt in enumerate(instances):
|
1700
|
+
label = f"{i}"
|
1701
|
+
box = elt["bbox"]
|
1702
|
+
|
1703
|
+
# denormalize the box if it is normalized
|
1704
|
+
box = denormalize_bbox(box, (height, width))
|
1705
|
+
x0, y0, x1, y1 = box
|
1706
|
+
cx, cy = (x0 + x1) / 2, (y0 + y1) / 2
|
1707
|
+
|
1708
|
+
text_box = draw.textbbox(
|
1709
|
+
(cx, cy), text=label, font=font, align="center", anchor="mm"
|
1710
|
+
)
|
1711
|
+
|
1712
|
+
# Calculate the offset to center the text within the bounding box
|
1713
|
+
text_width = text_box[2] - text_box[0]
|
1714
|
+
text_height = text_box[3] - text_box[1]
|
1715
|
+
text_x0 = cx - text_width / 2
|
1716
|
+
text_y0 = cy - text_height / 2
|
1717
|
+
text_x1 = cx + text_width / 2
|
1718
|
+
text_y1 = cy + text_height / 2
|
1719
|
+
|
1720
|
+
# Draw the rectangle encapsulating the text
|
1721
|
+
draw.rectangle((text_x0, text_y0, text_x1, text_y1), fill=color)
|
1722
|
+
|
1723
|
+
# Draw the text at the center of the bounding box
|
1724
|
+
draw.text(
|
1725
|
+
(text_x0, text_y0),
|
1726
|
+
label,
|
1727
|
+
fill="black",
|
1728
|
+
font=font,
|
1729
|
+
anchor="lt",
|
1730
|
+
)
|
1731
|
+
|
1732
|
+
return np.array(pil_image)
|
1733
|
+
|
1734
|
+
|
1562
1735
|
# TODO: add this function to the imports so that is picked in the agent
|
1563
1736
|
def florencev2_fine_tuning(bboxes: List[Dict[str, Any]], task: str) -> UUID:
|
1564
1737
|
"""'florencev2_fine_tuning' is a tool that fine-tune florencev2 to be able
|
@@ -1679,8 +1852,7 @@ FUNCTION_TOOLS = [
|
|
1679
1852
|
clip,
|
1680
1853
|
vit_image_classification,
|
1681
1854
|
vit_nsfw_classification,
|
1682
|
-
|
1683
|
-
loca_visual_prompt_counting,
|
1855
|
+
countgd_counting,
|
1684
1856
|
florence2_image_caption,
|
1685
1857
|
florence2_ocr,
|
1686
1858
|
florence2_sam2_image,
|
@@ -1703,6 +1875,7 @@ UTIL_TOOLS = [
|
|
1703
1875
|
overlay_bounding_boxes,
|
1704
1876
|
overlay_segmentation_masks,
|
1705
1877
|
overlay_heat_map,
|
1878
|
+
overlay_counting_results,
|
1706
1879
|
]
|
1707
1880
|
|
1708
1881
|
TOOLS = FUNCTION_TOOLS + UTIL_TOOLS
|
@@ -1720,5 +1893,6 @@ UTILITIES_DOCSTRING = get_tool_documentation(
|
|
1720
1893
|
overlay_bounding_boxes,
|
1721
1894
|
overlay_segmentation_masks,
|
1722
1895
|
overlay_heat_map,
|
1896
|
+
overlay_counting_results,
|
1723
1897
|
]
|
1724
1898
|
)
|
@@ -1,8 +1,8 @@
|
|
1
1
|
from enum import Enum
|
2
|
-
from typing import List, Optional, Tuple
|
3
2
|
from uuid import UUID
|
3
|
+
from typing import List, Tuple, Optional, Union
|
4
4
|
|
5
|
-
from pydantic import BaseModel, ConfigDict, Field,
|
5
|
+
from pydantic import BaseModel, ConfigDict, Field, field_serializer, SerializationInfo
|
6
6
|
|
7
7
|
|
8
8
|
class BboxInput(BaseModel):
|
@@ -82,3 +82,16 @@ class JobStatus(str, Enum):
|
|
82
82
|
SUCCEEDED = "SUCCEEDED"
|
83
83
|
FAILED = "FAILED"
|
84
84
|
STOPPED = "STOPPED"
|
85
|
+
|
86
|
+
|
87
|
+
class ODResponseData(BaseModel):
|
88
|
+
label: str
|
89
|
+
score: float
|
90
|
+
bbox: Union[list[int], list[float]] = Field(alias="bounding_box")
|
91
|
+
|
92
|
+
model_config = ConfigDict(
|
93
|
+
populate_by_name=True,
|
94
|
+
)
|
95
|
+
|
96
|
+
|
97
|
+
BoundingBoxes = list[ODResponseData]
|
@@ -181,7 +181,7 @@ def denormalize_bbox(
|
|
181
181
|
raise ValueError("Bounding box must be of length 4.")
|
182
182
|
|
183
183
|
arr = np.array(bbox)
|
184
|
-
if np.all((arr >= 0) & (arr <= 1)):
|
184
|
+
if np.all((arr[:2] >= 0) & (arr[:2] <= 1)):
|
185
185
|
x1, y1, x2, y2 = bbox
|
186
186
|
x1 = round(x1 * image_size[1])
|
187
187
|
y1 = round(y1 * image_size[0])
|
@@ -4,7 +4,7 @@ vision_agent/agent/agent.py,sha256=2cjIOxEuSJrqbfPXYoV0qER5ihXsPFCoEFJa4jpqan0,5
|
|
4
4
|
vision_agent/agent/agent_utils.py,sha256=22LiPhkJlS5mVeo2dIi259pc2NgA7PGHRpcbnrtKo78,1930
|
5
5
|
vision_agent/agent/vision_agent.py,sha256=IEyXT_JPCuWmBHdEnM1Wrsj7hmCe5pKLf0gnZFJTddI,11046
|
6
6
|
vision_agent/agent/vision_agent_coder.py,sha256=DOTmDdGPxcI06Jp6yx4ekRMP0vhiVaK9B9Dl8UyJHeo,34396
|
7
|
-
vision_agent/agent/vision_agent_coder_prompts.py,sha256=
|
7
|
+
vision_agent/agent/vision_agent_coder_prompts.py,sha256=Rg7-Ih7oFgFbHFFno0EHpaZEgm0SYj_nTdqqdp21YLo,11246
|
8
8
|
vision_agent/agent/vision_agent_prompts.py,sha256=0GliXFtBf32aPu2ClU63FI5ii5CTxWYsvrsmnnDp-gs,7134
|
9
9
|
vision_agent/clients/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
10
|
vision_agent/clients/http.py,sha256=k883i6M_4nl7zwwHSI-yP5sAgQZIDPM1nrKD6YFJ3Xs,2009
|
@@ -12,22 +12,22 @@ vision_agent/clients/landing_public_api.py,sha256=rGtACkr8o5egDuMHQ5MBO4NuvsgPTp
|
|
12
12
|
vision_agent/fonts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
13
13
|
vision_agent/fonts/default_font_ch_en.ttf,sha256=1YM0Z3XqLDjSNbF7ihQFSAIUdjF9m1rtHiNC_6QosTE,1594400
|
14
14
|
vision_agent/lmm/__init__.py,sha256=YuUZRsMHdn8cMOv6iBU8yUqlIOLrbZQqZl9KPnofsHQ,103
|
15
|
-
vision_agent/lmm/lmm.py,sha256=
|
15
|
+
vision_agent/lmm/lmm.py,sha256=H3a5V7c073-vXRJfQOblE2j_CsZkH1CNNRoQgLjJZuQ,20751
|
16
16
|
vision_agent/lmm/types.py,sha256=ZEXR_ptBL0ZwDMTDYkgxUCmSZFmBYPQd2jreNzr_8UY,221
|
17
|
-
vision_agent/tools/__init__.py,sha256=
|
17
|
+
vision_agent/tools/__init__.py,sha256=TILaqdFYicScvpnCXMxgBsFmSW22NQDIvucvEgo0etw,2289
|
18
18
|
vision_agent/tools/meta_tools.py,sha256=Vu9WnKicGhafx9dPzDbQjQdcIzRCYYFPF68o79hDP-8,14616
|
19
19
|
vision_agent/tools/prompts.py,sha256=V1z4YJLXZuUl_iZ5rY0M5hHc_2tmMEUKr0WocXKGt4E,1430
|
20
|
-
vision_agent/tools/tool_utils.py,sha256=
|
21
|
-
vision_agent/tools/tools.py,sha256=
|
22
|
-
vision_agent/tools/tools_types.py,sha256=
|
20
|
+
vision_agent/tools/tool_utils.py,sha256=e_p-G2nwgWOpoaqpDitY3FJ6fFuTEg5GhDOD67wI2bE,7527
|
21
|
+
vision_agent/tools/tools.py,sha256=Eec7-3ecjv_8s0CJcDMibDD5z99CLHMOx7SOL3kilVE,67010
|
22
|
+
vision_agent/tools/tools_types.py,sha256=1AvGEb-eslXjz4iWQGNQIatgKm6JDoBCDP0h7TjsNkU,2468
|
23
23
|
vision_agent/utils/__init__.py,sha256=pWk0ktvR4aUEhuEIzSLM9kSgW4WDVqptdvOTeGLkJ6M,230
|
24
24
|
vision_agent/utils/exceptions.py,sha256=booSPSuoULF7OXRr_YbC4dtKt6gM_HyiFQHBuaW86C4,2052
|
25
25
|
vision_agent/utils/execute.py,sha256=Ap8Yx80spQq5f2QtKGx1MK03BR45mJKhlp1kfh-rIao,26751
|
26
|
-
vision_agent/utils/image_utils.py,sha256=
|
26
|
+
vision_agent/utils/image_utils.py,sha256=UloC4byIQLM4CSCaH41SBciQ7X2OqKvsVvNOVKqIH_k,9856
|
27
27
|
vision_agent/utils/sim.py,sha256=ebE9Cs00pVEDI1HMjAzUBk88tQQmc2U-yAzIDinnekU,5572
|
28
28
|
vision_agent/utils/type_defs.py,sha256=BE12s3JNQy36QvauXHjwyeffVh5enfcvd4vTzSwvEZI,1384
|
29
29
|
vision_agent/utils/video.py,sha256=rNmU9KEIkZB5-EztZNlUiKYN0mm_55A_2VGUM0QpqLA,8779
|
30
|
-
vision_agent-0.2.
|
31
|
-
vision_agent-0.2.
|
32
|
-
vision_agent-0.2.
|
33
|
-
vision_agent-0.2.
|
30
|
+
vision_agent-0.2.121.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
31
|
+
vision_agent-0.2.121.dist-info/METADATA,sha256=OEbC_dogT2Hg9xLN2H8Zb2FCLQjxf1wfx_0TM1aJrYU,12255
|
32
|
+
vision_agent-0.2.121.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
|
33
|
+
vision_agent-0.2.121.dist-info/RECORD,,
|
File without changes
|
File without changes
|