dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
|
@@ -9,8 +9,9 @@
|
|
|
9
9
|
# Build the TinyViT Model
|
|
10
10
|
# --------------------------------------------------------
|
|
11
11
|
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
12
14
|
import itertools
|
|
13
|
-
from typing import Tuple
|
|
14
15
|
|
|
15
16
|
import torch
|
|
16
17
|
import torch.nn as nn
|
|
@@ -21,35 +22,47 @@ from ultralytics.utils.instance import to_2tuple
|
|
|
21
22
|
|
|
22
23
|
|
|
23
24
|
class Conv2d_BN(torch.nn.Sequential):
|
|
24
|
-
"""
|
|
25
|
-
|
|
25
|
+
"""A sequential container that performs 2D convolution followed by batch normalization.
|
|
26
|
+
|
|
27
|
+
This module combines a 2D convolution layer with batch normalization, providing a common building block for
|
|
28
|
+
convolutional neural networks. The batch normalization weights and biases are initialized to specific values for
|
|
29
|
+
optimal training performance.
|
|
26
30
|
|
|
27
31
|
Attributes:
|
|
28
32
|
c (torch.nn.Conv2d): 2D convolution layer.
|
|
29
33
|
bn (torch.nn.BatchNorm2d): Batch normalization layer.
|
|
30
34
|
|
|
31
|
-
Methods:
|
|
32
|
-
__init__: Initializes the Conv2d_BN with specified parameters.
|
|
33
|
-
|
|
34
|
-
Args:
|
|
35
|
-
a (int): Number of input channels.
|
|
36
|
-
b (int): Number of output channels.
|
|
37
|
-
ks (int): Kernel size for the convolution. Defaults to 1.
|
|
38
|
-
stride (int): Stride for the convolution. Defaults to 1.
|
|
39
|
-
pad (int): Padding for the convolution. Defaults to 0.
|
|
40
|
-
dilation (int): Dilation factor for the convolution. Defaults to 1.
|
|
41
|
-
groups (int): Number of groups for the convolution. Defaults to 1.
|
|
42
|
-
bn_weight_init (float): Initial value for batch normalization weight. Defaults to 1.
|
|
43
|
-
|
|
44
35
|
Examples:
|
|
45
36
|
>>> conv_bn = Conv2d_BN(3, 64, ks=3, stride=1, pad=1)
|
|
46
37
|
>>> input_tensor = torch.randn(1, 3, 224, 224)
|
|
47
38
|
>>> output = conv_bn(input_tensor)
|
|
48
39
|
>>> print(output.shape)
|
|
40
|
+
torch.Size([1, 64, 224, 224])
|
|
49
41
|
"""
|
|
50
42
|
|
|
51
|
-
def __init__(
|
|
52
|
-
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
a: int,
|
|
46
|
+
b: int,
|
|
47
|
+
ks: int = 1,
|
|
48
|
+
stride: int = 1,
|
|
49
|
+
pad: int = 0,
|
|
50
|
+
dilation: int = 1,
|
|
51
|
+
groups: int = 1,
|
|
52
|
+
bn_weight_init: float = 1,
|
|
53
|
+
):
|
|
54
|
+
"""Initialize a sequential container with 2D convolution followed by batch normalization.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
a (int): Number of input channels.
|
|
58
|
+
b (int): Number of output channels.
|
|
59
|
+
ks (int, optional): Kernel size for the convolution.
|
|
60
|
+
stride (int, optional): Stride for the convolution.
|
|
61
|
+
pad (int, optional): Padding for the convolution.
|
|
62
|
+
dilation (int, optional): Dilation factor for the convolution.
|
|
63
|
+
groups (int, optional): Number of groups for the convolution.
|
|
64
|
+
bn_weight_init (float, optional): Initial value for batch normalization weight.
|
|
65
|
+
"""
|
|
53
66
|
super().__init__()
|
|
54
67
|
self.add_module("c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
|
|
55
68
|
bn = torch.nn.BatchNorm2d(b)
|
|
@@ -59,31 +72,38 @@ class Conv2d_BN(torch.nn.Sequential):
|
|
|
59
72
|
|
|
60
73
|
|
|
61
74
|
class PatchEmbed(nn.Module):
|
|
62
|
-
"""
|
|
63
|
-
|
|
75
|
+
"""Embed images into patches and project them into a specified embedding dimension.
|
|
76
|
+
|
|
77
|
+
This module converts input images into patch embeddings using a sequence of convolutional layers, effectively
|
|
78
|
+
downsampling the spatial dimensions while increasing the channel dimension.
|
|
64
79
|
|
|
65
80
|
Attributes:
|
|
66
|
-
patches_resolution (
|
|
81
|
+
patches_resolution (tuple[int, int]): Resolution of the patches after embedding.
|
|
67
82
|
num_patches (int): Total number of patches.
|
|
68
83
|
in_chans (int): Number of input channels.
|
|
69
84
|
embed_dim (int): Dimension of the embedding.
|
|
70
85
|
seq (nn.Sequential): Sequence of convolutional and activation layers for patch embedding.
|
|
71
86
|
|
|
72
|
-
Methods:
|
|
73
|
-
forward: Processes the input tensor through the patch embedding sequence.
|
|
74
|
-
|
|
75
87
|
Examples:
|
|
76
88
|
>>> import torch
|
|
77
89
|
>>> patch_embed = PatchEmbed(in_chans=3, embed_dim=96, resolution=224, activation=nn.GELU)
|
|
78
90
|
>>> x = torch.randn(1, 3, 224, 224)
|
|
79
91
|
>>> output = patch_embed(x)
|
|
80
92
|
>>> print(output.shape)
|
|
93
|
+
torch.Size([1, 96, 56, 56])
|
|
81
94
|
"""
|
|
82
95
|
|
|
83
|
-
def __init__(self, in_chans, embed_dim, resolution, activation):
|
|
84
|
-
"""
|
|
96
|
+
def __init__(self, in_chans: int, embed_dim: int, resolution: int, activation):
|
|
97
|
+
"""Initialize patch embedding with convolutional layers for image-to-patch conversion and projection.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
in_chans (int): Number of input channels.
|
|
101
|
+
embed_dim (int): Dimension of the embedding.
|
|
102
|
+
resolution (int): Input image resolution.
|
|
103
|
+
activation (nn.Module): Activation function to use between convolutions.
|
|
104
|
+
"""
|
|
85
105
|
super().__init__()
|
|
86
|
-
img_size:
|
|
106
|
+
img_size: tuple[int, int] = to_2tuple(resolution)
|
|
87
107
|
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
|
|
88
108
|
self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
|
|
89
109
|
self.in_chans = in_chans
|
|
@@ -95,30 +115,29 @@ class PatchEmbed(nn.Module):
|
|
|
95
115
|
Conv2d_BN(n // 2, n, 3, 2, 1),
|
|
96
116
|
)
|
|
97
117
|
|
|
98
|
-
def forward(self, x):
|
|
99
|
-
"""
|
|
118
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
119
|
+
"""Process input tensor through patch embedding sequence, converting images to patch embeddings."""
|
|
100
120
|
return self.seq(x)
|
|
101
121
|
|
|
102
122
|
|
|
103
123
|
class MBConv(nn.Module):
|
|
104
|
-
"""
|
|
105
|
-
|
|
124
|
+
"""Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture.
|
|
125
|
+
|
|
126
|
+
This module implements the mobile inverted bottleneck convolution with expansion, depthwise convolution, and
|
|
127
|
+
projection phases, along with residual connections for improved gradient flow.
|
|
106
128
|
|
|
107
129
|
Attributes:
|
|
108
130
|
in_chans (int): Number of input channels.
|
|
109
|
-
hidden_chans (int): Number of hidden channels.
|
|
131
|
+
hidden_chans (int): Number of hidden channels after expansion.
|
|
110
132
|
out_chans (int): Number of output channels.
|
|
111
|
-
conv1 (Conv2d_BN): First convolutional layer.
|
|
133
|
+
conv1 (Conv2d_BN): First convolutional layer for channel expansion.
|
|
112
134
|
act1 (nn.Module): First activation function.
|
|
113
135
|
conv2 (Conv2d_BN): Depthwise convolutional layer.
|
|
114
136
|
act2 (nn.Module): Second activation function.
|
|
115
|
-
conv3 (Conv2d_BN): Final convolutional layer.
|
|
137
|
+
conv3 (Conv2d_BN): Final convolutional layer for projection.
|
|
116
138
|
act3 (nn.Module): Third activation function.
|
|
117
139
|
drop_path (nn.Module): Drop path layer (Identity for inference).
|
|
118
140
|
|
|
119
|
-
Methods:
|
|
120
|
-
forward: Performs the forward pass through the MBConv layer.
|
|
121
|
-
|
|
122
141
|
Examples:
|
|
123
142
|
>>> in_chans, out_chans = 32, 64
|
|
124
143
|
>>> mbconv = MBConv(in_chans, out_chans, expand_ratio=4, activation=nn.ReLU, drop_path=0.1)
|
|
@@ -128,8 +147,16 @@ class MBConv(nn.Module):
|
|
|
128
147
|
torch.Size([1, 64, 56, 56])
|
|
129
148
|
"""
|
|
130
149
|
|
|
131
|
-
def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
|
|
132
|
-
"""
|
|
150
|
+
def __init__(self, in_chans: int, out_chans: int, expand_ratio: float, activation, drop_path: float):
|
|
151
|
+
"""Initialize the MBConv layer with specified input/output channels, expansion ratio, and activation.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
in_chans (int): Number of input channels.
|
|
155
|
+
out_chans (int): Number of output channels.
|
|
156
|
+
expand_ratio (float): Channel expansion ratio for the hidden layer.
|
|
157
|
+
activation (nn.Module): Activation function to use.
|
|
158
|
+
drop_path (float): Drop path rate for stochastic depth.
|
|
159
|
+
"""
|
|
133
160
|
super().__init__()
|
|
134
161
|
self.in_chans = in_chans
|
|
135
162
|
self.hidden_chans = int(in_chans * expand_ratio)
|
|
@@ -148,8 +175,8 @@ class MBConv(nn.Module):
|
|
|
148
175
|
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
149
176
|
self.drop_path = nn.Identity()
|
|
150
177
|
|
|
151
|
-
def forward(self, x):
|
|
152
|
-
"""
|
|
178
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
179
|
+
"""Implement the forward pass of MBConv, applying convolutions and skip connection."""
|
|
153
180
|
shortcut = x
|
|
154
181
|
x = self.conv1(x)
|
|
155
182
|
x = self.act1(x)
|
|
@@ -162,14 +189,14 @@ class MBConv(nn.Module):
|
|
|
162
189
|
|
|
163
190
|
|
|
164
191
|
class PatchMerging(nn.Module):
|
|
165
|
-
"""
|
|
166
|
-
Merges neighboring patches in the feature map and projects to a new dimension.
|
|
192
|
+
"""Merge neighboring patches in the feature map and project to a new dimension.
|
|
167
193
|
|
|
168
|
-
This class implements a patch merging operation that combines spatial information and adjusts the feature
|
|
169
|
-
|
|
194
|
+
This class implements a patch merging operation that combines spatial information and adjusts the feature dimension
|
|
195
|
+
using a series of convolutional layers with batch normalization. It effectively reduces spatial resolution while
|
|
196
|
+
potentially increasing channel dimensions.
|
|
170
197
|
|
|
171
198
|
Attributes:
|
|
172
|
-
input_resolution (
|
|
199
|
+
input_resolution (tuple[int, int]): The input resolution (height, width) of the feature map.
|
|
173
200
|
dim (int): The input dimension of the feature map.
|
|
174
201
|
out_dim (int): The output dimension after merging and projection.
|
|
175
202
|
act (nn.Module): The activation function used between convolutions.
|
|
@@ -177,19 +204,24 @@ class PatchMerging(nn.Module):
|
|
|
177
204
|
conv2 (Conv2d_BN): The second convolutional layer for spatial merging.
|
|
178
205
|
conv3 (Conv2d_BN): The third convolutional layer for final projection.
|
|
179
206
|
|
|
180
|
-
Methods:
|
|
181
|
-
forward: Applies the patch merging operation to the input tensor.
|
|
182
|
-
|
|
183
207
|
Examples:
|
|
184
208
|
>>> input_resolution = (56, 56)
|
|
185
209
|
>>> patch_merging = PatchMerging(input_resolution, dim=64, out_dim=128, activation=nn.ReLU)
|
|
186
210
|
>>> x = torch.randn(4, 64, 56, 56)
|
|
187
211
|
>>> output = patch_merging(x)
|
|
188
212
|
>>> print(output.shape)
|
|
213
|
+
torch.Size([4, 3136, 128])
|
|
189
214
|
"""
|
|
190
215
|
|
|
191
|
-
def __init__(self, input_resolution, dim, out_dim, activation):
|
|
192
|
-
"""
|
|
216
|
+
def __init__(self, input_resolution: tuple[int, int], dim: int, out_dim: int, activation):
|
|
217
|
+
"""Initialize the PatchMerging module for merging and projecting neighboring patches in feature maps.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
input_resolution (tuple[int, int]): The input resolution (height, width) of the feature map.
|
|
221
|
+
dim (int): The input dimension of the feature map.
|
|
222
|
+
out_dim (int): The output dimension after merging and projection.
|
|
223
|
+
activation (nn.Module): The activation function used between convolutions.
|
|
224
|
+
"""
|
|
193
225
|
super().__init__()
|
|
194
226
|
|
|
195
227
|
self.input_resolution = input_resolution
|
|
@@ -201,8 +233,8 @@ class PatchMerging(nn.Module):
|
|
|
201
233
|
self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
|
|
202
234
|
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
|
|
203
235
|
|
|
204
|
-
def forward(self, x):
|
|
205
|
-
"""
|
|
236
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
237
|
+
"""Apply patch merging and dimension projection to the input feature map."""
|
|
206
238
|
if x.ndim == 3:
|
|
207
239
|
H, W = self.input_resolution
|
|
208
240
|
B = len(x)
|
|
@@ -219,63 +251,54 @@ class PatchMerging(nn.Module):
|
|
|
219
251
|
|
|
220
252
|
|
|
221
253
|
class ConvLayer(nn.Module):
|
|
222
|
-
"""
|
|
223
|
-
Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv).
|
|
254
|
+
"""Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv).
|
|
224
255
|
|
|
225
|
-
This layer optionally applies downsample operations to the output and supports gradient checkpointing
|
|
256
|
+
This layer optionally applies downsample operations to the output and supports gradient checkpointing for memory
|
|
257
|
+
efficiency during training.
|
|
226
258
|
|
|
227
259
|
Attributes:
|
|
228
260
|
dim (int): Dimensionality of the input and output.
|
|
229
|
-
input_resolution (
|
|
261
|
+
input_resolution (tuple[int, int]): Resolution of the input image.
|
|
230
262
|
depth (int): Number of MBConv layers in the block.
|
|
231
263
|
use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
|
|
232
264
|
blocks (nn.ModuleList): List of MBConv layers.
|
|
233
|
-
downsample (Optional[
|
|
234
|
-
|
|
235
|
-
Methods:
|
|
236
|
-
forward: Processes the input through the convolutional layers.
|
|
265
|
+
downsample (Optional[nn.Module]): Function for downsampling the output.
|
|
237
266
|
|
|
238
267
|
Examples:
|
|
239
268
|
>>> input_tensor = torch.randn(1, 64, 56, 56)
|
|
240
269
|
>>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU)
|
|
241
270
|
>>> output = conv_layer(input_tensor)
|
|
242
271
|
>>> print(output.shape)
|
|
272
|
+
torch.Size([1, 3136, 128])
|
|
243
273
|
"""
|
|
244
274
|
|
|
245
275
|
def __init__(
|
|
246
276
|
self,
|
|
247
|
-
dim,
|
|
248
|
-
input_resolution,
|
|
249
|
-
depth,
|
|
277
|
+
dim: int,
|
|
278
|
+
input_resolution: tuple[int, int],
|
|
279
|
+
depth: int,
|
|
250
280
|
activation,
|
|
251
|
-
drop_path=0.0,
|
|
252
|
-
downsample=None,
|
|
253
|
-
use_checkpoint=False,
|
|
254
|
-
out_dim=None,
|
|
255
|
-
conv_expand_ratio=4.0,
|
|
281
|
+
drop_path: float | list[float] = 0.0,
|
|
282
|
+
downsample: nn.Module | None = None,
|
|
283
|
+
use_checkpoint: bool = False,
|
|
284
|
+
out_dim: int | None = None,
|
|
285
|
+
conv_expand_ratio: float = 4.0,
|
|
256
286
|
):
|
|
257
|
-
"""
|
|
258
|
-
Initializes the ConvLayer with the given dimensions and settings.
|
|
287
|
+
"""Initialize the ConvLayer with the given dimensions and settings.
|
|
259
288
|
|
|
260
|
-
This layer consists of multiple MobileNetV3-style inverted bottleneck convolutions (MBConv) and
|
|
261
|
-
|
|
289
|
+
This layer consists of multiple MobileNetV3-style inverted bottleneck convolutions (MBConv) and optionally
|
|
290
|
+
applies downsampling to the output.
|
|
262
291
|
|
|
263
292
|
Args:
|
|
264
293
|
dim (int): The dimensionality of the input and output.
|
|
265
|
-
input_resolution (
|
|
294
|
+
input_resolution (tuple[int, int]): The resolution of the input image.
|
|
266
295
|
depth (int): The number of MBConv layers in the block.
|
|
267
296
|
activation (nn.Module): Activation function applied after each convolution.
|
|
268
|
-
drop_path (float |
|
|
269
|
-
downsample (Optional[nn.Module]): Function for downsampling the output. None to skip downsampling.
|
|
270
|
-
use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
|
|
271
|
-
out_dim (Optional[int]):
|
|
272
|
-
conv_expand_ratio (float): Expansion ratio for the MBConv layers.
|
|
273
|
-
|
|
274
|
-
Examples:
|
|
275
|
-
>>> input_tensor = torch.randn(1, 64, 56, 56)
|
|
276
|
-
>>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU)
|
|
277
|
-
>>> output = conv_layer(input_tensor)
|
|
278
|
-
>>> print(output.shape)
|
|
297
|
+
drop_path (float | list[float], optional): Drop path rate. Single float or a list of floats for each MBConv.
|
|
298
|
+
downsample (Optional[nn.Module], optional): Function for downsampling the output. None to skip downsampling.
|
|
299
|
+
use_checkpoint (bool, optional): Whether to use gradient checkpointing to save memory.
|
|
300
|
+
out_dim (Optional[int], optional): Output dimensions. None means it will be the same as `dim`.
|
|
301
|
+
conv_expand_ratio (float, optional): Expansion ratio for the MBConv layers.
|
|
279
302
|
"""
|
|
280
303
|
super().__init__()
|
|
281
304
|
self.dim = dim
|
|
@@ -304,19 +327,18 @@ class ConvLayer(nn.Module):
|
|
|
304
327
|
else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
|
305
328
|
)
|
|
306
329
|
|
|
307
|
-
def forward(self, x):
|
|
308
|
-
"""
|
|
330
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
331
|
+
"""Process input through convolutional layers, applying MBConv blocks and optional downsampling."""
|
|
309
332
|
for blk in self.blocks:
|
|
310
333
|
x = torch.utils.checkpoint(blk, x) if self.use_checkpoint else blk(x) # warn: checkpoint is slow import
|
|
311
334
|
return x if self.downsample is None else self.downsample(x)
|
|
312
335
|
|
|
313
336
|
|
|
314
|
-
class
|
|
315
|
-
"""
|
|
316
|
-
Multi-layer Perceptron (MLP) module for transformer architectures.
|
|
337
|
+
class MLP(nn.Module):
|
|
338
|
+
"""Multi-layer Perceptron (MLP) module for transformer architectures.
|
|
317
339
|
|
|
318
|
-
This module applies layer normalization, two fully-connected layers with an activation function in between,
|
|
319
|
-
|
|
340
|
+
This module applies layer normalization, two fully-connected layers with an activation function in between, and
|
|
341
|
+
dropout. It is commonly used in transformer-based architectures for processing token embeddings.
|
|
320
342
|
|
|
321
343
|
Attributes:
|
|
322
344
|
norm (nn.LayerNorm): Layer normalization applied to the input.
|
|
@@ -325,32 +347,44 @@ class Mlp(nn.Module):
|
|
|
325
347
|
act (nn.Module): Activation function applied after the first fully-connected layer.
|
|
326
348
|
drop (nn.Dropout): Dropout layer applied after the activation function.
|
|
327
349
|
|
|
328
|
-
Methods:
|
|
329
|
-
forward: Applies the MLP operations on the input tensor.
|
|
330
|
-
|
|
331
350
|
Examples:
|
|
332
351
|
>>> import torch
|
|
333
352
|
>>> from torch import nn
|
|
334
|
-
>>> mlp =
|
|
353
|
+
>>> mlp = MLP(in_features=256, hidden_features=512, out_features=256, activation=nn.GELU, drop=0.1)
|
|
335
354
|
>>> x = torch.randn(32, 100, 256)
|
|
336
355
|
>>> output = mlp(x)
|
|
337
356
|
>>> print(output.shape)
|
|
338
357
|
torch.Size([32, 100, 256])
|
|
339
358
|
"""
|
|
340
359
|
|
|
341
|
-
def __init__(
|
|
342
|
-
|
|
360
|
+
def __init__(
|
|
361
|
+
self,
|
|
362
|
+
in_features: int,
|
|
363
|
+
hidden_features: int | None = None,
|
|
364
|
+
out_features: int | None = None,
|
|
365
|
+
activation=nn.GELU,
|
|
366
|
+
drop: float = 0.0,
|
|
367
|
+
):
|
|
368
|
+
"""Initialize a multi-layer perceptron with configurable input, hidden, and output dimensions.
|
|
369
|
+
|
|
370
|
+
Args:
|
|
371
|
+
in_features (int): Number of input features.
|
|
372
|
+
hidden_features (Optional[int], optional): Number of hidden features.
|
|
373
|
+
out_features (Optional[int], optional): Number of output features.
|
|
374
|
+
activation (nn.Module): Activation function applied after the first fully-connected layer.
|
|
375
|
+
drop (float, optional): Dropout probability.
|
|
376
|
+
"""
|
|
343
377
|
super().__init__()
|
|
344
378
|
out_features = out_features or in_features
|
|
345
379
|
hidden_features = hidden_features or in_features
|
|
346
380
|
self.norm = nn.LayerNorm(in_features)
|
|
347
381
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
|
348
382
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
|
349
|
-
self.act =
|
|
383
|
+
self.act = activation()
|
|
350
384
|
self.drop = nn.Dropout(drop)
|
|
351
385
|
|
|
352
|
-
def forward(self, x):
|
|
353
|
-
"""
|
|
386
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
387
|
+
"""Apply MLP operations: layer norm, FC layers, activation, and dropout to the input tensor."""
|
|
354
388
|
x = self.norm(x)
|
|
355
389
|
x = self.fc1(x)
|
|
356
390
|
x = self.act(x)
|
|
@@ -360,12 +394,11 @@ class Mlp(nn.Module):
|
|
|
360
394
|
|
|
361
395
|
|
|
362
396
|
class Attention(torch.nn.Module):
|
|
363
|
-
"""
|
|
364
|
-
Multi-head attention module with spatial awareness and trainable attention biases.
|
|
397
|
+
"""Multi-head attention module with spatial awareness and trainable attention biases.
|
|
365
398
|
|
|
366
|
-
This module implements a multi-head attention mechanism with support for spatial awareness, applying
|
|
367
|
-
|
|
368
|
-
|
|
399
|
+
This module implements a multi-head attention mechanism with support for spatial awareness, applying attention
|
|
400
|
+
biases based on spatial resolution. It includes trainable attention biases for each unique offset between spatial
|
|
401
|
+
positions in the resolution grid.
|
|
369
402
|
|
|
370
403
|
Attributes:
|
|
371
404
|
num_heads (int): Number of attention heads.
|
|
@@ -379,12 +412,8 @@ class Attention(torch.nn.Module):
|
|
|
379
412
|
qkv (nn.Linear): Linear layer for computing query, key, and value projections.
|
|
380
413
|
proj (nn.Linear): Linear layer for final projection.
|
|
381
414
|
attention_biases (nn.Parameter): Learnable attention biases.
|
|
382
|
-
attention_bias_idxs (Tensor): Indices for attention biases.
|
|
383
|
-
ab (Tensor): Cached attention biases for inference, deleted during training.
|
|
384
|
-
|
|
385
|
-
Methods:
|
|
386
|
-
train: Sets the module in training mode and handles the 'ab' attribute.
|
|
387
|
-
forward: Performs the forward pass of the attention mechanism.
|
|
415
|
+
attention_bias_idxs (torch.Tensor): Indices for attention biases.
|
|
416
|
+
ab (torch.Tensor): Cached attention biases for inference, deleted during training.
|
|
388
417
|
|
|
389
418
|
Examples:
|
|
390
419
|
>>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14))
|
|
@@ -396,32 +425,24 @@ class Attention(torch.nn.Module):
|
|
|
396
425
|
|
|
397
426
|
def __init__(
|
|
398
427
|
self,
|
|
399
|
-
dim,
|
|
400
|
-
key_dim,
|
|
401
|
-
num_heads=8,
|
|
402
|
-
attn_ratio=4,
|
|
403
|
-
resolution=(14, 14),
|
|
428
|
+
dim: int,
|
|
429
|
+
key_dim: int,
|
|
430
|
+
num_heads: int = 8,
|
|
431
|
+
attn_ratio: float = 4,
|
|
432
|
+
resolution: tuple[int, int] = (14, 14),
|
|
404
433
|
):
|
|
405
|
-
"""
|
|
406
|
-
Initializes the Attention module for multi-head attention with spatial awareness.
|
|
434
|
+
"""Initialize the Attention module for multi-head attention with spatial awareness.
|
|
407
435
|
|
|
408
|
-
This module implements a multi-head attention mechanism with support for spatial awareness, applying
|
|
409
|
-
|
|
410
|
-
|
|
436
|
+
This module implements a multi-head attention mechanism with support for spatial awareness, applying attention
|
|
437
|
+
biases based on spatial resolution. It includes trainable attention biases for each unique offset between
|
|
438
|
+
spatial positions in the resolution grid.
|
|
411
439
|
|
|
412
440
|
Args:
|
|
413
441
|
dim (int): The dimensionality of the input and output.
|
|
414
442
|
key_dim (int): The dimensionality of the keys and queries.
|
|
415
|
-
num_heads (int): Number of attention heads.
|
|
416
|
-
attn_ratio (float): Attention ratio, affecting the dimensions of the value vectors.
|
|
417
|
-
resolution (
|
|
418
|
-
|
|
419
|
-
Examples:
|
|
420
|
-
>>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14))
|
|
421
|
-
>>> x = torch.randn(1, 196, 256)
|
|
422
|
-
>>> output = attn(x)
|
|
423
|
-
>>> print(output.shape)
|
|
424
|
-
torch.Size([1, 196, 256])
|
|
443
|
+
num_heads (int, optional): Number of attention heads.
|
|
444
|
+
attn_ratio (float, optional): Attention ratio, affecting the dimensions of the value vectors.
|
|
445
|
+
resolution (tuple[int, int], optional): Spatial resolution of the input feature map.
|
|
425
446
|
"""
|
|
426
447
|
super().__init__()
|
|
427
448
|
|
|
@@ -453,16 +474,16 @@ class Attention(torch.nn.Module):
|
|
|
453
474
|
self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(N, N), persistent=False)
|
|
454
475
|
|
|
455
476
|
@torch.no_grad()
|
|
456
|
-
def train(self, mode=True):
|
|
457
|
-
"""
|
|
477
|
+
def train(self, mode: bool = True):
|
|
478
|
+
"""Set the module in training mode and handle the 'ab' attribute for cached attention biases."""
|
|
458
479
|
super().train(mode)
|
|
459
480
|
if mode and hasattr(self, "ab"):
|
|
460
481
|
del self.ab
|
|
461
482
|
else:
|
|
462
483
|
self.ab = self.attention_biases[:, self.attention_bias_idxs]
|
|
463
484
|
|
|
464
|
-
def forward(self, x):
|
|
465
|
-
"""
|
|
485
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
486
|
+
"""Apply multi-head attention with spatial awareness and trainable attention biases."""
|
|
466
487
|
B, N, _ = x.shape # B, N, C
|
|
467
488
|
|
|
468
489
|
# Normalization
|
|
@@ -486,27 +507,23 @@ class Attention(torch.nn.Module):
|
|
|
486
507
|
|
|
487
508
|
|
|
488
509
|
class TinyViTBlock(nn.Module):
|
|
489
|
-
"""
|
|
490
|
-
TinyViT Block that applies self-attention and a local convolution to the input.
|
|
510
|
+
"""TinyViT Block that applies self-attention and a local convolution to the input.
|
|
491
511
|
|
|
492
|
-
This block is a key component of the TinyViT architecture, combining self-attention mechanisms with
|
|
493
|
-
|
|
512
|
+
This block is a key component of the TinyViT architecture, combining self-attention mechanisms with local
|
|
513
|
+
convolutions to process input features efficiently. It supports windowed attention for computational efficiency and
|
|
514
|
+
includes residual connections.
|
|
494
515
|
|
|
495
516
|
Attributes:
|
|
496
517
|
dim (int): The dimensionality of the input and output.
|
|
497
|
-
input_resolution (
|
|
518
|
+
input_resolution (tuple[int, int]): Spatial resolution of the input feature map.
|
|
498
519
|
num_heads (int): Number of attention heads.
|
|
499
520
|
window_size (int): Size of the attention window.
|
|
500
521
|
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
|
|
501
522
|
drop_path (nn.Module): Stochastic depth layer, identity function during inference.
|
|
502
523
|
attn (Attention): Self-attention module.
|
|
503
|
-
mlp (
|
|
524
|
+
mlp (MLP): Multi-layer perceptron module.
|
|
504
525
|
local_conv (Conv2d_BN): Depth-wise local convolution layer.
|
|
505
526
|
|
|
506
|
-
Methods:
|
|
507
|
-
forward: Processes the input through the TinyViT block.
|
|
508
|
-
extra_repr: Returns a string with extra information about the block's parameters.
|
|
509
|
-
|
|
510
527
|
Examples:
|
|
511
528
|
>>> input_tensor = torch.randn(1, 196, 192)
|
|
512
529
|
>>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3)
|
|
@@ -517,43 +534,31 @@ class TinyViTBlock(nn.Module):
|
|
|
517
534
|
|
|
518
535
|
def __init__(
|
|
519
536
|
self,
|
|
520
|
-
dim,
|
|
521
|
-
input_resolution,
|
|
522
|
-
num_heads,
|
|
523
|
-
window_size=7,
|
|
524
|
-
mlp_ratio=4.0,
|
|
525
|
-
drop=0.0,
|
|
526
|
-
drop_path=0.0,
|
|
527
|
-
local_conv_size=3,
|
|
537
|
+
dim: int,
|
|
538
|
+
input_resolution: tuple[int, int],
|
|
539
|
+
num_heads: int,
|
|
540
|
+
window_size: int = 7,
|
|
541
|
+
mlp_ratio: float = 4.0,
|
|
542
|
+
drop: float = 0.0,
|
|
543
|
+
drop_path: float = 0.0,
|
|
544
|
+
local_conv_size: int = 3,
|
|
528
545
|
activation=nn.GELU,
|
|
529
546
|
):
|
|
530
|
-
"""
|
|
531
|
-
Initializes a TinyViT block with self-attention and local convolution.
|
|
547
|
+
"""Initialize a TinyViT block with self-attention and local convolution.
|
|
532
548
|
|
|
533
|
-
This block is a key component of the TinyViT architecture, combining self-attention mechanisms with
|
|
534
|
-
|
|
549
|
+
This block is a key component of the TinyViT architecture, combining self-attention mechanisms with local
|
|
550
|
+
convolutions to process input features efficiently.
|
|
535
551
|
|
|
536
552
|
Args:
|
|
537
553
|
dim (int): Dimensionality of the input and output features.
|
|
538
|
-
input_resolution (
|
|
554
|
+
input_resolution (tuple[int, int]): Spatial resolution of the input feature map (height, width).
|
|
539
555
|
num_heads (int): Number of attention heads.
|
|
540
|
-
window_size (int): Size of the attention window. Must be greater than 0.
|
|
541
|
-
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
|
|
542
|
-
drop (float): Dropout rate.
|
|
543
|
-
drop_path (float): Stochastic depth rate.
|
|
544
|
-
local_conv_size (int): Kernel size of the local convolution.
|
|
545
|
-
activation (
|
|
546
|
-
|
|
547
|
-
Raises:
|
|
548
|
-
AssertionError: If window_size is not greater than 0.
|
|
549
|
-
AssertionError: If dim is not divisible by num_heads.
|
|
550
|
-
|
|
551
|
-
Examples:
|
|
552
|
-
>>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3)
|
|
553
|
-
>>> input_tensor = torch.randn(1, 196, 192)
|
|
554
|
-
>>> output = block(input_tensor)
|
|
555
|
-
>>> print(output.shape)
|
|
556
|
-
torch.Size([1, 196, 192])
|
|
556
|
+
window_size (int, optional): Size of the attention window. Must be greater than 0.
|
|
557
|
+
mlp_ratio (float, optional): Ratio of MLP hidden dimension to embedding dimension.
|
|
558
|
+
drop (float, optional): Dropout rate.
|
|
559
|
+
drop_path (float, optional): Stochastic depth rate.
|
|
560
|
+
local_conv_size (int, optional): Kernel size of the local convolution.
|
|
561
|
+
activation (nn.Module): Activation function for MLP.
|
|
557
562
|
"""
|
|
558
563
|
super().__init__()
|
|
559
564
|
self.dim = dim
|
|
@@ -575,13 +580,13 @@ class TinyViTBlock(nn.Module):
|
|
|
575
580
|
|
|
576
581
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
577
582
|
mlp_activation = activation
|
|
578
|
-
self.mlp =
|
|
583
|
+
self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, activation=mlp_activation, drop=drop)
|
|
579
584
|
|
|
580
585
|
pad = local_conv_size // 2
|
|
581
586
|
self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
|
|
582
587
|
|
|
583
|
-
def forward(self, x):
|
|
584
|
-
"""
|
|
588
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
589
|
+
"""Apply self-attention, local convolution, and MLP operations to the input tensor."""
|
|
585
590
|
h, w = self.input_resolution
|
|
586
591
|
b, hw, c = x.shape # batch, height*width, channels
|
|
587
592
|
assert hw == h * w, "input feature has wrong size"
|
|
@@ -623,8 +628,7 @@ class TinyViTBlock(nn.Module):
|
|
|
623
628
|
return x + self.drop_path(self.mlp(x))
|
|
624
629
|
|
|
625
630
|
def extra_repr(self) -> str:
|
|
626
|
-
"""
|
|
627
|
-
Returns a string representation of the TinyViTBlock's parameters.
|
|
631
|
+
"""Return a string representation of the TinyViTBlock's parameters.
|
|
628
632
|
|
|
629
633
|
This method provides a formatted string containing key information about the TinyViTBlock, including its
|
|
630
634
|
dimension, input resolution, number of attention heads, window size, and MLP ratio.
|
|
@@ -644,24 +648,20 @@ class TinyViTBlock(nn.Module):
|
|
|
644
648
|
|
|
645
649
|
|
|
646
650
|
class BasicLayer(nn.Module):
|
|
647
|
-
"""
|
|
648
|
-
A basic TinyViT layer for one stage in a TinyViT architecture.
|
|
651
|
+
"""A basic TinyViT layer for one stage in a TinyViT architecture.
|
|
649
652
|
|
|
650
|
-
This class represents a single layer in the TinyViT model, consisting of multiple TinyViT blocks
|
|
651
|
-
and
|
|
653
|
+
This class represents a single layer in the TinyViT model, consisting of multiple TinyViT blocks and an optional
|
|
654
|
+
downsampling operation. It processes features at a specific resolution and dimensionality within the overall
|
|
655
|
+
architecture.
|
|
652
656
|
|
|
653
657
|
Attributes:
|
|
654
658
|
dim (int): The dimensionality of the input and output features.
|
|
655
|
-
input_resolution (
|
|
659
|
+
input_resolution (tuple[int, int]): Spatial resolution of the input feature map.
|
|
656
660
|
depth (int): Number of TinyViT blocks in this layer.
|
|
657
661
|
use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
|
|
658
662
|
blocks (nn.ModuleList): List of TinyViT blocks that make up this layer.
|
|
659
663
|
downsample (nn.Module | None): Downsample layer at the end of the layer, if specified.
|
|
660
664
|
|
|
661
|
-
Methods:
|
|
662
|
-
forward: Processes the input through the layer's blocks and optional downsampling.
|
|
663
|
-
extra_repr: Returns a string with the layer's parameters for printing.
|
|
664
|
-
|
|
665
665
|
Examples:
|
|
666
666
|
>>> input_tensor = torch.randn(1, 3136, 192)
|
|
667
667
|
>>> layer = BasicLayer(dim=192, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)
|
|
@@ -672,49 +672,41 @@ class BasicLayer(nn.Module):
|
|
|
672
672
|
|
|
673
673
|
def __init__(
|
|
674
674
|
self,
|
|
675
|
-
dim,
|
|
676
|
-
input_resolution,
|
|
677
|
-
depth,
|
|
678
|
-
num_heads,
|
|
679
|
-
window_size,
|
|
680
|
-
mlp_ratio=4.0,
|
|
681
|
-
drop=0.0,
|
|
682
|
-
drop_path=0.0,
|
|
683
|
-
downsample=None,
|
|
684
|
-
use_checkpoint=False,
|
|
685
|
-
local_conv_size=3,
|
|
675
|
+
dim: int,
|
|
676
|
+
input_resolution: tuple[int, int],
|
|
677
|
+
depth: int,
|
|
678
|
+
num_heads: int,
|
|
679
|
+
window_size: int,
|
|
680
|
+
mlp_ratio: float = 4.0,
|
|
681
|
+
drop: float = 0.0,
|
|
682
|
+
drop_path: float | list[float] = 0.0,
|
|
683
|
+
downsample: nn.Module | None = None,
|
|
684
|
+
use_checkpoint: bool = False,
|
|
685
|
+
local_conv_size: int = 3,
|
|
686
686
|
activation=nn.GELU,
|
|
687
|
-
out_dim=None,
|
|
687
|
+
out_dim: int | None = None,
|
|
688
688
|
):
|
|
689
|
-
"""
|
|
690
|
-
Initializes a BasicLayer in the TinyViT architecture.
|
|
689
|
+
"""Initialize a BasicLayer in the TinyViT architecture.
|
|
691
690
|
|
|
692
|
-
This layer consists of multiple TinyViT blocks and an optional downsampling operation. It is designed to
|
|
693
|
-
|
|
691
|
+
This layer consists of multiple TinyViT blocks and an optional downsampling operation. It is designed to process
|
|
692
|
+
feature maps at a specific resolution and dimensionality within the TinyViT model.
|
|
694
693
|
|
|
695
694
|
Args:
|
|
696
695
|
dim (int): Dimensionality of the input and output features.
|
|
697
|
-
input_resolution (
|
|
696
|
+
input_resolution (tuple[int, int]): Spatial resolution of the input feature map (height, width).
|
|
698
697
|
depth (int): Number of TinyViT blocks in this layer.
|
|
699
698
|
num_heads (int): Number of attention heads in each TinyViT block.
|
|
700
699
|
window_size (int): Size of the local window for attention computation.
|
|
701
|
-
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
|
|
702
|
-
drop (float): Dropout rate.
|
|
703
|
-
drop_path (float |
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
700
|
+
mlp_ratio (float, optional): Ratio of MLP hidden dimension to embedding dimension.
|
|
701
|
+
drop (float, optional): Dropout rate.
|
|
702
|
+
drop_path (float | list[float], optional): Stochastic depth rate. Can be a float or a list of floats for
|
|
703
|
+
each block.
|
|
704
|
+
downsample (nn.Module | None, optional): Downsampling layer at the end of the layer. None to skip
|
|
705
|
+
downsampling.
|
|
706
|
+
use_checkpoint (bool, optional): Whether to use gradient checkpointing to save memory.
|
|
707
|
+
local_conv_size (int, optional): Kernel size for the local convolution in each TinyViT block.
|
|
707
708
|
activation (nn.Module): Activation function used in the MLP.
|
|
708
|
-
out_dim (int | None): Output dimension after downsampling. None means it will be the same as
|
|
709
|
-
|
|
710
|
-
Raises:
|
|
711
|
-
ValueError: If `drop_path` is a list and its length doesn't match `depth`.
|
|
712
|
-
|
|
713
|
-
Examples:
|
|
714
|
-
>>> layer = BasicLayer(dim=96, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)
|
|
715
|
-
>>> x = torch.randn(1, 56 * 56, 96)
|
|
716
|
-
>>> output = layer(x)
|
|
717
|
-
>>> print(output.shape)
|
|
709
|
+
out_dim (int | None, optional): Output dimension after downsampling. None means it will be the same as dim.
|
|
718
710
|
"""
|
|
719
711
|
super().__init__()
|
|
720
712
|
self.dim = dim
|
|
@@ -747,97 +739,82 @@ class BasicLayer(nn.Module):
|
|
|
747
739
|
else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
|
748
740
|
)
|
|
749
741
|
|
|
750
|
-
def forward(self, x):
|
|
751
|
-
"""
|
|
742
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
743
|
+
"""Process input through TinyViT blocks and optional downsampling."""
|
|
752
744
|
for blk in self.blocks:
|
|
753
745
|
x = torch.utils.checkpoint(blk, x) if self.use_checkpoint else blk(x) # warn: checkpoint is slow import
|
|
754
746
|
return x if self.downsample is None else self.downsample(x)
|
|
755
747
|
|
|
756
748
|
def extra_repr(self) -> str:
|
|
757
|
-
"""
|
|
749
|
+
"""Return a string with the layer's parameters for printing."""
|
|
758
750
|
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
|
759
751
|
|
|
760
752
|
|
|
761
753
|
class TinyViT(nn.Module):
|
|
762
|
-
"""
|
|
763
|
-
TinyViT: A compact vision transformer architecture for efficient image classification and feature extraction.
|
|
754
|
+
"""TinyViT: A compact vision transformer architecture for efficient image classification and feature extraction.
|
|
764
755
|
|
|
765
|
-
This class implements the TinyViT model, which combines elements of vision transformers and convolutional
|
|
766
|
-
|
|
756
|
+
This class implements the TinyViT model, which combines elements of vision transformers and convolutional neural
|
|
757
|
+
networks for improved efficiency and performance on vision tasks. It features hierarchical processing with patch
|
|
758
|
+
embedding, multiple stages of attention and convolution blocks, and a feature refinement neck.
|
|
767
759
|
|
|
768
760
|
Attributes:
|
|
769
761
|
img_size (int): Input image size.
|
|
770
762
|
num_classes (int): Number of classification classes.
|
|
771
|
-
depths (
|
|
763
|
+
depths (tuple[int, int, int, int]): Number of blocks in each stage.
|
|
772
764
|
num_layers (int): Total number of layers in the network.
|
|
773
765
|
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
|
|
774
766
|
patch_embed (PatchEmbed): Module for patch embedding.
|
|
775
|
-
patches_resolution (
|
|
767
|
+
patches_resolution (tuple[int, int]): Resolution of embedded patches.
|
|
776
768
|
layers (nn.ModuleList): List of network layers.
|
|
777
769
|
norm_head (nn.LayerNorm): Layer normalization for the classifier head.
|
|
778
770
|
head (nn.Linear): Linear layer for final classification.
|
|
779
771
|
neck (nn.Sequential): Neck module for feature refinement.
|
|
780
772
|
|
|
781
|
-
Methods:
|
|
782
|
-
set_layer_lr_decay: Sets layer-wise learning rate decay.
|
|
783
|
-
_init_weights: Initializes weights for linear and normalization layers.
|
|
784
|
-
no_weight_decay_keywords: Returns keywords for parameters that should not use weight decay.
|
|
785
|
-
forward_features: Processes input through the feature extraction layers.
|
|
786
|
-
forward: Performs a forward pass through the entire network.
|
|
787
|
-
|
|
788
773
|
Examples:
|
|
789
774
|
>>> model = TinyViT(img_size=224, num_classes=1000)
|
|
790
775
|
>>> x = torch.randn(1, 3, 224, 224)
|
|
791
776
|
>>> features = model.forward_features(x)
|
|
792
777
|
>>> print(features.shape)
|
|
793
|
-
torch.Size([1, 256,
|
|
778
|
+
torch.Size([1, 256, 56, 56])
|
|
794
779
|
"""
|
|
795
780
|
|
|
796
781
|
def __init__(
|
|
797
782
|
self,
|
|
798
|
-
img_size=224,
|
|
799
|
-
in_chans=3,
|
|
800
|
-
num_classes=1000,
|
|
801
|
-
embed_dims=(96, 192, 384, 768),
|
|
802
|
-
depths=(2, 2, 6, 2),
|
|
803
|
-
num_heads=(3, 6, 12, 24),
|
|
804
|
-
window_sizes=(7, 7, 14, 7),
|
|
805
|
-
mlp_ratio=4.0,
|
|
806
|
-
drop_rate=0.0,
|
|
807
|
-
drop_path_rate=0.1,
|
|
808
|
-
use_checkpoint=False,
|
|
809
|
-
mbconv_expand_ratio=4.0,
|
|
810
|
-
local_conv_size=3,
|
|
811
|
-
layer_lr_decay=1.0,
|
|
783
|
+
img_size: int = 224,
|
|
784
|
+
in_chans: int = 3,
|
|
785
|
+
num_classes: int = 1000,
|
|
786
|
+
embed_dims: tuple[int, int, int, int] = (96, 192, 384, 768),
|
|
787
|
+
depths: tuple[int, int, int, int] = (2, 2, 6, 2),
|
|
788
|
+
num_heads: tuple[int, int, int, int] = (3, 6, 12, 24),
|
|
789
|
+
window_sizes: tuple[int, int, int, int] = (7, 7, 14, 7),
|
|
790
|
+
mlp_ratio: float = 4.0,
|
|
791
|
+
drop_rate: float = 0.0,
|
|
792
|
+
drop_path_rate: float = 0.1,
|
|
793
|
+
use_checkpoint: bool = False,
|
|
794
|
+
mbconv_expand_ratio: float = 4.0,
|
|
795
|
+
local_conv_size: int = 3,
|
|
796
|
+
layer_lr_decay: float = 1.0,
|
|
812
797
|
):
|
|
813
|
-
"""
|
|
814
|
-
Initializes the TinyViT model.
|
|
798
|
+
"""Initialize the TinyViT model.
|
|
815
799
|
|
|
816
|
-
This constructor sets up the TinyViT architecture, including patch embedding, multiple layers of
|
|
817
|
-
|
|
800
|
+
This constructor sets up the TinyViT architecture, including patch embedding, multiple layers of attention and
|
|
801
|
+
convolution blocks, and a classification head.
|
|
818
802
|
|
|
819
803
|
Args:
|
|
820
|
-
img_size (int): Size of the input image.
|
|
821
|
-
in_chans (int): Number of input channels.
|
|
822
|
-
num_classes (int): Number of classes for classification.
|
|
823
|
-
embed_dims (
|
|
824
|
-
depths (
|
|
825
|
-
num_heads (
|
|
826
|
-
window_sizes (
|
|
827
|
-
mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
|
|
828
|
-
drop_rate (float): Dropout rate.
|
|
829
|
-
drop_path_rate (float): Stochastic depth rate.
|
|
830
|
-
use_checkpoint (bool): Whether to use checkpointing to save memory.
|
|
831
|
-
mbconv_expand_ratio (float): Expansion ratio for MBConv layer.
|
|
832
|
-
local_conv_size (int): Kernel size for local convolutions.
|
|
833
|
-
layer_lr_decay (float): Layer-wise learning rate decay factor.
|
|
834
|
-
|
|
835
|
-
Examples:
|
|
836
|
-
>>> model = TinyViT(img_size=224, num_classes=1000)
|
|
837
|
-
>>> x = torch.randn(1, 3, 224, 224)
|
|
838
|
-
>>> output = model(x)
|
|
839
|
-
>>> print(output.shape)
|
|
840
|
-
torch.Size([1, 1000])
|
|
804
|
+
img_size (int, optional): Size of the input image.
|
|
805
|
+
in_chans (int, optional): Number of input channels.
|
|
806
|
+
num_classes (int, optional): Number of classes for classification.
|
|
807
|
+
embed_dims (tuple[int, int, int, int], optional): Embedding dimensions for each stage.
|
|
808
|
+
depths (tuple[int, int, int, int], optional): Number of blocks in each stage.
|
|
809
|
+
num_heads (tuple[int, int, int, int], optional): Number of attention heads in each stage.
|
|
810
|
+
window_sizes (tuple[int, int, int, int], optional): Window sizes for each stage.
|
|
811
|
+
mlp_ratio (float, optional): Ratio of MLP hidden dim to embedding dim.
|
|
812
|
+
drop_rate (float, optional): Dropout rate.
|
|
813
|
+
drop_path_rate (float, optional): Stochastic depth rate.
|
|
814
|
+
use_checkpoint (bool, optional): Whether to use checkpointing to save memory.
|
|
815
|
+
mbconv_expand_ratio (float, optional): Expansion ratio for MBConv layer.
|
|
816
|
+
local_conv_size (int, optional): Kernel size for local convolutions.
|
|
817
|
+
layer_lr_decay (float, optional): Layer-wise learning rate decay factor.
|
|
841
818
|
"""
|
|
842
819
|
super().__init__()
|
|
843
820
|
self.img_size = img_size
|
|
@@ -914,8 +891,8 @@ class TinyViT(nn.Module):
|
|
|
914
891
|
LayerNorm2d(256),
|
|
915
892
|
)
|
|
916
893
|
|
|
917
|
-
def set_layer_lr_decay(self, layer_lr_decay):
|
|
918
|
-
"""
|
|
894
|
+
def set_layer_lr_decay(self, layer_lr_decay: float):
|
|
895
|
+
"""Set layer-wise learning rate decay for the TinyViT model based on depth."""
|
|
919
896
|
decay_rate = layer_lr_decay
|
|
920
897
|
|
|
921
898
|
# Layers -> blocks (depth)
|
|
@@ -923,7 +900,7 @@ class TinyViT(nn.Module):
|
|
|
923
900
|
lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
|
|
924
901
|
|
|
925
902
|
def _set_lr_scale(m, scale):
|
|
926
|
-
"""
|
|
903
|
+
"""Set the learning rate scale for each layer in the model based on the layer's depth."""
|
|
927
904
|
for p in m.parameters():
|
|
928
905
|
p.lr_scale = scale
|
|
929
906
|
|
|
@@ -936,14 +913,14 @@ class TinyViT(nn.Module):
|
|
|
936
913
|
if layer.downsample is not None:
|
|
937
914
|
layer.downsample.apply(lambda x: _set_lr_scale(x, lr_scales[i - 1]))
|
|
938
915
|
assert i == depth
|
|
939
|
-
for m in
|
|
916
|
+
for m in {self.norm_head, self.head}:
|
|
940
917
|
m.apply(lambda x: _set_lr_scale(x, lr_scales[-1]))
|
|
941
918
|
|
|
942
919
|
for k, p in self.named_parameters():
|
|
943
920
|
p.param_name = k
|
|
944
921
|
|
|
945
922
|
def _check_lr_scale(m):
|
|
946
|
-
"""
|
|
923
|
+
"""Check if the learning rate scale attribute is present in module's parameters."""
|
|
947
924
|
for p in m.parameters():
|
|
948
925
|
assert hasattr(p, "lr_scale"), p.param_name
|
|
949
926
|
|
|
@@ -951,7 +928,7 @@ class TinyViT(nn.Module):
|
|
|
951
928
|
|
|
952
929
|
@staticmethod
|
|
953
930
|
def _init_weights(m):
|
|
954
|
-
"""
|
|
931
|
+
"""Initialize weights for linear and normalization layers in the TinyViT model."""
|
|
955
932
|
if isinstance(m, nn.Linear):
|
|
956
933
|
# NOTE: This initialization is needed only for training.
|
|
957
934
|
# trunc_normal_(m.weight, std=.02)
|
|
@@ -963,11 +940,11 @@ class TinyViT(nn.Module):
|
|
|
963
940
|
|
|
964
941
|
@torch.jit.ignore
|
|
965
942
|
def no_weight_decay_keywords(self):
|
|
966
|
-
"""
|
|
943
|
+
"""Return a set of keywords for parameters that should not use weight decay."""
|
|
967
944
|
return {"attention_biases"}
|
|
968
945
|
|
|
969
|
-
def forward_features(self, x):
|
|
970
|
-
"""
|
|
946
|
+
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
|
947
|
+
"""Process input through feature extraction layers, returning spatial features."""
|
|
971
948
|
x = self.patch_embed(x) # x input is (N, C, H, W)
|
|
972
949
|
|
|
973
950
|
x = self.layers[0](x)
|
|
@@ -981,11 +958,11 @@ class TinyViT(nn.Module):
|
|
|
981
958
|
x = x.permute(0, 3, 1, 2)
|
|
982
959
|
return self.neck(x)
|
|
983
960
|
|
|
984
|
-
def forward(self, x):
|
|
985
|
-
"""
|
|
961
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
962
|
+
"""Perform the forward pass through the TinyViT model, extracting features from the input image."""
|
|
986
963
|
return self.forward_features(x)
|
|
987
964
|
|
|
988
|
-
def set_imgsz(self, imgsz=[1024, 1024]):
|
|
965
|
+
def set_imgsz(self, imgsz: list[int] = [1024, 1024]):
|
|
989
966
|
"""Set image size to make model compatible with different image sizes."""
|
|
990
967
|
imgsz = [s // 4 for s in imgsz]
|
|
991
968
|
self.patches_resolution = imgsz
|