ultralytics 8.0.196__py3-none-any.whl → 8.0.198__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (49) hide show
  1. ultralytics/__init__.py +1 -1
  2. ultralytics/cfg/__init__.py +4 -5
  3. ultralytics/data/augment.py +2 -2
  4. ultralytics/data/converter.py +12 -13
  5. ultralytics/data/dataset.py +1 -1
  6. ultralytics/engine/__init__.py +1 -0
  7. ultralytics/engine/exporter.py +1 -1
  8. ultralytics/engine/trainer.py +2 -1
  9. ultralytics/hub/session.py +1 -1
  10. ultralytics/models/fastsam/predict.py +33 -2
  11. ultralytics/models/fastsam/prompt.py +38 -1
  12. ultralytics/models/fastsam/utils.py +5 -5
  13. ultralytics/models/fastsam/val.py +27 -1
  14. ultralytics/models/nas/model.py +20 -0
  15. ultralytics/models/nas/predict.py +23 -0
  16. ultralytics/models/nas/val.py +24 -0
  17. ultralytics/models/rtdetr/val.py +17 -5
  18. ultralytics/models/sam/modules/decoders.py +26 -1
  19. ultralytics/models/sam/modules/encoders.py +31 -3
  20. ultralytics/models/sam/modules/sam.py +22 -7
  21. ultralytics/models/sam/modules/tiny_encoder.py +147 -45
  22. ultralytics/models/sam/modules/transformer.py +47 -2
  23. ultralytics/models/sam/predict.py +19 -2
  24. ultralytics/models/utils/loss.py +20 -2
  25. ultralytics/models/utils/ops.py +5 -5
  26. ultralytics/nn/modules/block.py +33 -10
  27. ultralytics/nn/modules/conv.py +16 -4
  28. ultralytics/nn/modules/head.py +48 -17
  29. ultralytics/nn/modules/transformer.py +2 -2
  30. ultralytics/nn/tasks.py +7 -7
  31. ultralytics/utils/__init__.py +2 -1
  32. ultralytics/utils/benchmarks.py +13 -0
  33. ultralytics/utils/callbacks/mlflow.py +76 -36
  34. ultralytics/utils/callbacks/wb.py +92 -1
  35. ultralytics/utils/checks.py +4 -4
  36. ultralytics/utils/errors.py +12 -0
  37. ultralytics/utils/files.py +1 -1
  38. ultralytics/utils/instance.py +41 -3
  39. ultralytics/utils/loss.py +22 -19
  40. ultralytics/utils/metrics.py +106 -24
  41. ultralytics/utils/tal.py +1 -1
  42. ultralytics/utils/torch_utils.py +4 -2
  43. ultralytics/utils/tuner.py +10 -4
  44. {ultralytics-8.0.196.dist-info → ultralytics-8.0.198.dist-info}/METADATA +1 -1
  45. {ultralytics-8.0.196.dist-info → ultralytics-8.0.198.dist-info}/RECORD +49 -49
  46. {ultralytics-8.0.196.dist-info → ultralytics-8.0.198.dist-info}/LICENSE +0 -0
  47. {ultralytics-8.0.196.dist-info → ultralytics-8.0.198.dist-info}/WHEEL +0 -0
  48. {ultralytics-8.0.196.dist-info → ultralytics-8.0.198.dist-info}/entry_points.txt +0 -0
  49. {ultralytics-8.0.196.dist-info → ultralytics-8.0.198.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,20 @@ from .encoders import ImageEncoderViT, PromptEncoder
16
16
 
17
17
 
18
18
  class Sam(nn.Module):
19
+ """
20
+ Sam (Segment Anything Model) is designed for object segmentation tasks. It uses image encoders to generate image
21
+ embeddings, and prompt encoders to encode various types of input prompts. These embeddings are then used by the mask
22
+ decoder to predict object masks.
23
+
24
+ Attributes:
25
+ mask_threshold (float): Threshold value for mask prediction.
26
+ image_format (str): Format of the input image, default is 'RGB'.
27
+ image_encoder (ImageEncoderViT): The backbone used to encode the image into embeddings.
28
+ prompt_encoder (PromptEncoder): Encodes various types of input prompts.
29
+ mask_decoder (MaskDecoder): Predicts object masks from the image and prompt embeddings.
30
+ pixel_mean (List[float]): Mean pixel values for image normalization.
31
+ pixel_std (List[float]): Standard deviation values for image normalization.
32
+ """
19
33
  mask_threshold: float = 0.0
20
34
  image_format: str = 'RGB'
21
35
 
@@ -28,18 +42,19 @@ class Sam(nn.Module):
28
42
  pixel_std: List[float] = (58.395, 57.12, 57.375)
29
43
  ) -> None:
30
44
  """
31
- SAM predicts object masks from an image and input prompts.
45
+ Initialize the Sam class to predict object masks from an image and input prompts.
32
46
 
33
47
  Note:
34
48
  All forward() operations moved to SAMPredictor.
35
49
 
36
50
  Args:
37
- image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for
38
- efficient mask prediction.
39
- prompt_encoder (PromptEncoder): Encodes various types of input prompts.
40
- mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
41
- pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
42
- pixel_std (list(float)): Std values for normalizing pixels in the input image.
51
+ image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings.
52
+ prompt_encoder (PromptEncoder): Encodes various types of input prompts.
53
+ mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
54
+ pixel_mean (List[float], optional): Mean values for normalizing pixels in the input image. Defaults to
55
+ (123.675, 116.28, 103.53).
56
+ pixel_std (List[float], optional): Std values for normalizing pixels in the input image. Defaults to
57
+ (58.395, 57.12, 57.375).
43
58
  """
44
59
  super().__init__()
45
60
  self.image_encoder = image_encoder
@@ -21,6 +21,7 @@ from ultralytics.utils.instance import to_2tuple
21
21
 
22
22
 
23
23
  class Conv2d_BN(torch.nn.Sequential):
24
+ """A sequential container that performs 2D convolution followed by batch normalization."""
24
25
 
25
26
  def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
26
27
  """Initializes the MBConv model with given input channels, output channels, expansion ratio, activation, and
@@ -35,6 +36,7 @@ class Conv2d_BN(torch.nn.Sequential):
35
36
 
36
37
 
37
38
  class PatchEmbed(nn.Module):
39
+ """Embeds images into patches and projects them into a specified embedding dimension."""
38
40
 
39
41
  def __init__(self, in_chans, embed_dim, resolution, activation):
40
42
  """Initialize the PatchMerging class with specified input, output dimensions, resolution and activation
@@ -59,6 +61,7 @@ class PatchEmbed(nn.Module):
59
61
 
60
62
 
61
63
  class MBConv(nn.Module):
64
+ """Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture."""
62
65
 
63
66
  def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
64
67
  """Initializes a convolutional layer with specified dimensions, input resolution, depth, and activation
@@ -96,6 +99,7 @@ class MBConv(nn.Module):
96
99
 
97
100
 
98
101
  class PatchMerging(nn.Module):
102
+ """Merges neighboring patches in the feature map and projects to a new dimension."""
99
103
 
100
104
  def __init__(self, input_resolution, dim, out_dim, activation):
101
105
  """Initializes the ConvLayer with specific dimension, input resolution, depth, activation, drop path, and other
@@ -130,6 +134,11 @@ class PatchMerging(nn.Module):
130
134
 
131
135
 
132
136
  class ConvLayer(nn.Module):
137
+ """
138
+ Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv).
139
+
140
+ Optionally applies downsample operations to the output, and provides support for gradient checkpointing.
141
+ """
133
142
 
134
143
  def __init__(
135
144
  self,
@@ -143,13 +152,27 @@ class ConvLayer(nn.Module):
143
152
  out_dim=None,
144
153
  conv_expand_ratio=4.,
145
154
  ):
155
+ """
156
+ Initializes the ConvLayer with the given dimensions and settings.
157
+
158
+ Args:
159
+ dim (int): The dimensionality of the input and output.
160
+ input_resolution (Tuple[int, int]): The resolution of the input image.
161
+ depth (int): The number of MBConv layers in the block.
162
+ activation (Callable): Activation function applied after each convolution.
163
+ drop_path (Union[float, List[float]]): Drop path rate. Single float or a list of floats for each MBConv.
164
+ downsample (Optional[Callable]): Function for downsampling the output. None to skip downsampling.
165
+ use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
166
+ out_dim (Optional[int]): The dimensionality of the output. None means it will be the same as `dim`.
167
+ conv_expand_ratio (float): Expansion ratio for the MBConv layers.
168
+ """
146
169
  super().__init__()
147
170
  self.dim = dim
148
171
  self.input_resolution = input_resolution
149
172
  self.depth = depth
150
173
  self.use_checkpoint = use_checkpoint
151
174
 
152
- # build blocks
175
+ # Build blocks
153
176
  self.blocks = nn.ModuleList([
154
177
  MBConv(
155
178
  dim,
@@ -159,7 +182,7 @@ class ConvLayer(nn.Module):
159
182
  drop_path[i] if isinstance(drop_path, list) else drop_path,
160
183
  ) for i in range(depth)])
161
184
 
162
- # patch merging layer
185
+ # Patch merging layer
163
186
  self.downsample = None if downsample is None else downsample(
164
187
  input_resolution, dim=dim, out_dim=out_dim, activation=activation)
165
188
 
@@ -171,6 +194,11 @@ class ConvLayer(nn.Module):
171
194
 
172
195
 
173
196
  class Mlp(nn.Module):
197
+ """
198
+ Multi-layer Perceptron (MLP) for transformer architectures.
199
+
200
+ This layer takes an input with in_features, applies layer normalization and two fully-connected layers.
201
+ """
174
202
 
175
203
  def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
176
204
  """Initializes Attention module with the given parameters including dimension, key_dim, number of heads, etc."""
@@ -194,6 +222,14 @@ class Mlp(nn.Module):
194
222
 
195
223
 
196
224
  class Attention(torch.nn.Module):
225
+ """
226
+ Multi-head attention module with support for spatial awareness, applying attention biases based on spatial
227
+ resolution. Implements trainable attention biases for each unique offset between spatial positions in the resolution
228
+ grid.
229
+
230
+ Attributes:
231
+ ab (Tensor, optional): Cached attention biases for inference, deleted during training.
232
+ """
197
233
 
198
234
  def __init__(
199
235
  self,
@@ -203,8 +239,21 @@ class Attention(torch.nn.Module):
203
239
  attn_ratio=4,
204
240
  resolution=(14, 14),
205
241
  ):
242
+ """
243
+ Initializes the Attention module.
244
+
245
+ Args:
246
+ dim (int): The dimensionality of the input and output.
247
+ key_dim (int): The dimensionality of the keys and queries.
248
+ num_heads (int, optional): Number of attention heads. Default is 8.
249
+ attn_ratio (float, optional): Attention ratio, affecting the dimensions of the value vectors. Default is 4.
250
+ resolution (Tuple[int, int], optional): Spatial resolution of the input feature map. Default is (14, 14).
251
+
252
+ Raises:
253
+ AssertionError: If `resolution` is not a tuple of length 2.
254
+ """
206
255
  super().__init__()
207
- # (h, w)
256
+
208
257
  assert isinstance(resolution, tuple) and len(resolution) == 2
209
258
  self.num_heads = num_heads
210
259
  self.scale = key_dim ** -0.5
@@ -241,8 +290,9 @@ class Attention(torch.nn.Module):
241
290
  else:
242
291
  self.ab = self.attention_biases[:, self.attention_bias_idxs]
243
292
 
244
- def forward(self, x): # x (B,N,C)
245
- B, N, _ = x.shape
293
+ def forward(self, x): # x
294
+ """Performs forward pass over the input tensor 'x' by applying normalization and querying keys/values."""
295
+ B, N, _ = x.shape # B, N, C
246
296
 
247
297
  # Normalization
248
298
  x = self.norm(x)
@@ -264,20 +314,7 @@ class Attention(torch.nn.Module):
264
314
 
265
315
 
266
316
  class TinyViTBlock(nn.Module):
267
- """
268
- TinyViT Block.
269
-
270
- Args:
271
- dim (int): Number of input channels.
272
- input_resolution (tuple[int, int]): Input resolution.
273
- num_heads (int): Number of attention heads.
274
- window_size (int): Window size.
275
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
276
- drop (float, optional): Dropout rate. Default: 0.0
277
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
278
- local_conv_size (int): the kernel size of the convolution between Attention and MLP. Default: 3
279
- activation (torch.nn): the activation function. Default: nn.GELU
280
- """
317
+ """TinyViT Block that applies self-attention and a local convolution to the input."""
281
318
 
282
319
  def __init__(
283
320
  self,
@@ -291,6 +328,24 @@ class TinyViTBlock(nn.Module):
291
328
  local_conv_size=3,
292
329
  activation=nn.GELU,
293
330
  ):
331
+ """
332
+ Initializes the TinyViTBlock.
333
+
334
+ Args:
335
+ dim (int): The dimensionality of the input and output.
336
+ input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
337
+ num_heads (int): Number of attention heads.
338
+ window_size (int, optional): Window size for attention. Default is 7.
339
+ mlp_ratio (float, optional): Ratio of mlp hidden dim to embedding dim. Default is 4.
340
+ drop (float, optional): Dropout rate. Default is 0.
341
+ drop_path (float, optional): Stochastic depth rate. Default is 0.
342
+ local_conv_size (int, optional): The kernel size of the local convolution. Default is 3.
343
+ activation (torch.nn, optional): Activation function for MLP. Default is nn.GELU.
344
+
345
+ Raises:
346
+ AssertionError: If `window_size` is not greater than 0.
347
+ AssertionError: If `dim` is not divisible by `num_heads`.
348
+ """
294
349
  super().__init__()
295
350
  self.dim = dim
296
351
  self.input_resolution = input_resolution
@@ -338,11 +393,11 @@ class TinyViTBlock(nn.Module):
338
393
  pH, pW = H + pad_b, W + pad_r
339
394
  nH = pH // self.window_size
340
395
  nW = pW // self.window_size
341
- # window partition
396
+ # Window partition
342
397
  x = x.view(B, nH, self.window_size, nW, self.window_size,
343
398
  C).transpose(2, 3).reshape(B * nH * nW, self.window_size * self.window_size, C)
344
399
  x = self.attn(x)
345
- # window reverse
400
+ # Window reverse
346
401
  x = x.view(B, nH, nW, self.window_size, self.window_size, C).transpose(2, 3).reshape(B, pH, pW, C)
347
402
 
348
403
  if padding:
@@ -367,24 +422,7 @@ class TinyViTBlock(nn.Module):
367
422
 
368
423
 
369
424
  class BasicLayer(nn.Module):
370
- """
371
- A basic TinyViT layer for one stage.
372
-
373
- Args:
374
- dim (int): Number of input channels.
375
- input_resolution (tuple[int]): Input resolution.
376
- depth (int): Number of blocks.
377
- num_heads (int): Number of attention heads.
378
- window_size (int): Local window size.
379
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
380
- drop (float, optional): Dropout rate. Default: 0.0
381
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
382
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
383
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
384
- local_conv_size (int): the kernel size of the depthwise convolution between attention and MLP. Default: 3
385
- activation (torch.nn): the activation function. Default: nn.GELU
386
- out_dim (int | optional): the output dimension of the layer. Default: None
387
- """
425
+ """A basic TinyViT layer for one stage in a TinyViT architecture."""
388
426
 
389
427
  def __init__(
390
428
  self,
@@ -402,13 +440,34 @@ class BasicLayer(nn.Module):
402
440
  activation=nn.GELU,
403
441
  out_dim=None,
404
442
  ):
443
+ """
444
+ Initializes the BasicLayer.
445
+
446
+ Args:
447
+ dim (int): The dimensionality of the input and output.
448
+ input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
449
+ depth (int): Number of TinyViT blocks.
450
+ num_heads (int): Number of attention heads.
451
+ window_size (int): Local window size.
452
+ mlp_ratio (float, optional): Ratio of mlp hidden dim to embedding dim. Default is 4.
453
+ drop (float, optional): Dropout rate. Default is 0.
454
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default is 0.
455
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default is None.
456
+ use_checkpoint (bool, optional): Whether to use checkpointing to save memory. Default is False.
457
+ local_conv_size (int, optional): Kernel size of the local convolution. Default is 3.
458
+ activation (torch.nn, optional): Activation function for MLP. Default is nn.GELU.
459
+ out_dim (int | None, optional): The output dimension of the layer. Default is None.
460
+
461
+ Raises:
462
+ ValueError: If `drop_path` is a list of float but its length doesn't match `depth`.
463
+ """
405
464
  super().__init__()
406
465
  self.dim = dim
407
466
  self.input_resolution = input_resolution
408
467
  self.depth = depth
409
468
  self.use_checkpoint = use_checkpoint
410
469
 
411
- # build blocks
470
+ # Build blocks
412
471
  self.blocks = nn.ModuleList([
413
472
  TinyViTBlock(
414
473
  dim=dim,
@@ -422,7 +481,7 @@ class BasicLayer(nn.Module):
422
481
  activation=activation,
423
482
  ) for i in range(depth)])
424
483
 
425
- # patch merging layer
484
+ # Patch merging layer
426
485
  self.downsample = None if downsample is None else downsample(
427
486
  input_resolution, dim=dim, out_dim=out_dim, activation=activation)
428
487
 
@@ -456,6 +515,30 @@ class LayerNorm2d(nn.Module):
456
515
 
457
516
 
458
517
  class TinyViT(nn.Module):
518
+ """
519
+ The TinyViT architecture for vision tasks.
520
+
521
+ Attributes:
522
+ img_size (int): Input image size.
523
+ in_chans (int): Number of input channels.
524
+ num_classes (int): Number of classification classes.
525
+ embed_dims (List[int]): List of embedding dimensions for each layer.
526
+ depths (List[int]): List of depths for each layer.
527
+ num_heads (List[int]): List of number of attention heads for each layer.
528
+ window_sizes (List[int]): List of window sizes for each layer.
529
+ mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
530
+ drop_rate (float): Dropout rate for drop layers.
531
+ drop_path_rate (float): Drop path rate for stochastic depth.
532
+ use_checkpoint (bool): Use checkpointing for efficient memory usage.
533
+ mbconv_expand_ratio (float): Expansion ratio for MBConv layer.
534
+ local_conv_size (int): Local convolution kernel size.
535
+ layer_lr_decay (float): Layer-wise learning rate decay.
536
+
537
+ Note:
538
+ This implementation is generalized to accept a list of depths, attention heads,
539
+ embedding dimensions and window sizes, which allows you to create a
540
+ "stack" of TinyViT models of varying configurations.
541
+ """
459
542
 
460
543
  def __init__(
461
544
  self,
@@ -474,6 +557,25 @@ class TinyViT(nn.Module):
474
557
  local_conv_size=3,
475
558
  layer_lr_decay=1.0,
476
559
  ):
560
+ """
561
+ Initializes the TinyViT model.
562
+
563
+ Args:
564
+ img_size (int, optional): The input image size. Defaults to 224.
565
+ in_chans (int, optional): Number of input channels. Defaults to 3.
566
+ num_classes (int, optional): Number of classification classes. Defaults to 1000.
567
+ embed_dims (List[int], optional): List of embedding dimensions for each layer. Defaults to [96, 192, 384, 768].
568
+ depths (List[int], optional): List of depths for each layer. Defaults to [2, 2, 6, 2].
569
+ num_heads (List[int], optional): List of number of attention heads for each layer. Defaults to [3, 6, 12, 24].
570
+ window_sizes (List[int], optional): List of window sizes for each layer. Defaults to [7, 7, 14, 7].
571
+ mlp_ratio (float, optional): Ratio of MLP hidden dimension to embedding dimension. Defaults to 4.
572
+ drop_rate (float, optional): Dropout rate. Defaults to 0.
573
+ drop_path_rate (float, optional): Drop path rate for stochastic depth. Defaults to 0.1.
574
+ use_checkpoint (bool, optional): Whether to use checkpointing for efficient memory usage. Defaults to False.
575
+ mbconv_expand_ratio (float, optional): Expansion ratio for MBConv layer. Defaults to 4.0.
576
+ local_conv_size (int, optional): Local convolution kernel size. Defaults to 3.
577
+ layer_lr_decay (float, optional): Layer-wise learning rate decay. Defaults to 1.0.
578
+ """
477
579
  super().__init__()
478
580
  self.img_size = img_size
479
581
  self.num_classes = num_classes
@@ -491,10 +593,10 @@ class TinyViT(nn.Module):
491
593
  patches_resolution = self.patch_embed.patches_resolution
492
594
  self.patches_resolution = patches_resolution
493
595
 
494
- # stochastic depth
596
+ # Stochastic depth
495
597
  dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
496
598
 
497
- # build layers
599
+ # Build layers
498
600
  self.layers = nn.ModuleList()
499
601
  for i_layer in range(self.num_layers):
500
602
  kwargs = dict(
@@ -526,7 +628,7 @@ class TinyViT(nn.Module):
526
628
  self.norm_head = nn.LayerNorm(embed_dims[-1])
527
629
  self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
528
630
 
529
- # init weights
631
+ # Init weights
530
632
  self.apply(self._init_weights)
531
633
  self.set_layer_lr_decay(layer_lr_decay)
532
634
  self.neck = nn.Sequential(
@@ -551,7 +653,7 @@ class TinyViT(nn.Module):
551
653
  """Sets the learning rate decay for each layer in the TinyViT model."""
552
654
  decay_rate = layer_lr_decay
553
655
 
554
- # layers -> blocks (depth)
656
+ # Layers -> blocks (depth)
555
657
  depth = sum(self.depths)
556
658
  lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
557
659
 
@@ -10,6 +10,21 @@ from ultralytics.nn.modules import MLPBlock
10
10
 
11
11
 
12
12
  class TwoWayTransformer(nn.Module):
13
+ """
14
+ A Two-Way Transformer module that enables the simultaneous attention to both image and query points. This class
15
+ serves as a specialized transformer decoder that attends to an input image using queries whose positional embedding
16
+ is supplied. This is particularly useful for tasks like object detection, image segmentation, and point cloud
17
+ processing.
18
+
19
+ Attributes:
20
+ depth (int): The number of layers in the transformer.
21
+ embedding_dim (int): The channel dimension for the input embeddings.
22
+ num_heads (int): The number of heads for multihead attention.
23
+ mlp_dim (int): The internal channel dimension for the MLP block.
24
+ layers (nn.ModuleList): The list of TwoWayAttentionBlock layers that make up the transformer.
25
+ final_attn_token_to_image (Attention): The final attention layer applied from the queries to the image.
26
+ norm_final_attn (nn.LayerNorm): The layer normalization applied to the final queries.
27
+ """
13
28
 
14
29
  def __init__(
15
30
  self,
@@ -98,6 +113,23 @@ class TwoWayTransformer(nn.Module):
98
113
 
99
114
 
100
115
  class TwoWayAttentionBlock(nn.Module):
116
+ """
117
+ An attention block that performs both self-attention and cross-attention in two directions: queries to keys and
118
+ keys to queries. This block consists of four main layers: (1) self-attention on sparse inputs, (2) cross-attention
119
+ of sparse inputs to dense inputs, (3) an MLP block on sparse inputs, and (4) cross-attention of dense inputs to
120
+ sparse inputs.
121
+
122
+ Attributes:
123
+ self_attn (Attention): The self-attention layer for the queries.
124
+ norm1 (nn.LayerNorm): Layer normalization following the first attention block.
125
+ cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
126
+ norm2 (nn.LayerNorm): Layer normalization following the second attention block.
127
+ mlp (MLPBlock): MLP block that transforms the query embeddings.
128
+ norm3 (nn.LayerNorm): Layer normalization following the MLP block.
129
+ norm4 (nn.LayerNorm): Layer normalization following the third attention block.
130
+ cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
131
+ skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
132
+ """
101
133
 
102
134
  def __init__(
103
135
  self,
@@ -180,6 +212,17 @@ class Attention(nn.Module):
180
212
  num_heads: int,
181
213
  downsample_rate: int = 1,
182
214
  ) -> None:
215
+ """
216
+ Initializes the Attention model with the given dimensions and settings.
217
+
218
+ Args:
219
+ embedding_dim (int): The dimensionality of the input embeddings.
220
+ num_heads (int): The number of attention heads.
221
+ downsample_rate (int, optional): The factor by which the internal dimensions are downsampled. Defaults to 1.
222
+
223
+ Raises:
224
+ AssertionError: If 'num_heads' does not evenly divide the internal dimension (embedding_dim / downsample_rate).
225
+ """
183
226
  super().__init__()
184
227
  self.embedding_dim = embedding_dim
185
228
  self.internal_dim = embedding_dim // downsample_rate
@@ -191,13 +234,15 @@ class Attention(nn.Module):
191
234
  self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
192
235
  self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
193
236
 
194
- def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
237
+ @staticmethod
238
+ def _separate_heads(x: Tensor, num_heads: int) -> Tensor:
195
239
  """Separate the input tensor into the specified number of attention heads."""
196
240
  b, n, c = x.shape
197
241
  x = x.reshape(b, n, num_heads, c // num_heads)
198
242
  return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
199
243
 
200
- def _recombine_heads(self, x: Tensor) -> Tensor:
244
+ @staticmethod
245
+ def _recombine_heads(x: Tensor) -> Tensor:
201
246
  """Recombine the separated attention heads into a single tensor."""
202
247
  b, n_heads, n_tokens, c_per_head = x.shape
203
248
  x = x.transpose(1, 2)
@@ -17,6 +17,24 @@ from .build import build_sam
17
17
 
18
18
 
19
19
  class Predictor(BasePredictor):
20
+ """
21
+ A prediction class for segmentation tasks, extending the BasePredictor.
22
+
23
+ This class serves as an interface for model inference for segmentation tasks.
24
+ It can preprocess input images, perform inference, and postprocess the output.
25
+ It also supports handling various types of input prompts including bounding boxes,
26
+ points, and low-resolution masks for better prediction results.
27
+
28
+ Attributes:
29
+ cfg (dict): Configuration dictionary.
30
+ overrides (dict): Dictionary of overriding values.
31
+ _callbacks (dict): Dictionary of callback functions.
32
+ args (namespace): Argument namespace.
33
+ im (torch.Tensor): Preprocessed image for current prediction.
34
+ features (torch.Tensor): Image features.
35
+ prompts (dict): Dictionary of prompts like bboxes, points, masks.
36
+ segment_all (bool): Whether to perform segmentation on all objects or not.
37
+ """
20
38
 
21
39
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
22
40
  """Initializes the Predictor class with default or provided configuration, overrides, and callbacks."""
@@ -396,8 +414,7 @@ class Predictor(BasePredictor):
396
414
  unchanged = unchanged and not changed
397
415
 
398
416
  new_masks.append(torch.as_tensor(mask).unsqueeze(0))
399
- # Give score=0 to changed masks and score=1 to unchanged masks
400
- # so NMS will prefer ones that didn't need postprocessing
417
+ # Give score=0 to changed masks and 1 to unchanged masks so NMS prefers masks not needing postprocessing
401
418
  scores.append(float(unchanged))
402
419
 
403
420
  # Recalculate boxes and remove any new duplicates
@@ -11,6 +11,24 @@ from .ops import HungarianMatcher
11
11
 
12
12
 
13
13
  class DETRLoss(nn.Module):
14
+ """
15
+ DETR (DEtection TRansformer) Loss class. This class calculates and returns the different loss components for the
16
+ DETR object detection model. It computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary
17
+ losses.
18
+
19
+ Attributes:
20
+ nc (int): The number of classes.
21
+ loss_gain (dict): Coefficients for different loss components.
22
+ aux_loss (bool): Whether to compute auxiliary losses.
23
+ use_fl (bool): Use FocalLoss or not.
24
+ use_vfl (bool): Use VarifocalLoss or not.
25
+ use_uni_match (bool): Whether to use a fixed layer to assign labels for the auxiliary branch.
26
+ uni_match_ind (int): The fixed indices of a layer to use if `use_uni_match` is True.
27
+ matcher (HungarianMatcher): Object to compute matching cost and indices.
28
+ fl (FocalLoss or None): Focal Loss object if `use_fl` is True, otherwise None.
29
+ vfl (VarifocalLoss or None): Varifocal Loss object if `use_vfl` is True, otherwise None.
30
+ device (torch.device): Device on which tensors are stored.
31
+ """
14
32
 
15
33
  def __init__(self,
16
34
  nc=80,
@@ -48,7 +66,7 @@ class DETRLoss(nn.Module):
48
66
 
49
67
  def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=''):
50
68
  """Computes the classification loss based on predictions, target values, and ground truth scores."""
51
- # logits: [b, query, num_classes], gt_class: list[[n, 1]]
69
+ # Logits: [b, query, num_classes], gt_class: list[[n, 1]]
52
70
  name_class = f'loss_class{postfix}'
53
71
  bs, nq = pred_scores.shape[:2]
54
72
  # one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes)
@@ -72,7 +90,7 @@ class DETRLoss(nn.Module):
72
90
  """Calculates and returns the bounding box loss and GIoU loss for the predicted and ground truth bounding
73
91
  boxes.
74
92
  """
75
- # boxes: [b, query, 4], gt_bbox: list[[n, 4]]
93
+ # Boxes: [b, query, 4], gt_bbox: list[[n, 4]]
76
94
  name_bbox = f'loss_bbox{postfix}'
77
95
  name_giou = f'loss_giou{postfix}'
78
96
 
@@ -188,7 +188,7 @@ def get_cdn_group(batch,
188
188
 
189
189
  num_group = num_dn // max_nums
190
190
  num_group = 1 if num_group == 0 else num_group
191
- # pad gt to max_num of a batch
191
+ # Pad gt to max_num of a batch
192
192
  bs = len(gt_groups)
193
193
  gt_cls = batch['cls'] # (bs*num, )
194
194
  gt_bbox = batch['bboxes'] # bs*num, 4
@@ -204,10 +204,10 @@ def get_cdn_group(batch,
204
204
  neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
205
205
 
206
206
  if cls_noise_ratio > 0:
207
- # half of bbox prob
207
+ # Half of bbox prob
208
208
  mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)
209
209
  idx = torch.nonzero(mask).squeeze(-1)
210
- # randomly put a new one here
210
+ # Randomly put a new one here
211
211
  new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)
212
212
  dn_cls[idx] = new_label
213
213
 
@@ -240,9 +240,9 @@ def get_cdn_group(batch,
240
240
 
241
241
  tgt_size = num_dn + num_queries
242
242
  attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool)
243
- # match query cannot see the reconstruct
243
+ # Match query cannot see the reconstruct
244
244
  attn_mask[num_dn:, :num_dn] = True
245
- # reconstruct cannot see each other
245
+ # Reconstruct cannot see each other
246
246
  for i in range(num_group):
247
247
  if i == 0:
248
248
  attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True