vision-agent 0.2.192__py3-none-any.whl → 0.2.195__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- vision_agent/.sim_tools/df.csv +640 -0
- vision_agent/.sim_tools/embs.npy +0 -0
- vision_agent/agent/__init__.py +2 -0
- vision_agent/agent/agent_utils.py +211 -3
- vision_agent/agent/vision_agent_coder.py +5 -113
- vision_agent/agent/vision_agent_coder_prompts_v2.py +119 -0
- vision_agent/agent/vision_agent_coder_v2.py +341 -0
- vision_agent/agent/vision_agent_planner.py +2 -2
- vision_agent/agent/vision_agent_planner_prompts.py +1 -1
- vision_agent/agent/vision_agent_planner_prompts_v2.py +748 -0
- vision_agent/agent/vision_agent_planner_v2.py +432 -0
- vision_agent/lmm/lmm.py +4 -0
- vision_agent/tools/__init__.py +2 -1
- vision_agent/tools/planner_tools.py +246 -0
- vision_agent/tools/tool_utils.py +65 -1
- vision_agent/tools/tools.py +98 -35
- vision_agent/utils/image_utils.py +12 -6
- vision_agent/utils/sim.py +65 -14
- {vision_agent-0.2.192.dist-info → vision_agent-0.2.195.dist-info}/METADATA +1 -1
- vision_agent-0.2.195.dist-info/RECORD +42 -0
- vision_agent-0.2.192.dist-info/RECORD +0 -35
- {vision_agent-0.2.192.dist-info → vision_agent-0.2.195.dist-info}/LICENSE +0 -0
- {vision_agent-0.2.192.dist-info → vision_agent-0.2.195.dist-info}/WHEEL +0 -0
@@ -0,0 +1,246 @@
|
|
1
|
+
import logging
|
2
|
+
import shutil
|
3
|
+
import tempfile
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
from PIL import Image
|
8
|
+
|
9
|
+
import vision_agent.tools as T
|
10
|
+
from vision_agent.agent.agent_utils import (
|
11
|
+
DefaultImports,
|
12
|
+
extract_code,
|
13
|
+
extract_json,
|
14
|
+
extract_tag,
|
15
|
+
)
|
16
|
+
from vision_agent.agent.vision_agent_planner_prompts_v2 import (
|
17
|
+
CATEGORIZE_TOOL_REQUEST,
|
18
|
+
FINALIZE_PLAN,
|
19
|
+
PICK_TOOL,
|
20
|
+
TEST_TOOLS,
|
21
|
+
TEST_TOOLS_EXAMPLE1,
|
22
|
+
TEST_TOOLS_EXAMPLE2,
|
23
|
+
)
|
24
|
+
from vision_agent.lmm import AnthropicLMM
|
25
|
+
from vision_agent.utils.execute import CodeInterpreterFactory
|
26
|
+
from vision_agent.utils.image_utils import convert_to_b64
|
27
|
+
from vision_agent.utils.sim import load_cached_sim
|
28
|
+
|
29
|
+
TOOL_FUNCTIONS = {tool.__name__: tool for tool in T.TOOLS}
|
30
|
+
TOOL_RECOMMENDER = load_cached_sim(T.TOOLS_DF)
|
31
|
+
|
32
|
+
_LOGGER = logging.getLogger(__name__)
|
33
|
+
EXAMPLES = f"\n{TEST_TOOLS_EXAMPLE1}\n{TEST_TOOLS_EXAMPLE2}\n"
|
34
|
+
|
35
|
+
|
36
|
+
def extract_tool_info(
|
37
|
+
tool_choice_context: Dict[str, Any]
|
38
|
+
) -> Tuple[Optional[Callable], str, str, str]:
|
39
|
+
tool_thoughts = tool_choice_context.get("thoughts", "")
|
40
|
+
tool_docstring = ""
|
41
|
+
tool = tool_choice_context.get("best_tool", None)
|
42
|
+
if tool in TOOL_FUNCTIONS:
|
43
|
+
tool = TOOL_FUNCTIONS[tool]
|
44
|
+
tool_docstring = T.TOOLS_INFO[tool.__name__]
|
45
|
+
|
46
|
+
return tool, tool_thoughts, tool_docstring, ""
|
47
|
+
|
48
|
+
|
49
|
+
def get_tool_for_task(
|
50
|
+
task: str, images: List[np.ndarray], exclude_tools: Optional[List[str]] = None
|
51
|
+
) -> None:
|
52
|
+
"""Given a task and one or more images this function will find a tool to accomplish
|
53
|
+
the jobs. It prints the tool documentation and thoughts on why it chose the tool.
|
54
|
+
|
55
|
+
It can produce tools for the following types of tasks:
|
56
|
+
- Object detection and counting
|
57
|
+
- Classification
|
58
|
+
- Segmentation
|
59
|
+
- OCR
|
60
|
+
- VQA
|
61
|
+
- Depth and pose estimation
|
62
|
+
- Video object tracking
|
63
|
+
|
64
|
+
Wait until the documentation is printed to use the function so you know what the
|
65
|
+
input and output signatures are.
|
66
|
+
|
67
|
+
Parameters:
|
68
|
+
task: str: The task to accomplish.
|
69
|
+
images: List[np.ndarray]: The images to use for the task.
|
70
|
+
exclude_tools: Optional[List[str]]: A list of tool names to exclude from the
|
71
|
+
recommendations. This is helpful if you are calling get_tool_for_task twice
|
72
|
+
and do not want the same tool recommended.
|
73
|
+
|
74
|
+
Returns:
|
75
|
+
The tool to use for the task is printed to stdout
|
76
|
+
|
77
|
+
Examples
|
78
|
+
--------
|
79
|
+
>>> get_tool_for_task("Give me an OCR model that can find 'hot chocolate' in the image", [image])
|
80
|
+
"""
|
81
|
+
lmm = AnthropicLMM()
|
82
|
+
|
83
|
+
with (
|
84
|
+
tempfile.TemporaryDirectory() as tmpdirname,
|
85
|
+
CodeInterpreterFactory.new_instance() as code_interpreter,
|
86
|
+
):
|
87
|
+
image_paths = []
|
88
|
+
for i, image in enumerate(images[:3]):
|
89
|
+
image_path = f"{tmpdirname}/image_{i}.png"
|
90
|
+
Image.fromarray(image).save(image_path)
|
91
|
+
image_paths.append(image_path)
|
92
|
+
|
93
|
+
query = lmm.generate(CATEGORIZE_TOOL_REQUEST.format(task=task))
|
94
|
+
category = extract_tag(query, "category") # type: ignore
|
95
|
+
if category is None:
|
96
|
+
category = task
|
97
|
+
else:
|
98
|
+
category = (
|
99
|
+
f"I need models from the {category.strip()} category of tools. {task}"
|
100
|
+
)
|
101
|
+
|
102
|
+
tool_docs = TOOL_RECOMMENDER.top_k(category, k=10, thresh=0.2)
|
103
|
+
if exclude_tools is not None and len(exclude_tools) > 0:
|
104
|
+
cleaned_tool_docs = []
|
105
|
+
for tool_doc in tool_docs:
|
106
|
+
if not tool_doc["name"] in exclude_tools:
|
107
|
+
cleaned_tool_docs.append(tool_doc)
|
108
|
+
tool_docs = cleaned_tool_docs
|
109
|
+
tool_docs_str = "\n".join([e["doc"] for e in tool_docs])
|
110
|
+
|
111
|
+
prompt = TEST_TOOLS.format(
|
112
|
+
tool_docs=tool_docs_str,
|
113
|
+
previous_attempts="",
|
114
|
+
user_request=task,
|
115
|
+
examples=EXAMPLES,
|
116
|
+
media=str(image_paths),
|
117
|
+
)
|
118
|
+
|
119
|
+
response = lmm.generate(prompt, media=image_paths)
|
120
|
+
code = extract_tag(response, "code") # type: ignore
|
121
|
+
if code is None:
|
122
|
+
raise ValueError(f"Could not extract code from response: {response}")
|
123
|
+
tool_output = code_interpreter.exec_isolation(
|
124
|
+
DefaultImports.prepend_imports(code)
|
125
|
+
)
|
126
|
+
tool_output_str = tool_output.text(include_results=False).strip()
|
127
|
+
|
128
|
+
count = 1
|
129
|
+
while (
|
130
|
+
not tool_output.success
|
131
|
+
or (len(tool_output.logs.stdout) == 0 and len(tool_output.logs.stderr) == 0)
|
132
|
+
) and count <= 3:
|
133
|
+
if tool_output_str.strip() == "":
|
134
|
+
tool_output_str = "EMPTY"
|
135
|
+
prompt = TEST_TOOLS.format(
|
136
|
+
tool_docs=tool_docs_str,
|
137
|
+
previous_attempts=f"<code>\n{code}\n</code>\nTOOL OUTPUT\n{tool_output_str}",
|
138
|
+
user_request=task,
|
139
|
+
examples=EXAMPLES,
|
140
|
+
media=str(image_paths),
|
141
|
+
)
|
142
|
+
code = extract_code(lmm.generate(prompt, media=image_paths)) # type: ignore
|
143
|
+
tool_output = code_interpreter.exec_isolation(
|
144
|
+
DefaultImports.prepend_imports(code)
|
145
|
+
)
|
146
|
+
tool_output_str = tool_output.text(include_results=False).strip()
|
147
|
+
|
148
|
+
error_message = ""
|
149
|
+
prompt = PICK_TOOL.format(
|
150
|
+
tool_docs=tool_docs_str,
|
151
|
+
user_request=task,
|
152
|
+
context=f"<code>\n{code}\n</code>\n<tool_output>\n{tool_output_str}\n</tool_output>",
|
153
|
+
previous_attempts=error_message,
|
154
|
+
)
|
155
|
+
|
156
|
+
response = lmm.generate(prompt, media=image_paths)
|
157
|
+
tool_choice_context = extract_tag(response, "json") # type: ignore
|
158
|
+
tool_choice_context_dict = extract_json(tool_choice_context) # type: ignore
|
159
|
+
|
160
|
+
tool, tool_thoughts, tool_docstring, error_message = extract_tool_info(
|
161
|
+
tool_choice_context_dict
|
162
|
+
)
|
163
|
+
|
164
|
+
count = 1
|
165
|
+
while tool is None and count <= 3:
|
166
|
+
prompt = PICK_TOOL.format(
|
167
|
+
tool_docs=tool_docs_str,
|
168
|
+
user_request=task,
|
169
|
+
context=f"<code>\n{code}\n</code>\n<tool_output>\n{tool_output_str}\n</tool_output>",
|
170
|
+
previous_attempts=error_message,
|
171
|
+
)
|
172
|
+
tool_choice_context_dict = extract_json(lmm.generate(prompt, media=image_paths)) # type: ignore
|
173
|
+
tool, tool_thoughts, tool_docstring, error_message = extract_tool_info(
|
174
|
+
tool_choice_context_dict
|
175
|
+
)
|
176
|
+
try:
|
177
|
+
shutil.rmtree(tmpdirname)
|
178
|
+
except Exception as e:
|
179
|
+
_LOGGER.error(f"Error removing temp directory: {e}")
|
180
|
+
|
181
|
+
print(
|
182
|
+
f"[get_tool_for_task output]\n{tool_thoughts}\n\nTool Documentation:\n{tool_docstring}\n[end of get_tool_for_task output]\n"
|
183
|
+
)
|
184
|
+
|
185
|
+
|
186
|
+
def finalize_plan(user_request: str, chain_of_thoughts: str) -> str:
|
187
|
+
"""Finalizes the plan by taking the user request and the chain of thoughts that
|
188
|
+
represent the plan and returns the finalized plan.
|
189
|
+
"""
|
190
|
+
lmm = AnthropicLMM()
|
191
|
+
prompt = FINALIZE_PLAN.format(
|
192
|
+
user_request=user_request, chain_of_thoughts=chain_of_thoughts
|
193
|
+
)
|
194
|
+
finalized_plan = cast(str, lmm.generate(prompt))
|
195
|
+
return finalized_plan
|
196
|
+
|
197
|
+
|
198
|
+
def claude35_vqa(prompt: str, medias: List[np.ndarray]) -> None:
|
199
|
+
"""Asks the Claude-3.5 model a question about the given media and returns an answer.
|
200
|
+
|
201
|
+
Parameters:
|
202
|
+
prompt: str: The question to ask the model.
|
203
|
+
medias: List[np.ndarray]: The images to ask the question about, it could also
|
204
|
+
be frames from a video. You can send up to 5 frames from a video.
|
205
|
+
"""
|
206
|
+
lmm = AnthropicLMM()
|
207
|
+
if isinstance(medias, np.ndarray):
|
208
|
+
medias = [medias]
|
209
|
+
if isinstance(medias, list) and len(medias) > 5:
|
210
|
+
medias = medias[:5]
|
211
|
+
all_media_b64 = [
|
212
|
+
"data:image/png;base64," + convert_to_b64(media) for media in medias
|
213
|
+
]
|
214
|
+
|
215
|
+
response = cast(str, lmm.generate(prompt, media=all_media_b64))
|
216
|
+
print(f"[claude35_vqa output]\n{response}\n[end of claude35_vqa output]")
|
217
|
+
|
218
|
+
|
219
|
+
def suggestion(prompt: str, medias: List[np.ndarray]) -> None:
|
220
|
+
"""Given your problem statement and the images, this will provide you with a
|
221
|
+
suggested plan on how to proceed. Always call suggestion when starting to solve
|
222
|
+
a problem.
|
223
|
+
|
224
|
+
Parameters:
|
225
|
+
prompt: str: The problem statement.
|
226
|
+
medias: List[np.ndarray]: The images to use for the problem
|
227
|
+
"""
|
228
|
+
try:
|
229
|
+
from .suggestion import suggestion_impl # type: ignore
|
230
|
+
|
231
|
+
suggestion = suggestion_impl(prompt, medias)
|
232
|
+
print(suggestion)
|
233
|
+
except ImportError:
|
234
|
+
print("")
|
235
|
+
|
236
|
+
|
237
|
+
PLANNER_TOOLS = [
|
238
|
+
claude35_vqa,
|
239
|
+
suggestion,
|
240
|
+
get_tool_for_task,
|
241
|
+
T.load_image,
|
242
|
+
T.save_image,
|
243
|
+
T.extract_frames_and_timestamps,
|
244
|
+
T.save_video,
|
245
|
+
]
|
246
|
+
PLANNER_DOCSTRING = T.get_tool_documentation(PLANNER_TOOLS) # type: ignore
|
vision_agent/tools/tool_utils.py
CHANGED
@@ -4,6 +4,7 @@ import os
|
|
4
4
|
from base64 import b64encode
|
5
5
|
from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple
|
6
6
|
|
7
|
+
import numpy as np
|
7
8
|
import pandas as pd
|
8
9
|
from IPython.display import display
|
9
10
|
from pydantic import BaseModel
|
@@ -14,6 +15,7 @@ from urllib3.util.retry import Retry
|
|
14
15
|
from vision_agent.tools.tools_types import BoundingBoxes
|
15
16
|
from vision_agent.utils.exceptions import RemoteToolCallFailed
|
16
17
|
from vision_agent.utils.execute import Error, MimeType
|
18
|
+
from vision_agent.utils.image_utils import normalize_bbox
|
17
19
|
from vision_agent.utils.type_defs import LandingaiAPIKey
|
18
20
|
|
19
21
|
_LOGGER = logging.getLogger(__name__)
|
@@ -170,7 +172,7 @@ def get_tool_descriptions_by_names(
|
|
170
172
|
|
171
173
|
|
172
174
|
def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
|
173
|
-
data: Dict[str, List[str]] = {"desc": [], "doc": []}
|
175
|
+
data: Dict[str, List[str]] = {"desc": [], "doc": [], "name": []}
|
174
176
|
|
175
177
|
for func in funcs:
|
176
178
|
desc = func.__doc__
|
@@ -182,6 +184,7 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
|
|
182
184
|
doc = f"{func.__name__}{inspect.signature(func)}:\n{func.__doc__}"
|
183
185
|
data["desc"].append(desc)
|
184
186
|
data["doc"].append(doc)
|
187
|
+
data["name"].append(func.__name__)
|
185
188
|
|
186
189
|
return pd.DataFrame(data) # type: ignore
|
187
190
|
|
@@ -256,3 +259,64 @@ def filter_bboxes_by_threshold(
|
|
256
259
|
bboxes: BoundingBoxes, threshold: float
|
257
260
|
) -> BoundingBoxes:
|
258
261
|
return list(filter(lambda bbox: bbox.score >= threshold, bboxes))
|
262
|
+
|
263
|
+
|
264
|
+
def add_bboxes_from_masks(
|
265
|
+
all_preds: List[List[Dict[str, Any]]],
|
266
|
+
) -> List[List[Dict[str, Any]]]:
|
267
|
+
for frame_preds in all_preds:
|
268
|
+
for preds in frame_preds:
|
269
|
+
if np.sum(preds["mask"]) == 0:
|
270
|
+
preds["bbox"] = []
|
271
|
+
else:
|
272
|
+
rows, cols = np.where(preds["mask"])
|
273
|
+
bbox = [
|
274
|
+
float(np.min(cols)),
|
275
|
+
float(np.min(rows)),
|
276
|
+
float(np.max(cols)),
|
277
|
+
float(np.max(rows)),
|
278
|
+
]
|
279
|
+
bbox = normalize_bbox(bbox, preds["mask"].shape)
|
280
|
+
preds["bbox"] = bbox
|
281
|
+
|
282
|
+
return all_preds
|
283
|
+
|
284
|
+
|
285
|
+
def calculate_iou(bbox1: List[float], bbox2: List[float]) -> float:
|
286
|
+
x1, y1, x2, y2 = bbox1
|
287
|
+
x3, y3, x4, y4 = bbox2
|
288
|
+
|
289
|
+
x_overlap = max(0, min(x2, x4) - max(x1, x3))
|
290
|
+
y_overlap = max(0, min(y2, y4) - max(y1, y3))
|
291
|
+
intersection = x_overlap * y_overlap
|
292
|
+
|
293
|
+
area1 = (x2 - x1) * (y2 - y1)
|
294
|
+
area2 = (x4 - x3) * (y4 - y3)
|
295
|
+
union = area1 + area2 - intersection
|
296
|
+
|
297
|
+
return intersection / union if union > 0 else 0
|
298
|
+
|
299
|
+
|
300
|
+
def single_nms(
|
301
|
+
preds: List[Dict[str, Any]], iou_threshold: float
|
302
|
+
) -> List[Dict[str, Any]]:
|
303
|
+
for i in range(len(preds)):
|
304
|
+
for j in range(i + 1, len(preds)):
|
305
|
+
if calculate_iou(preds[i]["bbox"], preds[j]["bbox"]) > iou_threshold:
|
306
|
+
if preds[i]["score"] > preds[j]["score"]:
|
307
|
+
preds[j]["score"] = 0
|
308
|
+
else:
|
309
|
+
preds[i]["score"] = 0
|
310
|
+
|
311
|
+
return [pred for pred in preds if pred["score"] > 0]
|
312
|
+
|
313
|
+
|
314
|
+
def nms(
|
315
|
+
all_preds: List[List[Dict[str, Any]]], iou_threshold: float
|
316
|
+
) -> List[List[Dict[str, Any]]]:
|
317
|
+
return_preds = []
|
318
|
+
for frame_preds in all_preds:
|
319
|
+
frame_preds = single_nms(frame_preds, iou_threshold)
|
320
|
+
return_preds.append(frame_preds)
|
321
|
+
|
322
|
+
return return_preds
|
vision_agent/tools/tools.py
CHANGED
@@ -17,15 +17,18 @@ from pillow_heif import register_heif_opener # type: ignore
|
|
17
17
|
from pytube import YouTube # type: ignore
|
18
18
|
|
19
19
|
from vision_agent.clients.landing_public_api import LandingPublicAPI
|
20
|
-
from vision_agent.lmm.lmm import OpenAILMM
|
20
|
+
from vision_agent.lmm.lmm import AnthropicLMM, OpenAILMM
|
21
21
|
from vision_agent.tools.tool_utils import (
|
22
|
+
add_bboxes_from_masks,
|
22
23
|
filter_bboxes_by_threshold,
|
23
24
|
get_tool_descriptions,
|
24
25
|
get_tool_documentation,
|
25
26
|
get_tools_df,
|
26
27
|
get_tools_info,
|
28
|
+
nms,
|
27
29
|
send_inference_request,
|
28
30
|
send_task_inference_request,
|
31
|
+
single_nms,
|
29
32
|
)
|
30
33
|
from vision_agent.tools.tools_types import JobStatus, ODResponseData
|
31
34
|
from vision_agent.utils.exceptions import FineTuneModelIsNotReady
|
@@ -260,8 +263,8 @@ def owl_v2_video(
|
|
260
263
|
...
|
261
264
|
]
|
262
265
|
"""
|
263
|
-
if len(frames) == 0:
|
264
|
-
raise ValueError("
|
266
|
+
if len(frames) == 0 or not isinstance(frames, List):
|
267
|
+
raise ValueError("Must provide a list of numpy arrays for frames")
|
265
268
|
|
266
269
|
image_size = frames[0].shape[:2]
|
267
270
|
buffer_bytes = frames_to_bytes(frames)
|
@@ -455,7 +458,7 @@ def florence2_sam2_image(
|
|
455
458
|
def florence2_sam2_video_tracking(
|
456
459
|
prompt: str,
|
457
460
|
frames: List[np.ndarray],
|
458
|
-
chunk_length: Optional[int] =
|
461
|
+
chunk_length: Optional[int] = 10,
|
459
462
|
fine_tune_id: Optional[str] = None,
|
460
463
|
) -> List[List[Dict[str, Any]]]:
|
461
464
|
"""'florence2_sam2_video_tracking' is a tool that can segment and track multiple
|
@@ -473,11 +476,11 @@ def florence2_sam2_video_tracking(
|
|
473
476
|
fine-tuned model ID here to use it.
|
474
477
|
|
475
478
|
Returns:
|
476
|
-
List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
|
477
|
-
and
|
478
|
-
the entities per frame. The label contains the object ID
|
479
|
-
name. The objects are only identified in the first framed
|
480
|
-
throughout the video.
|
479
|
+
List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
|
480
|
+
label,segment mask and bounding boxes. The outer list represents each frame and
|
481
|
+
the inner list is the entities per frame. The label contains the object ID
|
482
|
+
followed by the label name. The objects are only identified in the first framed
|
483
|
+
and tracked throughout the video.
|
481
484
|
|
482
485
|
Example
|
483
486
|
-------
|
@@ -486,6 +489,7 @@ def florence2_sam2_video_tracking(
|
|
486
489
|
[
|
487
490
|
{
|
488
491
|
'label': '0: dinosaur',
|
492
|
+
'bbox': [0.1, 0.11, 0.35, 0.4],
|
489
493
|
'mask': array([[0, 0, 0, ..., 0, 0, 0],
|
490
494
|
[0, 0, 0, ..., 0, 0, 0],
|
491
495
|
...,
|
@@ -496,8 +500,8 @@ def florence2_sam2_video_tracking(
|
|
496
500
|
...
|
497
501
|
]
|
498
502
|
"""
|
499
|
-
if len(frames) == 0:
|
500
|
-
raise ValueError("
|
503
|
+
if len(frames) == 0 or not isinstance(frames, List):
|
504
|
+
raise ValueError("Must provide a list of numpy arrays for frames")
|
501
505
|
|
502
506
|
buffer_bytes = frames_to_bytes(frames)
|
503
507
|
files = [("video", buffer_bytes)]
|
@@ -535,7 +539,8 @@ def florence2_sam2_video_tracking(
|
|
535
539
|
label = str(detection["id"]) + ": " + detection["label"]
|
536
540
|
return_frame_data.append({"label": label, "mask": mask, "score": 1.0})
|
537
541
|
return_data.append(return_frame_data)
|
538
|
-
|
542
|
+
return_data = add_bboxes_from_masks(return_data)
|
543
|
+
return nms(return_data, iou_threshold=0.95)
|
539
544
|
|
540
545
|
|
541
546
|
def ocr(image: np.ndarray) -> List[Dict[str, Any]]:
|
@@ -677,8 +682,9 @@ def countgd_counting(
|
|
677
682
|
image: np.ndarray,
|
678
683
|
box_threshold: float = 0.23,
|
679
684
|
) -> List[Dict[str, Any]]:
|
680
|
-
"""'countgd_counting' is a tool that can
|
681
|
-
|
685
|
+
"""'countgd_counting' is a tool that can detect multiple instances of an object
|
686
|
+
given a text prompt. It is particularly useful when trying to detect and count a
|
687
|
+
large number of objects. It returns a list of bounding boxes with normalized
|
682
688
|
coordinates, label names and associated confidence scores.
|
683
689
|
|
684
690
|
Parameters:
|
@@ -711,7 +717,7 @@ def countgd_counting(
|
|
711
717
|
buffer_bytes = numpy_to_bytes(image)
|
712
718
|
files = [("image", buffer_bytes)]
|
713
719
|
payload = {
|
714
|
-
"prompts": [prompt.replace(", ", "
|
720
|
+
"prompts": [prompt.replace(", ", ". ")],
|
715
721
|
"confidence": box_threshold, # still not being used in the API
|
716
722
|
"model": "countgd",
|
717
723
|
}
|
@@ -733,7 +739,8 @@ def countgd_counting(
|
|
733
739
|
]
|
734
740
|
# TODO: remove this once we start to use the confidence on countgd
|
735
741
|
filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold)
|
736
|
-
|
742
|
+
return_data = [bbox.model_dump() for bbox in filtered_bboxes]
|
743
|
+
return single_nms(return_data, iou_threshold=0.80)
|
737
744
|
|
738
745
|
|
739
746
|
def countgd_example_based_counting(
|
@@ -864,9 +871,10 @@ def ixc25_image_vqa(prompt: str, image: np.ndarray) -> str:
|
|
864
871
|
|
865
872
|
|
866
873
|
def qwen2_vl_images_vqa(prompt: str, images: List[np.ndarray]) -> str:
|
867
|
-
"""'qwen2_vl_images_vqa' is a tool that can answer any questions about arbitrary
|
868
|
-
including regular images or images of documents or presentations. It
|
869
|
-
as an answer to
|
874
|
+
"""'qwen2_vl_images_vqa' is a tool that can answer any questions about arbitrary
|
875
|
+
images including regular images or images of documents or presentations. It can be
|
876
|
+
very useful for document QA or OCR text extraction. It returns text as an answer to
|
877
|
+
the question.
|
870
878
|
|
871
879
|
Parameters:
|
872
880
|
prompt (str): The question about the document image
|
@@ -880,6 +888,9 @@ def qwen2_vl_images_vqa(prompt: str, images: List[np.ndarray]) -> str:
|
|
880
888
|
>>> qwen2_vl_images_vqa('Give a summary of the document', images)
|
881
889
|
'The document talks about the history of the United States of America and its...'
|
882
890
|
"""
|
891
|
+
if isinstance(images, np.ndarray):
|
892
|
+
images = [images]
|
893
|
+
|
883
894
|
for image in images:
|
884
895
|
if image.shape[0] < 1 or image.shape[1] < 1:
|
885
896
|
raise ValueError(f"Image is empty, image shape: {image.shape}")
|
@@ -896,6 +907,30 @@ def qwen2_vl_images_vqa(prompt: str, images: List[np.ndarray]) -> str:
|
|
896
907
|
return cast(str, data)
|
897
908
|
|
898
909
|
|
910
|
+
def claude35_text_extraction(image: np.ndarray) -> str:
|
911
|
+
"""'claude35_text_extraction' is a tool that can extract text from an image. It
|
912
|
+
returns the extracted text as a string and can be used as an alternative to OCR if
|
913
|
+
you do not need to know the exact bounding box of the text.
|
914
|
+
|
915
|
+
Parameters:
|
916
|
+
image (np.ndarray): The image to extract text from.
|
917
|
+
|
918
|
+
Returns:
|
919
|
+
str: The extracted text from the image.
|
920
|
+
"""
|
921
|
+
|
922
|
+
lmm = AnthropicLMM()
|
923
|
+
buffer = io.BytesIO()
|
924
|
+
Image.fromarray(image).save(buffer, format="PNG")
|
925
|
+
image_bytes = buffer.getvalue()
|
926
|
+
image_b64 = "data:image/png;base64," + encode_image_bytes(image_bytes)
|
927
|
+
text = lmm.generate(
|
928
|
+
"Extract and return any text you see in this image and nothing else. If you do not read any text respond with an empty string.",
|
929
|
+
[image_b64],
|
930
|
+
)
|
931
|
+
return cast(str, text)
|
932
|
+
|
933
|
+
|
899
934
|
def ixc25_video_vqa(prompt: str, frames: List[np.ndarray]) -> str:
|
900
935
|
"""'ixc25_video_vqa' is a tool that can answer any questions about arbitrary videos
|
901
936
|
including regular videos or videos of documents or presentations. It returns text
|
@@ -944,6 +979,9 @@ def qwen2_vl_video_vqa(prompt: str, frames: List[np.ndarray]) -> str:
|
|
944
979
|
'Lionel Messi'
|
945
980
|
"""
|
946
981
|
|
982
|
+
if len(frames) == 0 or not isinstance(frames, List):
|
983
|
+
raise ValueError("Must provide a list of numpy arrays for frames")
|
984
|
+
|
947
985
|
buffer_bytes = frames_to_bytes(frames)
|
948
986
|
files = [("video", buffer_bytes)]
|
949
987
|
payload = {
|
@@ -1798,24 +1836,33 @@ def flux_image_inpainting(
|
|
1798
1836
|
... )
|
1799
1837
|
>>> save_image(result, "inpainted_room.png")
|
1800
1838
|
"""
|
1801
|
-
if (
|
1802
|
-
image.shape[0] < 8
|
1803
|
-
or image.shape[1] < 8
|
1804
|
-
or mask.shape[0] < 8
|
1805
|
-
or mask.shape[1] < 8
|
1806
|
-
):
|
1807
|
-
raise ValueError("The image or mask does not have enough size for inpainting")
|
1808
1839
|
|
1809
|
-
|
1810
|
-
|
1811
|
-
|
1812
|
-
|
1813
|
-
|
1840
|
+
min_dim = 8
|
1841
|
+
|
1842
|
+
if any(dim < min_dim for dim in image.shape[:2] + mask.shape[:2]):
|
1843
|
+
raise ValueError(f"Image and mask must be at least {min_dim}x{min_dim} pixels")
|
1844
|
+
|
1845
|
+
max_size = (512, 512)
|
1846
|
+
|
1847
|
+
if image.shape[0] > max_size[0] or image.shape[1] > max_size[1]:
|
1848
|
+
scaling_factor = min(max_size[0] / image.shape[0], max_size[1] / image.shape[1])
|
1849
|
+
new_size = (
|
1850
|
+
int(image.shape[1] * scaling_factor),
|
1851
|
+
int(image.shape[0] * scaling_factor),
|
1852
|
+
)
|
1853
|
+
new_size = ((new_size[0] // 8) * 8, (new_size[1] // 8) * 8)
|
1854
|
+
image = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA)
|
1855
|
+
mask = cv2.resize(mask, new_size, interpolation=cv2.INTER_NEAREST)
|
1856
|
+
|
1857
|
+
elif image.shape[0] % 8 != 0 or image.shape[1] % 8 != 0:
|
1858
|
+
new_size = ((image.shape[1] // 8) * 8, (image.shape[0] // 8) * 8)
|
1859
|
+
image = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA)
|
1860
|
+
mask = cv2.resize(mask, new_size, interpolation=cv2.INTER_NEAREST)
|
1814
1861
|
|
1815
1862
|
if np.array_equal(mask, mask.astype(bool).astype(int)):
|
1816
1863
|
mask = np.where(mask > 0, 255, 0).astype(np.uint8)
|
1817
1864
|
else:
|
1818
|
-
raise ValueError("
|
1865
|
+
raise ValueError("Mask should contain only binary values (0 or 1)")
|
1819
1866
|
|
1820
1867
|
image_file = numpy_to_bytes(image)
|
1821
1868
|
mask_file = numpy_to_bytes(mask)
|
@@ -2148,7 +2195,8 @@ def overlay_bounding_boxes(
|
|
2148
2195
|
bboxes = bbox_int[i]
|
2149
2196
|
bboxes = sorted(bboxes, key=lambda x: x["label"], reverse=True)
|
2150
2197
|
|
2151
|
-
if
|
2198
|
+
# if more than 50 boxes use small boxes to indicate objects else use regular boxes
|
2199
|
+
if len(bboxes) > 50:
|
2152
2200
|
pil_image = _plot_counting(pil_image, bboxes, color)
|
2153
2201
|
else:
|
2154
2202
|
width, height = pil_image.size
|
@@ -2179,7 +2227,14 @@ def overlay_bounding_boxes(
|
|
2179
2227
|
draw.text((box[0], box[1]), text, fill="black", font=font)
|
2180
2228
|
|
2181
2229
|
frame_out.append(np.array(pil_image))
|
2182
|
-
|
2230
|
+
return_frame = frame_out[0] if len(frame_out) == 1 else frame_out
|
2231
|
+
|
2232
|
+
if isinstance(return_frame, np.ndarray):
|
2233
|
+
from IPython.display import display
|
2234
|
+
|
2235
|
+
display(Image.fromarray(return_frame))
|
2236
|
+
|
2237
|
+
return return_frame # type: ignore
|
2183
2238
|
|
2184
2239
|
|
2185
2240
|
def _get_text_coords_from_mask(
|
@@ -2291,7 +2346,14 @@ def overlay_segmentation_masks(
|
|
2291
2346
|
draw.rectangle((x, y, text_box[2], text_box[3]), fill=color[label])
|
2292
2347
|
draw.text((x, y), text, fill="black", font=font)
|
2293
2348
|
frame_out.append(np.array(pil_image))
|
2294
|
-
|
2349
|
+
return_frame = frame_out[0] if len(frame_out) == 1 else frame_out
|
2350
|
+
|
2351
|
+
if isinstance(return_frame, np.ndarray):
|
2352
|
+
from IPython.display import display
|
2353
|
+
|
2354
|
+
display(Image.fromarray(return_frame))
|
2355
|
+
|
2356
|
+
return return_frame # type: ignore
|
2295
2357
|
|
2296
2358
|
|
2297
2359
|
def overlay_heat_map(
|
@@ -2399,6 +2461,7 @@ FUNCTION_TOOLS = [
|
|
2399
2461
|
florence2_sam2_image,
|
2400
2462
|
florence2_sam2_video_tracking,
|
2401
2463
|
florence2_phrase_grounding,
|
2464
|
+
claude35_text_extraction,
|
2402
2465
|
detr_segmentation,
|
2403
2466
|
depth_anything_v2,
|
2404
2467
|
generate_pose_image,
|
@@ -42,10 +42,10 @@ def normalize_bbox(
|
|
42
42
|
) -> List[float]:
|
43
43
|
r"""Normalize the bounding box coordinates to be between 0 and 1."""
|
44
44
|
x1, y1, x2, y2 = bbox
|
45
|
-
x1 = round(x1 / image_size[1], 2)
|
46
|
-
y1 = round(y1 / image_size[0], 2)
|
47
|
-
x2 = round(x2 / image_size[1], 2)
|
48
|
-
y2 = round(y2 / image_size[0], 2)
|
45
|
+
x1 = max(round(x1 / image_size[1], 2), 0)
|
46
|
+
y1 = max(round(y1 / image_size[0], 2), 0)
|
47
|
+
x2 = min(round(x2 / image_size[1], 2), image_size[1])
|
48
|
+
y2 = min(round(y2 / image_size[0], 2), image_size[0])
|
49
49
|
return [x1, y1, x2, y2]
|
50
50
|
|
51
51
|
|
@@ -175,9 +175,15 @@ def encode_media(media: Union[str, Path], resize: Optional[int] = None) -> str:
|
|
175
175
|
return media[:-4] + ".png"
|
176
176
|
return media
|
177
177
|
|
178
|
-
# if media is
|
178
|
+
# if media is in base64 ensure it's the correct resize
|
179
179
|
if isinstance(media, str) and media.startswith("data:image/"):
|
180
|
-
|
180
|
+
image_pil = b64_to_pil(media)
|
181
|
+
if resize is not None:
|
182
|
+
if image_pil.size[0] > resize or image_pil.size[1] > resize:
|
183
|
+
image_pil.thumbnail((resize, resize))
|
184
|
+
buffer = io.BytesIO()
|
185
|
+
image_pil.save(buffer, format="PNG")
|
186
|
+
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
181
187
|
|
182
188
|
extension = "png"
|
183
189
|
extension = Path(media).suffix
|