ultralytics 8.3.101__py3-none-any.whl → 8.3.103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. tests/test_exports.py +14 -5
  2. tests/test_solutions.py +140 -76
  3. ultralytics/__init__.py +1 -1
  4. ultralytics/cfg/__init__.py +1 -1
  5. ultralytics/engine/exporter.py +23 -8
  6. ultralytics/engine/tuner.py +8 -2
  7. ultralytics/hub/__init__.py +29 -2
  8. ultralytics/hub/google/__init__.py +18 -1
  9. ultralytics/models/fastsam/predict.py +12 -1
  10. ultralytics/models/nas/predict.py +21 -3
  11. ultralytics/models/rtdetr/val.py +26 -2
  12. ultralytics/models/sam/amg.py +22 -1
  13. ultralytics/models/sam/modules/encoders.py +85 -4
  14. ultralytics/models/sam/modules/memory_attention.py +61 -3
  15. ultralytics/models/sam/modules/utils.py +108 -5
  16. ultralytics/models/utils/loss.py +38 -2
  17. ultralytics/models/utils/ops.py +15 -1
  18. ultralytics/models/yolo/classify/predict.py +11 -1
  19. ultralytics/models/yolo/classify/train.py +17 -1
  20. ultralytics/models/yolo/classify/val.py +82 -6
  21. ultralytics/models/yolo/detect/predict.py +20 -1
  22. ultralytics/models/yolo/model.py +55 -4
  23. ultralytics/models/yolo/obb/predict.py +16 -1
  24. ultralytics/models/yolo/obb/train.py +35 -2
  25. ultralytics/models/yolo/obb/val.py +87 -6
  26. ultralytics/models/yolo/pose/predict.py +18 -1
  27. ultralytics/models/yolo/pose/train.py +48 -3
  28. ultralytics/models/yolo/pose/val.py +113 -8
  29. ultralytics/models/yolo/segment/predict.py +27 -2
  30. ultralytics/models/yolo/segment/train.py +61 -3
  31. ultralytics/models/yolo/segment/val.py +10 -1
  32. ultralytics/models/yolo/world/train_world.py +29 -1
  33. ultralytics/models/yolo/yoloe/train.py +47 -3
  34. ultralytics/nn/autobackend.py +9 -8
  35. ultralytics/nn/modules/activation.py +26 -3
  36. ultralytics/nn/modules/block.py +89 -0
  37. ultralytics/nn/modules/head.py +3 -92
  38. ultralytics/nn/modules/utils.py +70 -4
  39. ultralytics/nn/tasks.py +3 -0
  40. ultralytics/nn/text_model.py +93 -17
  41. ultralytics/solutions/instance_segmentation.py +15 -7
  42. ultralytics/solutions/solutions.py +2 -47
  43. ultralytics/utils/benchmarks.py +1 -1
  44. ultralytics/utils/callbacks/base.py +22 -5
  45. ultralytics/utils/callbacks/comet.py +93 -5
  46. ultralytics/utils/callbacks/dvc.py +64 -5
  47. ultralytics/utils/callbacks/neptune.py +25 -2
  48. ultralytics/utils/callbacks/tensorboard.py +30 -2
  49. ultralytics/utils/callbacks/wb.py +16 -1
  50. ultralytics/utils/dist.py +35 -2
  51. ultralytics/utils/errors.py +27 -6
  52. ultralytics/utils/metrics.py +1 -1
  53. ultralytics/utils/patches.py +33 -5
  54. ultralytics/utils/torch_utils.py +14 -6
  55. ultralytics/utils/triton.py +16 -3
  56. ultralytics/utils/tuner.py +17 -9
  57. {ultralytics-8.3.101.dist-info → ultralytics-8.3.103.dist-info}/METADATA +3 -4
  58. {ultralytics-8.3.101.dist-info → ultralytics-8.3.103.dist-info}/RECORD +62 -62
  59. {ultralytics-8.3.101.dist-info → ultralytics-8.3.103.dist-info}/WHEEL +0 -0
  60. {ultralytics-8.3.101.dist-info → ultralytics-8.3.103.dist-info}/entry_points.txt +0 -0
  61. {ultralytics-8.3.101.dist-info → ultralytics-8.3.103.dist-info}/licenses/LICENSE +0 -0
  62. {ultralytics-8.3.101.dist-info → ultralytics-8.3.103.dist-info}/top_level.txt +0 -0
@@ -386,7 +386,24 @@ class MemoryEncoder(nn.Module):
386
386
  out_dim,
387
387
  in_dim=256, # in_dim of pix_feats
388
388
  ):
389
- """Initialize the MemoryEncoder for encoding pixel features and masks into memory representations."""
389
+ """
390
+ Initialize the MemoryEncoder for encoding pixel features and masks into memory representations.
391
+
392
+ This encoder processes pixel-level features and masks, fusing them to generate encoded memory representations
393
+ suitable for downstream tasks in image segmentation models like SAM (Segment Anything Model).
394
+
395
+ Args:
396
+ out_dim (int): Output dimension of the encoded features.
397
+ in_dim (int): Input dimension of the pixel features. Default is 256.
398
+
399
+ Examples:
400
+ >>> encoder = MemoryEncoder(out_dim=256, in_dim=256)
401
+ >>> pix_feat = torch.randn(1, 256, 64, 64)
402
+ >>> masks = torch.randn(1, 1, 64, 64)
403
+ >>> encoded_feat, pos = encoder(pix_feat, masks)
404
+ >>> print(encoded_feat.shape, pos.shape)
405
+ torch.Size([1, 256, 64, 64]) torch.Size([1, 128, 64, 64])
406
+ """
390
407
  super().__init__()
391
408
 
392
409
  self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)
@@ -453,7 +470,26 @@ class ImageEncoder(nn.Module):
453
470
  neck: nn.Module,
454
471
  scalp: int = 0,
455
472
  ):
456
- """Initialize the ImageEncoder with trunk and neck networks for feature extraction and refinement."""
473
+ """
474
+ Initialize the ImageEncoder with trunk and neck networks for feature extraction and refinement.
475
+
476
+ This encoder combines a trunk network for feature extraction with a neck network for feature refinement
477
+ and positional encoding generation. It can optionally discard the lowest resolution features.
478
+
479
+ Args:
480
+ trunk (nn.Module): The trunk network for initial feature extraction.
481
+ neck (nn.Module): The neck network for feature refinement and positional encoding generation.
482
+ scalp (int): Number of lowest resolution feature levels to discard.
483
+
484
+ Examples:
485
+ >>> trunk = SomeTrunkNetwork()
486
+ >>> neck = SomeNeckNetwork()
487
+ >>> encoder = ImageEncoder(trunk, neck, scalp=1)
488
+ >>> image = torch.randn(1, 3, 224, 224)
489
+ >>> output = encoder(image)
490
+ >>> print(output.keys())
491
+ dict_keys(['vision_features', 'vision_pos_enc', 'backbone_fpn'])
492
+ """
457
493
  super().__init__()
458
494
  self.trunk = trunk
459
495
  self.neck = neck
@@ -681,7 +717,34 @@ class Hiera(nn.Module):
681
717
  ),
682
718
  return_interm_layers=True, # return feats from every stage
683
719
  ):
684
- """Initialize the Hiera model, configuring its hierarchical vision transformer architecture."""
720
+ """
721
+ Initialize a Hiera model, a hierarchical vision transformer for efficient multiscale feature extraction.
722
+
723
+ Hiera is a hierarchical vision transformer architecture designed for efficient multiscale feature extraction
724
+ in image processing tasks. It uses a series of transformer blocks organized into stages, with optional
725
+ pooling and global attention mechanisms.
726
+
727
+ Args:
728
+ embed_dim (int): Initial embedding dimension for the model.
729
+ num_heads (int): Initial number of attention heads.
730
+ drop_path_rate (float): Stochastic depth rate.
731
+ q_pool (int): Number of query pooling stages.
732
+ q_stride (Tuple[int, int]): Downsampling stride between stages.
733
+ stages (Tuple[int, ...]): Number of blocks per stage.
734
+ dim_mul (float): Dimension multiplier factor at stage transitions.
735
+ head_mul (float): Head multiplier factor at stage transitions.
736
+ window_pos_embed_bkg_spatial_size (Tuple[int, int]): Spatial size for window positional embedding background.
737
+ window_spec (Tuple[int, ...]): Window sizes for each stage when not using global attention.
738
+ global_att_blocks (Tuple[int, ...]): Indices of blocks that use global attention.
739
+ return_interm_layers (bool): Whether to return intermediate layer outputs.
740
+
741
+ Examples:
742
+ >>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))
743
+ >>> input_tensor = torch.randn(1, 3, 224, 224)
744
+ >>> output_features = model(input_tensor)
745
+ >>> for feat in output_features:
746
+ ... print(feat.shape)
747
+ """
685
748
  super().__init__()
686
749
 
687
750
  assert len(stages) == len(window_spec)
@@ -756,7 +819,25 @@ class Hiera(nn.Module):
756
819
  return pos_embed
757
820
 
758
821
  def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
759
- """Perform forward pass through Hiera model, extracting multiscale features from input images."""
822
+ """
823
+ Perform forward pass through Hiera model, extracting multiscale features from input images.
824
+
825
+ Args:
826
+ x (torch.Tensor): Input tensor with shape (B, C, H, W) representing a batch of images.
827
+
828
+ Returns:
829
+ (List[torch.Tensor]): List of feature maps at different scales, each with shape (B, C_i, H_i, W_i), where
830
+ C_i is the channel dimension and H_i, W_i are the spatial dimensions at scale i. The list is ordered
831
+ from highest resolution (fine features) to lowest resolution (coarse features) if return_interm_layers
832
+ is True, otherwise contains only the final output.
833
+
834
+ Examples:
835
+ >>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))
836
+ >>> input_tensor = torch.randn(1, 3, 224, 224)
837
+ >>> output_features = model(input_tensor)
838
+ >>> for feat in output_features:
839
+ ... print(feat.shape)
840
+ """
760
841
  x = self.patch_embed(x)
761
842
  # x: (B, H, W, C)
762
843
 
@@ -60,7 +60,17 @@ class MemoryAttentionLayer(nn.Module):
60
60
  pos_enc_at_cross_attn_keys: bool = True,
61
61
  pos_enc_at_cross_attn_queries: bool = False,
62
62
  ):
63
- """Initialize a memory attention layer with self-attention, cross-attention, and feedforward components."""
63
+ """
64
+ Initialize a memory attention layer with self-attention, cross-attention, and feedforward components.
65
+
66
+ Args:
67
+ d_model (int): Dimensionality of the model.
68
+ dim_feedforward (int): Dimensionality of the feedforward network.
69
+ dropout (float): Dropout rate for regularization.
70
+ pos_enc_at_attn (bool): Whether to add positional encoding at attention.
71
+ pos_enc_at_cross_attn_keys (bool): Whether to add positional encoding to cross-attention keys.
72
+ pos_enc_at_cross_attn_queries (bool): Whether to add positional encoding to cross-attention queries.
73
+ """
64
74
  super().__init__()
65
75
  self.d_model = d_model
66
76
  self.dim_feedforward = dim_feedforward
@@ -183,7 +193,31 @@ class MemoryAttention(nn.Module):
183
193
  num_layers: int,
184
194
  batch_first: bool = True, # Do layers expect batch first input?
185
195
  ):
186
- """Initialize MemoryAttention with specified layers and normalization for sequential data processing."""
196
+ """
197
+ Initialize MemoryAttention with specified layers and normalization for sequential data processing.
198
+
199
+ This class implements a multi-layer attention mechanism that combines self-attention and cross-attention
200
+ for processing sequential data, particularly useful in transformer-like architectures.
201
+
202
+ Args:
203
+ d_model (int): The dimension of the model's hidden state.
204
+ pos_enc_at_input (bool): Whether to apply positional encoding at the input.
205
+ layer (nn.Module): The attention layer to be used in the module.
206
+ num_layers (int): The number of attention layers.
207
+ batch_first (bool): Whether the input tensors are in batch-first format.
208
+
209
+ Examples:
210
+ >>> d_model = 256
211
+ >>> layer = MemoryAttentionLayer(d_model)
212
+ >>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)
213
+ >>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)
214
+ >>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)
215
+ >>> curr_pos = torch.randn(10, 32, d_model)
216
+ >>> memory_pos = torch.randn(20, 32, d_model)
217
+ >>> output = attention(curr, memory, curr_pos, memory_pos)
218
+ >>> print(output.shape)
219
+ torch.Size([10, 32, 256])
220
+ """
187
221
  super().__init__()
188
222
  self.d_model = d_model
189
223
  self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])
@@ -200,7 +234,31 @@ class MemoryAttention(nn.Module):
200
234
  memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
201
235
  num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
202
236
  ) -> torch.Tensor:
203
- """Process inputs through attention layers, applying self and cross-attention with positional encoding."""
237
+ """
238
+ Process inputs through attention layers, applying self and cross-attention with positional encoding.
239
+
240
+ Args:
241
+ curr (torch.Tensor): Self-attention input tensor, representing the current state.
242
+ memory (torch.Tensor): Cross-attention input tensor, representing memory information.
243
+ curr_pos (Optional[Tensor]): Positional encoding for self-attention inputs.
244
+ memory_pos (Optional[Tensor]): Positional encoding for cross-attention inputs.
245
+ num_obj_ptr_tokens (int): Number of object pointer tokens to exclude from rotary position embedding.
246
+
247
+ Returns:
248
+ (torch.Tensor): Processed output tensor after applying attention layers and normalization.
249
+
250
+ Examples:
251
+ >>> d_model = 256
252
+ >>> layer = MemoryAttentionLayer(d_model)
253
+ >>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)
254
+ >>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)
255
+ >>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)
256
+ >>> curr_pos = torch.randn(10, 32, d_model)
257
+ >>> memory_pos = torch.randn(20, 32, d_model)
258
+ >>> output = attention(curr, memory, curr_pos, memory_pos)
259
+ >>> print(output.shape)
260
+ torch.Size([10, 32, 256])
261
+ """
204
262
  if isinstance(curr, list):
205
263
  assert isinstance(curr_pos, list)
206
264
  assert len(curr) == len(curr_pos) == 1
@@ -61,7 +61,23 @@ def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num
61
61
 
62
62
 
63
63
  def get_1d_sine_pe(pos_inds, dim, temperature=10000):
64
- """Generate 1D sinusoidal positional embeddings for given positions and dimensions."""
64
+ """
65
+ Generate 1D sinusoidal positional embeddings for given positions and dimensions.
66
+
67
+ Args:
68
+ pos_inds (torch.Tensor): Position indices for which to generate embeddings.
69
+ dim (int): Dimension of the positional embeddings. Should be an even number.
70
+ temperature (float): Scaling factor for the frequency of the sinusoidal functions.
71
+
72
+ Returns:
73
+ (torch.Tensor): Sinusoidal positional embeddings with shape (pos_inds.shape, dim).
74
+
75
+ Examples:
76
+ >>> pos = torch.tensor([0, 1, 2, 3])
77
+ >>> embeddings = get_1d_sine_pe(pos, 128)
78
+ >>> embeddings.shape
79
+ torch.Size([4, 128])
80
+ """
65
81
  pe_dim = dim // 2
66
82
  dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
67
83
  dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
@@ -72,7 +88,30 @@ def get_1d_sine_pe(pos_inds, dim, temperature=10000):
72
88
 
73
89
 
74
90
  def init_t_xy(end_x: int, end_y: int):
75
- """Initialize 1D and 2D coordinate tensors for a grid of specified dimensions."""
91
+ """
92
+ Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.
93
+
94
+ This function creates coordinate tensors for a grid with dimensions end_x × end_y. It generates a linear index tensor
95
+ and corresponding x and y coordinate tensors.
96
+
97
+ Args:
98
+ end_x (int): Width of the grid (number of columns).
99
+ end_y (int): Height of the grid (number of rows).
100
+
101
+ Returns:
102
+ t (torch.Tensor): Linear indices for each position in the grid, with shape (end_x * end_y).
103
+ t_x (torch.Tensor): X-coordinates for each position, with shape (end_x * end_y).
104
+ t_y (torch.Tensor): Y-coordinates for each position, with shape (end_x * end_y).
105
+
106
+ Examples:
107
+ >>> t, t_x, t_y = init_t_xy(3, 2)
108
+ >>> print(t)
109
+ tensor([0., 1., 2., 3., 4., 5.])
110
+ >>> print(t_x)
111
+ tensor([0., 1., 2., 0., 1., 2.])
112
+ >>> print(t_y)
113
+ tensor([0., 0., 0., 1., 1., 1.])
114
+ """
76
115
  t = torch.arange(end_x * end_y, dtype=torch.float32)
77
116
  t_x = (t % end_x).float()
78
117
  t_y = torch.div(t, end_x, rounding_mode="floor").float()
@@ -80,7 +119,32 @@ def init_t_xy(end_x: int, end_y: int):
80
119
 
81
120
 
82
121
  def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
83
- """Compute axial complex exponential positional encodings for 2D spatial positions in a grid."""
122
+ """
123
+ Compute axial complex exponential positional encodings for 2D spatial positions in a grid.
124
+
125
+ This function generates complex exponential positional encodings for a 2D grid of spatial positions,
126
+ using separate frequency components for the x and y dimensions.
127
+
128
+ Args:
129
+ dim (int): Dimension of the positional encoding.
130
+ end_x (int): Width of the 2D grid.
131
+ end_y (int): Height of the 2D grid.
132
+ theta (float, optional): Scaling factor for frequency computation.
133
+
134
+ Returns:
135
+ freqs_cis_x (torch.Tensor): Complex exponential positional encodings for x-dimension with shape
136
+ (end_x*end_y, dim//4).
137
+ freqs_cis_y (torch.Tensor): Complex exponential positional encodings for y-dimension with shape
138
+ (end_x*end_y, dim//4).
139
+
140
+ Examples:
141
+ >>> dim, end_x, end_y = 128, 8, 8
142
+ >>> freqs_cis_x, freqs_cis_y = compute_axial_cis(dim, end_x, end_y)
143
+ >>> freqs_cis_x.shape
144
+ torch.Size([64, 32])
145
+ >>> freqs_cis_y.shape
146
+ torch.Size([64, 32])
147
+ """
84
148
  freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
85
149
  freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
86
150
 
@@ -93,7 +157,22 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
93
157
 
94
158
 
95
159
  def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
96
- """Reshape frequency tensor for broadcasting with input tensor, ensuring dimensional compatibility."""
160
+ """
161
+ Reshape frequency tensor for broadcasting with input tensor.
162
+
163
+ Reshapes a frequency tensor to ensure dimensional compatibility for broadcasting with an input tensor.
164
+ This function is typically used in positional encoding operations.
165
+
166
+ Args:
167
+ freqs_cis (torch.Tensor): Frequency tensor with shape matching the last two dimensions of x.
168
+ x (torch.Tensor): Input tensor to broadcast with.
169
+
170
+ Returns:
171
+ (torch.Tensor): Reshaped frequency tensor ready for broadcasting with the input tensor.
172
+
173
+ Raises:
174
+ AssertionError: If the shape of freqs_cis doesn't match the last two dimensions of x.
175
+ """
97
176
  ndim = x.ndim
98
177
  assert 0 <= 1 < ndim
99
178
  assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
@@ -107,7 +186,31 @@ def apply_rotary_enc(
107
186
  freqs_cis: torch.Tensor,
108
187
  repeat_freqs_k: bool = False,
109
188
  ):
110
- """Apply rotary positional encoding to query and key tensors using complex-valued frequency components."""
189
+ """
190
+ Apply rotary positional encoding to query and key tensors.
191
+
192
+ This function applies rotary positional encoding (RoPE) to query and key tensors using complex-valued frequency
193
+ components. RoPE is a technique that injects relative position information into self-attention mechanisms.
194
+
195
+ Args:
196
+ xq (torch.Tensor): Query tensor to encode with positional information.
197
+ xk (torch.Tensor): Key tensor to encode with positional information.
198
+ freqs_cis (torch.Tensor): Complex-valued frequency components for rotary encoding with shape matching the
199
+ last two dimensions of xq.
200
+ repeat_freqs_k (bool, optional): Whether to repeat frequency components along sequence length dimension
201
+ to match key sequence length.
202
+
203
+ Returns:
204
+ xq_out (torch.Tensor): Query tensor with rotary positional encoding applied.
205
+ xk_out (torch.Tensor): Key tensor with rotary positional encoding applied, or original xk if xk is empty.
206
+
207
+ Examples:
208
+ >>> import torch
209
+ >>> xq = torch.randn(2, 8, 16, 64) # [batch, heads, seq_len, dim]
210
+ >>> xk = torch.randn(2, 8, 16, 64)
211
+ >>> freqs_cis = compute_axial_cis(64, 4, 4) # For a 4x4 spatial grid with dim=64
212
+ >>> q_encoded, k_encoded = apply_rotary_enc(xq, xk, freqs_cis)
213
+ """
111
214
  xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
112
215
  xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) if xk.shape[-2] != 0 else None
113
216
  freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
@@ -65,7 +65,25 @@ class DETRLoss(nn.Module):
65
65
  self.device = None
66
66
 
67
67
  def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=""):
68
- """Compute classification loss based on predictions, target values, and ground truth scores."""
68
+ """
69
+ Compute classification loss based on predictions, target values, and ground truth scores.
70
+
71
+ Args:
72
+ pred_scores (torch.Tensor): Predicted class scores with shape (batch_size, num_queries, num_classes).
73
+ targets (torch.Tensor): Target class indices with shape (batch_size, num_queries).
74
+ gt_scores (torch.Tensor): Ground truth confidence scores with shape (batch_size, num_queries).
75
+ num_gts (int): Number of ground truth objects.
76
+ postfix (str, optional): String to append to the loss name for identification in multi-loss scenarios.
77
+
78
+ Returns:
79
+ loss_cls (torch.Tensor): Classification loss value.
80
+
81
+ Notes:
82
+ The function supports different classification loss types:
83
+ - Varifocal Loss (if self.vfl is True and num_gts > 0)
84
+ - Focal Loss (if self.fl is True)
85
+ - BCE Loss (default fallback)
86
+ """
69
87
  # Logits: [b, query, num_classes], gt_class: list[[n, 1]]
70
88
  name_class = f"loss_class{postfix}"
71
89
  bs, nq = pred_scores.shape[:2]
@@ -87,7 +105,25 @@ class DETRLoss(nn.Module):
87
105
  return {name_class: loss_cls.squeeze() * self.loss_gain["class"]}
88
106
 
89
107
  def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=""):
90
- """Compute bounding box and GIoU losses for predicted and ground truth bounding boxes."""
108
+ """
109
+ Compute bounding box and GIoU losses for predicted and ground truth bounding boxes.
110
+
111
+ Args:
112
+ pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (batch_size, num_queries, 4).
113
+ gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (N, 4), where N is the total
114
+ number of ground truth boxes.
115
+ postfix (str): String to append to the loss names for identification in multi-loss scenarios.
116
+
117
+ Returns:
118
+ loss (dict): Dictionary containing:
119
+ - loss_bbox{postfix} (torch.Tensor): L1 loss between predicted and ground truth boxes,
120
+ scaled by the bbox loss gain.
121
+ - loss_giou{postfix} (torch.Tensor): GIoU loss between predicted and ground truth boxes,
122
+ scaled by the giou loss gain.
123
+
124
+ Notes:
125
+ If no ground truth boxes are provided (empty list), zero-valued tensors are returned for both losses.
126
+ """
91
127
  # Boxes: [b, query, 4], gt_bbox: list[[n, 4]]
92
128
  name_bbox = f"loss_bbox{postfix}"
93
129
  name_giou = f"loss_giou{postfix}"
@@ -31,7 +31,21 @@ class HungarianMatcher(nn.Module):
31
31
  """
32
32
 
33
33
  def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
34
- """Initialize a HungarianMatcher module for optimal assignment of predicted and ground truth bounding boxes."""
34
+ """
35
+ Initialize a HungarianMatcher module for optimal assignment of predicted and ground truth bounding boxes.
36
+
37
+ The HungarianMatcher uses a cost function that considers classification scores, bounding box coordinates,
38
+ and optionally mask predictions to perform optimal bipartite matching between predictions and ground truths.
39
+
40
+ Args:
41
+ cost_gain (dict, optional): Dictionary of cost coefficients for different components of the matching cost.
42
+ Should contain keys 'class', 'bbox', 'giou', 'mask', and 'dice'.
43
+ use_fl (bool, optional): Whether to use Focal Loss for the classification cost calculation.
44
+ with_mask (bool, optional): Whether the model makes mask predictions.
45
+ num_sample_points (int, optional): Number of sample points used in mask cost calculation.
46
+ alpha (float, optional): Alpha factor in Focal Loss calculation.
47
+ gamma (float, optional): Gamma factor in Focal Loss calculation.
48
+ """
35
49
  super().__init__()
36
50
  if cost_gain is None:
37
51
  cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1}
@@ -36,7 +36,17 @@ class ClassificationPredictor(BasePredictor):
36
36
  """
37
37
 
38
38
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
39
- """Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'."""
39
+ """
40
+ Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.
41
+
42
+ This constructor initializes a ClassificationPredictor instance, which extends BasePredictor for classification
43
+ tasks. It ensures the task is set to 'classify' regardless of input configuration.
44
+
45
+ Args:
46
+ cfg (dict): Default configuration dictionary containing prediction settings. Defaults to DEFAULT_CFG.
47
+ overrides (dict, optional): Configuration overrides that take precedence over cfg.
48
+ _callbacks (list, optional): List of callback functions to be executed during prediction.
49
+ """
40
50
  super().__init__(cfg, overrides, _callbacks)
41
51
  self.args.task = "classify"
42
52
  self._legacy_transform_name = "ultralytics.yolo.data.augment.ToTensor"
@@ -48,7 +48,23 @@ class ClassificationTrainer(BaseTrainer):
48
48
  """
49
49
 
50
50
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
51
- """Initialize a ClassificationTrainer object with optional configuration overrides and callbacks."""
51
+ """
52
+ Initialize a ClassificationTrainer object.
53
+
54
+ This constructor sets up a trainer for image classification tasks, configuring the task type and default
55
+ image size if not specified.
56
+
57
+ Args:
58
+ cfg (dict, optional): Default configuration dictionary containing training parameters.
59
+ overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
60
+ _callbacks (list, optional): List of callback functions to be executed during training.
61
+
62
+ Examples:
63
+ >>> from ultralytics.models.yolo.classify import ClassificationTrainer
64
+ >>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3)
65
+ >>> trainer = ClassificationTrainer(overrides=args)
66
+ >>> trainer.train()
67
+ """
52
68
  if overrides is None:
53
69
  overrides = {}
54
70
  overrides["task"] = "classify"
@@ -49,7 +49,25 @@ class ClassificationValidator(BaseValidator):
49
49
  """
50
50
 
51
51
  def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
52
- """Initialize ClassificationValidator with dataloader, save directory, and other parameters."""
52
+ """
53
+ Initialize ClassificationValidator with dataloader, save directory, and other parameters.
54
+
55
+ This validator handles the validation process for classification models, including metrics calculation,
56
+ confusion matrix generation, and visualization of results.
57
+
58
+ Args:
59
+ dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
60
+ save_dir (str | Path, optional): Directory to save results.
61
+ pbar (bool, optional): Display a progress bar.
62
+ args (dict, optional): Arguments containing model and validation configuration.
63
+ _callbacks (list, optional): List of callback functions to be called during validation.
64
+
65
+ Examples:
66
+ >>> from ultralytics.models.yolo.classify import ClassificationValidator
67
+ >>> args = dict(model="yolo11n-cls.pt", data="imagenet10")
68
+ >>> validator = ClassificationValidator(args=args)
69
+ >>> validator()
70
+ """
53
71
  super().__init__(dataloader, save_dir, pbar, args, _callbacks)
54
72
  self.targets = None
55
73
  self.pred = None
@@ -76,13 +94,38 @@ class ClassificationValidator(BaseValidator):
76
94
  return batch
77
95
 
78
96
  def update_metrics(self, preds, batch):
79
- """Update running metrics with model predictions and batch targets."""
97
+ """
98
+ Update running metrics with model predictions and batch targets.
99
+
100
+ Args:
101
+ preds (torch.Tensor): Model predictions, typically logits or probabilities for each class.
102
+ batch (dict): Batch data containing images and class labels.
103
+
104
+ This method appends the top-N predictions (sorted by confidence in descending order) to the
105
+ prediction list for later evaluation. N is limited to the minimum of 5 and the number of classes.
106
+ """
80
107
  n5 = min(len(self.names), 5)
81
108
  self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
82
109
  self.targets.append(batch["cls"].type(torch.int32).cpu())
83
110
 
84
111
  def finalize_metrics(self, *args, **kwargs):
85
- """Finalize metrics including confusion matrix and processing speed."""
112
+ """
113
+ Finalize metrics including confusion matrix and processing speed.
114
+
115
+ This method processes the accumulated predictions and targets to generate the confusion matrix,
116
+ optionally plots it, and updates the metrics object with speed information.
117
+
118
+ Args:
119
+ *args (Any): Variable length argument list.
120
+ **kwargs (Any): Arbitrary keyword arguments.
121
+
122
+ Examples:
123
+ >>> validator = ClassificationValidator()
124
+ >>> validator.pred = [torch.tensor([[0, 1, 2]])] # Top-3 predictions for one sample
125
+ >>> validator.targets = [torch.tensor([0])] # Ground truth class
126
+ >>> validator.finalize_metrics()
127
+ >>> print(validator.metrics.confusion_matrix) # Access the confusion matrix
128
+ """
86
129
  self.confusion_matrix.process_cls_preds(self.pred, self.targets)
87
130
  if self.args.plots:
88
131
  for normalize in True, False:
@@ -107,7 +150,16 @@ class ClassificationValidator(BaseValidator):
107
150
  return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
108
151
 
109
152
  def get_dataloader(self, dataset_path, batch_size):
110
- """Build and return a data loader for classification validation."""
153
+ """
154
+ Build and return a data loader for classification validation.
155
+
156
+ Args:
157
+ dataset_path (str | Path): Path to the dataset directory.
158
+ batch_size (int): Number of samples per batch.
159
+
160
+ Returns:
161
+ (torch.utils.data.DataLoader): DataLoader object for the classification validation dataset.
162
+ """
111
163
  dataset = self.build_dataset(dataset_path)
112
164
  return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)
113
165
 
@@ -117,7 +169,18 @@ class ClassificationValidator(BaseValidator):
117
169
  LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
118
170
 
119
171
  def plot_val_samples(self, batch, ni):
120
- """Plot validation image samples with their ground truth labels."""
172
+ """
173
+ Plot validation image samples with their ground truth labels.
174
+
175
+ Args:
176
+ batch (dict): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).
177
+ ni (int): Batch index used for naming the output file.
178
+
179
+ Examples:
180
+ >>> validator = ClassificationValidator()
181
+ >>> batch = {"img": torch.rand(16, 3, 224, 224), "cls": torch.randint(0, 10, (16,))}
182
+ >>> validator.plot_val_samples(batch, 0)
183
+ """
121
184
  plot_images(
122
185
  images=batch["img"],
123
186
  batch_idx=torch.arange(len(batch["img"])),
@@ -128,7 +191,20 @@ class ClassificationValidator(BaseValidator):
128
191
  )
129
192
 
130
193
  def plot_predictions(self, batch, preds, ni):
131
- """Plot images with their predicted class labels and save the visualization."""
194
+ """
195
+ Plot images with their predicted class labels and save the visualization.
196
+
197
+ Args:
198
+ batch (dict): Batch data containing images and other information.
199
+ preds (torch.Tensor): Model predictions with shape (batch_size, num_classes).
200
+ ni (int): Batch index used for naming the output file.
201
+
202
+ Examples:
203
+ >>> validator = ClassificationValidator()
204
+ >>> batch = {"img": torch.rand(16, 3, 224, 224)}
205
+ >>> preds = torch.rand(16, 10) # 16 images, 10 classes
206
+ >>> validator.plot_predictions(batch, preds, 0)
207
+ """
132
208
  plot_images(
133
209
  batch["img"],
134
210
  batch_idx=torch.arange(len(batch["img"])),
@@ -31,7 +31,26 @@ class DetectionPredictor(BasePredictor):
31
31
  """
32
32
 
33
33
  def postprocess(self, preds, img, orig_imgs, **kwargs):
34
- """Post-processes predictions and returns a list of Results objects."""
34
+ """
35
+ Post-process predictions and return a list of Results objects.
36
+
37
+ This method applies non-maximum suppression to raw model predictions and prepares them for visualization and
38
+ further analysis.
39
+
40
+ Args:
41
+ preds (torch.Tensor): Raw predictions from the model.
42
+ img (torch.Tensor): Processed input image tensor in model input format.
43
+ orig_imgs (torch.Tensor | list): Original input images before preprocessing.
44
+ **kwargs (Any): Additional keyword arguments.
45
+
46
+ Returns:
47
+ (list): List of Results objects containing the post-processed predictions.
48
+
49
+ Examples:
50
+ >>> predictor = DetectionPredictor(overrides=dict(model="yolov8n.pt"))
51
+ >>> results = predictor.predict("path/to/image.jpg")
52
+ >>> processed_results = predictor.postprocess(preds, img, orig_imgs)
53
+ """
35
54
  preds = ops.non_max_suppression(
36
55
  preds,
37
56
  self.args.conf,