vision-agent 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -314,6 +314,7 @@ def _handle_extract_frames(
314
314
  image_to_data[image] = {
315
315
  "bboxes": [],
316
316
  "masks": [],
317
+ "heat_map": [],
317
318
  "labels": [],
318
319
  "scores": [],
319
320
  }
@@ -340,9 +341,12 @@ def _handle_viz_tools(
340
341
  return image_to_data
341
342
 
342
343
  for param, call_result in zip(parameters, tool_result["call_results"]):
343
- # calls can fail, so we need to check if the call was successful
344
+ # Calls can fail, so we need to check if the call was successful. It can either:
345
+ # 1. return a str or some error that's not a dictionary
346
+ # 2. return a dictionary but not have the necessary keys
347
+
344
348
  if not isinstance(call_result, dict) or (
345
- "bboxes" not in call_result and "masks" not in call_result
349
+ "bboxes" not in call_result and "heat_map" not in call_result
346
350
  ):
347
351
  return image_to_data
348
352
 
@@ -352,6 +356,7 @@ def _handle_viz_tools(
352
356
  image_to_data[image] = {
353
357
  "bboxes": [],
354
358
  "masks": [],
359
+ "heat_map": [],
355
360
  "labels": [],
356
361
  "scores": [],
357
362
  }
@@ -360,12 +365,28 @@ def _handle_viz_tools(
360
365
  image_to_data[image]["labels"].extend(call_result.get("labels", []))
361
366
  image_to_data[image]["scores"].extend(call_result.get("scores", []))
362
367
  image_to_data[image]["masks"].extend(call_result.get("masks", []))
368
+ # only single heatmap is returned
369
+ image_to_data[image]["heat_map"].append(call_result.get("heat_map", []))
363
370
  if "mask_shape" in call_result:
364
371
  image_to_data[image]["mask_shape"] = call_result["mask_shape"]
365
372
 
366
373
  return image_to_data
367
374
 
368
375
 
376
+ def sample_n_evenly_spaced(lst: Sequence, n: int) -> Sequence:
377
+ if n <= 0:
378
+ return []
379
+ elif len(lst) == 0:
380
+ return []
381
+ elif n == 1:
382
+ return [lst[0]]
383
+ elif n >= len(lst):
384
+ return lst
385
+
386
+ spacing = (len(lst) - 1) / (n - 1)
387
+ return [lst[round(spacing * i)] for i in range(n)]
388
+
389
+
369
390
  def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]:
370
391
  image_to_data: Dict[str, Dict] = {}
371
392
  for tool_result in all_tool_results:
@@ -466,9 +487,14 @@ class VisionAgent(Agent):
466
487
  """Invoke the vision agent.
467
488
 
468
489
  Parameters:
469
- input: a prompt that describe the task or a conversation in the format of
490
+ chat: A conversation in the format of
470
491
  [{"role": "user", "content": "describe your task here..."}].
471
- image: the input image referenced in the prompt parameter.
492
+ image: The input image referenced in the chat parameter.
493
+ reference_data: A dictionary containing the reference image, mask or bounding
494
+ box in the format of:
495
+ {"image": "image.jpg", "mask": "mask.jpg", "bbox": [0.1, 0.2, 0.1, 0.2]}
496
+ where the bounding box coordinates are normalized.
497
+ visualize_output: Whether to visualize the output.
472
498
 
473
499
  Returns:
474
500
  The result of the vision agent in text.
@@ -508,12 +534,14 @@ class VisionAgent(Agent):
508
534
  """Chat with the vision agent and return the final answer and all tool results.
509
535
 
510
536
  Parameters:
511
- chat: a conversation in the format of
537
+ chat: A conversation in the format of
512
538
  [{"role": "user", "content": "describe your task here..."}].
513
- image: the input image referenced in the chat parameter.
514
- reference_data: a dictionary containing the reference image and mask. in the
515
- format of {"image": "image.jpg", "mask": "mask.jpg}
516
- visualize_output: whether to visualize the output.
539
+ image: The input image referenced in the chat parameter.
540
+ reference_data: A dictionary containing the reference image, mask or bounding
541
+ box in the format of:
542
+ {"image": "image.jpg", "mask": "mask.jpg", "bbox": [0.1, 0.2, 0.1, 0.2]}
543
+ where the bounding box coordinates are normalized.
544
+ visualize_output: Whether to visualize the output.
517
545
 
518
546
  Returns:
519
547
  A tuple where the first item is the final answer and the second item is a
@@ -584,7 +612,7 @@ class VisionAgent(Agent):
584
612
  visualized_output = visualize_result(all_tool_results)
585
613
  all_tool_results.append({"visualized_output": visualized_output})
586
614
  if len(visualized_output) > 0:
587
- reflection_images = visualized_output
615
+ reflection_images = sample_n_evenly_spaced(visualized_output, 3)
588
616
  elif image is not None:
589
617
  reflection_images = [image]
590
618
  else:
@@ -211,7 +211,7 @@ def overlay_masks(
211
211
  }
212
212
 
213
213
  for label, mask in zip(masks["labels"], masks["masks"]):
214
- if isinstance(mask, str):
214
+ if isinstance(mask, str) or isinstance(mask, Path):
215
215
  mask = np.array(Image.open(mask))
216
216
  np_mask = np.zeros((image.size[1], image.size[0], 4))
217
217
  np_mask[mask > 0, :] = color[label] + (255 * alpha,)
@@ -221,7 +221,7 @@ def overlay_masks(
221
221
 
222
222
 
223
223
  def overlay_heat_map(
224
- image: Union[str, Path, np.ndarray, ImageType], masks: Dict, alpha: float = 0.8
224
+ image: Union[str, Path, np.ndarray, ImageType], heat_map: Dict, alpha: float = 0.8
225
225
  ) -> ImageType:
226
226
  r"""Plots heat map on to an image.
227
227
 
@@ -238,14 +238,12 @@ def overlay_heat_map(
238
238
  elif isinstance(image, np.ndarray):
239
239
  image = Image.fromarray(image)
240
240
 
241
- if "masks" not in masks:
241
+ if "heat_map" not in heat_map:
242
242
  return image.convert("RGB")
243
243
 
244
- # Only one heat map per image, so no need to loop through masks
245
244
  image = image.convert("L")
246
-
247
- if isinstance(masks["masks"][0], str):
248
- mask = b64_to_pil(masks["masks"][0])
245
+ # Only one heat map per image, so no need to loop through masks
246
+ mask = Image.fromarray(heat_map["heat_map"][0])
249
247
 
250
248
  overlay = Image.new("RGBA", mask.size)
251
249
  odraw = ImageDraw.Draw(overlay)
vision_agent/lmm/lmm.py CHANGED
@@ -9,10 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Union, cast
9
9
  import requests
10
10
  from openai import AzureOpenAI, OpenAI
11
11
 
12
- from vision_agent.tools import (
13
- CHOOSE_PARAMS,
14
- SYSTEM_PROMPT,
15
- )
12
+ from vision_agent.tools import CHOOSE_PARAMS, SYSTEM_PROMPT
16
13
 
17
14
  _LOGGER = logging.getLogger(__name__)
18
15
 
@@ -12,12 +12,12 @@ from .tools import ( # Counter,
12
12
  GroundingDINO,
13
13
  GroundingSAM,
14
14
  ImageCaption,
15
- ZeroShotCounting,
16
- VisualPromptCounting,
17
- VisualQuestionAnswering,
18
15
  ImageQuestionAnswering,
19
16
  SegArea,
20
17
  SegIoU,
21
18
  Tool,
19
+ VisualPromptCounting,
20
+ VisualQuestionAnswering,
21
+ ZeroShotCounting,
22
22
  register_tool,
23
23
  )
@@ -11,15 +11,16 @@ from PIL import Image
11
11
  from PIL.Image import Image as ImageType
12
12
 
13
13
  from vision_agent.image_utils import (
14
+ b64_to_pil,
14
15
  convert_to_b64,
15
16
  denormalize_bbox,
16
17
  get_image_size,
17
18
  normalize_bbox,
18
19
  rle_decode,
19
20
  )
21
+ from vision_agent.lmm import OpenAILMM
20
22
  from vision_agent.tools.video import extract_frames_from_video
21
23
  from vision_agent.type_defs import LandingaiAPIKey
22
- from vision_agent.lmm import OpenAILMM
23
24
 
24
25
  _LOGGER = logging.getLogger(__name__)
25
26
  _LND_API_KEY = LandingaiAPIKey().api_key
@@ -516,7 +517,9 @@ class ZeroShotCounting(Tool):
516
517
  "image": image_b64,
517
518
  "tool": "zero_shot_counting",
518
519
  }
519
- return _send_inference_request(data, "tools")
520
+ resp_data = _send_inference_request(data, "tools")
521
+ resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0]))
522
+ return resp_data
520
523
 
521
524
 
522
525
  class VisualPromptCounting(Tool):
@@ -585,7 +588,9 @@ class VisualPromptCounting(Tool):
585
588
  "prompt": prompt,
586
589
  "tool": "few_shot_counting",
587
590
  }
588
- return _send_inference_request(data, "tools")
591
+ resp_data = _send_inference_request(data, "tools")
592
+ resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0]))
593
+ return resp_data
589
594
 
590
595
 
591
596
  class VisualQuestionAnswering(Tool):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vision-agent
3
- Version: 0.2.4
3
+ Version: 0.2.6
4
4
  Summary: Toolset for Vision Agent
5
5
  Author: Landing AI
6
6
  Author-email: dev@landing.ai
@@ -5,21 +5,21 @@ vision_agent/agent/easytool.py,sha256=oMHnBg7YBtIPgqQUNcZgq7uMgpPThs99_UnO7ERkMV
5
5
  vision_agent/agent/easytool_prompts.py,sha256=zdQQw6WpXOmvwOMtlBlNKY5a3WNlr65dbUvMIGiqdeo,4526
6
6
  vision_agent/agent/reflexion.py,sha256=4gz30BuFMeGxSsTzoDV4p91yE0R8LISXp28IaOI6wdM,10506
7
7
  vision_agent/agent/reflexion_prompts.py,sha256=G7UAeNz_g2qCb2yN6OaIC7bQVUkda4m3z42EG8wAyfE,9342
8
- vision_agent/agent/vision_agent.py,sha256=Ehb97lyPs7lYM9ipx07yxm6c2kUqz2OnjGQsv-nMwKA,24849
8
+ vision_agent/agent/vision_agent.py,sha256=xepqtPqxwIEj0V9OyPtlBr4hsE67BazmqcXHBjUO8a4,25971
9
9
  vision_agent/agent/vision_agent_prompts.py,sha256=W3Z72FpUt71UIJSkjAcgtQqxeMqkYuATqHAN5fYY26c,7342
10
10
  vision_agent/fonts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  vision_agent/fonts/default_font_ch_en.ttf,sha256=1YM0Z3XqLDjSNbF7ihQFSAIUdjF9m1rtHiNC_6QosTE,1594400
12
- vision_agent/image_utils.py,sha256=YvP5KE9NrWdgJKuHW2NR1glzfObkxtcXBknpmj3Gsbs,7554
12
+ vision_agent/image_utils.py,sha256=BfUueXZKOvjdI_J7vi2vE57FfpaHOkxrTY0d3aQ1zgI,7552
13
13
  vision_agent/llm/__init__.py,sha256=BoUm_zSAKnLlE8s-gKTSQugXDqVZKPqYlWwlTLdhcz4,48
14
14
  vision_agent/llm/llm.py,sha256=1BkrSVBWEClyqLc0Rmyw4heLhi_ZVm6JO7-i1wd1ziw,5383
15
15
  vision_agent/lmm/__init__.py,sha256=nnNeKD1k7q_4vLb1x51O_EUTYaBgGfeiCx5F433gr3M,67
16
- vision_agent/lmm/lmm.py,sha256=sECjGMaGrv1QHq7OiFr-9LoBM5uRLjAqd0Ypp-zyFlw,10552
17
- vision_agent/tools/__init__.py,sha256=X6yJhWa8iKkQm4Mgf1KcV0_o39-Nrg3E56QAB5gWCO0,413
16
+ vision_agent/lmm/lmm.py,sha256=gK90vMxh0OcGSuIZQikBkDXm4pfkdFk1R2y7rtWDl84,10539
17
+ vision_agent/tools/__init__.py,sha256=HfUr0JQUwk0Kyieen93df9lMbbdpVf9Q6CcVFmKv_q4,413
18
18
  vision_agent/tools/prompts.py,sha256=V1z4YJLXZuUl_iZ5rY0M5hHc_2tmMEUKr0WocXKGt4E,1430
19
- vision_agent/tools/tools.py,sha256=hYgRTHMCBwjT0kkT2SY5MN0FK89vuuecu-x1VqRlGbU,42779
19
+ vision_agent/tools/tools.py,sha256=DZX-w17XWttR4j8bQKY90QxLOs3-ZD5qOHA_53LL7Dk,43013
20
20
  vision_agent/tools/video.py,sha256=xTElFSFp1Jw4ulOMnk81Vxsh-9dTxcWUO6P9fzEi3AM,7653
21
21
  vision_agent/type_defs.py,sha256=4LTnTL4HNsfYqCrDn9Ppjg9bSG2ZGcoKSSd9YeQf4Bw,1792
22
- vision_agent-0.2.4.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
23
- vision_agent-0.2.4.dist-info/METADATA,sha256=2T1YLGMh2-n8F0gGf1P2BDhgzxmtmAiylpfW3E3Q4_c,7697
24
- vision_agent-0.2.4.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
25
- vision_agent-0.2.4.dist-info/RECORD,,
22
+ vision_agent-0.2.6.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
23
+ vision_agent-0.2.6.dist-info/METADATA,sha256=84i_O_9o8Ro6PbR3bi0rLyWbSwjqoxO6n6V9Bk06tP4,7697
24
+ vision_agent-0.2.6.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
25
+ vision_agent-0.2.6.dist-info/RECORD,,