ultralytics 8.1.29__py3-none-any.whl → 8.3.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. tests/__init__.py +22 -0
  2. tests/conftest.py +83 -0
  3. tests/test_cli.py +122 -0
  4. tests/test_cuda.py +155 -0
  5. tests/test_engine.py +131 -0
  6. tests/test_exports.py +216 -0
  7. tests/test_integrations.py +150 -0
  8. tests/test_python.py +615 -0
  9. tests/test_solutions.py +94 -0
  10. ultralytics/__init__.py +11 -8
  11. ultralytics/cfg/__init__.py +569 -131
  12. ultralytics/cfg/datasets/Argoverse.yaml +2 -1
  13. ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
  14. ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
  15. ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
  16. ultralytics/cfg/datasets/ImageNet.yaml +2 -1
  17. ultralytics/cfg/datasets/Objects365.yaml +5 -4
  18. ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
  19. ultralytics/cfg/datasets/VOC.yaml +3 -2
  20. ultralytics/cfg/datasets/VisDrone.yaml +6 -5
  21. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  22. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  23. ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
  24. ultralytics/cfg/datasets/coco-pose.yaml +7 -6
  25. ultralytics/cfg/datasets/coco.yaml +3 -2
  26. ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
  27. ultralytics/cfg/datasets/coco128.yaml +4 -3
  28. ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
  29. ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
  30. ultralytics/cfg/datasets/coco8.yaml +3 -2
  31. ultralytics/cfg/datasets/crack-seg.yaml +3 -2
  32. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  33. ultralytics/cfg/datasets/dota8.yaml +3 -2
  34. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  35. ultralytics/cfg/datasets/lvis.yaml +1236 -0
  36. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  37. ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
  38. ultralytics/cfg/datasets/package-seg.yaml +5 -4
  39. ultralytics/cfg/datasets/signature.yaml +21 -0
  40. ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
  41. ultralytics/cfg/datasets/xView.yaml +2 -1
  42. ultralytics/cfg/default.yaml +14 -11
  43. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
  44. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  45. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  46. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  47. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  48. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  49. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
  51. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
  52. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
  53. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  54. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  55. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  56. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  57. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  58. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  59. ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
  60. ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
  61. ultralytics/cfg/models/v3/yolov3.yaml +5 -2
  62. ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
  63. ultralytics/cfg/models/v5/yolov5.yaml +5 -2
  64. ultralytics/cfg/models/v6/yolov6.yaml +5 -2
  65. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
  66. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
  67. ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
  68. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
  69. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
  70. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
  71. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
  72. ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
  73. ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
  74. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
  75. ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
  76. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
  77. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
  78. ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
  79. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
  80. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
  81. ultralytics/cfg/models/v8/yolov8.yaml +5 -2
  82. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  83. ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
  84. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  85. ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
  86. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  87. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  88. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  89. ultralytics/cfg/solutions/default.yaml +24 -0
  90. ultralytics/cfg/trackers/botsort.yaml +8 -5
  91. ultralytics/cfg/trackers/bytetrack.yaml +8 -5
  92. ultralytics/data/__init__.py +14 -3
  93. ultralytics/data/annotator.py +37 -15
  94. ultralytics/data/augment.py +1783 -289
  95. ultralytics/data/base.py +62 -27
  96. ultralytics/data/build.py +36 -8
  97. ultralytics/data/converter.py +196 -36
  98. ultralytics/data/dataset.py +233 -94
  99. ultralytics/data/loaders.py +199 -96
  100. ultralytics/data/split_dota.py +39 -29
  101. ultralytics/data/utils.py +110 -40
  102. ultralytics/engine/__init__.py +1 -1
  103. ultralytics/engine/exporter.py +569 -242
  104. ultralytics/engine/model.py +604 -252
  105. ultralytics/engine/predictor.py +22 -11
  106. ultralytics/engine/results.py +1228 -218
  107. ultralytics/engine/trainer.py +190 -129
  108. ultralytics/engine/tuner.py +18 -18
  109. ultralytics/engine/validator.py +18 -15
  110. ultralytics/hub/__init__.py +31 -13
  111. ultralytics/hub/auth.py +11 -7
  112. ultralytics/hub/google/__init__.py +159 -0
  113. ultralytics/hub/session.py +128 -94
  114. ultralytics/hub/utils.py +20 -21
  115. ultralytics/models/__init__.py +4 -2
  116. ultralytics/models/fastsam/__init__.py +2 -3
  117. ultralytics/models/fastsam/model.py +26 -4
  118. ultralytics/models/fastsam/predict.py +127 -63
  119. ultralytics/models/fastsam/utils.py +1 -44
  120. ultralytics/models/fastsam/val.py +1 -1
  121. ultralytics/models/nas/__init__.py +1 -1
  122. ultralytics/models/nas/model.py +21 -10
  123. ultralytics/models/nas/predict.py +3 -6
  124. ultralytics/models/nas/val.py +4 -4
  125. ultralytics/models/rtdetr/__init__.py +1 -1
  126. ultralytics/models/rtdetr/model.py +1 -1
  127. ultralytics/models/rtdetr/predict.py +6 -8
  128. ultralytics/models/rtdetr/train.py +6 -2
  129. ultralytics/models/rtdetr/val.py +3 -3
  130. ultralytics/models/sam/__init__.py +3 -3
  131. ultralytics/models/sam/amg.py +29 -23
  132. ultralytics/models/sam/build.py +211 -13
  133. ultralytics/models/sam/model.py +91 -30
  134. ultralytics/models/sam/modules/__init__.py +1 -1
  135. ultralytics/models/sam/modules/blocks.py +1129 -0
  136. ultralytics/models/sam/modules/decoders.py +381 -53
  137. ultralytics/models/sam/modules/encoders.py +515 -324
  138. ultralytics/models/sam/modules/memory_attention.py +237 -0
  139. ultralytics/models/sam/modules/sam.py +969 -21
  140. ultralytics/models/sam/modules/tiny_encoder.py +425 -154
  141. ultralytics/models/sam/modules/transformer.py +159 -60
  142. ultralytics/models/sam/modules/utils.py +293 -0
  143. ultralytics/models/sam/predict.py +1263 -132
  144. ultralytics/models/utils/__init__.py +1 -1
  145. ultralytics/models/utils/loss.py +36 -24
  146. ultralytics/models/utils/ops.py +3 -7
  147. ultralytics/models/yolo/__init__.py +3 -3
  148. ultralytics/models/yolo/classify/__init__.py +1 -1
  149. ultralytics/models/yolo/classify/predict.py +7 -8
  150. ultralytics/models/yolo/classify/train.py +17 -22
  151. ultralytics/models/yolo/classify/val.py +8 -4
  152. ultralytics/models/yolo/detect/__init__.py +1 -1
  153. ultralytics/models/yolo/detect/predict.py +3 -5
  154. ultralytics/models/yolo/detect/train.py +11 -4
  155. ultralytics/models/yolo/detect/val.py +90 -52
  156. ultralytics/models/yolo/model.py +14 -9
  157. ultralytics/models/yolo/obb/__init__.py +1 -1
  158. ultralytics/models/yolo/obb/predict.py +2 -2
  159. ultralytics/models/yolo/obb/train.py +5 -3
  160. ultralytics/models/yolo/obb/val.py +41 -23
  161. ultralytics/models/yolo/pose/__init__.py +1 -1
  162. ultralytics/models/yolo/pose/predict.py +3 -5
  163. ultralytics/models/yolo/pose/train.py +2 -2
  164. ultralytics/models/yolo/pose/val.py +51 -17
  165. ultralytics/models/yolo/segment/__init__.py +1 -1
  166. ultralytics/models/yolo/segment/predict.py +3 -5
  167. ultralytics/models/yolo/segment/train.py +2 -2
  168. ultralytics/models/yolo/segment/val.py +60 -19
  169. ultralytics/models/yolo/world/__init__.py +5 -0
  170. ultralytics/models/yolo/world/train.py +92 -0
  171. ultralytics/models/yolo/world/train_world.py +109 -0
  172. ultralytics/nn/__init__.py +1 -1
  173. ultralytics/nn/autobackend.py +228 -93
  174. ultralytics/nn/modules/__init__.py +39 -14
  175. ultralytics/nn/modules/activation.py +21 -0
  176. ultralytics/nn/modules/block.py +526 -66
  177. ultralytics/nn/modules/conv.py +24 -7
  178. ultralytics/nn/modules/head.py +177 -34
  179. ultralytics/nn/modules/transformer.py +6 -5
  180. ultralytics/nn/modules/utils.py +1 -2
  181. ultralytics/nn/tasks.py +225 -77
  182. ultralytics/solutions/__init__.py +30 -1
  183. ultralytics/solutions/ai_gym.py +96 -143
  184. ultralytics/solutions/analytics.py +247 -0
  185. ultralytics/solutions/distance_calculation.py +78 -135
  186. ultralytics/solutions/heatmap.py +93 -247
  187. ultralytics/solutions/object_counter.py +184 -259
  188. ultralytics/solutions/parking_management.py +246 -0
  189. ultralytics/solutions/queue_management.py +112 -0
  190. ultralytics/solutions/region_counter.py +116 -0
  191. ultralytics/solutions/security_alarm.py +144 -0
  192. ultralytics/solutions/solutions.py +178 -0
  193. ultralytics/solutions/speed_estimation.py +86 -174
  194. ultralytics/solutions/streamlit_inference.py +190 -0
  195. ultralytics/solutions/trackzone.py +68 -0
  196. ultralytics/trackers/__init__.py +1 -1
  197. ultralytics/trackers/basetrack.py +32 -13
  198. ultralytics/trackers/bot_sort.py +61 -28
  199. ultralytics/trackers/byte_tracker.py +83 -51
  200. ultralytics/trackers/track.py +21 -6
  201. ultralytics/trackers/utils/__init__.py +1 -1
  202. ultralytics/trackers/utils/gmc.py +62 -48
  203. ultralytics/trackers/utils/kalman_filter.py +166 -35
  204. ultralytics/trackers/utils/matching.py +40 -21
  205. ultralytics/utils/__init__.py +511 -239
  206. ultralytics/utils/autobatch.py +40 -22
  207. ultralytics/utils/benchmarks.py +266 -85
  208. ultralytics/utils/callbacks/__init__.py +1 -1
  209. ultralytics/utils/callbacks/base.py +1 -3
  210. ultralytics/utils/callbacks/clearml.py +7 -6
  211. ultralytics/utils/callbacks/comet.py +39 -17
  212. ultralytics/utils/callbacks/dvc.py +1 -1
  213. ultralytics/utils/callbacks/hub.py +16 -16
  214. ultralytics/utils/callbacks/mlflow.py +28 -24
  215. ultralytics/utils/callbacks/neptune.py +6 -2
  216. ultralytics/utils/callbacks/raytune.py +3 -4
  217. ultralytics/utils/callbacks/tensorboard.py +18 -18
  218. ultralytics/utils/callbacks/wb.py +27 -20
  219. ultralytics/utils/checks.py +160 -100
  220. ultralytics/utils/dist.py +2 -1
  221. ultralytics/utils/downloads.py +40 -34
  222. ultralytics/utils/errors.py +1 -1
  223. ultralytics/utils/files.py +72 -38
  224. ultralytics/utils/instance.py +41 -19
  225. ultralytics/utils/loss.py +83 -55
  226. ultralytics/utils/metrics.py +61 -56
  227. ultralytics/utils/ops.py +94 -89
  228. ultralytics/utils/patches.py +30 -14
  229. ultralytics/utils/plotting.py +600 -269
  230. ultralytics/utils/tal.py +67 -26
  231. ultralytics/utils/torch_utils.py +302 -102
  232. ultralytics/utils/triton.py +2 -1
  233. ultralytics/utils/tuner.py +21 -12
  234. ultralytics-8.3.62.dist-info/METADATA +370 -0
  235. ultralytics-8.3.62.dist-info/RECORD +241 -0
  236. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
  237. ultralytics/data/explorer/__init__.py +0 -5
  238. ultralytics/data/explorer/explorer.py +0 -472
  239. ultralytics/data/explorer/gui/__init__.py +0 -1
  240. ultralytics/data/explorer/gui/dash.py +0 -268
  241. ultralytics/data/explorer/utils.py +0 -166
  242. ultralytics/models/fastsam/prompt.py +0 -357
  243. ultralytics-8.1.29.dist-info/METADATA +0 -373
  244. ultralytics-8.1.29.dist-info/RECORD +0 -197
  245. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
  246. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
  247. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,12 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
  """Block modules."""
3
3
 
4
4
  import torch
5
5
  import torch.nn as nn
6
6
  import torch.nn.functional as F
7
7
 
8
+ from ultralytics.utils.torch_utils import fuse_conv_and_bn
9
+
8
10
  from .conv import Conv, DWConv, GhostConv, LightConv, RepConv, autopad
9
11
  from .transformer import TransformerBlock
10
12
 
@@ -32,11 +34,22 @@ __all__ = (
32
34
  "RepC3",
33
35
  "ResNetLayer",
34
36
  "RepNCSPELAN4",
37
+ "ELAN1",
35
38
  "ADown",
39
+ "AConv",
36
40
  "SPPELAN",
37
41
  "CBFuse",
38
42
  "CBLinear",
39
- "Silence",
43
+ "C3k2",
44
+ "C2fPSA",
45
+ "C2PSA",
46
+ "RepVGGDW",
47
+ "CIB",
48
+ "C2fCIB",
49
+ "Attention",
50
+ "PSA",
51
+ "SCDown",
52
+ "TorchVision",
40
53
  )
41
54
 
42
55
 
@@ -171,10 +184,9 @@ class SPPF(nn.Module):
171
184
 
172
185
  def forward(self, x):
173
186
  """Forward pass through Ghost Convolution block."""
174
- x = self.cv1(x)
175
- y1 = self.m(x)
176
- y2 = self.m(y1)
177
- return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
187
+ y = [self.cv1(x)]
188
+ y.extend(self.m(y[-1]) for _ in range(3))
189
+ return self.cv2(torch.cat(y, 1))
178
190
 
179
191
 
180
192
  class C1(nn.Module):
@@ -196,9 +208,7 @@ class C2(nn.Module):
196
208
  """CSP Bottleneck with 2 convolutions."""
197
209
 
198
210
  def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
199
- """Initializes the CSP Bottleneck with 2 convolutions module with arguments ch_in, ch_out, number, shortcut,
200
- groups, expansion.
201
- """
211
+ """Initializes a CSP Bottleneck with 2 convolutions and optional shortcut connection."""
202
212
  super().__init__()
203
213
  self.c = int(c2 * e) # hidden channels
204
214
  self.cv1 = Conv(c1, 2 * self.c, 1, 1)
@@ -216,9 +226,7 @@ class C2f(nn.Module):
216
226
  """Faster Implementation of CSP Bottleneck with 2 convolutions."""
217
227
 
218
228
  def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
219
- """Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,
220
- expansion.
221
- """
229
+ """Initializes a CSP bottleneck with 2 convolutions and n Bottleneck blocks for faster processing."""
222
230
  super().__init__()
223
231
  self.c = int(c2 * e) # hidden channels
224
232
  self.cv1 = Conv(c1, 2 * self.c, 1, 1)
@@ -233,7 +241,8 @@ class C2f(nn.Module):
233
241
 
234
242
  def forward_split(self, x):
235
243
  """Forward pass using split() instead of chunk()."""
236
- y = list(self.cv1(x).split((self.c, self.c), 1))
244
+ y = self.cv1(x).split((self.c, self.c), 1)
245
+ y = [y[0], y[1]]
237
246
  y.extend(m(y[-1]) for m in self.m)
238
247
  return self.cv2(torch.cat(y, 1))
239
248
 
@@ -272,8 +281,8 @@ class RepC3(nn.Module):
272
281
  """Initialize CSP Bottleneck with a single convolution using input channels, output channels, and number."""
273
282
  super().__init__()
274
283
  c_ = int(c2 * e) # hidden channels
275
- self.cv1 = Conv(c1, c2, 1, 1)
276
- self.cv2 = Conv(c1, c2, 1, 1)
284
+ self.cv1 = Conv(c1, c_, 1, 1)
285
+ self.cv2 = Conv(c1, c_, 1, 1)
277
286
  self.m = nn.Sequential(*[RepConv(c_, c_) for _ in range(n)])
278
287
  self.cv3 = Conv(c_, c2, 1, 1) if c_ != c2 else nn.Identity()
279
288
 
@@ -327,9 +336,7 @@ class Bottleneck(nn.Module):
327
336
  """Standard bottleneck."""
328
337
 
329
338
  def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):
330
- """Initializes a bottleneck module with given input/output channels, shortcut option, group, kernels, and
331
- expansion.
332
- """
339
+ """Initializes a standard bottleneck module with optional shortcut connection and configurable parameters."""
333
340
  super().__init__()
334
341
  c_ = int(c2 * e) # hidden channels
335
342
  self.cv1 = Conv(c1, c_, k[0], 1)
@@ -337,7 +344,7 @@ class Bottleneck(nn.Module):
337
344
  self.add = shortcut and c1 == c2
338
345
 
339
346
  def forward(self, x):
340
- """'forward()' applies the YOLO FPN to input data."""
347
+ """Applies the YOLO FPN to input data."""
341
348
  return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
342
349
 
343
350
 
@@ -441,9 +448,7 @@ class C2fAttn(nn.Module):
441
448
  """C2f module with an additional attn module."""
442
449
 
443
450
  def __init__(self, c1, c2, n=1, ec=128, nh=1, gc=512, shortcut=False, g=1, e=0.5):
444
- """Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,
445
- expansion.
446
- """
451
+ """Initializes C2f module with attention mechanism for enhanced feature extraction and processing."""
447
452
  super().__init__()
448
453
  self.c = int(c2 * e) # hidden channels
449
454
  self.cv1 = Conv(c1, 2 * self.c, 1, 1)
@@ -513,14 +518,13 @@ class ImagePoolingAttn(nn.Module):
513
518
 
514
519
 
515
520
  class ContrastiveHead(nn.Module):
516
- """Contrastive Head for YOLO-World compute the region-text scores according to the similarity between image and text
517
- features.
518
- """
521
+ """Implements contrastive learning head for region-text similarity in vision-language models."""
519
522
 
520
523
  def __init__(self):
521
524
  """Initializes ContrastiveHead with specified region-text similarity parameters."""
522
525
  super().__init__()
523
- self.bias = nn.Parameter(torch.zeros([]))
526
+ # NOTE: use -10.0 to keep the init cls loss consistency with other losses
527
+ self.bias = nn.Parameter(torch.tensor([-10.0]))
524
528
  self.logit_scale = nn.Parameter(torch.ones([]) * torch.tensor(1 / 0.07).log())
525
529
 
526
530
  def forward(self, x, w):
@@ -543,7 +547,8 @@ class BNContrastiveHead(nn.Module):
543
547
  """Initialize ContrastiveHead with region-text similarity parameters."""
544
548
  super().__init__()
545
549
  self.norm = nn.BatchNorm2d(embed_dims)
546
- self.bias = nn.Parameter(torch.zeros([]))
550
+ # NOTE: use -10.0 to keep the init cls loss consistency with other losses
551
+ self.bias = nn.Parameter(torch.tensor([-10.0]))
547
552
  # use -1.0 is more stable
548
553
  self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))
549
554
 
@@ -555,40 +560,25 @@ class BNContrastiveHead(nn.Module):
555
560
  return x * self.logit_scale.exp() + self.bias
556
561
 
557
562
 
558
- class RepBottleneck(nn.Module):
563
+ class RepBottleneck(Bottleneck):
559
564
  """Rep bottleneck."""
560
565
 
561
566
  def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):
562
- """Initializes a RepBottleneck module with customizable in/out channels, shortcut option, groups and expansion
563
- ratio.
564
- """
565
- super().__init__()
567
+ """Initializes a RepBottleneck module with customizable in/out channels, shortcuts, groups and expansion."""
568
+ super().__init__(c1, c2, shortcut, g, k, e)
566
569
  c_ = int(c2 * e) # hidden channels
567
570
  self.cv1 = RepConv(c1, c_, k[0], 1)
568
- self.cv2 = Conv(c_, c2, k[1], 1, g=g)
569
- self.add = shortcut and c1 == c2
570
-
571
- def forward(self, x):
572
- """Forward pass through RepBottleneck layer."""
573
- return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
574
571
 
575
572
 
576
- class RepCSP(nn.Module):
577
- """Rep CSP Bottleneck with 3 convolutions."""
573
+ class RepCSP(C3):
574
+ """Repeatable Cross Stage Partial Network (RepCSP) module for efficient feature extraction."""
578
575
 
579
576
  def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
580
577
  """Initializes RepCSP layer with given channels, repetitions, shortcut, groups and expansion ratio."""
581
- super().__init__()
578
+ super().__init__(c1, c2, n, shortcut, g, e)
582
579
  c_ = int(c2 * e) # hidden channels
583
- self.cv1 = Conv(c1, c_, 1, 1)
584
- self.cv2 = Conv(c1, c_, 1, 1)
585
- self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
586
580
  self.m = nn.Sequential(*(RepBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
587
581
 
588
- def forward(self, x):
589
- """Forward pass through RepCSP layer."""
590
- return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
591
-
592
582
 
593
583
  class RepNCSPELAN4(nn.Module):
594
584
  """CSP-ELAN."""
@@ -615,6 +605,33 @@ class RepNCSPELAN4(nn.Module):
615
605
  return self.cv4(torch.cat(y, 1))
616
606
 
617
607
 
608
+ class ELAN1(RepNCSPELAN4):
609
+ """ELAN1 module with 4 convolutions."""
610
+
611
+ def __init__(self, c1, c2, c3, c4):
612
+ """Initializes ELAN1 layer with specified channel sizes."""
613
+ super().__init__(c1, c2, c3, c4)
614
+ self.c = c3 // 2
615
+ self.cv1 = Conv(c1, c3, 1, 1)
616
+ self.cv2 = Conv(c3 // 2, c4, 3, 1)
617
+ self.cv3 = Conv(c4, c4, 3, 1)
618
+ self.cv4 = Conv(c3 + (2 * c4), c2, 1, 1)
619
+
620
+
621
+ class AConv(nn.Module):
622
+ """AConv."""
623
+
624
+ def __init__(self, c1, c2):
625
+ """Initializes AConv module with convolution layers."""
626
+ super().__init__()
627
+ self.cv1 = Conv(c1, c2, 3, 2, 1)
628
+
629
+ def forward(self, x):
630
+ """Forward pass through AConv layer."""
631
+ x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)
632
+ return self.cv1(x)
633
+
634
+
618
635
  class ADown(nn.Module):
619
636
  """ADown."""
620
637
 
@@ -655,31 +672,18 @@ class SPPELAN(nn.Module):
655
672
  return self.cv5(torch.cat(y, 1))
656
673
 
657
674
 
658
- class Silence(nn.Module):
659
- """Silence."""
660
-
661
- def __init__(self):
662
- """Initializes the Silence module."""
663
- super(Silence, self).__init__()
664
-
665
- def forward(self, x):
666
- """Forward pass through Silence layer."""
667
- return x
668
-
669
-
670
675
  class CBLinear(nn.Module):
671
676
  """CBLinear."""
672
677
 
673
678
  def __init__(self, c1, c2s, k=1, s=1, p=None, g=1):
674
679
  """Initializes the CBLinear module, passing inputs unchanged."""
675
- super(CBLinear, self).__init__()
680
+ super().__init__()
676
681
  self.c2s = c2s
677
682
  self.conv = nn.Conv2d(c1, sum(c2s), k, s, autopad(k, p), groups=g, bias=True)
678
683
 
679
684
  def forward(self, x):
680
685
  """Forward pass through CBLinear layer."""
681
- outs = self.conv(x).split(self.c2s, dim=1)
682
- return outs
686
+ return self.conv(x).split(self.c2s, dim=1)
683
687
 
684
688
 
685
689
  class CBFuse(nn.Module):
@@ -687,12 +691,468 @@ class CBFuse(nn.Module):
687
691
 
688
692
  def __init__(self, idx):
689
693
  """Initializes CBFuse module with layer index for selective feature fusion."""
690
- super(CBFuse, self).__init__()
694
+ super().__init__()
691
695
  self.idx = idx
692
696
 
693
697
  def forward(self, xs):
694
698
  """Forward pass through CBFuse layer."""
695
699
  target_size = xs[-1].shape[2:]
696
700
  res = [F.interpolate(x[self.idx[i]], size=target_size, mode="nearest") for i, x in enumerate(xs[:-1])]
697
- out = torch.sum(torch.stack(res + xs[-1:]), dim=0)
698
- return out
701
+ return torch.sum(torch.stack(res + xs[-1:]), dim=0)
702
+
703
+
704
+ class C3f(nn.Module):
705
+ """Faster Implementation of CSP Bottleneck with 2 convolutions."""
706
+
707
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
708
+ """Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,
709
+ expansion.
710
+ """
711
+ super().__init__()
712
+ c_ = int(c2 * e) # hidden channels
713
+ self.cv1 = Conv(c1, c_, 1, 1)
714
+ self.cv2 = Conv(c1, c_, 1, 1)
715
+ self.cv3 = Conv((2 + n) * c_, c2, 1) # optional act=FReLU(c2)
716
+ self.m = nn.ModuleList(Bottleneck(c_, c_, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
717
+
718
+ def forward(self, x):
719
+ """Forward pass through C2f layer."""
720
+ y = [self.cv2(x), self.cv1(x)]
721
+ y.extend(m(y[-1]) for m in self.m)
722
+ return self.cv3(torch.cat(y, 1))
723
+
724
+
725
+ class C3k2(C2f):
726
+ """Faster Implementation of CSP Bottleneck with 2 convolutions."""
727
+
728
+ def __init__(self, c1, c2, n=1, c3k=False, e=0.5, g=1, shortcut=True):
729
+ """Initializes the C3k2 module, a faster CSP Bottleneck with 2 convolutions and optional C3k blocks."""
730
+ super().__init__(c1, c2, n, shortcut, g, e)
731
+ self.m = nn.ModuleList(
732
+ C3k(self.c, self.c, 2, shortcut, g) if c3k else Bottleneck(self.c, self.c, shortcut, g) for _ in range(n)
733
+ )
734
+
735
+
736
+ class C3k(C3):
737
+ """C3k is a CSP bottleneck module with customizable kernel sizes for feature extraction in neural networks."""
738
+
739
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, k=3):
740
+ """Initializes the C3k module with specified channels, number of layers, and configurations."""
741
+ super().__init__(c1, c2, n, shortcut, g, e)
742
+ c_ = int(c2 * e) # hidden channels
743
+ # self.m = nn.Sequential(*(RepBottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n)))
744
+ self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n)))
745
+
746
+
747
+ class RepVGGDW(torch.nn.Module):
748
+ """RepVGGDW is a class that represents a depth wise separable convolutional block in RepVGG architecture."""
749
+
750
+ def __init__(self, ed) -> None:
751
+ """Initializes RepVGGDW with depthwise separable convolutional layers for efficient processing."""
752
+ super().__init__()
753
+ self.conv = Conv(ed, ed, 7, 1, 3, g=ed, act=False)
754
+ self.conv1 = Conv(ed, ed, 3, 1, 1, g=ed, act=False)
755
+ self.dim = ed
756
+ self.act = nn.SiLU()
757
+
758
+ def forward(self, x):
759
+ """
760
+ Performs a forward pass of the RepVGGDW block.
761
+
762
+ Args:
763
+ x (torch.Tensor): Input tensor.
764
+
765
+ Returns:
766
+ (torch.Tensor): Output tensor after applying the depth wise separable convolution.
767
+ """
768
+ return self.act(self.conv(x) + self.conv1(x))
769
+
770
+ def forward_fuse(self, x):
771
+ """
772
+ Performs a forward pass of the RepVGGDW block without fusing the convolutions.
773
+
774
+ Args:
775
+ x (torch.Tensor): Input tensor.
776
+
777
+ Returns:
778
+ (torch.Tensor): Output tensor after applying the depth wise separable convolution.
779
+ """
780
+ return self.act(self.conv(x))
781
+
782
+ @torch.no_grad()
783
+ def fuse(self):
784
+ """
785
+ Fuses the convolutional layers in the RepVGGDW block.
786
+
787
+ This method fuses the convolutional layers and updates the weights and biases accordingly.
788
+ """
789
+ conv = fuse_conv_and_bn(self.conv.conv, self.conv.bn)
790
+ conv1 = fuse_conv_and_bn(self.conv1.conv, self.conv1.bn)
791
+
792
+ conv_w = conv.weight
793
+ conv_b = conv.bias
794
+ conv1_w = conv1.weight
795
+ conv1_b = conv1.bias
796
+
797
+ conv1_w = torch.nn.functional.pad(conv1_w, [2, 2, 2, 2])
798
+
799
+ final_conv_w = conv_w + conv1_w
800
+ final_conv_b = conv_b + conv1_b
801
+
802
+ conv.weight.data.copy_(final_conv_w)
803
+ conv.bias.data.copy_(final_conv_b)
804
+
805
+ self.conv = conv
806
+ del self.conv1
807
+
808
+
809
+ class CIB(nn.Module):
810
+ """
811
+ Conditional Identity Block (CIB) module.
812
+
813
+ Args:
814
+ c1 (int): Number of input channels.
815
+ c2 (int): Number of output channels.
816
+ shortcut (bool, optional): Whether to add a shortcut connection. Defaults to True.
817
+ e (float, optional): Scaling factor for the hidden channels. Defaults to 0.5.
818
+ lk (bool, optional): Whether to use RepVGGDW for the third convolutional layer. Defaults to False.
819
+ """
820
+
821
+ def __init__(self, c1, c2, shortcut=True, e=0.5, lk=False):
822
+ """Initializes the custom model with optional shortcut, scaling factor, and RepVGGDW layer."""
823
+ super().__init__()
824
+ c_ = int(c2 * e) # hidden channels
825
+ self.cv1 = nn.Sequential(
826
+ Conv(c1, c1, 3, g=c1),
827
+ Conv(c1, 2 * c_, 1),
828
+ RepVGGDW(2 * c_) if lk else Conv(2 * c_, 2 * c_, 3, g=2 * c_),
829
+ Conv(2 * c_, c2, 1),
830
+ Conv(c2, c2, 3, g=c2),
831
+ )
832
+
833
+ self.add = shortcut and c1 == c2
834
+
835
+ def forward(self, x):
836
+ """
837
+ Forward pass of the CIB module.
838
+
839
+ Args:
840
+ x (torch.Tensor): Input tensor.
841
+
842
+ Returns:
843
+ (torch.Tensor): Output tensor.
844
+ """
845
+ return x + self.cv1(x) if self.add else self.cv1(x)
846
+
847
+
848
+ class C2fCIB(C2f):
849
+ """
850
+ C2fCIB class represents a convolutional block with C2f and CIB modules.
851
+
852
+ Args:
853
+ c1 (int): Number of input channels.
854
+ c2 (int): Number of output channels.
855
+ n (int, optional): Number of CIB modules to stack. Defaults to 1.
856
+ shortcut (bool, optional): Whether to use shortcut connection. Defaults to False.
857
+ lk (bool, optional): Whether to use local key connection. Defaults to False.
858
+ g (int, optional): Number of groups for grouped convolution. Defaults to 1.
859
+ e (float, optional): Expansion ratio for CIB modules. Defaults to 0.5.
860
+ """
861
+
862
+ def __init__(self, c1, c2, n=1, shortcut=False, lk=False, g=1, e=0.5):
863
+ """Initializes the module with specified parameters for channel, shortcut, local key, groups, and expansion."""
864
+ super().__init__(c1, c2, n, shortcut, g, e)
865
+ self.m = nn.ModuleList(CIB(self.c, self.c, shortcut, e=1.0, lk=lk) for _ in range(n))
866
+
867
+
868
+ class Attention(nn.Module):
869
+ """
870
+ Attention module that performs self-attention on the input tensor.
871
+
872
+ Args:
873
+ dim (int): The input tensor dimension.
874
+ num_heads (int): The number of attention heads.
875
+ attn_ratio (float): The ratio of the attention key dimension to the head dimension.
876
+
877
+ Attributes:
878
+ num_heads (int): The number of attention heads.
879
+ head_dim (int): The dimension of each attention head.
880
+ key_dim (int): The dimension of the attention key.
881
+ scale (float): The scaling factor for the attention scores.
882
+ qkv (Conv): Convolutional layer for computing the query, key, and value.
883
+ proj (Conv): Convolutional layer for projecting the attended values.
884
+ pe (Conv): Convolutional layer for positional encoding.
885
+ """
886
+
887
+ def __init__(self, dim, num_heads=8, attn_ratio=0.5):
888
+ """Initializes multi-head attention module with query, key, and value convolutions and positional encoding."""
889
+ super().__init__()
890
+ self.num_heads = num_heads
891
+ self.head_dim = dim // num_heads
892
+ self.key_dim = int(self.head_dim * attn_ratio)
893
+ self.scale = self.key_dim**-0.5
894
+ nh_kd = self.key_dim * num_heads
895
+ h = dim + nh_kd * 2
896
+ self.qkv = Conv(dim, h, 1, act=False)
897
+ self.proj = Conv(dim, dim, 1, act=False)
898
+ self.pe = Conv(dim, dim, 3, 1, g=dim, act=False)
899
+
900
+ def forward(self, x):
901
+ """
902
+ Forward pass of the Attention module.
903
+
904
+ Args:
905
+ x (torch.Tensor): The input tensor.
906
+
907
+ Returns:
908
+ (torch.Tensor): The output tensor after self-attention.
909
+ """
910
+ B, C, H, W = x.shape
911
+ N = H * W
912
+ qkv = self.qkv(x)
913
+ q, k, v = qkv.view(B, self.num_heads, self.key_dim * 2 + self.head_dim, N).split(
914
+ [self.key_dim, self.key_dim, self.head_dim], dim=2
915
+ )
916
+
917
+ attn = (q.transpose(-2, -1) @ k) * self.scale
918
+ attn = attn.softmax(dim=-1)
919
+ x = (v @ attn.transpose(-2, -1)).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W))
920
+ x = self.proj(x)
921
+ return x
922
+
923
+
924
+ class PSABlock(nn.Module):
925
+ """
926
+ PSABlock class implementing a Position-Sensitive Attention block for neural networks.
927
+
928
+ This class encapsulates the functionality for applying multi-head attention and feed-forward neural network layers
929
+ with optional shortcut connections.
930
+
931
+ Attributes:
932
+ attn (Attention): Multi-head attention module.
933
+ ffn (nn.Sequential): Feed-forward neural network module.
934
+ add (bool): Flag indicating whether to add shortcut connections.
935
+
936
+ Methods:
937
+ forward: Performs a forward pass through the PSABlock, applying attention and feed-forward layers.
938
+
939
+ Examples:
940
+ Create a PSABlock and perform a forward pass
941
+ >>> psablock = PSABlock(c=128, attn_ratio=0.5, num_heads=4, shortcut=True)
942
+ >>> input_tensor = torch.randn(1, 128, 32, 32)
943
+ >>> output_tensor = psablock(input_tensor)
944
+ """
945
+
946
+ def __init__(self, c, attn_ratio=0.5, num_heads=4, shortcut=True) -> None:
947
+ """Initializes the PSABlock with attention and feed-forward layers for enhanced feature extraction."""
948
+ super().__init__()
949
+
950
+ self.attn = Attention(c, attn_ratio=attn_ratio, num_heads=num_heads)
951
+ self.ffn = nn.Sequential(Conv(c, c * 2, 1), Conv(c * 2, c, 1, act=False))
952
+ self.add = shortcut
953
+
954
+ def forward(self, x):
955
+ """Executes a forward pass through PSABlock, applying attention and feed-forward layers to the input tensor."""
956
+ x = x + self.attn(x) if self.add else self.attn(x)
957
+ x = x + self.ffn(x) if self.add else self.ffn(x)
958
+ return x
959
+
960
+
961
+ class PSA(nn.Module):
962
+ """
963
+ PSA class for implementing Position-Sensitive Attention in neural networks.
964
+
965
+ This class encapsulates the functionality for applying position-sensitive attention and feed-forward networks to
966
+ input tensors, enhancing feature extraction and processing capabilities.
967
+
968
+ Attributes:
969
+ c (int): Number of hidden channels after applying the initial convolution.
970
+ cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.
971
+ cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.
972
+ attn (Attention): Attention module for position-sensitive attention.
973
+ ffn (nn.Sequential): Feed-forward network for further processing.
974
+
975
+ Methods:
976
+ forward: Applies position-sensitive attention and feed-forward network to the input tensor.
977
+
978
+ Examples:
979
+ Create a PSA module and apply it to an input tensor
980
+ >>> psa = PSA(c1=128, c2=128, e=0.5)
981
+ >>> input_tensor = torch.randn(1, 128, 64, 64)
982
+ >>> output_tensor = psa.forward(input_tensor)
983
+ """
984
+
985
+ def __init__(self, c1, c2, e=0.5):
986
+ """Initializes the PSA module with input/output channels and attention mechanism for feature extraction."""
987
+ super().__init__()
988
+ assert c1 == c2
989
+ self.c = int(c1 * e)
990
+ self.cv1 = Conv(c1, 2 * self.c, 1, 1)
991
+ self.cv2 = Conv(2 * self.c, c1, 1)
992
+
993
+ self.attn = Attention(self.c, attn_ratio=0.5, num_heads=self.c // 64)
994
+ self.ffn = nn.Sequential(Conv(self.c, self.c * 2, 1), Conv(self.c * 2, self.c, 1, act=False))
995
+
996
+ def forward(self, x):
997
+ """Executes forward pass in PSA module, applying attention and feed-forward layers to the input tensor."""
998
+ a, b = self.cv1(x).split((self.c, self.c), dim=1)
999
+ b = b + self.attn(b)
1000
+ b = b + self.ffn(b)
1001
+ return self.cv2(torch.cat((a, b), 1))
1002
+
1003
+
1004
+ class C2PSA(nn.Module):
1005
+ """
1006
+ C2PSA module with attention mechanism for enhanced feature extraction and processing.
1007
+
1008
+ This module implements a convolutional block with attention mechanisms to enhance feature extraction and processing
1009
+ capabilities. It includes a series of PSABlock modules for self-attention and feed-forward operations.
1010
+
1011
+ Attributes:
1012
+ c (int): Number of hidden channels.
1013
+ cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.
1014
+ cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.
1015
+ m (nn.Sequential): Sequential container of PSABlock modules for attention and feed-forward operations.
1016
+
1017
+ Methods:
1018
+ forward: Performs a forward pass through the C2PSA module, applying attention and feed-forward operations.
1019
+
1020
+ Notes:
1021
+ This module essentially is the same as PSA module, but refactored to allow stacking more PSABlock modules.
1022
+
1023
+ Examples:
1024
+ >>> c2psa = C2PSA(c1=256, c2=256, n=3, e=0.5)
1025
+ >>> input_tensor = torch.randn(1, 256, 64, 64)
1026
+ >>> output_tensor = c2psa(input_tensor)
1027
+ """
1028
+
1029
+ def __init__(self, c1, c2, n=1, e=0.5):
1030
+ """Initializes the C2PSA module with specified input/output channels, number of layers, and expansion ratio."""
1031
+ super().__init__()
1032
+ assert c1 == c2
1033
+ self.c = int(c1 * e)
1034
+ self.cv1 = Conv(c1, 2 * self.c, 1, 1)
1035
+ self.cv2 = Conv(2 * self.c, c1, 1)
1036
+
1037
+ self.m = nn.Sequential(*(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n)))
1038
+
1039
+ def forward(self, x):
1040
+ """Processes the input tensor 'x' through a series of PSA blocks and returns the transformed tensor."""
1041
+ a, b = self.cv1(x).split((self.c, self.c), dim=1)
1042
+ b = self.m(b)
1043
+ return self.cv2(torch.cat((a, b), 1))
1044
+
1045
+
1046
+ class C2fPSA(C2f):
1047
+ """
1048
+ C2fPSA module with enhanced feature extraction using PSA blocks.
1049
+
1050
+ This class extends the C2f module by incorporating PSA blocks for improved attention mechanisms and feature extraction.
1051
+
1052
+ Attributes:
1053
+ c (int): Number of hidden channels.
1054
+ cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.
1055
+ cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.
1056
+ m (nn.ModuleList): List of PSA blocks for feature extraction.
1057
+
1058
+ Methods:
1059
+ forward: Performs a forward pass through the C2fPSA module.
1060
+ forward_split: Performs a forward pass using split() instead of chunk().
1061
+
1062
+ Examples:
1063
+ >>> import torch
1064
+ >>> from ultralytics.models.common import C2fPSA
1065
+ >>> model = C2fPSA(c1=64, c2=64, n=3, e=0.5)
1066
+ >>> x = torch.randn(1, 64, 128, 128)
1067
+ >>> output = model(x)
1068
+ >>> print(output.shape)
1069
+ """
1070
+
1071
+ def __init__(self, c1, c2, n=1, e=0.5):
1072
+ """Initializes the C2fPSA module, a variant of C2f with PSA blocks for enhanced feature extraction."""
1073
+ assert c1 == c2
1074
+ super().__init__(c1, c2, n=n, e=e)
1075
+ self.m = nn.ModuleList(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n))
1076
+
1077
+
1078
+ class SCDown(nn.Module):
1079
+ """
1080
+ SCDown module for downsampling with separable convolutions.
1081
+
1082
+ This module performs downsampling using a combination of pointwise and depthwise convolutions, which helps in
1083
+ efficiently reducing the spatial dimensions of the input tensor while maintaining the channel information.
1084
+
1085
+ Attributes:
1086
+ cv1 (Conv): Pointwise convolution layer that reduces the number of channels.
1087
+ cv2 (Conv): Depthwise convolution layer that performs spatial downsampling.
1088
+
1089
+ Methods:
1090
+ forward: Applies the SCDown module to the input tensor.
1091
+
1092
+ Examples:
1093
+ >>> import torch
1094
+ >>> from ultralytics import SCDown
1095
+ >>> model = SCDown(c1=64, c2=128, k=3, s=2)
1096
+ >>> x = torch.randn(1, 64, 128, 128)
1097
+ >>> y = model(x)
1098
+ >>> print(y.shape)
1099
+ torch.Size([1, 128, 64, 64])
1100
+ """
1101
+
1102
+ def __init__(self, c1, c2, k, s):
1103
+ """Initializes the SCDown module with specified input/output channels, kernel size, and stride."""
1104
+ super().__init__()
1105
+ self.cv1 = Conv(c1, c2, 1, 1)
1106
+ self.cv2 = Conv(c2, c2, k=k, s=s, g=c2, act=False)
1107
+
1108
+ def forward(self, x):
1109
+ """Applies convolution and downsampling to the input tensor in the SCDown module."""
1110
+ return self.cv2(self.cv1(x))
1111
+
1112
+
1113
+ class TorchVision(nn.Module):
1114
+ """
1115
+ TorchVision module to allow loading any torchvision model.
1116
+
1117
+ This class provides a way to load a model from the torchvision library, optionally load pre-trained weights, and customize the model by truncating or unwrapping layers.
1118
+
1119
+ Attributes:
1120
+ m (nn.Module): The loaded torchvision model, possibly truncated and unwrapped.
1121
+
1122
+ Args:
1123
+ c1 (int): Input channels.
1124
+ c2 (): Output channels.
1125
+ model (str): Name of the torchvision model to load.
1126
+ weights (str, optional): Pre-trained weights to load. Default is "DEFAULT".
1127
+ unwrap (bool, optional): If True, unwraps the model to a sequential containing all but the last `truncate` layers. Default is True.
1128
+ truncate (int, optional): Number of layers to truncate from the end if `unwrap` is True. Default is 2.
1129
+ split (bool, optional): Returns output from intermediate child modules as list. Default is False.
1130
+ """
1131
+
1132
+ def __init__(self, c1, c2, model, weights="DEFAULT", unwrap=True, truncate=2, split=False):
1133
+ """Load the model and weights from torchvision."""
1134
+ import torchvision
1135
+
1136
+ super().__init__()
1137
+ if hasattr(torchvision.models, "get_model"):
1138
+ self.m = torchvision.models.get_model(model, weights=weights)
1139
+ else:
1140
+ self.m = torchvision.models.__dict__[model](pretrained=bool(weights))
1141
+ if unwrap:
1142
+ layers = list(self.m.children())[:-truncate]
1143
+ if isinstance(layers[0], nn.Sequential): # Second-level for some models like EfficientNet, Swin
1144
+ layers = [*list(layers[0].children()), *layers[1:]]
1145
+ self.m = nn.Sequential(*layers)
1146
+ self.split = split
1147
+ else:
1148
+ self.split = False
1149
+ self.m.head = self.m.heads = nn.Identity()
1150
+
1151
+ def forward(self, x):
1152
+ """Forward pass through the model."""
1153
+ if self.split:
1154
+ y = [x]
1155
+ y.extend(m(y[-1]) for m in self.m)
1156
+ else:
1157
+ y = self.m(x)
1158
+ return y