vision-agent 0.2.193__py3-none-any.whl → 0.2.196__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vision_agent/.sim_tools/df.csv +640 -0
- vision_agent/.sim_tools/embs.npy +0 -0
- vision_agent/agent/__init__.py +2 -0
- vision_agent/agent/agent_utils.py +211 -3
- vision_agent/agent/vision_agent_coder.py +5 -113
- vision_agent/agent/vision_agent_coder_prompts_v2.py +119 -0
- vision_agent/agent/vision_agent_coder_v2.py +341 -0
- vision_agent/agent/vision_agent_planner.py +2 -2
- vision_agent/agent/vision_agent_planner_prompts.py +1 -1
- vision_agent/agent/vision_agent_planner_prompts_v2.py +748 -0
- vision_agent/agent/vision_agent_planner_v2.py +432 -0
- vision_agent/lmm/lmm.py +4 -0
- vision_agent/tools/__init__.py +2 -1
- vision_agent/tools/planner_tools.py +246 -0
- vision_agent/tools/tool_utils.py +65 -1
- vision_agent/tools/tools.py +76 -22
- vision_agent/utils/image_utils.py +12 -6
- vision_agent/utils/sim.py +65 -14
- {vision_agent-0.2.193.dist-info → vision_agent-0.2.196.dist-info}/METADATA +2 -1
- vision_agent-0.2.196.dist-info/RECORD +42 -0
- vision_agent-0.2.193.dist-info/RECORD +0 -35
- {vision_agent-0.2.193.dist-info → vision_agent-0.2.196.dist-info}/LICENSE +0 -0
- {vision_agent-0.2.193.dist-info → vision_agent-0.2.196.dist-info}/WHEEL +0 -0
@@ -0,0 +1,246 @@
|
|
1
|
+
import logging
|
2
|
+
import shutil
|
3
|
+
import tempfile
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
from PIL import Image
|
8
|
+
|
9
|
+
import vision_agent.tools as T
|
10
|
+
from vision_agent.agent.agent_utils import (
|
11
|
+
DefaultImports,
|
12
|
+
extract_code,
|
13
|
+
extract_json,
|
14
|
+
extract_tag,
|
15
|
+
)
|
16
|
+
from vision_agent.agent.vision_agent_planner_prompts_v2 import (
|
17
|
+
CATEGORIZE_TOOL_REQUEST,
|
18
|
+
FINALIZE_PLAN,
|
19
|
+
PICK_TOOL,
|
20
|
+
TEST_TOOLS,
|
21
|
+
TEST_TOOLS_EXAMPLE1,
|
22
|
+
TEST_TOOLS_EXAMPLE2,
|
23
|
+
)
|
24
|
+
from vision_agent.lmm import AnthropicLMM
|
25
|
+
from vision_agent.utils.execute import CodeInterpreterFactory
|
26
|
+
from vision_agent.utils.image_utils import convert_to_b64
|
27
|
+
from vision_agent.utils.sim import load_cached_sim
|
28
|
+
|
29
|
+
TOOL_FUNCTIONS = {tool.__name__: tool for tool in T.TOOLS}
|
30
|
+
TOOL_RECOMMENDER = load_cached_sim(T.TOOLS_DF)
|
31
|
+
|
32
|
+
_LOGGER = logging.getLogger(__name__)
|
33
|
+
EXAMPLES = f"\n{TEST_TOOLS_EXAMPLE1}\n{TEST_TOOLS_EXAMPLE2}\n"
|
34
|
+
|
35
|
+
|
36
|
+
def extract_tool_info(
|
37
|
+
tool_choice_context: Dict[str, Any]
|
38
|
+
) -> Tuple[Optional[Callable], str, str, str]:
|
39
|
+
tool_thoughts = tool_choice_context.get("thoughts", "")
|
40
|
+
tool_docstring = ""
|
41
|
+
tool = tool_choice_context.get("best_tool", None)
|
42
|
+
if tool in TOOL_FUNCTIONS:
|
43
|
+
tool = TOOL_FUNCTIONS[tool]
|
44
|
+
tool_docstring = T.TOOLS_INFO[tool.__name__]
|
45
|
+
|
46
|
+
return tool, tool_thoughts, tool_docstring, ""
|
47
|
+
|
48
|
+
|
49
|
+
def get_tool_for_task(
|
50
|
+
task: str, images: List[np.ndarray], exclude_tools: Optional[List[str]] = None
|
51
|
+
) -> None:
|
52
|
+
"""Given a task and one or more images this function will find a tool to accomplish
|
53
|
+
the jobs. It prints the tool documentation and thoughts on why it chose the tool.
|
54
|
+
|
55
|
+
It can produce tools for the following types of tasks:
|
56
|
+
- Object detection and counting
|
57
|
+
- Classification
|
58
|
+
- Segmentation
|
59
|
+
- OCR
|
60
|
+
- VQA
|
61
|
+
- Depth and pose estimation
|
62
|
+
- Video object tracking
|
63
|
+
|
64
|
+
Wait until the documentation is printed to use the function so you know what the
|
65
|
+
input and output signatures are.
|
66
|
+
|
67
|
+
Parameters:
|
68
|
+
task: str: The task to accomplish.
|
69
|
+
images: List[np.ndarray]: The images to use for the task.
|
70
|
+
exclude_tools: Optional[List[str]]: A list of tool names to exclude from the
|
71
|
+
recommendations. This is helpful if you are calling get_tool_for_task twice
|
72
|
+
and do not want the same tool recommended.
|
73
|
+
|
74
|
+
Returns:
|
75
|
+
The tool to use for the task is printed to stdout
|
76
|
+
|
77
|
+
Examples
|
78
|
+
--------
|
79
|
+
>>> get_tool_for_task("Give me an OCR model that can find 'hot chocolate' in the image", [image])
|
80
|
+
"""
|
81
|
+
lmm = AnthropicLMM()
|
82
|
+
|
83
|
+
with (
|
84
|
+
tempfile.TemporaryDirectory() as tmpdirname,
|
85
|
+
CodeInterpreterFactory.new_instance() as code_interpreter,
|
86
|
+
):
|
87
|
+
image_paths = []
|
88
|
+
for i, image in enumerate(images[:3]):
|
89
|
+
image_path = f"{tmpdirname}/image_{i}.png"
|
90
|
+
Image.fromarray(image).save(image_path)
|
91
|
+
image_paths.append(image_path)
|
92
|
+
|
93
|
+
query = lmm.generate(CATEGORIZE_TOOL_REQUEST.format(task=task))
|
94
|
+
category = extract_tag(query, "category") # type: ignore
|
95
|
+
if category is None:
|
96
|
+
category = task
|
97
|
+
else:
|
98
|
+
category = (
|
99
|
+
f"I need models from the {category.strip()} category of tools. {task}"
|
100
|
+
)
|
101
|
+
|
102
|
+
tool_docs = TOOL_RECOMMENDER.top_k(category, k=10, thresh=0.2)
|
103
|
+
if exclude_tools is not None and len(exclude_tools) > 0:
|
104
|
+
cleaned_tool_docs = []
|
105
|
+
for tool_doc in tool_docs:
|
106
|
+
if not tool_doc["name"] in exclude_tools:
|
107
|
+
cleaned_tool_docs.append(tool_doc)
|
108
|
+
tool_docs = cleaned_tool_docs
|
109
|
+
tool_docs_str = "\n".join([e["doc"] for e in tool_docs])
|
110
|
+
|
111
|
+
prompt = TEST_TOOLS.format(
|
112
|
+
tool_docs=tool_docs_str,
|
113
|
+
previous_attempts="",
|
114
|
+
user_request=task,
|
115
|
+
examples=EXAMPLES,
|
116
|
+
media=str(image_paths),
|
117
|
+
)
|
118
|
+
|
119
|
+
response = lmm.generate(prompt, media=image_paths)
|
120
|
+
code = extract_tag(response, "code") # type: ignore
|
121
|
+
if code is None:
|
122
|
+
raise ValueError(f"Could not extract code from response: {response}")
|
123
|
+
tool_output = code_interpreter.exec_isolation(
|
124
|
+
DefaultImports.prepend_imports(code)
|
125
|
+
)
|
126
|
+
tool_output_str = tool_output.text(include_results=False).strip()
|
127
|
+
|
128
|
+
count = 1
|
129
|
+
while (
|
130
|
+
not tool_output.success
|
131
|
+
or (len(tool_output.logs.stdout) == 0 and len(tool_output.logs.stderr) == 0)
|
132
|
+
) and count <= 3:
|
133
|
+
if tool_output_str.strip() == "":
|
134
|
+
tool_output_str = "EMPTY"
|
135
|
+
prompt = TEST_TOOLS.format(
|
136
|
+
tool_docs=tool_docs_str,
|
137
|
+
previous_attempts=f"<code>\n{code}\n</code>\nTOOL OUTPUT\n{tool_output_str}",
|
138
|
+
user_request=task,
|
139
|
+
examples=EXAMPLES,
|
140
|
+
media=str(image_paths),
|
141
|
+
)
|
142
|
+
code = extract_code(lmm.generate(prompt, media=image_paths)) # type: ignore
|
143
|
+
tool_output = code_interpreter.exec_isolation(
|
144
|
+
DefaultImports.prepend_imports(code)
|
145
|
+
)
|
146
|
+
tool_output_str = tool_output.text(include_results=False).strip()
|
147
|
+
|
148
|
+
error_message = ""
|
149
|
+
prompt = PICK_TOOL.format(
|
150
|
+
tool_docs=tool_docs_str,
|
151
|
+
user_request=task,
|
152
|
+
context=f"<code>\n{code}\n</code>\n<tool_output>\n{tool_output_str}\n</tool_output>",
|
153
|
+
previous_attempts=error_message,
|
154
|
+
)
|
155
|
+
|
156
|
+
response = lmm.generate(prompt, media=image_paths)
|
157
|
+
tool_choice_context = extract_tag(response, "json") # type: ignore
|
158
|
+
tool_choice_context_dict = extract_json(tool_choice_context) # type: ignore
|
159
|
+
|
160
|
+
tool, tool_thoughts, tool_docstring, error_message = extract_tool_info(
|
161
|
+
tool_choice_context_dict
|
162
|
+
)
|
163
|
+
|
164
|
+
count = 1
|
165
|
+
while tool is None and count <= 3:
|
166
|
+
prompt = PICK_TOOL.format(
|
167
|
+
tool_docs=tool_docs_str,
|
168
|
+
user_request=task,
|
169
|
+
context=f"<code>\n{code}\n</code>\n<tool_output>\n{tool_output_str}\n</tool_output>",
|
170
|
+
previous_attempts=error_message,
|
171
|
+
)
|
172
|
+
tool_choice_context_dict = extract_json(lmm.generate(prompt, media=image_paths)) # type: ignore
|
173
|
+
tool, tool_thoughts, tool_docstring, error_message = extract_tool_info(
|
174
|
+
tool_choice_context_dict
|
175
|
+
)
|
176
|
+
try:
|
177
|
+
shutil.rmtree(tmpdirname)
|
178
|
+
except Exception as e:
|
179
|
+
_LOGGER.error(f"Error removing temp directory: {e}")
|
180
|
+
|
181
|
+
print(
|
182
|
+
f"[get_tool_for_task output]\n{tool_thoughts}\n\nTool Documentation:\n{tool_docstring}\n[end of get_tool_for_task output]\n"
|
183
|
+
)
|
184
|
+
|
185
|
+
|
186
|
+
def finalize_plan(user_request: str, chain_of_thoughts: str) -> str:
|
187
|
+
"""Finalizes the plan by taking the user request and the chain of thoughts that
|
188
|
+
represent the plan and returns the finalized plan.
|
189
|
+
"""
|
190
|
+
lmm = AnthropicLMM()
|
191
|
+
prompt = FINALIZE_PLAN.format(
|
192
|
+
user_request=user_request, chain_of_thoughts=chain_of_thoughts
|
193
|
+
)
|
194
|
+
finalized_plan = cast(str, lmm.generate(prompt))
|
195
|
+
return finalized_plan
|
196
|
+
|
197
|
+
|
198
|
+
def claude35_vqa(prompt: str, medias: List[np.ndarray]) -> None:
|
199
|
+
"""Asks the Claude-3.5 model a question about the given media and returns an answer.
|
200
|
+
|
201
|
+
Parameters:
|
202
|
+
prompt: str: The question to ask the model.
|
203
|
+
medias: List[np.ndarray]: The images to ask the question about, it could also
|
204
|
+
be frames from a video. You can send up to 5 frames from a video.
|
205
|
+
"""
|
206
|
+
lmm = AnthropicLMM()
|
207
|
+
if isinstance(medias, np.ndarray):
|
208
|
+
medias = [medias]
|
209
|
+
if isinstance(medias, list) and len(medias) > 5:
|
210
|
+
medias = medias[:5]
|
211
|
+
all_media_b64 = [
|
212
|
+
"data:image/png;base64," + convert_to_b64(media) for media in medias
|
213
|
+
]
|
214
|
+
|
215
|
+
response = cast(str, lmm.generate(prompt, media=all_media_b64))
|
216
|
+
print(f"[claude35_vqa output]\n{response}\n[end of claude35_vqa output]")
|
217
|
+
|
218
|
+
|
219
|
+
def suggestion(prompt: str, medias: List[np.ndarray]) -> None:
|
220
|
+
"""Given your problem statement and the images, this will provide you with a
|
221
|
+
suggested plan on how to proceed. Always call suggestion when starting to solve
|
222
|
+
a problem.
|
223
|
+
|
224
|
+
Parameters:
|
225
|
+
prompt: str: The problem statement.
|
226
|
+
medias: List[np.ndarray]: The images to use for the problem
|
227
|
+
"""
|
228
|
+
try:
|
229
|
+
from .suggestion import suggestion_impl # type: ignore
|
230
|
+
|
231
|
+
suggestion = suggestion_impl(prompt, medias)
|
232
|
+
print(suggestion)
|
233
|
+
except ImportError:
|
234
|
+
print("")
|
235
|
+
|
236
|
+
|
237
|
+
PLANNER_TOOLS = [
|
238
|
+
claude35_vqa,
|
239
|
+
suggestion,
|
240
|
+
get_tool_for_task,
|
241
|
+
T.load_image,
|
242
|
+
T.save_image,
|
243
|
+
T.extract_frames_and_timestamps,
|
244
|
+
T.save_video,
|
245
|
+
]
|
246
|
+
PLANNER_DOCSTRING = T.get_tool_documentation(PLANNER_TOOLS) # type: ignore
|
vision_agent/tools/tool_utils.py
CHANGED
@@ -4,6 +4,7 @@ import os
|
|
4
4
|
from base64 import b64encode
|
5
5
|
from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple
|
6
6
|
|
7
|
+
import numpy as np
|
7
8
|
import pandas as pd
|
8
9
|
from IPython.display import display
|
9
10
|
from pydantic import BaseModel
|
@@ -14,6 +15,7 @@ from urllib3.util.retry import Retry
|
|
14
15
|
from vision_agent.tools.tools_types import BoundingBoxes
|
15
16
|
from vision_agent.utils.exceptions import RemoteToolCallFailed
|
16
17
|
from vision_agent.utils.execute import Error, MimeType
|
18
|
+
from vision_agent.utils.image_utils import normalize_bbox
|
17
19
|
from vision_agent.utils.type_defs import LandingaiAPIKey
|
18
20
|
|
19
21
|
_LOGGER = logging.getLogger(__name__)
|
@@ -170,7 +172,7 @@ def get_tool_descriptions_by_names(
|
|
170
172
|
|
171
173
|
|
172
174
|
def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
|
173
|
-
data: Dict[str, List[str]] = {"desc": [], "doc": []}
|
175
|
+
data: Dict[str, List[str]] = {"desc": [], "doc": [], "name": []}
|
174
176
|
|
175
177
|
for func in funcs:
|
176
178
|
desc = func.__doc__
|
@@ -182,6 +184,7 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
|
|
182
184
|
doc = f"{func.__name__}{inspect.signature(func)}:\n{func.__doc__}"
|
183
185
|
data["desc"].append(desc)
|
184
186
|
data["doc"].append(doc)
|
187
|
+
data["name"].append(func.__name__)
|
185
188
|
|
186
189
|
return pd.DataFrame(data) # type: ignore
|
187
190
|
|
@@ -256,3 +259,64 @@ def filter_bboxes_by_threshold(
|
|
256
259
|
bboxes: BoundingBoxes, threshold: float
|
257
260
|
) -> BoundingBoxes:
|
258
261
|
return list(filter(lambda bbox: bbox.score >= threshold, bboxes))
|
262
|
+
|
263
|
+
|
264
|
+
def add_bboxes_from_masks(
|
265
|
+
all_preds: List[List[Dict[str, Any]]],
|
266
|
+
) -> List[List[Dict[str, Any]]]:
|
267
|
+
for frame_preds in all_preds:
|
268
|
+
for preds in frame_preds:
|
269
|
+
if np.sum(preds["mask"]) == 0:
|
270
|
+
preds["bbox"] = []
|
271
|
+
else:
|
272
|
+
rows, cols = np.where(preds["mask"])
|
273
|
+
bbox = [
|
274
|
+
float(np.min(cols)),
|
275
|
+
float(np.min(rows)),
|
276
|
+
float(np.max(cols)),
|
277
|
+
float(np.max(rows)),
|
278
|
+
]
|
279
|
+
bbox = normalize_bbox(bbox, preds["mask"].shape)
|
280
|
+
preds["bbox"] = bbox
|
281
|
+
|
282
|
+
return all_preds
|
283
|
+
|
284
|
+
|
285
|
+
def calculate_iou(bbox1: List[float], bbox2: List[float]) -> float:
|
286
|
+
x1, y1, x2, y2 = bbox1
|
287
|
+
x3, y3, x4, y4 = bbox2
|
288
|
+
|
289
|
+
x_overlap = max(0, min(x2, x4) - max(x1, x3))
|
290
|
+
y_overlap = max(0, min(y2, y4) - max(y1, y3))
|
291
|
+
intersection = x_overlap * y_overlap
|
292
|
+
|
293
|
+
area1 = (x2 - x1) * (y2 - y1)
|
294
|
+
area2 = (x4 - x3) * (y4 - y3)
|
295
|
+
union = area1 + area2 - intersection
|
296
|
+
|
297
|
+
return intersection / union if union > 0 else 0
|
298
|
+
|
299
|
+
|
300
|
+
def single_nms(
|
301
|
+
preds: List[Dict[str, Any]], iou_threshold: float
|
302
|
+
) -> List[Dict[str, Any]]:
|
303
|
+
for i in range(len(preds)):
|
304
|
+
for j in range(i + 1, len(preds)):
|
305
|
+
if calculate_iou(preds[i]["bbox"], preds[j]["bbox"]) > iou_threshold:
|
306
|
+
if preds[i]["score"] > preds[j]["score"]:
|
307
|
+
preds[j]["score"] = 0
|
308
|
+
else:
|
309
|
+
preds[i]["score"] = 0
|
310
|
+
|
311
|
+
return [pred for pred in preds if pred["score"] > 0]
|
312
|
+
|
313
|
+
|
314
|
+
def nms(
|
315
|
+
all_preds: List[List[Dict[str, Any]]], iou_threshold: float
|
316
|
+
) -> List[List[Dict[str, Any]]]:
|
317
|
+
return_preds = []
|
318
|
+
for frame_preds in all_preds:
|
319
|
+
frame_preds = single_nms(frame_preds, iou_threshold)
|
320
|
+
return_preds.append(frame_preds)
|
321
|
+
|
322
|
+
return return_preds
|
vision_agent/tools/tools.py
CHANGED
@@ -17,15 +17,18 @@ from pillow_heif import register_heif_opener # type: ignore
|
|
17
17
|
from pytube import YouTube # type: ignore
|
18
18
|
|
19
19
|
from vision_agent.clients.landing_public_api import LandingPublicAPI
|
20
|
-
from vision_agent.lmm.lmm import OpenAILMM
|
20
|
+
from vision_agent.lmm.lmm import AnthropicLMM, OpenAILMM
|
21
21
|
from vision_agent.tools.tool_utils import (
|
22
|
+
add_bboxes_from_masks,
|
22
23
|
filter_bboxes_by_threshold,
|
23
24
|
get_tool_descriptions,
|
24
25
|
get_tool_documentation,
|
25
26
|
get_tools_df,
|
26
27
|
get_tools_info,
|
28
|
+
nms,
|
27
29
|
send_inference_request,
|
28
30
|
send_task_inference_request,
|
31
|
+
single_nms,
|
29
32
|
)
|
30
33
|
from vision_agent.tools.tools_types import JobStatus, ODResponseData
|
31
34
|
from vision_agent.utils.exceptions import FineTuneModelIsNotReady
|
@@ -260,8 +263,8 @@ def owl_v2_video(
|
|
260
263
|
...
|
261
264
|
]
|
262
265
|
"""
|
263
|
-
if len(frames) == 0:
|
264
|
-
raise ValueError("
|
266
|
+
if len(frames) == 0 or not isinstance(frames, List):
|
267
|
+
raise ValueError("Must provide a list of numpy arrays for frames")
|
265
268
|
|
266
269
|
image_size = frames[0].shape[:2]
|
267
270
|
buffer_bytes = frames_to_bytes(frames)
|
@@ -455,7 +458,7 @@ def florence2_sam2_image(
|
|
455
458
|
def florence2_sam2_video_tracking(
|
456
459
|
prompt: str,
|
457
460
|
frames: List[np.ndarray],
|
458
|
-
chunk_length: Optional[int] =
|
461
|
+
chunk_length: Optional[int] = 10,
|
459
462
|
fine_tune_id: Optional[str] = None,
|
460
463
|
) -> List[List[Dict[str, Any]]]:
|
461
464
|
"""'florence2_sam2_video_tracking' is a tool that can segment and track multiple
|
@@ -473,11 +476,11 @@ def florence2_sam2_video_tracking(
|
|
473
476
|
fine-tuned model ID here to use it.
|
474
477
|
|
475
478
|
Returns:
|
476
|
-
List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
|
477
|
-
and
|
478
|
-
the entities per frame. The label contains the object ID
|
479
|
-
name. The objects are only identified in the first framed
|
480
|
-
throughout the video.
|
479
|
+
List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
|
480
|
+
label,segment mask and bounding boxes. The outer list represents each frame and
|
481
|
+
the inner list is the entities per frame. The label contains the object ID
|
482
|
+
followed by the label name. The objects are only identified in the first framed
|
483
|
+
and tracked throughout the video.
|
481
484
|
|
482
485
|
Example
|
483
486
|
-------
|
@@ -486,6 +489,7 @@ def florence2_sam2_video_tracking(
|
|
486
489
|
[
|
487
490
|
{
|
488
491
|
'label': '0: dinosaur',
|
492
|
+
'bbox': [0.1, 0.11, 0.35, 0.4],
|
489
493
|
'mask': array([[0, 0, 0, ..., 0, 0, 0],
|
490
494
|
[0, 0, 0, ..., 0, 0, 0],
|
491
495
|
...,
|
@@ -496,8 +500,8 @@ def florence2_sam2_video_tracking(
|
|
496
500
|
...
|
497
501
|
]
|
498
502
|
"""
|
499
|
-
if len(frames) == 0:
|
500
|
-
raise ValueError("
|
503
|
+
if len(frames) == 0 or not isinstance(frames, List):
|
504
|
+
raise ValueError("Must provide a list of numpy arrays for frames")
|
501
505
|
|
502
506
|
buffer_bytes = frames_to_bytes(frames)
|
503
507
|
files = [("video", buffer_bytes)]
|
@@ -535,7 +539,8 @@ def florence2_sam2_video_tracking(
|
|
535
539
|
label = str(detection["id"]) + ": " + detection["label"]
|
536
540
|
return_frame_data.append({"label": label, "mask": mask, "score": 1.0})
|
537
541
|
return_data.append(return_frame_data)
|
538
|
-
|
542
|
+
return_data = add_bboxes_from_masks(return_data)
|
543
|
+
return nms(return_data, iou_threshold=0.95)
|
539
544
|
|
540
545
|
|
541
546
|
def ocr(image: np.ndarray) -> List[Dict[str, Any]]:
|
@@ -677,8 +682,9 @@ def countgd_counting(
|
|
677
682
|
image: np.ndarray,
|
678
683
|
box_threshold: float = 0.23,
|
679
684
|
) -> List[Dict[str, Any]]:
|
680
|
-
"""'countgd_counting' is a tool that can
|
681
|
-
|
685
|
+
"""'countgd_counting' is a tool that can detect multiple instances of an object
|
686
|
+
given a text prompt. It is particularly useful when trying to detect and count a
|
687
|
+
large number of objects. It returns a list of bounding boxes with normalized
|
682
688
|
coordinates, label names and associated confidence scores.
|
683
689
|
|
684
690
|
Parameters:
|
@@ -711,7 +717,7 @@ def countgd_counting(
|
|
711
717
|
buffer_bytes = numpy_to_bytes(image)
|
712
718
|
files = [("image", buffer_bytes)]
|
713
719
|
payload = {
|
714
|
-
"prompts": [prompt.replace(", ", "
|
720
|
+
"prompts": [prompt.replace(", ", ". ")],
|
715
721
|
"confidence": box_threshold, # still not being used in the API
|
716
722
|
"model": "countgd",
|
717
723
|
}
|
@@ -733,7 +739,8 @@ def countgd_counting(
|
|
733
739
|
]
|
734
740
|
# TODO: remove this once we start to use the confidence on countgd
|
735
741
|
filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold)
|
736
|
-
|
742
|
+
return_data = [bbox.model_dump() for bbox in filtered_bboxes]
|
743
|
+
return single_nms(return_data, iou_threshold=0.80)
|
737
744
|
|
738
745
|
|
739
746
|
def countgd_example_based_counting(
|
@@ -864,9 +871,10 @@ def ixc25_image_vqa(prompt: str, image: np.ndarray) -> str:
|
|
864
871
|
|
865
872
|
|
866
873
|
def qwen2_vl_images_vqa(prompt: str, images: List[np.ndarray]) -> str:
|
867
|
-
"""'qwen2_vl_images_vqa' is a tool that can answer any questions about arbitrary
|
868
|
-
including regular images or images of documents or presentations. It
|
869
|
-
as an answer to
|
874
|
+
"""'qwen2_vl_images_vqa' is a tool that can answer any questions about arbitrary
|
875
|
+
images including regular images or images of documents or presentations. It can be
|
876
|
+
very useful for document QA or OCR text extraction. It returns text as an answer to
|
877
|
+
the question.
|
870
878
|
|
871
879
|
Parameters:
|
872
880
|
prompt (str): The question about the document image
|
@@ -880,6 +888,9 @@ def qwen2_vl_images_vqa(prompt: str, images: List[np.ndarray]) -> str:
|
|
880
888
|
>>> qwen2_vl_images_vqa('Give a summary of the document', images)
|
881
889
|
'The document talks about the history of the United States of America and its...'
|
882
890
|
"""
|
891
|
+
if isinstance(images, np.ndarray):
|
892
|
+
images = [images]
|
893
|
+
|
883
894
|
for image in images:
|
884
895
|
if image.shape[0] < 1 or image.shape[1] < 1:
|
885
896
|
raise ValueError(f"Image is empty, image shape: {image.shape}")
|
@@ -896,6 +907,30 @@ def qwen2_vl_images_vqa(prompt: str, images: List[np.ndarray]) -> str:
|
|
896
907
|
return cast(str, data)
|
897
908
|
|
898
909
|
|
910
|
+
def claude35_text_extraction(image: np.ndarray) -> str:
|
911
|
+
"""'claude35_text_extraction' is a tool that can extract text from an image. It
|
912
|
+
returns the extracted text as a string and can be used as an alternative to OCR if
|
913
|
+
you do not need to know the exact bounding box of the text.
|
914
|
+
|
915
|
+
Parameters:
|
916
|
+
image (np.ndarray): The image to extract text from.
|
917
|
+
|
918
|
+
Returns:
|
919
|
+
str: The extracted text from the image.
|
920
|
+
"""
|
921
|
+
|
922
|
+
lmm = AnthropicLMM()
|
923
|
+
buffer = io.BytesIO()
|
924
|
+
Image.fromarray(image).save(buffer, format="PNG")
|
925
|
+
image_bytes = buffer.getvalue()
|
926
|
+
image_b64 = "data:image/png;base64," + encode_image_bytes(image_bytes)
|
927
|
+
text = lmm.generate(
|
928
|
+
"Extract and return any text you see in this image and nothing else. If you do not read any text respond with an empty string.",
|
929
|
+
[image_b64],
|
930
|
+
)
|
931
|
+
return cast(str, text)
|
932
|
+
|
933
|
+
|
899
934
|
def ixc25_video_vqa(prompt: str, frames: List[np.ndarray]) -> str:
|
900
935
|
"""'ixc25_video_vqa' is a tool that can answer any questions about arbitrary videos
|
901
936
|
including regular videos or videos of documents or presentations. It returns text
|
@@ -944,6 +979,9 @@ def qwen2_vl_video_vqa(prompt: str, frames: List[np.ndarray]) -> str:
|
|
944
979
|
'Lionel Messi'
|
945
980
|
"""
|
946
981
|
|
982
|
+
if len(frames) == 0 or not isinstance(frames, List):
|
983
|
+
raise ValueError("Must provide a list of numpy arrays for frames")
|
984
|
+
|
947
985
|
buffer_bytes = frames_to_bytes(frames)
|
948
986
|
files = [("video", buffer_bytes)]
|
949
987
|
payload = {
|
@@ -2157,7 +2195,8 @@ def overlay_bounding_boxes(
|
|
2157
2195
|
bboxes = bbox_int[i]
|
2158
2196
|
bboxes = sorted(bboxes, key=lambda x: x["label"], reverse=True)
|
2159
2197
|
|
2160
|
-
if
|
2198
|
+
# if more than 50 boxes use small boxes to indicate objects else use regular boxes
|
2199
|
+
if len(bboxes) > 50:
|
2161
2200
|
pil_image = _plot_counting(pil_image, bboxes, color)
|
2162
2201
|
else:
|
2163
2202
|
width, height = pil_image.size
|
@@ -2188,7 +2227,14 @@ def overlay_bounding_boxes(
|
|
2188
2227
|
draw.text((box[0], box[1]), text, fill="black", font=font)
|
2189
2228
|
|
2190
2229
|
frame_out.append(np.array(pil_image))
|
2191
|
-
|
2230
|
+
return_frame = frame_out[0] if len(frame_out) == 1 else frame_out
|
2231
|
+
|
2232
|
+
if isinstance(return_frame, np.ndarray):
|
2233
|
+
from IPython.display import display
|
2234
|
+
|
2235
|
+
display(Image.fromarray(return_frame))
|
2236
|
+
|
2237
|
+
return return_frame # type: ignore
|
2192
2238
|
|
2193
2239
|
|
2194
2240
|
def _get_text_coords_from_mask(
|
@@ -2300,7 +2346,14 @@ def overlay_segmentation_masks(
|
|
2300
2346
|
draw.rectangle((x, y, text_box[2], text_box[3]), fill=color[label])
|
2301
2347
|
draw.text((x, y), text, fill="black", font=font)
|
2302
2348
|
frame_out.append(np.array(pil_image))
|
2303
|
-
|
2349
|
+
return_frame = frame_out[0] if len(frame_out) == 1 else frame_out
|
2350
|
+
|
2351
|
+
if isinstance(return_frame, np.ndarray):
|
2352
|
+
from IPython.display import display
|
2353
|
+
|
2354
|
+
display(Image.fromarray(return_frame))
|
2355
|
+
|
2356
|
+
return return_frame # type: ignore
|
2304
2357
|
|
2305
2358
|
|
2306
2359
|
def overlay_heat_map(
|
@@ -2408,6 +2461,7 @@ FUNCTION_TOOLS = [
|
|
2408
2461
|
florence2_sam2_image,
|
2409
2462
|
florence2_sam2_video_tracking,
|
2410
2463
|
florence2_phrase_grounding,
|
2464
|
+
claude35_text_extraction,
|
2411
2465
|
detr_segmentation,
|
2412
2466
|
depth_anything_v2,
|
2413
2467
|
generate_pose_image,
|
@@ -42,10 +42,10 @@ def normalize_bbox(
|
|
42
42
|
) -> List[float]:
|
43
43
|
r"""Normalize the bounding box coordinates to be between 0 and 1."""
|
44
44
|
x1, y1, x2, y2 = bbox
|
45
|
-
x1 = round(x1 / image_size[1], 2)
|
46
|
-
y1 = round(y1 / image_size[0], 2)
|
47
|
-
x2 = round(x2 / image_size[1], 2)
|
48
|
-
y2 = round(y2 / image_size[0], 2)
|
45
|
+
x1 = max(round(x1 / image_size[1], 2), 0)
|
46
|
+
y1 = max(round(y1 / image_size[0], 2), 0)
|
47
|
+
x2 = min(round(x2 / image_size[1], 2), image_size[1])
|
48
|
+
y2 = min(round(y2 / image_size[0], 2), image_size[0])
|
49
49
|
return [x1, y1, x2, y2]
|
50
50
|
|
51
51
|
|
@@ -175,9 +175,15 @@ def encode_media(media: Union[str, Path], resize: Optional[int] = None) -> str:
|
|
175
175
|
return media[:-4] + ".png"
|
176
176
|
return media
|
177
177
|
|
178
|
-
# if media is
|
178
|
+
# if media is in base64 ensure it's the correct resize
|
179
179
|
if isinstance(media, str) and media.startswith("data:image/"):
|
180
|
-
|
180
|
+
image_pil = b64_to_pil(media)
|
181
|
+
if resize is not None:
|
182
|
+
if image_pil.size[0] > resize or image_pil.size[1] > resize:
|
183
|
+
image_pil.thumbnail((resize, resize))
|
184
|
+
buffer = io.BytesIO()
|
185
|
+
image_pil.save(buffer, format="PNG")
|
186
|
+
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
181
187
|
|
182
188
|
extension = "png"
|
183
189
|
extension = Path(media).suffix
|