embedl-deploy-tensorrt 0.4.1__tar.gz → 0.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/PKG-INFO +7 -6
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/README.md +6 -5
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/backend.py +1 -0
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/modules/attention.py +35 -10
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/modules/conv.py +4 -4
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/modules/linear.py +3 -3
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/modules/pointwise.py +1 -1
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/modules/pool.py +1 -1
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/modules/swin_attention.py +56 -64
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/attention.py +123 -171
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/general.py +13 -82
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions.py +26 -3
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/patterns/quantizations.py +80 -15
- embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/__init__.py +3 -0
- embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/attention.py +808 -0
- embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/functional.py +325 -0
- embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/general.py +718 -0
- embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/utils.py +44 -0
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/patterns/smoothings.py +6 -4
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/plan.py +24 -12
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/tensorrt/__init__.py +2 -2
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/tensorrt/patterns/__init__.py +0 -2
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/version/public.py +1 -1
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy_tensorrt.egg-info/SOURCES.txt +5 -2
- embedl_deploy_tensorrt-0.4.1/src/embedl_deploy/_internal/tensorrt/patterns/recompositions.py +0 -1475
- embedl_deploy_tensorrt-0.4.1/src/embedl_deploy/_internal/tensorrt/patterns/utils.py +0 -89
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/LICENSE +0 -0
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/MANIFEST.in +0 -0
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/NOTICE +0 -0
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/pyproject.toml +0 -0
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/setup.cfg +0 -0
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/modules/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/patterns/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/tensorrt/modules/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/version/__init__.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: embedl-deploy-tensorrt
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.5.0
|
|
4
4
|
Summary: TensorRT backend for embedl-deploy.
|
|
5
5
|
Author-email: Embedl AB <support@embedl.com>
|
|
6
6
|
Project-URL: Homepage, https://www.embedl.com/
|
|
@@ -54,9 +54,10 @@ hardware target ensuring correct quantization and compilation.
|
|
|
54
54
|
|
|
55
55
|
## Supported Backends
|
|
56
56
|
|
|
57
|
-
| Backend
|
|
58
|
-
|
|
59
|
-
| NVIDIA TensorRT
|
|
57
|
+
| Backend | Status |
|
|
58
|
+
|---------------------------|-----------------|
|
|
59
|
+
| NVIDIA TensorRT (v10.3) | Supported |
|
|
60
|
+
| Lattice SensAI (v8.0) | In Development |
|
|
60
61
|
|
|
61
62
|
Contact Embedl for other backends.
|
|
62
63
|
|
|
@@ -71,7 +72,7 @@ intermediate.
|
|
|
71
72
|
|
|
72
73
|
---
|
|
73
74
|
|
|
74
|
-
## Quick Start
|
|
75
|
+
## Quick Start for TensorRT Backend
|
|
75
76
|
|
|
76
77
|
```python
|
|
77
78
|
import torch
|
|
@@ -85,7 +86,7 @@ model = Model().eval()
|
|
|
85
86
|
example_input = torch.randn(1, 3, 224, 224)
|
|
86
87
|
|
|
87
88
|
# 2. Transform — fuse and optimize for TensorRT in one call
|
|
88
|
-
# For more
|
|
89
|
+
# For more compatibility you can trace your model with torch.export.export
|
|
89
90
|
# as follows:
|
|
90
91
|
# model = torch.export.export(model, (example_input)).module()
|
|
91
92
|
res = transform(model, patterns=TENSORRT_PATTERNS)
|
|
@@ -35,9 +35,10 @@ hardware target ensuring correct quantization and compilation.
|
|
|
35
35
|
|
|
36
36
|
## Supported Backends
|
|
37
37
|
|
|
38
|
-
| Backend
|
|
39
|
-
|
|
40
|
-
| NVIDIA TensorRT
|
|
38
|
+
| Backend | Status |
|
|
39
|
+
|---------------------------|-----------------|
|
|
40
|
+
| NVIDIA TensorRT (v10.3) | Supported |
|
|
41
|
+
| Lattice SensAI (v8.0) | In Development |
|
|
41
42
|
|
|
42
43
|
Contact Embedl for other backends.
|
|
43
44
|
|
|
@@ -52,7 +53,7 @@ intermediate.
|
|
|
52
53
|
|
|
53
54
|
---
|
|
54
55
|
|
|
55
|
-
## Quick Start
|
|
56
|
+
## Quick Start for TensorRT Backend
|
|
56
57
|
|
|
57
58
|
```python
|
|
58
59
|
import torch
|
|
@@ -66,7 +67,7 @@ model = Model().eval()
|
|
|
66
67
|
example_input = torch.randn(1, 3, 224, 224)
|
|
67
68
|
|
|
68
69
|
# 2. Transform — fuse and optimize for TensorRT in one call
|
|
69
|
-
# For more
|
|
70
|
+
# For more compatibility you can trace your model with torch.export.export
|
|
70
71
|
# as follows:
|
|
71
72
|
# model = torch.export.export(model, (example_input)).module()
|
|
72
73
|
res = transform(model, patterns=TENSORRT_PATTERNS)
|
|
@@ -69,7 +69,7 @@ class MHAInProjection(ConvertedModule):
|
|
|
69
69
|
v = v.view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
|
|
70
70
|
return q, k, v
|
|
71
71
|
|
|
72
|
-
def __repr__(self) -> str:
|
|
72
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
73
73
|
embed_dim = self.num_heads * self.head_dim
|
|
74
74
|
return (
|
|
75
75
|
f"MHAInProjection("
|
|
@@ -80,7 +80,7 @@ class MHAInProjection(ConvertedModule):
|
|
|
80
80
|
|
|
81
81
|
|
|
82
82
|
class ScaledDotProductAttention(ConvertedModule):
|
|
83
|
-
"""Core attention: ``softmax(Q · Kᵀ
|
|
83
|
+
"""Core attention: ``softmax(Q · Kᵀ · scale) · V``.
|
|
84
84
|
|
|
85
85
|
:param num_heads:
|
|
86
86
|
Number of attention heads.
|
|
@@ -88,6 +88,14 @@ class ScaledDotProductAttention(ConvertedModule):
|
|
|
88
88
|
Dimension of each head.
|
|
89
89
|
:param dropout:
|
|
90
90
|
Dropout probability (applied during training only).
|
|
91
|
+
:param is_causal:
|
|
92
|
+
Whether to apply a causal mask. Mirrors the ``is_causal`` kwarg
|
|
93
|
+
of ``F.scaled_dot_product_attention``.
|
|
94
|
+
:param scale:
|
|
95
|
+
Explicit attention score scale (multiplied on Q·Kᵀ). When
|
|
96
|
+
``None`` the PyTorch default ``1/√head_dim`` is used. Models
|
|
97
|
+
that pre-scale Q themselves (e.g. chronos-2 + RoPE) must pass
|
|
98
|
+
``scale=1.0`` so the default scaling does not apply twice.
|
|
91
99
|
"""
|
|
92
100
|
|
|
93
101
|
def __init__(
|
|
@@ -95,11 +103,15 @@ class ScaledDotProductAttention(ConvertedModule):
|
|
|
95
103
|
num_heads: int,
|
|
96
104
|
head_dim: int,
|
|
97
105
|
dropout: float = 0.0,
|
|
106
|
+
is_causal: bool = False,
|
|
107
|
+
scale: float | None = None,
|
|
98
108
|
) -> None:
|
|
99
109
|
super().__init__()
|
|
100
110
|
self.num_heads = num_heads
|
|
101
111
|
self.head_dim = head_dim
|
|
102
112
|
self.dropout = dropout
|
|
113
|
+
self.is_causal = is_causal
|
|
114
|
+
self.scale = scale
|
|
103
115
|
|
|
104
116
|
def forward(
|
|
105
117
|
self,
|
|
@@ -117,8 +129,9 @@ class ScaledDotProductAttention(ConvertedModule):
|
|
|
117
129
|
:param v:
|
|
118
130
|
Value tensor ``[B, num_heads, S, head_dim]``.
|
|
119
131
|
:param attn_mask:
|
|
120
|
-
Optional attention mask.
|
|
121
|
-
takes an
|
|
132
|
+
Optional attention mask.
|
|
133
|
+
``torch.nn.functional.scaled_dot_product_attention`` takes an
|
|
134
|
+
optional 4th positional arg; ``WrapFunctionalSDPAPattern``
|
|
122
135
|
forwards whatever positional args were on the source node, so
|
|
123
136
|
this module accepts the mask too. SAM3, masked-LM, and
|
|
124
137
|
similar models that compile with mixed-mask attention rely
|
|
@@ -135,14 +148,18 @@ class ScaledDotProductAttention(ConvertedModule):
|
|
|
135
148
|
v,
|
|
136
149
|
attn_mask=attn_mask,
|
|
137
150
|
dropout_p=self.dropout if self.training else 0.0,
|
|
151
|
+
is_causal=self.is_causal,
|
|
152
|
+
scale=self.scale,
|
|
138
153
|
)
|
|
139
154
|
|
|
140
|
-
def __repr__(self) -> str:
|
|
155
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
141
156
|
return (
|
|
142
157
|
f"ScaledDotProductAttention("
|
|
143
158
|
f"num_heads={self.num_heads}, "
|
|
144
159
|
f"head_dim={self.head_dim}, "
|
|
145
|
-
f"dropout={self.dropout}
|
|
160
|
+
f"dropout={self.dropout}, "
|
|
161
|
+
f"is_causal={self.is_causal}, "
|
|
162
|
+
f"scale={self.scale})"
|
|
146
163
|
)
|
|
147
164
|
|
|
148
165
|
|
|
@@ -197,7 +214,7 @@ class FusedMHAInProjection(FusedModule):
|
|
|
197
214
|
v = v.view(batch, seq, num_heads, head_dim).transpose(1, 2)
|
|
198
215
|
return q, k, v
|
|
199
216
|
|
|
200
|
-
def __repr__(self) -> str:
|
|
217
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
201
218
|
embed_dim = self.in_proj.num_heads * self.in_proj.head_dim
|
|
202
219
|
return (
|
|
203
220
|
f"FusedMHAInProjection("
|
|
@@ -275,10 +292,18 @@ class FusedScaledDotProductAttention(FusedModule):
|
|
|
275
292
|
# MHA kernel onto the slower INT8-aware variant for no gain.
|
|
276
293
|
if not self.surrounded or not self.softmax_quant.enabled:
|
|
277
294
|
return self.attention(q, k, v, attn_mask)
|
|
278
|
-
#
|
|
295
|
+
# Honour the wrapped attention module's explicit ``scale`` if
|
|
296
|
+
# set — models that pre-scale Q themselves (chronos-2 + RoPE,
|
|
297
|
+
# for example) build with ``scale=1.0`` to disable the default
|
|
298
|
+
# ``1/sqrt(head_dim)`` scaling. Falling back to the default
|
|
299
|
+
# here would apply it twice and collapse softmax.
|
|
300
|
+
# Note on ``1/sqrt(head_dim)`` vs ``head_dim ** -0.5``: the
|
|
279
301
|
# tensor Pow with a negative float exponent traces to ONNX as a
|
|
280
302
|
# ``Cast → complex128`` node that TRT 10.x can't parse.
|
|
281
|
-
scale
|
|
303
|
+
if self.attention.scale is not None:
|
|
304
|
+
scale = self.attention.scale
|
|
305
|
+
else:
|
|
306
|
+
scale = 1.0 / math.sqrt(q.shape[-1])
|
|
282
307
|
attn_weight = torch.matmul(q, k.transpose(-2, -1)) * scale
|
|
283
308
|
if attn_mask is not None:
|
|
284
309
|
if attn_mask.dtype == torch.bool:
|
|
@@ -297,7 +322,7 @@ class FusedScaledDotProductAttention(FusedModule):
|
|
|
297
322
|
)
|
|
298
323
|
return torch.matmul(attn_weight, v)
|
|
299
324
|
|
|
300
|
-
def __repr__(self) -> str:
|
|
325
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
301
326
|
a = self.attention
|
|
302
327
|
qdq = "yes" if self.softmax_quant.enabled else "no"
|
|
303
328
|
return (
|
|
@@ -91,7 +91,7 @@ class FusedConvBNAct(FusedModule):
|
|
|
91
91
|
x = self.bn(x)
|
|
92
92
|
return self.act(x)
|
|
93
93
|
|
|
94
|
-
def __repr__(self) -> str:
|
|
94
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
95
95
|
bn_info = ""
|
|
96
96
|
if self.bn is not None:
|
|
97
97
|
bn_info = f", bn={self.bn.num_features} (foldable)"
|
|
@@ -129,7 +129,7 @@ class FusedConvBN(FusedModule):
|
|
|
129
129
|
x = self.bn(x)
|
|
130
130
|
return x
|
|
131
131
|
|
|
132
|
-
def __repr__(self) -> str:
|
|
132
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
133
133
|
bn_info = ""
|
|
134
134
|
if self.bn is not None:
|
|
135
135
|
bn_info = f", bn={self.bn.num_features} (foldable)"
|
|
@@ -168,7 +168,7 @@ class FusedConvBNActMaxPool(FusedModule):
|
|
|
168
168
|
x = self.act(x)
|
|
169
169
|
return self.maxpool(x)
|
|
170
170
|
|
|
171
|
-
def __repr__(self) -> str:
|
|
171
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
172
172
|
bn_info = ""
|
|
173
173
|
if self.bn is not None:
|
|
174
174
|
bn_info = f", bn={self.bn.num_features} (foldable)"
|
|
@@ -213,7 +213,7 @@ class FusedConvBNAddAct(FusedModule):
|
|
|
213
213
|
x = self.bn(x)
|
|
214
214
|
return self.act(x + residual)
|
|
215
215
|
|
|
216
|
-
def __repr__(self) -> str:
|
|
216
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
217
217
|
return (
|
|
218
218
|
f"FusedConvBNAddAct("
|
|
219
219
|
f"{self.conv.in_channels}→{self.conv.out_channels}, "
|
|
@@ -82,7 +82,7 @@ class FusedLinear(FusedModule):
|
|
|
82
82
|
# pylint: disable-next=not-callable
|
|
83
83
|
return F.linear(x, weight, self.linear.bias)
|
|
84
84
|
|
|
85
|
-
def __repr__(self) -> str:
|
|
85
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
86
86
|
return (
|
|
87
87
|
f"FusedLinear("
|
|
88
88
|
f"{self.linear.in_features}→{self.linear.out_features})"
|
|
@@ -113,7 +113,7 @@ class FusedLinearAct(FusedModule):
|
|
|
113
113
|
x = F.linear(x, weight, self.linear.bias)
|
|
114
114
|
return self.act(x)
|
|
115
115
|
|
|
116
|
-
def __repr__(self) -> str:
|
|
116
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
117
117
|
act_name = type(self.act).__name__
|
|
118
118
|
return (
|
|
119
119
|
f"FusedLinearAct("
|
|
@@ -151,7 +151,7 @@ class FusedLayerNorm(FusedModule):
|
|
|
151
151
|
"""Apply ``layer_norm``."""
|
|
152
152
|
return self.layer_norm(x)
|
|
153
153
|
|
|
154
|
-
def __repr__(self) -> str:
|
|
154
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
155
155
|
return (
|
|
156
156
|
f"FusedLayerNorm("
|
|
157
157
|
f"normalized_shape={self.layer_norm.normalized_shape}, "
|
|
@@ -9,7 +9,7 @@ are spatial rearrangements that need no quantization.
|
|
|
9
9
|
computes the shifted-window attention mask.
|
|
10
10
|
"""
|
|
11
11
|
|
|
12
|
-
from dataclasses import dataclass
|
|
12
|
+
from dataclasses import dataclass
|
|
13
13
|
|
|
14
14
|
import torch
|
|
15
15
|
import torch.nn.functional as F
|
|
@@ -19,7 +19,7 @@ from embedl_deploy._internal.core.modules import ConvertedModule, FusedModule
|
|
|
19
19
|
from embedl_deploy._internal.core.quantize.stubs import QuantStub
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
@dataclass
|
|
22
|
+
@dataclass(eq=False)
|
|
23
23
|
class SwinSpatialState:
|
|
24
24
|
"""Shared mutable state for spatial dimensions.
|
|
25
25
|
|
|
@@ -34,9 +34,7 @@ class SwinSpatialState:
|
|
|
34
34
|
pad_height: int = 0
|
|
35
35
|
pad_width: int = 0
|
|
36
36
|
#: Effective shift size after clamping for small feature maps.
|
|
37
|
-
effective_shift_size:
|
|
38
|
-
default_factory=lambda: [0, 0],
|
|
39
|
-
)
|
|
37
|
+
effective_shift_size: tuple[int, int] = (0, 0)
|
|
40
38
|
|
|
41
39
|
|
|
42
40
|
class SwinWindowPartition(ConvertedModule):
|
|
@@ -71,31 +69,29 @@ class SwinWindowPartition(ConvertedModule):
|
|
|
71
69
|
:returns:
|
|
72
70
|
Windowed tensor ``[B*nW, Ws*Ws, C]``.
|
|
73
71
|
"""
|
|
74
|
-
|
|
75
|
-
B, H, W, C = x.shape
|
|
72
|
+
b, h, w, c = x.shape
|
|
76
73
|
|
|
77
74
|
# Pad to multiples of window size.
|
|
78
75
|
ws_h, ws_w = self.window_size
|
|
79
|
-
pad_b = (ws_h -
|
|
80
|
-
pad_r = (ws_w -
|
|
76
|
+
pad_b = (ws_h - h % ws_h) % ws_h
|
|
77
|
+
pad_r = (ws_w - w % ws_w) % ws_w
|
|
81
78
|
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
|
|
82
|
-
|
|
83
|
-
_, pad_H, pad_W, _ = x.shape
|
|
79
|
+
_, pad_h, pad_w, _ = x.shape
|
|
84
80
|
|
|
85
81
|
# Clamp shift size when the window covers the whole feature map.
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
82
|
+
sh, sw = self.shift_size
|
|
83
|
+
eff_shift = (
|
|
84
|
+
0 if ws_h >= pad_h else sh,
|
|
85
|
+
0 if ws_w >= pad_w else sw,
|
|
86
|
+
)
|
|
91
87
|
|
|
92
88
|
# Write spatial state for downstream modules.
|
|
93
89
|
st = self._spatial_state
|
|
94
|
-
st.batch_size =
|
|
95
|
-
st.height =
|
|
96
|
-
st.width =
|
|
97
|
-
st.pad_height =
|
|
98
|
-
st.pad_width =
|
|
90
|
+
st.batch_size = b
|
|
91
|
+
st.height = h
|
|
92
|
+
st.width = w
|
|
93
|
+
st.pad_height = pad_h
|
|
94
|
+
st.pad_width = pad_w
|
|
99
95
|
st.effective_shift_size = eff_shift
|
|
100
96
|
|
|
101
97
|
# Cyclic shift.
|
|
@@ -108,22 +104,22 @@ class SwinWindowPartition(ConvertedModule):
|
|
|
108
104
|
|
|
109
105
|
# Window partition.
|
|
110
106
|
x = x.view(
|
|
111
|
-
|
|
112
|
-
|
|
107
|
+
b,
|
|
108
|
+
pad_h // ws_h,
|
|
113
109
|
ws_h,
|
|
114
|
-
|
|
110
|
+
pad_w // ws_w,
|
|
115
111
|
ws_w,
|
|
116
|
-
|
|
112
|
+
c,
|
|
117
113
|
)
|
|
118
|
-
num_windows = (
|
|
114
|
+
num_windows = (pad_h // ws_h) * (pad_w // ws_w)
|
|
119
115
|
x = x.permute(0, 1, 3, 2, 4, 5).reshape(
|
|
120
|
-
|
|
116
|
+
b * num_windows,
|
|
121
117
|
ws_h * ws_w,
|
|
122
|
-
|
|
118
|
+
c,
|
|
123
119
|
)
|
|
124
120
|
return x
|
|
125
121
|
|
|
126
|
-
def __repr__(self) -> str:
|
|
122
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
127
123
|
return (
|
|
128
124
|
f"SwinWindowPartition("
|
|
129
125
|
f"window_size={self.window_size}, "
|
|
@@ -177,11 +173,10 @@ class SwinAttention(ConvertedModule):
|
|
|
177
173
|
|
|
178
174
|
def _get_relative_position_bias(self) -> torch.Tensor:
|
|
179
175
|
"""Compute relative position bias ``[1, nH, N, N]``."""
|
|
180
|
-
|
|
181
|
-
N = self.window_size[0] * self.window_size[1]
|
|
176
|
+
n = self.window_size[0] * self.window_size[1]
|
|
182
177
|
assert isinstance(self.relative_position_index, torch.Tensor)
|
|
183
178
|
bias = self.relative_position_bias_table[self.relative_position_index]
|
|
184
|
-
bias = bias.view(
|
|
179
|
+
bias = bias.view(n, n, -1).permute(2, 0, 1).contiguous()
|
|
185
180
|
return bias.unsqueeze(0)
|
|
186
181
|
|
|
187
182
|
def _compute_attn_mask( # pylint: disable=too-many-locals
|
|
@@ -193,13 +188,13 @@ class SwinAttention(ConvertedModule):
|
|
|
193
188
|
if sum(eff_shift) == 0:
|
|
194
189
|
return None
|
|
195
190
|
|
|
196
|
-
|
|
197
|
-
|
|
191
|
+
pad_h = st.pad_height
|
|
192
|
+
pad_w = st.pad_width
|
|
198
193
|
ws_h, ws_w = self.window_size
|
|
199
|
-
num_windows = (
|
|
194
|
+
num_windows = (pad_h // ws_h) * (pad_w // ws_w)
|
|
200
195
|
|
|
201
196
|
attn_mask = torch.zeros(
|
|
202
|
-
(
|
|
197
|
+
(pad_h, pad_w),
|
|
203
198
|
device=self.relative_position_bias_table.device,
|
|
204
199
|
)
|
|
205
200
|
h_slices = (
|
|
@@ -219,9 +214,9 @@ class SwinAttention(ConvertedModule):
|
|
|
219
214
|
count += 1
|
|
220
215
|
|
|
221
216
|
attn_mask = attn_mask.view(
|
|
222
|
-
|
|
217
|
+
pad_h // ws_h,
|
|
223
218
|
ws_h,
|
|
224
|
-
|
|
219
|
+
pad_w // ws_w,
|
|
225
220
|
ws_w,
|
|
226
221
|
)
|
|
227
222
|
attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(
|
|
@@ -259,10 +254,9 @@ class SwinAttention(ConvertedModule):
|
|
|
259
254
|
|
|
260
255
|
attn_mask = self._compute_attn_mask()
|
|
261
256
|
if attn_mask is not None:
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
attn = attn.view(B, nW, self.num_heads, -1, attn.size(-1))
|
|
257
|
+
b = self._spatial_state.batch_size
|
|
258
|
+
n_w = attn.size(0) // b
|
|
259
|
+
attn = attn.view(b, n_w, self.num_heads, -1, attn.size(-1))
|
|
266
260
|
attn = attn + attn_mask.unsqueeze(1).unsqueeze(0)
|
|
267
261
|
attn = attn.view(-1, self.num_heads, attn.size(-2), attn.size(-1))
|
|
268
262
|
|
|
@@ -282,7 +276,7 @@ class SwinAttention(ConvertedModule):
|
|
|
282
276
|
)
|
|
283
277
|
return x
|
|
284
278
|
|
|
285
|
-
def __repr__(self) -> str:
|
|
279
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
286
280
|
embed_dim = self.num_heads * self.head_dim
|
|
287
281
|
return (
|
|
288
282
|
f"SwinAttention("
|
|
@@ -320,27 +314,25 @@ class SwinWindowReverse(ConvertedModule):
|
|
|
320
314
|
Spatial tensor ``[B, H, W, C]``.
|
|
321
315
|
"""
|
|
322
316
|
st = self._spatial_state
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
W = st.width
|
|
329
|
-
# pylint: enable=invalid-name
|
|
317
|
+
b = st.batch_size
|
|
318
|
+
pad_h = st.pad_height
|
|
319
|
+
pad_w = st.pad_width
|
|
320
|
+
h = st.height
|
|
321
|
+
w = st.width
|
|
330
322
|
eff_shift = st.effective_shift_size
|
|
331
323
|
ws_h, ws_w = self.window_size
|
|
332
|
-
|
|
324
|
+
c = x.size(-1)
|
|
333
325
|
|
|
334
326
|
# Reverse window partition.
|
|
335
327
|
x = x.view(
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
328
|
+
b,
|
|
329
|
+
pad_h // ws_h,
|
|
330
|
+
pad_w // ws_w,
|
|
339
331
|
ws_h,
|
|
340
332
|
ws_w,
|
|
341
|
-
|
|
333
|
+
c,
|
|
342
334
|
)
|
|
343
|
-
x = x.permute(0, 1, 3, 2, 4, 5).reshape(
|
|
335
|
+
x = x.permute(0, 1, 3, 2, 4, 5).reshape(b, pad_h, pad_w, c)
|
|
344
336
|
|
|
345
337
|
# Reverse cyclic shift.
|
|
346
338
|
if sum(eff_shift) > 0:
|
|
@@ -351,10 +343,10 @@ class SwinWindowReverse(ConvertedModule):
|
|
|
351
343
|
)
|
|
352
344
|
|
|
353
345
|
# Remove padding.
|
|
354
|
-
x = x[:, :
|
|
346
|
+
x = x[:, :h, :w, :].contiguous()
|
|
355
347
|
return x
|
|
356
348
|
|
|
357
|
-
def __repr__(self) -> str:
|
|
349
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
358
350
|
return f"SwinWindowReverse(window_size={self.window_size})"
|
|
359
351
|
|
|
360
352
|
|
|
@@ -411,13 +403,13 @@ class FusedSwinAttention(FusedModule):
|
|
|
411
403
|
# pylint: disable-next=protected-access
|
|
412
404
|
attn_mask = a._compute_attn_mask() # noqa: SLF001
|
|
413
405
|
if attn_mask is not None:
|
|
414
|
-
# pylint: disable-next=protected-access
|
|
415
|
-
|
|
416
|
-
|
|
406
|
+
# pylint: disable-next=protected-access
|
|
407
|
+
b = a._spatial_state.batch_size # noqa: SLF001
|
|
408
|
+
n_w = attn_weight.size(0) // b
|
|
417
409
|
n = attn_weight.size(-2)
|
|
418
410
|
attn_weight = attn_weight.view(
|
|
419
|
-
|
|
420
|
-
|
|
411
|
+
b,
|
|
412
|
+
n_w,
|
|
421
413
|
a.num_heads,
|
|
422
414
|
n,
|
|
423
415
|
attn_weight.size(-1),
|
|
@@ -448,7 +440,7 @@ class FusedSwinAttention(FusedModule):
|
|
|
448
440
|
)
|
|
449
441
|
)
|
|
450
442
|
|
|
451
|
-
def __repr__(self) -> str:
|
|
443
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
452
444
|
a = self.attention
|
|
453
445
|
qdq = "yes" if self.softmax_quant.enabled else "no"
|
|
454
446
|
return (
|