vision-agent 0.2.234__tar.gz → 0.2.236__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {vision_agent-0.2.234 → vision_agent-0.2.236}/PKG-INFO +1 -1
  2. {vision_agent-0.2.234 → vision_agent-0.2.236}/pyproject.toml +1 -1
  3. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/agent/vision_agent_coder_prompts.py +1 -1
  4. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/agent/vision_agent_coder_prompts_v2.py +1 -1
  5. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/agent/vision_agent_planner_prompts_v2.py +1 -1
  6. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/agent/vision_agent_planner_v2.py +2 -0
  7. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/tools/tool_utils.py +14 -9
  8. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/tools/tools.py +58 -21
  9. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/utils/video_tracking.py +59 -58
  10. {vision_agent-0.2.234 → vision_agent-0.2.236}/LICENSE +0 -0
  11. {vision_agent-0.2.234 → vision_agent-0.2.236}/README.md +0 -0
  12. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/.sim_tools/df.csv +0 -0
  13. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/.sim_tools/embs.npy +0 -0
  14. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/__init__.py +0 -0
  15. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/agent/README.md +0 -0
  16. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/agent/__init__.py +0 -0
  17. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/agent/agent.py +0 -0
  18. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/agent/agent_utils.py +0 -0
  19. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/agent/types.py +0 -0
  20. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/agent/vision_agent.py +0 -0
  21. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/agent/vision_agent_coder.py +0 -0
  22. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/agent/vision_agent_coder_v2.py +0 -0
  23. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/agent/vision_agent_planner.py +0 -0
  24. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/agent/vision_agent_planner_prompts.py +0 -0
  25. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/agent/vision_agent_prompts.py +0 -0
  26. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/agent/vision_agent_prompts_v2.py +0 -0
  27. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/agent/vision_agent_v2.py +0 -0
  28. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/clients/__init__.py +0 -0
  29. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/clients/http.py +0 -0
  30. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/clients/landing_public_api.py +0 -0
  31. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/configs/__init__.py +0 -0
  32. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/configs/anthropic_config.py +0 -0
  33. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/configs/anthropic_openai_config.py +0 -0
  34. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/configs/config.py +0 -0
  35. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/configs/openai_config.py +0 -0
  36. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/fonts/__init__.py +0 -0
  37. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/fonts/default_font_ch_en.ttf +0 -0
  38. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/lmm/__init__.py +0 -0
  39. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/lmm/lmm.py +0 -0
  40. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/lmm/types.py +0 -0
  41. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/tools/__init__.py +0 -0
  42. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/tools/meta_tools.py +0 -0
  43. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/tools/planner_tools.py +0 -0
  44. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/tools/prompts.py +0 -0
  45. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/tools/tools_types.py +0 -0
  46. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/utils/__init__.py +0 -0
  47. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/utils/exceptions.py +0 -0
  48. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/utils/execute.py +0 -0
  49. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/utils/image_utils.py +0 -0
  50. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/utils/sim.py +0 -0
  51. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/utils/type_defs.py +0 -0
  52. {vision_agent-0.2.234 → vision_agent-0.2.236}/vision_agent/utils/video.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vision-agent
3
- Version: 0.2.234
3
+ Version: 0.2.236
4
4
  Summary: Toolset for Vision Agent
5
5
  Author: Landing AI
6
6
  Author-email: dev@landing.ai
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "vision-agent"
7
- version = "0.2.234"
7
+ version = "0.2.236"
8
8
  description = "Toolset for Vision Agent"
9
9
  authors = ["Landing AI <dev@landing.ai>"]
10
10
  readme = "README.md"
@@ -230,7 +230,7 @@ This is the documentation for the functions you have access to. You may call any
230
230
 
231
231
 
232
232
  FIX_BUG = """
233
- **Role** As a coder, your job is to find the error in the code and fix it. You are running in a notebook setting so you can run !pip install to install missing packages.
233
+ **Role** As a coder, your job is to find the error in the code and fix it. You are running in a notebook setting but do not run !pip install to install new packages.
234
234
 
235
235
  **Documentation**:
236
236
  This is the documentation for the functions you have access to. You may call any of these functions to help you complete the task. They are available through importing `from vision_agent.tools import *`.
@@ -77,7 +77,7 @@ This is the documentation for the functions you have access to. You may call any
77
77
 
78
78
 
79
79
  FIX_BUG = """
80
- **Role**: As a coder, your job is to find the error in the code and fix it. You are running in a notebook setting so you can run !pip install to install missing packages.
80
+ **Role** As a coder, your job is to find the error in the code and fix it. You are running in a notebook setting but do not run !pip install to install new packages.
81
81
 
82
82
  **Task**: A previous agent has written some code and some testing code according to a plan given to it. It has introduced a bug into it's code while trying to implement the plan. You are given the plan, code, test code and error. Your job is to fix the error in the code or test code.
83
83
 
@@ -1,7 +1,7 @@
1
1
  PLAN = """
2
2
  **Role**: You are an expert planning agent that can understand the user request and search for a plan to accomplish it.
3
3
 
4
- **Task**: As a planning agent you are required to understand the user's request and search for a plan to accomplish it. Use Chain-of-Thought approach to break down the problem, create a plan, and then provide a response. Esnure your response is clear, concise, and helpful. You can use an interactive Pyton (Jupyter Notebok) environment, executing code with <execute_python>, each execution is a new cell so old code and outputs are saved.
4
+ **Task**: As a planning agent you are required to understand the user's request and search for a plan to accomplish it. Use Chain-of-Thought approach to break down the problem, create a plan, and then provide a response. Esnure your response is clear, concise, and helpful. You can use an interactive Pyton (Jupyter Notebok) environment but do not !pip install packages, execute code with <execute_python>, each execution is a new cell so old code and outputs are saved.
5
5
 
6
6
  **Documentation**: this is the documentation for the functions you can use to accomplish the task:
7
7
  {tool_desc}
@@ -21,6 +21,7 @@ from vision_agent.agent.agent_utils import (
21
21
  extract_tag,
22
22
  print_code,
23
23
  print_table,
24
+ remove_installs_from_code,
24
25
  )
25
26
  from vision_agent.agent.types import AgentMessage, InteractionContext, PlanContext
26
27
  from vision_agent.agent.vision_agent_planner_prompts_v2 import (
@@ -180,6 +181,7 @@ def run_critic(
180
181
 
181
182
 
182
183
  def code_safeguards(code: str) -> str:
184
+ code = remove_installs_from_code(code)
183
185
  if "get_tool_for_task" in code:
184
186
  lines = code.split("\n")
185
187
  new_lines = []
@@ -270,17 +270,22 @@ def add_bboxes_from_masks(
270
270
  ) -> List[List[Dict[str, Any]]]:
271
271
  for frame_preds in all_preds:
272
272
  for preds in frame_preds:
273
- if np.sum(preds["mask"]) == 0:
273
+ mask = preds["mask"]
274
+ if mask.sum() == 0:
274
275
  preds["bbox"] = []
275
276
  else:
276
- rows, cols = np.where(preds["mask"])
277
- bbox = [
278
- float(np.min(cols)),
279
- float(np.min(rows)),
280
- float(np.max(cols)),
281
- float(np.max(rows)),
282
- ]
283
- bbox = normalize_bbox(bbox, preds["mask"].shape)
277
+ # Get indices where mask is True using axis operations
278
+ rows = np.any(mask, axis=1)
279
+ cols = np.any(mask, axis=0)
280
+
281
+ # Find boundaries using argmax/argmin
282
+ y_min = np.argmax(rows)
283
+ y_max = len(rows) - np.argmax(rows[::-1])
284
+ x_min = np.argmax(cols)
285
+ x_max = len(cols) - np.argmax(cols[::-1])
286
+
287
+ bbox = [float(x_min), float(y_min), float(x_max), float(y_max)]
288
+ bbox = normalize_bbox(bbox, mask.shape)
284
289
  preds["bbox"] = bbox
285
290
 
286
291
  return all_preds
@@ -234,16 +234,24 @@ def od_sam2_video_tracking(
234
234
  od_model: ODModels,
235
235
  prompt: str,
236
236
  frames: List[np.ndarray],
237
- chunk_length: Optional[int] = 10,
237
+ chunk_length: Optional[int] = 50,
238
238
  fine_tune_id: Optional[str] = None,
239
239
  ) -> Dict[str, Any]:
240
- SEGMENT_SIZE = 50
241
- OVERLAP = 1 # Number of overlapping frames between segments
240
+ chunk_length = 50 if chunk_length is None else chunk_length
241
+ segment_size = chunk_length
242
+ # Number of overlapping frames between segments
243
+ overlap = 1
244
+ # chunk_length needs to be segment_size + 1 or else on the last segment it will
245
+ # run the OD model again and merging will not work
246
+ chunk_length = chunk_length + 1
247
+
248
+ if len(frames) == 0 or not isinstance(frames, List):
249
+ return {"files": [], "return_data": [], "display_data": []}
242
250
 
243
251
  image_size = frames[0].shape[:2]
244
252
 
245
253
  # Split frames into segments with overlap
246
- segments = split_frames_into_segments(frames, SEGMENT_SIZE, OVERLAP)
254
+ segments = split_frames_into_segments(frames, segment_size, overlap)
247
255
 
248
256
  def _apply_object_detection( # inner method to avoid circular importing issues.
249
257
  od_model: ODModels,
@@ -538,7 +546,7 @@ def owlv2_sam2_instance_segmentation(
538
546
  def owlv2_sam2_video_tracking(
539
547
  prompt: str,
540
548
  frames: List[np.ndarray],
541
- chunk_length: Optional[int] = 10,
549
+ chunk_length: Optional[int] = 25,
542
550
  fine_tune_id: Optional[str] = None,
543
551
  ) -> List[List[Dict[str, Any]]]:
544
552
  """'owlv2_sam2_video_tracking' is a tool that can track and segment multiple
@@ -771,7 +779,7 @@ def florence2_sam2_instance_segmentation(
771
779
  def florence2_sam2_video_tracking(
772
780
  prompt: str,
773
781
  frames: List[np.ndarray],
774
- chunk_length: Optional[int] = 10,
782
+ chunk_length: Optional[int] = 25,
775
783
  fine_tune_id: Optional[str] = None,
776
784
  ) -> List[List[Dict[str, Any]]]:
777
785
  """'florence2_sam2_video_tracking' is a tool that can track and segment multiple
@@ -1110,7 +1118,7 @@ def countgd_sam2_instance_segmentation(
1110
1118
  def countgd_sam2_video_tracking(
1111
1119
  prompt: str,
1112
1120
  frames: List[np.ndarray],
1113
- chunk_length: Optional[int] = 10,
1121
+ chunk_length: Optional[int] = 25,
1114
1122
  ) -> List[List[Dict[str, Any]]]:
1115
1123
  """'countgd_sam2_video_tracking' is a tool that can track and segment multiple
1116
1124
  objects in a video given a text prompt such as category names or referring
@@ -1322,7 +1330,7 @@ def custom_object_detection(
1322
1330
  def custom_od_sam2_video_tracking(
1323
1331
  deployment_id: str,
1324
1332
  frames: List[np.ndarray],
1325
- chunk_length: Optional[int] = 10,
1333
+ chunk_length: Optional[int] = 25,
1326
1334
  ) -> List[List[Dict[str, Any]]]:
1327
1335
  """'custom_od_sam2_video_tracking' is a tool that can segment multiple objects given a
1328
1336
  custom model with predefined category names.
@@ -2366,7 +2374,7 @@ def agentic_sam2_instance_segmentation(
2366
2374
  def agentic_sam2_video_tracking(
2367
2375
  prompt: str,
2368
2376
  frames: List[np.ndarray],
2369
- chunk_length: Optional[int] = 10,
2377
+ chunk_length: Optional[int] = 25,
2370
2378
  fine_tune_id: Optional[str] = None,
2371
2379
  ) -> List[List[Dict[str, Any]]]:
2372
2380
  """'agentic_sam2_video_tracking' is a tool that can track and segment multiple
@@ -2791,7 +2799,15 @@ def overlay_bounding_boxes(
2791
2799
  "Number of unique labels exceeds the number of available colors. Some labels may have the same color."
2792
2800
  )
2793
2801
 
2794
- color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(labels)}
2802
+ use_tracking_label = False
2803
+ if all([":" in label for label in labels]):
2804
+ unique_labels = set([label.split(":")[1].strip() for label in labels])
2805
+ use_tracking_label = True
2806
+ colors = {
2807
+ label: COLORS[i % len(COLORS)] for i, label in enumerate(unique_labels)
2808
+ }
2809
+ else:
2810
+ colors = {label: COLORS[i % len(COLORS)] for i, label in enumerate(labels)}
2795
2811
 
2796
2812
  frame_out = []
2797
2813
  for i, frame in enumerate(medias_int):
@@ -2802,7 +2818,7 @@ def overlay_bounding_boxes(
2802
2818
 
2803
2819
  # if more than 50 boxes use small boxes to indicate objects else use regular boxes
2804
2820
  if len(bboxes) > 50:
2805
- pil_image = _plot_counting(pil_image, bboxes, color)
2821
+ pil_image = _plot_counting(pil_image, bboxes, colors, use_tracking_label)
2806
2822
  else:
2807
2823
  width, height = pil_image.size
2808
2824
  fontsize = max(12, int(min(width, height) / 40))
@@ -2817,18 +2833,20 @@ def overlay_bounding_boxes(
2817
2833
  )
2818
2834
 
2819
2835
  for elt in bboxes:
2836
+ if use_tracking_label:
2837
+ color = colors[elt["label"].split(":")[1].strip()]
2838
+ else:
2839
+ color = colors[elt["label"]]
2820
2840
  label = elt["label"]
2821
2841
  box = elt["bbox"]
2822
2842
  scores = elt["score"]
2823
2843
 
2824
2844
  # denormalize the box if it is normalized
2825
2845
  box = denormalize_bbox(box, (height, width))
2826
- draw.rectangle(box, outline=color[label], width=4)
2846
+ draw.rectangle(box, outline=color, width=4)
2827
2847
  text = f"{label}: {scores:.2f}"
2828
2848
  text_box = draw.textbbox((box[0], box[1]), text=text, font=font)
2829
- draw.rectangle(
2830
- (box[0], box[1], text_box[2], text_box[3]), fill=color[label]
2831
- )
2849
+ draw.rectangle((box[0], box[1], text_box[2], text_box[3]), fill=color)
2832
2850
  draw.text((box[0], box[1]), text, fill="black", font=font)
2833
2851
 
2834
2852
  frame_out.append(np.array(pil_image))
@@ -2911,7 +2929,16 @@ def overlay_segmentation_masks(
2911
2929
  for mask_i in masks_int:
2912
2930
  for mask_j in mask_i:
2913
2931
  labels.add(mask_j["label"])
2914
- color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(labels)}
2932
+
2933
+ use_tracking_label = False
2934
+ if all([":" in label for label in labels]):
2935
+ use_tracking_label = True
2936
+ unique_labels = set([label.split(":")[1].strip() for label in labels])
2937
+ colors = {
2938
+ label: COLORS[i % len(COLORS)] for i, label in enumerate(unique_labels)
2939
+ }
2940
+ else:
2941
+ colors = {label: COLORS[i % len(COLORS)] for i, label in enumerate(labels)}
2915
2942
 
2916
2943
  width, height = Image.fromarray(medias_int[0]).size
2917
2944
  fontsize = max(12, int(min(width, height) / 40))
@@ -2925,12 +2952,16 @@ def overlay_segmentation_masks(
2925
2952
  pil_image = Image.fromarray(frame.astype(np.uint8)).convert("RGBA")
2926
2953
  for elt in masks_int[i]:
2927
2954
  mask = elt["mask"]
2955
+ if use_tracking_label:
2956
+ color = colors[elt["label"].split(":")[1].strip()]
2957
+ else:
2958
+ color = colors[elt["label"]]
2928
2959
  label = elt["label"]
2929
2960
  tracking_lbl = elt.get(secondary_label_key, None)
2930
2961
 
2931
2962
  # Create semi-transparent mask overlay
2932
2963
  np_mask = np.zeros((pil_image.size[1], pil_image.size[0], 4))
2933
- np_mask[mask > 0, :] = color[label] + (255 * 0.7,)
2964
+ np_mask[mask > 0, :] = color + (255 * 0.7,)
2934
2965
  mask_img = Image.fromarray(np_mask.astype(np.uint8))
2935
2966
  pil_image = Image.alpha_composite(pil_image, mask_img)
2936
2967
 
@@ -2942,7 +2973,7 @@ def overlay_segmentation_masks(
2942
2973
  border_mask = np.zeros(
2943
2974
  (pil_image.size[1], pil_image.size[0], 4), dtype=np.uint8
2944
2975
  )
2945
- cv2.drawContours(border_mask, contours, -1, color[label] + (255,), 8)
2976
+ cv2.drawContours(border_mask, contours, -1, color + (255,), 8)
2946
2977
  border_img = Image.fromarray(border_mask)
2947
2978
  pil_image = Image.alpha_composite(pil_image, border_img)
2948
2979
 
@@ -2957,7 +2988,7 @@ def overlay_segmentation_masks(
2957
2988
  )
2958
2989
  if x != 0 and y != 0:
2959
2990
  text_box = draw.textbbox((x, y), text=text, font=font)
2960
- draw.rectangle((x, y, text_box[2], text_box[3]), fill=color[label])
2991
+ draw.rectangle((x, y, text_box[2], text_box[3]), fill=color)
2961
2992
  draw.text((x, y), text, fill="black", font=font)
2962
2993
  frame_out.append(np.array(pil_image))
2963
2994
  return_frame = frame_out[0] if len(frame_out) == 1 else frame_out
@@ -3014,6 +3045,7 @@ def _plot_counting(
3014
3045
  image: Image.Image,
3015
3046
  bboxes: List[Dict[str, Any]],
3016
3047
  colors: Dict[str, Tuple[int, int, int]],
3048
+ use_tracking_label: bool = False,
3017
3049
  ) -> Image.Image:
3018
3050
  width, height = image.size
3019
3051
  fontsize = max(12, int(min(width, height) / 40))
@@ -3023,7 +3055,12 @@ def _plot_counting(
3023
3055
  fontsize,
3024
3056
  )
3025
3057
  for i, elt in enumerate(bboxes, 1):
3026
- label = f"{i}"
3058
+ if use_tracking_label:
3059
+ label = elt["label"].split(":")[0]
3060
+ color = colors[elt["label"].split(":")[1].strip()]
3061
+ else:
3062
+ label = f"{i}"
3063
+ color = colors[elt["label"]]
3027
3064
  box = elt["bbox"]
3028
3065
 
3029
3066
  # denormalize the box if it is normalized
@@ -3044,7 +3081,7 @@ def _plot_counting(
3044
3081
  text_y1 = cy + text_height / 2
3045
3082
 
3046
3083
  # Draw the rectangle encapsulating the text
3047
- draw.rectangle((text_x0, text_y0, text_x1, text_y1), fill=colors[elt["label"]])
3084
+ draw.rectangle((text_x0, text_y0, text_x1, text_y1), fill=color)
3048
3085
 
3049
3086
  # Draw the text at the center of the bounding box
3050
3087
  draw.text(
@@ -3,10 +3,10 @@ from enum import Enum
3
3
  from typing import Any, Callable, Dict, List, Optional, Tuple
4
4
 
5
5
  import numpy as np
6
+ from scipy.optimize import linear_sum_assignment # type: ignore
6
7
 
7
8
  from vision_agent.tools.tool_utils import (
8
9
  add_bboxes_from_masks,
9
- nms,
10
10
  send_task_inference_request,
11
11
  )
12
12
  from vision_agent.utils.image_utils import denormalize_bbox, rle_decode_array
@@ -171,63 +171,45 @@ def _calculate_mask_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
171
171
  def _match_by_iou(
172
172
  first_param: List[Dict],
173
173
  second_param: List[Dict],
174
- iou_threshold: float = 0.8,
175
- ) -> Tuple[List[Dict], Dict[int, int]]:
176
- max_id = max((item["id"] for item in first_param), default=0)
177
-
178
- matched_new_item_indices = set()
179
- id_mapping = {}
180
-
181
- for new_index, new_item in enumerate(second_param):
182
- matched_id = None
183
-
184
- for existing_item in first_param:
174
+ max_id: int,
175
+ iou_threshold: float = 0.05,
176
+ ) -> Tuple[Dict[int, int], int]:
177
+ max_first_id = max((item["id"] for item in first_param), default=0)
178
+ max_second_id = max((item["id"] for item in second_param), default=0)
179
+
180
+ cost_matrix = np.ones((max_first_id + 1, max_second_id + 1))
181
+ for first_item in first_param:
182
+ for second_item in second_param:
185
183
  iou = _calculate_mask_iou(
186
- existing_item["decoded_mask"], new_item["decoded_mask"]
184
+ first_item["decoded_mask"], second_item["decoded_mask"]
187
185
  )
188
- if iou > iou_threshold:
189
- matched_id = existing_item["id"]
190
- matched_new_item_indices.add(new_index)
191
- id_mapping[new_item["id"]] = matched_id
192
- break
193
-
194
- if matched_id:
195
- new_item["id"] = matched_id
196
- else:
197
- max_id += 1
198
- id_mapping[new_item["id"]] = max_id
199
- new_item["id"] = max_id
200
-
201
- unmatched_items = [
202
- item for i, item in enumerate(second_param) if i not in matched_new_item_indices
203
- ]
204
- combined_list = first_param + unmatched_items
205
-
206
- return combined_list, id_mapping
186
+ cost_matrix[first_item["id"], second_item["id"]] = 1 - iou
207
187
 
188
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
189
+ id_mapping = {second_id: first_id for first_id, second_id in zip(row_ind, col_ind)}
190
+ first_id_to_label = {item["id"]: item["label"] for item in first_param}
208
191
 
209
- def _update_ids(detections: List[Dict], id_mapping: Dict[int, int]) -> None:
210
- for inner_list in detections:
211
- for detection in inner_list:
212
- if detection["id"] in id_mapping:
213
- detection["id"] = id_mapping[detection["id"]]
192
+ cleaned_mapping = {}
193
+ for elt in second_param:
194
+ second_id = elt["id"]
195
+ # if the id is not in the mapping, give it a new id
196
+ if second_id not in id_mapping:
197
+ max_id += 1
198
+ cleaned_mapping[second_id] = max_id
199
+ else:
200
+ first_id = id_mapping[second_id]
201
+ iou = 1 - cost_matrix[first_id, second_id]
202
+ # only map if the iou is above the threshold and the labels match
203
+ if iou > iou_threshold and first_id_to_label[first_id] == elt["label"]:
204
+ cleaned_mapping[second_id] = first_id
214
205
  else:
215
- max_new_id = max(id_mapping.values(), default=0)
216
- detection["id"] = max_new_id + 1
217
- id_mapping[detection["id"]] = detection["id"]
206
+ max_id += 1
207
+ cleaned_mapping[second_id] = max_id
218
208
 
209
+ return cleaned_mapping, max_id
219
210
 
220
- def _convert_to_2d(detections_per_segment: List[Any]) -> List[Any]:
221
- result = []
222
- for i, segment in enumerate(detections_per_segment):
223
- if i == 0:
224
- result.extend(segment)
225
- else:
226
- result.extend(segment[1:])
227
- return result
228
211
 
229
-
230
- def merge_segments(detections_per_segment: List[Any]) -> List[Any]:
212
+ def merge_segments(detections_per_segment: List[Any], overlap: int = 1) -> List[Any]:
231
213
  """
232
214
  Merges detections from all segments into a unified result.
233
215
 
@@ -242,16 +224,20 @@ def merge_segments(detections_per_segment: List[Any]) -> List[Any]:
242
224
  for item in detection:
243
225
  item["decoded_mask"] = rle_decode_array(item["mask"])
244
226
 
227
+ merged_result = detections_per_segment[0]
228
+ max_id = max((item["id"] for item in merged_result[-1]), default=0)
245
229
  for segment_idx in range(len(detections_per_segment) - 1):
246
- combined_detection, id_mapping = _match_by_iou(
230
+ id_mapping, max_id = _match_by_iou(
247
231
  detections_per_segment[segment_idx][-1],
248
232
  detections_per_segment[segment_idx + 1][0],
233
+ max_id,
249
234
  )
250
- _update_ids(detections_per_segment[segment_idx + 1], id_mapping)
251
-
252
- merged_result = _convert_to_2d(detections_per_segment)
235
+ for frame in detections_per_segment[segment_idx + 1][overlap:]:
236
+ for detection in frame:
237
+ detection["id"] = id_mapping[detection["id"]]
238
+ merged_result.extend(detections_per_segment[segment_idx + 1][overlap:])
253
239
 
254
- return merged_result
240
+ return merged_result # type: ignore
255
241
 
256
242
 
257
243
  def post_process(
@@ -269,10 +255,26 @@ def post_process(
269
255
  Dict[str, Any]: Post-processed data including return_data and display_data.
270
256
  """
271
257
  return_data = []
272
- for frame_idx, frame in enumerate(merged_detections):
258
+ label_remapping = {}
259
+ for _, frame in enumerate(merged_detections):
273
260
  return_frame_data = []
274
261
  for detection in frame:
275
- label = f"{detection['id']}: {detection['label']}"
262
+ label = detection["label"]
263
+ id = detection["id"]
264
+
265
+ # Remap label IDs so for each label the IDs restart at 1. This makes it
266
+ # easier to count the number of instances per label.
267
+ if label not in label_remapping:
268
+ label_remapping[label] = {"max": 1, "remap": {id: 1}}
269
+ elif label in label_remapping and id not in label_remapping[label]["remap"]: # type: ignore
270
+ max_id = label_remapping[label]["max"]
271
+ max_id += 1 # type: ignore
272
+ label_remapping[label]["remap"][id] = max_id # type: ignore
273
+ label_remapping[label]["max"] = max_id
274
+
275
+ new_id = label_remapping[label]["remap"][id] # type: ignore
276
+
277
+ label = f"{new_id}: {detection['label']}"
276
278
  return_frame_data.append(
277
279
  {
278
280
  "label": label,
@@ -285,7 +287,6 @@ def post_process(
285
287
  return_data.append(return_frame_data)
286
288
 
287
289
  return_data = add_bboxes_from_masks(return_data)
288
- return_data = nms(return_data, iou_threshold=0.95)
289
290
 
290
291
  # We save the RLE for display purposes, re-calculting RLE can get very expensive.
291
292
  # Deleted here because we are returning the numpy masks instead
File without changes
File without changes