vision-agent 0.2.220__py3-none-any.whl → 0.2.222__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -21,7 +21,7 @@ from pillow_heif import register_heif_opener # type: ignore
21
21
  from pytube import YouTube # type: ignore
22
22
 
23
23
  from vision_agent.clients.landing_public_api import LandingPublicAPI
24
- from vision_agent.lmm.lmm import AnthropicLMM, OpenAILMM
24
+ from vision_agent.lmm.lmm import AnthropicLMM
25
25
  from vision_agent.tools.tool_utils import (
26
26
  ToolCallTrace,
27
27
  add_bboxes_from_masks,
@@ -112,6 +112,114 @@ def _display_tool_trace(
112
112
  display({MimeType.APPLICATION_JSON: tool_call_trace.model_dump()}, raw=True)
113
113
 
114
114
 
115
+ def _sam2(
116
+ image: np.ndarray,
117
+ detections: List[Dict[str, Any]],
118
+ image_size: Tuple[int, ...],
119
+ image_bytes: Optional[bytes] = None,
120
+ ) -> Dict[str, Any]:
121
+ if image_bytes is None:
122
+ image_bytes = numpy_to_bytes(image)
123
+
124
+ files = [("images", image_bytes)]
125
+ payload = {
126
+ "model": "sam2",
127
+ "bboxes": json.dumps(
128
+ [
129
+ {
130
+ "labels": [d["label"] for d in detections],
131
+ "bboxes": [
132
+ denormalize_bbox(d["bbox"], image_size) for d in detections
133
+ ],
134
+ }
135
+ ]
136
+ ),
137
+ }
138
+
139
+ metadata = {"function_name": "sam2"}
140
+ pred_detections = send_task_inference_request(
141
+ payload, "sam2", files=files, metadata=metadata
142
+ )
143
+ frame = pred_detections[0]
144
+ return_data = []
145
+ display_data = []
146
+ for inp_detection, detection in zip(detections, frame):
147
+ mask = rle_decode_array(detection["mask"])
148
+ label = detection["label"]
149
+ bbox = normalize_bbox(detection["bounding_box"], detection["mask"]["size"])
150
+ return_data.append(
151
+ {
152
+ "label": label,
153
+ "bbox": bbox,
154
+ "mask": mask,
155
+ "score": inp_detection["score"],
156
+ }
157
+ )
158
+ display_data.append(
159
+ {
160
+ "label": label,
161
+ "bbox": detection["bounding_box"],
162
+ "mask": detection["mask"],
163
+ "score": inp_detection["score"],
164
+ }
165
+ )
166
+ return {"files": files, "return_data": return_data, "display_data": display_data}
167
+
168
+
169
+ def sam2(
170
+ image: np.ndarray,
171
+ detections: List[Dict[str, Any]],
172
+ ) -> List[Dict[str, Any]]:
173
+ """'sam2' is a tool that can segment multiple objects given an input bounding box,
174
+ label and score. It returns a set of masks along with the corresponding bounding
175
+ boxes and labels.
176
+
177
+ Parameters:
178
+ image (np.ndarray): The image that contains multiple instances of the object.
179
+ detections (List[Dict[str, Any]]): A list of dictionaries containing the score,
180
+ label, and bounding box of the detected objects with normalized coordinates
181
+ between 0 and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates
182
+ of the top-left and xmax and ymax are the coordinates of the bottom-right of
183
+ the bounding box.
184
+
185
+ Returns:
186
+ List[Dict[str, Any]]: A list of dictionaries containing the score, label,
187
+ bounding box, and mask of the detected objects with normalized coordinates
188
+ (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left
189
+ and xmax and ymax are the coordinates of the bottom-right of the bounding box.
190
+ The mask is binary 2D numpy array where 1 indicates the object and 0 indicates
191
+ the background.
192
+
193
+ Example
194
+ -------
195
+ >>> sam2(image, [
196
+ {'score': 0.49, 'label': 'flower', 'bbox': [0.1, 0.11, 0.35, 0.4]},
197
+ ])
198
+ [
199
+ {
200
+ 'score': 0.49,
201
+ 'label': 'flower',
202
+ 'bbox': [0.1, 0.11, 0.35, 0.4],
203
+ 'mask': array([[0, 0, 0, ..., 0, 0, 0],
204
+ [0, 0, 0, ..., 0, 0, 0],
205
+ ...,
206
+ [0, 0, 0, ..., 0, 0, 0],
207
+ [0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
208
+ },
209
+ ]
210
+ """
211
+ image_size = image.shape[:2]
212
+ ret = _sam2(image, detections, image_size)
213
+ _display_tool_trace(
214
+ sam2.__name__,
215
+ {},
216
+ ret["display_data"],
217
+ ret["files"],
218
+ )
219
+
220
+ return ret["return_data"] # type: ignore
221
+
222
+
115
223
  class ODModels(str, Enum):
116
224
  COUNTGD = "countgd"
117
225
  FLORENCE2 = "florence2"
@@ -139,15 +247,15 @@ def od_sam2_video_tracking(
139
247
  results[idx] = countgd_object_detection(prompt=prompt, image=frames[idx])
140
248
  function_name = "countgd_object_detection"
141
249
  elif od_model == ODModels.OWLV2:
142
- results[idx] = owl_v2_image(
250
+ results[idx] = owlv2_object_detection(
143
251
  prompt=prompt, image=frames[idx], fine_tune_id=fine_tune_id
144
252
  )
145
- function_name = "owl_v2_image"
253
+ function_name = "owlv2_object_detection"
146
254
  elif od_model == ODModels.FLORENCE2:
147
- results[idx] = florence2_sam2_image(
255
+ results[idx] = florence2_object_detection(
148
256
  prompt=prompt, image=frames[idx], fine_tune_id=fine_tune_id
149
257
  )
150
- function_name = "florence2_sam2_image"
258
+ function_name = "florence2_object_detection"
151
259
  else:
152
260
  raise NotImplementedError(
153
261
  f"Object detection model '{od_model}' is not implemented."
@@ -183,7 +291,7 @@ def od_sam2_video_tracking(
183
291
 
184
292
  buffer_bytes = frames_to_bytes(frames)
185
293
  files = [("video", buffer_bytes)]
186
- payload = {"bboxes": json.dumps(output), "chunk_length": chunk_length}
294
+ payload = {"bboxes": json.dumps(output), "chunk_length_frames": chunk_length}
187
295
  metadata = {"function_name": function_name}
188
296
 
189
297
  detections = send_task_inference_request(
@@ -226,53 +334,27 @@ def od_sam2_video_tracking(
226
334
  return {"files": files, "return_data": return_data, "display_data": detections}
227
335
 
228
336
 
229
- def owl_v2_image(
337
+ # Owl V2 Tools
338
+
339
+
340
+ def _owlv2_object_detection(
230
341
  prompt: str,
231
342
  image: np.ndarray,
232
- box_threshold: float = 0.10,
343
+ box_threshold: float,
344
+ image_size: Tuple[int, ...],
345
+ image_bytes: Optional[bytes] = None,
233
346
  fine_tune_id: Optional[str] = None,
234
- ) -> List[Dict[str, Any]]:
235
- """'owl_v2_image' is a tool that can detect and count multiple objects given a text
236
- prompt such as category names or referring expressions on images. The categories in
237
- text prompt are separated by commas. It returns a list of bounding boxes with
238
- normalized coordinates, label names and associated probability scores.
239
-
240
- Parameters:
241
- prompt (str): The prompt to ground to the image.
242
- image (np.ndarray): The image to ground the prompt to.
243
- box_threshold (float, optional): The threshold for the box detection. Defaults
244
- to 0.10.
245
- fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the
246
- fine-tuned model ID here to use it.
247
-
248
- Returns:
249
- List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
250
- bounding box of the detected objects with normalized coordinates between 0
251
- and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
252
- top-left and xmax and ymax are the coordinates of the bottom-right of the
253
- bounding box.
254
-
255
- Example
256
- -------
257
- >>> owl_v2_image("car, dinosaur", image)
258
- [
259
- {'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]},
260
- {'score': 0.98, 'label': 'car', 'bbox': [0.2, 0.21, 0.45, 0.5},
261
- ]
262
- """
263
-
264
- image_size = image.shape[:2]
265
- if image_size[0] < 1 or image_size[1] < 1:
266
- return []
347
+ ) -> Dict[str, Any]:
348
+ if image_bytes is None:
349
+ image_bytes = numpy_to_bytes(image)
267
350
 
268
- buffer_bytes = numpy_to_bytes(image)
269
- files = [("image", buffer_bytes)]
351
+ files = [("image", image_bytes)]
270
352
  payload = {
271
353
  "prompts": [s.strip() for s in prompt.split(",")],
272
354
  "confidence": box_threshold,
273
355
  "model": "owlv2",
274
356
  }
275
- metadata = {"function_name": "owl_v2_image"}
357
+ metadata = {"function_name": "owlv2_object_detection"}
276
358
 
277
359
  if fine_tune_id is not None:
278
360
  landing_api = LandingPublicAPI()
@@ -302,127 +384,96 @@ def owl_v2_image(
302
384
  {
303
385
  "label": bbox["label"],
304
386
  "bbox": normalize_bbox(bbox["bounding_box"], image_size),
305
- "score": round(bbox["score"], 2),
387
+ "score": bbox["score"],
306
388
  }
307
389
  for bbox in bboxes
308
390
  ]
309
-
310
- _display_tool_trace(
311
- owl_v2_image.__name__,
312
- payload,
313
- detections[0],
314
- files,
315
- )
316
- return bboxes_formatted
391
+ display_data = [
392
+ {
393
+ "label": bbox["label"],
394
+ "bbox": bbox["bounding_box"],
395
+ "score": bbox["score"],
396
+ }
397
+ for bbox in bboxes
398
+ ]
399
+ return {
400
+ "files": files,
401
+ "return_data": bboxes_formatted,
402
+ "display_data": display_data,
403
+ }
317
404
 
318
405
 
319
- def owl_v2_video(
406
+ def owlv2_object_detection(
320
407
  prompt: str,
321
- frames: List[np.ndarray],
408
+ image: np.ndarray,
322
409
  box_threshold: float = 0.10,
323
410
  fine_tune_id: Optional[str] = None,
324
- ) -> List[List[Dict[str, Any]]]:
325
- """'owl_v2_video' will run owl_v2 on each frame of a video. It can detect multiple
326
- objects independently per frame given a text prompt such as a category name or
327
- referring expression but does not track objects across frames. The categories in
328
- text prompt are separated by commas. It returns a list of lists where each inner
329
- list contains the score, label, and bounding box of the detections for that frame.
411
+ ) -> List[Dict[str, Any]]:
412
+ """'owlv2_object_detection' is a tool that can detect and count multiple objects
413
+ given a text prompt such as category names or referring expressions on images. The
414
+ categories in text prompt are separated by commas. It returns a list of bounding
415
+ boxes with normalized coordinates, label names and associated probability scores.
330
416
 
331
417
  Parameters:
332
- prompt (str): The prompt to ground to the video.
333
- frames (List[np.ndarray]): The list of frames to ground the prompt to.
418
+ prompt (str): The prompt to ground to the image.
419
+ image (np.ndarray): The image to ground the prompt to.
334
420
  box_threshold (float, optional): The threshold for the box detection. Defaults
335
- to 0.30.
421
+ to 0.10.
336
422
  fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the
337
423
  fine-tuned model ID here to use it.
338
424
 
339
425
  Returns:
340
- List[List[Dict[str, Any]]]: A list of lists of dictionaries containing the
341
- score, label, and bounding box of the detected objects with normalized
342
- coordinates between 0 and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the
343
- coordinates of the top-left and xmax and ymax are the coordinates of the
344
- bottom-right of the bounding box.
426
+ List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
427
+ bounding box of the detected objects with normalized coordinates between 0
428
+ and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
429
+ top-left and xmax and ymax are the coordinates of the bottom-right of the
430
+ bounding box.
345
431
 
346
432
  Example
347
433
  -------
348
- >>> owl_v2_video("car, dinosaur", frames)
434
+ >>> owlv2_object_detection("car, dinosaur", image)
349
435
  [
350
- [
351
- {'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]},
352
- {'score': 0.98, 'label': 'car', 'bbox': [0.2, 0.21, 0.45, 0.5},
353
- ],
354
- ...
436
+ {'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]},
437
+ {'score': 0.98, 'label': 'car', 'bbox': [0.2, 0.21, 0.45, 0.5},
355
438
  ]
356
439
  """
357
- if len(frames) == 0 or not isinstance(frames, List):
358
- raise ValueError("Must provide a list of numpy arrays for frames")
359
-
360
- image_size = frames[0].shape[:2]
361
- buffer_bytes = frames_to_bytes(frames)
362
- files = [("video", buffer_bytes)]
363
- payload = {
364
- "prompts": [s.strip() for s in prompt.split(",")],
365
- "confidence": box_threshold,
366
- "model": "owlv2",
367
- }
368
- metadata = {"function_name": "owl_v2_video"}
369
-
370
- if fine_tune_id is not None:
371
- landing_api = LandingPublicAPI()
372
- status = landing_api.check_fine_tuning_job(UUID(fine_tune_id))
373
- if status is not JobStatus.SUCCEEDED:
374
- raise FineTuneModelIsNotReady(
375
- f"Fine-tuned model {fine_tune_id} is not ready yet"
376
- )
377
440
 
378
- # we can only execute fine-tuned models with florence2
379
- payload = {
380
- "prompts": payload["prompts"],
381
- "jobId": fine_tune_id,
382
- "model": "florence2",
383
- }
441
+ image_size = image.shape[:2]
442
+ if image_size[0] < 1 or image_size[1] < 1:
443
+ return []
384
444
 
385
- detections = send_task_inference_request(
386
- payload,
387
- "text-to-object-detection",
388
- files=files,
389
- metadata=metadata,
445
+ ret = _owlv2_object_detection(
446
+ prompt, image, box_threshold, image_size, fine_tune_id=fine_tune_id
390
447
  )
391
448
 
392
- bboxes_formatted = []
393
- for frame_data in detections:
394
- bboxes_formatted_per_frame = [
395
- {
396
- "label": bbox["label"],
397
- "bbox": normalize_bbox(bbox["bounding_box"], image_size),
398
- "score": round(bbox["score"], 2),
399
- }
400
- for bbox in frame_data
401
- ]
402
- bboxes_formatted.append(bboxes_formatted_per_frame)
403
449
  _display_tool_trace(
404
- owl_v2_video.__name__,
405
- payload,
406
- detections[0],
407
- files,
450
+ owlv2_object_detection.__name__,
451
+ {
452
+ "prompts": prompt,
453
+ "confidence": box_threshold,
454
+ },
455
+ ret["display_data"],
456
+ ret["files"],
408
457
  )
409
- return bboxes_formatted
458
+ return ret["return_data"] # type: ignore
410
459
 
411
460
 
412
- def owlv2_sam2_video_tracking(
461
+ def owlv2_sam2_instance_segmentation(
413
462
  prompt: str,
414
- frames: List[np.ndarray],
415
- chunk_length: Optional[int] = 10,
416
- fine_tune_id: Optional[str] = None,
417
- ) -> List[List[Dict[str, Any]]]:
418
- """'owlv2_sam2_video_tracking' is a tool that can segment multiple objects given a text
419
- prompt such as category names or referring expressions. The categories in the text
420
- prompt are separated by commas. It returns a list of bounding boxes, label names,
421
- mask file names and associated probability scores.
463
+ image: np.ndarray,
464
+ box_threshold: float = 0.10,
465
+ ) -> List[Dict[str, Any]]:
466
+ """'owlv2_sam2_instance_segmentation' is a tool that can detect and count multiple
467
+ instances of objects given a text prompt such as category names or referring
468
+ expressions on images. The categories in text prompt are separated by commas. It
469
+ returns a list of bounding boxes with normalized coordinates, label names, masks
470
+ and associated probability scores.
422
471
 
423
472
  Parameters:
424
- prompt (str): The prompt to ground to the image.
425
- image (np.ndarray): The image to ground the prompt to.
473
+ prompt (str): The object that needs to be counted.
474
+ image (np.ndarray): The image that contains multiple instances of the object.
475
+ box_threshold (float, optional): The threshold for detection. Defaults
476
+ to 0.10.
426
477
 
427
478
  Returns:
428
479
  List[Dict[str, Any]]: A list of dictionaries containing the score, label,
@@ -434,16 +485,76 @@ def owlv2_sam2_video_tracking(
434
485
 
435
486
  Example
436
487
  -------
437
- >>> countgd_sam2_video_tracking("car, dinosaur", frames)
488
+ >>> owlv2_sam2_instance_segmentation("flower", image)
438
489
  [
439
- [
440
- {
441
- 'label': '0: dinosaur',
442
- 'bbox': [0.1, 0.11, 0.35, 0.4],
443
- 'mask': array([[0, 0, 0, ..., 0, 0, 0],
444
- [0, 0, 0, ..., 0, 0, 0],
445
- ...,
446
- [0, 0, 0, ..., 0, 0, 0],
490
+ {
491
+ 'score': 0.49,
492
+ 'label': 'flower',
493
+ 'bbox': [0.1, 0.11, 0.35, 0.4],
494
+ 'mask': array([[0, 0, 0, ..., 0, 0, 0],
495
+ [0, 0, 0, ..., 0, 0, 0],
496
+ ...,
497
+ [0, 0, 0, ..., 0, 0, 0],
498
+ [0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
499
+ },
500
+ ]
501
+ """
502
+
503
+ od_ret = _owlv2_object_detection(prompt, image, box_threshold, image.shape[:2])
504
+ seg_ret = _sam2(
505
+ image, od_ret["return_data"], image.shape[:2], image_bytes=od_ret["files"][0][1]
506
+ )
507
+
508
+ _display_tool_trace(
509
+ countgd_sam2_instance_segmentation.__name__,
510
+ {
511
+ "prompts": prompt,
512
+ "confidence": box_threshold,
513
+ },
514
+ seg_ret["display_data"],
515
+ seg_ret["files"],
516
+ )
517
+
518
+ return seg_ret["return_data"] # type: ignore
519
+
520
+
521
+ def owlv2_sam2_video_tracking(
522
+ prompt: str,
523
+ frames: List[np.ndarray],
524
+ chunk_length: Optional[int] = 10,
525
+ fine_tune_id: Optional[str] = None,
526
+ ) -> List[List[Dict[str, Any]]]:
527
+ """'owlv2_sam2_video_tracking' is a tool that can segment multiple objects given a text
528
+ prompt such as category names or referring expressions. The categories in the text
529
+ prompt are separated by commas. It returns a list of bounding boxes, label names,
530
+ mask file names and associated probability scores.
531
+
532
+ Parameters:
533
+ prompt (str): The prompt to ground to the image.
534
+ image (np.ndarray): The image to ground the prompt to.
535
+ fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the
536
+ fine-tuned model ID here to use it.
537
+
538
+ Returns:
539
+ List[Dict[str, Any]]: A list of dictionaries containing the score, label,
540
+ bounding box, and mask of the detected objects with normalized coordinates
541
+ (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left
542
+ and xmax and ymax are the coordinates of the bottom-right of the bounding box.
543
+ The mask is binary 2D numpy array where 1 indicates the object and 0 indicates
544
+ the background.
545
+
546
+ Example
547
+ -------
548
+ >>> owlv2_sam2_video_tracking("car, dinosaur", frames)
549
+ [
550
+ [
551
+ {
552
+ 'label': '0: dinosaur',
553
+ 'bbox': [0.1, 0.11, 0.35, 0.4],
554
+ 'mask': array([[0, 0, 0, ..., 0, 0, 0],
555
+ [0, 0, 0, ..., 0, 0, 0],
556
+ ...,
557
+ [0, 0, 0, ..., 0, 0, 0],
447
558
  [0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
448
559
  },
449
560
  ],
@@ -467,13 +578,96 @@ def owlv2_sam2_video_tracking(
467
578
  return ret["return_data"] # type: ignore
468
579
 
469
580
 
470
- def florence2_sam2_image(
581
+ # Florence2 Tools
582
+
583
+
584
+ def florence2_object_detection(
471
585
  prompt: str, image: np.ndarray, fine_tune_id: Optional[str] = None
472
586
  ) -> List[Dict[str, Any]]:
473
- """'florence2_sam2_image' is a tool that can segment multiple objects given a text
474
- prompt such as category names or referring expressions. The categories in the text
475
- prompt are separated by commas. It returns a list of bounding boxes, label names,
476
- mask file names and associated probability scores of 1.0.
587
+ """'florence2_object_detection' is a tool that can detect multiple
588
+ objects given a text prompt which can be object names or caption. You
589
+ can optionally separate the object names in the text with commas. It returns a list
590
+ of bounding boxes with normalized coordinates, label names and associated
591
+ confidence scores of 1.0.
592
+
593
+ Parameters:
594
+ prompt (str): The prompt to ground to the image.
595
+ image (np.ndarray): The image to used to detect objects
596
+ fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the
597
+ fine-tuned model ID here to use it.
598
+
599
+ Returns:
600
+ List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
601
+ bounding box of the detected objects with normalized coordinates between 0
602
+ and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
603
+ top-left and xmax and ymax are the coordinates of the bottom-right of the
604
+ bounding box. The scores are always 1.0 and cannot be thresholded
605
+
606
+ Example
607
+ -------
608
+ >>> florence2_object_detection('person looking at a coyote', image)
609
+ [
610
+ {'score': 1.0, 'label': 'person', 'bbox': [0.1, 0.11, 0.35, 0.4]},
611
+ {'score': 1.0, 'label': 'coyote', 'bbox': [0.34, 0.21, 0.85, 0.5},
612
+ ]
613
+ """
614
+ image_size = image.shape[:2]
615
+ if image_size[0] < 1 or image_size[1] < 1:
616
+ return []
617
+
618
+ buffer_bytes = numpy_to_bytes(image)
619
+ files = [("image", buffer_bytes)]
620
+ payload = {
621
+ "prompts": [s.strip() for s in prompt.split(",")],
622
+ "model": "florence2",
623
+ }
624
+ metadata = {"function_name": "florence2_object_detection"}
625
+
626
+ if fine_tune_id is not None:
627
+ landing_api = LandingPublicAPI()
628
+ status = landing_api.check_fine_tuning_job(UUID(fine_tune_id))
629
+ if status is not JobStatus.SUCCEEDED:
630
+ raise FineTuneModelIsNotReady(
631
+ f"Fine-tuned model {fine_tune_id} is not ready yet"
632
+ )
633
+
634
+ payload["jobId"] = fine_tune_id
635
+
636
+ detections = send_task_inference_request(
637
+ payload,
638
+ "text-to-object-detection",
639
+ files=files,
640
+ metadata=metadata,
641
+ )
642
+
643
+ # get the first frame
644
+ bboxes = detections[0]
645
+ bboxes_formatted = [
646
+ {
647
+ "label": bbox["label"],
648
+ "bbox": normalize_bbox(bbox["bounding_box"], image_size),
649
+ "score": round(bbox["score"], 2),
650
+ }
651
+ for bbox in bboxes
652
+ ]
653
+
654
+ _display_tool_trace(
655
+ florence2_object_detection.__name__,
656
+ payload,
657
+ detections[0],
658
+ files,
659
+ )
660
+ return [bbox for bbox in bboxes_formatted]
661
+
662
+
663
+ def florence2_sam2_instance_segmentation(
664
+ prompt: str, image: np.ndarray, fine_tune_id: Optional[str] = None
665
+ ) -> List[Dict[str, Any]]:
666
+ """'florence2_sam2_instance_segmentation' is a tool that can segment multiple
667
+ objects given a text prompt such as category names or referring expressions. The
668
+ categories in the text prompt are separated by commas. It returns a list of
669
+ bounding boxes, label names, mask file names and associated probability scores of
670
+ 1.0.
477
671
 
478
672
  Parameters:
479
673
  prompt (str): The prompt to ground to the image.
@@ -491,7 +685,7 @@ def florence2_sam2_image(
491
685
 
492
686
  Example
493
687
  -------
494
- >>> florence2_sam2_image("car, dinosaur", image)
688
+ >>> florence2_sam2_instance_segmentation("car, dinosaur", image)
495
689
  [
496
690
  {
497
691
  'score': 1.0,
@@ -514,7 +708,7 @@ def florence2_sam2_image(
514
708
  "prompt": prompt,
515
709
  "model": "florence2sam2",
516
710
  }
517
- metadata = {"function_name": "florence2_sam2_image"}
711
+ metadata = {"function_name": "florence2_sam2_instance_segmentation"}
518
712
 
519
713
  if fine_tune_id is not None:
520
714
  landing_api = LandingPublicAPI()
@@ -543,7 +737,7 @@ def florence2_sam2_image(
543
737
  return_data.append({"label": label, "bbox": bbox, "mask": mask, "score": 1.0})
544
738
 
545
739
  _display_tool_trace(
546
- florence2_sam2_image.__name__,
740
+ florence2_sam2_instance_segmentation.__name__,
547
741
  payload,
548
742
  detections[0],
549
743
  files,
@@ -580,7 +774,7 @@ def florence2_sam2_video_tracking(
580
774
 
581
775
  Example
582
776
  -------
583
- >>> florence2_sam2_video("car, dinosaur", frames)
777
+ >>> florence2_sam2_video_tracking("car, dinosaur", frames)
584
778
  [
585
779
  [
586
780
  {
@@ -665,10 +859,11 @@ def florence2_sam2_video_tracking(
665
859
  return return_data
666
860
 
667
861
 
668
- def ocr(image: np.ndarray) -> List[Dict[str, Any]]:
669
- """'ocr' extracts text from an image. It returns a list of detected text, bounding
670
- boxes with normalized coordinates, and confidence scores. The results are sorted
671
- from top-left to bottom right.
862
+ def florence2_ocr(image: np.ndarray) -> List[Dict[str, Any]]:
863
+ """'florence2_ocr' is a tool that can detect text and text regions in an image.
864
+ Each text region contains one line of text. It returns a list of detected text,
865
+ the text region as a bounding box with normalized coordinates, and confidence
866
+ scores. The results are sorted from top-left to bottom right.
672
867
 
673
868
  Parameters:
674
869
  image (np.ndarray): The image to extract text from.
@@ -679,173 +874,59 @@ def ocr(image: np.ndarray) -> List[Dict[str, Any]]:
679
874
 
680
875
  Example
681
876
  -------
682
- >>> ocr(image)
877
+ >>> florence2_ocr(image)
683
878
  [
684
879
  {'label': 'hello world', 'bbox': [0.1, 0.11, 0.35, 0.4], 'score': 0.99},
685
880
  ]
686
881
  """
687
882
 
688
- pil_image = Image.fromarray(image).convert("RGB")
689
- image_size = pil_image.size[::-1]
883
+ image_size = image.shape[:2]
690
884
  if image_size[0] < 1 or image_size[1] < 1:
691
885
  return []
692
- image_buffer = io.BytesIO()
693
- pil_image.save(image_buffer, format="PNG")
694
- buffer_bytes = image_buffer.getvalue()
695
- image_buffer.close()
696
-
697
- res = requests.post(
698
- _OCR_URL,
699
- files={"images": buffer_bytes},
700
- data={"language": "en"},
701
- headers={"contentType": "multipart/form-data", "apikey": _API_KEY},
702
- )
703
-
704
- if res.status_code != 200:
705
- raise ValueError(f"OCR request failed with status code {res.status_code}")
706
-
707
- data = res.json()
708
- output = []
709
- for det in data[0]:
710
- label = det["text"]
711
- box = [
712
- det["location"][0]["x"],
713
- det["location"][0]["y"],
714
- det["location"][2]["x"],
715
- det["location"][2]["y"],
716
- ]
717
- box = normalize_bbox(box, image_size)
718
- output.append({"label": label, "bbox": box, "score": round(det["score"], 2)})
886
+ image_b64 = convert_to_b64(image)
887
+ data = {
888
+ "image": image_b64,
889
+ "task": "<OCR_WITH_REGION>",
890
+ "function_name": "florence2_ocr",
891
+ }
719
892
 
893
+ detections = send_inference_request(data, "florence2", v2=True)
894
+ detections = detections["<OCR_WITH_REGION>"]
895
+ return_data = []
896
+ for i in range(len(detections["quad_boxes"])):
897
+ return_data.append(
898
+ {
899
+ "label": detections["labels"][i],
900
+ "bbox": normalize_bbox(
901
+ convert_quad_box_to_bbox(detections["quad_boxes"][i]), image_size
902
+ ),
903
+ "score": 1.0,
904
+ }
905
+ )
720
906
  _display_tool_trace(
721
- ocr.__name__,
907
+ florence2_ocr.__name__,
722
908
  {},
723
- data,
724
- cast(List[Tuple[str, bytes]], [("image", buffer_bytes)]),
909
+ detections,
910
+ image_b64,
725
911
  )
726
- return sorted(output, key=lambda x: (x["bbox"][1], x["bbox"][0]))
912
+ return return_data
727
913
 
728
914
 
729
- def _sam2(
915
+ # CountGD Tools
916
+
917
+
918
+ def _countgd_object_detection(
919
+ prompt: str,
730
920
  image: np.ndarray,
731
- detections: List[Dict[str, Any]],
921
+ box_threshold: float,
732
922
  image_size: Tuple[int, ...],
733
923
  image_bytes: Optional[bytes] = None,
734
924
  ) -> Dict[str, Any]:
735
925
  if image_bytes is None:
736
926
  image_bytes = numpy_to_bytes(image)
737
927
 
738
- files = [("images", image_bytes)]
739
- payload = {
740
- "model": "sam2",
741
- "bboxes": json.dumps(
742
- [
743
- {
744
- "labels": [d["label"] for d in detections],
745
- "bboxes": [
746
- denormalize_bbox(d["bbox"], image_size) for d in detections
747
- ],
748
- }
749
- ]
750
- ),
751
- }
752
-
753
- metadata = {"function_name": "sam2"}
754
- pred_detections = send_task_inference_request(
755
- payload, "sam2", files=files, metadata=metadata
756
- )
757
- frame = pred_detections[0]
758
- return_data = []
759
- display_data = []
760
- for inp_detection, detection in zip(detections, frame):
761
- mask = rle_decode_array(detection["mask"])
762
- label = detection["label"]
763
- bbox = normalize_bbox(detection["bounding_box"], detection["mask"]["size"])
764
- return_data.append(
765
- {
766
- "label": label,
767
- "bbox": bbox,
768
- "mask": mask,
769
- "score": inp_detection["score"],
770
- }
771
- )
772
- display_data.append(
773
- {
774
- "label": label,
775
- "bbox": detection["bounding_box"],
776
- "mask": detection["mask"],
777
- "score": inp_detection["score"],
778
- }
779
- )
780
- return {"files": files, "return_data": return_data, "display_data": display_data}
781
-
782
-
783
- def sam2(
784
- image: np.ndarray,
785
- detections: List[Dict[str, Any]],
786
- ) -> List[Dict[str, Any]]:
787
- """'sam2' is a tool that can segment multiple objects given an input bounding box,
788
- label and score. It returns a set of masks along with the corresponding bounding
789
- boxes and labels.
790
-
791
- Parameters:
792
- image (np.ndarray): The image that contains multiple instances of the object.
793
- detections (List[Dict[str, Any]]): A list of dictionaries containing the score,
794
- label, and bounding box of the detected objects with normalized coordinates
795
- between 0 and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates
796
- of the top-left and xmax and ymax are the coordinates of the bottom-right of
797
- the bounding box.
798
-
799
- Returns:
800
- List[Dict[str, Any]]: A list of dictionaries containing the score, label,
801
- bounding box, and mask of the detected objects with normalized coordinates
802
- (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left
803
- and xmax and ymax are the coordinates of the bottom-right of the bounding box.
804
- The mask is binary 2D numpy array where 1 indicates the object and 0 indicates
805
- the background.
806
-
807
- Example
808
- -------
809
- >>> sam2(image, [
810
- {'score': 0.49, 'label': 'flower', 'bbox': [0.1, 0.11, 0.35, 0.4]},
811
- ])
812
- [
813
- {
814
- 'score': 0.49,
815
- 'label': 'flower',
816
- 'bbox': [0.1, 0.11, 0.35, 0.4],
817
- 'mask': array([[0, 0, 0, ..., 0, 0, 0],
818
- [0, 0, 0, ..., 0, 0, 0],
819
- ...,
820
- [0, 0, 0, ..., 0, 0, 0],
821
- [0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
822
- },
823
- ]
824
- """
825
- image_size = image.shape[:2]
826
- ret = _sam2(image, detections, image_size)
827
- _display_tool_trace(
828
- sam2.__name__,
829
- {},
830
- ret["display_data"],
831
- ret["files"],
832
- )
833
-
834
- return ret["return_data"] # type: ignore
835
-
836
-
837
- def _countgd_object_detection(
838
- prompt: str,
839
- image: np.ndarray,
840
- box_threshold: float,
841
- image_size: Tuple[int, ...],
842
- image_bytes: Optional[bytes] = None,
843
- ) -> Dict[str, Any]:
844
- if image_bytes is None:
845
- image_bytes = numpy_to_bytes(image)
846
-
847
- files = [("image", image_bytes)]
848
- prompts = [p.strip() for p in prompt.split(", ")]
928
+ files = [("image", image_bytes)]
929
+ prompts = [p.strip() for p in prompt.split(", ")]
849
930
 
850
931
  def _run_countgd(prompt: str) -> List[Dict[str, Any]]:
851
932
  payload = {
@@ -853,7 +934,7 @@ def _countgd_object_detection(
853
934
  "confidence": box_threshold, # still not being used in the API
854
935
  "model": "countgd",
855
936
  }
856
- metadata = {"function_name": "countgd_counting"}
937
+ metadata = {"function_name": "countgd_object_detection"}
857
938
 
858
939
  detections = send_task_inference_request(
859
940
  payload, "text-to-object-detection", files=files, metadata=metadata
@@ -939,16 +1020,16 @@ def countgd_object_detection(
939
1020
  return ret["return_data"] # type: ignore
940
1021
 
941
1022
 
942
- def countgd_sam2_object_detection(
1023
+ def countgd_sam2_instance_segmentation(
943
1024
  prompt: str,
944
1025
  image: np.ndarray,
945
1026
  box_threshold: float = 0.23,
946
1027
  ) -> List[Dict[str, Any]]:
947
- """'countgd_sam2_object_detection' is a tool that can detect multiple instances of
948
- an object given a text prompt. It is particularly useful when trying to detect and
949
- count a large number of objects. You can optionally separate object names in the
950
- prompt with commas. It returns a list of bounding boxes with normalized coordinates,
951
- label names, masks associated confidence scores.
1028
+ """'countgd_sam2_instance_segmentation' is a tool that can detect multiple
1029
+ instances of an object given a text prompt. It is particularly useful when trying
1030
+ to detect and count a large number of objects. You can optionally separate object
1031
+ names in the prompt with commas. It returns a list of bounding boxes with
1032
+ normalized coordinates, label names, masks associated confidence scores.
952
1033
 
953
1034
  Parameters:
954
1035
  prompt (str): The object that needs to be counted.
@@ -966,7 +1047,7 @@ def countgd_sam2_object_detection(
966
1047
 
967
1048
  Example
968
1049
  -------
969
- >>> countgd_object_detection("flower", image)
1050
+ >>> countgd_sam2_instance_segmentation("flower", image)
970
1051
  [
971
1052
  {
972
1053
  'score': 0.49,
@@ -987,7 +1068,7 @@ def countgd_sam2_object_detection(
987
1068
  )
988
1069
 
989
1070
  _display_tool_trace(
990
- countgd_sam2_object_detection.__name__,
1071
+ countgd_sam2_instance_segmentation.__name__,
991
1072
  {
992
1073
  "prompts": prompt,
993
1074
  "confidence": box_threshold,
@@ -1012,6 +1093,8 @@ def countgd_sam2_video_tracking(
1012
1093
  Parameters:
1013
1094
  prompt (str): The prompt to ground to the image.
1014
1095
  image (np.ndarray): The image to ground the prompt to.
1096
+ chunk_length (Optional[int]): The number of frames to re-run florence2 to find
1097
+ new objects.
1015
1098
 
1016
1099
  Returns:
1017
1100
  List[Dict[str, Any]]: A list of dictionaries containing the score, label,
@@ -1052,14 +1135,15 @@ def countgd_sam2_video_tracking(
1052
1135
  return ret["return_data"] # type: ignore
1053
1136
 
1054
1137
 
1055
- def countgd_example_based_counting(
1138
+ def countgd_visual_prompt_object_detection(
1056
1139
  visual_prompts: List[List[float]],
1057
1140
  image: np.ndarray,
1058
1141
  box_threshold: float = 0.23,
1059
1142
  ) -> List[Dict[str, Any]]:
1060
- """'countgd_example_based_counting' is a tool that can precisely count multiple
1061
- instances of an object given few visual example prompts. It returns a list of bounding
1062
- boxes with normalized coordinates, label names and associated confidence scores.
1143
+ """'countgd_visual_prompt_object_detection' is a tool that can precisely count
1144
+ multiple instances of an object given few visual example prompts. It returns a list
1145
+ of bounding boxes with normalized coordinates, label names and associated
1146
+ confidence scores.
1063
1147
 
1064
1148
  Parameters:
1065
1149
  visual_prompts (List[List[float]]): Bounding boxes of the object in format
@@ -1077,7 +1161,7 @@ def countgd_example_based_counting(
1077
1161
 
1078
1162
  Example
1079
1163
  -------
1080
- >>> countgd_example_based_counting(
1164
+ >>> countgd_visual_object_detection(
1081
1165
  visual_prompts=[[0.1, 0.1, 0.4, 0.42], [0.2, 0.3, 0.25, 0.35]],
1082
1166
  image=image
1083
1167
  )
@@ -1098,7 +1182,7 @@ def countgd_example_based_counting(
1098
1182
  denormalize_bbox(bbox, image.shape[:2]) for bbox in visual_prompts
1099
1183
  ]
1100
1184
  payload = {"visual_prompts": json.dumps(visual_prompts), "model": "countgd"}
1101
- metadata = {"function_name": "countgd_example_based_counting"}
1185
+ metadata = {"function_name": "countgd_visual_prompt_object_detection"}
1102
1186
 
1103
1187
  detections = send_task_inference_request(
1104
1188
  payload, "visual-prompts-to-object-detection", files=files, metadata=metadata
@@ -1115,7 +1199,7 @@ def countgd_example_based_counting(
1115
1199
  for bbox in bboxes_per_frame
1116
1200
  ]
1117
1201
  _display_tool_trace(
1118
- countgd_example_based_counting.__name__,
1202
+ countgd_visual_prompt_object_detection.__name__,
1119
1203
  payload,
1120
1204
  [
1121
1205
  {
@@ -1214,6 +1298,67 @@ def qwen2_vl_video_vqa(prompt: str, frames: List[np.ndarray]) -> str:
1214
1298
  return cast(str, data)
1215
1299
 
1216
1300
 
1301
+ def ocr(image: np.ndarray) -> List[Dict[str, Any]]:
1302
+ """'ocr' extracts text from an image. It returns a list of detected text, bounding
1303
+ boxes with normalized coordinates, and confidence scores. The results are sorted
1304
+ from top-left to bottom right.
1305
+
1306
+ Parameters:
1307
+ image (np.ndarray): The image to extract text from.
1308
+
1309
+ Returns:
1310
+ List[Dict[str, Any]]: A list of dictionaries containing the detected text, bbox
1311
+ with normalized coordinates, and confidence score.
1312
+
1313
+ Example
1314
+ -------
1315
+ >>> ocr(image)
1316
+ [
1317
+ {'label': 'hello world', 'bbox': [0.1, 0.11, 0.35, 0.4], 'score': 0.99},
1318
+ ]
1319
+ """
1320
+
1321
+ pil_image = Image.fromarray(image).convert("RGB")
1322
+ image_size = pil_image.size[::-1]
1323
+ if image_size[0] < 1 or image_size[1] < 1:
1324
+ return []
1325
+ image_buffer = io.BytesIO()
1326
+ pil_image.save(image_buffer, format="PNG")
1327
+ buffer_bytes = image_buffer.getvalue()
1328
+ image_buffer.close()
1329
+
1330
+ res = requests.post(
1331
+ _OCR_URL,
1332
+ files={"images": buffer_bytes},
1333
+ data={"language": "en"},
1334
+ headers={"contentType": "multipart/form-data", "apikey": _API_KEY},
1335
+ )
1336
+
1337
+ if res.status_code != 200:
1338
+ raise ValueError(f"OCR request failed with status code {res.status_code}")
1339
+
1340
+ data = res.json()
1341
+ output = []
1342
+ for det in data[0]:
1343
+ label = det["text"]
1344
+ box = [
1345
+ det["location"][0]["x"],
1346
+ det["location"][0]["y"],
1347
+ det["location"][2]["x"],
1348
+ det["location"][2]["y"],
1349
+ ]
1350
+ box = normalize_bbox(box, image_size)
1351
+ output.append({"label": label, "bbox": box, "score": round(det["score"], 2)})
1352
+
1353
+ _display_tool_trace(
1354
+ ocr.__name__,
1355
+ {},
1356
+ data,
1357
+ cast(List[Tuple[str, bytes]], [("image", buffer_bytes)]),
1358
+ )
1359
+ return sorted(output, key=lambda x: (x["bbox"][1], x["bbox"][0]))
1360
+
1361
+
1217
1362
  def claude35_text_extraction(image: np.ndarray) -> str:
1218
1363
  """'claude35_text_extraction' is a tool that can extract text from an image. It
1219
1364
  returns the extracted text as a string and can be used as an alternative to OCR if
@@ -1238,67 +1383,146 @@ def claude35_text_extraction(image: np.ndarray) -> str:
1238
1383
  return cast(str, text)
1239
1384
 
1240
1385
 
1241
- def gpt4o_image_vqa(prompt: str, image: np.ndarray) -> str:
1242
- """'gpt4o_image_vqa' is a tool that can answer any questions about arbitrary images
1243
- including regular images or images of documents or presentations. It returns text
1244
- as an answer to the question.
1386
+ def document_extraction(image: np.ndarray) -> Dict[str, Any]:
1387
+ """'document_extraction' is a tool that can extract structured information out of
1388
+ documents with different layouts. It returns the extracted data in a structured
1389
+ hierarchical format containing text, tables, pictures, charts, and other
1390
+ information.
1245
1391
 
1246
1392
  Parameters:
1247
- prompt (str): The question about the image
1248
- image (np.ndarray): The reference image used for the question
1393
+ image (np.ndarray): The document image to analyze
1249
1394
 
1250
1395
  Returns:
1251
- str: A string which is the answer to the given prompt.
1396
+ Dict[str, Any]: A dictionary containing the extracted information.
1252
1397
 
1253
1398
  Example
1254
1399
  -------
1255
- >>> gpt4o_image_vqa('What is the cat doing?', image)
1256
- 'drinking milk'
1400
+ >>> document_analysis(image)
1401
+ {'pages':
1402
+ [{'bbox': [0, 0, 1.0, 1.0],
1403
+ 'chunks': [{'bbox': [0.8, 0.1, 1.0, 0.2],
1404
+ 'label': 'page_header',
1405
+ 'order': 75
1406
+ 'caption': 'Annual Report 2024',
1407
+ 'summary': 'This annual report summarizes ...' },
1408
+ {'bbox': [0.2, 0.9, 0.9, 1.0],
1409
+ 'label': 'table',
1410
+ 'order': 1119,
1411
+ 'caption': [{'Column 1': 'Value 1', 'Column 2': 'Value 2'},
1412
+ 'summary': 'This table illustrates a trend of ...'},
1413
+ ],
1257
1414
  """
1258
1415
 
1259
- lmm = OpenAILMM()
1260
- buffer = io.BytesIO()
1261
- Image.fromarray(image).save(buffer, format="PNG")
1262
- image_bytes = buffer.getvalue()
1263
- image_b64 = "data:image/png;base64," + encode_image_bytes(image_bytes)
1264
- resp = lmm.generate(prompt, [image_b64])
1265
- return cast(str, resp)
1416
+ image_file = numpy_to_bytes(image)
1266
1417
 
1418
+ files = [("image", image_file)]
1267
1419
 
1268
- def gpt4o_video_vqa(prompt: str, frames: List[np.ndarray]) -> str:
1269
- """'gpt4o_video_vqa' is a tool that can answer any questions about arbitrary videos
1270
- including regular videos or videos of documents or presentations. It returns text
1271
- as an answer to the question.
1420
+ payload = {
1421
+ "model": "document-analysis",
1422
+ }
1423
+
1424
+ data: Dict[str, Any] = send_inference_request(
1425
+ payload=payload,
1426
+ endpoint_name="document-analysis",
1427
+ files=files,
1428
+ v2=True,
1429
+ metadata_payload={"function_name": "document_analysis"},
1430
+ )
1431
+
1432
+ # don't display normalized bboxes
1433
+ _display_tool_trace(
1434
+ document_extraction.__name__,
1435
+ payload,
1436
+ data,
1437
+ files,
1438
+ )
1439
+
1440
+ def normalize(data: Any) -> Dict[str, Any]:
1441
+ if isinstance(data, Dict):
1442
+ if "bbox" in data:
1443
+ data["bbox"] = normalize_bbox(data["bbox"], image.shape[:2])
1444
+ for key in data:
1445
+ data[key] = normalize(data[key])
1446
+ elif isinstance(data, List):
1447
+ for i in range(len(data)):
1448
+ data[i] = normalize(data[i])
1449
+ return data # type: ignore
1450
+
1451
+ data = normalize(data)
1452
+
1453
+ return data
1454
+
1455
+
1456
+ def document_qa(
1457
+ prompt: str,
1458
+ image: np.ndarray,
1459
+ ) -> str:
1460
+ """'document_qa' is a tool that can answer any questions about arbitrary documents,
1461
+ presentations, or tables. It's very useful for document QA tasks, you can ask it a
1462
+ specific question or ask it to return a JSON object answering multiple questions
1463
+ about the document.
1272
1464
 
1273
1465
  Parameters:
1274
- prompt (str): The question about the video
1275
- frames (List[np.ndarray]): The reference frames used for the question
1466
+ prompt (str): The question to be answered about the document image.
1467
+ image (np.ndarray): The document image to analyze.
1276
1468
 
1277
1469
  Returns:
1278
- str: A string which is the answer to the given prompt.
1470
+ str: The answer to the question based on the document's context.
1279
1471
 
1280
1472
  Example
1281
1473
  -------
1282
- >>> gpt4o_video_vqa('Which football player made the goal?', frames)
1283
- 'Lionel Messi'
1474
+ >>> document_qa(image, question)
1475
+ 'The answer to the question ...'
1284
1476
  """
1285
1477
 
1286
- lmm = OpenAILMM()
1478
+ image_file = numpy_to_bytes(image)
1287
1479
 
1288
- if len(frames) > 10:
1289
- step = len(frames) / 10
1290
- frames = [frames[int(i * step)] for i in range(10)]
1480
+ files = [("image", image_file)]
1291
1481
 
1292
- frames_b64 = []
1293
- for frame in frames:
1294
- buffer = io.BytesIO()
1295
- Image.fromarray(frame).save(buffer, format="PNG")
1296
- image_bytes = buffer.getvalue()
1297
- image_b64 = "data:image/png;base64," + encode_image_bytes(image_bytes)
1298
- frames_b64.append(image_b64)
1482
+ payload = {
1483
+ "model": "document-analysis",
1484
+ }
1299
1485
 
1300
- resp = lmm.generate(prompt, frames_b64)
1301
- return cast(str, resp)
1486
+ data: Dict[str, Any] = send_inference_request(
1487
+ payload=payload,
1488
+ endpoint_name="document-analysis",
1489
+ files=files,
1490
+ v2=True,
1491
+ metadata_payload={"function_name": "document_qa"},
1492
+ )
1493
+
1494
+ def normalize(data: Any) -> Dict[str, Any]:
1495
+ if isinstance(data, Dict):
1496
+ if "bbox" in data:
1497
+ data["bbox"] = normalize_bbox(data["bbox"], image.shape[:2])
1498
+ for key in data:
1499
+ data[key] = normalize(data[key])
1500
+ elif isinstance(data, List):
1501
+ for i in range(len(data)):
1502
+ data[i] = normalize(data[i])
1503
+ return data # type: ignore
1504
+
1505
+ data = normalize(data)
1506
+
1507
+ prompt = f"""
1508
+ Document Context:
1509
+ {data}\n
1510
+ Question: {prompt}\n
1511
+ Answer the question directly using only the information from the document, do not answer with any additional text besides the answer. If the answer is not definitively contained in the document, say "I cannot find the answer in the provided document."
1512
+ """
1513
+
1514
+ lmm = AnthropicLMM()
1515
+ llm_output = lmm.generate(prompt=prompt)
1516
+ llm_output = cast(str, llm_output)
1517
+
1518
+ _display_tool_trace(
1519
+ document_qa.__name__,
1520
+ payload,
1521
+ llm_output,
1522
+ files,
1523
+ )
1524
+
1525
+ return llm_output
1302
1526
 
1303
1527
 
1304
1528
  def video_temporal_localization(
@@ -1360,284 +1584,70 @@ def vit_image_classification(image: np.ndarray) -> Dict[str, Any]:
1360
1584
  Parameters:
1361
1585
  image (np.ndarray): The image to classify or tag
1362
1586
 
1363
- Returns:
1364
- Dict[str, Any]: A dictionary containing the labels and scores. One dictionary
1365
- contains a list of labels and other a list of scores.
1366
-
1367
- Example
1368
- -------
1369
- >>> vit_image_classification(image)
1370
- {"labels": ["leopard", "lemur, otter", "bird"], "scores": [0.68, 0.30, 0.02]},
1371
- """
1372
- if image.shape[0] < 1 or image.shape[1] < 1:
1373
- return {"labels": [], "scores": []}
1374
-
1375
- image_b64 = convert_to_b64(image)
1376
- data = {
1377
- "image": image_b64,
1378
- "tool": "image_classification",
1379
- "function_name": "vit_image_classification",
1380
- }
1381
- resp_data: dict[str, Any] = send_inference_request(data, "tools")
1382
- resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]]
1383
- _display_tool_trace(
1384
- vit_image_classification.__name__,
1385
- data,
1386
- resp_data,
1387
- image_b64,
1388
- )
1389
- return resp_data
1390
-
1391
-
1392
- def vit_nsfw_classification(image: np.ndarray) -> Dict[str, Any]:
1393
- """'vit_nsfw_classification' is a tool that can classify an image as 'nsfw' or 'normal'.
1394
- It returns the predicted label and their probability scores based on image content.
1395
-
1396
- Parameters:
1397
- image (np.ndarray): The image to classify or tag
1398
-
1399
- Returns:
1400
- Dict[str, Any]: A dictionary containing the labels and scores. One dictionary
1401
- contains a list of labels and other a list of scores.
1402
-
1403
- Example
1404
- -------
1405
- >>> vit_nsfw_classification(image)
1406
- {"label": "normal", "scores": 0.68},
1407
- """
1408
- if image.shape[0] < 1 or image.shape[1] < 1:
1409
- raise ValueError(f"Image is empty, image shape: {image.shape}")
1410
-
1411
- image_b64 = convert_to_b64(image)
1412
- data = {
1413
- "image": image_b64,
1414
- "function_name": "vit_nsfw_classification",
1415
- }
1416
- resp_data: dict[str, Any] = send_inference_request(
1417
- data, "nsfw-classification", v2=True
1418
- )
1419
- resp_data["score"] = round(resp_data["score"], 4)
1420
- _display_tool_trace(
1421
- vit_nsfw_classification.__name__,
1422
- data,
1423
- resp_data,
1424
- image_b64,
1425
- )
1426
- return resp_data
1427
-
1428
-
1429
- def florence2_phrase_grounding(
1430
- prompt: str, image: np.ndarray, fine_tune_id: Optional[str] = None
1431
- ) -> List[Dict[str, Any]]:
1432
- """'florence2_phrase_grounding' is a tool that can detect multiple
1433
- objects given a text prompt which can be object names or caption. You
1434
- can optionally separate the object names in the text with commas. It returns a list
1435
- of bounding boxes with normalized coordinates, label names and associated
1436
- confidence scores of 1.0.
1437
-
1438
- Parameters:
1439
- prompt (str): The prompt to ground to the image.
1440
- image (np.ndarray): The image to used to detect objects
1441
- fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the
1442
- fine-tuned model ID here to use it.
1443
-
1444
- Returns:
1445
- List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
1446
- bounding box of the detected objects with normalized coordinates between 0
1447
- and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
1448
- top-left and xmax and ymax are the coordinates of the bottom-right of the
1449
- bounding box. The scores are always 1.0 and cannot be thresholded
1450
-
1451
- Example
1452
- -------
1453
- >>> florence2_phrase_grounding('person looking at a coyote', image)
1454
- [
1455
- {'score': 1.0, 'label': 'person', 'bbox': [0.1, 0.11, 0.35, 0.4]},
1456
- {'score': 1.0, 'label': 'coyote', 'bbox': [0.34, 0.21, 0.85, 0.5},
1457
- ]
1458
- """
1459
- image_size = image.shape[:2]
1460
- if image_size[0] < 1 or image_size[1] < 1:
1461
- return []
1462
-
1463
- buffer_bytes = numpy_to_bytes(image)
1464
- files = [("image", buffer_bytes)]
1465
- payload = {
1466
- "prompts": [s.strip() for s in prompt.split(",")],
1467
- "model": "florence2",
1468
- }
1469
- metadata = {"function_name": "florence2_phrase_grounding"}
1470
-
1471
- if fine_tune_id is not None:
1472
- landing_api = LandingPublicAPI()
1473
- status = landing_api.check_fine_tuning_job(UUID(fine_tune_id))
1474
- if status is not JobStatus.SUCCEEDED:
1475
- raise FineTuneModelIsNotReady(
1476
- f"Fine-tuned model {fine_tune_id} is not ready yet"
1477
- )
1478
-
1479
- payload["jobId"] = fine_tune_id
1480
-
1481
- detections = send_task_inference_request(
1482
- payload,
1483
- "text-to-object-detection",
1484
- files=files,
1485
- metadata=metadata,
1486
- )
1487
-
1488
- # get the first frame
1489
- bboxes = detections[0]
1490
- bboxes_formatted = [
1491
- {
1492
- "label": bbox["label"],
1493
- "bbox": normalize_bbox(bbox["bounding_box"], image_size),
1494
- "score": round(bbox["score"], 2),
1495
- }
1496
- for bbox in bboxes
1497
- ]
1498
-
1499
- _display_tool_trace(
1500
- florence2_phrase_grounding.__name__,
1501
- payload,
1502
- detections[0],
1503
- files,
1504
- )
1505
- return [bbox for bbox in bboxes_formatted]
1506
-
1507
-
1508
- def florence2_phrase_grounding_video(
1509
- prompt: str, frames: List[np.ndarray], fine_tune_id: Optional[str] = None
1510
- ) -> List[List[Dict[str, Any]]]:
1511
- """'florence2_phrase_grounding_video' will run florence2 on each frame of a video.
1512
- It can detect multiple objects given a text prompt which can be object names or
1513
- caption. You can optionally separate the object names in the text with commas.
1514
- It returns a list of lists where each inner list contains bounding boxes with
1515
- normalized coordinates, label names and associated probability scores of 1.0.
1516
-
1517
- Parameters:
1518
- prompt (str): The prompt to ground to the video.
1519
- frames (List[np.ndarray]): The list of frames to detect objects.
1520
- fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the
1521
- fine-tuned model ID here to use it.
1522
-
1523
- Returns:
1524
- List[List[Dict[str, Any]]]: A list of lists of dictionaries containing the score,
1525
- label, and bounding box of the detected objects with normalized coordinates
1526
- between 0 and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates
1527
- of the top-left and xmax and ymax are the coordinates of the bottom-right of
1528
- the bounding box. The scores are always 1.0 and cannot be thresholded.
1529
-
1530
- Example
1531
- -------
1532
- >>> florence2_phrase_grounding_video('person looking at a coyote', frames)
1533
- [
1534
- [
1535
- {'score': 1.0, 'label': 'person', 'bbox': [0.1, 0.11, 0.35, 0.4]},
1536
- {'score': 1.0, 'label': 'coyote', 'bbox': [0.34, 0.21, 0.85, 0.5},
1537
- ],
1538
- ...
1539
- ]
1540
- """
1541
- if len(frames) == 0:
1542
- raise ValueError("No frames provided")
1543
-
1544
- image_size = frames[0].shape[:2]
1545
- buffer_bytes = frames_to_bytes(frames)
1546
- files = [("video", buffer_bytes)]
1547
- payload = {
1548
- "prompts": [s.strip() for s in prompt.split(",")],
1549
- "model": "florence2",
1550
- }
1551
- metadata = {"function_name": "florence2_phrase_grounding_video"}
1552
-
1553
- if fine_tune_id is not None:
1554
- landing_api = LandingPublicAPI()
1555
- status = landing_api.check_fine_tuning_job(UUID(fine_tune_id))
1556
- if status is not JobStatus.SUCCEEDED:
1557
- raise FineTuneModelIsNotReady(
1558
- f"Fine-tuned model {fine_tune_id} is not ready yet"
1559
- )
1560
-
1561
- payload["jobId"] = fine_tune_id
1562
-
1563
- detections = send_task_inference_request(
1564
- payload,
1565
- "text-to-object-detection",
1566
- files=files,
1567
- metadata=metadata,
1568
- )
1569
-
1570
- bboxes_formatted = []
1571
- for frame_data in detections:
1572
- bboxes_formatted_per_frame = [
1573
- {
1574
- "label": bbox["label"],
1575
- "bbox": normalize_bbox(bbox["bounding_box"], image_size),
1576
- "score": round(bbox["score"], 2),
1577
- }
1578
- for bbox in frame_data
1579
- ]
1580
- bboxes_formatted.append(bboxes_formatted_per_frame)
1587
+ Returns:
1588
+ Dict[str, Any]: A dictionary containing the labels and scores. One dictionary
1589
+ contains a list of labels and other a list of scores.
1590
+
1591
+ Example
1592
+ -------
1593
+ >>> vit_image_classification(image)
1594
+ {"labels": ["leopard", "lemur, otter", "bird"], "scores": [0.68, 0.30, 0.02]},
1595
+ """
1596
+ if image.shape[0] < 1 or image.shape[1] < 1:
1597
+ return {"labels": [], "scores": []}
1598
+
1599
+ image_b64 = convert_to_b64(image)
1600
+ data = {
1601
+ "image": image_b64,
1602
+ "tool": "image_classification",
1603
+ "function_name": "vit_image_classification",
1604
+ }
1605
+ resp_data: dict[str, Any] = send_inference_request(data, "tools")
1606
+ resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]]
1581
1607
  _display_tool_trace(
1582
- florence2_phrase_grounding_video.__name__,
1583
- payload,
1584
- detections,
1585
- files,
1608
+ vit_image_classification.__name__,
1609
+ data,
1610
+ resp_data,
1611
+ image_b64,
1586
1612
  )
1587
- return bboxes_formatted
1613
+ return resp_data
1588
1614
 
1589
1615
 
1590
- def florence2_ocr(image: np.ndarray) -> List[Dict[str, Any]]:
1591
- """'florence2_ocr' is a tool that can detect text and text regions in an image.
1592
- Each text region contains one line of text. It returns a list of detected text,
1593
- the text region as a bounding box with normalized coordinates, and confidence
1594
- scores. The results are sorted from top-left to bottom right.
1616
+ def vit_nsfw_classification(image: np.ndarray) -> Dict[str, Any]:
1617
+ """'vit_nsfw_classification' is a tool that can classify an image as 'nsfw' or 'normal'.
1618
+ It returns the predicted label and their probability scores based on image content.
1595
1619
 
1596
1620
  Parameters:
1597
- image (np.ndarray): The image to extract text from.
1621
+ image (np.ndarray): The image to classify or tag
1598
1622
 
1599
1623
  Returns:
1600
- List[Dict[str, Any]]: A list of dictionaries containing the detected text, bbox
1601
- with normalized coordinates, and confidence score.
1624
+ Dict[str, Any]: A dictionary containing the labels and scores. One dictionary
1625
+ contains a list of labels and other a list of scores.
1602
1626
 
1603
1627
  Example
1604
1628
  -------
1605
- >>> florence2_ocr(image)
1606
- [
1607
- {'label': 'hello world', 'bbox': [0.1, 0.11, 0.35, 0.4], 'score': 0.99},
1608
- ]
1629
+ >>> vit_nsfw_classification(image)
1630
+ {"label": "normal", "scores": 0.68},
1609
1631
  """
1632
+ if image.shape[0] < 1 or image.shape[1] < 1:
1633
+ raise ValueError(f"Image is empty, image shape: {image.shape}")
1610
1634
 
1611
- image_size = image.shape[:2]
1612
- if image_size[0] < 1 or image_size[1] < 1:
1613
- return []
1614
1635
  image_b64 = convert_to_b64(image)
1615
1636
  data = {
1616
1637
  "image": image_b64,
1617
- "task": "<OCR_WITH_REGION>",
1618
- "function_name": "florence2_ocr",
1638
+ "function_name": "vit_nsfw_classification",
1619
1639
  }
1620
-
1621
- detections = send_inference_request(data, "florence2", v2=True)
1622
- detections = detections["<OCR_WITH_REGION>"]
1623
- return_data = []
1624
- for i in range(len(detections["quad_boxes"])):
1625
- return_data.append(
1626
- {
1627
- "label": detections["labels"][i],
1628
- "bbox": normalize_bbox(
1629
- convert_quad_box_to_bbox(detections["quad_boxes"][i]), image_size
1630
- ),
1631
- "score": 1.0,
1632
- }
1633
- )
1640
+ resp_data: dict[str, Any] = send_inference_request(
1641
+ data, "nsfw-classification", v2=True
1642
+ )
1643
+ resp_data["score"] = round(resp_data["score"], 4)
1634
1644
  _display_tool_trace(
1635
- florence2_ocr.__name__,
1636
- {},
1637
- detections,
1645
+ vit_nsfw_classification.__name__,
1646
+ data,
1647
+ resp_data,
1638
1648
  image_b64,
1639
1649
  )
1640
- return return_data
1650
+ return resp_data
1641
1651
 
1642
1652
 
1643
1653
  def detr_segmentation(image: np.ndarray) -> List[Dict[str, Any]]:
@@ -2097,164 +2107,6 @@ def closest_box_distance(
2097
2107
  return cast(float, np.sqrt(horizontal_distance**2 + vertical_distance**2))
2098
2108
 
2099
2109
 
2100
- def document_extraction(image: np.ndarray) -> Dict[str, Any]:
2101
- """'document_extraction' is a tool that can extract structured information out of
2102
- documents with different layouts. It returns the extracted data in a structured
2103
- hierarchical format containing text, tables, pictures, charts, and other
2104
- information.
2105
-
2106
- Parameters:
2107
- image (np.ndarray): The document image to analyze
2108
-
2109
- Returns:
2110
- Dict[str, Any]: A dictionary containing the extracted information.
2111
-
2112
- Example
2113
- -------
2114
- >>> document_analysis(image)
2115
- {'pages':
2116
- [{'bbox': [0, 0, 1.0, 1.0],
2117
- 'chunks': [{'bbox': [0.8, 0.1, 1.0, 0.2],
2118
- 'label': 'page_header',
2119
- 'order': 75
2120
- 'caption': 'Annual Report 2024',
2121
- 'summary': 'This annual report summarizes ...' },
2122
- {'bbox': [0.2, 0.9, 0.9, 1.0],
2123
- 'label': table',
2124
- 'order': 1119,
2125
- 'caption': [{'Column 1': 'Value 1', 'Column 2': 'Value 2'},
2126
- 'summary': 'This table illustrates a trend of ...'},
2127
- ],
2128
- """
2129
-
2130
- image_file = numpy_to_bytes(image)
2131
-
2132
- files = [("image", image_file)]
2133
-
2134
- payload = {
2135
- "model": "document-analysis",
2136
- }
2137
-
2138
- data: Dict[str, Any] = send_inference_request(
2139
- payload=payload,
2140
- endpoint_name="document-analysis",
2141
- files=files,
2142
- v2=True,
2143
- metadata_payload={"function_name": "document_analysis"},
2144
- )
2145
-
2146
- # don't display normalized bboxes
2147
- _display_tool_trace(
2148
- document_extraction.__name__,
2149
- payload,
2150
- data,
2151
- files,
2152
- )
2153
-
2154
- def normalize(data: Any) -> Dict[str, Any]:
2155
- if isinstance(data, Dict):
2156
- if "bbox" in data:
2157
- data["bbox"] = normalize_bbox(data["bbox"], image.shape[:2])
2158
- for key in data:
2159
- data[key] = normalize(data[key])
2160
- elif isinstance(data, List):
2161
- for i in range(len(data)):
2162
- data[i] = normalize(data[i])
2163
- return data # type: ignore
2164
-
2165
- data = normalize(data)
2166
-
2167
- return data
2168
-
2169
-
2170
- def document_qa(
2171
- prompt: str,
2172
- image: np.ndarray,
2173
- ) -> str:
2174
- """'document_qa' is a tool that can answer any questions about arbitrary documents,
2175
- presentations, or tables. It's very useful for document QA tasks, you can ask it a
2176
- specific question or ask it to return a JSON object answering multiple questions
2177
- about the document.
2178
-
2179
- Parameters:
2180
- prompt (str): The question to be answered about the document image.
2181
- image (np.ndarray): The document image to analyze.
2182
-
2183
- Returns:
2184
- str: The answer to the question based on the document's context.
2185
-
2186
- Example
2187
- -------
2188
- >>> document_qa(image, question)
2189
- 'The answer to the question ...'
2190
- """
2191
-
2192
- image_file = numpy_to_bytes(image)
2193
-
2194
- files = [("image", image_file)]
2195
-
2196
- payload = {
2197
- "model": "document-analysis",
2198
- }
2199
-
2200
- data: Dict[str, Any] = send_inference_request(
2201
- payload=payload,
2202
- endpoint_name="document-analysis",
2203
- files=files,
2204
- v2=True,
2205
- metadata_payload={"function_name": "document_qa"},
2206
- )
2207
-
2208
- def normalize(data: Any) -> Dict[str, Any]:
2209
- if isinstance(data, Dict):
2210
- if "bbox" in data:
2211
- data["bbox"] = normalize_bbox(data["bbox"], image.shape[:2])
2212
- for key in data:
2213
- data[key] = normalize(data[key])
2214
- elif isinstance(data, List):
2215
- for i in range(len(data)):
2216
- data[i] = normalize(data[i])
2217
- return data # type: ignore
2218
-
2219
- data = normalize(data)
2220
-
2221
- prompt = f"""
2222
- Document Context:
2223
- {data}\n
2224
- Question: {prompt}\n
2225
- Answer the question directly using only the information from the document, do not answer with any additional text besides the answer. If the answer is not definitively contained in the document, say "I cannot find the answer in the provided document."
2226
- """
2227
-
2228
- lmm = AnthropicLMM()
2229
- llm_output = lmm.generate(prompt=prompt)
2230
- llm_output = cast(str, llm_output)
2231
-
2232
- _display_tool_trace(
2233
- document_qa.__name__,
2234
- payload,
2235
- llm_output,
2236
- files,
2237
- )
2238
-
2239
- return llm_output
2240
-
2241
-
2242
- def stella_embeddings(prompts: List[str]) -> List[np.ndarray]:
2243
- payload = {
2244
- "input": prompts,
2245
- "model": "stella1.5b",
2246
- }
2247
-
2248
- data: Dict[str, Any] = send_inference_request(
2249
- payload=payload,
2250
- endpoint_name="embeddings",
2251
- v2=True,
2252
- metadata_payload={"function_name": "get_embeddings"},
2253
- is_form=True,
2254
- )
2255
- return [d["embedding"] for d in data] # type: ignore
2256
-
2257
-
2258
2110
  # Utility and visualization functions
2259
2111
 
2260
2112
 
@@ -2772,31 +2624,31 @@ def _plot_counting(
2772
2624
 
2773
2625
 
2774
2626
  FUNCTION_TOOLS = [
2775
- owl_v2_image,
2776
- owl_v2_video,
2777
- ocr,
2778
- vit_image_classification,
2779
- vit_nsfw_classification,
2627
+ owlv2_object_detection,
2628
+ owlv2_sam2_instance_segmentation,
2629
+ owlv2_sam2_video_tracking,
2780
2630
  countgd_object_detection,
2781
- countgd_sam2_object_detection,
2631
+ countgd_sam2_instance_segmentation,
2632
+ countgd_sam2_video_tracking,
2782
2633
  florence2_ocr,
2783
- florence2_sam2_image,
2634
+ florence2_sam2_instance_segmentation,
2784
2635
  florence2_sam2_video_tracking,
2785
- florence2_phrase_grounding,
2636
+ florence2_object_detection,
2786
2637
  claude35_text_extraction,
2638
+ document_extraction,
2639
+ document_qa,
2640
+ ocr,
2641
+ qwen2_vl_images_vqa,
2642
+ qwen2_vl_video_vqa,
2787
2643
  detr_segmentation,
2788
2644
  depth_anything_v2,
2789
2645
  generate_pose_image,
2790
- minimum_distance,
2791
- qwen2_vl_images_vqa,
2792
- qwen2_vl_video_vqa,
2793
- document_extraction,
2794
- document_qa,
2646
+ vit_image_classification,
2647
+ vit_nsfw_classification,
2795
2648
  video_temporal_localization,
2796
2649
  flux_image_inpainting,
2797
2650
  siglip_classification,
2798
- owlv2_sam2_video_tracking,
2799
- countgd_sam2_video_tracking,
2651
+ minimum_distance,
2800
2652
  ]
2801
2653
 
2802
2654
  UTIL_TOOLS = [