vision-agent 0.2.193__py3-none-any.whl → 0.2.196__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,246 @@
1
+ import logging
2
+ import shutil
3
+ import tempfile
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, cast
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ import vision_agent.tools as T
10
+ from vision_agent.agent.agent_utils import (
11
+ DefaultImports,
12
+ extract_code,
13
+ extract_json,
14
+ extract_tag,
15
+ )
16
+ from vision_agent.agent.vision_agent_planner_prompts_v2 import (
17
+ CATEGORIZE_TOOL_REQUEST,
18
+ FINALIZE_PLAN,
19
+ PICK_TOOL,
20
+ TEST_TOOLS,
21
+ TEST_TOOLS_EXAMPLE1,
22
+ TEST_TOOLS_EXAMPLE2,
23
+ )
24
+ from vision_agent.lmm import AnthropicLMM
25
+ from vision_agent.utils.execute import CodeInterpreterFactory
26
+ from vision_agent.utils.image_utils import convert_to_b64
27
+ from vision_agent.utils.sim import load_cached_sim
28
+
29
+ TOOL_FUNCTIONS = {tool.__name__: tool for tool in T.TOOLS}
30
+ TOOL_RECOMMENDER = load_cached_sim(T.TOOLS_DF)
31
+
32
+ _LOGGER = logging.getLogger(__name__)
33
+ EXAMPLES = f"\n{TEST_TOOLS_EXAMPLE1}\n{TEST_TOOLS_EXAMPLE2}\n"
34
+
35
+
36
+ def extract_tool_info(
37
+ tool_choice_context: Dict[str, Any]
38
+ ) -> Tuple[Optional[Callable], str, str, str]:
39
+ tool_thoughts = tool_choice_context.get("thoughts", "")
40
+ tool_docstring = ""
41
+ tool = tool_choice_context.get("best_tool", None)
42
+ if tool in TOOL_FUNCTIONS:
43
+ tool = TOOL_FUNCTIONS[tool]
44
+ tool_docstring = T.TOOLS_INFO[tool.__name__]
45
+
46
+ return tool, tool_thoughts, tool_docstring, ""
47
+
48
+
49
+ def get_tool_for_task(
50
+ task: str, images: List[np.ndarray], exclude_tools: Optional[List[str]] = None
51
+ ) -> None:
52
+ """Given a task and one or more images this function will find a tool to accomplish
53
+ the jobs. It prints the tool documentation and thoughts on why it chose the tool.
54
+
55
+ It can produce tools for the following types of tasks:
56
+ - Object detection and counting
57
+ - Classification
58
+ - Segmentation
59
+ - OCR
60
+ - VQA
61
+ - Depth and pose estimation
62
+ - Video object tracking
63
+
64
+ Wait until the documentation is printed to use the function so you know what the
65
+ input and output signatures are.
66
+
67
+ Parameters:
68
+ task: str: The task to accomplish.
69
+ images: List[np.ndarray]: The images to use for the task.
70
+ exclude_tools: Optional[List[str]]: A list of tool names to exclude from the
71
+ recommendations. This is helpful if you are calling get_tool_for_task twice
72
+ and do not want the same tool recommended.
73
+
74
+ Returns:
75
+ The tool to use for the task is printed to stdout
76
+
77
+ Examples
78
+ --------
79
+ >>> get_tool_for_task("Give me an OCR model that can find 'hot chocolate' in the image", [image])
80
+ """
81
+ lmm = AnthropicLMM()
82
+
83
+ with (
84
+ tempfile.TemporaryDirectory() as tmpdirname,
85
+ CodeInterpreterFactory.new_instance() as code_interpreter,
86
+ ):
87
+ image_paths = []
88
+ for i, image in enumerate(images[:3]):
89
+ image_path = f"{tmpdirname}/image_{i}.png"
90
+ Image.fromarray(image).save(image_path)
91
+ image_paths.append(image_path)
92
+
93
+ query = lmm.generate(CATEGORIZE_TOOL_REQUEST.format(task=task))
94
+ category = extract_tag(query, "category") # type: ignore
95
+ if category is None:
96
+ category = task
97
+ else:
98
+ category = (
99
+ f"I need models from the {category.strip()} category of tools. {task}"
100
+ )
101
+
102
+ tool_docs = TOOL_RECOMMENDER.top_k(category, k=10, thresh=0.2)
103
+ if exclude_tools is not None and len(exclude_tools) > 0:
104
+ cleaned_tool_docs = []
105
+ for tool_doc in tool_docs:
106
+ if not tool_doc["name"] in exclude_tools:
107
+ cleaned_tool_docs.append(tool_doc)
108
+ tool_docs = cleaned_tool_docs
109
+ tool_docs_str = "\n".join([e["doc"] for e in tool_docs])
110
+
111
+ prompt = TEST_TOOLS.format(
112
+ tool_docs=tool_docs_str,
113
+ previous_attempts="",
114
+ user_request=task,
115
+ examples=EXAMPLES,
116
+ media=str(image_paths),
117
+ )
118
+
119
+ response = lmm.generate(prompt, media=image_paths)
120
+ code = extract_tag(response, "code") # type: ignore
121
+ if code is None:
122
+ raise ValueError(f"Could not extract code from response: {response}")
123
+ tool_output = code_interpreter.exec_isolation(
124
+ DefaultImports.prepend_imports(code)
125
+ )
126
+ tool_output_str = tool_output.text(include_results=False).strip()
127
+
128
+ count = 1
129
+ while (
130
+ not tool_output.success
131
+ or (len(tool_output.logs.stdout) == 0 and len(tool_output.logs.stderr) == 0)
132
+ ) and count <= 3:
133
+ if tool_output_str.strip() == "":
134
+ tool_output_str = "EMPTY"
135
+ prompt = TEST_TOOLS.format(
136
+ tool_docs=tool_docs_str,
137
+ previous_attempts=f"<code>\n{code}\n</code>\nTOOL OUTPUT\n{tool_output_str}",
138
+ user_request=task,
139
+ examples=EXAMPLES,
140
+ media=str(image_paths),
141
+ )
142
+ code = extract_code(lmm.generate(prompt, media=image_paths)) # type: ignore
143
+ tool_output = code_interpreter.exec_isolation(
144
+ DefaultImports.prepend_imports(code)
145
+ )
146
+ tool_output_str = tool_output.text(include_results=False).strip()
147
+
148
+ error_message = ""
149
+ prompt = PICK_TOOL.format(
150
+ tool_docs=tool_docs_str,
151
+ user_request=task,
152
+ context=f"<code>\n{code}\n</code>\n<tool_output>\n{tool_output_str}\n</tool_output>",
153
+ previous_attempts=error_message,
154
+ )
155
+
156
+ response = lmm.generate(prompt, media=image_paths)
157
+ tool_choice_context = extract_tag(response, "json") # type: ignore
158
+ tool_choice_context_dict = extract_json(tool_choice_context) # type: ignore
159
+
160
+ tool, tool_thoughts, tool_docstring, error_message = extract_tool_info(
161
+ tool_choice_context_dict
162
+ )
163
+
164
+ count = 1
165
+ while tool is None and count <= 3:
166
+ prompt = PICK_TOOL.format(
167
+ tool_docs=tool_docs_str,
168
+ user_request=task,
169
+ context=f"<code>\n{code}\n</code>\n<tool_output>\n{tool_output_str}\n</tool_output>",
170
+ previous_attempts=error_message,
171
+ )
172
+ tool_choice_context_dict = extract_json(lmm.generate(prompt, media=image_paths)) # type: ignore
173
+ tool, tool_thoughts, tool_docstring, error_message = extract_tool_info(
174
+ tool_choice_context_dict
175
+ )
176
+ try:
177
+ shutil.rmtree(tmpdirname)
178
+ except Exception as e:
179
+ _LOGGER.error(f"Error removing temp directory: {e}")
180
+
181
+ print(
182
+ f"[get_tool_for_task output]\n{tool_thoughts}\n\nTool Documentation:\n{tool_docstring}\n[end of get_tool_for_task output]\n"
183
+ )
184
+
185
+
186
+ def finalize_plan(user_request: str, chain_of_thoughts: str) -> str:
187
+ """Finalizes the plan by taking the user request and the chain of thoughts that
188
+ represent the plan and returns the finalized plan.
189
+ """
190
+ lmm = AnthropicLMM()
191
+ prompt = FINALIZE_PLAN.format(
192
+ user_request=user_request, chain_of_thoughts=chain_of_thoughts
193
+ )
194
+ finalized_plan = cast(str, lmm.generate(prompt))
195
+ return finalized_plan
196
+
197
+
198
+ def claude35_vqa(prompt: str, medias: List[np.ndarray]) -> None:
199
+ """Asks the Claude-3.5 model a question about the given media and returns an answer.
200
+
201
+ Parameters:
202
+ prompt: str: The question to ask the model.
203
+ medias: List[np.ndarray]: The images to ask the question about, it could also
204
+ be frames from a video. You can send up to 5 frames from a video.
205
+ """
206
+ lmm = AnthropicLMM()
207
+ if isinstance(medias, np.ndarray):
208
+ medias = [medias]
209
+ if isinstance(medias, list) and len(medias) > 5:
210
+ medias = medias[:5]
211
+ all_media_b64 = [
212
+ "data:image/png;base64," + convert_to_b64(media) for media in medias
213
+ ]
214
+
215
+ response = cast(str, lmm.generate(prompt, media=all_media_b64))
216
+ print(f"[claude35_vqa output]\n{response}\n[end of claude35_vqa output]")
217
+
218
+
219
+ def suggestion(prompt: str, medias: List[np.ndarray]) -> None:
220
+ """Given your problem statement and the images, this will provide you with a
221
+ suggested plan on how to proceed. Always call suggestion when starting to solve
222
+ a problem.
223
+
224
+ Parameters:
225
+ prompt: str: The problem statement.
226
+ medias: List[np.ndarray]: The images to use for the problem
227
+ """
228
+ try:
229
+ from .suggestion import suggestion_impl # type: ignore
230
+
231
+ suggestion = suggestion_impl(prompt, medias)
232
+ print(suggestion)
233
+ except ImportError:
234
+ print("")
235
+
236
+
237
+ PLANNER_TOOLS = [
238
+ claude35_vqa,
239
+ suggestion,
240
+ get_tool_for_task,
241
+ T.load_image,
242
+ T.save_image,
243
+ T.extract_frames_and_timestamps,
244
+ T.save_video,
245
+ ]
246
+ PLANNER_DOCSTRING = T.get_tool_documentation(PLANNER_TOOLS) # type: ignore
@@ -4,6 +4,7 @@ import os
4
4
  from base64 import b64encode
5
5
  from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple
6
6
 
7
+ import numpy as np
7
8
  import pandas as pd
8
9
  from IPython.display import display
9
10
  from pydantic import BaseModel
@@ -14,6 +15,7 @@ from urllib3.util.retry import Retry
14
15
  from vision_agent.tools.tools_types import BoundingBoxes
15
16
  from vision_agent.utils.exceptions import RemoteToolCallFailed
16
17
  from vision_agent.utils.execute import Error, MimeType
18
+ from vision_agent.utils.image_utils import normalize_bbox
17
19
  from vision_agent.utils.type_defs import LandingaiAPIKey
18
20
 
19
21
  _LOGGER = logging.getLogger(__name__)
@@ -170,7 +172,7 @@ def get_tool_descriptions_by_names(
170
172
 
171
173
 
172
174
  def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
173
- data: Dict[str, List[str]] = {"desc": [], "doc": []}
175
+ data: Dict[str, List[str]] = {"desc": [], "doc": [], "name": []}
174
176
 
175
177
  for func in funcs:
176
178
  desc = func.__doc__
@@ -182,6 +184,7 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
182
184
  doc = f"{func.__name__}{inspect.signature(func)}:\n{func.__doc__}"
183
185
  data["desc"].append(desc)
184
186
  data["doc"].append(doc)
187
+ data["name"].append(func.__name__)
185
188
 
186
189
  return pd.DataFrame(data) # type: ignore
187
190
 
@@ -256,3 +259,64 @@ def filter_bboxes_by_threshold(
256
259
  bboxes: BoundingBoxes, threshold: float
257
260
  ) -> BoundingBoxes:
258
261
  return list(filter(lambda bbox: bbox.score >= threshold, bboxes))
262
+
263
+
264
+ def add_bboxes_from_masks(
265
+ all_preds: List[List[Dict[str, Any]]],
266
+ ) -> List[List[Dict[str, Any]]]:
267
+ for frame_preds in all_preds:
268
+ for preds in frame_preds:
269
+ if np.sum(preds["mask"]) == 0:
270
+ preds["bbox"] = []
271
+ else:
272
+ rows, cols = np.where(preds["mask"])
273
+ bbox = [
274
+ float(np.min(cols)),
275
+ float(np.min(rows)),
276
+ float(np.max(cols)),
277
+ float(np.max(rows)),
278
+ ]
279
+ bbox = normalize_bbox(bbox, preds["mask"].shape)
280
+ preds["bbox"] = bbox
281
+
282
+ return all_preds
283
+
284
+
285
+ def calculate_iou(bbox1: List[float], bbox2: List[float]) -> float:
286
+ x1, y1, x2, y2 = bbox1
287
+ x3, y3, x4, y4 = bbox2
288
+
289
+ x_overlap = max(0, min(x2, x4) - max(x1, x3))
290
+ y_overlap = max(0, min(y2, y4) - max(y1, y3))
291
+ intersection = x_overlap * y_overlap
292
+
293
+ area1 = (x2 - x1) * (y2 - y1)
294
+ area2 = (x4 - x3) * (y4 - y3)
295
+ union = area1 + area2 - intersection
296
+
297
+ return intersection / union if union > 0 else 0
298
+
299
+
300
+ def single_nms(
301
+ preds: List[Dict[str, Any]], iou_threshold: float
302
+ ) -> List[Dict[str, Any]]:
303
+ for i in range(len(preds)):
304
+ for j in range(i + 1, len(preds)):
305
+ if calculate_iou(preds[i]["bbox"], preds[j]["bbox"]) > iou_threshold:
306
+ if preds[i]["score"] > preds[j]["score"]:
307
+ preds[j]["score"] = 0
308
+ else:
309
+ preds[i]["score"] = 0
310
+
311
+ return [pred for pred in preds if pred["score"] > 0]
312
+
313
+
314
+ def nms(
315
+ all_preds: List[List[Dict[str, Any]]], iou_threshold: float
316
+ ) -> List[List[Dict[str, Any]]]:
317
+ return_preds = []
318
+ for frame_preds in all_preds:
319
+ frame_preds = single_nms(frame_preds, iou_threshold)
320
+ return_preds.append(frame_preds)
321
+
322
+ return return_preds
@@ -17,15 +17,18 @@ from pillow_heif import register_heif_opener # type: ignore
17
17
  from pytube import YouTube # type: ignore
18
18
 
19
19
  from vision_agent.clients.landing_public_api import LandingPublicAPI
20
- from vision_agent.lmm.lmm import OpenAILMM
20
+ from vision_agent.lmm.lmm import AnthropicLMM, OpenAILMM
21
21
  from vision_agent.tools.tool_utils import (
22
+ add_bboxes_from_masks,
22
23
  filter_bboxes_by_threshold,
23
24
  get_tool_descriptions,
24
25
  get_tool_documentation,
25
26
  get_tools_df,
26
27
  get_tools_info,
28
+ nms,
27
29
  send_inference_request,
28
30
  send_task_inference_request,
31
+ single_nms,
29
32
  )
30
33
  from vision_agent.tools.tools_types import JobStatus, ODResponseData
31
34
  from vision_agent.utils.exceptions import FineTuneModelIsNotReady
@@ -260,8 +263,8 @@ def owl_v2_video(
260
263
  ...
261
264
  ]
262
265
  """
263
- if len(frames) == 0:
264
- raise ValueError("No frames provided")
266
+ if len(frames) == 0 or not isinstance(frames, List):
267
+ raise ValueError("Must provide a list of numpy arrays for frames")
265
268
 
266
269
  image_size = frames[0].shape[:2]
267
270
  buffer_bytes = frames_to_bytes(frames)
@@ -455,7 +458,7 @@ def florence2_sam2_image(
455
458
  def florence2_sam2_video_tracking(
456
459
  prompt: str,
457
460
  frames: List[np.ndarray],
458
- chunk_length: Optional[int] = 3,
461
+ chunk_length: Optional[int] = 10,
459
462
  fine_tune_id: Optional[str] = None,
460
463
  ) -> List[List[Dict[str, Any]]]:
461
464
  """'florence2_sam2_video_tracking' is a tool that can segment and track multiple
@@ -473,11 +476,11 @@ def florence2_sam2_video_tracking(
473
476
  fine-tuned model ID here to use it.
474
477
 
475
478
  Returns:
476
- List[List[Dict[str, Any]]]: A list of list of dictionaries containing the label
477
- and segment mask. The outer list represents each frame and the inner list is
478
- the entities per frame. The label contains the object ID followed by the label
479
- name. The objects are only identified in the first framed and tracked
480
- throughout the video.
479
+ List[List[Dict[str, Any]]]: A list of list of dictionaries containing the
480
+ label,segment mask and bounding boxes. The outer list represents each frame and
481
+ the inner list is the entities per frame. The label contains the object ID
482
+ followed by the label name. The objects are only identified in the first framed
483
+ and tracked throughout the video.
481
484
 
482
485
  Example
483
486
  -------
@@ -486,6 +489,7 @@ def florence2_sam2_video_tracking(
486
489
  [
487
490
  {
488
491
  'label': '0: dinosaur',
492
+ 'bbox': [0.1, 0.11, 0.35, 0.4],
489
493
  'mask': array([[0, 0, 0, ..., 0, 0, 0],
490
494
  [0, 0, 0, ..., 0, 0, 0],
491
495
  ...,
@@ -496,8 +500,8 @@ def florence2_sam2_video_tracking(
496
500
  ...
497
501
  ]
498
502
  """
499
- if len(frames) == 0:
500
- raise ValueError("No frames provided")
503
+ if len(frames) == 0 or not isinstance(frames, List):
504
+ raise ValueError("Must provide a list of numpy arrays for frames")
501
505
 
502
506
  buffer_bytes = frames_to_bytes(frames)
503
507
  files = [("video", buffer_bytes)]
@@ -535,7 +539,8 @@ def florence2_sam2_video_tracking(
535
539
  label = str(detection["id"]) + ": " + detection["label"]
536
540
  return_frame_data.append({"label": label, "mask": mask, "score": 1.0})
537
541
  return_data.append(return_frame_data)
538
- return return_data
542
+ return_data = add_bboxes_from_masks(return_data)
543
+ return nms(return_data, iou_threshold=0.95)
539
544
 
540
545
 
541
546
  def ocr(image: np.ndarray) -> List[Dict[str, Any]]:
@@ -677,8 +682,9 @@ def countgd_counting(
677
682
  image: np.ndarray,
678
683
  box_threshold: float = 0.23,
679
684
  ) -> List[Dict[str, Any]]:
680
- """'countgd_counting' is a tool that can precisely count multiple instances of an
681
- object given a text prompt. It returns a list of bounding boxes with normalized
685
+ """'countgd_counting' is a tool that can detect multiple instances of an object
686
+ given a text prompt. It is particularly useful when trying to detect and count a
687
+ large number of objects. It returns a list of bounding boxes with normalized
682
688
  coordinates, label names and associated confidence scores.
683
689
 
684
690
  Parameters:
@@ -711,7 +717,7 @@ def countgd_counting(
711
717
  buffer_bytes = numpy_to_bytes(image)
712
718
  files = [("image", buffer_bytes)]
713
719
  payload = {
714
- "prompts": [prompt.replace(", ", " .")],
720
+ "prompts": [prompt.replace(", ", ". ")],
715
721
  "confidence": box_threshold, # still not being used in the API
716
722
  "model": "countgd",
717
723
  }
@@ -733,7 +739,8 @@ def countgd_counting(
733
739
  ]
734
740
  # TODO: remove this once we start to use the confidence on countgd
735
741
  filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold)
736
- return [bbox.model_dump() for bbox in filtered_bboxes]
742
+ return_data = [bbox.model_dump() for bbox in filtered_bboxes]
743
+ return single_nms(return_data, iou_threshold=0.80)
737
744
 
738
745
 
739
746
  def countgd_example_based_counting(
@@ -864,9 +871,10 @@ def ixc25_image_vqa(prompt: str, image: np.ndarray) -> str:
864
871
 
865
872
 
866
873
  def qwen2_vl_images_vqa(prompt: str, images: List[np.ndarray]) -> str:
867
- """'qwen2_vl_images_vqa' is a tool that can answer any questions about arbitrary images
868
- including regular images or images of documents or presentations. It returns text
869
- as an answer to the question.
874
+ """'qwen2_vl_images_vqa' is a tool that can answer any questions about arbitrary
875
+ images including regular images or images of documents or presentations. It can be
876
+ very useful for document QA or OCR text extraction. It returns text as an answer to
877
+ the question.
870
878
 
871
879
  Parameters:
872
880
  prompt (str): The question about the document image
@@ -880,6 +888,9 @@ def qwen2_vl_images_vqa(prompt: str, images: List[np.ndarray]) -> str:
880
888
  >>> qwen2_vl_images_vqa('Give a summary of the document', images)
881
889
  'The document talks about the history of the United States of America and its...'
882
890
  """
891
+ if isinstance(images, np.ndarray):
892
+ images = [images]
893
+
883
894
  for image in images:
884
895
  if image.shape[0] < 1 or image.shape[1] < 1:
885
896
  raise ValueError(f"Image is empty, image shape: {image.shape}")
@@ -896,6 +907,30 @@ def qwen2_vl_images_vqa(prompt: str, images: List[np.ndarray]) -> str:
896
907
  return cast(str, data)
897
908
 
898
909
 
910
+ def claude35_text_extraction(image: np.ndarray) -> str:
911
+ """'claude35_text_extraction' is a tool that can extract text from an image. It
912
+ returns the extracted text as a string and can be used as an alternative to OCR if
913
+ you do not need to know the exact bounding box of the text.
914
+
915
+ Parameters:
916
+ image (np.ndarray): The image to extract text from.
917
+
918
+ Returns:
919
+ str: The extracted text from the image.
920
+ """
921
+
922
+ lmm = AnthropicLMM()
923
+ buffer = io.BytesIO()
924
+ Image.fromarray(image).save(buffer, format="PNG")
925
+ image_bytes = buffer.getvalue()
926
+ image_b64 = "data:image/png;base64," + encode_image_bytes(image_bytes)
927
+ text = lmm.generate(
928
+ "Extract and return any text you see in this image and nothing else. If you do not read any text respond with an empty string.",
929
+ [image_b64],
930
+ )
931
+ return cast(str, text)
932
+
933
+
899
934
  def ixc25_video_vqa(prompt: str, frames: List[np.ndarray]) -> str:
900
935
  """'ixc25_video_vqa' is a tool that can answer any questions about arbitrary videos
901
936
  including regular videos or videos of documents or presentations. It returns text
@@ -944,6 +979,9 @@ def qwen2_vl_video_vqa(prompt: str, frames: List[np.ndarray]) -> str:
944
979
  'Lionel Messi'
945
980
  """
946
981
 
982
+ if len(frames) == 0 or not isinstance(frames, List):
983
+ raise ValueError("Must provide a list of numpy arrays for frames")
984
+
947
985
  buffer_bytes = frames_to_bytes(frames)
948
986
  files = [("video", buffer_bytes)]
949
987
  payload = {
@@ -2157,7 +2195,8 @@ def overlay_bounding_boxes(
2157
2195
  bboxes = bbox_int[i]
2158
2196
  bboxes = sorted(bboxes, key=lambda x: x["label"], reverse=True)
2159
2197
 
2160
- if len(bboxes) > 40:
2198
+ # if more than 50 boxes use small boxes to indicate objects else use regular boxes
2199
+ if len(bboxes) > 50:
2161
2200
  pil_image = _plot_counting(pil_image, bboxes, color)
2162
2201
  else:
2163
2202
  width, height = pil_image.size
@@ -2188,7 +2227,14 @@ def overlay_bounding_boxes(
2188
2227
  draw.text((box[0], box[1]), text, fill="black", font=font)
2189
2228
 
2190
2229
  frame_out.append(np.array(pil_image))
2191
- return frame_out[0] if len(frame_out) == 1 else frame_out
2230
+ return_frame = frame_out[0] if len(frame_out) == 1 else frame_out
2231
+
2232
+ if isinstance(return_frame, np.ndarray):
2233
+ from IPython.display import display
2234
+
2235
+ display(Image.fromarray(return_frame))
2236
+
2237
+ return return_frame # type: ignore
2192
2238
 
2193
2239
 
2194
2240
  def _get_text_coords_from_mask(
@@ -2300,7 +2346,14 @@ def overlay_segmentation_masks(
2300
2346
  draw.rectangle((x, y, text_box[2], text_box[3]), fill=color[label])
2301
2347
  draw.text((x, y), text, fill="black", font=font)
2302
2348
  frame_out.append(np.array(pil_image))
2303
- return frame_out[0] if len(frame_out) == 1 else frame_out
2349
+ return_frame = frame_out[0] if len(frame_out) == 1 else frame_out
2350
+
2351
+ if isinstance(return_frame, np.ndarray):
2352
+ from IPython.display import display
2353
+
2354
+ display(Image.fromarray(return_frame))
2355
+
2356
+ return return_frame # type: ignore
2304
2357
 
2305
2358
 
2306
2359
  def overlay_heat_map(
@@ -2408,6 +2461,7 @@ FUNCTION_TOOLS = [
2408
2461
  florence2_sam2_image,
2409
2462
  florence2_sam2_video_tracking,
2410
2463
  florence2_phrase_grounding,
2464
+ claude35_text_extraction,
2411
2465
  detr_segmentation,
2412
2466
  depth_anything_v2,
2413
2467
  generate_pose_image,
@@ -42,10 +42,10 @@ def normalize_bbox(
42
42
  ) -> List[float]:
43
43
  r"""Normalize the bounding box coordinates to be between 0 and 1."""
44
44
  x1, y1, x2, y2 = bbox
45
- x1 = round(x1 / image_size[1], 2)
46
- y1 = round(y1 / image_size[0], 2)
47
- x2 = round(x2 / image_size[1], 2)
48
- y2 = round(y2 / image_size[0], 2)
45
+ x1 = max(round(x1 / image_size[1], 2), 0)
46
+ y1 = max(round(y1 / image_size[0], 2), 0)
47
+ x2 = min(round(x2 / image_size[1], 2), image_size[1])
48
+ y2 = min(round(y2 / image_size[0], 2), image_size[0])
49
49
  return [x1, y1, x2, y2]
50
50
 
51
51
 
@@ -175,9 +175,15 @@ def encode_media(media: Union[str, Path], resize: Optional[int] = None) -> str:
175
175
  return media[:-4] + ".png"
176
176
  return media
177
177
 
178
- # if media is already a base64 encoded image return
178
+ # if media is in base64 ensure it's the correct resize
179
179
  if isinstance(media, str) and media.startswith("data:image/"):
180
- return media
180
+ image_pil = b64_to_pil(media)
181
+ if resize is not None:
182
+ if image_pil.size[0] > resize or image_pil.size[1] > resize:
183
+ image_pil.thumbnail((resize, resize))
184
+ buffer = io.BytesIO()
185
+ image_pil.save(buffer, format="PNG")
186
+ return base64.b64encode(buffer.getvalue()).decode("utf-8")
181
187
 
182
188
  extension = "png"
183
189
  extension = Path(media).suffix