ultralytics 8.3.142__py3-none-any.whl → 8.3.144__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. tests/conftest.py +7 -24
  2. tests/test_cli.py +1 -1
  3. tests/test_cuda.py +7 -2
  4. tests/test_engine.py +7 -8
  5. tests/test_exports.py +16 -16
  6. tests/test_integrations.py +1 -1
  7. tests/test_solutions.py +12 -12
  8. ultralytics/__init__.py +1 -1
  9. ultralytics/cfg/__init__.py +16 -13
  10. ultralytics/data/annotator.py +6 -5
  11. ultralytics/data/augment.py +127 -126
  12. ultralytics/data/base.py +54 -51
  13. ultralytics/data/build.py +47 -23
  14. ultralytics/data/converter.py +47 -43
  15. ultralytics/data/dataset.py +51 -50
  16. ultralytics/data/loaders.py +77 -44
  17. ultralytics/data/split.py +22 -9
  18. ultralytics/data/split_dota.py +63 -39
  19. ultralytics/data/utils.py +59 -39
  20. ultralytics/engine/exporter.py +79 -27
  21. ultralytics/engine/model.py +39 -39
  22. ultralytics/engine/predictor.py +37 -28
  23. ultralytics/engine/results.py +187 -157
  24. ultralytics/engine/trainer.py +36 -19
  25. ultralytics/engine/tuner.py +12 -9
  26. ultralytics/engine/validator.py +7 -9
  27. ultralytics/hub/__init__.py +11 -13
  28. ultralytics/hub/auth.py +22 -2
  29. ultralytics/hub/google/__init__.py +19 -19
  30. ultralytics/hub/session.py +37 -51
  31. ultralytics/hub/utils.py +19 -5
  32. ultralytics/models/fastsam/model.py +30 -12
  33. ultralytics/models/fastsam/predict.py +5 -6
  34. ultralytics/models/fastsam/utils.py +3 -3
  35. ultralytics/models/fastsam/val.py +10 -6
  36. ultralytics/models/nas/model.py +9 -5
  37. ultralytics/models/nas/predict.py +6 -6
  38. ultralytics/models/nas/val.py +3 -3
  39. ultralytics/models/rtdetr/model.py +7 -6
  40. ultralytics/models/rtdetr/predict.py +14 -7
  41. ultralytics/models/rtdetr/train.py +10 -4
  42. ultralytics/models/rtdetr/val.py +36 -9
  43. ultralytics/models/sam/amg.py +30 -12
  44. ultralytics/models/sam/build.py +22 -22
  45. ultralytics/models/sam/model.py +10 -9
  46. ultralytics/models/sam/modules/blocks.py +76 -80
  47. ultralytics/models/sam/modules/decoders.py +6 -8
  48. ultralytics/models/sam/modules/encoders.py +23 -26
  49. ultralytics/models/sam/modules/memory_attention.py +13 -1
  50. ultralytics/models/sam/modules/sam.py +57 -26
  51. ultralytics/models/sam/modules/tiny_encoder.py +232 -237
  52. ultralytics/models/sam/modules/transformer.py +13 -13
  53. ultralytics/models/sam/modules/utils.py +11 -19
  54. ultralytics/models/sam/predict.py +114 -101
  55. ultralytics/models/utils/loss.py +98 -77
  56. ultralytics/models/utils/ops.py +116 -67
  57. ultralytics/models/yolo/classify/predict.py +5 -5
  58. ultralytics/models/yolo/classify/train.py +32 -28
  59. ultralytics/models/yolo/classify/val.py +7 -8
  60. ultralytics/models/yolo/detect/predict.py +1 -0
  61. ultralytics/models/yolo/detect/train.py +15 -14
  62. ultralytics/models/yolo/detect/val.py +37 -36
  63. ultralytics/models/yolo/model.py +106 -23
  64. ultralytics/models/yolo/obb/predict.py +3 -4
  65. ultralytics/models/yolo/obb/train.py +14 -6
  66. ultralytics/models/yolo/obb/val.py +29 -23
  67. ultralytics/models/yolo/pose/predict.py +9 -8
  68. ultralytics/models/yolo/pose/train.py +24 -16
  69. ultralytics/models/yolo/pose/val.py +44 -26
  70. ultralytics/models/yolo/segment/predict.py +5 -5
  71. ultralytics/models/yolo/segment/train.py +11 -7
  72. ultralytics/models/yolo/segment/val.py +2 -2
  73. ultralytics/models/yolo/world/train.py +33 -23
  74. ultralytics/models/yolo/world/train_world.py +11 -3
  75. ultralytics/models/yolo/yoloe/predict.py +11 -11
  76. ultralytics/models/yolo/yoloe/train.py +73 -21
  77. ultralytics/models/yolo/yoloe/train_seg.py +10 -7
  78. ultralytics/models/yolo/yoloe/val.py +42 -18
  79. ultralytics/nn/autobackend.py +59 -15
  80. ultralytics/nn/modules/__init__.py +4 -4
  81. ultralytics/nn/modules/activation.py +4 -1
  82. ultralytics/nn/modules/block.py +178 -111
  83. ultralytics/nn/modules/conv.py +6 -5
  84. ultralytics/nn/modules/head.py +469 -121
  85. ultralytics/nn/modules/transformer.py +147 -58
  86. ultralytics/nn/tasks.py +227 -20
  87. ultralytics/nn/text_model.py +30 -33
  88. ultralytics/solutions/ai_gym.py +1 -1
  89. ultralytics/solutions/analytics.py +7 -4
  90. ultralytics/solutions/config.py +10 -10
  91. ultralytics/solutions/distance_calculation.py +11 -10
  92. ultralytics/solutions/heatmap.py +1 -1
  93. ultralytics/solutions/instance_segmentation.py +6 -3
  94. ultralytics/solutions/object_blurrer.py +3 -3
  95. ultralytics/solutions/object_counter.py +16 -8
  96. ultralytics/solutions/object_cropper.py +12 -5
  97. ultralytics/solutions/parking_management.py +29 -28
  98. ultralytics/solutions/queue_management.py +6 -6
  99. ultralytics/solutions/region_counter.py +10 -3
  100. ultralytics/solutions/security_alarm.py +3 -3
  101. ultralytics/solutions/similarity_search.py +85 -24
  102. ultralytics/solutions/solutions.py +215 -85
  103. ultralytics/solutions/speed_estimation.py +28 -22
  104. ultralytics/solutions/streamlit_inference.py +17 -12
  105. ultralytics/solutions/trackzone.py +4 -4
  106. ultralytics/trackers/basetrack.py +16 -23
  107. ultralytics/trackers/bot_sort.py +30 -20
  108. ultralytics/trackers/byte_tracker.py +70 -64
  109. ultralytics/trackers/track.py +4 -8
  110. ultralytics/trackers/utils/gmc.py +31 -58
  111. ultralytics/trackers/utils/kalman_filter.py +37 -37
  112. ultralytics/trackers/utils/matching.py +1 -1
  113. ultralytics/utils/__init__.py +105 -89
  114. ultralytics/utils/autobatch.py +16 -3
  115. ultralytics/utils/autodevice.py +54 -24
  116. ultralytics/utils/benchmarks.py +42 -28
  117. ultralytics/utils/callbacks/base.py +3 -3
  118. ultralytics/utils/callbacks/clearml.py +9 -9
  119. ultralytics/utils/callbacks/comet.py +67 -25
  120. ultralytics/utils/callbacks/dvc.py +7 -10
  121. ultralytics/utils/callbacks/mlflow.py +2 -5
  122. ultralytics/utils/callbacks/neptune.py +7 -13
  123. ultralytics/utils/callbacks/raytune.py +1 -1
  124. ultralytics/utils/callbacks/tensorboard.py +5 -6
  125. ultralytics/utils/callbacks/wb.py +14 -14
  126. ultralytics/utils/checks.py +14 -13
  127. ultralytics/utils/dist.py +5 -5
  128. ultralytics/utils/downloads.py +94 -67
  129. ultralytics/utils/errors.py +5 -5
  130. ultralytics/utils/export.py +61 -47
  131. ultralytics/utils/files.py +23 -22
  132. ultralytics/utils/instance.py +48 -52
  133. ultralytics/utils/loss.py +78 -40
  134. ultralytics/utils/metrics.py +186 -130
  135. ultralytics/utils/ops.py +186 -190
  136. ultralytics/utils/patches.py +15 -17
  137. ultralytics/utils/plotting.py +71 -27
  138. ultralytics/utils/tal.py +21 -15
  139. ultralytics/utils/torch_utils.py +53 -50
  140. ultralytics/utils/triton.py +5 -4
  141. ultralytics/utils/tuner.py +5 -5
  142. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/METADATA +1 -1
  143. ultralytics-8.3.144.dist-info/RECORD +272 -0
  144. ultralytics-8.3.142.dist-info/RECORD +0 -272
  145. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/WHEEL +0 -0
  146. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/entry_points.txt +0 -0
  147. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/licenses/LICENSE +0 -0
  148. {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/top_level.txt +0 -0
@@ -10,7 +10,7 @@
10
10
  # --------------------------------------------------------
11
11
 
12
12
  import itertools
13
- from typing import Tuple
13
+ from typing import List, Optional, Tuple, Union
14
14
 
15
15
  import torch
16
16
  import torch.nn as nn
@@ -24,32 +24,46 @@ class Conv2d_BN(torch.nn.Sequential):
24
24
  """
25
25
  A sequential container that performs 2D convolution followed by batch normalization.
26
26
 
27
+ This module combines a 2D convolution layer with batch normalization, providing a common building block
28
+ for convolutional neural networks. The batch normalization weights and biases are initialized to specific
29
+ values for optimal training performance.
30
+
27
31
  Attributes:
28
32
  c (torch.nn.Conv2d): 2D convolution layer.
29
33
  bn (torch.nn.BatchNorm2d): Batch normalization layer.
30
34
 
31
- Methods:
32
- __init__: Initializes the Conv2d_BN with specified parameters.
33
-
34
- Args:
35
- a (int): Number of input channels.
36
- b (int): Number of output channels.
37
- ks (int): Kernel size for the convolution. Defaults to 1.
38
- stride (int): Stride for the convolution. Defaults to 1.
39
- pad (int): Padding for the convolution. Defaults to 0.
40
- dilation (int): Dilation factor for the convolution. Defaults to 1.
41
- groups (int): Number of groups for the convolution. Defaults to 1.
42
- bn_weight_init (float): Initial value for batch normalization weight. Defaults to 1.
43
-
44
35
  Examples:
45
36
  >>> conv_bn = Conv2d_BN(3, 64, ks=3, stride=1, pad=1)
46
37
  >>> input_tensor = torch.randn(1, 3, 224, 224)
47
38
  >>> output = conv_bn(input_tensor)
48
39
  >>> print(output.shape)
40
+ torch.Size([1, 64, 224, 224])
49
41
  """
50
42
 
51
- def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
52
- """Initializes a sequential container with 2D convolution followed by batch normalization."""
43
+ def __init__(
44
+ self,
45
+ a: int,
46
+ b: int,
47
+ ks: int = 1,
48
+ stride: int = 1,
49
+ pad: int = 0,
50
+ dilation: int = 1,
51
+ groups: int = 1,
52
+ bn_weight_init: float = 1,
53
+ ):
54
+ """
55
+ Initialize a sequential container with 2D convolution followed by batch normalization.
56
+
57
+ Args:
58
+ a (int): Number of input channels.
59
+ b (int): Number of output channels.
60
+ ks (int, optional): Kernel size for the convolution.
61
+ stride (int, optional): Stride for the convolution.
62
+ pad (int, optional): Padding for the convolution.
63
+ dilation (int, optional): Dilation factor for the convolution.
64
+ groups (int, optional): Number of groups for the convolution.
65
+ bn_weight_init (float, optional): Initial value for batch normalization weight.
66
+ """
53
67
  super().__init__()
54
68
  self.add_module("c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
55
69
  bn = torch.nn.BatchNorm2d(b)
@@ -60,7 +74,10 @@ class Conv2d_BN(torch.nn.Sequential):
60
74
 
61
75
  class PatchEmbed(nn.Module):
62
76
  """
63
- Embeds images into patches and projects them into a specified embedding dimension.
77
+ Embed images into patches and project them into a specified embedding dimension.
78
+
79
+ This module converts input images into patch embeddings using a sequence of convolutional layers,
80
+ effectively downsampling the spatial dimensions while increasing the channel dimension.
64
81
 
65
82
  Attributes:
66
83
  patches_resolution (Tuple[int, int]): Resolution of the patches after embedding.
@@ -69,19 +86,25 @@ class PatchEmbed(nn.Module):
69
86
  embed_dim (int): Dimension of the embedding.
70
87
  seq (nn.Sequential): Sequence of convolutional and activation layers for patch embedding.
71
88
 
72
- Methods:
73
- forward: Processes the input tensor through the patch embedding sequence.
74
-
75
89
  Examples:
76
90
  >>> import torch
77
91
  >>> patch_embed = PatchEmbed(in_chans=3, embed_dim=96, resolution=224, activation=nn.GELU)
78
92
  >>> x = torch.randn(1, 3, 224, 224)
79
93
  >>> output = patch_embed(x)
80
94
  >>> print(output.shape)
95
+ torch.Size([1, 96, 56, 56])
81
96
  """
82
97
 
83
- def __init__(self, in_chans, embed_dim, resolution, activation):
84
- """Initializes patch embedding with convolutional layers for image-to-patch conversion and projection."""
98
+ def __init__(self, in_chans: int, embed_dim: int, resolution: int, activation):
99
+ """
100
+ Initialize patch embedding with convolutional layers for image-to-patch conversion and projection.
101
+
102
+ Args:
103
+ in_chans (int): Number of input channels.
104
+ embed_dim (int): Dimension of the embedding.
105
+ resolution (int): Input image resolution.
106
+ activation (nn.Module): Activation function to use between convolutions.
107
+ """
85
108
  super().__init__()
86
109
  img_size: Tuple[int, int] = to_2tuple(resolution)
87
110
  self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
@@ -95,8 +118,8 @@ class PatchEmbed(nn.Module):
95
118
  Conv2d_BN(n // 2, n, 3, 2, 1),
96
119
  )
97
120
 
98
- def forward(self, x):
99
- """Processes input tensor through patch embedding sequence, converting images to patch embeddings."""
121
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
122
+ """Process input tensor through patch embedding sequence, converting images to patch embeddings."""
100
123
  return self.seq(x)
101
124
 
102
125
 
@@ -104,21 +127,21 @@ class MBConv(nn.Module):
104
127
  """
105
128
  Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture.
106
129
 
130
+ This module implements the mobile inverted bottleneck convolution with expansion, depthwise convolution,
131
+ and projection phases, along with residual connections for improved gradient flow.
132
+
107
133
  Attributes:
108
134
  in_chans (int): Number of input channels.
109
- hidden_chans (int): Number of hidden channels.
135
+ hidden_chans (int): Number of hidden channels after expansion.
110
136
  out_chans (int): Number of output channels.
111
- conv1 (Conv2d_BN): First convolutional layer.
137
+ conv1 (Conv2d_BN): First convolutional layer for channel expansion.
112
138
  act1 (nn.Module): First activation function.
113
139
  conv2 (Conv2d_BN): Depthwise convolutional layer.
114
140
  act2 (nn.Module): Second activation function.
115
- conv3 (Conv2d_BN): Final convolutional layer.
141
+ conv3 (Conv2d_BN): Final convolutional layer for projection.
116
142
  act3 (nn.Module): Third activation function.
117
143
  drop_path (nn.Module): Drop path layer (Identity for inference).
118
144
 
119
- Methods:
120
- forward: Performs the forward pass through the MBConv layer.
121
-
122
145
  Examples:
123
146
  >>> in_chans, out_chans = 32, 64
124
147
  >>> mbconv = MBConv(in_chans, out_chans, expand_ratio=4, activation=nn.ReLU, drop_path=0.1)
@@ -128,8 +151,17 @@ class MBConv(nn.Module):
128
151
  torch.Size([1, 64, 56, 56])
129
152
  """
130
153
 
131
- def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
132
- """Initializes the MBConv layer with specified input/output channels, expansion ratio, and activation."""
154
+ def __init__(self, in_chans: int, out_chans: int, expand_ratio: float, activation, drop_path: float):
155
+ """
156
+ Initialize the MBConv layer with specified input/output channels, expansion ratio, and activation.
157
+
158
+ Args:
159
+ in_chans (int): Number of input channels.
160
+ out_chans (int): Number of output channels.
161
+ expand_ratio (float): Channel expansion ratio for the hidden layer.
162
+ activation (nn.Module): Activation function to use.
163
+ drop_path (float): Drop path rate for stochastic depth.
164
+ """
133
165
  super().__init__()
134
166
  self.in_chans = in_chans
135
167
  self.hidden_chans = int(in_chans * expand_ratio)
@@ -148,8 +180,8 @@ class MBConv(nn.Module):
148
180
  # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
149
181
  self.drop_path = nn.Identity()
150
182
 
151
- def forward(self, x):
152
- """Implements the forward pass of MBConv, applying convolutions and skip connection."""
183
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
184
+ """Implement the forward pass of MBConv, applying convolutions and skip connection."""
153
185
  shortcut = x
154
186
  x = self.conv1(x)
155
187
  x = self.act1(x)
@@ -163,10 +195,11 @@ class MBConv(nn.Module):
163
195
 
164
196
  class PatchMerging(nn.Module):
165
197
  """
166
- Merges neighboring patches in the feature map and projects to a new dimension.
198
+ Merge neighboring patches in the feature map and project to a new dimension.
167
199
 
168
200
  This class implements a patch merging operation that combines spatial information and adjusts the feature
169
- dimension. It uses a series of convolutional layers with batch normalization to achieve this.
201
+ dimension using a series of convolutional layers with batch normalization. It effectively reduces spatial
202
+ resolution while potentially increasing channel dimensions.
170
203
 
171
204
  Attributes:
172
205
  input_resolution (Tuple[int, int]): The input resolution (height, width) of the feature map.
@@ -177,19 +210,25 @@ class PatchMerging(nn.Module):
177
210
  conv2 (Conv2d_BN): The second convolutional layer for spatial merging.
178
211
  conv3 (Conv2d_BN): The third convolutional layer for final projection.
179
212
 
180
- Methods:
181
- forward: Applies the patch merging operation to the input tensor.
182
-
183
213
  Examples:
184
214
  >>> input_resolution = (56, 56)
185
215
  >>> patch_merging = PatchMerging(input_resolution, dim=64, out_dim=128, activation=nn.ReLU)
186
216
  >>> x = torch.randn(4, 64, 56, 56)
187
217
  >>> output = patch_merging(x)
188
218
  >>> print(output.shape)
219
+ torch.Size([4, 3136, 128])
189
220
  """
190
221
 
191
- def __init__(self, input_resolution, dim, out_dim, activation):
192
- """Initializes the PatchMerging module for merging and projecting neighboring patches in feature maps."""
222
+ def __init__(self, input_resolution: Tuple[int, int], dim: int, out_dim: int, activation):
223
+ """
224
+ Initialize the PatchMerging module for merging and projecting neighboring patches in feature maps.
225
+
226
+ Args:
227
+ input_resolution (Tuple[int, int]): The input resolution (height, width) of the feature map.
228
+ dim (int): The input dimension of the feature map.
229
+ out_dim (int): The output dimension after merging and projection.
230
+ activation (nn.Module): The activation function used between convolutions.
231
+ """
193
232
  super().__init__()
194
233
 
195
234
  self.input_resolution = input_resolution
@@ -201,8 +240,8 @@ class PatchMerging(nn.Module):
201
240
  self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
202
241
  self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
203
242
 
204
- def forward(self, x):
205
- """Applies patch merging and dimension projection to the input feature map."""
243
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
244
+ """Apply patch merging and dimension projection to the input feature map."""
206
245
  if x.ndim == 3:
207
246
  H, W = self.input_resolution
208
247
  B = len(x)
@@ -222,7 +261,8 @@ class ConvLayer(nn.Module):
222
261
  """
223
262
  Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv).
224
263
 
225
- This layer optionally applies downsample operations to the output and supports gradient checkpointing.
264
+ This layer optionally applies downsample operations to the output and supports gradient checkpointing
265
+ for memory efficiency during training.
226
266
 
227
267
  Attributes:
228
268
  dim (int): Dimensionality of the input and output.
@@ -230,32 +270,30 @@ class ConvLayer(nn.Module):
230
270
  depth (int): Number of MBConv layers in the block.
231
271
  use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
232
272
  blocks (nn.ModuleList): List of MBConv layers.
233
- downsample (Optional[Callable]): Function for downsampling the output.
234
-
235
- Methods:
236
- forward: Processes the input through the convolutional layers.
273
+ downsample (Optional[nn.Module]): Function for downsampling the output.
237
274
 
238
275
  Examples:
239
276
  >>> input_tensor = torch.randn(1, 64, 56, 56)
240
277
  >>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU)
241
278
  >>> output = conv_layer(input_tensor)
242
279
  >>> print(output.shape)
280
+ torch.Size([1, 3136, 128])
243
281
  """
244
282
 
245
283
  def __init__(
246
284
  self,
247
- dim,
248
- input_resolution,
249
- depth,
285
+ dim: int,
286
+ input_resolution: Tuple[int, int],
287
+ depth: int,
250
288
  activation,
251
- drop_path=0.0,
252
- downsample=None,
253
- use_checkpoint=False,
254
- out_dim=None,
255
- conv_expand_ratio=4.0,
289
+ drop_path: Union[float, List[float]] = 0.0,
290
+ downsample: Optional[nn.Module] = None,
291
+ use_checkpoint: bool = False,
292
+ out_dim: Optional[int] = None,
293
+ conv_expand_ratio: float = 4.0,
256
294
  ):
257
295
  """
258
- Initializes the ConvLayer with the given dimensions and settings.
296
+ Initialize the ConvLayer with the given dimensions and settings.
259
297
 
260
298
  This layer consists of multiple MobileNetV3-style inverted bottleneck convolutions (MBConv) and
261
299
  optionally applies downsampling to the output.
@@ -265,17 +303,11 @@ class ConvLayer(nn.Module):
265
303
  input_resolution (Tuple[int, int]): The resolution of the input image.
266
304
  depth (int): The number of MBConv layers in the block.
267
305
  activation (nn.Module): Activation function applied after each convolution.
268
- drop_path (float | List[float]): Drop path rate. Single float or a list of floats for each MBConv.
269
- downsample (Optional[nn.Module]): Function for downsampling the output. None to skip downsampling.
270
- use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
271
- out_dim (Optional[int]): The dimensionality of the output. None means it will be the same as `dim`.
272
- conv_expand_ratio (float): Expansion ratio for the MBConv layers.
273
-
274
- Examples:
275
- >>> input_tensor = torch.randn(1, 64, 56, 56)
276
- >>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU)
277
- >>> output = conv_layer(input_tensor)
278
- >>> print(output.shape)
306
+ drop_path (float | List[float], optional): Drop path rate. Single float or a list of floats for each MBConv.
307
+ downsample (Optional[nn.Module], optional): Function for downsampling the output. None to skip downsampling.
308
+ use_checkpoint (bool, optional): Whether to use gradient checkpointing to save memory.
309
+ out_dim (Optional[int], optional): The dimensionality of the output. None means it will be the same as `dim`.
310
+ conv_expand_ratio (float, optional): Expansion ratio for the MBConv layers.
279
311
  """
280
312
  super().__init__()
281
313
  self.dim = dim
@@ -304,19 +336,19 @@ class ConvLayer(nn.Module):
304
336
  else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
305
337
  )
306
338
 
307
- def forward(self, x):
308
- """Processes input through convolutional layers, applying MBConv blocks and optional downsampling."""
339
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
340
+ """Process input through convolutional layers, applying MBConv blocks and optional downsampling."""
309
341
  for blk in self.blocks:
310
342
  x = torch.utils.checkpoint(blk, x) if self.use_checkpoint else blk(x) # warn: checkpoint is slow import
311
343
  return x if self.downsample is None else self.downsample(x)
312
344
 
313
345
 
314
- class Mlp(nn.Module):
346
+ class MLP(nn.Module):
315
347
  """
316
348
  Multi-layer Perceptron (MLP) module for transformer architectures.
317
349
 
318
350
  This module applies layer normalization, two fully-connected layers with an activation function in between,
319
- and dropout. It is commonly used in transformer-based architectures.
351
+ and dropout. It is commonly used in transformer-based architectures for processing token embeddings.
320
352
 
321
353
  Attributes:
322
354
  norm (nn.LayerNorm): Layer normalization applied to the input.
@@ -325,32 +357,45 @@ class Mlp(nn.Module):
325
357
  act (nn.Module): Activation function applied after the first fully-connected layer.
326
358
  drop (nn.Dropout): Dropout layer applied after the activation function.
327
359
 
328
- Methods:
329
- forward: Applies the MLP operations on the input tensor.
330
-
331
360
  Examples:
332
361
  >>> import torch
333
362
  >>> from torch import nn
334
- >>> mlp = Mlp(in_features=256, hidden_features=512, out_features=256, act_layer=nn.GELU, drop=0.1)
363
+ >>> mlp = MLP(in_features=256, hidden_features=512, out_features=256, activation=nn.GELU, drop=0.1)
335
364
  >>> x = torch.randn(32, 100, 256)
336
365
  >>> output = mlp(x)
337
366
  >>> print(output.shape)
338
367
  torch.Size([32, 100, 256])
339
368
  """
340
369
 
341
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
342
- """Initializes a multi-layer perceptron with configurable input, hidden, and output dimensions."""
370
+ def __init__(
371
+ self,
372
+ in_features: int,
373
+ hidden_features: Optional[int] = None,
374
+ out_features: Optional[int] = None,
375
+ activation=nn.GELU,
376
+ drop: float = 0.0,
377
+ ):
378
+ """
379
+ Initialize a multi-layer perceptron with configurable input, hidden, and output dimensions.
380
+
381
+ Args:
382
+ in_features (int): Number of input features.
383
+ hidden_features (Optional[int], optional): Number of hidden features.
384
+ out_features (Optional[int], optional): Number of output features.
385
+ activation (nn.Module): Activation function applied after the first fully-connected layer.
386
+ drop (float, optional): Dropout probability.
387
+ """
343
388
  super().__init__()
344
389
  out_features = out_features or in_features
345
390
  hidden_features = hidden_features or in_features
346
391
  self.norm = nn.LayerNorm(in_features)
347
392
  self.fc1 = nn.Linear(in_features, hidden_features)
348
393
  self.fc2 = nn.Linear(hidden_features, out_features)
349
- self.act = act_layer()
394
+ self.act = activation()
350
395
  self.drop = nn.Dropout(drop)
351
396
 
352
- def forward(self, x):
353
- """Applies MLP operations: layer norm, FC layers, activation, and dropout to the input tensor."""
397
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
398
+ """Apply MLP operations: layer norm, FC layers, activation, and dropout to the input tensor."""
354
399
  x = self.norm(x)
355
400
  x = self.fc1(x)
356
401
  x = self.act(x)
@@ -379,12 +424,8 @@ class Attention(torch.nn.Module):
379
424
  qkv (nn.Linear): Linear layer for computing query, key, and value projections.
380
425
  proj (nn.Linear): Linear layer for final projection.
381
426
  attention_biases (nn.Parameter): Learnable attention biases.
382
- attention_bias_idxs (Tensor): Indices for attention biases.
383
- ab (Tensor): Cached attention biases for inference, deleted during training.
384
-
385
- Methods:
386
- train: Sets the module in training mode and handles the 'ab' attribute.
387
- forward: Performs the forward pass of the attention mechanism.
427
+ attention_bias_idxs (torch.Tensor): Indices for attention biases.
428
+ ab (torch.Tensor): Cached attention biases for inference, deleted during training.
388
429
 
389
430
  Examples:
390
431
  >>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14))
@@ -396,14 +437,14 @@ class Attention(torch.nn.Module):
396
437
 
397
438
  def __init__(
398
439
  self,
399
- dim,
400
- key_dim,
401
- num_heads=8,
402
- attn_ratio=4,
403
- resolution=(14, 14),
440
+ dim: int,
441
+ key_dim: int,
442
+ num_heads: int = 8,
443
+ attn_ratio: float = 4,
444
+ resolution: Tuple[int, int] = (14, 14),
404
445
  ):
405
446
  """
406
- Initializes the Attention module for multi-head attention with spatial awareness.
447
+ Initialize the Attention module for multi-head attention with spatial awareness.
407
448
 
408
449
  This module implements a multi-head attention mechanism with support for spatial awareness, applying
409
450
  attention biases based on spatial resolution. It includes trainable attention biases for each unique
@@ -412,16 +453,9 @@ class Attention(torch.nn.Module):
412
453
  Args:
413
454
  dim (int): The dimensionality of the input and output.
414
455
  key_dim (int): The dimensionality of the keys and queries.
415
- num_heads (int): Number of attention heads.
416
- attn_ratio (float): Attention ratio, affecting the dimensions of the value vectors.
417
- resolution (Tuple[int, int]): Spatial resolution of the input feature map.
418
-
419
- Examples:
420
- >>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14))
421
- >>> x = torch.randn(1, 196, 256)
422
- >>> output = attn(x)
423
- >>> print(output.shape)
424
- torch.Size([1, 196, 256])
456
+ num_heads (int, optional): Number of attention heads.
457
+ attn_ratio (float, optional): Attention ratio, affecting the dimensions of the value vectors.
458
+ resolution (Tuple[int, int], optional): Spatial resolution of the input feature map.
425
459
  """
426
460
  super().__init__()
427
461
 
@@ -453,16 +487,16 @@ class Attention(torch.nn.Module):
453
487
  self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(N, N), persistent=False)
454
488
 
455
489
  @torch.no_grad()
456
- def train(self, mode=True):
457
- """Performs multi-head attention with spatial awareness and trainable attention biases."""
490
+ def train(self, mode: bool = True):
491
+ """Set the module in training mode and handle the 'ab' attribute for cached attention biases."""
458
492
  super().train(mode)
459
493
  if mode and hasattr(self, "ab"):
460
494
  del self.ab
461
495
  else:
462
496
  self.ab = self.attention_biases[:, self.attention_bias_idxs]
463
497
 
464
- def forward(self, x): # x
465
- """Applies multi-head attention with spatial awareness and trainable attention biases."""
498
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
499
+ """Apply multi-head attention with spatial awareness and trainable attention biases."""
466
500
  B, N, _ = x.shape # B, N, C
467
501
 
468
502
  # Normalization
@@ -490,7 +524,8 @@ class TinyViTBlock(nn.Module):
490
524
  TinyViT Block that applies self-attention and a local convolution to the input.
491
525
 
492
526
  This block is a key component of the TinyViT architecture, combining self-attention mechanisms with
493
- local convolutions to process input features efficiently.
527
+ local convolutions to process input features efficiently. It supports windowed attention for
528
+ computational efficiency and includes residual connections.
494
529
 
495
530
  Attributes:
496
531
  dim (int): The dimensionality of the input and output.
@@ -500,13 +535,9 @@ class TinyViTBlock(nn.Module):
500
535
  mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
501
536
  drop_path (nn.Module): Stochastic depth layer, identity function during inference.
502
537
  attn (Attention): Self-attention module.
503
- mlp (Mlp): Multi-layer perceptron module.
538
+ mlp (MLP): Multi-layer perceptron module.
504
539
  local_conv (Conv2d_BN): Depth-wise local convolution layer.
505
540
 
506
- Methods:
507
- forward: Processes the input through the TinyViT block.
508
- extra_repr: Returns a string with extra information about the block's parameters.
509
-
510
541
  Examples:
511
542
  >>> input_tensor = torch.randn(1, 196, 192)
512
543
  >>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3)
@@ -517,18 +548,18 @@ class TinyViTBlock(nn.Module):
517
548
 
518
549
  def __init__(
519
550
  self,
520
- dim,
521
- input_resolution,
522
- num_heads,
523
- window_size=7,
524
- mlp_ratio=4.0,
525
- drop=0.0,
526
- drop_path=0.0,
527
- local_conv_size=3,
551
+ dim: int,
552
+ input_resolution: Tuple[int, int],
553
+ num_heads: int,
554
+ window_size: int = 7,
555
+ mlp_ratio: float = 4.0,
556
+ drop: float = 0.0,
557
+ drop_path: float = 0.0,
558
+ local_conv_size: int = 3,
528
559
  activation=nn.GELU,
529
560
  ):
530
561
  """
531
- Initializes a TinyViT block with self-attention and local convolution.
562
+ Initialize a TinyViT block with self-attention and local convolution.
532
563
 
533
564
  This block is a key component of the TinyViT architecture, combining self-attention mechanisms with
534
565
  local convolutions to process input features efficiently.
@@ -537,23 +568,12 @@ class TinyViTBlock(nn.Module):
537
568
  dim (int): Dimensionality of the input and output features.
538
569
  input_resolution (Tuple[int, int]): Spatial resolution of the input feature map (height, width).
539
570
  num_heads (int): Number of attention heads.
540
- window_size (int): Size of the attention window. Must be greater than 0.
541
- mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
542
- drop (float): Dropout rate.
543
- drop_path (float): Stochastic depth rate.
544
- local_conv_size (int): Kernel size of the local convolution.
545
- activation (torch.nn.Module): Activation function for MLP.
546
-
547
- Raises:
548
- AssertionError: If window_size is not greater than 0.
549
- AssertionError: If dim is not divisible by num_heads.
550
-
551
- Examples:
552
- >>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3)
553
- >>> input_tensor = torch.randn(1, 196, 192)
554
- >>> output = block(input_tensor)
555
- >>> print(output.shape)
556
- torch.Size([1, 196, 192])
571
+ window_size (int, optional): Size of the attention window. Must be greater than 0.
572
+ mlp_ratio (float, optional): Ratio of MLP hidden dimension to embedding dimension.
573
+ drop (float, optional): Dropout rate.
574
+ drop_path (float, optional): Stochastic depth rate.
575
+ local_conv_size (int, optional): Kernel size of the local convolution.
576
+ activation (nn.Module): Activation function for MLP.
557
577
  """
558
578
  super().__init__()
559
579
  self.dim = dim
@@ -575,13 +595,13 @@ class TinyViTBlock(nn.Module):
575
595
 
576
596
  mlp_hidden_dim = int(dim * mlp_ratio)
577
597
  mlp_activation = activation
578
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=mlp_activation, drop=drop)
598
+ self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, activation=mlp_activation, drop=drop)
579
599
 
580
600
  pad = local_conv_size // 2
581
601
  self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
582
602
 
583
- def forward(self, x):
584
- """Applies self-attention, local convolution, and MLP operations to the input tensor."""
603
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
604
+ """Apply self-attention, local convolution, and MLP operations to the input tensor."""
585
605
  h, w = self.input_resolution
586
606
  b, hw, c = x.shape # batch, height*width, channels
587
607
  assert hw == h * w, "input feature has wrong size"
@@ -624,7 +644,7 @@ class TinyViTBlock(nn.Module):
624
644
 
625
645
  def extra_repr(self) -> str:
626
646
  """
627
- Returns a string representation of the TinyViTBlock's parameters.
647
+ Return a string representation of the TinyViTBlock's parameters.
628
648
 
629
649
  This method provides a formatted string containing key information about the TinyViTBlock, including its
630
650
  dimension, input resolution, number of attention heads, window size, and MLP ratio.
@@ -648,7 +668,8 @@ class BasicLayer(nn.Module):
648
668
  A basic TinyViT layer for one stage in a TinyViT architecture.
649
669
 
650
670
  This class represents a single layer in the TinyViT model, consisting of multiple TinyViT blocks
651
- and an optional downsampling operation.
671
+ and an optional downsampling operation. It processes features at a specific resolution and
672
+ dimensionality within the overall architecture.
652
673
 
653
674
  Attributes:
654
675
  dim (int): The dimensionality of the input and output features.
@@ -658,10 +679,6 @@ class BasicLayer(nn.Module):
658
679
  blocks (nn.ModuleList): List of TinyViT blocks that make up this layer.
659
680
  downsample (nn.Module | None): Downsample layer at the end of the layer, if specified.
660
681
 
661
- Methods:
662
- forward: Processes the input through the layer's blocks and optional downsampling.
663
- extra_repr: Returns a string with the layer's parameters for printing.
664
-
665
682
  Examples:
666
683
  >>> input_tensor = torch.randn(1, 3136, 192)
667
684
  >>> layer = BasicLayer(dim=192, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)
@@ -672,22 +689,22 @@ class BasicLayer(nn.Module):
672
689
 
673
690
  def __init__(
674
691
  self,
675
- dim,
676
- input_resolution,
677
- depth,
678
- num_heads,
679
- window_size,
680
- mlp_ratio=4.0,
681
- drop=0.0,
682
- drop_path=0.0,
683
- downsample=None,
684
- use_checkpoint=False,
685
- local_conv_size=3,
692
+ dim: int,
693
+ input_resolution: Tuple[int, int],
694
+ depth: int,
695
+ num_heads: int,
696
+ window_size: int,
697
+ mlp_ratio: float = 4.0,
698
+ drop: float = 0.0,
699
+ drop_path: Union[float, List[float]] = 0.0,
700
+ downsample: Optional[nn.Module] = None,
701
+ use_checkpoint: bool = False,
702
+ local_conv_size: int = 3,
686
703
  activation=nn.GELU,
687
- out_dim=None,
704
+ out_dim: Optional[int] = None,
688
705
  ):
689
706
  """
690
- Initializes a BasicLayer in the TinyViT architecture.
707
+ Initialize a BasicLayer in the TinyViT architecture.
691
708
 
692
709
  This layer consists of multiple TinyViT blocks and an optional downsampling operation. It is designed to
693
710
  process feature maps at a specific resolution and dimensionality within the TinyViT model.
@@ -698,23 +715,14 @@ class BasicLayer(nn.Module):
698
715
  depth (int): Number of TinyViT blocks in this layer.
699
716
  num_heads (int): Number of attention heads in each TinyViT block.
700
717
  window_size (int): Size of the local window for attention computation.
701
- mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
702
- drop (float): Dropout rate.
703
- drop_path (float | List[float]): Stochastic depth rate. Can be a float or a list of floats for each block.
704
- downsample (nn.Module | None): Downsampling layer at the end of the layer. None to skip downsampling.
705
- use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
706
- local_conv_size (int): Kernel size for the local convolution in each TinyViT block.
718
+ mlp_ratio (float, optional): Ratio of MLP hidden dimension to embedding dimension.
719
+ drop (float, optional): Dropout rate.
720
+ drop_path (float | List[float], optional): Stochastic depth rate. Can be a float or a list of floats for each block.
721
+ downsample (nn.Module | None, optional): Downsampling layer at the end of the layer. None to skip downsampling.
722
+ use_checkpoint (bool, optional): Whether to use gradient checkpointing to save memory.
723
+ local_conv_size (int, optional): Kernel size for the local convolution in each TinyViT block.
707
724
  activation (nn.Module): Activation function used in the MLP.
708
- out_dim (int | None): Output dimension after downsampling. None means it will be the same as `dim`.
709
-
710
- Raises:
711
- ValueError: If `drop_path` is a list and its length doesn't match `depth`.
712
-
713
- Examples:
714
- >>> layer = BasicLayer(dim=96, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)
715
- >>> x = torch.randn(1, 56 * 56, 96)
716
- >>> output = layer(x)
717
- >>> print(output.shape)
725
+ out_dim (int | None, optional): Output dimension after downsampling. None means it will be the same as `dim`.
718
726
  """
719
727
  super().__init__()
720
728
  self.dim = dim
@@ -747,14 +755,14 @@ class BasicLayer(nn.Module):
747
755
  else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
748
756
  )
749
757
 
750
- def forward(self, x):
751
- """Processes input through TinyViT blocks and optional downsampling."""
758
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
759
+ """Process input through TinyViT blocks and optional downsampling."""
752
760
  for blk in self.blocks:
753
761
  x = torch.utils.checkpoint(blk, x) if self.use_checkpoint else blk(x) # warn: checkpoint is slow import
754
762
  return x if self.downsample is None else self.downsample(x)
755
763
 
756
764
  def extra_repr(self) -> str:
757
- """Returns a string with the layer's parameters for printing."""
765
+ """Return a string with the layer's parameters for printing."""
758
766
  return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
759
767
 
760
768
 
@@ -763,12 +771,13 @@ class TinyViT(nn.Module):
763
771
  TinyViT: A compact vision transformer architecture for efficient image classification and feature extraction.
764
772
 
765
773
  This class implements the TinyViT model, which combines elements of vision transformers and convolutional
766
- neural networks for improved efficiency and performance on vision tasks.
774
+ neural networks for improved efficiency and performance on vision tasks. It features hierarchical processing
775
+ with patch embedding, multiple stages of attention and convolution blocks, and a feature refinement neck.
767
776
 
768
777
  Attributes:
769
778
  img_size (int): Input image size.
770
779
  num_classes (int): Number of classification classes.
771
- depths (List[int]): Number of blocks in each stage.
780
+ depths (Tuple[int, int, int, int]): Number of blocks in each stage.
772
781
  num_layers (int): Total number of layers in the network.
773
782
  mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
774
783
  patch_embed (PatchEmbed): Module for patch embedding.
@@ -778,66 +787,52 @@ class TinyViT(nn.Module):
778
787
  head (nn.Linear): Linear layer for final classification.
779
788
  neck (nn.Sequential): Neck module for feature refinement.
780
789
 
781
- Methods:
782
- set_layer_lr_decay: Sets layer-wise learning rate decay.
783
- _init_weights: Initializes weights for linear and normalization layers.
784
- no_weight_decay_keywords: Returns keywords for parameters that should not use weight decay.
785
- forward_features: Processes input through the feature extraction layers.
786
- forward: Performs a forward pass through the entire network.
787
-
788
790
  Examples:
789
791
  >>> model = TinyViT(img_size=224, num_classes=1000)
790
792
  >>> x = torch.randn(1, 3, 224, 224)
791
793
  >>> features = model.forward_features(x)
792
794
  >>> print(features.shape)
793
- torch.Size([1, 256, 64, 64])
795
+ torch.Size([1, 256, 56, 56])
794
796
  """
795
797
 
796
798
  def __init__(
797
799
  self,
798
- img_size=224,
799
- in_chans=3,
800
- num_classes=1000,
801
- embed_dims=(96, 192, 384, 768),
802
- depths=(2, 2, 6, 2),
803
- num_heads=(3, 6, 12, 24),
804
- window_sizes=(7, 7, 14, 7),
805
- mlp_ratio=4.0,
806
- drop_rate=0.0,
807
- drop_path_rate=0.1,
808
- use_checkpoint=False,
809
- mbconv_expand_ratio=4.0,
810
- local_conv_size=3,
811
- layer_lr_decay=1.0,
800
+ img_size: int = 224,
801
+ in_chans: int = 3,
802
+ num_classes: int = 1000,
803
+ embed_dims: Tuple[int, int, int, int] = (96, 192, 384, 768),
804
+ depths: Tuple[int, int, int, int] = (2, 2, 6, 2),
805
+ num_heads: Tuple[int, int, int, int] = (3, 6, 12, 24),
806
+ window_sizes: Tuple[int, int, int, int] = (7, 7, 14, 7),
807
+ mlp_ratio: float = 4.0,
808
+ drop_rate: float = 0.0,
809
+ drop_path_rate: float = 0.1,
810
+ use_checkpoint: bool = False,
811
+ mbconv_expand_ratio: float = 4.0,
812
+ local_conv_size: int = 3,
813
+ layer_lr_decay: float = 1.0,
812
814
  ):
813
815
  """
814
- Initializes the TinyViT model.
816
+ Initialize the TinyViT model.
815
817
 
816
818
  This constructor sets up the TinyViT architecture, including patch embedding, multiple layers of
817
819
  attention and convolution blocks, and a classification head.
818
820
 
819
821
  Args:
820
- img_size (int): Size of the input image.
821
- in_chans (int): Number of input channels.
822
- num_classes (int): Number of classes for classification.
823
- embed_dims (Tuple[int, int, int, int]): Embedding dimensions for each stage.
824
- depths (Tuple[int, int, int, int]): Number of blocks in each stage.
825
- num_heads (Tuple[int, int, int, int]): Number of attention heads in each stage.
826
- window_sizes (Tuple[int, int, int, int]): Window sizes for each stage.
827
- mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
828
- drop_rate (float): Dropout rate.
829
- drop_path_rate (float): Stochastic depth rate.
830
- use_checkpoint (bool): Whether to use checkpointing to save memory.
831
- mbconv_expand_ratio (float): Expansion ratio for MBConv layer.
832
- local_conv_size (int): Kernel size for local convolutions.
833
- layer_lr_decay (float): Layer-wise learning rate decay factor.
834
-
835
- Examples:
836
- >>> model = TinyViT(img_size=224, num_classes=1000)
837
- >>> x = torch.randn(1, 3, 224, 224)
838
- >>> output = model(x)
839
- >>> print(output.shape)
840
- torch.Size([1, 1000])
822
+ img_size (int, optional): Size of the input image.
823
+ in_chans (int, optional): Number of input channels.
824
+ num_classes (int, optional): Number of classes for classification.
825
+ embed_dims (Tuple[int, int, int, int], optional): Embedding dimensions for each stage.
826
+ depths (Tuple[int, int, int, int], optional): Number of blocks in each stage.
827
+ num_heads (Tuple[int, int, int, int], optional): Number of attention heads in each stage.
828
+ window_sizes (Tuple[int, int, int, int], optional): Window sizes for each stage.
829
+ mlp_ratio (float, optional): Ratio of MLP hidden dim to embedding dim.
830
+ drop_rate (float, optional): Dropout rate.
831
+ drop_path_rate (float, optional): Stochastic depth rate.
832
+ use_checkpoint (bool, optional): Whether to use checkpointing to save memory.
833
+ mbconv_expand_ratio (float, optional): Expansion ratio for MBConv layer.
834
+ local_conv_size (int, optional): Kernel size for local convolutions.
835
+ layer_lr_decay (float, optional): Layer-wise learning rate decay factor.
841
836
  """
842
837
  super().__init__()
843
838
  self.img_size = img_size
@@ -914,8 +909,8 @@ class TinyViT(nn.Module):
914
909
  LayerNorm2d(256),
915
910
  )
916
911
 
917
- def set_layer_lr_decay(self, layer_lr_decay):
918
- """Sets layer-wise learning rate decay for the TinyViT model based on depth."""
912
+ def set_layer_lr_decay(self, layer_lr_decay: float):
913
+ """Set layer-wise learning rate decay for the TinyViT model based on depth."""
919
914
  decay_rate = layer_lr_decay
920
915
 
921
916
  # Layers -> blocks (depth)
@@ -923,7 +918,7 @@ class TinyViT(nn.Module):
923
918
  lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
924
919
 
925
920
  def _set_lr_scale(m, scale):
926
- """Sets the learning rate scale for each layer in the model based on the layer's depth."""
921
+ """Set the learning rate scale for each layer in the model based on the layer's depth."""
927
922
  for p in m.parameters():
928
923
  p.lr_scale = scale
929
924
 
@@ -943,7 +938,7 @@ class TinyViT(nn.Module):
943
938
  p.param_name = k
944
939
 
945
940
  def _check_lr_scale(m):
946
- """Checks if the learning rate scale attribute is present in module's parameters."""
941
+ """Check if the learning rate scale attribute is present in module's parameters."""
947
942
  for p in m.parameters():
948
943
  assert hasattr(p, "lr_scale"), p.param_name
949
944
 
@@ -951,7 +946,7 @@ class TinyViT(nn.Module):
951
946
 
952
947
  @staticmethod
953
948
  def _init_weights(m):
954
- """Initializes weights for linear and normalization layers in the TinyViT model."""
949
+ """Initialize weights for linear and normalization layers in the TinyViT model."""
955
950
  if isinstance(m, nn.Linear):
956
951
  # NOTE: This initialization is needed only for training.
957
952
  # trunc_normal_(m.weight, std=.02)
@@ -963,11 +958,11 @@ class TinyViT(nn.Module):
963
958
 
964
959
  @torch.jit.ignore
965
960
  def no_weight_decay_keywords(self):
966
- """Returns a set of keywords for parameters that should not use weight decay."""
961
+ """Return a set of keywords for parameters that should not use weight decay."""
967
962
  return {"attention_biases"}
968
963
 
969
- def forward_features(self, x):
970
- """Processes input through feature extraction layers, returning spatial features."""
964
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
965
+ """Process input through feature extraction layers, returning spatial features."""
971
966
  x = self.patch_embed(x) # x input is (N, C, H, W)
972
967
 
973
968
  x = self.layers[0](x)
@@ -981,11 +976,11 @@ class TinyViT(nn.Module):
981
976
  x = x.permute(0, 3, 1, 2)
982
977
  return self.neck(x)
983
978
 
984
- def forward(self, x):
985
- """Performs the forward pass through the TinyViT model, extracting features from the input image."""
979
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
980
+ """Perform the forward pass through the TinyViT model, extracting features from the input image."""
986
981
  return self.forward_features(x)
987
982
 
988
- def set_imgsz(self, imgsz=[1024, 1024]):
983
+ def set_imgsz(self, imgsz: List[int] = [1024, 1024]):
989
984
  """Set image size to make model compatible with different image sizes."""
990
985
  imgsz = [s // 4 for s in imgsz]
991
986
  self.patches_resolution = imgsz