embedl-deploy-tensorrt 0.6.1__tar.gz → 0.7.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/PKG-INFO +2 -1
  2. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/README.md +1 -0
  3. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/modules/AGENTS.md +41 -29
  4. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/modules/attention.py +32 -27
  5. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/modules/conv.py +110 -45
  6. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/modules/linear.py +10 -54
  7. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/modules/pool.py +7 -3
  8. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/modules/swin_attention.py +8 -3
  9. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/patterns/AGENTS.md +68 -18
  10. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions.py +50 -16
  11. embedl_deploy_tensorrt-0.7.0/src/embedl_deploy/_internal/tensorrt/patterns/quantizations.py +590 -0
  12. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/plan.py +10 -0
  13. embedl_deploy_tensorrt-0.7.0/src/embedl_deploy/_internal/tensorrt/precision.py +197 -0
  14. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/tensorrt/__init__.py +4 -0
  15. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/version/public.py +1 -1
  16. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy_tensorrt.egg-info/SOURCES.txt +1 -0
  17. embedl_deploy_tensorrt-0.6.1/src/embedl_deploy/_internal/tensorrt/patterns/quantizations.py +0 -424
  18. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/LICENSE +0 -0
  19. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/MANIFEST.in +0 -0
  20. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/NOTICE +0 -0
  21. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/pyproject.toml +0 -0
  22. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/setup.cfg +0 -0
  23. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/__init__.py +0 -0
  24. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/__init__.py +0 -0
  25. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/__init__.py +0 -0
  26. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/backend.py +0 -0
  27. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/modules/__init__.py +0 -0
  28. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/modules/pointwise.py +0 -0
  29. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/patterns/__init__.py +0 -0
  30. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/__init__.py +0 -0
  31. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/attention.py +0 -0
  32. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/general.py +0 -0
  33. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/_internal/tensorrt/patterns/smoothings.py +0 -0
  34. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/tensorrt/modules/__init__.py +0 -0
  35. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/tensorrt/patterns/__init__.py +0 -0
  36. {embedl_deploy_tensorrt-0.6.1 → embedl_deploy_tensorrt-0.7.0}/src/embedl_deploy/version/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: embedl-deploy-tensorrt
3
- Version: 0.6.1
3
+ Version: 0.7.0
4
4
  Summary: TensorRT backend for embedl-deploy.
5
5
  Author-email: Embedl AB <support@embedl.com>
6
6
  Project-URL: Homepage, https://www.embedl.com/
@@ -58,6 +58,7 @@ hardware target ensuring correct quantization and compilation.
58
58
  |---------------------------|-----------------|
59
59
  | NVIDIA TensorRT (v10.3) | Supported |
60
60
  | Lattice SensAI (v8.0) | In Development |
61
+ | AMD Vitis AI | In Development |
61
62
 
62
63
  Contact Embedl for other backends.
63
64
 
@@ -39,6 +39,7 @@ hardware target ensuring correct quantization and compilation.
39
39
  |---------------------------|-----------------|
40
40
  | NVIDIA TensorRT (v10.3) | Supported |
41
41
  | Lattice SensAI (v8.0) | In Development |
42
+ | AMD Vitis AI | In Development |
42
43
 
43
44
  Contact Embedl for other backends.
44
45
 
@@ -31,10 +31,10 @@ core/quantize/prepare.py # walks the graph, uses isinstance(mod, FusedModu
31
31
  ```
32
32
 
33
33
  **Instantiated by:** pattern grafts and `replace()` methods in `patterns/fusions.py`
34
- and `patterns/conversions/`. Most fusion patterns declare a `graft` attribute
35
- pointing to the fused module class; the graft system calls `_collect_modules()`
36
- to gather matched modules in tree order and passes them as positional arguments
37
- to the constructor. For example, `ConvBNActPattern` grafts `FusedConvBNAct`;
34
+ and `patterns/conversions/`. Most fusion patterns declare `graft = (make_fused(FusedFoo),)`;
35
+ `make_fused` gathers the matched modules in tree order (via `_collect_modules()`)
36
+ and passes them as positional arguments to the constructor. For example,
37
+ `ConvBNActPattern` grafts `(make_fused(FusedConvBNAct),)`;
38
38
  `DecomposeMultiheadAttentionPattern.replace()` constructs `MHAInProjection` and
39
39
  `ScaledDotProductAttention`.
40
40
 
@@ -61,15 +61,18 @@ Every `FusedModule` subclass must satisfy three requirements:
61
61
  2. **Call `super().__init__()`**, which creates `self.input_quant_stubs` — a dict
62
62
  mapping each index in `inputs_to_quantize` to a fresh `QuantStub`. The Q/DQ
63
63
  pass later enables and configures these stubs during `prepare_qdq()`. It also
64
- initialises `self.surrounded = False`, which is later set to `True` by
65
- `SurroundWithQuantStubsPattern` to mark modules that have been surrounded
66
- with input `QuantStub` entries.
64
+ initialises `self.output_precision = Precision.UNSET`. Surround-type modules
65
+ override this to `Precision.DEFERRED` in their constructors to signal that
66
+ their INT8 decision depends on graph context (see
67
+ `FusedAdaptiveAvgPool2d`, `FusedScaledDotProductAttention`,
68
+ `FusedSwinAttention`). Quantization patterns read and update
69
+ `output_precision` to track precision flow through the graph.
67
70
 
68
71
  3. **Implement `forward()`** with the fused computation the module represents.
69
72
 
70
73
  ### Graft compatibility
71
74
 
72
- When a pattern uses `graft = FusedFoo` (bare class), the graft system calls
75
+ When a pattern uses `graft = (make_fused(FusedFoo),)`, `make_fused` calls
73
76
  `_collect_modules()` to walk the matched tree and collect the `nn.Module`
74
77
  instances corresponding to trunk and fork nodes (nested branches first, then
75
78
  trunk nodes). These are passed as positional arguments to the constructor.
@@ -155,11 +158,14 @@ Fusing them together avoids an extra quantized activation between Conv and Pool.
155
158
  `inputs_to_quantize = {0, 1}`. The residual tensor is the second input (index 1),
156
159
  hence both inputs are quantized. This is the ResNet skip-connection block tail.
157
160
 
158
- **INT8 compatibility guard:** grouped convolutions where `in_channels / groups` or
159
- `out_channels / groups` is not a multiple of 4 cannot be quantized to INT8 in
160
- TensorRT. For those cases `_is_int8_compatible_conv()` returns `False`, and the
161
- module sets `self.input_quant_stubs = {}` (overriding the `super().__init__()`
162
- default), effectively opting out of quantization.
161
+ **INT8 compatibility:** all conv fused modules unconditionally create a
162
+ `WeightFakeQuantize` and `input_quant_stubs` in their constructors. INT8
163
+ compatibility filtering (grouped convolutions where `in_channels / groups` or
164
+ `out_channels / groups` is not a multiple of 4, depthwise convolutions) is
165
+ deferred to `DisableInt8Pattern` in the quantization pass, which
166
+ calls `disable_int8()` on incompatible modules. Depthwise convolutions set
167
+ `output_precision = Precision.DEFERRED` so the surround pass can decide
168
+ contextually whether to enable INT8.
163
169
 
164
170
  ### Linear family — `linear.py`
165
171
 
@@ -217,7 +223,9 @@ is quantized; `_key` and `_value` are accepted to match the self-attention
217
223
  call-site but ignored.
218
224
 
219
225
  **`FusedScaledDotProductAttention`** — wraps `ScaledDotProductAttention`,
220
- `inputs_to_quantize = set()`. Adds an internal `softmax_quant` stub with a fixed
226
+ `inputs_to_quantize = {0, 1, 2}`. The pre-declared stubs (Q, K, V) are created
227
+ disabled; `SurroundWithQuantStubsPattern` enables them when the module is
228
+ surrounded with Q/DQ stubs. Adds an internal `softmax_quant` stub with a fixed
221
229
  calibration of `(1/127, 0)` — i.e., 8-bit symmetric with a fixed scale matched to
222
230
  the softmax output range `[0, 1]`. When the stub is disabled the module delegates
223
231
  to the plain SDPA; when enabled it performs manual attention with the quantization
@@ -262,10 +270,12 @@ three modules. This is intentional: the three modules reference the *same*
262
270
  referencing attributes to point to the copy. However, the sharing is not
263
271
  thread-safe (see Gotchas).
264
272
 
265
- **`FusedSwinAttention`** — wraps `SwinAttention`, `inputs_to_quantize = set()`.
266
- Mirrors `FusedScaledDotProductAttention`: adds an internal `softmax_quant` stub
267
- with fixed calibration. When enabled it manually expands the attention computation
268
- to insert the quantization step between softmax and BMM2.
273
+ **`FusedSwinAttention`** — wraps `SwinAttention`, `inputs_to_quantize = {0, 1, 2}`.
274
+ The pre-declared stubs (Q, K, V) are created disabled;
275
+ `SurroundWithQuantStubsPattern` enables them. Mirrors
276
+ `FusedScaledDotProductAttention`: adds an internal `softmax_quant` stub with fixed
277
+ calibration. When enabled it manually expands the attention computation to insert
278
+ the quantization step between softmax and BMM2.
269
279
 
270
280
  ### Pointwise family — `pointwise.py`
271
281
 
@@ -283,9 +293,9 @@ quantized at their respective scales.
283
293
  Pattern that creates it: `patterns/fusions.py`.
284
294
 
285
295
  **`FusedAdaptiveAvgPool2d`** — wraps `nn.AdaptiveAvgPool2d`,
286
- `inputs_to_quantize = set()`. No quantization is applied (pooling is a
287
- linear operation that does not benefit from separate quantization). Exists as a
288
- `FusedModule` so the Q/DQ pass treats it uniformly without special-casing it.
296
+ `inputs_to_quantize = {0}`. Defers output precision to the surround pattern.
297
+ The pre-declared stub is created disabled; `SurroundWithQuantStubsPattern`
298
+ enables it when the module is surrounded with Q/DQ stubs.
289
299
 
290
300
  ---
291
301
 
@@ -355,11 +365,13 @@ weights are modified after fusion, the fused module sees the change. This is
355
365
  usually desirable (e.g. for QAT gradient updates), but can be surprising if the
356
366
  original model is used independently.
357
367
 
358
- **Grouped conv INT8 opt-out:** When `_is_int8_compatible_conv()` returns `False`
359
- the `__init__` of the conv fused modules sets `self.input_quant_stubs = {}`,
360
- overriding the dict populated by `FusedModule.__init__()`. This means the module
361
- is effectively excluded from quantization despite being a `FusedModule`. The
362
- `weight_fake_quant` attribute is also not created in this path.
368
+ **Grouped conv INT8 opt-out:** INT8 compatibility for grouped convolutions is
369
+ no longer handled in the module constructor. Instead, `DisableInt8Pattern`
370
+ calls `disable_int8()` on modules where `is_int8_beneficial_conv()` returns
371
+ `False`. The module still has `input_quant_stubs` and `weight_fake_quant` after
372
+ construction they are disabled by the pattern pass. This means
373
+ `bool(mod.input_quant_stubs)` is `True` even for incompatible convolutions until
374
+ the quantization pass runs.
363
375
 
364
376
  ---
365
377
 
@@ -372,11 +384,11 @@ is effectively excluded from quantization despite being a `FusedModule`. The
372
384
  `__init__` if the module has a learnable weight that should be fake-quantized.
373
385
  4. **Implement `forward()`** with the fused computation.
374
386
  5. **Write a `Pattern` subclass** in `patterns/fusions.py`. Prefer declaring
375
- `graft = FusedFoo` (bare class) so the graft system handles replacement
387
+ `graft = (make_fused(FusedFoo),)` so the graft system handles replacement
376
388
  automatically. The constructor must accept modules in tree order (nested
377
389
  branches first, then trunk nodes). If the replacement logic cannot be
378
- expressed as a bare-class graft, provide a `ReplacementMaker` or a custom
379
- `replace()` method instead.
390
+ expressed with `make_fused`, provide a different `ReplacementMaker` or a
391
+ custom `replace()` method instead.
380
392
  6. **Add the pattern to `TENSORRT_PATTERNS`** (or the appropriate pattern list) in
381
393
  `tensorrt/plan.py`.
382
394
  7. **Write tests** in `tests/tensorrt/patterns/fusions/` covering: pattern match,
@@ -13,11 +13,14 @@ import torch
13
13
  import torch.nn.functional as F
14
14
  from torch import nn
15
15
 
16
- from embedl_deploy._internal.core.modules import ConvertedModule, FusedModule
17
- from embedl_deploy._internal.core.quantize.stubs import QuantStub
18
- from embedl_deploy._internal.tensorrt.modules.linear import (
19
- attach_int8_weight_quant,
20
- maybe_quantize_weight,
16
+ from embedl_deploy._internal.core.modules import (
17
+ ConvertedModule,
18
+ FusedModule,
19
+ Precision,
20
+ )
21
+ from embedl_deploy._internal.core.quantize.stubs import (
22
+ QuantStub,
23
+ WeightFakeQuantize,
21
24
  )
22
25
 
23
26
 
@@ -181,7 +184,7 @@ class FusedMHAInProjection(FusedModule):
181
184
  def __init__(self, in_proj: MHAInProjection) -> None:
182
185
  super().__init__()
183
186
  self.in_proj = in_proj
184
- attach_int8_weight_quant(self, in_proj.linear)
187
+ self.weight_fake_quant = WeightFakeQuantize({self})
185
188
 
186
189
  @property
187
190
  def quantized_weight(self) -> torch.Tensor | None:
@@ -205,7 +208,7 @@ class FusedMHAInProjection(FusedModule):
205
208
  :returns:
206
209
  Tuple ``(Q, K, V)`` each of shape ``[B, num_heads, S, head_dim]``.
207
210
  """
208
- weight = maybe_quantize_weight(self, self.in_proj.linear.weight)
211
+ weight = self.weight_fake_quant(self.in_proj.linear.weight)
209
212
  batch, seq, _ = query.shape
210
213
  qkv = F.linear(query, weight, self.in_proj.linear.bias)
211
214
  q, k, v = qkv.chunk(3, dim=-1)
@@ -229,15 +232,13 @@ class FusedMHAInProjection(FusedModule):
229
232
  class FusedScaledDotProductAttention(FusedModule):
230
233
  """Fused wrapper for ``ScaledDotProductAttention``.
231
234
 
232
- Allows the Q/DQ insertion pass to place quantize / dequantize stubs on each
233
- of the three inputs (Q, K, V).
234
-
235
- Additionally holds an internal
235
+ Quantizes the three inputs (Q, K, V) and holds an internal
236
236
  :class:`~embedl_deploy._internal.core.quantize.stubs.QuantStub` between the
237
- softmax output and the second batched matrix multiply (BMM2). When that
238
- stub is disabled the forward pass delegates to the unwrapped
237
+ softmax output and the second batched matrix multiply (BMM2). When
238
+ ``output_precision`` is still ``DEFERRED`` or the softmax stub is disabled,
239
+ the forward pass delegates to the unwrapped
239
240
  :class:`~embedl_deploy._internal.tensorrt.modules.attention.ScaledDotProductAttention`;
240
- when enabled it performs manual attention with the quantization step.
241
+ otherwise it performs manual attention with the quantization step.
241
242
 
242
243
  :param attention:
243
244
  The
@@ -245,11 +246,12 @@ class FusedScaledDotProductAttention(FusedModule):
245
246
  from the decomposed MHA.
246
247
  """
247
248
 
248
- inputs_to_quantize: set[int] = set()
249
+ inputs_to_quantize: set[int] = {0, 1, 2}
249
250
 
250
251
  def __init__(self, attention: ScaledDotProductAttention) -> None:
251
252
  super().__init__()
252
253
  self.attention = attention
254
+ self.output_precision = Precision.DEFERRED
253
255
  self.softmax_quant = QuantStub(
254
256
  consumers={self},
255
257
  n_bits=8,
@@ -266,11 +268,11 @@ class FusedScaledDotProductAttention(FusedModule):
266
268
  ) -> torch.Tensor:
267
269
  r"""Compute scaled dot-product attention.
268
270
 
269
- When the SDPA has been surrounded by ``QuantStub``\ s on its Q/K/V
270
- inputs *and* the internal softmax quant stub is enabled, performs
271
- manual attention with a quantization step between softmax and BMM2.
272
- Otherwise delegates to the wrapped attention module so TensorRT can
273
- fuse it into its native FP16 MHA kernel.
271
+ When ``output_precision`` has been resolved (no longer ``DEFERRED``)
272
+ and the internal softmax quant stub is enabled, performs manual
273
+ attention with a quantization step between softmax and BMM2. Otherwise
274
+ delegates to the wrapped attention module so TensorRT can fuse it into
275
+ its native FP16 MHA kernel.
274
276
 
275
277
  :param q:
276
278
  Query tensor ``[B, num_heads, S, head_dim]``.
@@ -286,13 +288,16 @@ class FusedScaledDotProductAttention(FusedModule):
286
288
  Output tensor ``[B, num_heads, S, head_dim]``. Callers are
287
289
  responsible for any subsequent head-flattening reshape.
288
290
  """
289
- # Manual attention is only beneficial when this SDPA was
290
- # surrounded with input ``QuantStub``s (i.e. Q/K/V are arriving
291
- # in INT8). Without surround, ``configure`` may still have left
292
- # ``softmax_quant`` enabled — running manual attention then adds
293
- # a softmax Q/DQ pair that pushes TensorRT off its FP16 fused
294
- # MHA kernel onto the slower INT8-aware variant for no gain.
295
- if not self.surrounded or not self.softmax_quant.enabled:
291
+ # Manual attention is only beneficial when output_precision has
292
+ # been resolved (i.e. Q/K/V stubs are active and arriving in
293
+ # INT8). While still DEFERRED, ``softmax_quant`` may be enabled
294
+ # but running manual attention would add a softmax Q/DQ pair
295
+ # that pushes TensorRT off its FP16 fused MHA kernel onto the
296
+ # slower INT8-aware variant for no gain.
297
+ if (
298
+ self.output_precision == Precision.DEFERRED
299
+ or not self.softmax_quant.enabled
300
+ ):
296
301
  return self.attention(q, k, v, attn_mask)
297
302
  # Honor the wrapped attention module's explicit ``scale`` if
298
303
  # set — models that pre-scale Q themselves (chronos-2 + RoPE,
@@ -15,29 +15,20 @@ import torch
15
15
  import torch.nn.functional as F
16
16
  from torch import nn
17
17
 
18
- from embedl_deploy._internal.core.modules import ActivationLike, FusedModule
18
+ from embedl_deploy._internal.core.modules import (
19
+ ActivationLike,
20
+ FusedModule,
21
+ Precision,
22
+ )
19
23
  from embedl_deploy._internal.core.quantize.stubs import (
24
+ QuantStub,
20
25
  WeightFakeQuantize,
21
26
  )
22
27
 
23
28
 
24
- def _is_int8_compatible_conv(conv: nn.Conv2d) -> bool:
25
- """Return ``True`` unless *conv* is a grouped conv violating TRT INT8.
26
-
27
- TensorRT's documented constraint for ``IConvolutionLayer`` is that
28
- ``in_channels / groups`` and ``out_channels / groups`` must both be
29
- multiples of 4 in INT8 mode. Depthwise convolutions (``groups ==
30
- in_channels``) are an exception: our benchmarks on the target devices show
31
- they still benefit from INT8 despite channels-per-group being 1, so we let
32
- them through.
33
- """
34
- if conv.groups <= 1:
35
- return True
36
- if conv.groups == conv.in_channels:
37
- return True
38
- in_per_group: int = conv.in_channels // conv.groups
39
- out_per_group: int = conv.out_channels // conv.groups
40
- return in_per_group % 4 == 0 and out_per_group % 4 == 0
29
+ def is_depthwise_conv(conv: nn.Conv2d) -> bool:
30
+ """Return ``True`` when *conv* is depthwise."""
31
+ return conv.groups == conv.in_channels
41
32
 
42
33
 
43
34
  def _conv_weight_forward(
@@ -51,6 +42,7 @@ def _conv_weight_forward(
51
42
  if weight_fake_quant is not None
52
43
  else conv.weight
53
44
  )
45
+ # pylint: disable-next=not-callable
54
46
  return F.conv2d(
55
47
  x,
56
48
  weight,
@@ -63,7 +55,10 @@ def _conv_weight_forward(
63
55
 
64
56
 
65
57
  class FusedConvBNAct(FusedModule):
66
- """Fused ``Conv2d → [BatchNorm2d] → Act``."""
58
+ """Fused ``Conv2d → [BatchNorm2d] → Act``.
59
+
60
+ Depthwise convolutions defer output precision to the surround pattern.
61
+ """
67
62
 
68
63
  inputs_to_quantize: set[int] = {0}
69
64
 
@@ -77,10 +72,9 @@ class FusedConvBNAct(FusedModule):
77
72
  self.conv = conv
78
73
  self.bn = bn
79
74
  self.act = act
80
- if _is_int8_compatible_conv(conv):
81
- self.weight_fake_quant = WeightFakeQuantize({self})
82
- else:
83
- self.input_quant_stubs = {}
75
+ self.weight_fake_quant = WeightFakeQuantize({self})
76
+ if is_depthwise_conv(conv):
77
+ self.output_precision = Precision.DEFERRED
84
78
 
85
79
  @property
86
80
  def quantized_weight(self) -> torch.Tensor | None:
@@ -88,7 +82,7 @@ class FusedConvBNAct(FusedModule):
88
82
 
89
83
  def forward(self, x: torch.Tensor) -> torch.Tensor:
90
84
  """Apply ``conv → [bn] → act``."""
91
- wfq = getattr(self, 'weight_fake_quant', None)
85
+ wfq = getattr(self, "weight_fake_quant", None)
92
86
  x = _conv_weight_forward(self.conv, wfq, x)
93
87
  if self.bn is not None:
94
88
  x = self.bn(x)
@@ -107,7 +101,10 @@ class FusedConvBNAct(FusedModule):
107
101
 
108
102
 
109
103
  class FusedConvBN(FusedModule):
110
- """Fused ``Conv2d → [BatchNorm2d]`` (no activation)."""
104
+ """Fused ``Conv2d → [BatchNorm2d]`` (no activation).
105
+
106
+ Depthwise convolutions defer output precision to the surround pattern.
107
+ """
111
108
 
112
109
  inputs_to_quantize: set[int] = {0}
113
110
 
@@ -119,10 +116,9 @@ class FusedConvBN(FusedModule):
119
116
  super().__init__()
120
117
  self.conv = conv
121
118
  self.bn = bn
122
- if _is_int8_compatible_conv(conv):
123
- self.weight_fake_quant = WeightFakeQuantize({self})
124
- else:
125
- self.input_quant_stubs = {}
119
+ self.weight_fake_quant = WeightFakeQuantize({self})
120
+ if is_depthwise_conv(conv):
121
+ self.output_precision = Precision.DEFERRED
126
122
 
127
123
  @property
128
124
  def quantized_weight(self) -> torch.Tensor | None:
@@ -130,7 +126,7 @@ class FusedConvBN(FusedModule):
130
126
 
131
127
  def forward(self, x: torch.Tensor) -> torch.Tensor:
132
128
  """Apply ``conv → [bn]``."""
133
- wfq = getattr(self, 'weight_fake_quant', None)
129
+ wfq = getattr(self, "weight_fake_quant", None)
134
130
  x = _conv_weight_forward(self.conv, wfq, x)
135
131
  if self.bn is not None:
136
132
  x = self.bn(x)
@@ -173,7 +169,8 @@ class FusedConvBNActMaxPool(FusedModule):
173
169
 
174
170
  def forward(self, x: torch.Tensor) -> torch.Tensor:
175
171
  """Apply ``conv → [bn] → act → maxpool``."""
176
- x = _conv_weight_forward(self.conv, self.weight_fake_quant, x)
172
+ wfq = getattr(self, "weight_fake_quant", None)
173
+ x = _conv_weight_forward(self.conv, wfq, x)
177
174
  if self.bn is not None:
178
175
  x = self.bn(x)
179
176
  x = self.act(x)
@@ -194,10 +191,12 @@ class FusedConvBNActMaxPool(FusedModule):
194
191
 
195
192
 
196
193
  class FusedConvBNAddAct(FusedModule):
197
- """Fused ``Conv2d → BatchNorm2d → add(·, residual) → Activation``.
194
+ """Fused ``Conv2d → [BatchNorm2d] → add(·, residual) → [Activation]``.
198
195
 
199
196
  ``forward()`` accepts two inputs: the main tensor ``x`` and the
200
- ``residual`` tensor.
197
+ ``residual`` tensor. Both the ``BatchNorm2d`` and the trailing activation
198
+ are optional so that EfficientNet-style ``Conv → BN → Add`` blocks (no
199
+ activation) are captured.
201
200
  """
202
201
 
203
202
  inputs_to_quantize: set[int] = {0, 1}
@@ -205,33 +204,99 @@ class FusedConvBNAddAct(FusedModule):
205
204
  def __init__(
206
205
  self,
207
206
  conv: nn.Conv2d,
208
- bn: nn.BatchNorm2d,
209
- act: ActivationLike,
207
+ bn: nn.BatchNorm2d | None,
208
+ act: ActivationLike | None,
210
209
  ) -> None:
211
210
  super().__init__()
212
211
  self.conv = conv
213
212
  self.bn = bn
214
213
  self.act = act
215
- if _is_int8_compatible_conv(conv):
216
- self.weight_fake_quant = WeightFakeQuantize({self})
217
- else:
218
- self.input_quant_stubs = {}
214
+ self.weight_fake_quant = WeightFakeQuantize({self})
219
215
 
220
216
  @property
221
217
  def quantized_weight(self) -> torch.Tensor | None:
222
218
  return self.conv.weight
223
219
 
224
220
  def forward(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
225
- """Apply ``conv → bn → add(·, residual) → act``."""
226
- wfq = getattr(self, 'weight_fake_quant', None)
221
+ """Apply ``conv → [bn] → add(·, residual) → [act]``."""
222
+ wfq = getattr(self, "weight_fake_quant", None)
227
223
  x = _conv_weight_forward(self.conv, wfq, x)
228
- x = self.bn(x)
229
- return self.act(x + residual)
224
+ if self.bn is not None:
225
+ x = self.bn(x)
226
+ x = x + residual
227
+ if self.act is not None:
228
+ x = self.act(x)
229
+ return x
230
230
 
231
231
  def __repr__(self) -> str: # pragma: no cover
232
+ bn_info = ""
233
+ if self.bn is not None:
234
+ bn_info = f", bn={self.bn.num_features} (foldable)"
235
+ act_info = ""
236
+ if self.act is not None:
237
+ act_info = f", act={type(self.act).__name__}"
232
238
  return (
233
239
  f"FusedConvBNAddAct("
234
240
  f"{self.conv.in_channels}→{self.conv.out_channels}, "
235
- f"k={self.conv.kernel_size}, s={self.conv.stride}, "
236
- f"bn={self.bn.num_features} (foldable))"
241
+ f"k={self.conv.kernel_size}, s={self.conv.stride}"
242
+ f"{bn_info}{act_info})"
243
+ )
244
+
245
+
246
+ class FusedConvBNSigmoidMul(FusedModule):
247
+ """Fused ``Conv2d → [BatchNorm2d] → Sigmoid → Mul(·, skip)``.
248
+
249
+ Captures the SE gate pattern where an expand convolution produces channel
250
+ attention weights via Sigmoid, then element-wise multiplies with the skip
251
+ connection. An internal
252
+ :class:`~embedl_deploy._internal.core.quantize.stubs.QuantStub` between the
253
+ conv/BN output and the Sigmoid produces the Q/DQ pair that enables
254
+ TensorRT's ``PWN(Sigmoid, Mul)`` fusion.
255
+
256
+ ``forward()`` accepts two inputs: the main tensor ``x`` feeding the
257
+ convolution, and the ``skip`` tensor multiplied by the sigmoid gate.
258
+ """
259
+
260
+ inputs_to_quantize: set[int] = {0, 1}
261
+
262
+ def __init__(
263
+ self,
264
+ conv: nn.Conv2d,
265
+ bn: nn.BatchNorm2d | None,
266
+ sigmoid: nn.Sigmoid,
267
+ ) -> None:
268
+ super().__init__()
269
+ self.conv = conv
270
+ self.bn = bn
271
+ self.sigmoid = sigmoid
272
+ self.weight_fake_quant = WeightFakeQuantize({self})
273
+ self.gate_quant = QuantStub({self})
274
+
275
+ @property
276
+ def quantized_weight(self) -> torch.Tensor | None:
277
+ return self.conv.weight
278
+
279
+ def forward(
280
+ self,
281
+ x: torch.Tensor,
282
+ skip: torch.Tensor,
283
+ ) -> torch.Tensor:
284
+ """Apply ``conv → [bn] → gate_quant → sigmoid → mul(·, skip)``."""
285
+ wfq = getattr(self, "weight_fake_quant", None)
286
+ x = _conv_weight_forward(self.conv, wfq, x)
287
+ if self.bn is not None:
288
+ x = self.bn(x)
289
+ x = self.gate_quant(x)
290
+ x = self.sigmoid(x)
291
+ return x * skip
292
+
293
+ def __repr__(self) -> str: # pragma: no cover
294
+ bn_info = ""
295
+ if self.bn is not None:
296
+ bn_info = f", bn={self.bn.num_features} (foldable)"
297
+ return (
298
+ f"FusedConvBNSigmoidMul("
299
+ f"{self.conv.in_channels}→{self.conv.out_channels}, "
300
+ f"k={self.conv.kernel_size}, s={self.conv.stride}"
301
+ f"{bn_info})"
237
302
  )
@@ -8,60 +8,16 @@ import torch
8
8
  import torch.nn.functional as F
9
9
  from torch import nn
10
10
 
11
- from embedl_deploy._internal.core.modules import ActivationLike, FusedModule
11
+ from embedl_deploy._internal.core.modules import (
12
+ ActivationLike,
13
+ FusedModule,
14
+ Precision,
15
+ )
12
16
  from embedl_deploy._internal.core.quantize.stubs import (
13
17
  SmoothQuantObserver,
14
18
  WeightFakeQuantize,
15
19
  )
16
20
 
17
- #: Minimum ``K * N / (K + N)`` for INT8 to outperform FP16.
18
- INT8_LINEAR_MIN_RATIO: int = 256
19
-
20
-
21
- def is_int8_beneficial_linear(linear: nn.Linear) -> bool:
22
- """Return ``True`` when INT8 quantization benefits *linear*.
23
-
24
- Uses the harmonic mean of the weight dimensions ``K * N / (K + N)`` as a
25
- proxy for the ratio of INT8 compute savings to Q/DQ reformat overhead.
26
- Below
27
- :data:`~embedl_deploy._internal.tensorrt.modules.linear.INT8_LINEAR_MIN_RATIO`,
28
- the overhead from quantize/dequantize boundary layers exceeds any INT8 GEMM
29
- speedup and the layer is better left in FP16.
30
-
31
- Reference: NVIDIA benchmarks show INT8 GEMM outperforms FP16 only when all
32
- three matrix dimensions exceed ~2048 (A100). The harmonic mean threshold of
33
- 256 conservatively separates mobile-class models (MobileViT FFN ratio ≤
34
- 160) from server-class models (ViT-B/16 FFN ratio = 614) where INT8 is
35
- beneficial.
36
- """
37
- k, n = linear.in_features, linear.out_features
38
- return k * n / (k + n) >= INT8_LINEAR_MIN_RATIO
39
-
40
-
41
- def attach_int8_weight_quant(
42
- mod: FusedModule,
43
- linear: nn.Linear,
44
- ) -> None:
45
- """Attach a ``WeightFakeQuantize`` to *mod* when INT8 helps *linear*.
46
-
47
- When INT8 wouldn't pay for its Q/DQ boundary cost, also clear
48
- ``mod.input_quant_stubs`` so the surrounding Q/DQ pass leaves the wrapped
49
- linear entirely in FP16.
50
- """
51
- if is_int8_beneficial_linear(linear):
52
- mod.weight_fake_quant = WeightFakeQuantize({mod})
53
- else:
54
- mod.input_quant_stubs = {}
55
-
56
-
57
- def maybe_quantize_weight(
58
- mod: nn.Module,
59
- weight: torch.Tensor,
60
- ) -> torch.Tensor:
61
- """Fake-quantize *weight* through ``mod.weight_fake_quant`` if present."""
62
- wfq = getattr(mod, "weight_fake_quant", None)
63
- return wfq(weight) if wfq is not None else weight
64
-
65
21
 
66
22
  class FusedLinear(FusedModule):
67
23
  """Fused wrapper for a standalone ``Linear`` layer.
@@ -75,7 +31,7 @@ class FusedLinear(FusedModule):
75
31
  def __init__(self, linear: nn.Linear) -> None:
76
32
  super().__init__()
77
33
  self.linear = linear
78
- attach_int8_weight_quant(self, linear)
34
+ self.weight_fake_quant = WeightFakeQuantize({self})
79
35
 
80
36
  @property
81
37
  def quantized_weight(self) -> torch.Tensor | None:
@@ -83,7 +39,7 @@ class FusedLinear(FusedModule):
83
39
 
84
40
  def forward(self, x: torch.Tensor) -> torch.Tensor:
85
41
  """Apply ``linear``, fake-quantizing the weight."""
86
- weight = maybe_quantize_weight(self, self.linear.weight)
42
+ weight = self.weight_fake_quant(self.linear.weight)
87
43
  return F.linear(x, weight, self.linear.bias)
88
44
 
89
45
  def __repr__(self) -> str: # pragma: no cover
@@ -108,7 +64,7 @@ class FusedLinearAct(FusedModule):
108
64
  super().__init__()
109
65
  self.linear = linear
110
66
  self.act = act
111
- attach_int8_weight_quant(self, linear)
67
+ self.weight_fake_quant = WeightFakeQuantize({self})
112
68
 
113
69
  @property
114
70
  def quantized_weight(self) -> torch.Tensor | None:
@@ -116,7 +72,7 @@ class FusedLinearAct(FusedModule):
116
72
 
117
73
  def forward(self, x: torch.Tensor) -> torch.Tensor:
118
74
  """Apply ``linear → activation``, fake-quantizing the weight."""
119
- weight = maybe_quantize_weight(self, self.linear.weight)
75
+ weight = self.weight_fake_quant(self.linear.weight)
120
76
  x = F.linear(x, weight, self.linear.bias)
121
77
  return self.act(x)
122
78
 
@@ -143,12 +99,12 @@ class FusedLayerNorm(FusedModule):
143
99
  The ``nn.LayerNorm`` from the matched chain.
144
100
  """
145
101
 
146
- prefers_fp_input: bool = True
147
102
  inputs_to_quantize: set[int] = set()
148
103
 
149
104
  def __init__(self, layer_norm: nn.LayerNorm) -> None:
150
105
  super().__init__()
151
106
  self.layer_norm = layer_norm
107
+ self.output_precision = Precision.INT8
152
108
  self.smooth_quant_observer = SmoothQuantObserver(
153
109
  consumers={self},
154
110
  layer_norm=layer_norm,
@@ -5,17 +5,21 @@
5
5
  import torch
6
6
  from torch import nn
7
7
 
8
- from embedl_deploy._internal.core.modules import FusedModule
8
+ from embedl_deploy._internal.core.modules import FusedModule, Precision
9
9
 
10
10
 
11
11
  class FusedAdaptiveAvgPool2d(FusedModule):
12
- """Fused wrapper for ``AdaptiveAvgPool2d``."""
12
+ """Fused wrapper for ``AdaptiveAvgPool2d``.
13
13
 
14
- inputs_to_quantize: set[int] = set()
14
+ Defers output precision to the surround pattern.
15
+ """
16
+
17
+ inputs_to_quantize: set[int] = {0}
15
18
 
16
19
  def __init__(self, pool: nn.AdaptiveAvgPool2d) -> None:
17
20
  super().__init__()
18
21
  self.pool = pool
22
+ self.output_precision = Precision.DEFERRED
19
23
 
20
24
  def forward(self, x: torch.Tensor) -> torch.Tensor:
21
25
  """Apply adaptive average pooling."""