embedl-deploy-tensorrt 0.4.1__tar.gz → 0.6.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/PKG-INFO +7 -6
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/README.md +6 -5
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/__init__.py +2 -0
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/backend.py +1 -0
- embedl_deploy_tensorrt-0.6.0/src/embedl_deploy/_internal/tensorrt/modules/AGENTS.md +408 -0
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/modules/attention.py +56 -29
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/modules/conv.py +26 -11
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/modules/linear.py +26 -19
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/modules/pointwise.py +2 -2
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/modules/pool.py +1 -1
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/modules/swin_attention.py +81 -82
- embedl_deploy_tensorrt-0.6.0/src/embedl_deploy/_internal/tensorrt/patterns/AGENTS.md +721 -0
- embedl_deploy_tensorrt-0.6.0/src/embedl_deploy/_internal/tensorrt/patterns/conversions/__init__.py +3 -0
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/attention.py +194 -246
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/general.py +29 -98
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions.py +33 -27
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/patterns/quantizations.py +122 -41
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/patterns/smoothings.py +23 -19
- embedl_deploy_tensorrt-0.6.0/src/embedl_deploy/_internal/tensorrt/plan.py +98 -0
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/tensorrt/__init__.py +11 -9
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/tensorrt/patterns/__init__.py +0 -2
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/version/public.py +1 -1
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy_tensorrt.egg-info/SOURCES.txt +2 -2
- embedl_deploy_tensorrt-0.4.1/src/embedl_deploy/_internal/tensorrt/patterns/conversions/__init__.py +0 -15
- embedl_deploy_tensorrt-0.4.1/src/embedl_deploy/_internal/tensorrt/patterns/recompositions.py +0 -1475
- embedl_deploy_tensorrt-0.4.1/src/embedl_deploy/_internal/tensorrt/patterns/utils.py +0 -89
- embedl_deploy_tensorrt-0.4.1/src/embedl_deploy/_internal/tensorrt/plan.py +0 -127
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/LICENSE +0 -0
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/MANIFEST.in +0 -0
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/NOTICE +0 -0
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/pyproject.toml +0 -0
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/setup.cfg +0 -0
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/modules/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/patterns/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/tensorrt/modules/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/version/__init__.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: embedl-deploy-tensorrt
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.6.0
|
|
4
4
|
Summary: TensorRT backend for embedl-deploy.
|
|
5
5
|
Author-email: Embedl AB <support@embedl.com>
|
|
6
6
|
Project-URL: Homepage, https://www.embedl.com/
|
|
@@ -54,9 +54,10 @@ hardware target ensuring correct quantization and compilation.
|
|
|
54
54
|
|
|
55
55
|
## Supported Backends
|
|
56
56
|
|
|
57
|
-
| Backend
|
|
58
|
-
|
|
59
|
-
| NVIDIA TensorRT
|
|
57
|
+
| Backend | Status |
|
|
58
|
+
|---------------------------|-----------------|
|
|
59
|
+
| NVIDIA TensorRT (v10.3) | Supported |
|
|
60
|
+
| Lattice SensAI (v8.0) | In Development |
|
|
60
61
|
|
|
61
62
|
Contact Embedl for other backends.
|
|
62
63
|
|
|
@@ -71,7 +72,7 @@ intermediate.
|
|
|
71
72
|
|
|
72
73
|
---
|
|
73
74
|
|
|
74
|
-
## Quick Start
|
|
75
|
+
## Quick Start for TensorRT Backend
|
|
75
76
|
|
|
76
77
|
```python
|
|
77
78
|
import torch
|
|
@@ -85,7 +86,7 @@ model = Model().eval()
|
|
|
85
86
|
example_input = torch.randn(1, 3, 224, 224)
|
|
86
87
|
|
|
87
88
|
# 2. Transform — fuse and optimize for TensorRT in one call
|
|
88
|
-
# For more
|
|
89
|
+
# For more compatibility you can trace your model with torch.export.export
|
|
89
90
|
# as follows:
|
|
90
91
|
# model = torch.export.export(model, (example_input)).module()
|
|
91
92
|
res = transform(model, patterns=TENSORRT_PATTERNS)
|
|
@@ -35,9 +35,10 @@ hardware target ensuring correct quantization and compilation.
|
|
|
35
35
|
|
|
36
36
|
## Supported Backends
|
|
37
37
|
|
|
38
|
-
| Backend
|
|
39
|
-
|
|
40
|
-
| NVIDIA TensorRT
|
|
38
|
+
| Backend | Status |
|
|
39
|
+
|---------------------------|-----------------|
|
|
40
|
+
| NVIDIA TensorRT (v10.3) | Supported |
|
|
41
|
+
| Lattice SensAI (v8.0) | In Development |
|
|
41
42
|
|
|
42
43
|
Contact Embedl for other backends.
|
|
43
44
|
|
|
@@ -52,7 +53,7 @@ intermediate.
|
|
|
52
53
|
|
|
53
54
|
---
|
|
54
55
|
|
|
55
|
-
## Quick Start
|
|
56
|
+
## Quick Start for TensorRT Backend
|
|
56
57
|
|
|
57
58
|
```python
|
|
58
59
|
import torch
|
|
@@ -66,7 +67,7 @@ model = Model().eval()
|
|
|
66
67
|
example_input = torch.randn(1, 3, 224, 224)
|
|
67
68
|
|
|
68
69
|
# 2. Transform — fuse and optimize for TensorRT in one call
|
|
69
|
-
# For more
|
|
70
|
+
# For more compatibility you can trace your model with torch.export.export
|
|
70
71
|
# as follows:
|
|
71
72
|
# model = torch.export.export(model, (example_input)).module()
|
|
72
73
|
res = transform(model, patterns=TENSORRT_PATTERNS)
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
"""Python package to make AI models deployment-ready for any hardware."""
|
|
4
4
|
|
|
5
5
|
from embedl_deploy._internal.core.plan import (
|
|
6
|
+
Trace,
|
|
6
7
|
TransformationPlan,
|
|
7
8
|
TransformationResult,
|
|
8
9
|
apply_transformation_plan,
|
|
@@ -14,6 +15,7 @@ from embedl_deploy.version.public import PUBLIC_VERSION
|
|
|
14
15
|
__version__ = PUBLIC_VERSION
|
|
15
16
|
|
|
16
17
|
__all__ = [
|
|
18
|
+
"Trace",
|
|
17
19
|
"TransformationPlan",
|
|
18
20
|
"TransformationResult",
|
|
19
21
|
"__version__",
|
|
@@ -0,0 +1,408 @@
|
|
|
1
|
+
# TensorRT Fused Modules — Subsystem Guide
|
|
2
|
+
|
|
3
|
+
## Role & Boundary
|
|
4
|
+
|
|
5
|
+
**This subsystem owns:** concrete `FusedModule` (and `ConvertedModule`) implementations
|
|
6
|
+
for TensorRT. Every class in `modules/` corresponds to a hardware-fusible or
|
|
7
|
+
structurally-decomposed operation the TensorRT backend can emit as a single engine
|
|
8
|
+
layer or kernel.
|
|
9
|
+
|
|
10
|
+
**This subsystem does NOT own:**
|
|
11
|
+
|
|
12
|
+
- When or how to match these modules in a graph — that lives in `patterns/`.
|
|
13
|
+
- Q/DQ stub insertion or calibration logic — that lives in `core/quantize/`.
|
|
14
|
+
- The pattern-matching engine itself — that lives in `core/`.
|
|
15
|
+
|
|
16
|
+
---
|
|
17
|
+
|
|
18
|
+
## Position in the System
|
|
19
|
+
|
|
20
|
+
```
|
|
21
|
+
patterns/fusions.py # creates FusedModule instances
|
|
22
|
+
patterns/conversions/general.py # creates ConvertedModule instances (erasure, flatten-linear)
|
|
23
|
+
patterns/conversions/attention.py # creates ConvertedModule instances (MHA/Swin/SDPA)
|
|
24
|
+
│
|
|
25
|
+
▼
|
|
26
|
+
modules/ (this subsystem)
|
|
27
|
+
│
|
|
28
|
+
▼
|
|
29
|
+
core/quantize/prepare.py # walks the graph, uses isinstance(mod, FusedModule)
|
|
30
|
+
# and mod.inputs_to_quantize to insert Q/DQ stubs
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
**Instantiated by:** pattern grafts and `replace()` methods in `patterns/fusions.py`
|
|
34
|
+
and `patterns/conversions/`. Most fusion patterns declare a `graft` attribute
|
|
35
|
+
pointing to the fused module class; the graft system calls `_collect_modules()`
|
|
36
|
+
to gather matched modules in tree order and passes them as positional arguments
|
|
37
|
+
to the constructor. For example, `ConvBNActPattern` grafts `FusedConvBNAct`;
|
|
38
|
+
`DecomposeMultiheadAttentionPattern.replace()` constructs `MHAInProjection` and
|
|
39
|
+
`ScaledDotProductAttention`.
|
|
40
|
+
|
|
41
|
+
**Used by:** The Q/DQ insertion pass in `core/quantize/prepare.py` via
|
|
42
|
+
`isinstance(mod, FusedModule)` and `mod.inputs_to_quantize`, to determine which
|
|
43
|
+
inputs of each `call_module` node should receive a `QuantStub`.
|
|
44
|
+
|
|
45
|
+
**Also used as intermediate graph nodes by:** conversion patterns, which produce
|
|
46
|
+
`ConvertedModule` subclasses (`MHAInProjection`, `ScaledDotProductAttention`,
|
|
47
|
+
`SwinWindowPartition`, `SwinAttention`, `SwinWindowReverse`). These stay opaque
|
|
48
|
+
during re-tracing so that downstream fusion patterns can see them as atomic nodes.
|
|
49
|
+
|
|
50
|
+
---
|
|
51
|
+
|
|
52
|
+
## FusedModule Contract
|
|
53
|
+
|
|
54
|
+
Every `FusedModule` subclass must satisfy three requirements:
|
|
55
|
+
|
|
56
|
+
1. **Set `inputs_to_quantize` as a class attribute** (not an instance attribute).
|
|
57
|
+
The Q/DQ insertion pass reads this before the module is instantiated to decide
|
|
58
|
+
which positional arguments of the `call_module` FX node will have a `QuantStub`
|
|
59
|
+
inserted in front of them.
|
|
60
|
+
|
|
61
|
+
2. **Call `super().__init__()`**, which creates `self.input_quant_stubs` — a dict
|
|
62
|
+
mapping each index in `inputs_to_quantize` to a fresh `QuantStub`. The Q/DQ
|
|
63
|
+
pass later enables and configures these stubs during `prepare_qdq()`. It also
|
|
64
|
+
initialises `self.surrounded = False`, which is later set to `True` by
|
|
65
|
+
`SurroundWithQuantStubsPattern` to mark modules that have been surrounded
|
|
66
|
+
with input `QuantStub` entries.
|
|
67
|
+
|
|
68
|
+
3. **Implement `forward()`** with the fused computation the module represents.
|
|
69
|
+
|
|
70
|
+
### Graft compatibility
|
|
71
|
+
|
|
72
|
+
When a pattern uses `graft = FusedFoo` (bare class), the graft system calls
|
|
73
|
+
`_collect_modules()` to walk the matched tree and collect the `nn.Module`
|
|
74
|
+
instances corresponding to trunk and fork nodes (nested branches first, then
|
|
75
|
+
trunk nodes). These are passed as positional arguments to the constructor.
|
|
76
|
+
Therefore the constructor signature must accept modules in the same order
|
|
77
|
+
they appear in the pattern tree.
|
|
78
|
+
|
|
79
|
+
### What `inputs_to_quantize` means
|
|
80
|
+
|
|
81
|
+
`inputs_to_quantize` is a set of *positional argument indices* corresponding to
|
|
82
|
+
the `call_module` FX node's arguments — that is, the arguments visible in the
|
|
83
|
+
traced graph, not necessarily the Python keyword positions in `forward()`.
|
|
84
|
+
|
|
85
|
+
For example, `FusedConvBN.inputs_to_quantize = {0}` means the first tensor
|
|
86
|
+
argument to the fused node (the image tensor) gets a `QuantStub`. The convolution
|
|
87
|
+
weight is quantized separately via `WeightFakeQuantize`, which is attached to the
|
|
88
|
+
module directly rather than via `inputs_to_quantize`.
|
|
89
|
+
|
|
90
|
+
`FusedConvBNAddAct.inputs_to_quantize = {0, 1}` means both the main feature map
|
|
91
|
+
and the residual skip tensor each get their own `QuantStub`.
|
|
92
|
+
|
|
93
|
+
### Why `inputs_to_quantize` is a class attribute
|
|
94
|
+
|
|
95
|
+
Because it describes the module *type's* quantization contract, not any particular
|
|
96
|
+
instance's configuration. The Q/DQ insertion pass queries it once per class during
|
|
97
|
+
graph preparation, before modules are constructed. Making it a class attribute
|
|
98
|
+
(rather than an instance attribute set in `__init__`) ensures the value is
|
|
99
|
+
available from the class itself without constructing an instance.
|
|
100
|
+
|
|
101
|
+
---
|
|
102
|
+
|
|
103
|
+
## ConvertedModule vs FusedModule
|
|
104
|
+
|
|
105
|
+
Both base classes are defined in `core/modules.py`. They serve different roles:
|
|
106
|
+
|
|
107
|
+
**`ConvertedModule` subclasses** are *intermediate decomposition products*. They
|
|
108
|
+
replace high-level opaque ops (e.g. `nn.MultiheadAttention`,
|
|
109
|
+
`shifted_window_attention`) with sub-modules that downstream *fusion* patterns can
|
|
110
|
+
then match individually. The custom tracer (`_LeafTracer`) treats them as leaf
|
|
111
|
+
nodes — they are never unwrapped during re-tracing. Examples: `MHAInProjection`,
|
|
112
|
+
`ScaledDotProductAttention`, `SwinWindowPartition`, `SwinAttention`,
|
|
113
|
+
`SwinWindowReverse`.
|
|
114
|
+
|
|
115
|
+
**`FusedModule` subclasses** are the *final quantization target*. The Q/DQ
|
|
116
|
+
insertion pass identifies them via `isinstance(mod, FusedModule)` and wraps their
|
|
117
|
+
inputs/weights with fake-quantization stubs. Examples: `FusedConvBN`,
|
|
118
|
+
`FusedConvBNAct`, `FusedLinear`, `FusedSwinAttention`, `FusedMHAInProjection`.
|
|
119
|
+
|
|
120
|
+
A `FusedModule` may wrap a `ConvertedModule` (e.g. `FusedMHAInProjection` wraps
|
|
121
|
+
`MHAInProjection`, `FusedSwinAttention` wraps `SwinAttention`). In that case the
|
|
122
|
+
conversion pass runs first, then the fusion pass replaces the `ConvertedModule`
|
|
123
|
+
with its `FusedModule` counterpart.
|
|
124
|
+
|
|
125
|
+
---
|
|
126
|
+
|
|
127
|
+
## Module Catalog
|
|
128
|
+
|
|
129
|
+
### Conv family — `conv.py`
|
|
130
|
+
|
|
131
|
+
Pattern that creates them: `patterns/fusions.py` (`ConvBNActPattern`,
|
|
132
|
+
`ConvBNPattern`, `StemConvBNActMaxPoolPattern`, `ConvBNAddActPattern`).
|
|
133
|
+
|
|
134
|
+
All conv fused modules store a reference to the original `nn.Conv2d` (and
|
|
135
|
+
optionally `nn.BatchNorm2d`, `ActivationLike`, `nn.MaxPool2d`) — they do not
|
|
136
|
+
pre-compute fused weights. This lets the same module object work correctly in both
|
|
137
|
+
training (with live BN statistics) and eval (where BN is folded by TensorRT at
|
|
138
|
+
export time).
|
|
139
|
+
|
|
140
|
+
**BN folding indicator:** The presence of `self.bn` (not `None`) implicitly
|
|
141
|
+
indicates that BN can be folded into the convolution at export time. The
|
|
142
|
+
`__repr__` methods print `"(foldable)"` when `self.bn is not None`.
|
|
143
|
+
|
|
144
|
+
**`FusedConvBN`** — `Conv2d → [BatchNorm2d]`, `inputs_to_quantize = {0}`.
|
|
145
|
+
|
|
146
|
+
**`FusedConvBNAct`** — `Conv2d → [BatchNorm2d] → Activation`, `inputs_to_quantize = {0}`.
|
|
147
|
+
|
|
148
|
+
**`FusedConvBNActMaxPool`** — `Conv2d → [BatchNorm2d] → Activation → MaxPool2d`,
|
|
149
|
+
`inputs_to_quantize = {0}`. MaxPool is included here (rather than fused
|
|
150
|
+
separately) because TensorRT supports fusing the full stem sequence
|
|
151
|
+
(`Conv → BN → Act → Pool`) as a single engine node in the common 7x7 stem pattern.
|
|
152
|
+
Fusing them together avoids an extra quantized activation between Conv and Pool.
|
|
153
|
+
|
|
154
|
+
**`FusedConvBNAddAct`** — `Conv2d → BatchNorm2d → add(·, residual) → Activation`,
|
|
155
|
+
`inputs_to_quantize = {0, 1}`. The residual tensor is the second input (index 1),
|
|
156
|
+
hence both inputs are quantized. This is the ResNet skip-connection block tail.
|
|
157
|
+
|
|
158
|
+
**INT8 compatibility guard:** grouped convolutions where `in_channels / groups` or
|
|
159
|
+
`out_channels / groups` is not a multiple of 4 cannot be quantized to INT8 in
|
|
160
|
+
TensorRT. For those cases `_is_int8_compatible_conv()` returns `False`, and the
|
|
161
|
+
module sets `self.input_quant_stubs = {}` (overriding the `super().__init__()`
|
|
162
|
+
default), effectively opting out of quantization.
|
|
163
|
+
|
|
164
|
+
### Linear family — `linear.py`
|
|
165
|
+
|
|
166
|
+
Pattern that creates them: `patterns/fusions.py`.
|
|
167
|
+
|
|
168
|
+
**`FusedLinear`** — wraps a single `nn.Linear`, `inputs_to_quantize = {0}`.
|
|
169
|
+
|
|
170
|
+
**`FusedLinearAct`** — wraps `nn.Linear → Activation`, `inputs_to_quantize = {0}`.
|
|
171
|
+
|
|
172
|
+
**`FusedLayerNorm`** — wraps `nn.LayerNorm`, `inputs_to_quantize = set()`.
|
|
173
|
+
LayerNorm is placed in the linear family because it appears in transformer
|
|
174
|
+
architectures immediately before or after linear layers, and the SmoothQuant
|
|
175
|
+
pass needs to reason about LayerNorm and the following linear together.
|
|
176
|
+
Input quantization is disabled (`inputs_to_quantize = set()`, `prefers_fp_input =
|
|
177
|
+
True`) because LayerNorm normalises its input internally — quantizing the input
|
|
178
|
+
before LayerNorm would destroy the statistical properties that normalization relies
|
|
179
|
+
on.
|
|
180
|
+
|
|
181
|
+
**`FusedLayerNorm.smooth_quant_observer`** — holds a `SmoothQuantObserver`
|
|
182
|
+
instance. It is created in `__init__` but is populated (i.e., scale factors are
|
|
183
|
+
computed and migrated) by the SmoothQuant calibration pass in
|
|
184
|
+
`core/quantize/calibrate.py`. The module itself does not use the observer
|
|
185
|
+
during `forward()` — it only provides a hook for the calibration pass to attach to.
|
|
186
|
+
|
|
187
|
+
### Attention family — `attention.py`
|
|
188
|
+
|
|
189
|
+
Pattern that creates the `ConvertedModule` subclasses:
|
|
190
|
+
`patterns/conversions/attention.py` (`DecomposeMultiheadAttentionPattern`).
|
|
191
|
+
Pattern that creates the `FusedModule` wrappers: `patterns/fusions.py`.
|
|
192
|
+
|
|
193
|
+
**Decomposition model:** `nn.MultiheadAttention` is opaque to FX and to TensorRT.
|
|
194
|
+
The conversion pass decomposes it into three explicit sub-modules:
|
|
195
|
+
|
|
196
|
+
1. `MHAInProjection` (in-projection) — packed `Linear(E, 3E)` followed by
|
|
197
|
+
chunking and head splitting.
|
|
198
|
+
2. `ScaledDotProductAttention` (SDPA) — `softmax(Q·Kᵀ / √H) · V`.
|
|
199
|
+
3. `nn.Linear` (out-projection) — the original `mha.out_proj`.
|
|
200
|
+
|
|
201
|
+
**Why in-proj returns a tuple:** The in-projection produces three independent
|
|
202
|
+
tensors (Q, K, V). TensorRT can fuse the single packed `Linear(E, 3E)` and the
|
|
203
|
+
split into an efficient multi-head projection kernel. Returning a tuple of
|
|
204
|
+
`(Q, K, V)` lets the graph represent the split explicitly — the conversion pass
|
|
205
|
+
inserts `operator.getitem` nodes to fan out the tuple into three separate feeds
|
|
206
|
+
for SDPA.
|
|
207
|
+
|
|
208
|
+
**Head splitting:** Inside `MHAInProjection.forward()`, after the packed linear
|
|
209
|
+
op, each of Q, K, V is reshaped from `[B, S, E]` to `[B, num_heads, S, head_dim]`
|
|
210
|
+
via `view` + `transpose`. This puts the head dimension before the sequence
|
|
211
|
+
dimension, which is the layout expected by `F.scaled_dot_product_attention` and
|
|
212
|
+
by matrix-multiply kernels.
|
|
213
|
+
|
|
214
|
+
**`FusedMHAInProjection`** — wraps `MHAInProjection`, `inputs_to_quantize = {0}`.
|
|
215
|
+
Adds a `WeightFakeQuantize` for the packed linear weight. Only `query` (index 0)
|
|
216
|
+
is quantized; `_key` and `_value` are accepted to match the self-attention
|
|
217
|
+
call-site but ignored.
|
|
218
|
+
|
|
219
|
+
**`FusedScaledDotProductAttention`** — wraps `ScaledDotProductAttention`,
|
|
220
|
+
`inputs_to_quantize = set()`. Adds an internal `softmax_quant` stub with a fixed
|
|
221
|
+
calibration of `(1/127, 0)` — i.e., 8-bit symmetric with a fixed scale matched to
|
|
222
|
+
the softmax output range `[0, 1]`. When the stub is disabled the module delegates
|
|
223
|
+
to the plain SDPA; when enabled it performs manual attention with the quantization
|
|
224
|
+
step between softmax and the second batched matrix multiply (BMM2).
|
|
225
|
+
|
|
226
|
+
### Swin Attention family — `swin_attention.py`
|
|
227
|
+
|
|
228
|
+
Pattern that creates the `ConvertedModule` subclasses:
|
|
229
|
+
`patterns/conversions/attention.py` (`DecomposeSwinAttentionPattern`). Pattern
|
|
230
|
+
that creates `FusedSwinAttention`: `patterns/fusions.py`.
|
|
231
|
+
|
|
232
|
+
**Decomposition model:** `torchvision`'s `shifted_window_attention` free function
|
|
233
|
+
is an opaque `fx.wrap`-ped call that includes spatial padding, cyclic shifting,
|
|
234
|
+
window partitioning, QKV projection, attention, output projection, and window
|
|
235
|
+
reversal — all in one node. The conversion pass splits it into five sub-modules
|
|
236
|
+
that downstream fusion patterns can match individually:
|
|
237
|
+
|
|
238
|
+
1. `SwinWindowPartition` — pad, cyclic-shift, partition to windows.
|
|
239
|
+
2. `MHAInProjection` — QKV projection (shared with the MHA attention family).
|
|
240
|
+
3. `SwinAttention` — windowed attention with relative position bias and shifted-window mask.
|
|
241
|
+
4. `nn.Linear` — output projection.
|
|
242
|
+
5. `SwinWindowReverse` — reverse partition, shift, and unpad.
|
|
243
|
+
|
|
244
|
+
**`SwinSpatialState`** — a mutable dataclass shared by all three spatial modules
|
|
245
|
+
(`SwinWindowPartition`, `SwinAttention`, `SwinWindowReverse`). During each
|
|
246
|
+
`forward()` call:
|
|
247
|
+
|
|
248
|
+
- `SwinWindowPartition.forward()` **writes** `batch_size`, `height`, `width`,
|
|
249
|
+
`pad_height`, `pad_width`, and `effective_shift_size` into the state.
|
|
250
|
+
- `SwinAttention.forward()` **reads** `pad_height`, `pad_width`,
|
|
251
|
+
`effective_shift_size`, and `batch_size` to compute the attention mask.
|
|
252
|
+
- `SwinWindowReverse.forward()` **reads** all fields to undo the partition and
|
|
253
|
+
remove padding.
|
|
254
|
+
|
|
255
|
+
The state must be shared (not copied) because `SwinWindowPartition` computes the
|
|
256
|
+
actual padded dimensions and effective shift at runtime — these depend on the input
|
|
257
|
+
spatial size, which is not known until `forward()` is called.
|
|
258
|
+
|
|
259
|
+
**`deepcopy` note:** `copy.deepcopy` preserves the shared reference among the
|
|
260
|
+
three modules. This is intentional: the three modules reference the *same*
|
|
261
|
+
`SwinSpatialState` object; `deepcopy` copies the object once and updates all
|
|
262
|
+
referencing attributes to point to the copy. However, the sharing is not
|
|
263
|
+
thread-safe (see Gotchas).
|
|
264
|
+
|
|
265
|
+
**`FusedSwinAttention`** — wraps `SwinAttention`, `inputs_to_quantize = set()`.
|
|
266
|
+
Mirrors `FusedScaledDotProductAttention`: adds an internal `softmax_quant` stub
|
|
267
|
+
with fixed calibration. When enabled it manually expands the attention computation
|
|
268
|
+
to insert the quantization step between softmax and BMM2.
|
|
269
|
+
|
|
270
|
+
### Pointwise family — `pointwise.py`
|
|
271
|
+
|
|
272
|
+
Pattern that creates it: `patterns/fusions.py` (`ActAddPattern`).
|
|
273
|
+
|
|
274
|
+
**`FusedActAdd`** — `Activation → add(·, residual)`, `inputs_to_quantize = {0, 1}`.
|
|
275
|
+
Constructor takes `(act: ActivationLike)`. `forward(x, residual)` applies
|
|
276
|
+
`act(x) + residual`. This prevents TensorRT from merging the upstream convolution
|
|
277
|
+
into an activation-fused kernel when the activation output is consumed by a
|
|
278
|
+
subsequent add. Both the activated feature map and the skip-connection tensor are
|
|
279
|
+
quantized at their respective scales.
|
|
280
|
+
|
|
281
|
+
### Pool family — `pool.py`
|
|
282
|
+
|
|
283
|
+
Pattern that creates it: `patterns/fusions.py`.
|
|
284
|
+
|
|
285
|
+
**`FusedAdaptiveAvgPool2d`** — wraps `nn.AdaptiveAvgPool2d`,
|
|
286
|
+
`inputs_to_quantize = set()`. No quantization is applied (pooling is a
|
|
287
|
+
linear operation that does not benefit from separate quantization). Exists as a
|
|
288
|
+
`FusedModule` so the Q/DQ pass treats it uniformly without special-casing it.
|
|
289
|
+
|
|
290
|
+
---
|
|
291
|
+
|
|
292
|
+
## BN Folding
|
|
293
|
+
|
|
294
|
+
Batch normalisation folding is the process of absorbing the BN scale and bias into
|
|
295
|
+
the preceding convolution weight and bias at inference time:
|
|
296
|
+
|
|
297
|
+
```
|
|
298
|
+
w_folded = w * (γ / σ)
|
|
299
|
+
b_folded = (b - μ) * (γ / σ) + β
|
|
300
|
+
```
|
|
301
|
+
|
|
302
|
+
where `γ`, `β` are BN's learned weight/bias and `μ`, `σ` are running statistics.
|
|
303
|
+
|
|
304
|
+
The fused modules store the original `nn.Conv2d` and `nn.BatchNorm2d` as separate
|
|
305
|
+
sub-modules. During training, the BN is applied as a normal operation (live
|
|
306
|
+
statistics). At export, TensorRT performs the folding in the engine. Whether a
|
|
307
|
+
fused module carries a BN is determined implicitly: `self.bn is not None` means
|
|
308
|
+
the module was created from a pattern that included a `BatchNorm2d`, and folding
|
|
309
|
+
is safe. When `self.bn is None`, the convolution weight is used as-is.
|
|
310
|
+
|
|
311
|
+
---
|
|
312
|
+
|
|
313
|
+
## Design Decisions
|
|
314
|
+
|
|
315
|
+
**Why modules store original sub-modules rather than pre-computing fused weights:**
|
|
316
|
+
Because the modules must be correct in both training and eval modes. Training uses
|
|
317
|
+
live BN statistics; the fused weights can only be computed accurately in eval mode
|
|
318
|
+
with frozen running mean/variance. Storing the originals defers the decision to
|
|
319
|
+
the export step while keeping `forward()` correct throughout.
|
|
320
|
+
|
|
321
|
+
**Why `inputs_to_quantize` is a class attribute:** The Q/DQ insertion pass queries
|
|
322
|
+
the set before constructing any instances — it reads it from the class via
|
|
323
|
+
`type(mod).inputs_to_quantize`. Making it a class attribute also prevents
|
|
324
|
+
accidental per-instance divergence.
|
|
325
|
+
|
|
326
|
+
**Why intermediate decomposition modules are `ConvertedModule`:** The custom tracer
|
|
327
|
+
(`_LeafTracer` in `core/modules.py`) checks `isinstance(m, (ConvertedModule,
|
|
328
|
+
FusedModule))` to decide whether to treat a module as a leaf. `ConvertedModule` is
|
|
329
|
+
the marker that tells the tracer "do not recurse into this module's `forward()`".
|
|
330
|
+
Without this, re-tracing the graph after conversion would unwrap the decomposed
|
|
331
|
+
sub-modules, defeating the purpose of decomposition.
|
|
332
|
+
|
|
333
|
+
---
|
|
334
|
+
|
|
335
|
+
## Gotchas & Pitfalls
|
|
336
|
+
|
|
337
|
+
**`SwinSpatialState` thread-safety:** The state is a mutable shared object written
|
|
338
|
+
during every `SwinWindowPartition.forward()` call and read by the other two
|
|
339
|
+
modules in the same forward pass. If the model is run from multiple threads
|
|
340
|
+
concurrently (e.g. in a data-parallel setup without model replication), the writes
|
|
341
|
+
from one thread will corrupt the reads of another. Use `torch.nn.DataParallel` (not
|
|
342
|
+
`DistributedDataParallel`) only if each replica gets its own model copy via
|
|
343
|
+
`deepcopy`, which does preserve sharing within a single copy.
|
|
344
|
+
|
|
345
|
+
**`inputs_to_quantize` indices vs. Python `forward()` args:** The indices refer to
|
|
346
|
+
the positional arguments of the FX `call_module` node in the traced graph, not to
|
|
347
|
+
the `forward()` Python signature. In the normal case they are identical. They
|
|
348
|
+
diverge if the FX graph is produced from a non-trivial call-site (e.g. keyword
|
|
349
|
+
arguments re-ordered). Always verify against the actual graph node when debugging
|
|
350
|
+
Q/DQ placement.
|
|
351
|
+
|
|
352
|
+
**Weight sharing:** The conv/linear sub-modules stored inside fused modules hold
|
|
353
|
+
references to the tensors from the *original* model. If the original model's
|
|
354
|
+
weights are modified after fusion, the fused module sees the change. This is
|
|
355
|
+
usually desirable (e.g. for QAT gradient updates), but can be surprising if the
|
|
356
|
+
original model is used independently.
|
|
357
|
+
|
|
358
|
+
**Grouped conv INT8 opt-out:** When `_is_int8_compatible_conv()` returns `False`
|
|
359
|
+
the `__init__` of the conv fused modules sets `self.input_quant_stubs = {}`,
|
|
360
|
+
overriding the dict populated by `FusedModule.__init__()`. This means the module
|
|
361
|
+
is effectively excluded from quantization despite being a `FusedModule`. The
|
|
362
|
+
`weight_fake_quant` attribute is also not created in this path.
|
|
363
|
+
|
|
364
|
+
---
|
|
365
|
+
|
|
366
|
+
## Adding a New Fused Module
|
|
367
|
+
|
|
368
|
+
1. **Subclass `FusedModule`** (from `core/modules.py`).
|
|
369
|
+
2. **Set `inputs_to_quantize`** as a class attribute — a `set[int]` of positional
|
|
370
|
+
argument indices that should receive activation `QuantStub`s.
|
|
371
|
+
3. **Optionally add `self.weight_fake_quant = WeightFakeQuantize({self})`** in
|
|
372
|
+
`__init__` if the module has a learnable weight that should be fake-quantized.
|
|
373
|
+
4. **Implement `forward()`** with the fused computation.
|
|
374
|
+
5. **Write a `Pattern` subclass** in `patterns/fusions.py`. Prefer declaring
|
|
375
|
+
`graft = FusedFoo` (bare class) so the graft system handles replacement
|
|
376
|
+
automatically. The constructor must accept modules in tree order (nested
|
|
377
|
+
branches first, then trunk nodes). If the replacement logic cannot be
|
|
378
|
+
expressed as a bare-class graft, provide a `ReplacementMaker` or a custom
|
|
379
|
+
`replace()` method instead.
|
|
380
|
+
6. **Add the pattern to `TENSORRT_PATTERNS`** (or the appropriate pattern list) in
|
|
381
|
+
`tensorrt/plan.py`.
|
|
382
|
+
7. **Write tests** in `tests/tensorrt/patterns/fusions/` covering: pattern match,
|
|
383
|
+
pattern replace, correct quantisation stub placement.
|
|
384
|
+
|
|
385
|
+
---
|
|
386
|
+
|
|
387
|
+
## Testing
|
|
388
|
+
|
|
389
|
+
Test models for these modules live in two locations:
|
|
390
|
+
|
|
391
|
+
- `tests/models/conv.py` — `ConvBnRelu`, `ConvBnSiLU`, `ConvBn`, `ConvOnly`,
|
|
392
|
+
`ConvBnAddRelu`, `StemConvBnReluMaxPool`, etc.
|
|
393
|
+
- `tests/models/attention.py` — `SimpleSelfAttention`, `SimpleSwinAttention`, etc.
|
|
394
|
+
- `tests/models/linear.py`, `tests/models/pool.py` — linear and pool variants.
|
|
395
|
+
- `tests/tensorrt/models/attention.py` — `LayerNormMHAInProjection`,
|
|
396
|
+
`SDPALinearOutProjection`, `SwinAttentionLinearOutProjection`.
|
|
397
|
+
|
|
398
|
+
Pattern tests live in `tests/tensorrt/patterns/fusions/`.
|
|
399
|
+
|
|
400
|
+
**What to verify for a new fused module:**
|
|
401
|
+
|
|
402
|
+
- The pattern matches exactly the expected nodes (check `tree_match.serialize()`).
|
|
403
|
+
- The fused node's `target` name is the expected auto-generated name.
|
|
404
|
+
- `resolve_module(fused_node, FusedXxx)` succeeds.
|
|
405
|
+
- `bool(fused_module.input_quant_stubs)` matches `inputs_to_quantize` expectations.
|
|
406
|
+
- `hasattr(fused_module, 'weight_fake_quant')` matches expectations.
|
|
407
|
+
- Graph equivalence: fused model output matches original model output on the same
|
|
408
|
+
input (see `tests/tensorrt/test_equivalence.py`).
|