embedl-deploy-tensorrt 0.6.0__tar.gz → 0.7.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/PKG-INFO +4 -6
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/README.md +3 -5
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/__init__.py +2 -0
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/modules/AGENTS.md +41 -29
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/modules/attention.py +32 -27
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/modules/conv.py +110 -45
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/modules/linear.py +10 -54
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/modules/pool.py +7 -3
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/modules/swin_attention.py +8 -3
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/patterns/AGENTS.md +68 -18
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions.py +50 -16
- embedl_deploy_tensorrt-0.7.0/src/embedl_deploy/_internal/tensorrt/patterns/quantizations.py +590 -0
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/plan.py +10 -0
- embedl_deploy_tensorrt-0.7.0/src/embedl_deploy/_internal/tensorrt/precision.py +197 -0
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/tensorrt/__init__.py +4 -0
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/version/public.py +1 -1
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy_tensorrt.egg-info/SOURCES.txt +1 -0
- embedl_deploy_tensorrt-0.6.0/src/embedl_deploy/_internal/tensorrt/patterns/quantizations.py +0 -424
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/LICENSE +0 -0
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/MANIFEST.in +0 -0
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/NOTICE +0 -0
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/pyproject.toml +0 -0
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/setup.cfg +0 -0
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/backend.py +0 -0
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/modules/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/modules/pointwise.py +0 -0
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/patterns/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/attention.py +0 -0
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/general.py +0 -0
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/patterns/smoothings.py +0 -0
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/tensorrt/modules/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/tensorrt/patterns/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.6.0 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/version/__init__.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: embedl-deploy-tensorrt
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.7.0
|
|
4
4
|
Summary: TensorRT backend for embedl-deploy.
|
|
5
5
|
Author-email: Embedl AB <support@embedl.com>
|
|
6
6
|
Project-URL: Homepage, https://www.embedl.com/
|
|
@@ -58,6 +58,7 @@ hardware target ensuring correct quantization and compilation.
|
|
|
58
58
|
|---------------------------|-----------------|
|
|
59
59
|
| NVIDIA TensorRT (v10.3) | Supported |
|
|
60
60
|
| Lattice SensAI (v8.0) | In Development |
|
|
61
|
+
| AMD Vitis AI | In Development |
|
|
61
62
|
|
|
62
63
|
Contact Embedl for other backends.
|
|
63
64
|
|
|
@@ -86,10 +87,7 @@ model = Model().eval()
|
|
|
86
87
|
example_input = torch.randn(1, 3, 224, 224)
|
|
87
88
|
|
|
88
89
|
# 2. Transform — fuse and optimize for TensorRT in one call
|
|
89
|
-
|
|
90
|
-
# as follows:
|
|
91
|
-
# model = torch.export.export(model, (example_input)).module()
|
|
92
|
-
res = transform(model, patterns=TENSORRT_PATTERNS)
|
|
90
|
+
res = transform(model, (example_input,), patterns=TENSORRT_PATTERNS)
|
|
93
91
|
print("Model\n", res.model.print_readable())
|
|
94
92
|
print("Matches", "\n".join([str(match) for match in res.matches]))
|
|
95
93
|
|
|
@@ -149,7 +147,7 @@ the reference **from the fused graph**, not from the original model:
|
|
|
149
147
|
```python
|
|
150
148
|
from embedl_deploy.quantize import quantize, QuantConfig, ModulesToSkip
|
|
151
149
|
|
|
152
|
-
res = transform(model, patterns=TENSORRT_PATTERNS)
|
|
150
|
+
res = transform(model, (example_input,), patterns=TENSORRT_PATTERNS)
|
|
153
151
|
|
|
154
152
|
# Grab the conv instance from the fused graph (not from the original model)
|
|
155
153
|
first_conv = res.model.FusedConvBNActMaxPool_0.conv
|
|
@@ -39,6 +39,7 @@ hardware target ensuring correct quantization and compilation.
|
|
|
39
39
|
|---------------------------|-----------------|
|
|
40
40
|
| NVIDIA TensorRT (v10.3) | Supported |
|
|
41
41
|
| Lattice SensAI (v8.0) | In Development |
|
|
42
|
+
| AMD Vitis AI | In Development |
|
|
42
43
|
|
|
43
44
|
Contact Embedl for other backends.
|
|
44
45
|
|
|
@@ -67,10 +68,7 @@ model = Model().eval()
|
|
|
67
68
|
example_input = torch.randn(1, 3, 224, 224)
|
|
68
69
|
|
|
69
70
|
# 2. Transform — fuse and optimize for TensorRT in one call
|
|
70
|
-
|
|
71
|
-
# as follows:
|
|
72
|
-
# model = torch.export.export(model, (example_input)).module()
|
|
73
|
-
res = transform(model, patterns=TENSORRT_PATTERNS)
|
|
71
|
+
res = transform(model, (example_input,), patterns=TENSORRT_PATTERNS)
|
|
74
72
|
print("Model\n", res.model.print_readable())
|
|
75
73
|
print("Matches", "\n".join([str(match) for match in res.matches]))
|
|
76
74
|
|
|
@@ -130,7 +128,7 @@ the reference **from the fused graph**, not from the original model:
|
|
|
130
128
|
```python
|
|
131
129
|
from embedl_deploy.quantize import quantize, QuantConfig, ModulesToSkip
|
|
132
130
|
|
|
133
|
-
res = transform(model, patterns=TENSORRT_PATTERNS)
|
|
131
|
+
res = transform(model, (example_input,), patterns=TENSORRT_PATTERNS)
|
|
134
132
|
|
|
135
133
|
# Grab the conv instance from the fused graph (not from the original model)
|
|
136
134
|
first_conv = res.model.FusedConvBNActMaxPool_0.conv
|
|
@@ -8,6 +8,7 @@ from embedl_deploy._internal.core.plan import (
|
|
|
8
8
|
TransformationResult,
|
|
9
9
|
apply_transformation_plan,
|
|
10
10
|
get_transformation_plan,
|
|
11
|
+
prepare_graph,
|
|
11
12
|
transform,
|
|
12
13
|
)
|
|
13
14
|
from embedl_deploy.version.public import PUBLIC_VERSION
|
|
@@ -21,5 +22,6 @@ __all__ = [
|
|
|
21
22
|
"__version__",
|
|
22
23
|
"apply_transformation_plan",
|
|
23
24
|
"get_transformation_plan",
|
|
25
|
+
"prepare_graph",
|
|
24
26
|
"transform",
|
|
25
27
|
]
|
|
@@ -31,10 +31,10 @@ core/quantize/prepare.py # walks the graph, uses isinstance(mod, FusedModu
|
|
|
31
31
|
```
|
|
32
32
|
|
|
33
33
|
**Instantiated by:** pattern grafts and `replace()` methods in `patterns/fusions.py`
|
|
34
|
-
and `patterns/conversions/`. Most fusion patterns declare
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
34
|
+
and `patterns/conversions/`. Most fusion patterns declare `graft = (make_fused(FusedFoo),)`;
|
|
35
|
+
`make_fused` gathers the matched modules in tree order (via `_collect_modules()`)
|
|
36
|
+
and passes them as positional arguments to the constructor. For example,
|
|
37
|
+
`ConvBNActPattern` grafts `(make_fused(FusedConvBNAct),)`;
|
|
38
38
|
`DecomposeMultiheadAttentionPattern.replace()` constructs `MHAInProjection` and
|
|
39
39
|
`ScaledDotProductAttention`.
|
|
40
40
|
|
|
@@ -61,15 +61,18 @@ Every `FusedModule` subclass must satisfy three requirements:
|
|
|
61
61
|
2. **Call `super().__init__()`**, which creates `self.input_quant_stubs` — a dict
|
|
62
62
|
mapping each index in `inputs_to_quantize` to a fresh `QuantStub`. The Q/DQ
|
|
63
63
|
pass later enables and configures these stubs during `prepare_qdq()`. It also
|
|
64
|
-
initialises `self.
|
|
65
|
-
`
|
|
66
|
-
|
|
64
|
+
initialises `self.output_precision = Precision.UNSET`. Surround-type modules
|
|
65
|
+
override this to `Precision.DEFERRED` in their constructors to signal that
|
|
66
|
+
their INT8 decision depends on graph context (see
|
|
67
|
+
`FusedAdaptiveAvgPool2d`, `FusedScaledDotProductAttention`,
|
|
68
|
+
`FusedSwinAttention`). Quantization patterns read and update
|
|
69
|
+
`output_precision` to track precision flow through the graph.
|
|
67
70
|
|
|
68
71
|
3. **Implement `forward()`** with the fused computation the module represents.
|
|
69
72
|
|
|
70
73
|
### Graft compatibility
|
|
71
74
|
|
|
72
|
-
When a pattern uses `graft = FusedFoo
|
|
75
|
+
When a pattern uses `graft = (make_fused(FusedFoo),)`, `make_fused` calls
|
|
73
76
|
`_collect_modules()` to walk the matched tree and collect the `nn.Module`
|
|
74
77
|
instances corresponding to trunk and fork nodes (nested branches first, then
|
|
75
78
|
trunk nodes). These are passed as positional arguments to the constructor.
|
|
@@ -155,11 +158,14 @@ Fusing them together avoids an extra quantized activation between Conv and Pool.
|
|
|
155
158
|
`inputs_to_quantize = {0, 1}`. The residual tensor is the second input (index 1),
|
|
156
159
|
hence both inputs are quantized. This is the ResNet skip-connection block tail.
|
|
157
160
|
|
|
158
|
-
**INT8 compatibility
|
|
159
|
-
`
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
161
|
+
**INT8 compatibility:** all conv fused modules unconditionally create a
|
|
162
|
+
`WeightFakeQuantize` and `input_quant_stubs` in their constructors. INT8
|
|
163
|
+
compatibility filtering (grouped convolutions where `in_channels / groups` or
|
|
164
|
+
`out_channels / groups` is not a multiple of 4, depthwise convolutions) is
|
|
165
|
+
deferred to `DisableInt8Pattern` in the quantization pass, which
|
|
166
|
+
calls `disable_int8()` on incompatible modules. Depthwise convolutions set
|
|
167
|
+
`output_precision = Precision.DEFERRED` so the surround pass can decide
|
|
168
|
+
contextually whether to enable INT8.
|
|
163
169
|
|
|
164
170
|
### Linear family — `linear.py`
|
|
165
171
|
|
|
@@ -217,7 +223,9 @@ is quantized; `_key` and `_value` are accepted to match the self-attention
|
|
|
217
223
|
call-site but ignored.
|
|
218
224
|
|
|
219
225
|
**`FusedScaledDotProductAttention`** — wraps `ScaledDotProductAttention`,
|
|
220
|
-
`inputs_to_quantize =
|
|
226
|
+
`inputs_to_quantize = {0, 1, 2}`. The pre-declared stubs (Q, K, V) are created
|
|
227
|
+
disabled; `SurroundWithQuantStubsPattern` enables them when the module is
|
|
228
|
+
surrounded with Q/DQ stubs. Adds an internal `softmax_quant` stub with a fixed
|
|
221
229
|
calibration of `(1/127, 0)` — i.e., 8-bit symmetric with a fixed scale matched to
|
|
222
230
|
the softmax output range `[0, 1]`. When the stub is disabled the module delegates
|
|
223
231
|
to the plain SDPA; when enabled it performs manual attention with the quantization
|
|
@@ -262,10 +270,12 @@ three modules. This is intentional: the three modules reference the *same*
|
|
|
262
270
|
referencing attributes to point to the copy. However, the sharing is not
|
|
263
271
|
thread-safe (see Gotchas).
|
|
264
272
|
|
|
265
|
-
**`FusedSwinAttention`** — wraps `SwinAttention`, `inputs_to_quantize =
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
273
|
+
**`FusedSwinAttention`** — wraps `SwinAttention`, `inputs_to_quantize = {0, 1, 2}`.
|
|
274
|
+
The pre-declared stubs (Q, K, V) are created disabled;
|
|
275
|
+
`SurroundWithQuantStubsPattern` enables them. Mirrors
|
|
276
|
+
`FusedScaledDotProductAttention`: adds an internal `softmax_quant` stub with fixed
|
|
277
|
+
calibration. When enabled it manually expands the attention computation to insert
|
|
278
|
+
the quantization step between softmax and BMM2.
|
|
269
279
|
|
|
270
280
|
### Pointwise family — `pointwise.py`
|
|
271
281
|
|
|
@@ -283,9 +293,9 @@ quantized at their respective scales.
|
|
|
283
293
|
Pattern that creates it: `patterns/fusions.py`.
|
|
284
294
|
|
|
285
295
|
**`FusedAdaptiveAvgPool2d`** — wraps `nn.AdaptiveAvgPool2d`,
|
|
286
|
-
`inputs_to_quantize =
|
|
287
|
-
|
|
288
|
-
|
|
296
|
+
`inputs_to_quantize = {0}`. Defers output precision to the surround pattern.
|
|
297
|
+
The pre-declared stub is created disabled; `SurroundWithQuantStubsPattern`
|
|
298
|
+
enables it when the module is surrounded with Q/DQ stubs.
|
|
289
299
|
|
|
290
300
|
---
|
|
291
301
|
|
|
@@ -355,11 +365,13 @@ weights are modified after fusion, the fused module sees the change. This is
|
|
|
355
365
|
usually desirable (e.g. for QAT gradient updates), but can be surprising if the
|
|
356
366
|
original model is used independently.
|
|
357
367
|
|
|
358
|
-
**Grouped conv INT8 opt-out:**
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
368
|
+
**Grouped conv INT8 opt-out:** INT8 compatibility for grouped convolutions is
|
|
369
|
+
no longer handled in the module constructor. Instead, `DisableInt8Pattern`
|
|
370
|
+
calls `disable_int8()` on modules where `is_int8_beneficial_conv()` returns
|
|
371
|
+
`False`. The module still has `input_quant_stubs` and `weight_fake_quant` after
|
|
372
|
+
construction — they are disabled by the pattern pass. This means
|
|
373
|
+
`bool(mod.input_quant_stubs)` is `True` even for incompatible convolutions until
|
|
374
|
+
the quantization pass runs.
|
|
363
375
|
|
|
364
376
|
---
|
|
365
377
|
|
|
@@ -372,11 +384,11 @@ is effectively excluded from quantization despite being a `FusedModule`. The
|
|
|
372
384
|
`__init__` if the module has a learnable weight that should be fake-quantized.
|
|
373
385
|
4. **Implement `forward()`** with the fused computation.
|
|
374
386
|
5. **Write a `Pattern` subclass** in `patterns/fusions.py`. Prefer declaring
|
|
375
|
-
`graft = FusedFoo`
|
|
387
|
+
`graft = (make_fused(FusedFoo),)` so the graft system handles replacement
|
|
376
388
|
automatically. The constructor must accept modules in tree order (nested
|
|
377
389
|
branches first, then trunk nodes). If the replacement logic cannot be
|
|
378
|
-
expressed
|
|
379
|
-
`replace()` method instead.
|
|
390
|
+
expressed with `make_fused`, provide a different `ReplacementMaker` or a
|
|
391
|
+
custom `replace()` method instead.
|
|
380
392
|
6. **Add the pattern to `TENSORRT_PATTERNS`** (or the appropriate pattern list) in
|
|
381
393
|
`tensorrt/plan.py`.
|
|
382
394
|
7. **Write tests** in `tests/tensorrt/patterns/fusions/` covering: pattern match,
|
|
@@ -13,11 +13,14 @@ import torch
|
|
|
13
13
|
import torch.nn.functional as F
|
|
14
14
|
from torch import nn
|
|
15
15
|
|
|
16
|
-
from embedl_deploy._internal.core.modules import
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
16
|
+
from embedl_deploy._internal.core.modules import (
|
|
17
|
+
ConvertedModule,
|
|
18
|
+
FusedModule,
|
|
19
|
+
Precision,
|
|
20
|
+
)
|
|
21
|
+
from embedl_deploy._internal.core.quantize.stubs import (
|
|
22
|
+
QuantStub,
|
|
23
|
+
WeightFakeQuantize,
|
|
21
24
|
)
|
|
22
25
|
|
|
23
26
|
|
|
@@ -181,7 +184,7 @@ class FusedMHAInProjection(FusedModule):
|
|
|
181
184
|
def __init__(self, in_proj: MHAInProjection) -> None:
|
|
182
185
|
super().__init__()
|
|
183
186
|
self.in_proj = in_proj
|
|
184
|
-
|
|
187
|
+
self.weight_fake_quant = WeightFakeQuantize({self})
|
|
185
188
|
|
|
186
189
|
@property
|
|
187
190
|
def quantized_weight(self) -> torch.Tensor | None:
|
|
@@ -205,7 +208,7 @@ class FusedMHAInProjection(FusedModule):
|
|
|
205
208
|
:returns:
|
|
206
209
|
Tuple ``(Q, K, V)`` each of shape ``[B, num_heads, S, head_dim]``.
|
|
207
210
|
"""
|
|
208
|
-
weight =
|
|
211
|
+
weight = self.weight_fake_quant(self.in_proj.linear.weight)
|
|
209
212
|
batch, seq, _ = query.shape
|
|
210
213
|
qkv = F.linear(query, weight, self.in_proj.linear.bias)
|
|
211
214
|
q, k, v = qkv.chunk(3, dim=-1)
|
|
@@ -229,15 +232,13 @@ class FusedMHAInProjection(FusedModule):
|
|
|
229
232
|
class FusedScaledDotProductAttention(FusedModule):
|
|
230
233
|
"""Fused wrapper for ``ScaledDotProductAttention``.
|
|
231
234
|
|
|
232
|
-
|
|
233
|
-
of the three inputs (Q, K, V).
|
|
234
|
-
|
|
235
|
-
Additionally holds an internal
|
|
235
|
+
Quantizes the three inputs (Q, K, V) and holds an internal
|
|
236
236
|
:class:`~embedl_deploy._internal.core.quantize.stubs.QuantStub` between the
|
|
237
|
-
softmax output and the second batched matrix multiply (BMM2). When
|
|
238
|
-
|
|
237
|
+
softmax output and the second batched matrix multiply (BMM2). When
|
|
238
|
+
``output_precision`` is still ``DEFERRED`` or the softmax stub is disabled,
|
|
239
|
+
the forward pass delegates to the unwrapped
|
|
239
240
|
:class:`~embedl_deploy._internal.tensorrt.modules.attention.ScaledDotProductAttention`;
|
|
240
|
-
|
|
241
|
+
otherwise it performs manual attention with the quantization step.
|
|
241
242
|
|
|
242
243
|
:param attention:
|
|
243
244
|
The
|
|
@@ -245,11 +246,12 @@ class FusedScaledDotProductAttention(FusedModule):
|
|
|
245
246
|
from the decomposed MHA.
|
|
246
247
|
"""
|
|
247
248
|
|
|
248
|
-
inputs_to_quantize: set[int] =
|
|
249
|
+
inputs_to_quantize: set[int] = {0, 1, 2}
|
|
249
250
|
|
|
250
251
|
def __init__(self, attention: ScaledDotProductAttention) -> None:
|
|
251
252
|
super().__init__()
|
|
252
253
|
self.attention = attention
|
|
254
|
+
self.output_precision = Precision.DEFERRED
|
|
253
255
|
self.softmax_quant = QuantStub(
|
|
254
256
|
consumers={self},
|
|
255
257
|
n_bits=8,
|
|
@@ -266,11 +268,11 @@ class FusedScaledDotProductAttention(FusedModule):
|
|
|
266
268
|
) -> torch.Tensor:
|
|
267
269
|
r"""Compute scaled dot-product attention.
|
|
268
270
|
|
|
269
|
-
When
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
271
|
+
When ``output_precision`` has been resolved (no longer ``DEFERRED``)
|
|
272
|
+
and the internal softmax quant stub is enabled, performs manual
|
|
273
|
+
attention with a quantization step between softmax and BMM2. Otherwise
|
|
274
|
+
delegates to the wrapped attention module so TensorRT can fuse it into
|
|
275
|
+
its native FP16 MHA kernel.
|
|
274
276
|
|
|
275
277
|
:param q:
|
|
276
278
|
Query tensor ``[B, num_heads, S, head_dim]``.
|
|
@@ -286,13 +288,16 @@ class FusedScaledDotProductAttention(FusedModule):
|
|
|
286
288
|
Output tensor ``[B, num_heads, S, head_dim]``. Callers are
|
|
287
289
|
responsible for any subsequent head-flattening reshape.
|
|
288
290
|
"""
|
|
289
|
-
# Manual attention is only beneficial when
|
|
290
|
-
#
|
|
291
|
-
#
|
|
292
|
-
#
|
|
293
|
-
#
|
|
294
|
-
#
|
|
295
|
-
if
|
|
291
|
+
# Manual attention is only beneficial when output_precision has
|
|
292
|
+
# been resolved (i.e. Q/K/V stubs are active and arriving in
|
|
293
|
+
# INT8). While still DEFERRED, ``softmax_quant`` may be enabled
|
|
294
|
+
# but running manual attention would add a softmax Q/DQ pair
|
|
295
|
+
# that pushes TensorRT off its FP16 fused MHA kernel onto the
|
|
296
|
+
# slower INT8-aware variant for no gain.
|
|
297
|
+
if (
|
|
298
|
+
self.output_precision == Precision.DEFERRED
|
|
299
|
+
or not self.softmax_quant.enabled
|
|
300
|
+
):
|
|
296
301
|
return self.attention(q, k, v, attn_mask)
|
|
297
302
|
# Honor the wrapped attention module's explicit ``scale`` if
|
|
298
303
|
# set — models that pre-scale Q themselves (chronos-2 + RoPE,
|
|
@@ -15,29 +15,20 @@ import torch
|
|
|
15
15
|
import torch.nn.functional as F
|
|
16
16
|
from torch import nn
|
|
17
17
|
|
|
18
|
-
from embedl_deploy._internal.core.modules import
|
|
18
|
+
from embedl_deploy._internal.core.modules import (
|
|
19
|
+
ActivationLike,
|
|
20
|
+
FusedModule,
|
|
21
|
+
Precision,
|
|
22
|
+
)
|
|
19
23
|
from embedl_deploy._internal.core.quantize.stubs import (
|
|
24
|
+
QuantStub,
|
|
20
25
|
WeightFakeQuantize,
|
|
21
26
|
)
|
|
22
27
|
|
|
23
28
|
|
|
24
|
-
def
|
|
25
|
-
"""Return ``True``
|
|
26
|
-
|
|
27
|
-
TensorRT's documented constraint for ``IConvolutionLayer`` is that
|
|
28
|
-
``in_channels / groups`` and ``out_channels / groups`` must both be
|
|
29
|
-
multiples of 4 in INT8 mode. Depthwise convolutions (``groups ==
|
|
30
|
-
in_channels``) are an exception: our benchmarks on the target devices show
|
|
31
|
-
they still benefit from INT8 despite channels-per-group being 1, so we let
|
|
32
|
-
them through.
|
|
33
|
-
"""
|
|
34
|
-
if conv.groups <= 1:
|
|
35
|
-
return True
|
|
36
|
-
if conv.groups == conv.in_channels:
|
|
37
|
-
return True
|
|
38
|
-
in_per_group: int = conv.in_channels // conv.groups
|
|
39
|
-
out_per_group: int = conv.out_channels // conv.groups
|
|
40
|
-
return in_per_group % 4 == 0 and out_per_group % 4 == 0
|
|
29
|
+
def is_depthwise_conv(conv: nn.Conv2d) -> bool:
|
|
30
|
+
"""Return ``True`` when *conv* is depthwise."""
|
|
31
|
+
return conv.groups == conv.in_channels
|
|
41
32
|
|
|
42
33
|
|
|
43
34
|
def _conv_weight_forward(
|
|
@@ -51,6 +42,7 @@ def _conv_weight_forward(
|
|
|
51
42
|
if weight_fake_quant is not None
|
|
52
43
|
else conv.weight
|
|
53
44
|
)
|
|
45
|
+
# pylint: disable-next=not-callable
|
|
54
46
|
return F.conv2d(
|
|
55
47
|
x,
|
|
56
48
|
weight,
|
|
@@ -63,7 +55,10 @@ def _conv_weight_forward(
|
|
|
63
55
|
|
|
64
56
|
|
|
65
57
|
class FusedConvBNAct(FusedModule):
|
|
66
|
-
"""Fused ``Conv2d → [BatchNorm2d] → Act``.
|
|
58
|
+
"""Fused ``Conv2d → [BatchNorm2d] → Act``.
|
|
59
|
+
|
|
60
|
+
Depthwise convolutions defer output precision to the surround pattern.
|
|
61
|
+
"""
|
|
67
62
|
|
|
68
63
|
inputs_to_quantize: set[int] = {0}
|
|
69
64
|
|
|
@@ -77,10 +72,9 @@ class FusedConvBNAct(FusedModule):
|
|
|
77
72
|
self.conv = conv
|
|
78
73
|
self.bn = bn
|
|
79
74
|
self.act = act
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
self.input_quant_stubs = {}
|
|
75
|
+
self.weight_fake_quant = WeightFakeQuantize({self})
|
|
76
|
+
if is_depthwise_conv(conv):
|
|
77
|
+
self.output_precision = Precision.DEFERRED
|
|
84
78
|
|
|
85
79
|
@property
|
|
86
80
|
def quantized_weight(self) -> torch.Tensor | None:
|
|
@@ -88,7 +82,7 @@ class FusedConvBNAct(FusedModule):
|
|
|
88
82
|
|
|
89
83
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
90
84
|
"""Apply ``conv → [bn] → act``."""
|
|
91
|
-
wfq = getattr(self,
|
|
85
|
+
wfq = getattr(self, "weight_fake_quant", None)
|
|
92
86
|
x = _conv_weight_forward(self.conv, wfq, x)
|
|
93
87
|
if self.bn is not None:
|
|
94
88
|
x = self.bn(x)
|
|
@@ -107,7 +101,10 @@ class FusedConvBNAct(FusedModule):
|
|
|
107
101
|
|
|
108
102
|
|
|
109
103
|
class FusedConvBN(FusedModule):
|
|
110
|
-
"""Fused ``Conv2d → [BatchNorm2d]`` (no activation).
|
|
104
|
+
"""Fused ``Conv2d → [BatchNorm2d]`` (no activation).
|
|
105
|
+
|
|
106
|
+
Depthwise convolutions defer output precision to the surround pattern.
|
|
107
|
+
"""
|
|
111
108
|
|
|
112
109
|
inputs_to_quantize: set[int] = {0}
|
|
113
110
|
|
|
@@ -119,10 +116,9 @@ class FusedConvBN(FusedModule):
|
|
|
119
116
|
super().__init__()
|
|
120
117
|
self.conv = conv
|
|
121
118
|
self.bn = bn
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
self.input_quant_stubs = {}
|
|
119
|
+
self.weight_fake_quant = WeightFakeQuantize({self})
|
|
120
|
+
if is_depthwise_conv(conv):
|
|
121
|
+
self.output_precision = Precision.DEFERRED
|
|
126
122
|
|
|
127
123
|
@property
|
|
128
124
|
def quantized_weight(self) -> torch.Tensor | None:
|
|
@@ -130,7 +126,7 @@ class FusedConvBN(FusedModule):
|
|
|
130
126
|
|
|
131
127
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
132
128
|
"""Apply ``conv → [bn]``."""
|
|
133
|
-
wfq = getattr(self,
|
|
129
|
+
wfq = getattr(self, "weight_fake_quant", None)
|
|
134
130
|
x = _conv_weight_forward(self.conv, wfq, x)
|
|
135
131
|
if self.bn is not None:
|
|
136
132
|
x = self.bn(x)
|
|
@@ -173,7 +169,8 @@ class FusedConvBNActMaxPool(FusedModule):
|
|
|
173
169
|
|
|
174
170
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
175
171
|
"""Apply ``conv → [bn] → act → maxpool``."""
|
|
176
|
-
|
|
172
|
+
wfq = getattr(self, "weight_fake_quant", None)
|
|
173
|
+
x = _conv_weight_forward(self.conv, wfq, x)
|
|
177
174
|
if self.bn is not None:
|
|
178
175
|
x = self.bn(x)
|
|
179
176
|
x = self.act(x)
|
|
@@ -194,10 +191,12 @@ class FusedConvBNActMaxPool(FusedModule):
|
|
|
194
191
|
|
|
195
192
|
|
|
196
193
|
class FusedConvBNAddAct(FusedModule):
|
|
197
|
-
"""Fused ``Conv2d → BatchNorm2d → add(·, residual) → Activation``.
|
|
194
|
+
"""Fused ``Conv2d → [BatchNorm2d] → add(·, residual) → [Activation]``.
|
|
198
195
|
|
|
199
196
|
``forward()`` accepts two inputs: the main tensor ``x`` and the
|
|
200
|
-
``residual`` tensor.
|
|
197
|
+
``residual`` tensor. Both the ``BatchNorm2d`` and the trailing activation
|
|
198
|
+
are optional so that EfficientNet-style ``Conv → BN → Add`` blocks (no
|
|
199
|
+
activation) are captured.
|
|
201
200
|
"""
|
|
202
201
|
|
|
203
202
|
inputs_to_quantize: set[int] = {0, 1}
|
|
@@ -205,33 +204,99 @@ class FusedConvBNAddAct(FusedModule):
|
|
|
205
204
|
def __init__(
|
|
206
205
|
self,
|
|
207
206
|
conv: nn.Conv2d,
|
|
208
|
-
bn: nn.BatchNorm2d,
|
|
209
|
-
act: ActivationLike,
|
|
207
|
+
bn: nn.BatchNorm2d | None,
|
|
208
|
+
act: ActivationLike | None,
|
|
210
209
|
) -> None:
|
|
211
210
|
super().__init__()
|
|
212
211
|
self.conv = conv
|
|
213
212
|
self.bn = bn
|
|
214
213
|
self.act = act
|
|
215
|
-
|
|
216
|
-
self.weight_fake_quant = WeightFakeQuantize({self})
|
|
217
|
-
else:
|
|
218
|
-
self.input_quant_stubs = {}
|
|
214
|
+
self.weight_fake_quant = WeightFakeQuantize({self})
|
|
219
215
|
|
|
220
216
|
@property
|
|
221
217
|
def quantized_weight(self) -> torch.Tensor | None:
|
|
222
218
|
return self.conv.weight
|
|
223
219
|
|
|
224
220
|
def forward(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
|
|
225
|
-
"""Apply ``conv → bn → add(·, residual) → act``."""
|
|
226
|
-
wfq = getattr(self,
|
|
221
|
+
"""Apply ``conv → [bn] → add(·, residual) → [act]``."""
|
|
222
|
+
wfq = getattr(self, "weight_fake_quant", None)
|
|
227
223
|
x = _conv_weight_forward(self.conv, wfq, x)
|
|
228
|
-
|
|
229
|
-
|
|
224
|
+
if self.bn is not None:
|
|
225
|
+
x = self.bn(x)
|
|
226
|
+
x = x + residual
|
|
227
|
+
if self.act is not None:
|
|
228
|
+
x = self.act(x)
|
|
229
|
+
return x
|
|
230
230
|
|
|
231
231
|
def __repr__(self) -> str: # pragma: no cover
|
|
232
|
+
bn_info = ""
|
|
233
|
+
if self.bn is not None:
|
|
234
|
+
bn_info = f", bn={self.bn.num_features} (foldable)"
|
|
235
|
+
act_info = ""
|
|
236
|
+
if self.act is not None:
|
|
237
|
+
act_info = f", act={type(self.act).__name__}"
|
|
232
238
|
return (
|
|
233
239
|
f"FusedConvBNAddAct("
|
|
234
240
|
f"{self.conv.in_channels}→{self.conv.out_channels}, "
|
|
235
|
-
f"k={self.conv.kernel_size}, s={self.conv.stride}
|
|
236
|
-
f"
|
|
241
|
+
f"k={self.conv.kernel_size}, s={self.conv.stride}"
|
|
242
|
+
f"{bn_info}{act_info})"
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
class FusedConvBNSigmoidMul(FusedModule):
|
|
247
|
+
"""Fused ``Conv2d → [BatchNorm2d] → Sigmoid → Mul(·, skip)``.
|
|
248
|
+
|
|
249
|
+
Captures the SE gate pattern where an expand convolution produces channel
|
|
250
|
+
attention weights via Sigmoid, then element-wise multiplies with the skip
|
|
251
|
+
connection. An internal
|
|
252
|
+
:class:`~embedl_deploy._internal.core.quantize.stubs.QuantStub` between the
|
|
253
|
+
conv/BN output and the Sigmoid produces the Q/DQ pair that enables
|
|
254
|
+
TensorRT's ``PWN(Sigmoid, Mul)`` fusion.
|
|
255
|
+
|
|
256
|
+
``forward()`` accepts two inputs: the main tensor ``x`` feeding the
|
|
257
|
+
convolution, and the ``skip`` tensor multiplied by the sigmoid gate.
|
|
258
|
+
"""
|
|
259
|
+
|
|
260
|
+
inputs_to_quantize: set[int] = {0, 1}
|
|
261
|
+
|
|
262
|
+
def __init__(
|
|
263
|
+
self,
|
|
264
|
+
conv: nn.Conv2d,
|
|
265
|
+
bn: nn.BatchNorm2d | None,
|
|
266
|
+
sigmoid: nn.Sigmoid,
|
|
267
|
+
) -> None:
|
|
268
|
+
super().__init__()
|
|
269
|
+
self.conv = conv
|
|
270
|
+
self.bn = bn
|
|
271
|
+
self.sigmoid = sigmoid
|
|
272
|
+
self.weight_fake_quant = WeightFakeQuantize({self})
|
|
273
|
+
self.gate_quant = QuantStub({self})
|
|
274
|
+
|
|
275
|
+
@property
|
|
276
|
+
def quantized_weight(self) -> torch.Tensor | None:
|
|
277
|
+
return self.conv.weight
|
|
278
|
+
|
|
279
|
+
def forward(
|
|
280
|
+
self,
|
|
281
|
+
x: torch.Tensor,
|
|
282
|
+
skip: torch.Tensor,
|
|
283
|
+
) -> torch.Tensor:
|
|
284
|
+
"""Apply ``conv → [bn] → gate_quant → sigmoid → mul(·, skip)``."""
|
|
285
|
+
wfq = getattr(self, "weight_fake_quant", None)
|
|
286
|
+
x = _conv_weight_forward(self.conv, wfq, x)
|
|
287
|
+
if self.bn is not None:
|
|
288
|
+
x = self.bn(x)
|
|
289
|
+
x = self.gate_quant(x)
|
|
290
|
+
x = self.sigmoid(x)
|
|
291
|
+
return x * skip
|
|
292
|
+
|
|
293
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
294
|
+
bn_info = ""
|
|
295
|
+
if self.bn is not None:
|
|
296
|
+
bn_info = f", bn={self.bn.num_features} (foldable)"
|
|
297
|
+
return (
|
|
298
|
+
f"FusedConvBNSigmoidMul("
|
|
299
|
+
f"{self.conv.in_channels}→{self.conv.out_channels}, "
|
|
300
|
+
f"k={self.conv.kernel_size}, s={self.conv.stride}"
|
|
301
|
+
f"{bn_info})"
|
|
237
302
|
)
|