ultralytics 8.3.89__py3-none-any.whl → 8.3.91__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (156) hide show
  1. tests/conftest.py +2 -2
  2. tests/test_cli.py +13 -11
  3. tests/test_cuda.py +10 -1
  4. tests/test_exports.py +2 -2
  5. tests/test_integrations.py +1 -5
  6. tests/test_python.py +16 -16
  7. tests/test_solutions.py +9 -9
  8. ultralytics/__init__.py +1 -1
  9. ultralytics/cfg/__init__.py +3 -1
  10. ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
  11. ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
  12. ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
  13. ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
  14. ultralytics/cfg/models/11/yolo11.yaml +5 -5
  15. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
  16. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
  17. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
  18. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
  19. ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
  20. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
  21. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
  22. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
  23. ultralytics/cfg/models/v8/yolov8.yaml +5 -5
  24. ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
  25. ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
  26. ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
  27. ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
  28. ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
  29. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  30. ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
  31. ultralytics/data/annotator.py +9 -14
  32. ultralytics/data/base.py +118 -30
  33. ultralytics/data/build.py +63 -24
  34. ultralytics/data/converter.py +5 -5
  35. ultralytics/data/dataset.py +207 -53
  36. ultralytics/data/loaders.py +1 -0
  37. ultralytics/data/split_dota.py +39 -12
  38. ultralytics/data/utils.py +15 -19
  39. ultralytics/engine/exporter.py +24 -23
  40. ultralytics/engine/model.py +67 -88
  41. ultralytics/engine/predictor.py +106 -21
  42. ultralytics/engine/trainer.py +32 -23
  43. ultralytics/engine/tuner.py +21 -18
  44. ultralytics/engine/validator.py +75 -41
  45. ultralytics/hub/__init__.py +12 -13
  46. ultralytics/hub/auth.py +9 -12
  47. ultralytics/hub/session.py +76 -21
  48. ultralytics/hub/utils.py +19 -17
  49. ultralytics/models/fastsam/model.py +20 -11
  50. ultralytics/models/fastsam/predict.py +36 -16
  51. ultralytics/models/fastsam/utils.py +5 -5
  52. ultralytics/models/fastsam/val.py +6 -6
  53. ultralytics/models/nas/model.py +22 -11
  54. ultralytics/models/nas/predict.py +9 -4
  55. ultralytics/models/nas/val.py +5 -5
  56. ultralytics/models/rtdetr/model.py +20 -11
  57. ultralytics/models/rtdetr/predict.py +18 -15
  58. ultralytics/models/rtdetr/train.py +20 -16
  59. ultralytics/models/rtdetr/val.py +42 -6
  60. ultralytics/models/sam/__init__.py +1 -1
  61. ultralytics/models/sam/amg.py +50 -4
  62. ultralytics/models/sam/model.py +8 -14
  63. ultralytics/models/sam/modules/decoders.py +18 -21
  64. ultralytics/models/sam/modules/encoders.py +25 -46
  65. ultralytics/models/sam/modules/memory_attention.py +19 -15
  66. ultralytics/models/sam/modules/sam.py +18 -25
  67. ultralytics/models/sam/modules/tiny_encoder.py +19 -29
  68. ultralytics/models/sam/modules/transformer.py +35 -57
  69. ultralytics/models/sam/modules/utils.py +15 -15
  70. ultralytics/models/sam/predict.py +0 -3
  71. ultralytics/models/utils/loss.py +87 -36
  72. ultralytics/models/utils/ops.py +26 -31
  73. ultralytics/models/yolo/classify/predict.py +24 -3
  74. ultralytics/models/yolo/classify/train.py +77 -10
  75. ultralytics/models/yolo/classify/val.py +40 -15
  76. ultralytics/models/yolo/detect/predict.py +23 -10
  77. ultralytics/models/yolo/detect/train.py +85 -15
  78. ultralytics/models/yolo/detect/val.py +145 -21
  79. ultralytics/models/yolo/model.py +1 -2
  80. ultralytics/models/yolo/obb/predict.py +12 -4
  81. ultralytics/models/yolo/obb/train.py +7 -0
  82. ultralytics/models/yolo/obb/val.py +25 -7
  83. ultralytics/models/yolo/pose/predict.py +22 -6
  84. ultralytics/models/yolo/pose/train.py +17 -1
  85. ultralytics/models/yolo/pose/val.py +46 -21
  86. ultralytics/models/yolo/segment/predict.py +22 -8
  87. ultralytics/models/yolo/segment/train.py +6 -0
  88. ultralytics/models/yolo/segment/val.py +100 -14
  89. ultralytics/models/yolo/world/train.py +38 -8
  90. ultralytics/models/yolo/world/train_world.py +39 -10
  91. ultralytics/nn/autobackend.py +28 -14
  92. ultralytics/nn/modules/__init__.py +3 -0
  93. ultralytics/nn/modules/activation.py +12 -3
  94. ultralytics/nn/modules/block.py +587 -84
  95. ultralytics/nn/modules/conv.py +418 -54
  96. ultralytics/nn/modules/head.py +3 -4
  97. ultralytics/nn/modules/transformer.py +320 -34
  98. ultralytics/nn/modules/utils.py +17 -3
  99. ultralytics/nn/tasks.py +221 -69
  100. ultralytics/solutions/ai_gym.py +2 -2
  101. ultralytics/solutions/analytics.py +4 -4
  102. ultralytics/solutions/heatmap.py +4 -4
  103. ultralytics/solutions/instance_segmentation.py +10 -4
  104. ultralytics/solutions/object_blurrer.py +2 -2
  105. ultralytics/solutions/object_counter.py +2 -2
  106. ultralytics/solutions/object_cropper.py +2 -2
  107. ultralytics/solutions/parking_management.py +9 -9
  108. ultralytics/solutions/queue_management.py +1 -1
  109. ultralytics/solutions/region_counter.py +2 -2
  110. ultralytics/solutions/security_alarm.py +7 -7
  111. ultralytics/solutions/solutions.py +7 -4
  112. ultralytics/solutions/speed_estimation.py +2 -2
  113. ultralytics/solutions/streamlit_inference.py +6 -6
  114. ultralytics/solutions/trackzone.py +9 -2
  115. ultralytics/solutions/vision_eye.py +4 -4
  116. ultralytics/trackers/basetrack.py +1 -1
  117. ultralytics/trackers/bot_sort.py +23 -22
  118. ultralytics/trackers/byte_tracker.py +4 -4
  119. ultralytics/trackers/track.py +2 -1
  120. ultralytics/trackers/utils/gmc.py +26 -27
  121. ultralytics/trackers/utils/kalman_filter.py +31 -29
  122. ultralytics/trackers/utils/matching.py +7 -7
  123. ultralytics/utils/__init__.py +32 -27
  124. ultralytics/utils/autobatch.py +5 -5
  125. ultralytics/utils/benchmarks.py +111 -18
  126. ultralytics/utils/callbacks/base.py +3 -3
  127. ultralytics/utils/callbacks/clearml.py +11 -11
  128. ultralytics/utils/callbacks/comet.py +42 -24
  129. ultralytics/utils/callbacks/dvc.py +11 -10
  130. ultralytics/utils/callbacks/hub.py +8 -8
  131. ultralytics/utils/callbacks/mlflow.py +1 -1
  132. ultralytics/utils/callbacks/neptune.py +12 -10
  133. ultralytics/utils/callbacks/raytune.py +1 -1
  134. ultralytics/utils/callbacks/tensorboard.py +6 -6
  135. ultralytics/utils/callbacks/wb.py +16 -16
  136. ultralytics/utils/checks.py +116 -35
  137. ultralytics/utils/dist.py +15 -2
  138. ultralytics/utils/downloads.py +13 -9
  139. ultralytics/utils/files.py +12 -13
  140. ultralytics/utils/instance.py +112 -45
  141. ultralytics/utils/loss.py +28 -33
  142. ultralytics/utils/metrics.py +246 -181
  143. ultralytics/utils/ops.py +61 -53
  144. ultralytics/utils/patches.py +8 -6
  145. ultralytics/utils/plotting.py +65 -45
  146. ultralytics/utils/tal.py +88 -57
  147. ultralytics/utils/torch_utils.py +181 -33
  148. ultralytics/utils/triton.py +13 -3
  149. ultralytics/utils/tuner.py +8 -16
  150. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/METADATA +1 -1
  151. ultralytics-8.3.91.dist-info/RECORD +250 -0
  152. ultralytics-8.3.89.dist-info/RECORD +0 -250
  153. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/LICENSE +0 -0
  154. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/WHEEL +0 -0
  155. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/entry_points.txt +0 -0
  156. {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/top_level.txt +0 -0
@@ -35,40 +35,112 @@ def autopad(k, p=None, d=1): # kernel, padding, dilation
35
35
 
36
36
 
37
37
  class Conv(nn.Module):
38
- """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
38
+ """
39
+ Standard convolution module with batch normalization and activation.
40
+
41
+ Attributes:
42
+ conv (nn.Conv2d): Convolutional layer.
43
+ bn (nn.BatchNorm2d): Batch normalization layer.
44
+ act (nn.Module): Activation function layer.
45
+ default_act (nn.Module): Default activation function (SiLU).
46
+ """
39
47
 
40
48
  default_act = nn.SiLU() # default activation
41
49
 
42
50
  def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
43
- """Initialize Conv layer with given arguments including activation."""
51
+ """
52
+ Initialize Conv layer with given parameters.
53
+
54
+ Args:
55
+ c1 (int): Number of input channels.
56
+ c2 (int): Number of output channels.
57
+ k (int): Kernel size.
58
+ s (int): Stride.
59
+ p (int, optional): Padding.
60
+ g (int): Groups.
61
+ d (int): Dilation.
62
+ act (bool | nn.Module): Activation function.
63
+ """
44
64
  super().__init__()
45
65
  self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
46
66
  self.bn = nn.BatchNorm2d(c2)
47
67
  self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
48
68
 
49
69
  def forward(self, x):
50
- """Apply convolution, batch normalization and activation to input tensor."""
70
+ """
71
+ Apply convolution, batch normalization and activation to input tensor.
72
+
73
+ Args:
74
+ x (torch.Tensor): Input tensor.
75
+
76
+ Returns:
77
+ (torch.Tensor): Output tensor.
78
+ """
51
79
  return self.act(self.bn(self.conv(x)))
52
80
 
53
81
  def forward_fuse(self, x):
54
- """Apply convolution and activation without batch normalization."""
82
+ """
83
+ Apply convolution and activation without batch normalization.
84
+
85
+ Args:
86
+ x (torch.Tensor): Input tensor.
87
+
88
+ Returns:
89
+ (torch.Tensor): Output tensor.
90
+ """
55
91
  return self.act(self.conv(x))
56
92
 
57
93
 
58
94
  class Conv2(Conv):
59
- """Simplified RepConv module with Conv fusing."""
95
+ """
96
+ Simplified RepConv module with Conv fusing.
97
+
98
+ Attributes:
99
+ conv (nn.Conv2d): Main 3x3 convolutional layer.
100
+ cv2 (nn.Conv2d): Additional 1x1 convolutional layer.
101
+ bn (nn.BatchNorm2d): Batch normalization layer.
102
+ act (nn.Module): Activation function layer.
103
+ """
60
104
 
61
105
  def __init__(self, c1, c2, k=3, s=1, p=None, g=1, d=1, act=True):
62
- """Initialize Conv layer with given arguments including activation."""
106
+ """
107
+ Initialize Conv2 layer with given parameters.
108
+
109
+ Args:
110
+ c1 (int): Number of input channels.
111
+ c2 (int): Number of output channels.
112
+ k (int): Kernel size.
113
+ s (int): Stride.
114
+ p (int, optional): Padding.
115
+ g (int): Groups.
116
+ d (int): Dilation.
117
+ act (bool | nn.Module): Activation function.
118
+ """
63
119
  super().__init__(c1, c2, k, s, p, g=g, d=d, act=act)
64
120
  self.cv2 = nn.Conv2d(c1, c2, 1, s, autopad(1, p, d), groups=g, dilation=d, bias=False) # add 1x1 conv
65
121
 
66
122
  def forward(self, x):
67
- """Apply convolution, batch normalization and activation to input tensor."""
123
+ """
124
+ Apply convolution, batch normalization and activation to input tensor.
125
+
126
+ Args:
127
+ x (torch.Tensor): Input tensor.
128
+
129
+ Returns:
130
+ (torch.Tensor): Output tensor.
131
+ """
68
132
  return self.act(self.bn(self.conv(x) + self.cv2(x)))
69
133
 
70
134
  def forward_fuse(self, x):
71
- """Apply fused convolution, batch normalization and activation to input tensor."""
135
+ """
136
+ Apply fused convolution, batch normalization and activation to input tensor.
137
+
138
+ Args:
139
+ x (torch.Tensor): Input tensor.
140
+
141
+ Returns:
142
+ (torch.Tensor): Output tensor.
143
+ """
72
144
  return self.act(self.bn(self.conv(x)))
73
145
 
74
146
  def fuse_convs(self):
@@ -83,106 +155,257 @@ class Conv2(Conv):
83
155
 
84
156
  class LightConv(nn.Module):
85
157
  """
86
- Light convolution with args(ch_in, ch_out, kernel).
158
+ Light convolution module with 1x1 and depthwise convolutions.
159
+
160
+ This implementation is based on the PaddleDetection HGNetV2 backbone.
87
161
 
88
- https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
162
+ Attributes:
163
+ conv1 (Conv): 1x1 convolution layer.
164
+ conv2 (DWConv): Depthwise convolution layer.
89
165
  """
90
166
 
91
167
  def __init__(self, c1, c2, k=1, act=nn.ReLU()):
92
- """Initialize Conv layer with given arguments including activation."""
168
+ """
169
+ Initialize LightConv layer with given parameters.
170
+
171
+ Args:
172
+ c1 (int): Number of input channels.
173
+ c2 (int): Number of output channels.
174
+ k (int): Kernel size for depthwise convolution.
175
+ act (nn.Module): Activation function.
176
+ """
93
177
  super().__init__()
94
178
  self.conv1 = Conv(c1, c2, 1, act=False)
95
179
  self.conv2 = DWConv(c2, c2, k, act=act)
96
180
 
97
181
  def forward(self, x):
98
- """Apply 2 convolutions to input tensor."""
182
+ """
183
+ Apply 2 convolutions to input tensor.
184
+
185
+ Args:
186
+ x (torch.Tensor): Input tensor.
187
+
188
+ Returns:
189
+ (torch.Tensor): Output tensor.
190
+ """
99
191
  return self.conv2(self.conv1(x))
100
192
 
101
193
 
102
194
  class DWConv(Conv):
103
- """Depth-wise convolution."""
195
+ """Depth-wise convolution module."""
104
196
 
105
- def __init__(self, c1, c2, k=1, s=1, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation
106
- """Initialize Depth-wise convolution with given parameters."""
197
+ def __init__(self, c1, c2, k=1, s=1, d=1, act=True):
198
+ """
199
+ Initialize depth-wise convolution with given parameters.
200
+
201
+ Args:
202
+ c1 (int): Number of input channels.
203
+ c2 (int): Number of output channels.
204
+ k (int): Kernel size.
205
+ s (int): Stride.
206
+ d (int): Dilation.
207
+ act (bool | nn.Module): Activation function.
208
+ """
107
209
  super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
108
210
 
109
211
 
110
212
  class DWConvTranspose2d(nn.ConvTranspose2d):
111
- """Depth-wise transpose convolution."""
213
+ """Depth-wise transpose convolution module."""
112
214
 
113
- def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0): # ch_in, ch_out, kernel, stride, padding, padding_out
114
- """Initialize DWConvTranspose2d class with given parameters."""
215
+ def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0):
216
+ """
217
+ Initialize depth-wise transpose convolution with given parameters.
218
+
219
+ Args:
220
+ c1 (int): Number of input channels.
221
+ c2 (int): Number of output channels.
222
+ k (int): Kernel size.
223
+ s (int): Stride.
224
+ p1 (int): Padding.
225
+ p2 (int): Output padding.
226
+ """
115
227
  super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2))
116
228
 
117
229
 
118
230
  class ConvTranspose(nn.Module):
119
- """Convolution transpose 2d layer."""
231
+ """
232
+ Convolution transpose module with optional batch normalization and activation.
233
+
234
+ Attributes:
235
+ conv_transpose (nn.ConvTranspose2d): Transposed convolution layer.
236
+ bn (nn.BatchNorm2d | nn.Identity): Batch normalization layer.
237
+ act (nn.Module): Activation function layer.
238
+ default_act (nn.Module): Default activation function (SiLU).
239
+ """
120
240
 
121
241
  default_act = nn.SiLU() # default activation
122
242
 
123
243
  def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):
124
- """Initialize ConvTranspose2d layer with batch normalization and activation function."""
244
+ """
245
+ Initialize ConvTranspose layer with given parameters.
246
+
247
+ Args:
248
+ c1 (int): Number of input channels.
249
+ c2 (int): Number of output channels.
250
+ k (int): Kernel size.
251
+ s (int): Stride.
252
+ p (int): Padding.
253
+ bn (bool): Use batch normalization.
254
+ act (bool | nn.Module): Activation function.
255
+ """
125
256
  super().__init__()
126
257
  self.conv_transpose = nn.ConvTranspose2d(c1, c2, k, s, p, bias=not bn)
127
258
  self.bn = nn.BatchNorm2d(c2) if bn else nn.Identity()
128
259
  self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
129
260
 
130
261
  def forward(self, x):
131
- """Applies transposed convolutions, batch normalization and activation to input."""
262
+ """
263
+ Apply transposed convolution, batch normalization and activation to input.
264
+
265
+ Args:
266
+ x (torch.Tensor): Input tensor.
267
+
268
+ Returns:
269
+ (torch.Tensor): Output tensor.
270
+ """
132
271
  return self.act(self.bn(self.conv_transpose(x)))
133
272
 
134
273
  def forward_fuse(self, x):
135
- """Applies activation and convolution transpose operation to input."""
274
+ """
275
+ Apply activation and convolution transpose operation to input.
276
+
277
+ Args:
278
+ x (torch.Tensor): Input tensor.
279
+
280
+ Returns:
281
+ (torch.Tensor): Output tensor.
282
+ """
136
283
  return self.act(self.conv_transpose(x))
137
284
 
138
285
 
139
286
  class Focus(nn.Module):
140
- """Focus wh information into c-space."""
287
+ """
288
+ Focus module for concentrating feature information.
289
+
290
+ Slices input tensor into 4 parts and concatenates them in the channel dimension.
291
+
292
+ Attributes:
293
+ conv (Conv): Convolution layer.
294
+ """
141
295
 
142
296
  def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
143
- """Initializes Focus object with user defined channel, convolution, padding, group and activation values."""
297
+ """
298
+ Initialize Focus module with given parameters.
299
+
300
+ Args:
301
+ c1 (int): Number of input channels.
302
+ c2 (int): Number of output channels.
303
+ k (int): Kernel size.
304
+ s (int): Stride.
305
+ p (int, optional): Padding.
306
+ g (int): Groups.
307
+ act (bool | nn.Module): Activation function.
308
+ """
144
309
  super().__init__()
145
310
  self.conv = Conv(c1 * 4, c2, k, s, p, g, act=act)
146
311
  # self.contract = Contract(gain=2)
147
312
 
148
313
  def forward(self, x):
149
314
  """
150
- Applies convolution to concatenated tensor and returns the output.
315
+ Apply Focus operation and convolution to input tensor.
151
316
 
152
317
  Input shape is (b,c,w,h) and output shape is (b,4c,w/2,h/2).
318
+
319
+ Args:
320
+ x (torch.Tensor): Input tensor.
321
+
322
+ Returns:
323
+ (torch.Tensor): Output tensor.
153
324
  """
154
325
  return self.conv(torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1))
155
326
  # return self.conv(self.contract(x))
156
327
 
157
328
 
158
329
  class GhostConv(nn.Module):
159
- """Ghost Convolution https://github.com/huawei-noah/ghostnet."""
330
+ """
331
+ Ghost Convolution module.
332
+
333
+ Generates more features with fewer parameters by using cheap operations.
334
+
335
+ Attributes:
336
+ cv1 (Conv): Primary convolution.
337
+ cv2 (Conv): Cheap operation convolution.
338
+
339
+ References:
340
+ https://github.com/huawei-noah/ghostnet
341
+ """
160
342
 
161
343
  def __init__(self, c1, c2, k=1, s=1, g=1, act=True):
162
- """Initializes Ghost Convolution module with primary and cheap operations for efficient feature learning."""
344
+ """
345
+ Initialize Ghost Convolution module with given parameters.
346
+
347
+ Args:
348
+ c1 (int): Number of input channels.
349
+ c2 (int): Number of output channels.
350
+ k (int): Kernel size.
351
+ s (int): Stride.
352
+ g (int): Groups.
353
+ act (bool | nn.Module): Activation function.
354
+ """
163
355
  super().__init__()
164
356
  c_ = c2 // 2 # hidden channels
165
357
  self.cv1 = Conv(c1, c_, k, s, None, g, act=act)
166
358
  self.cv2 = Conv(c_, c_, 5, 1, None, c_, act=act)
167
359
 
168
360
  def forward(self, x):
169
- """Forward propagation through a Ghost Bottleneck layer with skip connection."""
361
+ """
362
+ Apply Ghost Convolution to input tensor.
363
+
364
+ Args:
365
+ x (torch.Tensor): Input tensor.
366
+
367
+ Returns:
368
+ (torch.Tensor): Output tensor with concatenated features.
369
+ """
170
370
  y = self.cv1(x)
171
371
  return torch.cat((y, self.cv2(y)), 1)
172
372
 
173
373
 
174
374
  class RepConv(nn.Module):
175
375
  """
176
- RepConv is a basic rep-style block, including training and deploy status.
376
+ RepConv module with training and deploy modes.
377
+
378
+ This module is used in RT-DETR and can fuse convolutions during inference for efficiency.
379
+
380
+ Attributes:
381
+ conv1 (Conv): 3x3 convolution.
382
+ conv2 (Conv): 1x1 convolution.
383
+ bn (nn.BatchNorm2d, optional): Batch normalization for identity branch.
384
+ act (nn.Module): Activation function.
385
+ default_act (nn.Module): Default activation function (SiLU).
177
386
 
178
- This module is used in RT-DETR.
179
- Based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
387
+ References:
388
+ https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
180
389
  """
181
390
 
182
391
  default_act = nn.SiLU() # default activation
183
392
 
184
393
  def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False):
185
- """Initializes Light Convolution layer with inputs, outputs & optional activation function."""
394
+ """
395
+ Initialize RepConv module with given parameters.
396
+
397
+ Args:
398
+ c1 (int): Number of input channels.
399
+ c2 (int): Number of output channels.
400
+ k (int): Kernel size.
401
+ s (int): Stride.
402
+ p (int): Padding.
403
+ g (int): Groups.
404
+ d (int): Dilation.
405
+ act (bool | nn.Module): Activation function.
406
+ bn (bool): Use batch normalization for identity branch.
407
+ deploy (bool): Deploy mode for inference.
408
+ """
186
409
  super().__init__()
187
410
  assert k == 3 and p == 1
188
411
  self.g = g
@@ -195,16 +418,39 @@ class RepConv(nn.Module):
195
418
  self.conv2 = Conv(c1, c2, 1, s, p=(p - k // 2), g=g, act=False)
196
419
 
197
420
  def forward_fuse(self, x):
198
- """Forward process."""
421
+ """
422
+ Forward pass for deploy mode.
423
+
424
+ Args:
425
+ x (torch.Tensor): Input tensor.
426
+
427
+ Returns:
428
+ (torch.Tensor): Output tensor.
429
+ """
199
430
  return self.act(self.conv(x))
200
431
 
201
432
  def forward(self, x):
202
- """Forward process."""
433
+ """
434
+ Forward pass for training mode.
435
+
436
+ Args:
437
+ x (torch.Tensor): Input tensor.
438
+
439
+ Returns:
440
+ (torch.Tensor): Output tensor.
441
+ """
203
442
  id_out = 0 if self.bn is None else self.bn(x)
204
443
  return self.act(self.conv1(x) + self.conv2(x) + id_out)
205
444
 
206
445
  def get_equivalent_kernel_bias(self):
207
- """Returns equivalent kernel and bias by adding 3x3 kernel, 1x1 kernel and identity kernel with their biases."""
446
+ """
447
+ Calculate equivalent kernel and bias by fusing convolutions.
448
+
449
+ Returns:
450
+ (tuple): Tuple containing:
451
+ - Equivalent kernel (torch.Tensor)
452
+ - Equivalent bias (torch.Tensor)
453
+ """
208
454
  kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
209
455
  kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
210
456
  kernelid, biasid = self._fuse_bn_tensor(self.bn)
@@ -212,14 +458,32 @@ class RepConv(nn.Module):
212
458
 
213
459
  @staticmethod
214
460
  def _pad_1x1_to_3x3_tensor(kernel1x1):
215
- """Pads a 1x1 tensor to a 3x3 tensor."""
461
+ """
462
+ Pad a 1x1 kernel to 3x3 size.
463
+
464
+ Args:
465
+ kernel1x1 (torch.Tensor): 1x1 convolution kernel.
466
+
467
+ Returns:
468
+ (torch.Tensor): Padded 3x3 kernel.
469
+ """
216
470
  if kernel1x1 is None:
217
471
  return 0
218
472
  else:
219
473
  return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
220
474
 
221
475
  def _fuse_bn_tensor(self, branch):
222
- """Generates appropriate kernels and biases for convolution by fusing branches of the neural network."""
476
+ """
477
+ Fuse batch normalization with convolution weights.
478
+
479
+ Args:
480
+ branch (Conv | nn.BatchNorm2d | None): Branch to fuse.
481
+
482
+ Returns:
483
+ (tuple): Tuple containing:
484
+ - Fused kernel (torch.Tensor)
485
+ - Fused bias (torch.Tensor)
486
+ """
223
487
  if branch is None:
224
488
  return 0, 0
225
489
  if isinstance(branch, Conv):
@@ -247,7 +511,7 @@ class RepConv(nn.Module):
247
511
  return kernel * t, beta - running_mean * gamma / std
248
512
 
249
513
  def fuse_convs(self):
250
- """Combines two convolution layers into a single layer and removes unused attributes from the class."""
514
+ """Fuse convolutions for inference by creating a single equivalent convolution."""
251
515
  if hasattr(self, "conv"):
252
516
  return
253
517
  kernel, bias = self.get_equivalent_kernel_bias()
@@ -276,25 +540,63 @@ class RepConv(nn.Module):
276
540
 
277
541
 
278
542
  class ChannelAttention(nn.Module):
279
- """Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet."""
543
+ """
544
+ Channel-attention module for feature recalibration.
545
+
546
+ Applies attention weights to channels based on global average pooling.
547
+
548
+ Attributes:
549
+ pool (nn.AdaptiveAvgPool2d): Global average pooling.
550
+ fc (nn.Conv2d): Fully connected layer implemented as 1x1 convolution.
551
+ act (nn.Sigmoid): Sigmoid activation for attention weights.
552
+
553
+ References:
554
+ https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet
555
+ """
280
556
 
281
557
  def __init__(self, channels: int) -> None:
282
- """Initializes the class and sets the basic configurations and instance variables required."""
558
+ """
559
+ Initialize Channel-attention module.
560
+
561
+ Args:
562
+ channels (int): Number of input channels.
563
+ """
283
564
  super().__init__()
284
565
  self.pool = nn.AdaptiveAvgPool2d(1)
285
566
  self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
286
567
  self.act = nn.Sigmoid()
287
568
 
288
569
  def forward(self, x: torch.Tensor) -> torch.Tensor:
289
- """Applies forward pass using activation on convolutions of the input, optionally using batch normalization."""
570
+ """
571
+ Apply channel attention to input tensor.
572
+
573
+ Args:
574
+ x (torch.Tensor): Input tensor.
575
+
576
+ Returns:
577
+ (torch.Tensor): Channel-attended output tensor.
578
+ """
290
579
  return x * self.act(self.fc(self.pool(x)))
291
580
 
292
581
 
293
582
  class SpatialAttention(nn.Module):
294
- """Spatial-attention module."""
583
+ """
584
+ Spatial-attention module for feature recalibration.
585
+
586
+ Applies attention weights to spatial dimensions based on channel statistics.
587
+
588
+ Attributes:
589
+ cv1 (nn.Conv2d): Convolution layer for spatial attention.
590
+ act (nn.Sigmoid): Sigmoid activation for attention weights.
591
+ """
295
592
 
296
593
  def __init__(self, kernel_size=7):
297
- """Initialize Spatial-attention module with kernel size argument."""
594
+ """
595
+ Initialize Spatial-attention module.
596
+
597
+ Args:
598
+ kernel_size (int): Size of the convolutional kernel (3 or 7).
599
+ """
298
600
  super().__init__()
299
601
  assert kernel_size in {3, 7}, "kernel size must be 3 or 7"
300
602
  padding = 3 if kernel_size == 7 else 1
@@ -302,49 +604,111 @@ class SpatialAttention(nn.Module):
302
604
  self.act = nn.Sigmoid()
303
605
 
304
606
  def forward(self, x):
305
- """Apply channel and spatial attention on input for feature recalibration."""
607
+ """
608
+ Apply spatial attention to input tensor.
609
+
610
+ Args:
611
+ x (torch.Tensor): Input tensor.
612
+
613
+ Returns:
614
+ (torch.Tensor): Spatial-attended output tensor.
615
+ """
306
616
  return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))
307
617
 
308
618
 
309
619
  class CBAM(nn.Module):
310
- """Convolutional Block Attention Module."""
620
+ """
621
+ Convolutional Block Attention Module.
622
+
623
+ Combines channel and spatial attention mechanisms for comprehensive feature refinement.
624
+
625
+ Attributes:
626
+ channel_attention (ChannelAttention): Channel attention module.
627
+ spatial_attention (SpatialAttention): Spatial attention module.
628
+ """
311
629
 
312
630
  def __init__(self, c1, kernel_size=7):
313
- """Initialize CBAM with given input channel (c1) and kernel size."""
631
+ """
632
+ Initialize CBAM with given parameters.
633
+
634
+ Args:
635
+ c1 (int): Number of input channels.
636
+ kernel_size (int): Size of the convolutional kernel for spatial attention.
637
+ """
314
638
  super().__init__()
315
639
  self.channel_attention = ChannelAttention(c1)
316
640
  self.spatial_attention = SpatialAttention(kernel_size)
317
641
 
318
642
  def forward(self, x):
319
- """Applies the forward pass through C1 module."""
643
+ """
644
+ Apply channel and spatial attention sequentially to input tensor.
645
+
646
+ Args:
647
+ x (torch.Tensor): Input tensor.
648
+
649
+ Returns:
650
+ (torch.Tensor): Attended output tensor.
651
+ """
320
652
  return self.spatial_attention(self.channel_attention(x))
321
653
 
322
654
 
323
655
  class Concat(nn.Module):
324
- """Concatenate a list of tensors along dimension."""
656
+ """
657
+ Concatenate a list of tensors along specified dimension.
658
+
659
+ Attributes:
660
+ d (int): Dimension along which to concatenate tensors.
661
+ """
325
662
 
326
663
  def __init__(self, dimension=1):
327
- """Concatenates a list of tensors along a specified dimension."""
664
+ """
665
+ Initialize Concat module.
666
+
667
+ Args:
668
+ dimension (int): Dimension along which to concatenate tensors.
669
+ """
328
670
  super().__init__()
329
671
  self.d = dimension
330
672
 
331
673
  def forward(self, x):
332
- """Forward pass for the YOLOv8 mask Proto module."""
674
+ """
675
+ Concatenate input tensors along specified dimension.
676
+
677
+ Args:
678
+ x (List[torch.Tensor]): List of input tensors.
679
+
680
+ Returns:
681
+ (torch.Tensor): Concatenated tensor.
682
+ """
333
683
  return torch.cat(x, self.d)
334
684
 
335
685
 
336
686
  class Index(nn.Module):
337
- """Returns a particular index of the input."""
687
+ """
688
+ Returns a particular index of the input.
689
+
690
+ Attributes:
691
+ index (int): Index to select from input.
692
+ """
338
693
 
339
694
  def __init__(self, index=0):
340
- """Returns a particular index of the input."""
695
+ """
696
+ Initialize Index module.
697
+
698
+ Args:
699
+ index (int): Index to select from input.
700
+ """
341
701
  super().__init__()
342
702
  self.index = index
343
703
 
344
704
  def forward(self, x):
345
705
  """
346
- Forward pass.
706
+ Select and return a particular index from input.
707
+
708
+ Args:
709
+ x (List[torch.Tensor]): List of input tensors.
347
710
 
348
- Expects a list of tensors as input.
711
+ Returns:
712
+ (torch.Tensor): Selected tensor.
349
713
  """
350
714
  return x[self.index]