embedl-deploy-tensorrt 0.5.0__tar.gz → 0.6.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/PKG-INFO +3 -6
  2. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/README.md +2 -5
  3. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/src/embedl_deploy/__init__.py +4 -0
  4. embedl_deploy_tensorrt-0.6.1/src/embedl_deploy/_internal/tensorrt/modules/AGENTS.md +408 -0
  5. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/src/embedl_deploy/_internal/tensorrt/modules/attention.py +22 -20
  6. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/src/embedl_deploy/_internal/tensorrt/modules/conv.py +22 -7
  7. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/src/embedl_deploy/_internal/tensorrt/modules/linear.py +23 -16
  8. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/src/embedl_deploy/_internal/tensorrt/modules/pointwise.py +1 -1
  9. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/src/embedl_deploy/_internal/tensorrt/modules/swin_attention.py +26 -19
  10. embedl_deploy_tensorrt-0.6.1/src/embedl_deploy/_internal/tensorrt/patterns/AGENTS.md +721 -0
  11. embedl_deploy_tensorrt-0.6.1/src/embedl_deploy/_internal/tensorrt/patterns/conversions/__init__.py +3 -0
  12. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/attention.py +92 -96
  13. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/general.py +16 -16
  14. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/src/embedl_deploy/_internal/tensorrt/patterns/fusions.py +9 -26
  15. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/src/embedl_deploy/_internal/tensorrt/patterns/quantizations.py +65 -49
  16. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/src/embedl_deploy/_internal/tensorrt/patterns/smoothings.py +17 -15
  17. embedl_deploy_tensorrt-0.6.1/src/embedl_deploy/_internal/tensorrt/plan.py +98 -0
  18. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/src/embedl_deploy/tensorrt/__init__.py +9 -7
  19. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/src/embedl_deploy/version/public.py +1 -1
  20. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/src/embedl_deploy_tensorrt.egg-info/SOURCES.txt +2 -5
  21. embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/conversions/__init__.py +0 -15
  22. embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/__init__.py +0 -3
  23. embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/attention.py +0 -808
  24. embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/functional.py +0 -325
  25. embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/general.py +0 -718
  26. embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/utils.py +0 -44
  27. embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/plan.py +0 -139
  28. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/LICENSE +0 -0
  29. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/MANIFEST.in +0 -0
  30. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/NOTICE +0 -0
  31. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/pyproject.toml +0 -0
  32. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/setup.cfg +0 -0
  33. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/src/embedl_deploy/_internal/__init__.py +0 -0
  34. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/src/embedl_deploy/_internal/tensorrt/__init__.py +0 -0
  35. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/src/embedl_deploy/_internal/tensorrt/backend.py +0 -0
  36. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/src/embedl_deploy/_internal/tensorrt/modules/__init__.py +0 -0
  37. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/src/embedl_deploy/_internal/tensorrt/modules/pool.py +0 -0
  38. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/src/embedl_deploy/_internal/tensorrt/patterns/__init__.py +0 -0
  39. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/src/embedl_deploy/tensorrt/modules/__init__.py +0 -0
  40. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/src/embedl_deploy/tensorrt/patterns/__init__.py +0 -0
  41. {embedl_deploy_tensorrt-0.5.0 → embedl_deploy_tensorrt-0.6.1}/src/embedl_deploy/version/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: embedl-deploy-tensorrt
3
- Version: 0.5.0
3
+ Version: 0.6.1
4
4
  Summary: TensorRT backend for embedl-deploy.
5
5
  Author-email: Embedl AB <support@embedl.com>
6
6
  Project-URL: Homepage, https://www.embedl.com/
@@ -86,10 +86,7 @@ model = Model().eval()
86
86
  example_input = torch.randn(1, 3, 224, 224)
87
87
 
88
88
  # 2. Transform — fuse and optimize for TensorRT in one call
89
- # For more compatibility you can trace your model with torch.export.export
90
- # as follows:
91
- # model = torch.export.export(model, (example_input)).module()
92
- res = transform(model, patterns=TENSORRT_PATTERNS)
89
+ res = transform(model, (example_input,), patterns=TENSORRT_PATTERNS)
93
90
  print("Model\n", res.model.print_readable())
94
91
  print("Matches", "\n".join([str(match) for match in res.matches]))
95
92
 
@@ -149,7 +146,7 @@ the reference **from the fused graph**, not from the original model:
149
146
  ```python
150
147
  from embedl_deploy.quantize import quantize, QuantConfig, ModulesToSkip
151
148
 
152
- res = transform(model, patterns=TENSORRT_PATTERNS)
149
+ res = transform(model, (example_input,), patterns=TENSORRT_PATTERNS)
153
150
 
154
151
  # Grab the conv instance from the fused graph (not from the original model)
155
152
  first_conv = res.model.FusedConvBNActMaxPool_0.conv
@@ -67,10 +67,7 @@ model = Model().eval()
67
67
  example_input = torch.randn(1, 3, 224, 224)
68
68
 
69
69
  # 2. Transform — fuse and optimize for TensorRT in one call
70
- # For more compatibility you can trace your model with torch.export.export
71
- # as follows:
72
- # model = torch.export.export(model, (example_input)).module()
73
- res = transform(model, patterns=TENSORRT_PATTERNS)
70
+ res = transform(model, (example_input,), patterns=TENSORRT_PATTERNS)
74
71
  print("Model\n", res.model.print_readable())
75
72
  print("Matches", "\n".join([str(match) for match in res.matches]))
76
73
 
@@ -130,7 +127,7 @@ the reference **from the fused graph**, not from the original model:
130
127
  ```python
131
128
  from embedl_deploy.quantize import quantize, QuantConfig, ModulesToSkip
132
129
 
133
- res = transform(model, patterns=TENSORRT_PATTERNS)
130
+ res = transform(model, (example_input,), patterns=TENSORRT_PATTERNS)
134
131
 
135
132
  # Grab the conv instance from the fused graph (not from the original model)
136
133
  first_conv = res.model.FusedConvBNActMaxPool_0.conv
@@ -3,10 +3,12 @@
3
3
  """Python package to make AI models deployment-ready for any hardware."""
4
4
 
5
5
  from embedl_deploy._internal.core.plan import (
6
+ Trace,
6
7
  TransformationPlan,
7
8
  TransformationResult,
8
9
  apply_transformation_plan,
9
10
  get_transformation_plan,
11
+ prepare_graph,
10
12
  transform,
11
13
  )
12
14
  from embedl_deploy.version.public import PUBLIC_VERSION
@@ -14,10 +16,12 @@ from embedl_deploy.version.public import PUBLIC_VERSION
14
16
  __version__ = PUBLIC_VERSION
15
17
 
16
18
  __all__ = [
19
+ "Trace",
17
20
  "TransformationPlan",
18
21
  "TransformationResult",
19
22
  "__version__",
20
23
  "apply_transformation_plan",
21
24
  "get_transformation_plan",
25
+ "prepare_graph",
22
26
  "transform",
23
27
  ]
@@ -0,0 +1,408 @@
1
+ # TensorRT Fused Modules — Subsystem Guide
2
+
3
+ ## Role & Boundary
4
+
5
+ **This subsystem owns:** concrete `FusedModule` (and `ConvertedModule`) implementations
6
+ for TensorRT. Every class in `modules/` corresponds to a hardware-fusible or
7
+ structurally-decomposed operation the TensorRT backend can emit as a single engine
8
+ layer or kernel.
9
+
10
+ **This subsystem does NOT own:**
11
+
12
+ - When or how to match these modules in a graph — that lives in `patterns/`.
13
+ - Q/DQ stub insertion or calibration logic — that lives in `core/quantize/`.
14
+ - The pattern-matching engine itself — that lives in `core/`.
15
+
16
+ ---
17
+
18
+ ## Position in the System
19
+
20
+ ```
21
+ patterns/fusions.py # creates FusedModule instances
22
+ patterns/conversions/general.py # creates ConvertedModule instances (erasure, flatten-linear)
23
+ patterns/conversions/attention.py # creates ConvertedModule instances (MHA/Swin/SDPA)
24
+
25
+
26
+ modules/ (this subsystem)
27
+
28
+
29
+ core/quantize/prepare.py # walks the graph, uses isinstance(mod, FusedModule)
30
+ # and mod.inputs_to_quantize to insert Q/DQ stubs
31
+ ```
32
+
33
+ **Instantiated by:** pattern grafts and `replace()` methods in `patterns/fusions.py`
34
+ and `patterns/conversions/`. Most fusion patterns declare a `graft` attribute
35
+ pointing to the fused module class; the graft system calls `_collect_modules()`
36
+ to gather matched modules in tree order and passes them as positional arguments
37
+ to the constructor. For example, `ConvBNActPattern` grafts `FusedConvBNAct`;
38
+ `DecomposeMultiheadAttentionPattern.replace()` constructs `MHAInProjection` and
39
+ `ScaledDotProductAttention`.
40
+
41
+ **Used by:** The Q/DQ insertion pass in `core/quantize/prepare.py` via
42
+ `isinstance(mod, FusedModule)` and `mod.inputs_to_quantize`, to determine which
43
+ inputs of each `call_module` node should receive a `QuantStub`.
44
+
45
+ **Also used as intermediate graph nodes by:** conversion patterns, which produce
46
+ `ConvertedModule` subclasses (`MHAInProjection`, `ScaledDotProductAttention`,
47
+ `SwinWindowPartition`, `SwinAttention`, `SwinWindowReverse`). These stay opaque
48
+ during re-tracing so that downstream fusion patterns can see them as atomic nodes.
49
+
50
+ ---
51
+
52
+ ## FusedModule Contract
53
+
54
+ Every `FusedModule` subclass must satisfy three requirements:
55
+
56
+ 1. **Set `inputs_to_quantize` as a class attribute** (not an instance attribute).
57
+ The Q/DQ insertion pass reads this before the module is instantiated to decide
58
+ which positional arguments of the `call_module` FX node will have a `QuantStub`
59
+ inserted in front of them.
60
+
61
+ 2. **Call `super().__init__()`**, which creates `self.input_quant_stubs` — a dict
62
+ mapping each index in `inputs_to_quantize` to a fresh `QuantStub`. The Q/DQ
63
+ pass later enables and configures these stubs during `prepare_qdq()`. It also
64
+ initialises `self.surrounded = False`, which is later set to `True` by
65
+ `SurroundWithQuantStubsPattern` to mark modules that have been surrounded
66
+ with input `QuantStub` entries.
67
+
68
+ 3. **Implement `forward()`** with the fused computation the module represents.
69
+
70
+ ### Graft compatibility
71
+
72
+ When a pattern uses `graft = FusedFoo` (bare class), the graft system calls
73
+ `_collect_modules()` to walk the matched tree and collect the `nn.Module`
74
+ instances corresponding to trunk and fork nodes (nested branches first, then
75
+ trunk nodes). These are passed as positional arguments to the constructor.
76
+ Therefore the constructor signature must accept modules in the same order
77
+ they appear in the pattern tree.
78
+
79
+ ### What `inputs_to_quantize` means
80
+
81
+ `inputs_to_quantize` is a set of *positional argument indices* corresponding to
82
+ the `call_module` FX node's arguments — that is, the arguments visible in the
83
+ traced graph, not necessarily the Python keyword positions in `forward()`.
84
+
85
+ For example, `FusedConvBN.inputs_to_quantize = {0}` means the first tensor
86
+ argument to the fused node (the image tensor) gets a `QuantStub`. The convolution
87
+ weight is quantized separately via `WeightFakeQuantize`, which is attached to the
88
+ module directly rather than via `inputs_to_quantize`.
89
+
90
+ `FusedConvBNAddAct.inputs_to_quantize = {0, 1}` means both the main feature map
91
+ and the residual skip tensor each get their own `QuantStub`.
92
+
93
+ ### Why `inputs_to_quantize` is a class attribute
94
+
95
+ Because it describes the module *type's* quantization contract, not any particular
96
+ instance's configuration. The Q/DQ insertion pass queries it once per class during
97
+ graph preparation, before modules are constructed. Making it a class attribute
98
+ (rather than an instance attribute set in `__init__`) ensures the value is
99
+ available from the class itself without constructing an instance.
100
+
101
+ ---
102
+
103
+ ## ConvertedModule vs FusedModule
104
+
105
+ Both base classes are defined in `core/modules.py`. They serve different roles:
106
+
107
+ **`ConvertedModule` subclasses** are *intermediate decomposition products*. They
108
+ replace high-level opaque ops (e.g. `nn.MultiheadAttention`,
109
+ `shifted_window_attention`) with sub-modules that downstream *fusion* patterns can
110
+ then match individually. The custom tracer (`_LeafTracer`) treats them as leaf
111
+ nodes — they are never unwrapped during re-tracing. Examples: `MHAInProjection`,
112
+ `ScaledDotProductAttention`, `SwinWindowPartition`, `SwinAttention`,
113
+ `SwinWindowReverse`.
114
+
115
+ **`FusedModule` subclasses** are the *final quantization target*. The Q/DQ
116
+ insertion pass identifies them via `isinstance(mod, FusedModule)` and wraps their
117
+ inputs/weights with fake-quantization stubs. Examples: `FusedConvBN`,
118
+ `FusedConvBNAct`, `FusedLinear`, `FusedSwinAttention`, `FusedMHAInProjection`.
119
+
120
+ A `FusedModule` may wrap a `ConvertedModule` (e.g. `FusedMHAInProjection` wraps
121
+ `MHAInProjection`, `FusedSwinAttention` wraps `SwinAttention`). In that case the
122
+ conversion pass runs first, then the fusion pass replaces the `ConvertedModule`
123
+ with its `FusedModule` counterpart.
124
+
125
+ ---
126
+
127
+ ## Module Catalog
128
+
129
+ ### Conv family — `conv.py`
130
+
131
+ Pattern that creates them: `patterns/fusions.py` (`ConvBNActPattern`,
132
+ `ConvBNPattern`, `StemConvBNActMaxPoolPattern`, `ConvBNAddActPattern`).
133
+
134
+ All conv fused modules store a reference to the original `nn.Conv2d` (and
135
+ optionally `nn.BatchNorm2d`, `ActivationLike`, `nn.MaxPool2d`) — they do not
136
+ pre-compute fused weights. This lets the same module object work correctly in both
137
+ training (with live BN statistics) and eval (where BN is folded by TensorRT at
138
+ export time).
139
+
140
+ **BN folding indicator:** The presence of `self.bn` (not `None`) implicitly
141
+ indicates that BN can be folded into the convolution at export time. The
142
+ `__repr__` methods print `"(foldable)"` when `self.bn is not None`.
143
+
144
+ **`FusedConvBN`** — `Conv2d → [BatchNorm2d]`, `inputs_to_quantize = {0}`.
145
+
146
+ **`FusedConvBNAct`** — `Conv2d → [BatchNorm2d] → Activation`, `inputs_to_quantize = {0}`.
147
+
148
+ **`FusedConvBNActMaxPool`** — `Conv2d → [BatchNorm2d] → Activation → MaxPool2d`,
149
+ `inputs_to_quantize = {0}`. MaxPool is included here (rather than fused
150
+ separately) because TensorRT supports fusing the full stem sequence
151
+ (`Conv → BN → Act → Pool`) as a single engine node in the common 7x7 stem pattern.
152
+ Fusing them together avoids an extra quantized activation between Conv and Pool.
153
+
154
+ **`FusedConvBNAddAct`** — `Conv2d → BatchNorm2d → add(·, residual) → Activation`,
155
+ `inputs_to_quantize = {0, 1}`. The residual tensor is the second input (index 1),
156
+ hence both inputs are quantized. This is the ResNet skip-connection block tail.
157
+
158
+ **INT8 compatibility guard:** grouped convolutions where `in_channels / groups` or
159
+ `out_channels / groups` is not a multiple of 4 cannot be quantized to INT8 in
160
+ TensorRT. For those cases `_is_int8_compatible_conv()` returns `False`, and the
161
+ module sets `self.input_quant_stubs = {}` (overriding the `super().__init__()`
162
+ default), effectively opting out of quantization.
163
+
164
+ ### Linear family — `linear.py`
165
+
166
+ Pattern that creates them: `patterns/fusions.py`.
167
+
168
+ **`FusedLinear`** — wraps a single `nn.Linear`, `inputs_to_quantize = {0}`.
169
+
170
+ **`FusedLinearAct`** — wraps `nn.Linear → Activation`, `inputs_to_quantize = {0}`.
171
+
172
+ **`FusedLayerNorm`** — wraps `nn.LayerNorm`, `inputs_to_quantize = set()`.
173
+ LayerNorm is placed in the linear family because it appears in transformer
174
+ architectures immediately before or after linear layers, and the SmoothQuant
175
+ pass needs to reason about LayerNorm and the following linear together.
176
+ Input quantization is disabled (`inputs_to_quantize = set()`, `prefers_fp_input =
177
+ True`) because LayerNorm normalises its input internally — quantizing the input
178
+ before LayerNorm would destroy the statistical properties that normalization relies
179
+ on.
180
+
181
+ **`FusedLayerNorm.smooth_quant_observer`** — holds a `SmoothQuantObserver`
182
+ instance. It is created in `__init__` but is populated (i.e., scale factors are
183
+ computed and migrated) by the SmoothQuant calibration pass in
184
+ `core/quantize/calibrate.py`. The module itself does not use the observer
185
+ during `forward()` — it only provides a hook for the calibration pass to attach to.
186
+
187
+ ### Attention family — `attention.py`
188
+
189
+ Pattern that creates the `ConvertedModule` subclasses:
190
+ `patterns/conversions/attention.py` (`DecomposeMultiheadAttentionPattern`).
191
+ Pattern that creates the `FusedModule` wrappers: `patterns/fusions.py`.
192
+
193
+ **Decomposition model:** `nn.MultiheadAttention` is opaque to FX and to TensorRT.
194
+ The conversion pass decomposes it into three explicit sub-modules:
195
+
196
+ 1. `MHAInProjection` (in-projection) — packed `Linear(E, 3E)` followed by
197
+ chunking and head splitting.
198
+ 2. `ScaledDotProductAttention` (SDPA) — `softmax(Q·Kᵀ / √H) · V`.
199
+ 3. `nn.Linear` (out-projection) — the original `mha.out_proj`.
200
+
201
+ **Why in-proj returns a tuple:** The in-projection produces three independent
202
+ tensors (Q, K, V). TensorRT can fuse the single packed `Linear(E, 3E)` and the
203
+ split into an efficient multi-head projection kernel. Returning a tuple of
204
+ `(Q, K, V)` lets the graph represent the split explicitly — the conversion pass
205
+ inserts `operator.getitem` nodes to fan out the tuple into three separate feeds
206
+ for SDPA.
207
+
208
+ **Head splitting:** Inside `MHAInProjection.forward()`, after the packed linear
209
+ op, each of Q, K, V is reshaped from `[B, S, E]` to `[B, num_heads, S, head_dim]`
210
+ via `view` + `transpose`. This puts the head dimension before the sequence
211
+ dimension, which is the layout expected by `F.scaled_dot_product_attention` and
212
+ by matrix-multiply kernels.
213
+
214
+ **`FusedMHAInProjection`** — wraps `MHAInProjection`, `inputs_to_quantize = {0}`.
215
+ Adds a `WeightFakeQuantize` for the packed linear weight. Only `query` (index 0)
216
+ is quantized; `_key` and `_value` are accepted to match the self-attention
217
+ call-site but ignored.
218
+
219
+ **`FusedScaledDotProductAttention`** — wraps `ScaledDotProductAttention`,
220
+ `inputs_to_quantize = set()`. Adds an internal `softmax_quant` stub with a fixed
221
+ calibration of `(1/127, 0)` — i.e., 8-bit symmetric with a fixed scale matched to
222
+ the softmax output range `[0, 1]`. When the stub is disabled the module delegates
223
+ to the plain SDPA; when enabled it performs manual attention with the quantization
224
+ step between softmax and the second batched matrix multiply (BMM2).
225
+
226
+ ### Swin Attention family — `swin_attention.py`
227
+
228
+ Pattern that creates the `ConvertedModule` subclasses:
229
+ `patterns/conversions/attention.py` (`DecomposeSwinAttentionPattern`). Pattern
230
+ that creates `FusedSwinAttention`: `patterns/fusions.py`.
231
+
232
+ **Decomposition model:** `torchvision`'s `shifted_window_attention` free function
233
+ is an opaque `fx.wrap`-ped call that includes spatial padding, cyclic shifting,
234
+ window partitioning, QKV projection, attention, output projection, and window
235
+ reversal — all in one node. The conversion pass splits it into five sub-modules
236
+ that downstream fusion patterns can match individually:
237
+
238
+ 1. `SwinWindowPartition` — pad, cyclic-shift, partition to windows.
239
+ 2. `MHAInProjection` — QKV projection (shared with the MHA attention family).
240
+ 3. `SwinAttention` — windowed attention with relative position bias and shifted-window mask.
241
+ 4. `nn.Linear` — output projection.
242
+ 5. `SwinWindowReverse` — reverse partition, shift, and unpad.
243
+
244
+ **`SwinSpatialState`** — a mutable dataclass shared by all three spatial modules
245
+ (`SwinWindowPartition`, `SwinAttention`, `SwinWindowReverse`). During each
246
+ `forward()` call:
247
+
248
+ - `SwinWindowPartition.forward()` **writes** `batch_size`, `height`, `width`,
249
+ `pad_height`, `pad_width`, and `effective_shift_size` into the state.
250
+ - `SwinAttention.forward()` **reads** `pad_height`, `pad_width`,
251
+ `effective_shift_size`, and `batch_size` to compute the attention mask.
252
+ - `SwinWindowReverse.forward()` **reads** all fields to undo the partition and
253
+ remove padding.
254
+
255
+ The state must be shared (not copied) because `SwinWindowPartition` computes the
256
+ actual padded dimensions and effective shift at runtime — these depend on the input
257
+ spatial size, which is not known until `forward()` is called.
258
+
259
+ **`deepcopy` note:** `copy.deepcopy` preserves the shared reference among the
260
+ three modules. This is intentional: the three modules reference the *same*
261
+ `SwinSpatialState` object; `deepcopy` copies the object once and updates all
262
+ referencing attributes to point to the copy. However, the sharing is not
263
+ thread-safe (see Gotchas).
264
+
265
+ **`FusedSwinAttention`** — wraps `SwinAttention`, `inputs_to_quantize = set()`.
266
+ Mirrors `FusedScaledDotProductAttention`: adds an internal `softmax_quant` stub
267
+ with fixed calibration. When enabled it manually expands the attention computation
268
+ to insert the quantization step between softmax and BMM2.
269
+
270
+ ### Pointwise family — `pointwise.py`
271
+
272
+ Pattern that creates it: `patterns/fusions.py` (`ActAddPattern`).
273
+
274
+ **`FusedActAdd`** — `Activation → add(·, residual)`, `inputs_to_quantize = {0, 1}`.
275
+ Constructor takes `(act: ActivationLike)`. `forward(x, residual)` applies
276
+ `act(x) + residual`. This prevents TensorRT from merging the upstream convolution
277
+ into an activation-fused kernel when the activation output is consumed by a
278
+ subsequent add. Both the activated feature map and the skip-connection tensor are
279
+ quantized at their respective scales.
280
+
281
+ ### Pool family — `pool.py`
282
+
283
+ Pattern that creates it: `patterns/fusions.py`.
284
+
285
+ **`FusedAdaptiveAvgPool2d`** — wraps `nn.AdaptiveAvgPool2d`,
286
+ `inputs_to_quantize = set()`. No quantization is applied (pooling is a
287
+ linear operation that does not benefit from separate quantization). Exists as a
288
+ `FusedModule` so the Q/DQ pass treats it uniformly without special-casing it.
289
+
290
+ ---
291
+
292
+ ## BN Folding
293
+
294
+ Batch normalisation folding is the process of absorbing the BN scale and bias into
295
+ the preceding convolution weight and bias at inference time:
296
+
297
+ ```
298
+ w_folded = w * (γ / σ)
299
+ b_folded = (b - μ) * (γ / σ) + β
300
+ ```
301
+
302
+ where `γ`, `β` are BN's learned weight/bias and `μ`, `σ` are running statistics.
303
+
304
+ The fused modules store the original `nn.Conv2d` and `nn.BatchNorm2d` as separate
305
+ sub-modules. During training, the BN is applied as a normal operation (live
306
+ statistics). At export, TensorRT performs the folding in the engine. Whether a
307
+ fused module carries a BN is determined implicitly: `self.bn is not None` means
308
+ the module was created from a pattern that included a `BatchNorm2d`, and folding
309
+ is safe. When `self.bn is None`, the convolution weight is used as-is.
310
+
311
+ ---
312
+
313
+ ## Design Decisions
314
+
315
+ **Why modules store original sub-modules rather than pre-computing fused weights:**
316
+ Because the modules must be correct in both training and eval modes. Training uses
317
+ live BN statistics; the fused weights can only be computed accurately in eval mode
318
+ with frozen running mean/variance. Storing the originals defers the decision to
319
+ the export step while keeping `forward()` correct throughout.
320
+
321
+ **Why `inputs_to_quantize` is a class attribute:** The Q/DQ insertion pass queries
322
+ the set before constructing any instances — it reads it from the class via
323
+ `type(mod).inputs_to_quantize`. Making it a class attribute also prevents
324
+ accidental per-instance divergence.
325
+
326
+ **Why intermediate decomposition modules are `ConvertedModule`:** The custom tracer
327
+ (`_LeafTracer` in `core/modules.py`) checks `isinstance(m, (ConvertedModule,
328
+ FusedModule))` to decide whether to treat a module as a leaf. `ConvertedModule` is
329
+ the marker that tells the tracer "do not recurse into this module's `forward()`".
330
+ Without this, re-tracing the graph after conversion would unwrap the decomposed
331
+ sub-modules, defeating the purpose of decomposition.
332
+
333
+ ---
334
+
335
+ ## Gotchas & Pitfalls
336
+
337
+ **`SwinSpatialState` thread-safety:** The state is a mutable shared object written
338
+ during every `SwinWindowPartition.forward()` call and read by the other two
339
+ modules in the same forward pass. If the model is run from multiple threads
340
+ concurrently (e.g. in a data-parallel setup without model replication), the writes
341
+ from one thread will corrupt the reads of another. Use `torch.nn.DataParallel` (not
342
+ `DistributedDataParallel`) only if each replica gets its own model copy via
343
+ `deepcopy`, which does preserve sharing within a single copy.
344
+
345
+ **`inputs_to_quantize` indices vs. Python `forward()` args:** The indices refer to
346
+ the positional arguments of the FX `call_module` node in the traced graph, not to
347
+ the `forward()` Python signature. In the normal case they are identical. They
348
+ diverge if the FX graph is produced from a non-trivial call-site (e.g. keyword
349
+ arguments re-ordered). Always verify against the actual graph node when debugging
350
+ Q/DQ placement.
351
+
352
+ **Weight sharing:** The conv/linear sub-modules stored inside fused modules hold
353
+ references to the tensors from the *original* model. If the original model's
354
+ weights are modified after fusion, the fused module sees the change. This is
355
+ usually desirable (e.g. for QAT gradient updates), but can be surprising if the
356
+ original model is used independently.
357
+
358
+ **Grouped conv INT8 opt-out:** When `_is_int8_compatible_conv()` returns `False`
359
+ the `__init__` of the conv fused modules sets `self.input_quant_stubs = {}`,
360
+ overriding the dict populated by `FusedModule.__init__()`. This means the module
361
+ is effectively excluded from quantization despite being a `FusedModule`. The
362
+ `weight_fake_quant` attribute is also not created in this path.
363
+
364
+ ---
365
+
366
+ ## Adding a New Fused Module
367
+
368
+ 1. **Subclass `FusedModule`** (from `core/modules.py`).
369
+ 2. **Set `inputs_to_quantize`** as a class attribute — a `set[int]` of positional
370
+ argument indices that should receive activation `QuantStub`s.
371
+ 3. **Optionally add `self.weight_fake_quant = WeightFakeQuantize({self})`** in
372
+ `__init__` if the module has a learnable weight that should be fake-quantized.
373
+ 4. **Implement `forward()`** with the fused computation.
374
+ 5. **Write a `Pattern` subclass** in `patterns/fusions.py`. Prefer declaring
375
+ `graft = FusedFoo` (bare class) so the graft system handles replacement
376
+ automatically. The constructor must accept modules in tree order (nested
377
+ branches first, then trunk nodes). If the replacement logic cannot be
378
+ expressed as a bare-class graft, provide a `ReplacementMaker` or a custom
379
+ `replace()` method instead.
380
+ 6. **Add the pattern to `TENSORRT_PATTERNS`** (or the appropriate pattern list) in
381
+ `tensorrt/plan.py`.
382
+ 7. **Write tests** in `tests/tensorrt/patterns/fusions/` covering: pattern match,
383
+ pattern replace, correct quantisation stub placement.
384
+
385
+ ---
386
+
387
+ ## Testing
388
+
389
+ Test models for these modules live in two locations:
390
+
391
+ - `tests/models/conv.py` — `ConvBnRelu`, `ConvBnSiLU`, `ConvBn`, `ConvOnly`,
392
+ `ConvBnAddRelu`, `StemConvBnReluMaxPool`, etc.
393
+ - `tests/models/attention.py` — `SimpleSelfAttention`, `SimpleSwinAttention`, etc.
394
+ - `tests/models/linear.py`, `tests/models/pool.py` — linear and pool variants.
395
+ - `tests/tensorrt/models/attention.py` — `LayerNormMHAInProjection`,
396
+ `SDPALinearOutProjection`, `SwinAttentionLinearOutProjection`.
397
+
398
+ Pattern tests live in `tests/tensorrt/patterns/fusions/`.
399
+
400
+ **What to verify for a new fused module:**
401
+
402
+ - The pattern matches exactly the expected nodes (check `tree_match.serialize()`).
403
+ - The fused node's `target` name is the expected auto-generated name.
404
+ - `resolve_module(fused_node, FusedXxx)` succeeds.
405
+ - `bool(fused_module.input_quant_stubs)` matches `inputs_to_quantize` expectations.
406
+ - `hasattr(fused_module, 'weight_fake_quant')` matches expectations.
407
+ - Graph equivalence: fused model output matches original model output on the same
408
+ input (see `tests/tensorrt/test_equivalence.py`).
@@ -3,7 +3,7 @@
3
3
  """Attention sub-modules introduced by MHA decomposition.
4
4
 
5
5
  These plain ``nn.Module`` subclasses replace the opaque
6
- ``nn.MultiheadAttention`` in the FX graph. Phase 2 creates ``Fused*`` wrappers
6
+ ``nn.MultiheadAttention`` in the FX graph. Phase 2 creates ``Fused*`` wrappers
7
7
  around them for Q/DQ insertion.
8
8
  """
9
9
 
@@ -138,10 +138,9 @@ class ScaledDotProductAttention(ConvertedModule):
138
138
  on this. Passes through to ``F.scaled_dot_product_attention``
139
139
  unchanged (``None`` is the no-mask default).
140
140
  :returns:
141
- Output tensor ``[B, num_heads, S, head_dim]``. Callers are
141
+ Output tensor ``[B, num_heads, S, head_dim]``. Callers are
142
142
  responsible for any subsequent head-flattening reshape.
143
143
  """
144
- # pylint: disable-next=not-callable
145
144
  return F.scaled_dot_product_attention(
146
145
  q,
147
146
  k,
@@ -166,8 +165,8 @@ class ScaledDotProductAttention(ConvertedModule):
166
165
  class FusedMHAInProjection(FusedModule):
167
166
  """Fused wrapper for ``MHAInProjection``.
168
167
 
169
- Allows the Q/DQ insertion pass to place quantize / dequantize stubs
170
- around the input projection and to attach a
168
+ Allows the Q/DQ insertion pass to place quantize / dequantize stubs around
169
+ the input projection and to attach a
171
170
  :class:`~embedl_deploy._internal.core.quantize.stubs.WeightFakeQuantize`
172
171
  for the packed linear weight.
173
172
 
@@ -184,6 +183,10 @@ class FusedMHAInProjection(FusedModule):
184
183
  self.in_proj = in_proj
185
184
  attach_int8_weight_quant(self, in_proj.linear)
186
185
 
186
+ @property
187
+ def quantized_weight(self) -> torch.Tensor | None:
188
+ return self.in_proj.linear.weight
189
+
187
190
  def forward(
188
191
  self,
189
192
  query: torch.Tensor,
@@ -192,10 +195,10 @@ class FusedMHAInProjection(FusedModule):
192
195
  ) -> tuple[torch.Tensor, ...]:
193
196
  """Project input to per-head ``(Q, K, V)`` tensors.
194
197
 
195
- Fake-quantizes the packed projection weight when enabled,
196
- then performs the linear operation. Only `query` is used;
197
- `_key` and `_value` are accepted to match the call-site
198
- signature but ignored for self-attention.
198
+ Fake-quantizes the packed projection weight when enabled, then performs
199
+ the linear operation. Only `query` is used; `_key` and `_value` are
200
+ accepted to match the call-site signature but ignored for
201
+ self-attention.
199
202
 
200
203
  :param query:
201
204
  Input tensor of shape ``[B, S, E]``.
@@ -204,7 +207,6 @@ class FusedMHAInProjection(FusedModule):
204
207
  """
205
208
  weight = maybe_quantize_weight(self, self.in_proj.linear.weight)
206
209
  batch, seq, _ = query.shape
207
- # pylint: disable-next=not-callable
208
210
  qkv = F.linear(query, weight, self.in_proj.linear.bias)
209
211
  q, k, v = qkv.chunk(3, dim=-1)
210
212
  num_heads = self.in_proj.num_heads
@@ -227,13 +229,13 @@ class FusedMHAInProjection(FusedModule):
227
229
  class FusedScaledDotProductAttention(FusedModule):
228
230
  """Fused wrapper for ``ScaledDotProductAttention``.
229
231
 
230
- Allows the Q/DQ insertion pass to place quantize / dequantize stubs
231
- on each of the three inputs (Q, K, V).
232
+ Allows the Q/DQ insertion pass to place quantize / dequantize stubs on each
233
+ of the three inputs (Q, K, V).
232
234
 
233
235
  Additionally holds an internal
234
- :class:`~embedl_deploy._internal.core.quantize.stubs.QuantStub` between
235
- the softmax output and the second batched matrix multiply (BMM2). When
236
- that stub is disabled the forward pass delegates to the unwrapped
236
+ :class:`~embedl_deploy._internal.core.quantize.stubs.QuantStub` between the
237
+ softmax output and the second batched matrix multiply (BMM2). When that
238
+ stub is disabled the forward pass delegates to the unwrapped
237
239
  :class:`~embedl_deploy._internal.tensorrt.modules.attention.ScaledDotProductAttention`;
238
240
  when enabled it performs manual attention with the quantization step.
239
241
 
@@ -266,9 +268,9 @@ class FusedScaledDotProductAttention(FusedModule):
266
268
 
267
269
  When the SDPA has been surrounded by ``QuantStub``\ s on its Q/K/V
268
270
  inputs *and* the internal softmax quant stub is enabled, performs
269
- manual attention with a quantization step between softmax and
270
- BMM2. Otherwise delegates to the wrapped attention module so
271
- TensorRT can fuse it into its native FP16 MHA kernel.
271
+ manual attention with a quantization step between softmax and BMM2.
272
+ Otherwise delegates to the wrapped attention module so TensorRT can
273
+ fuse it into its native FP16 MHA kernel.
272
274
 
273
275
  :param q:
274
276
  Query tensor ``[B, num_heads, S, head_dim]``.
@@ -281,7 +283,7 @@ class FusedScaledDotProductAttention(FusedModule):
281
283
  additive float mask broadcastable to ``[B, num_heads, S, S]``
282
284
  or a bool mask where ``True`` means "attend".
283
285
  :returns:
284
- Output tensor ``[B, num_heads, S, head_dim]``. Callers are
286
+ Output tensor ``[B, num_heads, S, head_dim]``. Callers are
285
287
  responsible for any subsequent head-flattening reshape.
286
288
  """
287
289
  # Manual attention is only beneficial when this SDPA was
@@ -292,7 +294,7 @@ class FusedScaledDotProductAttention(FusedModule):
292
294
  # MHA kernel onto the slower INT8-aware variant for no gain.
293
295
  if not self.surrounded or not self.softmax_quant.enabled:
294
296
  return self.attention(q, k, v, attn_mask)
295
- # Honour the wrapped attention module's explicit ``scale`` if
297
+ # Honor the wrapped attention module's explicit ``scale`` if
296
298
  # set — models that pre-scale Q themselves (chronos-2 + RoPE,
297
299
  # for example) build with ``scale=1.0`` to disable the default
298
300
  # ``1/sqrt(head_dim)`` scaling. Falling back to the default
@@ -3,7 +3,7 @@
3
3
  """Fused ``nn.Module`` replacements for convolution-based patterns.
4
4
 
5
5
  Each class represents a hardware-fusible operation that replaces a multi-op
6
- chain found by the pattern matcher. The fused module keeps the original sub-
6
+ chain found by the pattern matcher. The fused module keeps the original sub-
7
7
  modules (``Conv``, ``BN``, ``ReLU``) as children so that:
8
8
 
9
9
  * Weights are trivially transferred from the original model.
@@ -25,11 +25,11 @@ def _is_int8_compatible_conv(conv: nn.Conv2d) -> bool:
25
25
  """Return ``True`` unless *conv* is a grouped conv violating TRT INT8.
26
26
 
27
27
  TensorRT's documented constraint for ``IConvolutionLayer`` is that
28
- ``in_channels / groups`` and ``out_channels / groups`` must both
29
- be multiples of 4 in INT8 mode. Depthwise convolutions
30
- (``groups == in_channels``) are an exception: our benchmarks on
31
- the target devices show they still benefit from INT8 despite
32
- channels-per-group being 1, so we let them through.
28
+ ``in_channels / groups`` and ``out_channels / groups`` must both be
29
+ multiples of 4 in INT8 mode. Depthwise convolutions (``groups ==
30
+ in_channels``) are an exception: our benchmarks on the target devices show
31
+ they still benefit from INT8 despite channels-per-group being 1, so we let
32
+ them through.
33
33
  """
34
34
  if conv.groups <= 1:
35
35
  return True
@@ -51,7 +51,6 @@ def _conv_weight_forward(
51
51
  if weight_fake_quant is not None
52
52
  else conv.weight
53
53
  )
54
- # pylint: disable-next=not-callable
55
54
  return F.conv2d(
56
55
  x,
57
56
  weight,
@@ -83,6 +82,10 @@ class FusedConvBNAct(FusedModule):
83
82
  else:
84
83
  self.input_quant_stubs = {}
85
84
 
85
+ @property
86
+ def quantized_weight(self) -> torch.Tensor | None:
87
+ return self.conv.weight
88
+
86
89
  def forward(self, x: torch.Tensor) -> torch.Tensor:
87
90
  """Apply ``conv → [bn] → act``."""
88
91
  wfq = getattr(self, 'weight_fake_quant', None)
@@ -121,6 +124,10 @@ class FusedConvBN(FusedModule):
121
124
  else:
122
125
  self.input_quant_stubs = {}
123
126
 
127
+ @property
128
+ def quantized_weight(self) -> torch.Tensor | None:
129
+ return self.conv.weight
130
+
124
131
  def forward(self, x: torch.Tensor) -> torch.Tensor:
125
132
  """Apply ``conv → [bn]``."""
126
133
  wfq = getattr(self, 'weight_fake_quant', None)
@@ -160,6 +167,10 @@ class FusedConvBNActMaxPool(FusedModule):
160
167
  self.maxpool = maxpool
161
168
  self.weight_fake_quant = WeightFakeQuantize({self})
162
169
 
170
+ @property
171
+ def quantized_weight(self) -> torch.Tensor | None:
172
+ return self.conv.weight
173
+
163
174
  def forward(self, x: torch.Tensor) -> torch.Tensor:
164
175
  """Apply ``conv → [bn] → act → maxpool``."""
165
176
  x = _conv_weight_forward(self.conv, self.weight_fake_quant, x)
@@ -206,6 +217,10 @@ class FusedConvBNAddAct(FusedModule):
206
217
  else:
207
218
  self.input_quant_stubs = {}
208
219
 
220
+ @property
221
+ def quantized_weight(self) -> torch.Tensor | None:
222
+ return self.conv.weight
223
+
209
224
  def forward(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
210
225
  """Apply ``conv → bn → add(·, residual) → act``."""
211
226
  wfq = getattr(self, 'weight_fake_quant', None)