ultralytics 8.2.36__py3-none-any.whl → 8.2.38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

tests/test_python.py CHANGED
@@ -577,3 +577,12 @@ def test_yolo_world():
577
577
  close_mosaic=1,
578
578
  trainer=WorldTrainerFromScratch,
579
579
  )
580
+
581
+
582
+ def test_yolov10():
583
+ """A simple test for yolov10 for now."""
584
+ model = YOLO("yolov10n.yaml")
585
+ # train/val/predict
586
+ model.train(data="coco8.yaml", epochs=1, imgsz=32, close_mosaic=1, cache="disk")
587
+ model.val(data="coco8.yaml", imgsz=32)
588
+ model(SOURCE)
ultralytics/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
- __version__ = "8.2.36"
3
+ __version__ = "8.2.38"
4
4
 
5
5
  import os
6
6
 
@@ -0,0 +1,42 @@
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv10 object detection model. For Usage examples see https://docs.ultralytics.com/tasks/detect
3
+
4
+ # Parameters
5
+ nc: 80 # number of classes
6
+ scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
7
+ # [depth, width, max_channels]
8
+ b: [0.67, 1.00, 512]
9
+
10
+ backbone:
11
+ # [from, repeats, module, args]
12
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
13
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
14
+ - [-1, 3, C2f, [128, True]]
15
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
16
+ - [-1, 6, C2f, [256, True]]
17
+ - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16
18
+ - [-1, 6, C2f, [512, True]]
19
+ - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32
20
+ - [-1, 3, C2fCIB, [1024, True]]
21
+ - [-1, 1, SPPF, [1024, 5]] # 9
22
+ - [-1, 1, PSA, [1024]] # 10
23
+
24
+ # YOLOv8.0n head
25
+ head:
26
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
27
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
28
+ - [-1, 3, C2fCIB, [512, True]] # 13
29
+
30
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
31
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
32
+ - [-1, 3, C2f, [256]] # 16 (P3/8-small)
33
+
34
+ - [-1, 1, Conv, [256, 3, 2]]
35
+ - [[-1, 13], 1, Concat, [1]] # cat head P4
36
+ - [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium)
37
+
38
+ - [-1, 1, SCDown, [512, 3, 2]]
39
+ - [[-1, 10], 1, Concat, [1]] # cat head P5
40
+ - [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large)
41
+
42
+ - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5)
@@ -0,0 +1,42 @@
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv10 object detection model. For Usage examples see https://docs.ultralytics.com/tasks/detect
3
+
4
+ # Parameters
5
+ nc: 80 # number of classes
6
+ scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
7
+ # [depth, width, max_channels]
8
+ l: [1.00, 1.00, 512]
9
+
10
+ backbone:
11
+ # [from, repeats, module, args]
12
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
13
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
14
+ - [-1, 3, C2f, [128, True]]
15
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
16
+ - [-1, 6, C2f, [256, True]]
17
+ - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16
18
+ - [-1, 6, C2f, [512, True]]
19
+ - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32
20
+ - [-1, 3, C2fCIB, [1024, True]]
21
+ - [-1, 1, SPPF, [1024, 5]] # 9
22
+ - [-1, 1, PSA, [1024]] # 10
23
+
24
+ # YOLOv8.0n head
25
+ head:
26
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
27
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
28
+ - [-1, 3, C2fCIB, [512, True]] # 13
29
+
30
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
31
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
32
+ - [-1, 3, C2f, [256]] # 16 (P3/8-small)
33
+
34
+ - [-1, 1, Conv, [256, 3, 2]]
35
+ - [[-1, 13], 1, Concat, [1]] # cat head P4
36
+ - [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium)
37
+
38
+ - [-1, 1, SCDown, [512, 3, 2]]
39
+ - [[-1, 10], 1, Concat, [1]] # cat head P5
40
+ - [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large)
41
+
42
+ - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5)
@@ -0,0 +1,42 @@
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv10 object detection model. For Usage examples see https://docs.ultralytics.com/tasks/detect
3
+
4
+ # Parameters
5
+ nc: 80 # number of classes
6
+ scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
7
+ # [depth, width, max_channels]
8
+ m: [0.67, 0.75, 768]
9
+
10
+ backbone:
11
+ # [from, repeats, module, args]
12
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
13
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
14
+ - [-1, 3, C2f, [128, True]]
15
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
16
+ - [-1, 6, C2f, [256, True]]
17
+ - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16
18
+ - [-1, 6, C2f, [512, True]]
19
+ - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32
20
+ - [-1, 3, C2fCIB, [1024, True]]
21
+ - [-1, 1, SPPF, [1024, 5]] # 9
22
+ - [-1, 1, PSA, [1024]] # 10
23
+
24
+ # YOLOv8.0n head
25
+ head:
26
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
27
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
28
+ - [-1, 3, C2f, [512]] # 13
29
+
30
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
31
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
32
+ - [-1, 3, C2f, [256]] # 16 (P3/8-small)
33
+
34
+ - [-1, 1, Conv, [256, 3, 2]]
35
+ - [[-1, 13], 1, Concat, [1]] # cat head P4
36
+ - [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium)
37
+
38
+ - [-1, 1, SCDown, [512, 3, 2]]
39
+ - [[-1, 10], 1, Concat, [1]] # cat head P5
40
+ - [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large)
41
+
42
+ - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5)
@@ -0,0 +1,42 @@
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv10 object detection model. For Usage examples see https://docs.ultralytics.com/tasks/detect
3
+
4
+ # Parameters
5
+ nc: 80 # number of classes
6
+ scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
7
+ # [depth, width, max_channels]
8
+ n: [0.33, 0.25, 1024]
9
+
10
+ backbone:
11
+ # [from, repeats, module, args]
12
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
13
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
14
+ - [-1, 3, C2f, [128, True]]
15
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
16
+ - [-1, 6, C2f, [256, True]]
17
+ - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16
18
+ - [-1, 6, C2f, [512, True]]
19
+ - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32
20
+ - [-1, 3, C2f, [1024, True]]
21
+ - [-1, 1, SPPF, [1024, 5]] # 9
22
+ - [-1, 1, PSA, [1024]] # 10
23
+
24
+ # YOLOv8.0n head
25
+ head:
26
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
27
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
28
+ - [-1, 3, C2f, [512]] # 13
29
+
30
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
31
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
32
+ - [-1, 3, C2f, [256]] # 16 (P3/8-small)
33
+
34
+ - [-1, 1, Conv, [256, 3, 2]]
35
+ - [[-1, 13], 1, Concat, [1]] # cat head P4
36
+ - [-1, 3, C2f, [512]] # 19 (P4/16-medium)
37
+
38
+ - [-1, 1, SCDown, [512, 3, 2]]
39
+ - [[-1, 10], 1, Concat, [1]] # cat head P5
40
+ - [-1, 3, C2fCIB, [1024, True, True]] # 22 (P5/32-large)
41
+
42
+ - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5)
@@ -0,0 +1,42 @@
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv10 object detection model. For Usage examples see https://docs.ultralytics.com/tasks/detect
3
+
4
+ # Parameters
5
+ nc: 80 # number of classes
6
+ scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
7
+ # [depth, width, max_channels]
8
+ s: [0.33, 0.50, 1024]
9
+
10
+ backbone:
11
+ # [from, repeats, module, args]
12
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
13
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
14
+ - [-1, 3, C2f, [128, True]]
15
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
16
+ - [-1, 6, C2f, [256, True]]
17
+ - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16
18
+ - [-1, 6, C2f, [512, True]]
19
+ - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32
20
+ - [-1, 3, C2fCIB, [1024, True, True]]
21
+ - [-1, 1, SPPF, [1024, 5]] # 9
22
+ - [-1, 1, PSA, [1024]] # 10
23
+
24
+ # YOLOv8.0n head
25
+ head:
26
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
27
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
28
+ - [-1, 3, C2f, [512]] # 13
29
+
30
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
31
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
32
+ - [-1, 3, C2f, [256]] # 16 (P3/8-small)
33
+
34
+ - [-1, 1, Conv, [256, 3, 2]]
35
+ - [[-1, 13], 1, Concat, [1]] # cat head P4
36
+ - [-1, 3, C2f, [512]] # 19 (P4/16-medium)
37
+
38
+ - [-1, 1, SCDown, [512, 3, 2]]
39
+ - [[-1, 10], 1, Concat, [1]] # cat head P5
40
+ - [-1, 3, C2fCIB, [1024, True, True]] # 22 (P5/32-large)
41
+
42
+ - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5)
@@ -0,0 +1,42 @@
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv10 object detection model. For Usage examples see https://docs.ultralytics.com/tasks/detect
3
+
4
+ # Parameters
5
+ nc: 80 # number of classes
6
+ scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
7
+ # [depth, width, max_channels]
8
+ x: [1.00, 1.25, 512]
9
+
10
+ backbone:
11
+ # [from, repeats, module, args]
12
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
13
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
14
+ - [-1, 3, C2f, [128, True]]
15
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
16
+ - [-1, 6, C2f, [256, True]]
17
+ - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16
18
+ - [-1, 6, C2fCIB, [512, True]]
19
+ - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32
20
+ - [-1, 3, C2fCIB, [1024, True]]
21
+ - [-1, 1, SPPF, [1024, 5]] # 9
22
+ - [-1, 1, PSA, [1024]] # 10
23
+
24
+ # YOLOv8.0n head
25
+ head:
26
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
27
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
28
+ - [-1, 3, C2fCIB, [512, True]] # 13
29
+
30
+ - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
31
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
32
+ - [-1, 3, C2f, [256]] # 16 (P3/8-small)
33
+
34
+ - [-1, 1, Conv, [256, 3, 2]]
35
+ - [[-1, 13], 1, Concat, [1]] # cat head P4
36
+ - [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium)
37
+
38
+ - [-1, 1, SCDown, [512, 3, 2]]
39
+ - [[-1, 10], 1, Concat, [1]] # cat head P5
40
+ - [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large)
41
+
42
+ - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5)
@@ -362,7 +362,7 @@ class LoadImagesAndVideos:
362
362
  self.mode = "image"
363
363
  im0 = cv2.imread(path) # BGR
364
364
  if im0 is None:
365
- raise FileNotFoundError(f"Image Not Found {path}")
365
+ raise FileNotFoundError(f"Image Read Error {path}")
366
366
  paths.append(path)
367
367
  imgs.append(im0)
368
368
  info.append(f"image {self.count + 1}/{self.nf} {path}: ")
@@ -920,6 +920,7 @@ class Exporter:
920
920
  @try_export
921
921
  def export_tflite(self, keras_model, nms, agnostic_nms, prefix=colorstr("TensorFlow Lite:")):
922
922
  """YOLOv8 TensorFlow Lite export."""
923
+ # BUG https://github.com/ultralytics/ultralytics/issues/13436
923
924
  import tensorflow as tf # noqa
924
925
 
925
926
  LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
@@ -22,18 +22,22 @@ from .block import (
22
22
  C2,
23
23
  C3,
24
24
  C3TR,
25
+ CIB,
25
26
  DFL,
26
27
  ELAN1,
28
+ PSA,
27
29
  SPP,
28
30
  SPPELAN,
29
31
  SPPF,
30
32
  AConv,
31
33
  ADown,
34
+ Attention,
32
35
  BNContrastiveHead,
33
36
  Bottleneck,
34
37
  BottleneckCSP,
35
38
  C2f,
36
39
  C2fAttn,
40
+ C2fCIB,
37
41
  C3Ghost,
38
42
  C3x,
39
43
  CBFuse,
@@ -46,7 +50,9 @@ from .block import (
46
50
  Proto,
47
51
  RepC3,
48
52
  RepNCSPELAN4,
53
+ RepVGGDW,
49
54
  ResNetLayer,
55
+ SCDown,
50
56
  )
51
57
  from .conv import (
52
58
  CBAM,
@@ -63,7 +69,7 @@ from .conv import (
63
69
  RepConv,
64
70
  SpatialAttention,
65
71
  )
66
- from .head import OBB, Classify, Detect, Pose, RTDETRDecoder, Segment, WorldDetect
72
+ from .head import OBB, Classify, Detect, Pose, RTDETRDecoder, Segment, WorldDetect, v10Detect
67
73
  from .transformer import (
68
74
  AIFI,
69
75
  MLP,
@@ -137,4 +143,10 @@ __all__ = (
137
143
  "CBLinear",
138
144
  "AConv",
139
145
  "ELAN1",
146
+ "RepVGGDW",
147
+ "CIB",
148
+ "C2fCIB",
149
+ "Attention",
150
+ "PSA",
151
+ "SCDown",
140
152
  )
@@ -5,6 +5,8 @@ import torch
5
5
  import torch.nn as nn
6
6
  import torch.nn.functional as F
7
7
 
8
+ from ultralytics.utils.torch_utils import fuse_conv_and_bn
9
+
8
10
  from .conv import Conv, DWConv, GhostConv, LightConv, RepConv, autopad
9
11
  from .transformer import TransformerBlock
10
12
 
@@ -39,6 +41,12 @@ __all__ = (
39
41
  "CBFuse",
40
42
  "CBLinear",
41
43
  "Silence",
44
+ "RepVGGDW",
45
+ "CIB",
46
+ "C2fCIB",
47
+ "Attention",
48
+ "PSA",
49
+ "SCDown",
42
50
  )
43
51
 
44
52
 
@@ -672,18 +680,6 @@ class SPPELAN(nn.Module):
672
680
  return self.cv5(torch.cat(y, 1))
673
681
 
674
682
 
675
- class Silence(nn.Module):
676
- """Silence."""
677
-
678
- def __init__(self):
679
- """Initializes the Silence module."""
680
- super(Silence, self).__init__()
681
-
682
- def forward(self, x):
683
- """Forward pass through Silence layer."""
684
- return x
685
-
686
-
687
683
  class CBLinear(nn.Module):
688
684
  """CBLinear."""
689
685
 
@@ -711,3 +707,251 @@ class CBFuse(nn.Module):
711
707
  target_size = xs[-1].shape[2:]
712
708
  res = [F.interpolate(x[self.idx[i]], size=target_size, mode="nearest") for i, x in enumerate(xs[:-1])]
713
709
  return torch.sum(torch.stack(res + xs[-1:]), dim=0)
710
+
711
+
712
+ class RepVGGDW(torch.nn.Module):
713
+ """RepVGGDW is a class that represents a depth wise separable convolutional block in RepVGG architecture."""
714
+
715
+ def __init__(self, ed) -> None:
716
+ super().__init__()
717
+ self.conv = Conv(ed, ed, 7, 1, 3, g=ed, act=False)
718
+ self.conv1 = Conv(ed, ed, 3, 1, 1, g=ed, act=False)
719
+ self.dim = ed
720
+ self.act = nn.SiLU()
721
+
722
+ def forward(self, x):
723
+ """
724
+ Performs a forward pass of the RepVGGDW block.
725
+
726
+ Args:
727
+ x (torch.Tensor): Input tensor.
728
+
729
+ Returns:
730
+ (torch.Tensor): Output tensor after applying the depth wise separable convolution.
731
+ """
732
+ return self.act(self.conv(x) + self.conv1(x))
733
+
734
+ def forward_fuse(self, x):
735
+ """
736
+ Performs a forward pass of the RepVGGDW block without fusing the convolutions.
737
+
738
+ Args:
739
+ x (torch.Tensor): Input tensor.
740
+
741
+ Returns:
742
+ (torch.Tensor): Output tensor after applying the depth wise separable convolution.
743
+ """
744
+ return self.act(self.conv(x))
745
+
746
+ @torch.no_grad()
747
+ def fuse(self):
748
+ """
749
+ Fuses the convolutional layers in the RepVGGDW block.
750
+
751
+ This method fuses the convolutional layers and updates the weights and biases accordingly.
752
+ """
753
+ conv = fuse_conv_and_bn(self.conv.conv, self.conv.bn)
754
+ conv1 = fuse_conv_and_bn(self.conv1.conv, self.conv1.bn)
755
+
756
+ conv_w = conv.weight
757
+ conv_b = conv.bias
758
+ conv1_w = conv1.weight
759
+ conv1_b = conv1.bias
760
+
761
+ conv1_w = torch.nn.functional.pad(conv1_w, [2, 2, 2, 2])
762
+
763
+ final_conv_w = conv_w + conv1_w
764
+ final_conv_b = conv_b + conv1_b
765
+
766
+ conv.weight.data.copy_(final_conv_w)
767
+ conv.bias.data.copy_(final_conv_b)
768
+
769
+ self.conv = conv
770
+ del self.conv1
771
+
772
+
773
+ class CIB(nn.Module):
774
+ """
775
+ Conditional Identity Block (CIB) module.
776
+
777
+ Args:
778
+ c1 (int): Number of input channels.
779
+ c2 (int): Number of output channels.
780
+ shortcut (bool, optional): Whether to add a shortcut connection. Defaults to True.
781
+ e (float, optional): Scaling factor for the hidden channels. Defaults to 0.5.
782
+ lk (bool, optional): Whether to use RepVGGDW for the third convolutional layer. Defaults to False.
783
+ """
784
+
785
+ def __init__(self, c1, c2, shortcut=True, e=0.5, lk=False):
786
+ """Initializes the custom model with optional shortcut, scaling factor, and RepVGGDW layer."""
787
+ super().__init__()
788
+ c_ = int(c2 * e) # hidden channels
789
+ self.cv1 = nn.Sequential(
790
+ Conv(c1, c1, 3, g=c1),
791
+ Conv(c1, 2 * c_, 1),
792
+ Conv(2 * c_, 2 * c_, 3, g=2 * c_) if not lk else RepVGGDW(2 * c_),
793
+ Conv(2 * c_, c2, 1),
794
+ Conv(c2, c2, 3, g=c2),
795
+ )
796
+
797
+ self.add = shortcut and c1 == c2
798
+
799
+ def forward(self, x):
800
+ """
801
+ Forward pass of the CIB module.
802
+
803
+ Args:
804
+ x (torch.Tensor): Input tensor.
805
+
806
+ Returns:
807
+ (torch.Tensor): Output tensor.
808
+ """
809
+ return x + self.cv1(x) if self.add else self.cv1(x)
810
+
811
+
812
+ class C2fCIB(C2f):
813
+ """
814
+ C2fCIB class represents a convolutional block with C2f and CIB modules.
815
+
816
+ Args:
817
+ c1 (int): Number of input channels.
818
+ c2 (int): Number of output channels.
819
+ n (int, optional): Number of CIB modules to stack. Defaults to 1.
820
+ shortcut (bool, optional): Whether to use shortcut connection. Defaults to False.
821
+ lk (bool, optional): Whether to use local key connection. Defaults to False.
822
+ g (int, optional): Number of groups for grouped convolution. Defaults to 1.
823
+ e (float, optional): Expansion ratio for CIB modules. Defaults to 0.5.
824
+ """
825
+
826
+ def __init__(self, c1, c2, n=1, shortcut=False, lk=False, g=1, e=0.5):
827
+ """Initializes the module with specified parameters for channel, shortcut, local key, groups, and expansion."""
828
+ super().__init__(c1, c2, n, shortcut, g, e)
829
+ self.m = nn.ModuleList(CIB(self.c, self.c, shortcut, e=1.0, lk=lk) for _ in range(n))
830
+
831
+
832
+ class Attention(nn.Module):
833
+ """
834
+ Attention module that performs self-attention on the input tensor.
835
+
836
+ Args:
837
+ dim (int): The input tensor dimension.
838
+ num_heads (int): The number of attention heads.
839
+ attn_ratio (float): The ratio of the attention key dimension to the head dimension.
840
+
841
+ Attributes:
842
+ num_heads (int): The number of attention heads.
843
+ head_dim (int): The dimension of each attention head.
844
+ key_dim (int): The dimension of the attention key.
845
+ scale (float): The scaling factor for the attention scores.
846
+ qkv (Conv): Convolutional layer for computing the query, key, and value.
847
+ proj (Conv): Convolutional layer for projecting the attended values.
848
+ pe (Conv): Convolutional layer for positional encoding.
849
+ """
850
+
851
+ def __init__(self, dim, num_heads=8, attn_ratio=0.5):
852
+ """Initializes multi-head attention module with query, key, and value convolutions and positional encoding."""
853
+ super().__init__()
854
+ self.num_heads = num_heads
855
+ self.head_dim = dim // num_heads
856
+ self.key_dim = int(self.head_dim * attn_ratio)
857
+ self.scale = self.key_dim**-0.5
858
+ nh_kd = nh_kd = self.key_dim * num_heads
859
+ h = dim + nh_kd * 2
860
+ self.qkv = Conv(dim, h, 1, act=False)
861
+ self.proj = Conv(dim, dim, 1, act=False)
862
+ self.pe = Conv(dim, dim, 3, 1, g=dim, act=False)
863
+
864
+ def forward(self, x):
865
+ """
866
+ Forward pass of the Attention module.
867
+
868
+ Args:
869
+ x (torch.Tensor): The input tensor.
870
+
871
+ Returns:
872
+ (torch.Tensor): The output tensor after self-attention.
873
+ """
874
+ B, C, H, W = x.shape
875
+ N = H * W
876
+ qkv = self.qkv(x)
877
+ q, k, v = qkv.view(B, self.num_heads, self.key_dim * 2 + self.head_dim, N).split(
878
+ [self.key_dim, self.key_dim, self.head_dim], dim=2
879
+ )
880
+
881
+ attn = (q.transpose(-2, -1) @ k) * self.scale
882
+ attn = attn.softmax(dim=-1)
883
+ x = (v @ attn.transpose(-2, -1)).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W))
884
+ x = self.proj(x)
885
+ return x
886
+
887
+
888
+ class PSA(nn.Module):
889
+ """
890
+ Position-wise Spatial Attention module.
891
+
892
+ Args:
893
+ c1 (int): Number of input channels.
894
+ c2 (int): Number of output channels.
895
+ e (float): Expansion factor for the intermediate channels. Default is 0.5.
896
+
897
+ Attributes:
898
+ c (int): Number of intermediate channels.
899
+ cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.
900
+ cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.
901
+ attn (Attention): Attention module for spatial attention.
902
+ ffn (nn.Sequential): Feed-forward network module.
903
+ """
904
+
905
+ def __init__(self, c1, c2, e=0.5):
906
+ """Initializes convolution layers, attention module, and feed-forward network with channel reduction."""
907
+ super().__init__()
908
+ assert c1 == c2
909
+ self.c = int(c1 * e)
910
+ self.cv1 = Conv(c1, 2 * self.c, 1, 1)
911
+ self.cv2 = Conv(2 * self.c, c1, 1)
912
+
913
+ self.attn = Attention(self.c, attn_ratio=0.5, num_heads=self.c // 64)
914
+ self.ffn = nn.Sequential(Conv(self.c, self.c * 2, 1), Conv(self.c * 2, self.c, 1, act=False))
915
+
916
+ def forward(self, x):
917
+ """
918
+ Forward pass of the PSA module.
919
+
920
+ Args:
921
+ x (torch.Tensor): Input tensor.
922
+
923
+ Returns:
924
+ (torch.Tensor): Output tensor.
925
+ """
926
+ a, b = self.cv1(x).split((self.c, self.c), dim=1)
927
+ b = b + self.attn(b)
928
+ b = b + self.ffn(b)
929
+ return self.cv2(torch.cat((a, b), 1))
930
+
931
+
932
+ class SCDown(nn.Module):
933
+ def __init__(self, c1, c2, k, s):
934
+ """
935
+ Spatial Channel Downsample (SCDown) module.
936
+
937
+ Args:
938
+ c1 (int): Number of input channels.
939
+ c2 (int): Number of output channels.
940
+ k (int): Kernel size for the convolutional layer.
941
+ s (int): Stride for the convolutional layer.
942
+ """
943
+ super().__init__()
944
+ self.cv1 = Conv(c1, c2, 1, 1)
945
+ self.cv2 = Conv(c2, c2, k=k, s=s, g=c2, act=False)
946
+
947
+ def forward(self, x):
948
+ """
949
+ Forward pass of the SCDown module.
950
+
951
+ Args:
952
+ x (torch.Tensor): Input tensor.
953
+
954
+ Returns:
955
+ (torch.Tensor): Output tensor after applying the SCDown module.
956
+ """
957
+ return self.cv2(self.cv1(x))