embedl-deploy-tensorrt 0.5.0__tar.gz → 0.6.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/PKG-INFO +1 -1
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/__init__.py +2 -0
- embedl_deploy_tensorrt-0.6.0/src/embedl_deploy/_internal/tensorrt/modules/AGENTS.md +408 -0
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/modules/attention.py +22 -20
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/modules/conv.py +22 -7
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/modules/linear.py +23 -16
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/modules/pointwise.py +1 -1
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/modules/swin_attention.py +26 -19
- embedl_deploy_tensorrt-0.6.0/src/embedl_deploy/_internal/tensorrt/patterns/AGENTS.md +721 -0
- embedl_deploy_tensorrt-0.6.0/src/embedl_deploy/_internal/tensorrt/patterns/conversions/__init__.py +3 -0
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/attention.py +92 -96
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/general.py +16 -16
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions.py +9 -26
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/patterns/quantizations.py +65 -49
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/patterns/smoothings.py +17 -15
- embedl_deploy_tensorrt-0.6.0/src/embedl_deploy/_internal/tensorrt/plan.py +98 -0
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/tensorrt/__init__.py +9 -7
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/version/public.py +1 -1
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy_tensorrt.egg-info/SOURCES.txt +2 -5
- embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/conversions/__init__.py +0 -15
- embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/__init__.py +0 -3
- embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/attention.py +0 -808
- embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/functional.py +0 -325
- embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/general.py +0 -718
- embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/utils.py +0 -44
- embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/plan.py +0 -139
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/LICENSE +0 -0
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/MANIFEST.in +0 -0
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/NOTICE +0 -0
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/README.md +0 -0
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/pyproject.toml +0 -0
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/setup.cfg +0 -0
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/backend.py +0 -0
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/modules/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/modules/pool.py +0 -0
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/_internal/tensorrt/patterns/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/tensorrt/modules/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/tensorrt/patterns/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.0}/src/embedl_deploy/version/__init__.py +0 -0
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
"""Python package to make AI models deployment-ready for any hardware."""
|
|
4
4
|
|
|
5
5
|
from embedl_deploy._internal.core.plan import (
|
|
6
|
+
Trace,
|
|
6
7
|
TransformationPlan,
|
|
7
8
|
TransformationResult,
|
|
8
9
|
apply_transformation_plan,
|
|
@@ -14,6 +15,7 @@ from embedl_deploy.version.public import PUBLIC_VERSION
|
|
|
14
15
|
__version__ = PUBLIC_VERSION
|
|
15
16
|
|
|
16
17
|
__all__ = [
|
|
18
|
+
"Trace",
|
|
17
19
|
"TransformationPlan",
|
|
18
20
|
"TransformationResult",
|
|
19
21
|
"__version__",
|
|
@@ -0,0 +1,408 @@
|
|
|
1
|
+
# TensorRT Fused Modules — Subsystem Guide
|
|
2
|
+
|
|
3
|
+
## Role & Boundary
|
|
4
|
+
|
|
5
|
+
**This subsystem owns:** concrete `FusedModule` (and `ConvertedModule`) implementations
|
|
6
|
+
for TensorRT. Every class in `modules/` corresponds to a hardware-fusible or
|
|
7
|
+
structurally-decomposed operation the TensorRT backend can emit as a single engine
|
|
8
|
+
layer or kernel.
|
|
9
|
+
|
|
10
|
+
**This subsystem does NOT own:**
|
|
11
|
+
|
|
12
|
+
- When or how to match these modules in a graph — that lives in `patterns/`.
|
|
13
|
+
- Q/DQ stub insertion or calibration logic — that lives in `core/quantize/`.
|
|
14
|
+
- The pattern-matching engine itself — that lives in `core/`.
|
|
15
|
+
|
|
16
|
+
---
|
|
17
|
+
|
|
18
|
+
## Position in the System
|
|
19
|
+
|
|
20
|
+
```
|
|
21
|
+
patterns/fusions.py # creates FusedModule instances
|
|
22
|
+
patterns/conversions/general.py # creates ConvertedModule instances (erasure, flatten-linear)
|
|
23
|
+
patterns/conversions/attention.py # creates ConvertedModule instances (MHA/Swin/SDPA)
|
|
24
|
+
│
|
|
25
|
+
▼
|
|
26
|
+
modules/ (this subsystem)
|
|
27
|
+
│
|
|
28
|
+
▼
|
|
29
|
+
core/quantize/prepare.py # walks the graph, uses isinstance(mod, FusedModule)
|
|
30
|
+
# and mod.inputs_to_quantize to insert Q/DQ stubs
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
**Instantiated by:** pattern grafts and `replace()` methods in `patterns/fusions.py`
|
|
34
|
+
and `patterns/conversions/`. Most fusion patterns declare a `graft` attribute
|
|
35
|
+
pointing to the fused module class; the graft system calls `_collect_modules()`
|
|
36
|
+
to gather matched modules in tree order and passes them as positional arguments
|
|
37
|
+
to the constructor. For example, `ConvBNActPattern` grafts `FusedConvBNAct`;
|
|
38
|
+
`DecomposeMultiheadAttentionPattern.replace()` constructs `MHAInProjection` and
|
|
39
|
+
`ScaledDotProductAttention`.
|
|
40
|
+
|
|
41
|
+
**Used by:** The Q/DQ insertion pass in `core/quantize/prepare.py` via
|
|
42
|
+
`isinstance(mod, FusedModule)` and `mod.inputs_to_quantize`, to determine which
|
|
43
|
+
inputs of each `call_module` node should receive a `QuantStub`.
|
|
44
|
+
|
|
45
|
+
**Also used as intermediate graph nodes by:** conversion patterns, which produce
|
|
46
|
+
`ConvertedModule` subclasses (`MHAInProjection`, `ScaledDotProductAttention`,
|
|
47
|
+
`SwinWindowPartition`, `SwinAttention`, `SwinWindowReverse`). These stay opaque
|
|
48
|
+
during re-tracing so that downstream fusion patterns can see them as atomic nodes.
|
|
49
|
+
|
|
50
|
+
---
|
|
51
|
+
|
|
52
|
+
## FusedModule Contract
|
|
53
|
+
|
|
54
|
+
Every `FusedModule` subclass must satisfy three requirements:
|
|
55
|
+
|
|
56
|
+
1. **Set `inputs_to_quantize` as a class attribute** (not an instance attribute).
|
|
57
|
+
The Q/DQ insertion pass reads this before the module is instantiated to decide
|
|
58
|
+
which positional arguments of the `call_module` FX node will have a `QuantStub`
|
|
59
|
+
inserted in front of them.
|
|
60
|
+
|
|
61
|
+
2. **Call `super().__init__()`**, which creates `self.input_quant_stubs` — a dict
|
|
62
|
+
mapping each index in `inputs_to_quantize` to a fresh `QuantStub`. The Q/DQ
|
|
63
|
+
pass later enables and configures these stubs during `prepare_qdq()`. It also
|
|
64
|
+
initialises `self.surrounded = False`, which is later set to `True` by
|
|
65
|
+
`SurroundWithQuantStubsPattern` to mark modules that have been surrounded
|
|
66
|
+
with input `QuantStub` entries.
|
|
67
|
+
|
|
68
|
+
3. **Implement `forward()`** with the fused computation the module represents.
|
|
69
|
+
|
|
70
|
+
### Graft compatibility
|
|
71
|
+
|
|
72
|
+
When a pattern uses `graft = FusedFoo` (bare class), the graft system calls
|
|
73
|
+
`_collect_modules()` to walk the matched tree and collect the `nn.Module`
|
|
74
|
+
instances corresponding to trunk and fork nodes (nested branches first, then
|
|
75
|
+
trunk nodes). These are passed as positional arguments to the constructor.
|
|
76
|
+
Therefore the constructor signature must accept modules in the same order
|
|
77
|
+
they appear in the pattern tree.
|
|
78
|
+
|
|
79
|
+
### What `inputs_to_quantize` means
|
|
80
|
+
|
|
81
|
+
`inputs_to_quantize` is a set of *positional argument indices* corresponding to
|
|
82
|
+
the `call_module` FX node's arguments — that is, the arguments visible in the
|
|
83
|
+
traced graph, not necessarily the Python keyword positions in `forward()`.
|
|
84
|
+
|
|
85
|
+
For example, `FusedConvBN.inputs_to_quantize = {0}` means the first tensor
|
|
86
|
+
argument to the fused node (the image tensor) gets a `QuantStub`. The convolution
|
|
87
|
+
weight is quantized separately via `WeightFakeQuantize`, which is attached to the
|
|
88
|
+
module directly rather than via `inputs_to_quantize`.
|
|
89
|
+
|
|
90
|
+
`FusedConvBNAddAct.inputs_to_quantize = {0, 1}` means both the main feature map
|
|
91
|
+
and the residual skip tensor each get their own `QuantStub`.
|
|
92
|
+
|
|
93
|
+
### Why `inputs_to_quantize` is a class attribute
|
|
94
|
+
|
|
95
|
+
Because it describes the module *type's* quantization contract, not any particular
|
|
96
|
+
instance's configuration. The Q/DQ insertion pass queries it once per class during
|
|
97
|
+
graph preparation, before modules are constructed. Making it a class attribute
|
|
98
|
+
(rather than an instance attribute set in `__init__`) ensures the value is
|
|
99
|
+
available from the class itself without constructing an instance.
|
|
100
|
+
|
|
101
|
+
---
|
|
102
|
+
|
|
103
|
+
## ConvertedModule vs FusedModule
|
|
104
|
+
|
|
105
|
+
Both base classes are defined in `core/modules.py`. They serve different roles:
|
|
106
|
+
|
|
107
|
+
**`ConvertedModule` subclasses** are *intermediate decomposition products*. They
|
|
108
|
+
replace high-level opaque ops (e.g. `nn.MultiheadAttention`,
|
|
109
|
+
`shifted_window_attention`) with sub-modules that downstream *fusion* patterns can
|
|
110
|
+
then match individually. The custom tracer (`_LeafTracer`) treats them as leaf
|
|
111
|
+
nodes — they are never unwrapped during re-tracing. Examples: `MHAInProjection`,
|
|
112
|
+
`ScaledDotProductAttention`, `SwinWindowPartition`, `SwinAttention`,
|
|
113
|
+
`SwinWindowReverse`.
|
|
114
|
+
|
|
115
|
+
**`FusedModule` subclasses** are the *final quantization target*. The Q/DQ
|
|
116
|
+
insertion pass identifies them via `isinstance(mod, FusedModule)` and wraps their
|
|
117
|
+
inputs/weights with fake-quantization stubs. Examples: `FusedConvBN`,
|
|
118
|
+
`FusedConvBNAct`, `FusedLinear`, `FusedSwinAttention`, `FusedMHAInProjection`.
|
|
119
|
+
|
|
120
|
+
A `FusedModule` may wrap a `ConvertedModule` (e.g. `FusedMHAInProjection` wraps
|
|
121
|
+
`MHAInProjection`, `FusedSwinAttention` wraps `SwinAttention`). In that case the
|
|
122
|
+
conversion pass runs first, then the fusion pass replaces the `ConvertedModule`
|
|
123
|
+
with its `FusedModule` counterpart.
|
|
124
|
+
|
|
125
|
+
---
|
|
126
|
+
|
|
127
|
+
## Module Catalog
|
|
128
|
+
|
|
129
|
+
### Conv family — `conv.py`
|
|
130
|
+
|
|
131
|
+
Pattern that creates them: `patterns/fusions.py` (`ConvBNActPattern`,
|
|
132
|
+
`ConvBNPattern`, `StemConvBNActMaxPoolPattern`, `ConvBNAddActPattern`).
|
|
133
|
+
|
|
134
|
+
All conv fused modules store a reference to the original `nn.Conv2d` (and
|
|
135
|
+
optionally `nn.BatchNorm2d`, `ActivationLike`, `nn.MaxPool2d`) — they do not
|
|
136
|
+
pre-compute fused weights. This lets the same module object work correctly in both
|
|
137
|
+
training (with live BN statistics) and eval (where BN is folded by TensorRT at
|
|
138
|
+
export time).
|
|
139
|
+
|
|
140
|
+
**BN folding indicator:** The presence of `self.bn` (not `None`) implicitly
|
|
141
|
+
indicates that BN can be folded into the convolution at export time. The
|
|
142
|
+
`__repr__` methods print `"(foldable)"` when `self.bn is not None`.
|
|
143
|
+
|
|
144
|
+
**`FusedConvBN`** — `Conv2d → [BatchNorm2d]`, `inputs_to_quantize = {0}`.
|
|
145
|
+
|
|
146
|
+
**`FusedConvBNAct`** — `Conv2d → [BatchNorm2d] → Activation`, `inputs_to_quantize = {0}`.
|
|
147
|
+
|
|
148
|
+
**`FusedConvBNActMaxPool`** — `Conv2d → [BatchNorm2d] → Activation → MaxPool2d`,
|
|
149
|
+
`inputs_to_quantize = {0}`. MaxPool is included here (rather than fused
|
|
150
|
+
separately) because TensorRT supports fusing the full stem sequence
|
|
151
|
+
(`Conv → BN → Act → Pool`) as a single engine node in the common 7x7 stem pattern.
|
|
152
|
+
Fusing them together avoids an extra quantized activation between Conv and Pool.
|
|
153
|
+
|
|
154
|
+
**`FusedConvBNAddAct`** — `Conv2d → BatchNorm2d → add(·, residual) → Activation`,
|
|
155
|
+
`inputs_to_quantize = {0, 1}`. The residual tensor is the second input (index 1),
|
|
156
|
+
hence both inputs are quantized. This is the ResNet skip-connection block tail.
|
|
157
|
+
|
|
158
|
+
**INT8 compatibility guard:** grouped convolutions where `in_channels / groups` or
|
|
159
|
+
`out_channels / groups` is not a multiple of 4 cannot be quantized to INT8 in
|
|
160
|
+
TensorRT. For those cases `_is_int8_compatible_conv()` returns `False`, and the
|
|
161
|
+
module sets `self.input_quant_stubs = {}` (overriding the `super().__init__()`
|
|
162
|
+
default), effectively opting out of quantization.
|
|
163
|
+
|
|
164
|
+
### Linear family — `linear.py`
|
|
165
|
+
|
|
166
|
+
Pattern that creates them: `patterns/fusions.py`.
|
|
167
|
+
|
|
168
|
+
**`FusedLinear`** — wraps a single `nn.Linear`, `inputs_to_quantize = {0}`.
|
|
169
|
+
|
|
170
|
+
**`FusedLinearAct`** — wraps `nn.Linear → Activation`, `inputs_to_quantize = {0}`.
|
|
171
|
+
|
|
172
|
+
**`FusedLayerNorm`** — wraps `nn.LayerNorm`, `inputs_to_quantize = set()`.
|
|
173
|
+
LayerNorm is placed in the linear family because it appears in transformer
|
|
174
|
+
architectures immediately before or after linear layers, and the SmoothQuant
|
|
175
|
+
pass needs to reason about LayerNorm and the following linear together.
|
|
176
|
+
Input quantization is disabled (`inputs_to_quantize = set()`, `prefers_fp_input =
|
|
177
|
+
True`) because LayerNorm normalises its input internally — quantizing the input
|
|
178
|
+
before LayerNorm would destroy the statistical properties that normalization relies
|
|
179
|
+
on.
|
|
180
|
+
|
|
181
|
+
**`FusedLayerNorm.smooth_quant_observer`** — holds a `SmoothQuantObserver`
|
|
182
|
+
instance. It is created in `__init__` but is populated (i.e., scale factors are
|
|
183
|
+
computed and migrated) by the SmoothQuant calibration pass in
|
|
184
|
+
`core/quantize/calibrate.py`. The module itself does not use the observer
|
|
185
|
+
during `forward()` — it only provides a hook for the calibration pass to attach to.
|
|
186
|
+
|
|
187
|
+
### Attention family — `attention.py`
|
|
188
|
+
|
|
189
|
+
Pattern that creates the `ConvertedModule` subclasses:
|
|
190
|
+
`patterns/conversions/attention.py` (`DecomposeMultiheadAttentionPattern`).
|
|
191
|
+
Pattern that creates the `FusedModule` wrappers: `patterns/fusions.py`.
|
|
192
|
+
|
|
193
|
+
**Decomposition model:** `nn.MultiheadAttention` is opaque to FX and to TensorRT.
|
|
194
|
+
The conversion pass decomposes it into three explicit sub-modules:
|
|
195
|
+
|
|
196
|
+
1. `MHAInProjection` (in-projection) — packed `Linear(E, 3E)` followed by
|
|
197
|
+
chunking and head splitting.
|
|
198
|
+
2. `ScaledDotProductAttention` (SDPA) — `softmax(Q·Kᵀ / √H) · V`.
|
|
199
|
+
3. `nn.Linear` (out-projection) — the original `mha.out_proj`.
|
|
200
|
+
|
|
201
|
+
**Why in-proj returns a tuple:** The in-projection produces three independent
|
|
202
|
+
tensors (Q, K, V). TensorRT can fuse the single packed `Linear(E, 3E)` and the
|
|
203
|
+
split into an efficient multi-head projection kernel. Returning a tuple of
|
|
204
|
+
`(Q, K, V)` lets the graph represent the split explicitly — the conversion pass
|
|
205
|
+
inserts `operator.getitem` nodes to fan out the tuple into three separate feeds
|
|
206
|
+
for SDPA.
|
|
207
|
+
|
|
208
|
+
**Head splitting:** Inside `MHAInProjection.forward()`, after the packed linear
|
|
209
|
+
op, each of Q, K, V is reshaped from `[B, S, E]` to `[B, num_heads, S, head_dim]`
|
|
210
|
+
via `view` + `transpose`. This puts the head dimension before the sequence
|
|
211
|
+
dimension, which is the layout expected by `F.scaled_dot_product_attention` and
|
|
212
|
+
by matrix-multiply kernels.
|
|
213
|
+
|
|
214
|
+
**`FusedMHAInProjection`** — wraps `MHAInProjection`, `inputs_to_quantize = {0}`.
|
|
215
|
+
Adds a `WeightFakeQuantize` for the packed linear weight. Only `query` (index 0)
|
|
216
|
+
is quantized; `_key` and `_value` are accepted to match the self-attention
|
|
217
|
+
call-site but ignored.
|
|
218
|
+
|
|
219
|
+
**`FusedScaledDotProductAttention`** — wraps `ScaledDotProductAttention`,
|
|
220
|
+
`inputs_to_quantize = set()`. Adds an internal `softmax_quant` stub with a fixed
|
|
221
|
+
calibration of `(1/127, 0)` — i.e., 8-bit symmetric with a fixed scale matched to
|
|
222
|
+
the softmax output range `[0, 1]`. When the stub is disabled the module delegates
|
|
223
|
+
to the plain SDPA; when enabled it performs manual attention with the quantization
|
|
224
|
+
step between softmax and the second batched matrix multiply (BMM2).
|
|
225
|
+
|
|
226
|
+
### Swin Attention family — `swin_attention.py`
|
|
227
|
+
|
|
228
|
+
Pattern that creates the `ConvertedModule` subclasses:
|
|
229
|
+
`patterns/conversions/attention.py` (`DecomposeSwinAttentionPattern`). Pattern
|
|
230
|
+
that creates `FusedSwinAttention`: `patterns/fusions.py`.
|
|
231
|
+
|
|
232
|
+
**Decomposition model:** `torchvision`'s `shifted_window_attention` free function
|
|
233
|
+
is an opaque `fx.wrap`-ped call that includes spatial padding, cyclic shifting,
|
|
234
|
+
window partitioning, QKV projection, attention, output projection, and window
|
|
235
|
+
reversal — all in one node. The conversion pass splits it into five sub-modules
|
|
236
|
+
that downstream fusion patterns can match individually:
|
|
237
|
+
|
|
238
|
+
1. `SwinWindowPartition` — pad, cyclic-shift, partition to windows.
|
|
239
|
+
2. `MHAInProjection` — QKV projection (shared with the MHA attention family).
|
|
240
|
+
3. `SwinAttention` — windowed attention with relative position bias and shifted-window mask.
|
|
241
|
+
4. `nn.Linear` — output projection.
|
|
242
|
+
5. `SwinWindowReverse` — reverse partition, shift, and unpad.
|
|
243
|
+
|
|
244
|
+
**`SwinSpatialState`** — a mutable dataclass shared by all three spatial modules
|
|
245
|
+
(`SwinWindowPartition`, `SwinAttention`, `SwinWindowReverse`). During each
|
|
246
|
+
`forward()` call:
|
|
247
|
+
|
|
248
|
+
- `SwinWindowPartition.forward()` **writes** `batch_size`, `height`, `width`,
|
|
249
|
+
`pad_height`, `pad_width`, and `effective_shift_size` into the state.
|
|
250
|
+
- `SwinAttention.forward()` **reads** `pad_height`, `pad_width`,
|
|
251
|
+
`effective_shift_size`, and `batch_size` to compute the attention mask.
|
|
252
|
+
- `SwinWindowReverse.forward()` **reads** all fields to undo the partition and
|
|
253
|
+
remove padding.
|
|
254
|
+
|
|
255
|
+
The state must be shared (not copied) because `SwinWindowPartition` computes the
|
|
256
|
+
actual padded dimensions and effective shift at runtime — these depend on the input
|
|
257
|
+
spatial size, which is not known until `forward()` is called.
|
|
258
|
+
|
|
259
|
+
**`deepcopy` note:** `copy.deepcopy` preserves the shared reference among the
|
|
260
|
+
three modules. This is intentional: the three modules reference the *same*
|
|
261
|
+
`SwinSpatialState` object; `deepcopy` copies the object once and updates all
|
|
262
|
+
referencing attributes to point to the copy. However, the sharing is not
|
|
263
|
+
thread-safe (see Gotchas).
|
|
264
|
+
|
|
265
|
+
**`FusedSwinAttention`** — wraps `SwinAttention`, `inputs_to_quantize = set()`.
|
|
266
|
+
Mirrors `FusedScaledDotProductAttention`: adds an internal `softmax_quant` stub
|
|
267
|
+
with fixed calibration. When enabled it manually expands the attention computation
|
|
268
|
+
to insert the quantization step between softmax and BMM2.
|
|
269
|
+
|
|
270
|
+
### Pointwise family — `pointwise.py`
|
|
271
|
+
|
|
272
|
+
Pattern that creates it: `patterns/fusions.py` (`ActAddPattern`).
|
|
273
|
+
|
|
274
|
+
**`FusedActAdd`** — `Activation → add(·, residual)`, `inputs_to_quantize = {0, 1}`.
|
|
275
|
+
Constructor takes `(act: ActivationLike)`. `forward(x, residual)` applies
|
|
276
|
+
`act(x) + residual`. This prevents TensorRT from merging the upstream convolution
|
|
277
|
+
into an activation-fused kernel when the activation output is consumed by a
|
|
278
|
+
subsequent add. Both the activated feature map and the skip-connection tensor are
|
|
279
|
+
quantized at their respective scales.
|
|
280
|
+
|
|
281
|
+
### Pool family — `pool.py`
|
|
282
|
+
|
|
283
|
+
Pattern that creates it: `patterns/fusions.py`.
|
|
284
|
+
|
|
285
|
+
**`FusedAdaptiveAvgPool2d`** — wraps `nn.AdaptiveAvgPool2d`,
|
|
286
|
+
`inputs_to_quantize = set()`. No quantization is applied (pooling is a
|
|
287
|
+
linear operation that does not benefit from separate quantization). Exists as a
|
|
288
|
+
`FusedModule` so the Q/DQ pass treats it uniformly without special-casing it.
|
|
289
|
+
|
|
290
|
+
---
|
|
291
|
+
|
|
292
|
+
## BN Folding
|
|
293
|
+
|
|
294
|
+
Batch normalisation folding is the process of absorbing the BN scale and bias into
|
|
295
|
+
the preceding convolution weight and bias at inference time:
|
|
296
|
+
|
|
297
|
+
```
|
|
298
|
+
w_folded = w * (γ / σ)
|
|
299
|
+
b_folded = (b - μ) * (γ / σ) + β
|
|
300
|
+
```
|
|
301
|
+
|
|
302
|
+
where `γ`, `β` are BN's learned weight/bias and `μ`, `σ` are running statistics.
|
|
303
|
+
|
|
304
|
+
The fused modules store the original `nn.Conv2d` and `nn.BatchNorm2d` as separate
|
|
305
|
+
sub-modules. During training, the BN is applied as a normal operation (live
|
|
306
|
+
statistics). At export, TensorRT performs the folding in the engine. Whether a
|
|
307
|
+
fused module carries a BN is determined implicitly: `self.bn is not None` means
|
|
308
|
+
the module was created from a pattern that included a `BatchNorm2d`, and folding
|
|
309
|
+
is safe. When `self.bn is None`, the convolution weight is used as-is.
|
|
310
|
+
|
|
311
|
+
---
|
|
312
|
+
|
|
313
|
+
## Design Decisions
|
|
314
|
+
|
|
315
|
+
**Why modules store original sub-modules rather than pre-computing fused weights:**
|
|
316
|
+
Because the modules must be correct in both training and eval modes. Training uses
|
|
317
|
+
live BN statistics; the fused weights can only be computed accurately in eval mode
|
|
318
|
+
with frozen running mean/variance. Storing the originals defers the decision to
|
|
319
|
+
the export step while keeping `forward()` correct throughout.
|
|
320
|
+
|
|
321
|
+
**Why `inputs_to_quantize` is a class attribute:** The Q/DQ insertion pass queries
|
|
322
|
+
the set before constructing any instances — it reads it from the class via
|
|
323
|
+
`type(mod).inputs_to_quantize`. Making it a class attribute also prevents
|
|
324
|
+
accidental per-instance divergence.
|
|
325
|
+
|
|
326
|
+
**Why intermediate decomposition modules are `ConvertedModule`:** The custom tracer
|
|
327
|
+
(`_LeafTracer` in `core/modules.py`) checks `isinstance(m, (ConvertedModule,
|
|
328
|
+
FusedModule))` to decide whether to treat a module as a leaf. `ConvertedModule` is
|
|
329
|
+
the marker that tells the tracer "do not recurse into this module's `forward()`".
|
|
330
|
+
Without this, re-tracing the graph after conversion would unwrap the decomposed
|
|
331
|
+
sub-modules, defeating the purpose of decomposition.
|
|
332
|
+
|
|
333
|
+
---
|
|
334
|
+
|
|
335
|
+
## Gotchas & Pitfalls
|
|
336
|
+
|
|
337
|
+
**`SwinSpatialState` thread-safety:** The state is a mutable shared object written
|
|
338
|
+
during every `SwinWindowPartition.forward()` call and read by the other two
|
|
339
|
+
modules in the same forward pass. If the model is run from multiple threads
|
|
340
|
+
concurrently (e.g. in a data-parallel setup without model replication), the writes
|
|
341
|
+
from one thread will corrupt the reads of another. Use `torch.nn.DataParallel` (not
|
|
342
|
+
`DistributedDataParallel`) only if each replica gets its own model copy via
|
|
343
|
+
`deepcopy`, which does preserve sharing within a single copy.
|
|
344
|
+
|
|
345
|
+
**`inputs_to_quantize` indices vs. Python `forward()` args:** The indices refer to
|
|
346
|
+
the positional arguments of the FX `call_module` node in the traced graph, not to
|
|
347
|
+
the `forward()` Python signature. In the normal case they are identical. They
|
|
348
|
+
diverge if the FX graph is produced from a non-trivial call-site (e.g. keyword
|
|
349
|
+
arguments re-ordered). Always verify against the actual graph node when debugging
|
|
350
|
+
Q/DQ placement.
|
|
351
|
+
|
|
352
|
+
**Weight sharing:** The conv/linear sub-modules stored inside fused modules hold
|
|
353
|
+
references to the tensors from the *original* model. If the original model's
|
|
354
|
+
weights are modified after fusion, the fused module sees the change. This is
|
|
355
|
+
usually desirable (e.g. for QAT gradient updates), but can be surprising if the
|
|
356
|
+
original model is used independently.
|
|
357
|
+
|
|
358
|
+
**Grouped conv INT8 opt-out:** When `_is_int8_compatible_conv()` returns `False`
|
|
359
|
+
the `__init__` of the conv fused modules sets `self.input_quant_stubs = {}`,
|
|
360
|
+
overriding the dict populated by `FusedModule.__init__()`. This means the module
|
|
361
|
+
is effectively excluded from quantization despite being a `FusedModule`. The
|
|
362
|
+
`weight_fake_quant` attribute is also not created in this path.
|
|
363
|
+
|
|
364
|
+
---
|
|
365
|
+
|
|
366
|
+
## Adding a New Fused Module
|
|
367
|
+
|
|
368
|
+
1. **Subclass `FusedModule`** (from `core/modules.py`).
|
|
369
|
+
2. **Set `inputs_to_quantize`** as a class attribute — a `set[int]` of positional
|
|
370
|
+
argument indices that should receive activation `QuantStub`s.
|
|
371
|
+
3. **Optionally add `self.weight_fake_quant = WeightFakeQuantize({self})`** in
|
|
372
|
+
`__init__` if the module has a learnable weight that should be fake-quantized.
|
|
373
|
+
4. **Implement `forward()`** with the fused computation.
|
|
374
|
+
5. **Write a `Pattern` subclass** in `patterns/fusions.py`. Prefer declaring
|
|
375
|
+
`graft = FusedFoo` (bare class) so the graft system handles replacement
|
|
376
|
+
automatically. The constructor must accept modules in tree order (nested
|
|
377
|
+
branches first, then trunk nodes). If the replacement logic cannot be
|
|
378
|
+
expressed as a bare-class graft, provide a `ReplacementMaker` or a custom
|
|
379
|
+
`replace()` method instead.
|
|
380
|
+
6. **Add the pattern to `TENSORRT_PATTERNS`** (or the appropriate pattern list) in
|
|
381
|
+
`tensorrt/plan.py`.
|
|
382
|
+
7. **Write tests** in `tests/tensorrt/patterns/fusions/` covering: pattern match,
|
|
383
|
+
pattern replace, correct quantisation stub placement.
|
|
384
|
+
|
|
385
|
+
---
|
|
386
|
+
|
|
387
|
+
## Testing
|
|
388
|
+
|
|
389
|
+
Test models for these modules live in two locations:
|
|
390
|
+
|
|
391
|
+
- `tests/models/conv.py` — `ConvBnRelu`, `ConvBnSiLU`, `ConvBn`, `ConvOnly`,
|
|
392
|
+
`ConvBnAddRelu`, `StemConvBnReluMaxPool`, etc.
|
|
393
|
+
- `tests/models/attention.py` — `SimpleSelfAttention`, `SimpleSwinAttention`, etc.
|
|
394
|
+
- `tests/models/linear.py`, `tests/models/pool.py` — linear and pool variants.
|
|
395
|
+
- `tests/tensorrt/models/attention.py` — `LayerNormMHAInProjection`,
|
|
396
|
+
`SDPALinearOutProjection`, `SwinAttentionLinearOutProjection`.
|
|
397
|
+
|
|
398
|
+
Pattern tests live in `tests/tensorrt/patterns/fusions/`.
|
|
399
|
+
|
|
400
|
+
**What to verify for a new fused module:**
|
|
401
|
+
|
|
402
|
+
- The pattern matches exactly the expected nodes (check `tree_match.serialize()`).
|
|
403
|
+
- The fused node's `target` name is the expected auto-generated name.
|
|
404
|
+
- `resolve_module(fused_node, FusedXxx)` succeeds.
|
|
405
|
+
- `bool(fused_module.input_quant_stubs)` matches `inputs_to_quantize` expectations.
|
|
406
|
+
- `hasattr(fused_module, 'weight_fake_quant')` matches expectations.
|
|
407
|
+
- Graph equivalence: fused model output matches original model output on the same
|
|
408
|
+
input (see `tests/tensorrt/test_equivalence.py`).
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
"""Attention sub-modules introduced by MHA decomposition.
|
|
4
4
|
|
|
5
5
|
These plain ``nn.Module`` subclasses replace the opaque
|
|
6
|
-
``nn.MultiheadAttention`` in the FX graph.
|
|
6
|
+
``nn.MultiheadAttention`` in the FX graph. Phase 2 creates ``Fused*`` wrappers
|
|
7
7
|
around them for Q/DQ insertion.
|
|
8
8
|
"""
|
|
9
9
|
|
|
@@ -138,10 +138,9 @@ class ScaledDotProductAttention(ConvertedModule):
|
|
|
138
138
|
on this. Passes through to ``F.scaled_dot_product_attention``
|
|
139
139
|
unchanged (``None`` is the no-mask default).
|
|
140
140
|
:returns:
|
|
141
|
-
Output tensor ``[B, num_heads, S, head_dim]``.
|
|
141
|
+
Output tensor ``[B, num_heads, S, head_dim]``. Callers are
|
|
142
142
|
responsible for any subsequent head-flattening reshape.
|
|
143
143
|
"""
|
|
144
|
-
# pylint: disable-next=not-callable
|
|
145
144
|
return F.scaled_dot_product_attention(
|
|
146
145
|
q,
|
|
147
146
|
k,
|
|
@@ -166,8 +165,8 @@ class ScaledDotProductAttention(ConvertedModule):
|
|
|
166
165
|
class FusedMHAInProjection(FusedModule):
|
|
167
166
|
"""Fused wrapper for ``MHAInProjection``.
|
|
168
167
|
|
|
169
|
-
Allows the Q/DQ insertion pass to place quantize / dequantize stubs
|
|
170
|
-
|
|
168
|
+
Allows the Q/DQ insertion pass to place quantize / dequantize stubs around
|
|
169
|
+
the input projection and to attach a
|
|
171
170
|
:class:`~embedl_deploy._internal.core.quantize.stubs.WeightFakeQuantize`
|
|
172
171
|
for the packed linear weight.
|
|
173
172
|
|
|
@@ -184,6 +183,10 @@ class FusedMHAInProjection(FusedModule):
|
|
|
184
183
|
self.in_proj = in_proj
|
|
185
184
|
attach_int8_weight_quant(self, in_proj.linear)
|
|
186
185
|
|
|
186
|
+
@property
|
|
187
|
+
def quantized_weight(self) -> torch.Tensor | None:
|
|
188
|
+
return self.in_proj.linear.weight
|
|
189
|
+
|
|
187
190
|
def forward(
|
|
188
191
|
self,
|
|
189
192
|
query: torch.Tensor,
|
|
@@ -192,10 +195,10 @@ class FusedMHAInProjection(FusedModule):
|
|
|
192
195
|
) -> tuple[torch.Tensor, ...]:
|
|
193
196
|
"""Project input to per-head ``(Q, K, V)`` tensors.
|
|
194
197
|
|
|
195
|
-
Fake-quantizes the packed projection weight when enabled,
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
198
|
+
Fake-quantizes the packed projection weight when enabled, then performs
|
|
199
|
+
the linear operation. Only `query` is used; `_key` and `_value` are
|
|
200
|
+
accepted to match the call-site signature but ignored for
|
|
201
|
+
self-attention.
|
|
199
202
|
|
|
200
203
|
:param query:
|
|
201
204
|
Input tensor of shape ``[B, S, E]``.
|
|
@@ -204,7 +207,6 @@ class FusedMHAInProjection(FusedModule):
|
|
|
204
207
|
"""
|
|
205
208
|
weight = maybe_quantize_weight(self, self.in_proj.linear.weight)
|
|
206
209
|
batch, seq, _ = query.shape
|
|
207
|
-
# pylint: disable-next=not-callable
|
|
208
210
|
qkv = F.linear(query, weight, self.in_proj.linear.bias)
|
|
209
211
|
q, k, v = qkv.chunk(3, dim=-1)
|
|
210
212
|
num_heads = self.in_proj.num_heads
|
|
@@ -227,13 +229,13 @@ class FusedMHAInProjection(FusedModule):
|
|
|
227
229
|
class FusedScaledDotProductAttention(FusedModule):
|
|
228
230
|
"""Fused wrapper for ``ScaledDotProductAttention``.
|
|
229
231
|
|
|
230
|
-
Allows the Q/DQ insertion pass to place quantize / dequantize stubs
|
|
231
|
-
|
|
232
|
+
Allows the Q/DQ insertion pass to place quantize / dequantize stubs on each
|
|
233
|
+
of the three inputs (Q, K, V).
|
|
232
234
|
|
|
233
235
|
Additionally holds an internal
|
|
234
|
-
:class:`~embedl_deploy._internal.core.quantize.stubs.QuantStub` between
|
|
235
|
-
|
|
236
|
-
|
|
236
|
+
:class:`~embedl_deploy._internal.core.quantize.stubs.QuantStub` between the
|
|
237
|
+
softmax output and the second batched matrix multiply (BMM2). When that
|
|
238
|
+
stub is disabled the forward pass delegates to the unwrapped
|
|
237
239
|
:class:`~embedl_deploy._internal.tensorrt.modules.attention.ScaledDotProductAttention`;
|
|
238
240
|
when enabled it performs manual attention with the quantization step.
|
|
239
241
|
|
|
@@ -266,9 +268,9 @@ class FusedScaledDotProductAttention(FusedModule):
|
|
|
266
268
|
|
|
267
269
|
When the SDPA has been surrounded by ``QuantStub``\ s on its Q/K/V
|
|
268
270
|
inputs *and* the internal softmax quant stub is enabled, performs
|
|
269
|
-
manual attention with a quantization step between softmax and
|
|
270
|
-
|
|
271
|
-
|
|
271
|
+
manual attention with a quantization step between softmax and BMM2.
|
|
272
|
+
Otherwise delegates to the wrapped attention module so TensorRT can
|
|
273
|
+
fuse it into its native FP16 MHA kernel.
|
|
272
274
|
|
|
273
275
|
:param q:
|
|
274
276
|
Query tensor ``[B, num_heads, S, head_dim]``.
|
|
@@ -281,7 +283,7 @@ class FusedScaledDotProductAttention(FusedModule):
|
|
|
281
283
|
additive float mask broadcastable to ``[B, num_heads, S, S]``
|
|
282
284
|
or a bool mask where ``True`` means "attend".
|
|
283
285
|
:returns:
|
|
284
|
-
Output tensor ``[B, num_heads, S, head_dim]``.
|
|
286
|
+
Output tensor ``[B, num_heads, S, head_dim]``. Callers are
|
|
285
287
|
responsible for any subsequent head-flattening reshape.
|
|
286
288
|
"""
|
|
287
289
|
# Manual attention is only beneficial when this SDPA was
|
|
@@ -292,7 +294,7 @@ class FusedScaledDotProductAttention(FusedModule):
|
|
|
292
294
|
# MHA kernel onto the slower INT8-aware variant for no gain.
|
|
293
295
|
if not self.surrounded or not self.softmax_quant.enabled:
|
|
294
296
|
return self.attention(q, k, v, attn_mask)
|
|
295
|
-
#
|
|
297
|
+
# Honor the wrapped attention module's explicit ``scale`` if
|
|
296
298
|
# set — models that pre-scale Q themselves (chronos-2 + RoPE,
|
|
297
299
|
# for example) build with ``scale=1.0`` to disable the default
|
|
298
300
|
# ``1/sqrt(head_dim)`` scaling. Falling back to the default
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
"""Fused ``nn.Module`` replacements for convolution-based patterns.
|
|
4
4
|
|
|
5
5
|
Each class represents a hardware-fusible operation that replaces a multi-op
|
|
6
|
-
chain found by the pattern matcher.
|
|
6
|
+
chain found by the pattern matcher. The fused module keeps the original sub-
|
|
7
7
|
modules (``Conv``, ``BN``, ``ReLU``) as children so that:
|
|
8
8
|
|
|
9
9
|
* Weights are trivially transferred from the original model.
|
|
@@ -25,11 +25,11 @@ def _is_int8_compatible_conv(conv: nn.Conv2d) -> bool:
|
|
|
25
25
|
"""Return ``True`` unless *conv* is a grouped conv violating TRT INT8.
|
|
26
26
|
|
|
27
27
|
TensorRT's documented constraint for ``IConvolutionLayer`` is that
|
|
28
|
-
``in_channels / groups`` and ``out_channels / groups`` must both
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
28
|
+
``in_channels / groups`` and ``out_channels / groups`` must both be
|
|
29
|
+
multiples of 4 in INT8 mode. Depthwise convolutions (``groups ==
|
|
30
|
+
in_channels``) are an exception: our benchmarks on the target devices show
|
|
31
|
+
they still benefit from INT8 despite channels-per-group being 1, so we let
|
|
32
|
+
them through.
|
|
33
33
|
"""
|
|
34
34
|
if conv.groups <= 1:
|
|
35
35
|
return True
|
|
@@ -51,7 +51,6 @@ def _conv_weight_forward(
|
|
|
51
51
|
if weight_fake_quant is not None
|
|
52
52
|
else conv.weight
|
|
53
53
|
)
|
|
54
|
-
# pylint: disable-next=not-callable
|
|
55
54
|
return F.conv2d(
|
|
56
55
|
x,
|
|
57
56
|
weight,
|
|
@@ -83,6 +82,10 @@ class FusedConvBNAct(FusedModule):
|
|
|
83
82
|
else:
|
|
84
83
|
self.input_quant_stubs = {}
|
|
85
84
|
|
|
85
|
+
@property
|
|
86
|
+
def quantized_weight(self) -> torch.Tensor | None:
|
|
87
|
+
return self.conv.weight
|
|
88
|
+
|
|
86
89
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
87
90
|
"""Apply ``conv → [bn] → act``."""
|
|
88
91
|
wfq = getattr(self, 'weight_fake_quant', None)
|
|
@@ -121,6 +124,10 @@ class FusedConvBN(FusedModule):
|
|
|
121
124
|
else:
|
|
122
125
|
self.input_quant_stubs = {}
|
|
123
126
|
|
|
127
|
+
@property
|
|
128
|
+
def quantized_weight(self) -> torch.Tensor | None:
|
|
129
|
+
return self.conv.weight
|
|
130
|
+
|
|
124
131
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
125
132
|
"""Apply ``conv → [bn]``."""
|
|
126
133
|
wfq = getattr(self, 'weight_fake_quant', None)
|
|
@@ -160,6 +167,10 @@ class FusedConvBNActMaxPool(FusedModule):
|
|
|
160
167
|
self.maxpool = maxpool
|
|
161
168
|
self.weight_fake_quant = WeightFakeQuantize({self})
|
|
162
169
|
|
|
170
|
+
@property
|
|
171
|
+
def quantized_weight(self) -> torch.Tensor | None:
|
|
172
|
+
return self.conv.weight
|
|
173
|
+
|
|
163
174
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
164
175
|
"""Apply ``conv → [bn] → act → maxpool``."""
|
|
165
176
|
x = _conv_weight_forward(self.conv, self.weight_fake_quant, x)
|
|
@@ -206,6 +217,10 @@ class FusedConvBNAddAct(FusedModule):
|
|
|
206
217
|
else:
|
|
207
218
|
self.input_quant_stubs = {}
|
|
208
219
|
|
|
220
|
+
@property
|
|
221
|
+
def quantized_weight(self) -> torch.Tensor | None:
|
|
222
|
+
return self.conv.weight
|
|
223
|
+
|
|
209
224
|
def forward(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
|
|
210
225
|
"""Apply ``conv → bn → add(·, residual) → act``."""
|
|
211
226
|
wfq = getattr(self, 'weight_fake_quant', None)
|