ultralytics 8.3.143__py3-none-any.whl → 8.3.144__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/conftest.py +7 -24
- tests/test_cli.py +1 -1
- tests/test_cuda.py +7 -2
- tests/test_engine.py +7 -8
- tests/test_exports.py +16 -16
- tests/test_integrations.py +1 -1
- tests/test_solutions.py +11 -11
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +16 -13
- ultralytics/data/annotator.py +6 -5
- ultralytics/data/augment.py +127 -126
- ultralytics/data/base.py +54 -51
- ultralytics/data/build.py +47 -23
- ultralytics/data/converter.py +47 -43
- ultralytics/data/dataset.py +51 -50
- ultralytics/data/loaders.py +77 -44
- ultralytics/data/split.py +22 -9
- ultralytics/data/split_dota.py +63 -39
- ultralytics/data/utils.py +59 -39
- ultralytics/engine/exporter.py +79 -27
- ultralytics/engine/model.py +39 -39
- ultralytics/engine/predictor.py +37 -28
- ultralytics/engine/results.py +187 -157
- ultralytics/engine/trainer.py +36 -19
- ultralytics/engine/tuner.py +12 -9
- ultralytics/engine/validator.py +7 -9
- ultralytics/hub/__init__.py +11 -13
- ultralytics/hub/auth.py +22 -2
- ultralytics/hub/google/__init__.py +19 -19
- ultralytics/hub/session.py +37 -51
- ultralytics/hub/utils.py +19 -5
- ultralytics/models/fastsam/model.py +30 -12
- ultralytics/models/fastsam/predict.py +5 -6
- ultralytics/models/fastsam/utils.py +3 -3
- ultralytics/models/fastsam/val.py +10 -6
- ultralytics/models/nas/model.py +9 -5
- ultralytics/models/nas/predict.py +6 -6
- ultralytics/models/nas/val.py +3 -3
- ultralytics/models/rtdetr/model.py +7 -6
- ultralytics/models/rtdetr/predict.py +14 -7
- ultralytics/models/rtdetr/train.py +10 -4
- ultralytics/models/rtdetr/val.py +36 -9
- ultralytics/models/sam/amg.py +30 -12
- ultralytics/models/sam/build.py +22 -22
- ultralytics/models/sam/model.py +10 -9
- ultralytics/models/sam/modules/blocks.py +76 -80
- ultralytics/models/sam/modules/decoders.py +6 -8
- ultralytics/models/sam/modules/encoders.py +23 -26
- ultralytics/models/sam/modules/memory_attention.py +13 -1
- ultralytics/models/sam/modules/sam.py +57 -26
- ultralytics/models/sam/modules/tiny_encoder.py +232 -237
- ultralytics/models/sam/modules/transformer.py +13 -13
- ultralytics/models/sam/modules/utils.py +11 -19
- ultralytics/models/sam/predict.py +114 -101
- ultralytics/models/utils/loss.py +98 -77
- ultralytics/models/utils/ops.py +116 -67
- ultralytics/models/yolo/classify/predict.py +5 -5
- ultralytics/models/yolo/classify/train.py +32 -28
- ultralytics/models/yolo/classify/val.py +7 -8
- ultralytics/models/yolo/detect/predict.py +1 -0
- ultralytics/models/yolo/detect/train.py +15 -14
- ultralytics/models/yolo/detect/val.py +37 -36
- ultralytics/models/yolo/model.py +106 -23
- ultralytics/models/yolo/obb/predict.py +3 -4
- ultralytics/models/yolo/obb/train.py +14 -6
- ultralytics/models/yolo/obb/val.py +29 -23
- ultralytics/models/yolo/pose/predict.py +9 -8
- ultralytics/models/yolo/pose/train.py +24 -16
- ultralytics/models/yolo/pose/val.py +44 -26
- ultralytics/models/yolo/segment/predict.py +5 -5
- ultralytics/models/yolo/segment/train.py +11 -7
- ultralytics/models/yolo/segment/val.py +2 -2
- ultralytics/models/yolo/world/train.py +33 -23
- ultralytics/models/yolo/world/train_world.py +11 -3
- ultralytics/models/yolo/yoloe/predict.py +11 -11
- ultralytics/models/yolo/yoloe/train.py +73 -21
- ultralytics/models/yolo/yoloe/train_seg.py +10 -7
- ultralytics/models/yolo/yoloe/val.py +42 -18
- ultralytics/nn/autobackend.py +59 -15
- ultralytics/nn/modules/__init__.py +4 -4
- ultralytics/nn/modules/activation.py +4 -1
- ultralytics/nn/modules/block.py +178 -111
- ultralytics/nn/modules/conv.py +6 -5
- ultralytics/nn/modules/head.py +469 -121
- ultralytics/nn/modules/transformer.py +147 -58
- ultralytics/nn/tasks.py +227 -20
- ultralytics/nn/text_model.py +30 -33
- ultralytics/solutions/ai_gym.py +1 -1
- ultralytics/solutions/analytics.py +7 -4
- ultralytics/solutions/config.py +10 -10
- ultralytics/solutions/distance_calculation.py +11 -10
- ultralytics/solutions/heatmap.py +1 -1
- ultralytics/solutions/instance_segmentation.py +6 -3
- ultralytics/solutions/object_blurrer.py +3 -3
- ultralytics/solutions/object_counter.py +15 -7
- ultralytics/solutions/object_cropper.py +3 -2
- ultralytics/solutions/parking_management.py +29 -28
- ultralytics/solutions/queue_management.py +6 -6
- ultralytics/solutions/region_counter.py +10 -3
- ultralytics/solutions/security_alarm.py +3 -3
- ultralytics/solutions/similarity_search.py +85 -24
- ultralytics/solutions/solutions.py +184 -75
- ultralytics/solutions/speed_estimation.py +28 -22
- ultralytics/solutions/streamlit_inference.py +17 -12
- ultralytics/solutions/trackzone.py +4 -4
- ultralytics/trackers/basetrack.py +16 -23
- ultralytics/trackers/bot_sort.py +30 -20
- ultralytics/trackers/byte_tracker.py +70 -64
- ultralytics/trackers/track.py +4 -8
- ultralytics/trackers/utils/gmc.py +31 -58
- ultralytics/trackers/utils/kalman_filter.py +37 -37
- ultralytics/trackers/utils/matching.py +1 -1
- ultralytics/utils/__init__.py +105 -89
- ultralytics/utils/autobatch.py +16 -3
- ultralytics/utils/autodevice.py +54 -24
- ultralytics/utils/benchmarks.py +42 -28
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +9 -9
- ultralytics/utils/callbacks/comet.py +67 -25
- ultralytics/utils/callbacks/dvc.py +7 -10
- ultralytics/utils/callbacks/mlflow.py +2 -5
- ultralytics/utils/callbacks/neptune.py +7 -13
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +5 -6
- ultralytics/utils/callbacks/wb.py +14 -14
- ultralytics/utils/checks.py +14 -13
- ultralytics/utils/dist.py +5 -5
- ultralytics/utils/downloads.py +94 -67
- ultralytics/utils/errors.py +5 -5
- ultralytics/utils/export.py +61 -47
- ultralytics/utils/files.py +23 -22
- ultralytics/utils/instance.py +48 -52
- ultralytics/utils/loss.py +78 -40
- ultralytics/utils/metrics.py +186 -130
- ultralytics/utils/ops.py +186 -190
- ultralytics/utils/patches.py +15 -17
- ultralytics/utils/plotting.py +71 -27
- ultralytics/utils/tal.py +21 -15
- ultralytics/utils/torch_utils.py +53 -50
- ultralytics/utils/triton.py +5 -4
- ultralytics/utils/tuner.py +5 -5
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/METADATA +1 -1
- ultralytics-8.3.144.dist-info/RECORD +272 -0
- ultralytics-8.3.143.dist-info/RECORD +0 -272
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/licenses/LICENSE +0 -0
- {ultralytics-8.3.143.dist-info → ultralytics-8.3.144.dist-info}/top_level.txt +0 -0
@@ -10,7 +10,7 @@
|
|
10
10
|
# --------------------------------------------------------
|
11
11
|
|
12
12
|
import itertools
|
13
|
-
from typing import Tuple
|
13
|
+
from typing import List, Optional, Tuple, Union
|
14
14
|
|
15
15
|
import torch
|
16
16
|
import torch.nn as nn
|
@@ -24,32 +24,46 @@ class Conv2d_BN(torch.nn.Sequential):
|
|
24
24
|
"""
|
25
25
|
A sequential container that performs 2D convolution followed by batch normalization.
|
26
26
|
|
27
|
+
This module combines a 2D convolution layer with batch normalization, providing a common building block
|
28
|
+
for convolutional neural networks. The batch normalization weights and biases are initialized to specific
|
29
|
+
values for optimal training performance.
|
30
|
+
|
27
31
|
Attributes:
|
28
32
|
c (torch.nn.Conv2d): 2D convolution layer.
|
29
33
|
bn (torch.nn.BatchNorm2d): Batch normalization layer.
|
30
34
|
|
31
|
-
Methods:
|
32
|
-
__init__: Initializes the Conv2d_BN with specified parameters.
|
33
|
-
|
34
|
-
Args:
|
35
|
-
a (int): Number of input channels.
|
36
|
-
b (int): Number of output channels.
|
37
|
-
ks (int): Kernel size for the convolution. Defaults to 1.
|
38
|
-
stride (int): Stride for the convolution. Defaults to 1.
|
39
|
-
pad (int): Padding for the convolution. Defaults to 0.
|
40
|
-
dilation (int): Dilation factor for the convolution. Defaults to 1.
|
41
|
-
groups (int): Number of groups for the convolution. Defaults to 1.
|
42
|
-
bn_weight_init (float): Initial value for batch normalization weight. Defaults to 1.
|
43
|
-
|
44
35
|
Examples:
|
45
36
|
>>> conv_bn = Conv2d_BN(3, 64, ks=3, stride=1, pad=1)
|
46
37
|
>>> input_tensor = torch.randn(1, 3, 224, 224)
|
47
38
|
>>> output = conv_bn(input_tensor)
|
48
39
|
>>> print(output.shape)
|
40
|
+
torch.Size([1, 64, 224, 224])
|
49
41
|
"""
|
50
42
|
|
51
|
-
def __init__(
|
52
|
-
|
43
|
+
def __init__(
|
44
|
+
self,
|
45
|
+
a: int,
|
46
|
+
b: int,
|
47
|
+
ks: int = 1,
|
48
|
+
stride: int = 1,
|
49
|
+
pad: int = 0,
|
50
|
+
dilation: int = 1,
|
51
|
+
groups: int = 1,
|
52
|
+
bn_weight_init: float = 1,
|
53
|
+
):
|
54
|
+
"""
|
55
|
+
Initialize a sequential container with 2D convolution followed by batch normalization.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
a (int): Number of input channels.
|
59
|
+
b (int): Number of output channels.
|
60
|
+
ks (int, optional): Kernel size for the convolution.
|
61
|
+
stride (int, optional): Stride for the convolution.
|
62
|
+
pad (int, optional): Padding for the convolution.
|
63
|
+
dilation (int, optional): Dilation factor for the convolution.
|
64
|
+
groups (int, optional): Number of groups for the convolution.
|
65
|
+
bn_weight_init (float, optional): Initial value for batch normalization weight.
|
66
|
+
"""
|
53
67
|
super().__init__()
|
54
68
|
self.add_module("c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
|
55
69
|
bn = torch.nn.BatchNorm2d(b)
|
@@ -60,7 +74,10 @@ class Conv2d_BN(torch.nn.Sequential):
|
|
60
74
|
|
61
75
|
class PatchEmbed(nn.Module):
|
62
76
|
"""
|
63
|
-
|
77
|
+
Embed images into patches and project them into a specified embedding dimension.
|
78
|
+
|
79
|
+
This module converts input images into patch embeddings using a sequence of convolutional layers,
|
80
|
+
effectively downsampling the spatial dimensions while increasing the channel dimension.
|
64
81
|
|
65
82
|
Attributes:
|
66
83
|
patches_resolution (Tuple[int, int]): Resolution of the patches after embedding.
|
@@ -69,19 +86,25 @@ class PatchEmbed(nn.Module):
|
|
69
86
|
embed_dim (int): Dimension of the embedding.
|
70
87
|
seq (nn.Sequential): Sequence of convolutional and activation layers for patch embedding.
|
71
88
|
|
72
|
-
Methods:
|
73
|
-
forward: Processes the input tensor through the patch embedding sequence.
|
74
|
-
|
75
89
|
Examples:
|
76
90
|
>>> import torch
|
77
91
|
>>> patch_embed = PatchEmbed(in_chans=3, embed_dim=96, resolution=224, activation=nn.GELU)
|
78
92
|
>>> x = torch.randn(1, 3, 224, 224)
|
79
93
|
>>> output = patch_embed(x)
|
80
94
|
>>> print(output.shape)
|
95
|
+
torch.Size([1, 96, 56, 56])
|
81
96
|
"""
|
82
97
|
|
83
|
-
def __init__(self, in_chans, embed_dim, resolution, activation):
|
84
|
-
"""
|
98
|
+
def __init__(self, in_chans: int, embed_dim: int, resolution: int, activation):
|
99
|
+
"""
|
100
|
+
Initialize patch embedding with convolutional layers for image-to-patch conversion and projection.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
in_chans (int): Number of input channels.
|
104
|
+
embed_dim (int): Dimension of the embedding.
|
105
|
+
resolution (int): Input image resolution.
|
106
|
+
activation (nn.Module): Activation function to use between convolutions.
|
107
|
+
"""
|
85
108
|
super().__init__()
|
86
109
|
img_size: Tuple[int, int] = to_2tuple(resolution)
|
87
110
|
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
|
@@ -95,8 +118,8 @@ class PatchEmbed(nn.Module):
|
|
95
118
|
Conv2d_BN(n // 2, n, 3, 2, 1),
|
96
119
|
)
|
97
120
|
|
98
|
-
def forward(self, x):
|
99
|
-
"""
|
121
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
122
|
+
"""Process input tensor through patch embedding sequence, converting images to patch embeddings."""
|
100
123
|
return self.seq(x)
|
101
124
|
|
102
125
|
|
@@ -104,21 +127,21 @@ class MBConv(nn.Module):
|
|
104
127
|
"""
|
105
128
|
Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture.
|
106
129
|
|
130
|
+
This module implements the mobile inverted bottleneck convolution with expansion, depthwise convolution,
|
131
|
+
and projection phases, along with residual connections for improved gradient flow.
|
132
|
+
|
107
133
|
Attributes:
|
108
134
|
in_chans (int): Number of input channels.
|
109
|
-
hidden_chans (int): Number of hidden channels.
|
135
|
+
hidden_chans (int): Number of hidden channels after expansion.
|
110
136
|
out_chans (int): Number of output channels.
|
111
|
-
conv1 (Conv2d_BN): First convolutional layer.
|
137
|
+
conv1 (Conv2d_BN): First convolutional layer for channel expansion.
|
112
138
|
act1 (nn.Module): First activation function.
|
113
139
|
conv2 (Conv2d_BN): Depthwise convolutional layer.
|
114
140
|
act2 (nn.Module): Second activation function.
|
115
|
-
conv3 (Conv2d_BN): Final convolutional layer.
|
141
|
+
conv3 (Conv2d_BN): Final convolutional layer for projection.
|
116
142
|
act3 (nn.Module): Third activation function.
|
117
143
|
drop_path (nn.Module): Drop path layer (Identity for inference).
|
118
144
|
|
119
|
-
Methods:
|
120
|
-
forward: Performs the forward pass through the MBConv layer.
|
121
|
-
|
122
145
|
Examples:
|
123
146
|
>>> in_chans, out_chans = 32, 64
|
124
147
|
>>> mbconv = MBConv(in_chans, out_chans, expand_ratio=4, activation=nn.ReLU, drop_path=0.1)
|
@@ -128,8 +151,17 @@ class MBConv(nn.Module):
|
|
128
151
|
torch.Size([1, 64, 56, 56])
|
129
152
|
"""
|
130
153
|
|
131
|
-
def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
|
132
|
-
"""
|
154
|
+
def __init__(self, in_chans: int, out_chans: int, expand_ratio: float, activation, drop_path: float):
|
155
|
+
"""
|
156
|
+
Initialize the MBConv layer with specified input/output channels, expansion ratio, and activation.
|
157
|
+
|
158
|
+
Args:
|
159
|
+
in_chans (int): Number of input channels.
|
160
|
+
out_chans (int): Number of output channels.
|
161
|
+
expand_ratio (float): Channel expansion ratio for the hidden layer.
|
162
|
+
activation (nn.Module): Activation function to use.
|
163
|
+
drop_path (float): Drop path rate for stochastic depth.
|
164
|
+
"""
|
133
165
|
super().__init__()
|
134
166
|
self.in_chans = in_chans
|
135
167
|
self.hidden_chans = int(in_chans * expand_ratio)
|
@@ -148,8 +180,8 @@ class MBConv(nn.Module):
|
|
148
180
|
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
149
181
|
self.drop_path = nn.Identity()
|
150
182
|
|
151
|
-
def forward(self, x):
|
152
|
-
"""
|
183
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
184
|
+
"""Implement the forward pass of MBConv, applying convolutions and skip connection."""
|
153
185
|
shortcut = x
|
154
186
|
x = self.conv1(x)
|
155
187
|
x = self.act1(x)
|
@@ -163,10 +195,11 @@ class MBConv(nn.Module):
|
|
163
195
|
|
164
196
|
class PatchMerging(nn.Module):
|
165
197
|
"""
|
166
|
-
|
198
|
+
Merge neighboring patches in the feature map and project to a new dimension.
|
167
199
|
|
168
200
|
This class implements a patch merging operation that combines spatial information and adjusts the feature
|
169
|
-
dimension
|
201
|
+
dimension using a series of convolutional layers with batch normalization. It effectively reduces spatial
|
202
|
+
resolution while potentially increasing channel dimensions.
|
170
203
|
|
171
204
|
Attributes:
|
172
205
|
input_resolution (Tuple[int, int]): The input resolution (height, width) of the feature map.
|
@@ -177,19 +210,25 @@ class PatchMerging(nn.Module):
|
|
177
210
|
conv2 (Conv2d_BN): The second convolutional layer for spatial merging.
|
178
211
|
conv3 (Conv2d_BN): The third convolutional layer for final projection.
|
179
212
|
|
180
|
-
Methods:
|
181
|
-
forward: Applies the patch merging operation to the input tensor.
|
182
|
-
|
183
213
|
Examples:
|
184
214
|
>>> input_resolution = (56, 56)
|
185
215
|
>>> patch_merging = PatchMerging(input_resolution, dim=64, out_dim=128, activation=nn.ReLU)
|
186
216
|
>>> x = torch.randn(4, 64, 56, 56)
|
187
217
|
>>> output = patch_merging(x)
|
188
218
|
>>> print(output.shape)
|
219
|
+
torch.Size([4, 3136, 128])
|
189
220
|
"""
|
190
221
|
|
191
|
-
def __init__(self, input_resolution, dim, out_dim, activation):
|
192
|
-
"""
|
222
|
+
def __init__(self, input_resolution: Tuple[int, int], dim: int, out_dim: int, activation):
|
223
|
+
"""
|
224
|
+
Initialize the PatchMerging module for merging and projecting neighboring patches in feature maps.
|
225
|
+
|
226
|
+
Args:
|
227
|
+
input_resolution (Tuple[int, int]): The input resolution (height, width) of the feature map.
|
228
|
+
dim (int): The input dimension of the feature map.
|
229
|
+
out_dim (int): The output dimension after merging and projection.
|
230
|
+
activation (nn.Module): The activation function used between convolutions.
|
231
|
+
"""
|
193
232
|
super().__init__()
|
194
233
|
|
195
234
|
self.input_resolution = input_resolution
|
@@ -201,8 +240,8 @@ class PatchMerging(nn.Module):
|
|
201
240
|
self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
|
202
241
|
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
|
203
242
|
|
204
|
-
def forward(self, x):
|
205
|
-
"""
|
243
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
244
|
+
"""Apply patch merging and dimension projection to the input feature map."""
|
206
245
|
if x.ndim == 3:
|
207
246
|
H, W = self.input_resolution
|
208
247
|
B = len(x)
|
@@ -222,7 +261,8 @@ class ConvLayer(nn.Module):
|
|
222
261
|
"""
|
223
262
|
Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv).
|
224
263
|
|
225
|
-
This layer optionally applies downsample operations to the output and supports gradient checkpointing
|
264
|
+
This layer optionally applies downsample operations to the output and supports gradient checkpointing
|
265
|
+
for memory efficiency during training.
|
226
266
|
|
227
267
|
Attributes:
|
228
268
|
dim (int): Dimensionality of the input and output.
|
@@ -230,32 +270,30 @@ class ConvLayer(nn.Module):
|
|
230
270
|
depth (int): Number of MBConv layers in the block.
|
231
271
|
use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
|
232
272
|
blocks (nn.ModuleList): List of MBConv layers.
|
233
|
-
downsample (Optional[
|
234
|
-
|
235
|
-
Methods:
|
236
|
-
forward: Processes the input through the convolutional layers.
|
273
|
+
downsample (Optional[nn.Module]): Function for downsampling the output.
|
237
274
|
|
238
275
|
Examples:
|
239
276
|
>>> input_tensor = torch.randn(1, 64, 56, 56)
|
240
277
|
>>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU)
|
241
278
|
>>> output = conv_layer(input_tensor)
|
242
279
|
>>> print(output.shape)
|
280
|
+
torch.Size([1, 3136, 128])
|
243
281
|
"""
|
244
282
|
|
245
283
|
def __init__(
|
246
284
|
self,
|
247
|
-
dim,
|
248
|
-
input_resolution,
|
249
|
-
depth,
|
285
|
+
dim: int,
|
286
|
+
input_resolution: Tuple[int, int],
|
287
|
+
depth: int,
|
250
288
|
activation,
|
251
|
-
drop_path=0.0,
|
252
|
-
downsample=None,
|
253
|
-
use_checkpoint=False,
|
254
|
-
out_dim=None,
|
255
|
-
conv_expand_ratio=4.0,
|
289
|
+
drop_path: Union[float, List[float]] = 0.0,
|
290
|
+
downsample: Optional[nn.Module] = None,
|
291
|
+
use_checkpoint: bool = False,
|
292
|
+
out_dim: Optional[int] = None,
|
293
|
+
conv_expand_ratio: float = 4.0,
|
256
294
|
):
|
257
295
|
"""
|
258
|
-
|
296
|
+
Initialize the ConvLayer with the given dimensions and settings.
|
259
297
|
|
260
298
|
This layer consists of multiple MobileNetV3-style inverted bottleneck convolutions (MBConv) and
|
261
299
|
optionally applies downsampling to the output.
|
@@ -265,17 +303,11 @@ class ConvLayer(nn.Module):
|
|
265
303
|
input_resolution (Tuple[int, int]): The resolution of the input image.
|
266
304
|
depth (int): The number of MBConv layers in the block.
|
267
305
|
activation (nn.Module): Activation function applied after each convolution.
|
268
|
-
drop_path (float | List[float]): Drop path rate. Single float or a list of floats for each MBConv.
|
269
|
-
downsample (Optional[nn.Module]): Function for downsampling the output. None to skip downsampling.
|
270
|
-
use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
|
271
|
-
out_dim (Optional[int]): The dimensionality of the output. None means it will be the same as `dim`.
|
272
|
-
conv_expand_ratio (float): Expansion ratio for the MBConv layers.
|
273
|
-
|
274
|
-
Examples:
|
275
|
-
>>> input_tensor = torch.randn(1, 64, 56, 56)
|
276
|
-
>>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU)
|
277
|
-
>>> output = conv_layer(input_tensor)
|
278
|
-
>>> print(output.shape)
|
306
|
+
drop_path (float | List[float], optional): Drop path rate. Single float or a list of floats for each MBConv.
|
307
|
+
downsample (Optional[nn.Module], optional): Function for downsampling the output. None to skip downsampling.
|
308
|
+
use_checkpoint (bool, optional): Whether to use gradient checkpointing to save memory.
|
309
|
+
out_dim (Optional[int], optional): The dimensionality of the output. None means it will be the same as `dim`.
|
310
|
+
conv_expand_ratio (float, optional): Expansion ratio for the MBConv layers.
|
279
311
|
"""
|
280
312
|
super().__init__()
|
281
313
|
self.dim = dim
|
@@ -304,19 +336,19 @@ class ConvLayer(nn.Module):
|
|
304
336
|
else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
305
337
|
)
|
306
338
|
|
307
|
-
def forward(self, x):
|
308
|
-
"""
|
339
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
340
|
+
"""Process input through convolutional layers, applying MBConv blocks and optional downsampling."""
|
309
341
|
for blk in self.blocks:
|
310
342
|
x = torch.utils.checkpoint(blk, x) if self.use_checkpoint else blk(x) # warn: checkpoint is slow import
|
311
343
|
return x if self.downsample is None else self.downsample(x)
|
312
344
|
|
313
345
|
|
314
|
-
class
|
346
|
+
class MLP(nn.Module):
|
315
347
|
"""
|
316
348
|
Multi-layer Perceptron (MLP) module for transformer architectures.
|
317
349
|
|
318
350
|
This module applies layer normalization, two fully-connected layers with an activation function in between,
|
319
|
-
and dropout. It is commonly used in transformer-based architectures.
|
351
|
+
and dropout. It is commonly used in transformer-based architectures for processing token embeddings.
|
320
352
|
|
321
353
|
Attributes:
|
322
354
|
norm (nn.LayerNorm): Layer normalization applied to the input.
|
@@ -325,32 +357,45 @@ class Mlp(nn.Module):
|
|
325
357
|
act (nn.Module): Activation function applied after the first fully-connected layer.
|
326
358
|
drop (nn.Dropout): Dropout layer applied after the activation function.
|
327
359
|
|
328
|
-
Methods:
|
329
|
-
forward: Applies the MLP operations on the input tensor.
|
330
|
-
|
331
360
|
Examples:
|
332
361
|
>>> import torch
|
333
362
|
>>> from torch import nn
|
334
|
-
>>> mlp =
|
363
|
+
>>> mlp = MLP(in_features=256, hidden_features=512, out_features=256, activation=nn.GELU, drop=0.1)
|
335
364
|
>>> x = torch.randn(32, 100, 256)
|
336
365
|
>>> output = mlp(x)
|
337
366
|
>>> print(output.shape)
|
338
367
|
torch.Size([32, 100, 256])
|
339
368
|
"""
|
340
369
|
|
341
|
-
def __init__(
|
342
|
-
|
370
|
+
def __init__(
|
371
|
+
self,
|
372
|
+
in_features: int,
|
373
|
+
hidden_features: Optional[int] = None,
|
374
|
+
out_features: Optional[int] = None,
|
375
|
+
activation=nn.GELU,
|
376
|
+
drop: float = 0.0,
|
377
|
+
):
|
378
|
+
"""
|
379
|
+
Initialize a multi-layer perceptron with configurable input, hidden, and output dimensions.
|
380
|
+
|
381
|
+
Args:
|
382
|
+
in_features (int): Number of input features.
|
383
|
+
hidden_features (Optional[int], optional): Number of hidden features.
|
384
|
+
out_features (Optional[int], optional): Number of output features.
|
385
|
+
activation (nn.Module): Activation function applied after the first fully-connected layer.
|
386
|
+
drop (float, optional): Dropout probability.
|
387
|
+
"""
|
343
388
|
super().__init__()
|
344
389
|
out_features = out_features or in_features
|
345
390
|
hidden_features = hidden_features or in_features
|
346
391
|
self.norm = nn.LayerNorm(in_features)
|
347
392
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
348
393
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
349
|
-
self.act =
|
394
|
+
self.act = activation()
|
350
395
|
self.drop = nn.Dropout(drop)
|
351
396
|
|
352
|
-
def forward(self, x):
|
353
|
-
"""
|
397
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
398
|
+
"""Apply MLP operations: layer norm, FC layers, activation, and dropout to the input tensor."""
|
354
399
|
x = self.norm(x)
|
355
400
|
x = self.fc1(x)
|
356
401
|
x = self.act(x)
|
@@ -379,12 +424,8 @@ class Attention(torch.nn.Module):
|
|
379
424
|
qkv (nn.Linear): Linear layer for computing query, key, and value projections.
|
380
425
|
proj (nn.Linear): Linear layer for final projection.
|
381
426
|
attention_biases (nn.Parameter): Learnable attention biases.
|
382
|
-
attention_bias_idxs (Tensor): Indices for attention biases.
|
383
|
-
ab (Tensor): Cached attention biases for inference, deleted during training.
|
384
|
-
|
385
|
-
Methods:
|
386
|
-
train: Sets the module in training mode and handles the 'ab' attribute.
|
387
|
-
forward: Performs the forward pass of the attention mechanism.
|
427
|
+
attention_bias_idxs (torch.Tensor): Indices for attention biases.
|
428
|
+
ab (torch.Tensor): Cached attention biases for inference, deleted during training.
|
388
429
|
|
389
430
|
Examples:
|
390
431
|
>>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14))
|
@@ -396,14 +437,14 @@ class Attention(torch.nn.Module):
|
|
396
437
|
|
397
438
|
def __init__(
|
398
439
|
self,
|
399
|
-
dim,
|
400
|
-
key_dim,
|
401
|
-
num_heads=8,
|
402
|
-
attn_ratio=4,
|
403
|
-
resolution=(14, 14),
|
440
|
+
dim: int,
|
441
|
+
key_dim: int,
|
442
|
+
num_heads: int = 8,
|
443
|
+
attn_ratio: float = 4,
|
444
|
+
resolution: Tuple[int, int] = (14, 14),
|
404
445
|
):
|
405
446
|
"""
|
406
|
-
|
447
|
+
Initialize the Attention module for multi-head attention with spatial awareness.
|
407
448
|
|
408
449
|
This module implements a multi-head attention mechanism with support for spatial awareness, applying
|
409
450
|
attention biases based on spatial resolution. It includes trainable attention biases for each unique
|
@@ -412,16 +453,9 @@ class Attention(torch.nn.Module):
|
|
412
453
|
Args:
|
413
454
|
dim (int): The dimensionality of the input and output.
|
414
455
|
key_dim (int): The dimensionality of the keys and queries.
|
415
|
-
num_heads (int): Number of attention heads.
|
416
|
-
attn_ratio (float): Attention ratio, affecting the dimensions of the value vectors.
|
417
|
-
resolution (Tuple[int, int]): Spatial resolution of the input feature map.
|
418
|
-
|
419
|
-
Examples:
|
420
|
-
>>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14))
|
421
|
-
>>> x = torch.randn(1, 196, 256)
|
422
|
-
>>> output = attn(x)
|
423
|
-
>>> print(output.shape)
|
424
|
-
torch.Size([1, 196, 256])
|
456
|
+
num_heads (int, optional): Number of attention heads.
|
457
|
+
attn_ratio (float, optional): Attention ratio, affecting the dimensions of the value vectors.
|
458
|
+
resolution (Tuple[int, int], optional): Spatial resolution of the input feature map.
|
425
459
|
"""
|
426
460
|
super().__init__()
|
427
461
|
|
@@ -453,16 +487,16 @@ class Attention(torch.nn.Module):
|
|
453
487
|
self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(N, N), persistent=False)
|
454
488
|
|
455
489
|
@torch.no_grad()
|
456
|
-
def train(self, mode=True):
|
457
|
-
"""
|
490
|
+
def train(self, mode: bool = True):
|
491
|
+
"""Set the module in training mode and handle the 'ab' attribute for cached attention biases."""
|
458
492
|
super().train(mode)
|
459
493
|
if mode and hasattr(self, "ab"):
|
460
494
|
del self.ab
|
461
495
|
else:
|
462
496
|
self.ab = self.attention_biases[:, self.attention_bias_idxs]
|
463
497
|
|
464
|
-
def forward(self, x):
|
465
|
-
"""
|
498
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
499
|
+
"""Apply multi-head attention with spatial awareness and trainable attention biases."""
|
466
500
|
B, N, _ = x.shape # B, N, C
|
467
501
|
|
468
502
|
# Normalization
|
@@ -490,7 +524,8 @@ class TinyViTBlock(nn.Module):
|
|
490
524
|
TinyViT Block that applies self-attention and a local convolution to the input.
|
491
525
|
|
492
526
|
This block is a key component of the TinyViT architecture, combining self-attention mechanisms with
|
493
|
-
local convolutions to process input features efficiently.
|
527
|
+
local convolutions to process input features efficiently. It supports windowed attention for
|
528
|
+
computational efficiency and includes residual connections.
|
494
529
|
|
495
530
|
Attributes:
|
496
531
|
dim (int): The dimensionality of the input and output.
|
@@ -500,13 +535,9 @@ class TinyViTBlock(nn.Module):
|
|
500
535
|
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
|
501
536
|
drop_path (nn.Module): Stochastic depth layer, identity function during inference.
|
502
537
|
attn (Attention): Self-attention module.
|
503
|
-
mlp (
|
538
|
+
mlp (MLP): Multi-layer perceptron module.
|
504
539
|
local_conv (Conv2d_BN): Depth-wise local convolution layer.
|
505
540
|
|
506
|
-
Methods:
|
507
|
-
forward: Processes the input through the TinyViT block.
|
508
|
-
extra_repr: Returns a string with extra information about the block's parameters.
|
509
|
-
|
510
541
|
Examples:
|
511
542
|
>>> input_tensor = torch.randn(1, 196, 192)
|
512
543
|
>>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3)
|
@@ -517,18 +548,18 @@ class TinyViTBlock(nn.Module):
|
|
517
548
|
|
518
549
|
def __init__(
|
519
550
|
self,
|
520
|
-
dim,
|
521
|
-
input_resolution,
|
522
|
-
num_heads,
|
523
|
-
window_size=7,
|
524
|
-
mlp_ratio=4.0,
|
525
|
-
drop=0.0,
|
526
|
-
drop_path=0.0,
|
527
|
-
local_conv_size=3,
|
551
|
+
dim: int,
|
552
|
+
input_resolution: Tuple[int, int],
|
553
|
+
num_heads: int,
|
554
|
+
window_size: int = 7,
|
555
|
+
mlp_ratio: float = 4.0,
|
556
|
+
drop: float = 0.0,
|
557
|
+
drop_path: float = 0.0,
|
558
|
+
local_conv_size: int = 3,
|
528
559
|
activation=nn.GELU,
|
529
560
|
):
|
530
561
|
"""
|
531
|
-
|
562
|
+
Initialize a TinyViT block with self-attention and local convolution.
|
532
563
|
|
533
564
|
This block is a key component of the TinyViT architecture, combining self-attention mechanisms with
|
534
565
|
local convolutions to process input features efficiently.
|
@@ -537,23 +568,12 @@ class TinyViTBlock(nn.Module):
|
|
537
568
|
dim (int): Dimensionality of the input and output features.
|
538
569
|
input_resolution (Tuple[int, int]): Spatial resolution of the input feature map (height, width).
|
539
570
|
num_heads (int): Number of attention heads.
|
540
|
-
window_size (int): Size of the attention window. Must be greater than 0.
|
541
|
-
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
|
542
|
-
drop (float): Dropout rate.
|
543
|
-
drop_path (float): Stochastic depth rate.
|
544
|
-
local_conv_size (int): Kernel size of the local convolution.
|
545
|
-
activation (
|
546
|
-
|
547
|
-
Raises:
|
548
|
-
AssertionError: If window_size is not greater than 0.
|
549
|
-
AssertionError: If dim is not divisible by num_heads.
|
550
|
-
|
551
|
-
Examples:
|
552
|
-
>>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3)
|
553
|
-
>>> input_tensor = torch.randn(1, 196, 192)
|
554
|
-
>>> output = block(input_tensor)
|
555
|
-
>>> print(output.shape)
|
556
|
-
torch.Size([1, 196, 192])
|
571
|
+
window_size (int, optional): Size of the attention window. Must be greater than 0.
|
572
|
+
mlp_ratio (float, optional): Ratio of MLP hidden dimension to embedding dimension.
|
573
|
+
drop (float, optional): Dropout rate.
|
574
|
+
drop_path (float, optional): Stochastic depth rate.
|
575
|
+
local_conv_size (int, optional): Kernel size of the local convolution.
|
576
|
+
activation (nn.Module): Activation function for MLP.
|
557
577
|
"""
|
558
578
|
super().__init__()
|
559
579
|
self.dim = dim
|
@@ -575,13 +595,13 @@ class TinyViTBlock(nn.Module):
|
|
575
595
|
|
576
596
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
577
597
|
mlp_activation = activation
|
578
|
-
self.mlp =
|
598
|
+
self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, activation=mlp_activation, drop=drop)
|
579
599
|
|
580
600
|
pad = local_conv_size // 2
|
581
601
|
self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
|
582
602
|
|
583
|
-
def forward(self, x):
|
584
|
-
"""
|
603
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
604
|
+
"""Apply self-attention, local convolution, and MLP operations to the input tensor."""
|
585
605
|
h, w = self.input_resolution
|
586
606
|
b, hw, c = x.shape # batch, height*width, channels
|
587
607
|
assert hw == h * w, "input feature has wrong size"
|
@@ -624,7 +644,7 @@ class TinyViTBlock(nn.Module):
|
|
624
644
|
|
625
645
|
def extra_repr(self) -> str:
|
626
646
|
"""
|
627
|
-
|
647
|
+
Return a string representation of the TinyViTBlock's parameters.
|
628
648
|
|
629
649
|
This method provides a formatted string containing key information about the TinyViTBlock, including its
|
630
650
|
dimension, input resolution, number of attention heads, window size, and MLP ratio.
|
@@ -648,7 +668,8 @@ class BasicLayer(nn.Module):
|
|
648
668
|
A basic TinyViT layer for one stage in a TinyViT architecture.
|
649
669
|
|
650
670
|
This class represents a single layer in the TinyViT model, consisting of multiple TinyViT blocks
|
651
|
-
and an optional downsampling operation.
|
671
|
+
and an optional downsampling operation. It processes features at a specific resolution and
|
672
|
+
dimensionality within the overall architecture.
|
652
673
|
|
653
674
|
Attributes:
|
654
675
|
dim (int): The dimensionality of the input and output features.
|
@@ -658,10 +679,6 @@ class BasicLayer(nn.Module):
|
|
658
679
|
blocks (nn.ModuleList): List of TinyViT blocks that make up this layer.
|
659
680
|
downsample (nn.Module | None): Downsample layer at the end of the layer, if specified.
|
660
681
|
|
661
|
-
Methods:
|
662
|
-
forward: Processes the input through the layer's blocks and optional downsampling.
|
663
|
-
extra_repr: Returns a string with the layer's parameters for printing.
|
664
|
-
|
665
682
|
Examples:
|
666
683
|
>>> input_tensor = torch.randn(1, 3136, 192)
|
667
684
|
>>> layer = BasicLayer(dim=192, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)
|
@@ -672,22 +689,22 @@ class BasicLayer(nn.Module):
|
|
672
689
|
|
673
690
|
def __init__(
|
674
691
|
self,
|
675
|
-
dim,
|
676
|
-
input_resolution,
|
677
|
-
depth,
|
678
|
-
num_heads,
|
679
|
-
window_size,
|
680
|
-
mlp_ratio=4.0,
|
681
|
-
drop=0.0,
|
682
|
-
drop_path=0.0,
|
683
|
-
downsample=None,
|
684
|
-
use_checkpoint=False,
|
685
|
-
local_conv_size=3,
|
692
|
+
dim: int,
|
693
|
+
input_resolution: Tuple[int, int],
|
694
|
+
depth: int,
|
695
|
+
num_heads: int,
|
696
|
+
window_size: int,
|
697
|
+
mlp_ratio: float = 4.0,
|
698
|
+
drop: float = 0.0,
|
699
|
+
drop_path: Union[float, List[float]] = 0.0,
|
700
|
+
downsample: Optional[nn.Module] = None,
|
701
|
+
use_checkpoint: bool = False,
|
702
|
+
local_conv_size: int = 3,
|
686
703
|
activation=nn.GELU,
|
687
|
-
out_dim=None,
|
704
|
+
out_dim: Optional[int] = None,
|
688
705
|
):
|
689
706
|
"""
|
690
|
-
|
707
|
+
Initialize a BasicLayer in the TinyViT architecture.
|
691
708
|
|
692
709
|
This layer consists of multiple TinyViT blocks and an optional downsampling operation. It is designed to
|
693
710
|
process feature maps at a specific resolution and dimensionality within the TinyViT model.
|
@@ -698,23 +715,14 @@ class BasicLayer(nn.Module):
|
|
698
715
|
depth (int): Number of TinyViT blocks in this layer.
|
699
716
|
num_heads (int): Number of attention heads in each TinyViT block.
|
700
717
|
window_size (int): Size of the local window for attention computation.
|
701
|
-
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
|
702
|
-
drop (float): Dropout rate.
|
703
|
-
drop_path (float | List[float]): Stochastic depth rate. Can be a float or a list of floats for each block.
|
704
|
-
downsample (nn.Module | None): Downsampling layer at the end of the layer. None to skip downsampling.
|
705
|
-
use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
|
706
|
-
local_conv_size (int): Kernel size for the local convolution in each TinyViT block.
|
718
|
+
mlp_ratio (float, optional): Ratio of MLP hidden dimension to embedding dimension.
|
719
|
+
drop (float, optional): Dropout rate.
|
720
|
+
drop_path (float | List[float], optional): Stochastic depth rate. Can be a float or a list of floats for each block.
|
721
|
+
downsample (nn.Module | None, optional): Downsampling layer at the end of the layer. None to skip downsampling.
|
722
|
+
use_checkpoint (bool, optional): Whether to use gradient checkpointing to save memory.
|
723
|
+
local_conv_size (int, optional): Kernel size for the local convolution in each TinyViT block.
|
707
724
|
activation (nn.Module): Activation function used in the MLP.
|
708
|
-
out_dim (int | None): Output dimension after downsampling. None means it will be the same as `dim`.
|
709
|
-
|
710
|
-
Raises:
|
711
|
-
ValueError: If `drop_path` is a list and its length doesn't match `depth`.
|
712
|
-
|
713
|
-
Examples:
|
714
|
-
>>> layer = BasicLayer(dim=96, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)
|
715
|
-
>>> x = torch.randn(1, 56 * 56, 96)
|
716
|
-
>>> output = layer(x)
|
717
|
-
>>> print(output.shape)
|
725
|
+
out_dim (int | None, optional): Output dimension after downsampling. None means it will be the same as `dim`.
|
718
726
|
"""
|
719
727
|
super().__init__()
|
720
728
|
self.dim = dim
|
@@ -747,14 +755,14 @@ class BasicLayer(nn.Module):
|
|
747
755
|
else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
748
756
|
)
|
749
757
|
|
750
|
-
def forward(self, x):
|
751
|
-
"""
|
758
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
759
|
+
"""Process input through TinyViT blocks and optional downsampling."""
|
752
760
|
for blk in self.blocks:
|
753
761
|
x = torch.utils.checkpoint(blk, x) if self.use_checkpoint else blk(x) # warn: checkpoint is slow import
|
754
762
|
return x if self.downsample is None else self.downsample(x)
|
755
763
|
|
756
764
|
def extra_repr(self) -> str:
|
757
|
-
"""
|
765
|
+
"""Return a string with the layer's parameters for printing."""
|
758
766
|
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
759
767
|
|
760
768
|
|
@@ -763,12 +771,13 @@ class TinyViT(nn.Module):
|
|
763
771
|
TinyViT: A compact vision transformer architecture for efficient image classification and feature extraction.
|
764
772
|
|
765
773
|
This class implements the TinyViT model, which combines elements of vision transformers and convolutional
|
766
|
-
neural networks for improved efficiency and performance on vision tasks.
|
774
|
+
neural networks for improved efficiency and performance on vision tasks. It features hierarchical processing
|
775
|
+
with patch embedding, multiple stages of attention and convolution blocks, and a feature refinement neck.
|
767
776
|
|
768
777
|
Attributes:
|
769
778
|
img_size (int): Input image size.
|
770
779
|
num_classes (int): Number of classification classes.
|
771
|
-
depths (
|
780
|
+
depths (Tuple[int, int, int, int]): Number of blocks in each stage.
|
772
781
|
num_layers (int): Total number of layers in the network.
|
773
782
|
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
|
774
783
|
patch_embed (PatchEmbed): Module for patch embedding.
|
@@ -778,66 +787,52 @@ class TinyViT(nn.Module):
|
|
778
787
|
head (nn.Linear): Linear layer for final classification.
|
779
788
|
neck (nn.Sequential): Neck module for feature refinement.
|
780
789
|
|
781
|
-
Methods:
|
782
|
-
set_layer_lr_decay: Sets layer-wise learning rate decay.
|
783
|
-
_init_weights: Initializes weights for linear and normalization layers.
|
784
|
-
no_weight_decay_keywords: Returns keywords for parameters that should not use weight decay.
|
785
|
-
forward_features: Processes input through the feature extraction layers.
|
786
|
-
forward: Performs a forward pass through the entire network.
|
787
|
-
|
788
790
|
Examples:
|
789
791
|
>>> model = TinyViT(img_size=224, num_classes=1000)
|
790
792
|
>>> x = torch.randn(1, 3, 224, 224)
|
791
793
|
>>> features = model.forward_features(x)
|
792
794
|
>>> print(features.shape)
|
793
|
-
torch.Size([1, 256,
|
795
|
+
torch.Size([1, 256, 56, 56])
|
794
796
|
"""
|
795
797
|
|
796
798
|
def __init__(
|
797
799
|
self,
|
798
|
-
img_size=224,
|
799
|
-
in_chans=3,
|
800
|
-
num_classes=1000,
|
801
|
-
embed_dims=(96, 192, 384, 768),
|
802
|
-
depths=(2, 2, 6, 2),
|
803
|
-
num_heads=(3, 6, 12, 24),
|
804
|
-
window_sizes=(7, 7, 14, 7),
|
805
|
-
mlp_ratio=4.0,
|
806
|
-
drop_rate=0.0,
|
807
|
-
drop_path_rate=0.1,
|
808
|
-
use_checkpoint=False,
|
809
|
-
mbconv_expand_ratio=4.0,
|
810
|
-
local_conv_size=3,
|
811
|
-
layer_lr_decay=1.0,
|
800
|
+
img_size: int = 224,
|
801
|
+
in_chans: int = 3,
|
802
|
+
num_classes: int = 1000,
|
803
|
+
embed_dims: Tuple[int, int, int, int] = (96, 192, 384, 768),
|
804
|
+
depths: Tuple[int, int, int, int] = (2, 2, 6, 2),
|
805
|
+
num_heads: Tuple[int, int, int, int] = (3, 6, 12, 24),
|
806
|
+
window_sizes: Tuple[int, int, int, int] = (7, 7, 14, 7),
|
807
|
+
mlp_ratio: float = 4.0,
|
808
|
+
drop_rate: float = 0.0,
|
809
|
+
drop_path_rate: float = 0.1,
|
810
|
+
use_checkpoint: bool = False,
|
811
|
+
mbconv_expand_ratio: float = 4.0,
|
812
|
+
local_conv_size: int = 3,
|
813
|
+
layer_lr_decay: float = 1.0,
|
812
814
|
):
|
813
815
|
"""
|
814
|
-
|
816
|
+
Initialize the TinyViT model.
|
815
817
|
|
816
818
|
This constructor sets up the TinyViT architecture, including patch embedding, multiple layers of
|
817
819
|
attention and convolution blocks, and a classification head.
|
818
820
|
|
819
821
|
Args:
|
820
|
-
img_size (int): Size of the input image.
|
821
|
-
in_chans (int): Number of input channels.
|
822
|
-
num_classes (int): Number of classes for classification.
|
823
|
-
embed_dims (Tuple[int, int, int, int]): Embedding dimensions for each stage.
|
824
|
-
depths (Tuple[int, int, int, int]): Number of blocks in each stage.
|
825
|
-
num_heads (Tuple[int, int, int, int]): Number of attention heads in each stage.
|
826
|
-
window_sizes (Tuple[int, int, int, int]): Window sizes for each stage.
|
827
|
-
mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
|
828
|
-
drop_rate (float): Dropout rate.
|
829
|
-
drop_path_rate (float): Stochastic depth rate.
|
830
|
-
use_checkpoint (bool): Whether to use checkpointing to save memory.
|
831
|
-
mbconv_expand_ratio (float): Expansion ratio for MBConv layer.
|
832
|
-
local_conv_size (int): Kernel size for local convolutions.
|
833
|
-
layer_lr_decay (float): Layer-wise learning rate decay factor.
|
834
|
-
|
835
|
-
Examples:
|
836
|
-
>>> model = TinyViT(img_size=224, num_classes=1000)
|
837
|
-
>>> x = torch.randn(1, 3, 224, 224)
|
838
|
-
>>> output = model(x)
|
839
|
-
>>> print(output.shape)
|
840
|
-
torch.Size([1, 1000])
|
822
|
+
img_size (int, optional): Size of the input image.
|
823
|
+
in_chans (int, optional): Number of input channels.
|
824
|
+
num_classes (int, optional): Number of classes for classification.
|
825
|
+
embed_dims (Tuple[int, int, int, int], optional): Embedding dimensions for each stage.
|
826
|
+
depths (Tuple[int, int, int, int], optional): Number of blocks in each stage.
|
827
|
+
num_heads (Tuple[int, int, int, int], optional): Number of attention heads in each stage.
|
828
|
+
window_sizes (Tuple[int, int, int, int], optional): Window sizes for each stage.
|
829
|
+
mlp_ratio (float, optional): Ratio of MLP hidden dim to embedding dim.
|
830
|
+
drop_rate (float, optional): Dropout rate.
|
831
|
+
drop_path_rate (float, optional): Stochastic depth rate.
|
832
|
+
use_checkpoint (bool, optional): Whether to use checkpointing to save memory.
|
833
|
+
mbconv_expand_ratio (float, optional): Expansion ratio for MBConv layer.
|
834
|
+
local_conv_size (int, optional): Kernel size for local convolutions.
|
835
|
+
layer_lr_decay (float, optional): Layer-wise learning rate decay factor.
|
841
836
|
"""
|
842
837
|
super().__init__()
|
843
838
|
self.img_size = img_size
|
@@ -914,8 +909,8 @@ class TinyViT(nn.Module):
|
|
914
909
|
LayerNorm2d(256),
|
915
910
|
)
|
916
911
|
|
917
|
-
def set_layer_lr_decay(self, layer_lr_decay):
|
918
|
-
"""
|
912
|
+
def set_layer_lr_decay(self, layer_lr_decay: float):
|
913
|
+
"""Set layer-wise learning rate decay for the TinyViT model based on depth."""
|
919
914
|
decay_rate = layer_lr_decay
|
920
915
|
|
921
916
|
# Layers -> blocks (depth)
|
@@ -923,7 +918,7 @@ class TinyViT(nn.Module):
|
|
923
918
|
lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
|
924
919
|
|
925
920
|
def _set_lr_scale(m, scale):
|
926
|
-
"""
|
921
|
+
"""Set the learning rate scale for each layer in the model based on the layer's depth."""
|
927
922
|
for p in m.parameters():
|
928
923
|
p.lr_scale = scale
|
929
924
|
|
@@ -943,7 +938,7 @@ class TinyViT(nn.Module):
|
|
943
938
|
p.param_name = k
|
944
939
|
|
945
940
|
def _check_lr_scale(m):
|
946
|
-
"""
|
941
|
+
"""Check if the learning rate scale attribute is present in module's parameters."""
|
947
942
|
for p in m.parameters():
|
948
943
|
assert hasattr(p, "lr_scale"), p.param_name
|
949
944
|
|
@@ -951,7 +946,7 @@ class TinyViT(nn.Module):
|
|
951
946
|
|
952
947
|
@staticmethod
|
953
948
|
def _init_weights(m):
|
954
|
-
"""
|
949
|
+
"""Initialize weights for linear and normalization layers in the TinyViT model."""
|
955
950
|
if isinstance(m, nn.Linear):
|
956
951
|
# NOTE: This initialization is needed only for training.
|
957
952
|
# trunc_normal_(m.weight, std=.02)
|
@@ -963,11 +958,11 @@ class TinyViT(nn.Module):
|
|
963
958
|
|
964
959
|
@torch.jit.ignore
|
965
960
|
def no_weight_decay_keywords(self):
|
966
|
-
"""
|
961
|
+
"""Return a set of keywords for parameters that should not use weight decay."""
|
967
962
|
return {"attention_biases"}
|
968
963
|
|
969
|
-
def forward_features(self, x):
|
970
|
-
"""
|
964
|
+
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
965
|
+
"""Process input through feature extraction layers, returning spatial features."""
|
971
966
|
x = self.patch_embed(x) # x input is (N, C, H, W)
|
972
967
|
|
973
968
|
x = self.layers[0](x)
|
@@ -981,11 +976,11 @@ class TinyViT(nn.Module):
|
|
981
976
|
x = x.permute(0, 3, 1, 2)
|
982
977
|
return self.neck(x)
|
983
978
|
|
984
|
-
def forward(self, x):
|
985
|
-
"""
|
979
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
980
|
+
"""Perform the forward pass through the TinyViT model, extracting features from the input image."""
|
986
981
|
return self.forward_features(x)
|
987
982
|
|
988
|
-
def set_imgsz(self, imgsz=[1024, 1024]):
|
983
|
+
def set_imgsz(self, imgsz: List[int] = [1024, 1024]):
|
989
984
|
"""Set image size to make model compatible with different image sizes."""
|
990
985
|
imgsz = [s // 4 for s in imgsz]
|
991
986
|
self.patches_resolution = imgsz
|