dgenerate-ultralytics-headless 8.3.222__py3-none-any.whl → 8.3.225__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/METADATA +2 -2
  2. dgenerate_ultralytics_headless-8.3.225.dist-info/RECORD +286 -0
  3. tests/conftest.py +5 -8
  4. tests/test_cli.py +1 -8
  5. tests/test_python.py +1 -2
  6. ultralytics/__init__.py +1 -1
  7. ultralytics/cfg/__init__.py +34 -49
  8. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  9. ultralytics/cfg/datasets/kitti.yaml +27 -0
  10. ultralytics/cfg/datasets/lvis.yaml +5 -5
  11. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  12. ultralytics/data/annotator.py +3 -4
  13. ultralytics/data/augment.py +244 -323
  14. ultralytics/data/base.py +12 -22
  15. ultralytics/data/build.py +47 -40
  16. ultralytics/data/converter.py +32 -42
  17. ultralytics/data/dataset.py +43 -71
  18. ultralytics/data/loaders.py +22 -34
  19. ultralytics/data/split.py +5 -6
  20. ultralytics/data/split_dota.py +8 -15
  21. ultralytics/data/utils.py +27 -36
  22. ultralytics/engine/exporter.py +49 -116
  23. ultralytics/engine/model.py +144 -180
  24. ultralytics/engine/predictor.py +18 -29
  25. ultralytics/engine/results.py +165 -231
  26. ultralytics/engine/trainer.py +11 -19
  27. ultralytics/engine/tuner.py +13 -23
  28. ultralytics/engine/validator.py +6 -10
  29. ultralytics/hub/__init__.py +7 -12
  30. ultralytics/hub/auth.py +6 -12
  31. ultralytics/hub/google/__init__.py +7 -10
  32. ultralytics/hub/session.py +15 -25
  33. ultralytics/hub/utils.py +3 -6
  34. ultralytics/models/fastsam/model.py +6 -8
  35. ultralytics/models/fastsam/predict.py +5 -10
  36. ultralytics/models/fastsam/utils.py +1 -2
  37. ultralytics/models/fastsam/val.py +2 -4
  38. ultralytics/models/nas/model.py +5 -8
  39. ultralytics/models/nas/predict.py +7 -9
  40. ultralytics/models/nas/val.py +1 -2
  41. ultralytics/models/rtdetr/model.py +5 -8
  42. ultralytics/models/rtdetr/predict.py +15 -18
  43. ultralytics/models/rtdetr/train.py +10 -13
  44. ultralytics/models/rtdetr/val.py +13 -20
  45. ultralytics/models/sam/amg.py +12 -18
  46. ultralytics/models/sam/build.py +6 -9
  47. ultralytics/models/sam/model.py +16 -23
  48. ultralytics/models/sam/modules/blocks.py +62 -84
  49. ultralytics/models/sam/modules/decoders.py +17 -24
  50. ultralytics/models/sam/modules/encoders.py +40 -56
  51. ultralytics/models/sam/modules/memory_attention.py +10 -16
  52. ultralytics/models/sam/modules/sam.py +41 -47
  53. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  54. ultralytics/models/sam/modules/transformer.py +17 -27
  55. ultralytics/models/sam/modules/utils.py +31 -42
  56. ultralytics/models/sam/predict.py +172 -209
  57. ultralytics/models/utils/loss.py +14 -26
  58. ultralytics/models/utils/ops.py +13 -17
  59. ultralytics/models/yolo/classify/predict.py +8 -11
  60. ultralytics/models/yolo/classify/train.py +8 -16
  61. ultralytics/models/yolo/classify/val.py +13 -20
  62. ultralytics/models/yolo/detect/predict.py +4 -8
  63. ultralytics/models/yolo/detect/train.py +11 -20
  64. ultralytics/models/yolo/detect/val.py +38 -48
  65. ultralytics/models/yolo/model.py +35 -47
  66. ultralytics/models/yolo/obb/predict.py +5 -8
  67. ultralytics/models/yolo/obb/train.py +11 -14
  68. ultralytics/models/yolo/obb/val.py +20 -28
  69. ultralytics/models/yolo/pose/predict.py +5 -8
  70. ultralytics/models/yolo/pose/train.py +4 -8
  71. ultralytics/models/yolo/pose/val.py +31 -39
  72. ultralytics/models/yolo/segment/predict.py +9 -14
  73. ultralytics/models/yolo/segment/train.py +3 -6
  74. ultralytics/models/yolo/segment/val.py +16 -26
  75. ultralytics/models/yolo/world/train.py +8 -14
  76. ultralytics/models/yolo/world/train_world.py +11 -16
  77. ultralytics/models/yolo/yoloe/predict.py +16 -23
  78. ultralytics/models/yolo/yoloe/train.py +30 -43
  79. ultralytics/models/yolo/yoloe/train_seg.py +5 -10
  80. ultralytics/models/yolo/yoloe/val.py +15 -20
  81. ultralytics/nn/autobackend.py +10 -18
  82. ultralytics/nn/modules/activation.py +4 -6
  83. ultralytics/nn/modules/block.py +99 -185
  84. ultralytics/nn/modules/conv.py +45 -90
  85. ultralytics/nn/modules/head.py +44 -98
  86. ultralytics/nn/modules/transformer.py +44 -76
  87. ultralytics/nn/modules/utils.py +14 -19
  88. ultralytics/nn/tasks.py +86 -146
  89. ultralytics/nn/text_model.py +25 -40
  90. ultralytics/solutions/ai_gym.py +10 -16
  91. ultralytics/solutions/analytics.py +7 -10
  92. ultralytics/solutions/config.py +4 -5
  93. ultralytics/solutions/distance_calculation.py +9 -12
  94. ultralytics/solutions/heatmap.py +7 -13
  95. ultralytics/solutions/instance_segmentation.py +5 -8
  96. ultralytics/solutions/object_blurrer.py +7 -10
  97. ultralytics/solutions/object_counter.py +8 -12
  98. ultralytics/solutions/object_cropper.py +5 -8
  99. ultralytics/solutions/parking_management.py +12 -14
  100. ultralytics/solutions/queue_management.py +4 -6
  101. ultralytics/solutions/region_counter.py +7 -10
  102. ultralytics/solutions/security_alarm.py +14 -19
  103. ultralytics/solutions/similarity_search.py +7 -12
  104. ultralytics/solutions/solutions.py +31 -53
  105. ultralytics/solutions/speed_estimation.py +6 -9
  106. ultralytics/solutions/streamlit_inference.py +2 -4
  107. ultralytics/solutions/trackzone.py +7 -10
  108. ultralytics/solutions/vision_eye.py +5 -8
  109. ultralytics/trackers/basetrack.py +2 -4
  110. ultralytics/trackers/bot_sort.py +6 -11
  111. ultralytics/trackers/byte_tracker.py +10 -15
  112. ultralytics/trackers/track.py +3 -6
  113. ultralytics/trackers/utils/gmc.py +6 -12
  114. ultralytics/trackers/utils/kalman_filter.py +35 -43
  115. ultralytics/trackers/utils/matching.py +6 -10
  116. ultralytics/utils/__init__.py +61 -100
  117. ultralytics/utils/autobatch.py +2 -4
  118. ultralytics/utils/autodevice.py +11 -13
  119. ultralytics/utils/benchmarks.py +25 -35
  120. ultralytics/utils/callbacks/base.py +8 -10
  121. ultralytics/utils/callbacks/clearml.py +2 -4
  122. ultralytics/utils/callbacks/comet.py +30 -44
  123. ultralytics/utils/callbacks/dvc.py +13 -18
  124. ultralytics/utils/callbacks/mlflow.py +4 -5
  125. ultralytics/utils/callbacks/neptune.py +4 -6
  126. ultralytics/utils/callbacks/raytune.py +3 -4
  127. ultralytics/utils/callbacks/tensorboard.py +4 -6
  128. ultralytics/utils/callbacks/wb.py +10 -13
  129. ultralytics/utils/checks.py +29 -56
  130. ultralytics/utils/cpu.py +1 -2
  131. ultralytics/utils/dist.py +8 -12
  132. ultralytics/utils/downloads.py +17 -27
  133. ultralytics/utils/errors.py +6 -8
  134. ultralytics/utils/events.py +2 -4
  135. ultralytics/utils/export/__init__.py +4 -239
  136. ultralytics/utils/export/engine.py +237 -0
  137. ultralytics/utils/export/imx.py +11 -17
  138. ultralytics/utils/export/tensorflow.py +217 -0
  139. ultralytics/utils/files.py +10 -15
  140. ultralytics/utils/git.py +5 -7
  141. ultralytics/utils/instance.py +30 -51
  142. ultralytics/utils/logger.py +11 -15
  143. ultralytics/utils/loss.py +8 -14
  144. ultralytics/utils/metrics.py +98 -138
  145. ultralytics/utils/nms.py +13 -16
  146. ultralytics/utils/ops.py +47 -74
  147. ultralytics/utils/patches.py +11 -18
  148. ultralytics/utils/plotting.py +29 -42
  149. ultralytics/utils/tal.py +25 -39
  150. ultralytics/utils/torch_utils.py +45 -73
  151. ultralytics/utils/tqdm.py +6 -8
  152. ultralytics/utils/triton.py +9 -12
  153. ultralytics/utils/tuner.py +1 -2
  154. dgenerate_ultralytics_headless-8.3.222.dist-info/RECORD +0 -283
  155. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/WHEEL +0 -0
  156. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/entry_points.txt +0 -0
  157. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/licenses/LICENSE +0 -0
  158. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/top_level.txt +0 -0
@@ -11,8 +11,7 @@ from .blocks import RoPEAttention
11
11
 
12
12
 
13
13
  class MemoryAttentionLayer(nn.Module):
14
- """
15
- Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.
14
+ """Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.
16
15
 
17
16
  This class combines self-attention, cross-attention, and feedforward components to process input tensors and
18
17
  generate memory-based attention outputs.
@@ -61,8 +60,7 @@ class MemoryAttentionLayer(nn.Module):
61
60
  pos_enc_at_cross_attn_keys: bool = True,
62
61
  pos_enc_at_cross_attn_queries: bool = False,
63
62
  ):
64
- """
65
- Initialize a memory attention layer with self-attention, cross-attention, and feedforward components.
63
+ """Initialize a memory attention layer with self-attention, cross-attention, and feedforward components.
66
64
 
67
65
  Args:
68
66
  d_model (int): Dimensionality of the model.
@@ -145,8 +143,7 @@ class MemoryAttentionLayer(nn.Module):
145
143
  query_pos: torch.Tensor | None = None,
146
144
  num_k_exclude_rope: int = 0,
147
145
  ) -> torch.Tensor:
148
- """
149
- Process input tensors through self-attention, cross-attention, and feedforward network layers.
146
+ """Process input tensors through self-attention, cross-attention, and feedforward network layers.
150
147
 
151
148
  Args:
152
149
  tgt (torch.Tensor): Target tensor for self-attention with shape (N, L, D).
@@ -168,11 +165,10 @@ class MemoryAttentionLayer(nn.Module):
168
165
 
169
166
 
170
167
  class MemoryAttention(nn.Module):
171
- """
172
- Memory attention module for processing sequential data with self and cross-attention mechanisms.
168
+ """Memory attention module for processing sequential data with self and cross-attention mechanisms.
173
169
 
174
- This class implements a multi-layer attention mechanism that combines self-attention and cross-attention
175
- for processing sequential data, particularly useful in transformer-like architectures.
170
+ This class implements a multi-layer attention mechanism that combines self-attention and cross-attention for
171
+ processing sequential data, particularly useful in transformer-like architectures.
176
172
 
177
173
  Attributes:
178
174
  d_model (int): The dimension of the model's hidden state.
@@ -206,11 +202,10 @@ class MemoryAttention(nn.Module):
206
202
  num_layers: int,
207
203
  batch_first: bool = True, # Do layers expect batch first input?
208
204
  ):
209
- """
210
- Initialize MemoryAttention with specified layers and normalization for sequential data processing.
205
+ """Initialize MemoryAttention with specified layers and normalization for sequential data processing.
211
206
 
212
- This class implements a multi-layer attention mechanism that combines self-attention and cross-attention
213
- for processing sequential data, particularly useful in transformer-like architectures.
207
+ This class implements a multi-layer attention mechanism that combines self-attention and cross-attention for
208
+ processing sequential data, particularly useful in transformer-like architectures.
214
209
 
215
210
  Args:
216
211
  d_model (int): The dimension of the model's hidden state.
@@ -247,8 +242,7 @@ class MemoryAttention(nn.Module):
247
242
  memory_pos: torch.Tensor | None = None, # pos_enc for cross-attention inputs
248
243
  num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
249
244
  ) -> torch.Tensor:
250
- """
251
- Process inputs through attention layers, applying self and cross-attention with positional encoding.
245
+ """Process inputs through attention layers, applying self and cross-attention with positional encoding.
252
246
 
253
247
  Args:
254
248
  curr (torch.Tensor): Self-attention input tensor, representing the current state.
@@ -23,11 +23,10 @@ NO_OBJ_SCORE = -1024.0
23
23
 
24
24
 
25
25
  class SAMModel(nn.Module):
26
- """
27
- Segment Anything Model (SAM) for object segmentation tasks.
26
+ """Segment Anything Model (SAM) for object segmentation tasks.
28
27
 
29
- This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images
30
- and input prompts.
28
+ This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images and input
29
+ prompts.
31
30
 
32
31
  Attributes:
33
32
  mask_threshold (float): Threshold value for mask prediction.
@@ -61,8 +60,7 @@ class SAMModel(nn.Module):
61
60
  pixel_mean: list[float] = (123.675, 116.28, 103.53),
62
61
  pixel_std: list[float] = (58.395, 57.12, 57.375),
63
62
  ) -> None:
64
- """
65
- Initialize the SAMModel class to predict object masks from an image and input prompts.
63
+ """Initialize the SAMModel class to predict object masks from an image and input prompts.
66
64
 
67
65
  Args:
68
66
  image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings.
@@ -98,11 +96,10 @@ class SAMModel(nn.Module):
98
96
 
99
97
 
100
98
  class SAM2Model(torch.nn.Module):
101
- """
102
- SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.
99
+ """SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.
103
100
 
104
- This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms
105
- for temporal consistency and efficient tracking of objects across frames.
101
+ This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms for temporal
102
+ consistency and efficient tracking of objects across frames.
106
103
 
107
104
  Attributes:
108
105
  mask_threshold (float): Threshold value for mask prediction.
@@ -136,24 +133,24 @@ class SAM2Model(torch.nn.Module):
136
133
  use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection.
137
134
  no_obj_embed_spatial (torch.Tensor | None): No-object embedding for spatial frames.
138
135
  max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.
139
- directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the
140
- first frame.
141
- multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial
142
- conditioning frames.
136
+ directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the first
137
+ frame.
138
+ multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial conditioning
139
+ frames.
143
140
  multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.
144
141
  multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.
145
142
  multimask_output_for_tracking (bool): Whether to use multimask output for tracking.
146
143
  use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.
147
144
  iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].
148
145
  memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.
149
- non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in
150
- memory encoder during evaluation.
146
+ non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in memory
147
+ encoder during evaluation.
151
148
  sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.
152
149
  sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.
153
- binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames
154
- with clicks during evaluation.
155
- use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM
156
- prompt encoder and mask decoder on frames with mask input.
150
+ binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames with
151
+ clicks during evaluation.
152
+ use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM prompt
153
+ encoder and mask decoder on frames with mask input.
157
154
 
158
155
  Methods:
159
156
  forward_image: Process image batch through encoder to extract multi-level features.
@@ -208,8 +205,7 @@ class SAM2Model(torch.nn.Module):
208
205
  sam_mask_decoder_extra_args=None,
209
206
  compile_image_encoder: bool = False,
210
207
  ):
211
- """
212
- Initialize the SAM2Model for video object segmentation with memory-based tracking.
208
+ """Initialize the SAM2Model for video object segmentation with memory-based tracking.
213
209
 
214
210
  Args:
215
211
  image_encoder (nn.Module): Visual encoder for extracting image features.
@@ -220,35 +216,35 @@ class SAM2Model(torch.nn.Module):
220
216
  backbone_stride (int): Stride of the image backbone output.
221
217
  sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.
222
218
  sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.
223
- binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames
224
- with clicks during evaluation.
219
+ binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames with
220
+ clicks during evaluation.
225
221
  use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM
226
222
  prompt encoder and mask decoder on frames with mask input.
227
223
  max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.
228
- directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the
229
- first frame.
224
+ directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the first
225
+ frame.
230
226
  use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder.
231
- multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial
232
- conditioning frames.
227
+ multimask_output_in_sam (bool): Whether to output multiple masks for the first click on initial conditioning
228
+ frames.
233
229
  multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.
234
230
  multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.
235
231
  multimask_output_for_tracking (bool): Whether to use multimask output for tracking.
236
232
  use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.
237
233
  iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].
238
234
  memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.
239
- non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in
240
- memory encoder during evaluation.
235
+ non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in memory
236
+ encoder during evaluation.
241
237
  use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder.
242
238
  max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder
243
239
  cross-attention.
244
- add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in
245
- the encoder.
240
+ add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in the
241
+ encoder.
246
242
  proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional
247
243
  encoding in object pointers.
248
244
  use_signed_tpos_enc_to_obj_ptrs (bool): Whether to use signed distance in the temporal positional encoding
249
245
  in the object pointers.
250
- only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past
251
- during evaluation.
246
+ only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past during
247
+ evaluation.
252
248
  pred_obj_scores (bool): Whether to predict if there is an object in the frame.
253
249
  pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores.
254
250
  fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present.
@@ -428,25 +424,23 @@ class SAM2Model(torch.nn.Module):
428
424
  high_res_features=None,
429
425
  multimask_output=False,
430
426
  ):
431
- """
432
- Forward pass through SAM prompt encoders and mask heads.
427
+ """Forward pass through SAM prompt encoders and mask heads.
433
428
 
434
429
  This method processes image features and optional point/mask inputs to generate object masks and scores.
435
430
 
436
431
  Args:
437
432
  backbone_features (torch.Tensor): Image features with shape (B, C, H, W).
438
433
  point_inputs (dict[str, torch.Tensor] | None): Dictionary containing point prompts.
439
- 'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute
440
- pixel-unit coordinates in (x, y) format for P input points.
441
- 'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks,
442
- 0 means negative clicks, and -1 means padding.
443
- mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the
444
- same spatial size as the image.
445
- high_res_features (list[torch.Tensor] | None): List of two feature maps with shapes
446
- (B, C, 4*H, 4*W) and (B, C, 2*H, 2*W) respectively, used as high-resolution feature maps
447
- for SAM decoder.
448
- multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False,
449
- output only 1 mask and its IoU estimate.
434
+ 'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute pixel-unit coordinates in
435
+ (x, y) format for P input points.
436
+ 'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks, 0 means negative
437
+ clicks, and -1 means padding.
438
+ mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the same spatial
439
+ size as the image.
440
+ high_res_features (list[torch.Tensor] | None): List of two feature maps with shapes (B, C, 4*H, 4*W) and (B,
441
+ C, 2*H, 2*W) respectively, used as high-resolution feature maps for SAM decoder.
442
+ multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False, output only 1
443
+ mask and its IoU estimate.
450
444
 
451
445
  Returns:
452
446
  low_res_multimasks (torch.Tensor): Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.
@@ -22,12 +22,11 @@ from ultralytics.utils.instance import to_2tuple
22
22
 
23
23
 
24
24
  class Conv2d_BN(torch.nn.Sequential):
25
- """
26
- A sequential container that performs 2D convolution followed by batch normalization.
25
+ """A sequential container that performs 2D convolution followed by batch normalization.
27
26
 
28
- This module combines a 2D convolution layer with batch normalization, providing a common building block
29
- for convolutional neural networks. The batch normalization weights and biases are initialized to specific
30
- values for optimal training performance.
27
+ This module combines a 2D convolution layer with batch normalization, providing a common building block for
28
+ convolutional neural networks. The batch normalization weights and biases are initialized to specific values for
29
+ optimal training performance.
31
30
 
32
31
  Attributes:
33
32
  c (torch.nn.Conv2d): 2D convolution layer.
@@ -52,8 +51,7 @@ class Conv2d_BN(torch.nn.Sequential):
52
51
  groups: int = 1,
53
52
  bn_weight_init: float = 1,
54
53
  ):
55
- """
56
- Initialize a sequential container with 2D convolution followed by batch normalization.
54
+ """Initialize a sequential container with 2D convolution followed by batch normalization.
57
55
 
58
56
  Args:
59
57
  a (int): Number of input channels.
@@ -74,11 +72,10 @@ class Conv2d_BN(torch.nn.Sequential):
74
72
 
75
73
 
76
74
  class PatchEmbed(nn.Module):
77
- """
78
- Embed images into patches and project them into a specified embedding dimension.
75
+ """Embed images into patches and project them into a specified embedding dimension.
79
76
 
80
- This module converts input images into patch embeddings using a sequence of convolutional layers,
81
- effectively downsampling the spatial dimensions while increasing the channel dimension.
77
+ This module converts input images into patch embeddings using a sequence of convolutional layers, effectively
78
+ downsampling the spatial dimensions while increasing the channel dimension.
82
79
 
83
80
  Attributes:
84
81
  patches_resolution (tuple[int, int]): Resolution of the patches after embedding.
@@ -97,8 +94,7 @@ class PatchEmbed(nn.Module):
97
94
  """
98
95
 
99
96
  def __init__(self, in_chans: int, embed_dim: int, resolution: int, activation):
100
- """
101
- Initialize patch embedding with convolutional layers for image-to-patch conversion and projection.
97
+ """Initialize patch embedding with convolutional layers for image-to-patch conversion and projection.
102
98
 
103
99
  Args:
104
100
  in_chans (int): Number of input channels.
@@ -125,11 +121,10 @@ class PatchEmbed(nn.Module):
125
121
 
126
122
 
127
123
  class MBConv(nn.Module):
128
- """
129
- Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture.
124
+ """Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture.
130
125
 
131
- This module implements the mobile inverted bottleneck convolution with expansion, depthwise convolution,
132
- and projection phases, along with residual connections for improved gradient flow.
126
+ This module implements the mobile inverted bottleneck convolution with expansion, depthwise convolution, and
127
+ projection phases, along with residual connections for improved gradient flow.
133
128
 
134
129
  Attributes:
135
130
  in_chans (int): Number of input channels.
@@ -153,8 +148,7 @@ class MBConv(nn.Module):
153
148
  """
154
149
 
155
150
  def __init__(self, in_chans: int, out_chans: int, expand_ratio: float, activation, drop_path: float):
156
- """
157
- Initialize the MBConv layer with specified input/output channels, expansion ratio, and activation.
151
+ """Initialize the MBConv layer with specified input/output channels, expansion ratio, and activation.
158
152
 
159
153
  Args:
160
154
  in_chans (int): Number of input channels.
@@ -195,12 +189,11 @@ class MBConv(nn.Module):
195
189
 
196
190
 
197
191
  class PatchMerging(nn.Module):
198
- """
199
- Merge neighboring patches in the feature map and project to a new dimension.
192
+ """Merge neighboring patches in the feature map and project to a new dimension.
200
193
 
201
- This class implements a patch merging operation that combines spatial information and adjusts the feature
202
- dimension using a series of convolutional layers with batch normalization. It effectively reduces spatial
203
- resolution while potentially increasing channel dimensions.
194
+ This class implements a patch merging operation that combines spatial information and adjusts the feature dimension
195
+ using a series of convolutional layers with batch normalization. It effectively reduces spatial resolution while
196
+ potentially increasing channel dimensions.
204
197
 
205
198
  Attributes:
206
199
  input_resolution (tuple[int, int]): The input resolution (height, width) of the feature map.
@@ -221,8 +214,7 @@ class PatchMerging(nn.Module):
221
214
  """
222
215
 
223
216
  def __init__(self, input_resolution: tuple[int, int], dim: int, out_dim: int, activation):
224
- """
225
- Initialize the PatchMerging module for merging and projecting neighboring patches in feature maps.
217
+ """Initialize the PatchMerging module for merging and projecting neighboring patches in feature maps.
226
218
 
227
219
  Args:
228
220
  input_resolution (tuple[int, int]): The input resolution (height, width) of the feature map.
@@ -259,11 +251,10 @@ class PatchMerging(nn.Module):
259
251
 
260
252
 
261
253
  class ConvLayer(nn.Module):
262
- """
263
- Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv).
254
+ """Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv).
264
255
 
265
- This layer optionally applies downsample operations to the output and supports gradient checkpointing
266
- for memory efficiency during training.
256
+ This layer optionally applies downsample operations to the output and supports gradient checkpointing for memory
257
+ efficiency during training.
267
258
 
268
259
  Attributes:
269
260
  dim (int): Dimensionality of the input and output.
@@ -293,11 +284,10 @@ class ConvLayer(nn.Module):
293
284
  out_dim: int | None = None,
294
285
  conv_expand_ratio: float = 4.0,
295
286
  ):
296
- """
297
- Initialize the ConvLayer with the given dimensions and settings.
287
+ """Initialize the ConvLayer with the given dimensions and settings.
298
288
 
299
- This layer consists of multiple MobileNetV3-style inverted bottleneck convolutions (MBConv) and
300
- optionally applies downsampling to the output.
289
+ This layer consists of multiple MobileNetV3-style inverted bottleneck convolutions (MBConv) and optionally
290
+ applies downsampling to the output.
301
291
 
302
292
  Args:
303
293
  dim (int): The dimensionality of the input and output.
@@ -307,7 +297,7 @@ class ConvLayer(nn.Module):
307
297
  drop_path (float | list[float], optional): Drop path rate. Single float or a list of floats for each MBConv.
308
298
  downsample (Optional[nn.Module], optional): Function for downsampling the output. None to skip downsampling.
309
299
  use_checkpoint (bool, optional): Whether to use gradient checkpointing to save memory.
310
- out_dim (Optional[int], optional): The dimensionality of the output. None means it will be the same as `dim`.
300
+ out_dim (Optional[int], optional): Output dimensions. None means it will be the same as `dim`.
311
301
  conv_expand_ratio (float, optional): Expansion ratio for the MBConv layers.
312
302
  """
313
303
  super().__init__()
@@ -345,11 +335,10 @@ class ConvLayer(nn.Module):
345
335
 
346
336
 
347
337
  class MLP(nn.Module):
348
- """
349
- Multi-layer Perceptron (MLP) module for transformer architectures.
338
+ """Multi-layer Perceptron (MLP) module for transformer architectures.
350
339
 
351
- This module applies layer normalization, two fully-connected layers with an activation function in between,
352
- and dropout. It is commonly used in transformer-based architectures for processing token embeddings.
340
+ This module applies layer normalization, two fully-connected layers with an activation function in between, and
341
+ dropout. It is commonly used in transformer-based architectures for processing token embeddings.
353
342
 
354
343
  Attributes:
355
344
  norm (nn.LayerNorm): Layer normalization applied to the input.
@@ -376,8 +365,7 @@ class MLP(nn.Module):
376
365
  activation=nn.GELU,
377
366
  drop: float = 0.0,
378
367
  ):
379
- """
380
- Initialize a multi-layer perceptron with configurable input, hidden, and output dimensions.
368
+ """Initialize a multi-layer perceptron with configurable input, hidden, and output dimensions.
381
369
 
382
370
  Args:
383
371
  in_features (int): Number of input features.
@@ -406,12 +394,11 @@ class MLP(nn.Module):
406
394
 
407
395
 
408
396
  class Attention(torch.nn.Module):
409
- """
410
- Multi-head attention module with spatial awareness and trainable attention biases.
397
+ """Multi-head attention module with spatial awareness and trainable attention biases.
411
398
 
412
- This module implements a multi-head attention mechanism with support for spatial awareness, applying
413
- attention biases based on spatial resolution. It includes trainable attention biases for each unique
414
- offset between spatial positions in the resolution grid.
399
+ This module implements a multi-head attention mechanism with support for spatial awareness, applying attention
400
+ biases based on spatial resolution. It includes trainable attention biases for each unique offset between spatial
401
+ positions in the resolution grid.
415
402
 
416
403
  Attributes:
417
404
  num_heads (int): Number of attention heads.
@@ -444,12 +431,11 @@ class Attention(torch.nn.Module):
444
431
  attn_ratio: float = 4,
445
432
  resolution: tuple[int, int] = (14, 14),
446
433
  ):
447
- """
448
- Initialize the Attention module for multi-head attention with spatial awareness.
434
+ """Initialize the Attention module for multi-head attention with spatial awareness.
449
435
 
450
- This module implements a multi-head attention mechanism with support for spatial awareness, applying
451
- attention biases based on spatial resolution. It includes trainable attention biases for each unique
452
- offset between spatial positions in the resolution grid.
436
+ This module implements a multi-head attention mechanism with support for spatial awareness, applying attention
437
+ biases based on spatial resolution. It includes trainable attention biases for each unique offset between
438
+ spatial positions in the resolution grid.
453
439
 
454
440
  Args:
455
441
  dim (int): The dimensionality of the input and output.
@@ -521,12 +507,11 @@ class Attention(torch.nn.Module):
521
507
 
522
508
 
523
509
  class TinyViTBlock(nn.Module):
524
- """
525
- TinyViT Block that applies self-attention and a local convolution to the input.
510
+ """TinyViT Block that applies self-attention and a local convolution to the input.
526
511
 
527
- This block is a key component of the TinyViT architecture, combining self-attention mechanisms with
528
- local convolutions to process input features efficiently. It supports windowed attention for
529
- computational efficiency and includes residual connections.
512
+ This block is a key component of the TinyViT architecture, combining self-attention mechanisms with local
513
+ convolutions to process input features efficiently. It supports windowed attention for computational efficiency and
514
+ includes residual connections.
530
515
 
531
516
  Attributes:
532
517
  dim (int): The dimensionality of the input and output.
@@ -559,11 +544,10 @@ class TinyViTBlock(nn.Module):
559
544
  local_conv_size: int = 3,
560
545
  activation=nn.GELU,
561
546
  ):
562
- """
563
- Initialize a TinyViT block with self-attention and local convolution.
547
+ """Initialize a TinyViT block with self-attention and local convolution.
564
548
 
565
- This block is a key component of the TinyViT architecture, combining self-attention mechanisms with
566
- local convolutions to process input features efficiently.
549
+ This block is a key component of the TinyViT architecture, combining self-attention mechanisms with local
550
+ convolutions to process input features efficiently.
567
551
 
568
552
  Args:
569
553
  dim (int): Dimensionality of the input and output features.
@@ -644,8 +628,7 @@ class TinyViTBlock(nn.Module):
644
628
  return x + self.drop_path(self.mlp(x))
645
629
 
646
630
  def extra_repr(self) -> str:
647
- """
648
- Return a string representation of the TinyViTBlock's parameters.
631
+ """Return a string representation of the TinyViTBlock's parameters.
649
632
 
650
633
  This method provides a formatted string containing key information about the TinyViTBlock, including its
651
634
  dimension, input resolution, number of attention heads, window size, and MLP ratio.
@@ -665,12 +648,11 @@ class TinyViTBlock(nn.Module):
665
648
 
666
649
 
667
650
  class BasicLayer(nn.Module):
668
- """
669
- A basic TinyViT layer for one stage in a TinyViT architecture.
651
+ """A basic TinyViT layer for one stage in a TinyViT architecture.
670
652
 
671
- This class represents a single layer in the TinyViT model, consisting of multiple TinyViT blocks
672
- and an optional downsampling operation. It processes features at a specific resolution and
673
- dimensionality within the overall architecture.
653
+ This class represents a single layer in the TinyViT model, consisting of multiple TinyViT blocks and an optional
654
+ downsampling operation. It processes features at a specific resolution and dimensionality within the overall
655
+ architecture.
674
656
 
675
657
  Attributes:
676
658
  dim (int): The dimensionality of the input and output features.
@@ -704,11 +686,10 @@ class BasicLayer(nn.Module):
704
686
  activation=nn.GELU,
705
687
  out_dim: int | None = None,
706
688
  ):
707
- """
708
- Initialize a BasicLayer in the TinyViT architecture.
689
+ """Initialize a BasicLayer in the TinyViT architecture.
709
690
 
710
- This layer consists of multiple TinyViT blocks and an optional downsampling operation. It is designed to
711
- process feature maps at a specific resolution and dimensionality within the TinyViT model.
691
+ This layer consists of multiple TinyViT blocks and an optional downsampling operation. It is designed to process
692
+ feature maps at a specific resolution and dimensionality within the TinyViT model.
712
693
 
713
694
  Args:
714
695
  dim (int): Dimensionality of the input and output features.
@@ -718,12 +699,14 @@ class BasicLayer(nn.Module):
718
699
  window_size (int): Size of the local window for attention computation.
719
700
  mlp_ratio (float, optional): Ratio of MLP hidden dimension to embedding dimension.
720
701
  drop (float, optional): Dropout rate.
721
- drop_path (float | list[float], optional): Stochastic depth rate. Can be a float or a list of floats for each block.
722
- downsample (nn.Module | None, optional): Downsampling layer at the end of the layer. None to skip downsampling.
702
+ drop_path (float | list[float], optional): Stochastic depth rate. Can be a float or a list of floats for
703
+ each block.
704
+ downsample (nn.Module | None, optional): Downsampling layer at the end of the layer. None to skip
705
+ downsampling.
723
706
  use_checkpoint (bool, optional): Whether to use gradient checkpointing to save memory.
724
707
  local_conv_size (int, optional): Kernel size for the local convolution in each TinyViT block.
725
708
  activation (nn.Module): Activation function used in the MLP.
726
- out_dim (int | None, optional): Output dimension after downsampling. None means it will be the same as `dim`.
709
+ out_dim (int | None, optional): Output dimension after downsampling. None means it will be the same as dim.
727
710
  """
728
711
  super().__init__()
729
712
  self.dim = dim
@@ -768,12 +751,11 @@ class BasicLayer(nn.Module):
768
751
 
769
752
 
770
753
  class TinyViT(nn.Module):
771
- """
772
- TinyViT: A compact vision transformer architecture for efficient image classification and feature extraction.
754
+ """TinyViT: A compact vision transformer architecture for efficient image classification and feature extraction.
773
755
 
774
- This class implements the TinyViT model, which combines elements of vision transformers and convolutional
775
- neural networks for improved efficiency and performance on vision tasks. It features hierarchical processing
776
- with patch embedding, multiple stages of attention and convolution blocks, and a feature refinement neck.
756
+ This class implements the TinyViT model, which combines elements of vision transformers and convolutional neural
757
+ networks for improved efficiency and performance on vision tasks. It features hierarchical processing with patch
758
+ embedding, multiple stages of attention and convolution blocks, and a feature refinement neck.
777
759
 
778
760
  Attributes:
779
761
  img_size (int): Input image size.
@@ -813,11 +795,10 @@ class TinyViT(nn.Module):
813
795
  local_conv_size: int = 3,
814
796
  layer_lr_decay: float = 1.0,
815
797
  ):
816
- """
817
- Initialize the TinyViT model.
798
+ """Initialize the TinyViT model.
818
799
 
819
- This constructor sets up the TinyViT architecture, including patch embedding, multiple layers of
820
- attention and convolution blocks, and a classification head.
800
+ This constructor sets up the TinyViT architecture, including patch embedding, multiple layers of attention and
801
+ convolution blocks, and a classification head.
821
802
 
822
803
  Args:
823
804
  img_size (int, optional): Size of the input image.