vision-agent 0.2.192__py3-none-any.whl → 0.2.195__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,246 @@
1
+ import logging
2
+ import shutil
3
+ import tempfile
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, cast
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ import vision_agent.tools as T
10
+ from vision_agent.agent.agent_utils import (
11
+ DefaultImports,
12
+ extract_code,
13
+ extract_json,
14
+ extract_tag,
15
+ )
16
+ from vision_agent.agent.vision_agent_planner_prompts_v2 import (
17
+ CATEGORIZE_TOOL_REQUEST,
18
+ FINALIZE_PLAN,
19
+ PICK_TOOL,
20
+ TEST_TOOLS,
21
+ TEST_TOOLS_EXAMPLE1,
22
+ TEST_TOOLS_EXAMPLE2,
23
+ )
24
+ from vision_agent.lmm import AnthropicLMM
25
+ from vision_agent.utils.execute import CodeInterpreterFactory
26
+ from vision_agent.utils.image_utils import convert_to_b64
27
+ from vision_agent.utils.sim import load_cached_sim
28
+
29
+ TOOL_FUNCTIONS = {tool.__name__: tool for tool in T.TOOLS}
30
+ TOOL_RECOMMENDER = load_cached_sim(T.TOOLS_DF)
31
+
32
+ _LOGGER = logging.getLogger(__name__)
33
+ EXAMPLES = f"\n{TEST_TOOLS_EXAMPLE1}\n{TEST_TOOLS_EXAMPLE2}\n"
34
+
35
+
36
+ def extract_tool_info(
37
+ tool_choice_context: Dict[str, Any]
38
+ ) -> Tuple[Optional[Callable], str, str, str]:
39
+ tool_thoughts = tool_choice_context.get("thoughts", "")
40
+ tool_docstring = ""
41
+ tool = tool_choice_context.get("best_tool", None)
42
+ if tool in TOOL_FUNCTIONS:
43
+ tool = TOOL_FUNCTIONS[tool]
44
+ tool_docstring = T.TOOLS_INFO[tool.__name__]
45
+
46
+ return tool, tool_thoughts, tool_docstring, ""
47
+
48
+
49
+ def get_tool_for_task(
50
+ task: str, images: List[np.ndarray], exclude_tools: Optional[List[str]] = None
51
+ ) -> None:
52
+ """Given a task and one or more images this function will find a tool to accomplish
53
+ the jobs. It prints the tool documentation and thoughts on why it chose the tool.
54
+
55
+ It can produce tools for the following types of tasks:
56
+ - Object detection and counting
57
+ - Classification
58
+ - Segmentation
59
+ - OCR
60
+ - VQA
61
+ - Depth and pose estimation
62
+ - Video object tracking
63
+
64
+ Wait until the documentation is printed to use the function so you know what the
65
+ input and output signatures are.
66
+
67
+ Parameters:
68
+ task: str: The task to accomplish.
69
+ images: List[np.ndarray]: The images to use for the task.
70
+ exclude_tools: Optional[List[str]]: A list of tool names to exclude from the
71
+ recommendations. This is helpful if you are calling get_tool_for_task twice
72
+ and do not want the same tool recommended.
73
+
74
+ Returns:
75
+ The tool to use for the task is printed to stdout
76
+
77
+ Examples
78
+ --------
79
+ >>> get_tool_for_task("Give me an OCR model that can find 'hot chocolate' in the image", [image])
80
+ """
81
+ lmm = AnthropicLMM()
82
+
83
+ with (
84
+ tempfile.TemporaryDirectory() as tmpdirname,
85
+ CodeInterpreterFactory.new_instance() as code_interpreter,
86
+ ):
87
+ image_paths = []
88
+ for i, image in enumerate(images[:3]):
89
+ image_path = f"{tmpdirname}/image_{i}.png"
90
+ Image.fromarray(image).save(image_path)
91
+ image_paths.append(image_path)
92
+
93
+ query = lmm.generate(CATEGORIZE_TOOL_REQUEST.format(task=task))
94
+ category = extract_tag(query, "category") # type: ignore
95
+ if category is None:
96
+ category = task
97
+ else:
98
+ category = (
99
+ f"I need models from the {category.strip()} category of tools. {task}"
100
+ )
101
+
102
+ tool_docs = TOOL_RECOMMENDER.top_k(category, k=10, thresh=0.2)
103
+ if exclude_tools is not None and len(exclude_tools) > 0:
104
+ cleaned_tool_docs = []
105
+ for tool_doc in tool_docs:
106
+ if not tool_doc["name"] in exclude_tools:
107
+ cleaned_tool_docs.append(tool_doc)
108
+ tool_docs = cleaned_tool_docs
109
+ tool_docs_str = "\n".join([e["doc"] for e in tool_docs])
110
+
111
+ prompt = TEST_TOOLS.format(
112
+ tool_docs=tool_docs_str,
113
+ previous_attempts="",
114
+ user_request=task,
115
+ examples=EXAMPLES,
116
+ media=str(image_paths),
117
+ )
118
+
119
+ response = lmm.generate(prompt, media=image_paths)
120
+ code = extract_tag(response, "code") # type: ignore
121
+ if code is None:
122
+ raise ValueError(f"Could not extract code from response: {response}")
123
+ tool_output = code_interpreter.exec_isolation(
124
+ DefaultImports.prepend_imports(code)
125
+ )
126
+ tool_output_str = tool_output.text(include_results=False).strip()
127
+
128
+ count = 1
129
+ while (
130
+ not tool_output.success
131
+ or (len(tool_output.logs.stdout) == 0 and len(tool_output.logs.stderr) == 0)
132
+ ) and count <= 3:
133
+ if tool_output_str.strip() == "":
134
+ tool_output_str = "EMPTY"
135
+ prompt = TEST_TOOLS.format(
136
+ tool_docs=tool_docs_str,
137
+ previous_attempts=f"<code>\n{code}\n</code>\nTOOL OUTPUT\n{tool_output_str}",
138
+ user_request=task,
139
+ examples=EXAMPLES,
140
+ media=str(image_paths),
141
+ )
142
+ code = extract_code(lmm.generate(prompt, media=image_paths)) # type: ignore
143
+ tool_output = code_interpreter.exec_isolation(
144
+ DefaultImports.prepend_imports(code)
145
+ )
146
+ tool_output_str = tool_output.text(include_results=False).strip()
147
+
148
+ error_message = ""
149
+ prompt = PICK_TOOL.format(
150
+ tool_docs=tool_docs_str,
151
+ user_request=task,
152
+ context=f"<code>\n{code}\n</code>\n<tool_output>\n{tool_output_str}\n</tool_output>",
153
+ previous_attempts=error_message,
154
+ )
155
+
156
+ response = lmm.generate(prompt, media=image_paths)
157
+ tool_choice_context = extract_tag(response, "json") # type: ignore
158
+ tool_choice_context_dict = extract_json(tool_choice_context) # type: ignore
159
+
160
+ tool, tool_thoughts, tool_docstring, error_message = extract_tool_info(
161
+ tool_choice_context_dict
162
+ )
163
+
164
+ count = 1
165
+ while tool is None and count <= 3:
166
+ prompt = PICK_TOOL.format(
167
+ tool_docs=tool_docs_str,
168
+ user_request=task,
169
+ context=f"<code>\n{code}\n</code>\n<tool_output>\n{tool_output_str}\n</tool_output>",
170
+ previous_attempts=error_message,
171
+ )
172
+ tool_choice_context_dict = extract_json(lmm.generate(prompt, media=image_paths)) # type: ignore
173
+ tool, tool_thoughts, tool_docstring, error_message = extract_tool_info(
174
+ tool_choice_context_dict
175
+ )
176
+ try:
177
+ shutil.rmtree(tmpdirname)
178
+ except Exception as e:
179
+ _LOGGER.error(f"Error removing temp directory: {e}")
180
+
181
+ print(
182
+ f"[get_tool_for_task output]\n{tool_thoughts}\n\nTool Documentation:\n{tool_docstring}\n[end of get_tool_for_task output]\n"
183
+ )
184
+
185
+
186
+ def finalize_plan(user_request: str, chain_of_thoughts: str) -> str:
187
+ """Finalizes the plan by taking the user request and the chain of thoughts that
188
+ represent the plan and returns the finalized plan.
189
+ """
190
+ lmm = AnthropicLMM()
191
+ prompt = FINALIZE_PLAN.format(
192
+ user_request=user_request, chain_of_thoughts=chain_of_thoughts
193
+ )
194
+ finalized_plan = cast(str, lmm.generate(prompt))
195
+ return finalized_plan
196
+
197
+
198
+ def claude35_vqa(prompt: str, medias: List[np.ndarray]) -> None:
199
+ """Asks the Claude-3.5 model a question about the given media and returns an answer.
200
+
201
+ Parameters:
202
+ prompt: str: The question to ask the model.
203
+ medias: List[np.ndarray]: The images to ask the question about, it could also
204
+ be frames from a video. You can send up to 5 frames from a video.
205
+ """
206
+ lmm = AnthropicLMM()
207
+ if isinstance(medias, np.ndarray):
208
+ medias = [medias]
209
+ if isinstance(medias, list) and len(medias) > 5:
210
+ medias = medias[:5]
211
+ all_media_b64 = [
212
+ "data:image/png;base64," + convert_to_b64(media) for media in medias
213
+ ]
214
+
215
+ response = cast(str, lmm.generate(prompt, media=all_media_b64))
216
+ print(f"[claude35_vqa output]\n{response}\n[end of claude35_vqa output]")
217
+
218
+
219
+ def suggestion(prompt: str, medias: List[np.ndarray]) -> None:
220
+ """Given your problem statement and the images, this will provide you with a
221
+ suggested plan on how to proceed. Always call suggestion when starting to solve
222
+ a problem.
223
+
224
+ Parameters:
225
+ prompt: str: The problem statement.
226
+ medias: List[np.ndarray]: The images to use for the problem
227
+ """
228
+ try:
229
+ from .suggestion import suggestion_impl # type: ignore
230
+
231
+ suggestion = suggestion_impl(prompt, medias)
232
+ print(suggestion)
233
+ except ImportError:
234
+ print("")
235
+
236
+
237
+ PLANNER_TOOLS = [
238
+ claude35_vqa,
239
+ suggestion,
240
+ get_tool_for_task,
241
+ T.load_image,
242
+ T.save_image,
243
+ T.extract_frames_and_timestamps,
244
+ T.save_video,
245
+ ]
246
+ PLANNER_DOCSTRING = T.get_tool_documentation(PLANNER_TOOLS) # type: ignore
@@ -4,6 +4,7 @@ import os
4
4
  from base64 import b64encode
5
5
  from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple
6
6
 
7
+ import numpy as np
7
8
  import pandas as pd
8
9
  from IPython.display import display
9
10
  from pydantic import BaseModel
@@ -14,6 +15,7 @@ from urllib3.util.retry import Retry
14
15
  from vision_agent.tools.tools_types import BoundingBoxes
15
16
  from vision_agent.utils.exceptions import RemoteToolCallFailed
16
17
  from vision_agent.utils.execute import Error, MimeType
18
+ from vision_agent.utils.image_utils import normalize_bbox
17
19
  from vision_agent.utils.type_defs import LandingaiAPIKey
18
20
 
19
21
  _LOGGER = logging.getLogger(__name__)
@@ -170,7 +172,7 @@ def get_tool_descriptions_by_names(
170
172
 
171
173
 
172
174
  def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
173
- data: Dict[str, List[str]] = {"desc": [], "doc": []}
175
+ data: Dict[str, List[str]] = {"desc": [], "doc": [], "name": []}
174
176
 
175
177
  for func in funcs:
176
178
  desc = func.__doc__
@@ -182,6 +184,7 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
182
184
  doc = f"{func.__name__}{inspect.signature(func)}:\n{func.__doc__}"
183
185
  data["desc"].append(desc)
184
186
  data["doc"].append(doc)
187
+ data["name"].append(func.__name__)
185
188
 
186
189
  return pd.DataFrame(data) # type: ignore
187
190
 
@@ -256,3 +259,64 @@ def filter_bboxes_by_threshold(
256
259
  bboxes: BoundingBoxes, threshold: float
257
260
  ) -> BoundingBoxes:
258
261
  return list(filter(lambda bbox: bbox.score >= threshold, bboxes))
262
+
263
+
264
+ def add_bboxes_from_masks(
265
+ all_preds: List[List[Dict[str, Any]]],
266
+ ) -> List[List[Dict[str, Any]]]:
267
+ for frame_preds in all_preds:
268
+ for preds in frame_preds:
269
+ if np.sum(preds["mask"]) == 0:
270
+ preds["bbox"] = []
271
+ else:
272
+ rows, cols = np.where(preds["mask"])
273
+ bbox = [
274
+ float(np.min(cols)),
275
+ float(np.min(rows)),
276
+ float(np.max(cols)),
277
+ float(np.max(rows)),
278
+ ]
279
+ bbox = normalize_bbox(bbox, preds["mask"].shape)
280
+ preds["bbox"] = bbox
281
+
282
+ return all_preds
283
+
284
+
285
+ def calculate_iou(bbox1: List[float], bbox2: List[float]) -> float:
286
+ x1, y1, x2, y2 = bbox1
287
+ x3, y3, x4, y4 = bbox2
288
+
289
+ x_overlap = max(0, min(x2, x4) - max(x1, x3))
290
+ y_overlap = max(0, min(y2, y4) - max(y1, y3))
291
+ intersection = x_overlap * y_overlap
292
+
293
+ area1 = (x2 - x1) * (y2 - y1)
294
+ area2 = (x4 - x3) * (y4 - y3)
295
+ union = area1 + area2 - intersection
296
+
297
+ return intersection / union if union > 0 else 0
298
+
299
+
300
+ def single_nms(
301
+ preds: List[Dict[str, Any]], iou_threshold: float
302
+ ) -> List[Dict[str, Any]]:
303
+ for i in range(len(preds)):
304
+ for j in range(i + 1, len(preds)):
305
+ if calculate_iou(preds[i]["bbox"], preds[j]["bbox"]) > iou_threshold:
306
+ if preds[i]["score"] > preds[j]["score"]:
307
+ preds[j]["score"] = 0
308
+ else:
309
+ preds[i]["score"] = 0
310
+
311
+ return [pred for pred in preds if pred["score"] > 0]
312
+
313
+
314
+ def nms(
315
+ all_preds: List[List[Dict[str, Any]]], iou_threshold: float
316
+ ) -> List[List[Dict[str, Any]]]:
317
+ return_preds = []
318
+ for frame_preds in all_preds:
319
+ frame_preds = single_nms(frame_preds, iou_threshold)
320
+ return_preds.append(frame_preds)
321
+
322
+ return return_preds
@@ -17,15 +17,18 @@ from pillow_heif import register_heif_opener # type: ignore
17
17
  from pytube import YouTube # type: ignore
18
18
 
19
19
  from vision_agent.clients.landing_public_api import LandingPublicAPI
20
- from vision_agent.lmm.lmm import OpenAILMM
20
+ from vision_agent.lmm.lmm import AnthropicLMM, OpenAILMM
21
21
  from vision_agent.tools.tool_utils import (
22
+ add_bboxes_from_masks,
22
23
  filter_bboxes_by_threshold,
23
24
  get_tool_descriptions,
24
25
  get_tool_documentation,
25
26
  get_tools_df,
26
27
  get_tools_info,
28
+ nms,
27
29
  send_inference_request,
28
30
  send_task_inference_request,
31
+ single_nms,
29
32
  )
30
33
  from vision_agent.tools.tools_types import JobStatus, ODResponseData
31
34
  from vision_agent.utils.exceptions import FineTuneModelIsNotReady
@@ -260,8 +263,8 @@ def owl_v2_video(
260
263
  ...
261
264
  ]
262
265
  """
263
- if len(frames) == 0:
264
- raise ValueError("No frames provided")
266
+ if len(frames) == 0 or not isinstance(frames, List):
267
+ raise ValueError("Must provide a list of numpy arrays for frames")
265
268
 
266
269
  image_size = frames[0].shape[:2]
267
270
  buffer_bytes = frames_to_bytes(frames)
@@ -455,7 +458,7 @@ def florence2_sam2_image(
455
458
  def florence2_sam2_video_tracking(
456
459
  prompt: str,
457
460
  frames: List[np.ndarray],
458
- chunk_length: Optional[int] = 3,
461
+ chunk_length: Optional[int] = 10,
459
462
  fine_tune_id: Optional[str] = None,
460
463
  ) -> List[List[Dict[str, Any]]]:
461
464
  """'florence2_sam2_video_tracking' is a tool that can segment and track multiple
@@ -473,11 +476,11 @@ def florence2_sam2_video_tracking(
473
476
  fine-tuned model ID here to use it.
474
477
 
475
478
  Returns:
476
- List[List[Dict[str, Any]]]: A list of list of dictionaries containing the label
477
- and segment mask. The outer list represents each frame and the inner list is
478
- the entities per frame. The label contains the object ID followed by the label
479
- name. The objects are only identified in the first framed and tracked
480
- throughout the video.
479
+ List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
480
+ label,segment mask and bounding boxes. The outer list represents each frame and
481
+ the inner list is the entities per frame. The label contains the object ID
482
+ followed by the label name. The objects are only identified in the first framed
483
+ and tracked throughout the video.
481
484
 
482
485
  Example
483
486
  -------
@@ -486,6 +489,7 @@ def florence2_sam2_video_tracking(
486
489
  [
487
490
  {
488
491
  'label': '0: dinosaur',
492
+ 'bbox': [0.1, 0.11, 0.35, 0.4],
489
493
  'mask': array([[0, 0, 0, ..., 0, 0, 0],
490
494
  [0, 0, 0, ..., 0, 0, 0],
491
495
  ...,
@@ -496,8 +500,8 @@ def florence2_sam2_video_tracking(
496
500
  ...
497
501
  ]
498
502
  """
499
- if len(frames) == 0:
500
- raise ValueError("No frames provided")
503
+ if len(frames) == 0 or not isinstance(frames, List):
504
+ raise ValueError("Must provide a list of numpy arrays for frames")
501
505
 
502
506
  buffer_bytes = frames_to_bytes(frames)
503
507
  files = [("video", buffer_bytes)]
@@ -535,7 +539,8 @@ def florence2_sam2_video_tracking(
535
539
  label = str(detection["id"]) + ": " + detection["label"]
536
540
  return_frame_data.append({"label": label, "mask": mask, "score": 1.0})
537
541
  return_data.append(return_frame_data)
538
- return return_data
542
+ return_data = add_bboxes_from_masks(return_data)
543
+ return nms(return_data, iou_threshold=0.95)
539
544
 
540
545
 
541
546
  def ocr(image: np.ndarray) -> List[Dict[str, Any]]:
@@ -677,8 +682,9 @@ def countgd_counting(
677
682
  image: np.ndarray,
678
683
  box_threshold: float = 0.23,
679
684
  ) -> List[Dict[str, Any]]:
680
- """'countgd_counting' is a tool that can precisely count multiple instances of an
681
- object given a text prompt. It returns a list of bounding boxes with normalized
685
+ """'countgd_counting' is a tool that can detect multiple instances of an object
686
+ given a text prompt. It is particularly useful when trying to detect and count a
687
+ large number of objects. It returns a list of bounding boxes with normalized
682
688
  coordinates, label names and associated confidence scores.
683
689
 
684
690
  Parameters:
@@ -711,7 +717,7 @@ def countgd_counting(
711
717
  buffer_bytes = numpy_to_bytes(image)
712
718
  files = [("image", buffer_bytes)]
713
719
  payload = {
714
- "prompts": [prompt.replace(", ", " .")],
720
+ "prompts": [prompt.replace(", ", ". ")],
715
721
  "confidence": box_threshold, # still not being used in the API
716
722
  "model": "countgd",
717
723
  }
@@ -733,7 +739,8 @@ def countgd_counting(
733
739
  ]
734
740
  # TODO: remove this once we start to use the confidence on countgd
735
741
  filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold)
736
- return [bbox.model_dump() for bbox in filtered_bboxes]
742
+ return_data = [bbox.model_dump() for bbox in filtered_bboxes]
743
+ return single_nms(return_data, iou_threshold=0.80)
737
744
 
738
745
 
739
746
  def countgd_example_based_counting(
@@ -864,9 +871,10 @@ def ixc25_image_vqa(prompt: str, image: np.ndarray) -> str:
864
871
 
865
872
 
866
873
  def qwen2_vl_images_vqa(prompt: str, images: List[np.ndarray]) -> str:
867
- """'qwen2_vl_images_vqa' is a tool that can answer any questions about arbitrary images
868
- including regular images or images of documents or presentations. It returns text
869
- as an answer to the question.
874
+ """'qwen2_vl_images_vqa' is a tool that can answer any questions about arbitrary
875
+ images including regular images or images of documents or presentations. It can be
876
+ very useful for document QA or OCR text extraction. It returns text as an answer to
877
+ the question.
870
878
 
871
879
  Parameters:
872
880
  prompt (str): The question about the document image
@@ -880,6 +888,9 @@ def qwen2_vl_images_vqa(prompt: str, images: List[np.ndarray]) -> str:
880
888
  >>> qwen2_vl_images_vqa('Give a summary of the document', images)
881
889
  'The document talks about the history of the United States of America and its...'
882
890
  """
891
+ if isinstance(images, np.ndarray):
892
+ images = [images]
893
+
883
894
  for image in images:
884
895
  if image.shape[0] < 1 or image.shape[1] < 1:
885
896
  raise ValueError(f"Image is empty, image shape: {image.shape}")
@@ -896,6 +907,30 @@ def qwen2_vl_images_vqa(prompt: str, images: List[np.ndarray]) -> str:
896
907
  return cast(str, data)
897
908
 
898
909
 
910
+ def claude35_text_extraction(image: np.ndarray) -> str:
911
+ """'claude35_text_extraction' is a tool that can extract text from an image. It
912
+ returns the extracted text as a string and can be used as an alternative to OCR if
913
+ you do not need to know the exact bounding box of the text.
914
+
915
+ Parameters:
916
+ image (np.ndarray): The image to extract text from.
917
+
918
+ Returns:
919
+ str: The extracted text from the image.
920
+ """
921
+
922
+ lmm = AnthropicLMM()
923
+ buffer = io.BytesIO()
924
+ Image.fromarray(image).save(buffer, format="PNG")
925
+ image_bytes = buffer.getvalue()
926
+ image_b64 = "data:image/png;base64," + encode_image_bytes(image_bytes)
927
+ text = lmm.generate(
928
+ "Extract and return any text you see in this image and nothing else. If you do not read any text respond with an empty string.",
929
+ [image_b64],
930
+ )
931
+ return cast(str, text)
932
+
933
+
899
934
  def ixc25_video_vqa(prompt: str, frames: List[np.ndarray]) -> str:
900
935
  """'ixc25_video_vqa' is a tool that can answer any questions about arbitrary videos
901
936
  including regular videos or videos of documents or presentations. It returns text
@@ -944,6 +979,9 @@ def qwen2_vl_video_vqa(prompt: str, frames: List[np.ndarray]) -> str:
944
979
  'Lionel Messi'
945
980
  """
946
981
 
982
+ if len(frames) == 0 or not isinstance(frames, List):
983
+ raise ValueError("Must provide a list of numpy arrays for frames")
984
+
947
985
  buffer_bytes = frames_to_bytes(frames)
948
986
  files = [("video", buffer_bytes)]
949
987
  payload = {
@@ -1798,24 +1836,33 @@ def flux_image_inpainting(
1798
1836
  ... )
1799
1837
  >>> save_image(result, "inpainted_room.png")
1800
1838
  """
1801
- if (
1802
- image.shape[0] < 8
1803
- or image.shape[1] < 8
1804
- or mask.shape[0] < 8
1805
- or mask.shape[1] < 8
1806
- ):
1807
- raise ValueError("The image or mask does not have enough size for inpainting")
1808
1839
 
1809
- if image.shape[0] % 8 != 0 or image.shape[1] % 8 != 0:
1810
- new_height = (image.shape[0] // 8) * 8
1811
- new_width = (image.shape[1] // 8) * 8
1812
- image = cv2.resize(image, (new_width, new_height))
1813
- mask = cv2.resize(mask, (new_width, new_height))
1840
+ min_dim = 8
1841
+
1842
+ if any(dim < min_dim for dim in image.shape[:2] + mask.shape[:2]):
1843
+ raise ValueError(f"Image and mask must be at least {min_dim}x{min_dim} pixels")
1844
+
1845
+ max_size = (512, 512)
1846
+
1847
+ if image.shape[0] > max_size[0] or image.shape[1] > max_size[1]:
1848
+ scaling_factor = min(max_size[0] / image.shape[0], max_size[1] / image.shape[1])
1849
+ new_size = (
1850
+ int(image.shape[1] * scaling_factor),
1851
+ int(image.shape[0] * scaling_factor),
1852
+ )
1853
+ new_size = ((new_size[0] // 8) * 8, (new_size[1] // 8) * 8)
1854
+ image = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA)
1855
+ mask = cv2.resize(mask, new_size, interpolation=cv2.INTER_NEAREST)
1856
+
1857
+ elif image.shape[0] % 8 != 0 or image.shape[1] % 8 != 0:
1858
+ new_size = ((image.shape[1] // 8) * 8, (image.shape[0] // 8) * 8)
1859
+ image = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA)
1860
+ mask = cv2.resize(mask, new_size, interpolation=cv2.INTER_NEAREST)
1814
1861
 
1815
1862
  if np.array_equal(mask, mask.astype(bool).astype(int)):
1816
1863
  mask = np.where(mask > 0, 255, 0).astype(np.uint8)
1817
1864
  else:
1818
- raise ValueError("The mask should be a binary mask with 0's and 1's")
1865
+ raise ValueError("Mask should contain only binary values (0 or 1)")
1819
1866
 
1820
1867
  image_file = numpy_to_bytes(image)
1821
1868
  mask_file = numpy_to_bytes(mask)
@@ -2148,7 +2195,8 @@ def overlay_bounding_boxes(
2148
2195
  bboxes = bbox_int[i]
2149
2196
  bboxes = sorted(bboxes, key=lambda x: x["label"], reverse=True)
2150
2197
 
2151
- if len(bboxes) > 40:
2198
+ # if more than 50 boxes use small boxes to indicate objects else use regular boxes
2199
+ if len(bboxes) > 50:
2152
2200
  pil_image = _plot_counting(pil_image, bboxes, color)
2153
2201
  else:
2154
2202
  width, height = pil_image.size
@@ -2179,7 +2227,14 @@ def overlay_bounding_boxes(
2179
2227
  draw.text((box[0], box[1]), text, fill="black", font=font)
2180
2228
 
2181
2229
  frame_out.append(np.array(pil_image))
2182
- return frame_out[0] if len(frame_out) == 1 else frame_out
2230
+ return_frame = frame_out[0] if len(frame_out) == 1 else frame_out
2231
+
2232
+ if isinstance(return_frame, np.ndarray):
2233
+ from IPython.display import display
2234
+
2235
+ display(Image.fromarray(return_frame))
2236
+
2237
+ return return_frame # type: ignore
2183
2238
 
2184
2239
 
2185
2240
  def _get_text_coords_from_mask(
@@ -2291,7 +2346,14 @@ def overlay_segmentation_masks(
2291
2346
  draw.rectangle((x, y, text_box[2], text_box[3]), fill=color[label])
2292
2347
  draw.text((x, y), text, fill="black", font=font)
2293
2348
  frame_out.append(np.array(pil_image))
2294
- return frame_out[0] if len(frame_out) == 1 else frame_out
2349
+ return_frame = frame_out[0] if len(frame_out) == 1 else frame_out
2350
+
2351
+ if isinstance(return_frame, np.ndarray):
2352
+ from IPython.display import display
2353
+
2354
+ display(Image.fromarray(return_frame))
2355
+
2356
+ return return_frame # type: ignore
2295
2357
 
2296
2358
 
2297
2359
  def overlay_heat_map(
@@ -2399,6 +2461,7 @@ FUNCTION_TOOLS = [
2399
2461
  florence2_sam2_image,
2400
2462
  florence2_sam2_video_tracking,
2401
2463
  florence2_phrase_grounding,
2464
+ claude35_text_extraction,
2402
2465
  detr_segmentation,
2403
2466
  depth_anything_v2,
2404
2467
  generate_pose_image,
@@ -42,10 +42,10 @@ def normalize_bbox(
42
42
  ) -> List[float]:
43
43
  r"""Normalize the bounding box coordinates to be between 0 and 1."""
44
44
  x1, y1, x2, y2 = bbox
45
- x1 = round(x1 / image_size[1], 2)
46
- y1 = round(y1 / image_size[0], 2)
47
- x2 = round(x2 / image_size[1], 2)
48
- y2 = round(y2 / image_size[0], 2)
45
+ x1 = max(round(x1 / image_size[1], 2), 0)
46
+ y1 = max(round(y1 / image_size[0], 2), 0)
47
+ x2 = min(round(x2 / image_size[1], 2), image_size[1])
48
+ y2 = min(round(y2 / image_size[0], 2), image_size[0])
49
49
  return [x1, y1, x2, y2]
50
50
 
51
51
 
@@ -175,9 +175,15 @@ def encode_media(media: Union[str, Path], resize: Optional[int] = None) -> str:
175
175
  return media[:-4] + ".png"
176
176
  return media
177
177
 
178
- # if media is already a base64 encoded image return
178
+ # if media is in base64 ensure it's the correct resize
179
179
  if isinstance(media, str) and media.startswith("data:image/"):
180
- return media
180
+ image_pil = b64_to_pil(media)
181
+ if resize is not None:
182
+ if image_pil.size[0] > resize or image_pil.size[1] > resize:
183
+ image_pil.thumbnail((resize, resize))
184
+ buffer = io.BytesIO()
185
+ image_pil.save(buffer, format="PNG")
186
+ return base64.b64encode(buffer.getvalue()).decode("utf-8")
181
187
 
182
188
  extension = "png"
183
189
  extension = Path(media).suffix