embedl-deploy 0.2.0__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- embedl_deploy-0.3.0/MANIFEST.in +8 -0
- {embedl_deploy-0.2.0/src/embedl_deploy.egg-info → embedl_deploy-0.3.0}/PKG-INFO +1 -1
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/pattern.py +4 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/plan.py +23 -4
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/modules/attention.py +6 -5
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/modules/conv.py +10 -4
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/modules/linear.py +51 -4
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/attention.py +200 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/general.py +56 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/patterns/quantizations.py +40 -3
- embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions.py +1584 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/plan.py +9 -1
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/tensorrt/__init__.py +7 -2
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/version/public.py +1 -1
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0/src/embedl_deploy.egg-info}/PKG-INFO +1 -1
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy.egg-info/SOURCES.txt +2 -2
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions.py +0 -819
- embedl_deploy-0.2.0/tests/test_version.py +0 -20
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/LICENSE +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/NOTICE +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/README.md +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/pyproject.toml +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/setup.cfg +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/__init__.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/__init__.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/__init__.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/backend.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/match.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/modules.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/quantize/__init__.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/quantize/calibrate.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/quantize/config.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/quantize/main.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/quantize/prepare.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/quantize/qat.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/quantize/stubs.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/quantize/utils.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/replace.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/__init__.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/backend.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/modules/__init__.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/modules/pointwise.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/modules/pool.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/modules/swin_attention.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/patterns/__init__.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/__init__.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions/__init__.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions/attention.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions/conv.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions/linear.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions/pointwise.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions/pool.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/patterns/smoothings.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/patterns/utils.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/backend/__init__.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/py.typed +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/quantize/__init__.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/tensorrt/modules/__init__.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/tensorrt/patterns/__init__.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/version/__init__.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy.egg-info/dependency_links.txt +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy.egg-info/requires.txt +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy.egg-info/top_level.txt +0 -0
|
@@ -287,6 +287,10 @@ class Pattern(ABC):
|
|
|
287
287
|
``symbolic_trace``. This pattern has no effect on graphs exported with
|
|
288
288
|
``torch.export`` because the nodes never appear in those graphs."""
|
|
289
289
|
|
|
290
|
+
export_graph_only: bool = False
|
|
291
|
+
"""If ``True``, this pattern targets nodes that only appear in
|
|
292
|
+
``torch.export`` aten graphs and has no effect on symbolic-trace output."""
|
|
293
|
+
|
|
290
294
|
@abstractmethod
|
|
291
295
|
def match(self, graph_module: fx.GraphModule) -> list["PatternMatch"]:
|
|
292
296
|
"""Find all occurrences of this pattern in `graph_module`.
|
|
@@ -149,6 +149,18 @@ def get_transformation_plan(
|
|
|
149
149
|
graph_module = copy.deepcopy(graph_module)
|
|
150
150
|
setattr(graph_module, "_deep_copy_done", True)
|
|
151
151
|
|
|
152
|
+
# Strip torch.export shape-guard nodes that ShapeProp cannot evaluate.
|
|
153
|
+
guards = [
|
|
154
|
+
n
|
|
155
|
+
for n in graph_module.graph.nodes
|
|
156
|
+
if n.op == "call_module" and n.name.startswith("_guards")
|
|
157
|
+
]
|
|
158
|
+
for node in guards:
|
|
159
|
+
node.replace_all_uses_with(next(iter(node.args)))
|
|
160
|
+
graph_module.graph.erase_node(node)
|
|
161
|
+
if guards:
|
|
162
|
+
graph_module.recompile()
|
|
163
|
+
|
|
152
164
|
pattern_matches: list[PatternMatch] = []
|
|
153
165
|
for pattern in patterns:
|
|
154
166
|
pattern_matches.extend(pattern.match(graph_module))
|
|
@@ -219,10 +231,17 @@ def apply_transformation_plan(
|
|
|
219
231
|
graph_module.recompile()
|
|
220
232
|
graph_module.eval()
|
|
221
233
|
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
234
|
+
fake_args: list[torch.Tensor] = []
|
|
235
|
+
for n in graph_module.graph.nodes:
|
|
236
|
+
if n.op != "placeholder":
|
|
237
|
+
continue
|
|
238
|
+
meta = n.meta.get("tensor_meta")
|
|
239
|
+
if meta is None or not hasattr(meta, "shape"):
|
|
240
|
+
fake_args.clear()
|
|
241
|
+
break
|
|
242
|
+
fake_args.append(torch.randn(meta.shape))
|
|
243
|
+
if fake_args:
|
|
244
|
+
ShapeProp(graph_module).propagate(*fake_args) # type: ignore[no-untyped-call]
|
|
226
245
|
|
|
227
246
|
report = _build_report(enabled, skipped)
|
|
228
247
|
|
|
@@ -12,9 +12,10 @@ import torch.nn.functional as F
|
|
|
12
12
|
from torch import nn
|
|
13
13
|
|
|
14
14
|
from embedl_deploy._internal.core.modules import ConvertedModule, FusedModule
|
|
15
|
-
from embedl_deploy._internal.core.quantize.stubs import
|
|
16
|
-
|
|
17
|
-
|
|
15
|
+
from embedl_deploy._internal.core.quantize.stubs import QuantStub
|
|
16
|
+
from embedl_deploy._internal.tensorrt.modules.linear import (
|
|
17
|
+
attach_int8_weight_quant,
|
|
18
|
+
maybe_quantize_weight,
|
|
18
19
|
)
|
|
19
20
|
|
|
20
21
|
|
|
@@ -153,7 +154,7 @@ class FusedMHAInProjection(FusedModule):
|
|
|
153
154
|
def __init__(self, in_proj: MHAInProjection) -> None:
|
|
154
155
|
super().__init__()
|
|
155
156
|
self.in_proj = in_proj
|
|
156
|
-
self.
|
|
157
|
+
attach_int8_weight_quant(self, in_proj.linear)
|
|
157
158
|
|
|
158
159
|
def forward(
|
|
159
160
|
self,
|
|
@@ -173,7 +174,7 @@ class FusedMHAInProjection(FusedModule):
|
|
|
173
174
|
:returns:
|
|
174
175
|
Tuple ``(Q, K, V)`` each of shape ``[B, num_heads, S, head_dim]``.
|
|
175
176
|
"""
|
|
176
|
-
weight =
|
|
177
|
+
weight = maybe_quantize_weight(self, self.in_proj.linear.weight)
|
|
177
178
|
batch, seq, _ = query.shape
|
|
178
179
|
# pylint: disable-next=not-callable
|
|
179
180
|
qkv = F.linear(query, weight, self.in_proj.linear.bias)
|
{embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/modules/conv.py
RENAMED
|
@@ -22,13 +22,19 @@ from embedl_deploy._internal.core.quantize.stubs import (
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
def _is_int8_compatible_conv(conv: nn.Conv2d) -> bool:
|
|
25
|
-
"""Return ``True`` unless
|
|
26
|
-
|
|
27
|
-
TensorRT
|
|
28
|
-
``
|
|
25
|
+
"""Return ``True`` unless *conv* is a grouped conv violating TRT INT8.
|
|
26
|
+
|
|
27
|
+
TensorRT's documented constraint for ``IConvolutionLayer`` is that
|
|
28
|
+
``in_channels / groups`` and ``out_channels / groups`` must both
|
|
29
|
+
be multiples of 4 in INT8 mode. Depthwise convolutions
|
|
30
|
+
(``groups == in_channels``) are an exception: our benchmarks on
|
|
31
|
+
the target devices show they still benefit from INT8 despite
|
|
32
|
+
channels-per-group being 1, so we let them through.
|
|
29
33
|
"""
|
|
30
34
|
if conv.groups <= 1:
|
|
31
35
|
return True
|
|
36
|
+
if conv.groups == conv.in_channels:
|
|
37
|
+
return True
|
|
32
38
|
in_per_group: int = conv.in_channels // conv.groups
|
|
33
39
|
out_per_group: int = conv.out_channels // conv.groups
|
|
34
40
|
return in_per_group % 4 == 0 and out_per_group % 4 == 0
|
{embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/modules/linear.py
RENAMED
|
@@ -14,6 +14,53 @@ from embedl_deploy._internal.core.quantize.stubs import (
|
|
|
14
14
|
WeightFakeQuantize,
|
|
15
15
|
)
|
|
16
16
|
|
|
17
|
+
#: Minimum ``K * N / (K + N)`` for INT8 to outperform FP16.
|
|
18
|
+
INT8_LINEAR_MIN_RATIO: int = 256
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def is_int8_beneficial_linear(linear: nn.Linear) -> bool:
|
|
22
|
+
"""Return ``True`` when INT8 quantisation benefits *linear*.
|
|
23
|
+
|
|
24
|
+
Uses the harmonic mean of the weight dimensions ``K * N / (K + N)``
|
|
25
|
+
as a proxy for the ratio of INT8 compute savings to Q/DQ reformat
|
|
26
|
+
overhead. Below :data:`INT8_LINEAR_MIN_RATIO`, the overhead from
|
|
27
|
+
quantise/dequantise boundary layers exceeds any INT8 GEMM speedup
|
|
28
|
+
and the layer is better left in FP16.
|
|
29
|
+
|
|
30
|
+
Reference: NVIDIA benchmarks show INT8 GEMM outperforms FP16 only
|
|
31
|
+
when all three matrix dimensions exceed ~2048 (A100). The harmonic
|
|
32
|
+
mean threshold of 256 conservatively separates mobile-class models
|
|
33
|
+
(MobileViT FFN ratio ≤ 160) from server-class models (ViT-B/16 FFN
|
|
34
|
+
ratio = 614) where INT8 is beneficial.
|
|
35
|
+
"""
|
|
36
|
+
k, n = linear.in_features, linear.out_features
|
|
37
|
+
return k * n / (k + n) >= INT8_LINEAR_MIN_RATIO
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def attach_int8_weight_quant(
|
|
41
|
+
mod: FusedModule,
|
|
42
|
+
linear: nn.Linear,
|
|
43
|
+
) -> None:
|
|
44
|
+
"""Attach a ``WeightFakeQuantize`` to *mod* when INT8 helps *linear*.
|
|
45
|
+
|
|
46
|
+
When INT8 wouldn't pay for its Q/DQ boundary cost, also clear
|
|
47
|
+
``mod.input_quant_stubs`` so the surrounding Q/DQ pass leaves the
|
|
48
|
+
wrapped linear entirely in FP16.
|
|
49
|
+
"""
|
|
50
|
+
if is_int8_beneficial_linear(linear):
|
|
51
|
+
mod.weight_fake_quant = WeightFakeQuantize({mod})
|
|
52
|
+
else:
|
|
53
|
+
mod.input_quant_stubs = {}
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def maybe_quantize_weight(
|
|
57
|
+
mod: nn.Module,
|
|
58
|
+
weight: torch.Tensor,
|
|
59
|
+
) -> torch.Tensor:
|
|
60
|
+
"""Fake-quantize *weight* through ``mod.weight_fake_quant`` if present."""
|
|
61
|
+
wfq = getattr(mod, "weight_fake_quant", None)
|
|
62
|
+
return wfq(weight) if wfq is not None else weight
|
|
63
|
+
|
|
17
64
|
|
|
18
65
|
class FusedLinear(FusedModule):
|
|
19
66
|
"""Fused wrapper for a standalone ``Linear`` layer.
|
|
@@ -27,11 +74,11 @@ class FusedLinear(FusedModule):
|
|
|
27
74
|
def __init__(self, linear: nn.Linear) -> None:
|
|
28
75
|
super().__init__()
|
|
29
76
|
self.linear = linear
|
|
30
|
-
self
|
|
77
|
+
attach_int8_weight_quant(self, linear)
|
|
31
78
|
|
|
32
79
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
33
80
|
"""Apply ``linear``, fake-quantizing the weight."""
|
|
34
|
-
weight =
|
|
81
|
+
weight = maybe_quantize_weight(self, self.linear.weight)
|
|
35
82
|
# pylint: disable-next=not-callable
|
|
36
83
|
return F.linear(x, weight, self.linear.bias)
|
|
37
84
|
|
|
@@ -57,11 +104,11 @@ class FusedLinearAct(FusedModule):
|
|
|
57
104
|
super().__init__()
|
|
58
105
|
self.linear = linear
|
|
59
106
|
self.act = act
|
|
60
|
-
self
|
|
107
|
+
attach_int8_weight_quant(self, linear)
|
|
61
108
|
|
|
62
109
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
63
110
|
"""Apply ``linear → activation``, fake-quantizing the weight."""
|
|
64
|
-
weight =
|
|
111
|
+
weight = maybe_quantize_weight(self, self.linear.weight)
|
|
65
112
|
# pylint: disable-next=not-callable
|
|
66
113
|
x = F.linear(x, weight, self.linear.bias)
|
|
67
114
|
return self.act(x)
|
|
@@ -21,7 +21,9 @@ from embedl_deploy._internal.core.pattern import (
|
|
|
21
21
|
Fork,
|
|
22
22
|
Pattern,
|
|
23
23
|
PatternMatch,
|
|
24
|
+
SharedNodeCheck,
|
|
24
25
|
Tree,
|
|
26
|
+
TreeMatch,
|
|
25
27
|
Wildcard,
|
|
26
28
|
get_module,
|
|
27
29
|
node_check,
|
|
@@ -689,3 +691,201 @@ class ComposeScaledDotProductAttentionPattern(Pattern):
|
|
|
689
691
|
# matmul pinned as a non-tree user and block erasure.
|
|
690
692
|
pattern_match.graph_module.graph.eliminate_dead_code()
|
|
691
693
|
return replace_tree(pattern_match, [sdpa])
|
|
694
|
+
|
|
695
|
+
|
|
696
|
+
# -- Compose parallel linears into MHAInProjection ----------------------
|
|
697
|
+
|
|
698
|
+
|
|
699
|
+
def _is_transpose_1_2(node: fx.Node) -> bool:
|
|
700
|
+
"""Return ``True`` for ``tensor.transpose(1, 2)``."""
|
|
701
|
+
if node.op != "call_method" or node.target != "transpose":
|
|
702
|
+
return False
|
|
703
|
+
non_node = [a for a in node.args if not isinstance(a, fx.Node)]
|
|
704
|
+
return set(non_node) == {1, 2}
|
|
705
|
+
|
|
706
|
+
|
|
707
|
+
@node_check
|
|
708
|
+
def _is_sdpa_module(node: fx.Node) -> bool:
|
|
709
|
+
"""Return ``True`` for a ``ScaledDotProductAttention`` call_module node."""
|
|
710
|
+
return isinstance(get_module(node), ScaledDotProductAttention)
|
|
711
|
+
|
|
712
|
+
|
|
713
|
+
#: Shared across the three Q/K/V branches so they are all constrained to
|
|
714
|
+
#: the same physical source tensor. Re-using one instance means the
|
|
715
|
+
#: first branch to run caches the source node and the other two succeed
|
|
716
|
+
#: only when they see that exact node.
|
|
717
|
+
_parallel_linears_shared_input = SharedNodeCheck(lambda _: True)
|
|
718
|
+
|
|
719
|
+
|
|
720
|
+
#: One of the three Q/K/V projection branches: the ``transpose(1, 2)``
|
|
721
|
+
#: tail, a ``view``/``reshape`` with two shape arguments, and the
|
|
722
|
+
#: ``nn.Linear`` whose input is constrained by
|
|
723
|
+
#: :data:`_parallel_linears_shared_input`.
|
|
724
|
+
_parallel_linears_branch: Tree = Fork(
|
|
725
|
+
inputs=(
|
|
726
|
+
(_parallel_linears_shared_input, nn.Linear),
|
|
727
|
+
(),
|
|
728
|
+
(),
|
|
729
|
+
),
|
|
730
|
+
operator=_is_view_or_reshape,
|
|
731
|
+
output=(_is_transpose_1_2,),
|
|
732
|
+
)
|
|
733
|
+
|
|
734
|
+
|
|
735
|
+
def _get_parallel_linears_insert(
|
|
736
|
+
shared_input: fx.Node,
|
|
737
|
+
in_proj: MHAInProjection,
|
|
738
|
+
sdpa_node: fx.Node,
|
|
739
|
+
) -> ReplacementFn:
|
|
740
|
+
"""Return a replacement that inserts ``MHAInProjection → getitem×3`` and
|
|
741
|
+
rewires the existing SDPA to consume the Q/K/V getitems.
|
|
742
|
+
|
|
743
|
+
The three old ``transpose → view → nn.Linear`` input chains are part
|
|
744
|
+
of the matched tree and erased by
|
|
745
|
+
:func:`~embedl_deploy._internal.core.replace.replace_tree` once the
|
|
746
|
+
SDPA's args are rewired.
|
|
747
|
+
"""
|
|
748
|
+
|
|
749
|
+
def _insert(
|
|
750
|
+
graph_module: fx.GraphModule,
|
|
751
|
+
prev_args: tuple[fx.Node, ...],
|
|
752
|
+
) -> list[fx.Node]:
|
|
753
|
+
del prev_args # inputs are derived from the shared pre-linear tensor
|
|
754
|
+
replaced = get_replaced_nodes(graph_module)
|
|
755
|
+
resolved_input = replaced.get(shared_input, shared_input)
|
|
756
|
+
graph = graph_module.graph
|
|
757
|
+
|
|
758
|
+
ip_name = get_auto_name(graph_module, in_proj)
|
|
759
|
+
graph_module.add_module(ip_name, in_proj)
|
|
760
|
+
|
|
761
|
+
with graph.inserting_after(resolved_input):
|
|
762
|
+
ip_node = graph.call_module(
|
|
763
|
+
ip_name,
|
|
764
|
+
(resolved_input, resolved_input, resolved_input),
|
|
765
|
+
)
|
|
766
|
+
gis: list[fx.Node] = []
|
|
767
|
+
prev = ip_node
|
|
768
|
+
for i in range(3):
|
|
769
|
+
with graph.inserting_after(prev):
|
|
770
|
+
gi = graph.call_function(
|
|
771
|
+
operator.getitem,
|
|
772
|
+
(ip_node, i),
|
|
773
|
+
)
|
|
774
|
+
gis.append(gi)
|
|
775
|
+
prev = gi
|
|
776
|
+
sdpa_node.args = tuple(gis)
|
|
777
|
+
return [resolved_input, ip_node, *gis, sdpa_node]
|
|
778
|
+
|
|
779
|
+
return _insert
|
|
780
|
+
|
|
781
|
+
|
|
782
|
+
def _branch_linear(tree_match: TreeMatch, branch: int) -> nn.Linear:
|
|
783
|
+
"""Return the ``nn.Linear`` module matched in the *branch*-th Q/K/V arm."""
|
|
784
|
+
linear_node = tree_match.get_node(branch, 0, 1)
|
|
785
|
+
return resolve_module(linear_node, nn.Linear)
|
|
786
|
+
|
|
787
|
+
|
|
788
|
+
class ComposeParallelLinearsPattern(Pattern):
|
|
789
|
+
"""Compose three parallel ``nn.Linear`` Q/K/V into ``MHAInProjection``.
|
|
790
|
+
|
|
791
|
+
Matches a
|
|
792
|
+
:class:`~embedl_deploy._internal.tensorrt.modules.attention.ScaledDotProductAttention`
|
|
793
|
+
node whose three inputs each trace back through
|
|
794
|
+
``transpose(1, 2) → view → nn.Linear`` from the same source tensor.
|
|
795
|
+
The three branches are tied to a single source node by a
|
|
796
|
+
:class:`~embedl_deploy._internal.core.pattern.SharedNodeCheck` shared
|
|
797
|
+
across their data sub-trunks.
|
|
798
|
+
|
|
799
|
+
Packs the three separate linear weights into a single
|
|
800
|
+
``nn.Linear(embed_dim, 3 * embed_dim)`` and wraps it in an
|
|
801
|
+
:class:`~embedl_deploy._internal.tensorrt.modules.attention.MHAInProjection`.
|
|
802
|
+
|
|
803
|
+
Depends on
|
|
804
|
+
:class:`ComposeScaledDotProductAttentionPattern` having run
|
|
805
|
+
first (handled automatically by the iterative conversion loop).
|
|
806
|
+
"""
|
|
807
|
+
|
|
808
|
+
is_conversion = True
|
|
809
|
+
tree: Tree = Fork(
|
|
810
|
+
inputs=(
|
|
811
|
+
_parallel_linears_branch,
|
|
812
|
+
_parallel_linears_branch,
|
|
813
|
+
_parallel_linears_branch,
|
|
814
|
+
),
|
|
815
|
+
operator=_is_sdpa_module,
|
|
816
|
+
output=(),
|
|
817
|
+
)
|
|
818
|
+
|
|
819
|
+
def match(
|
|
820
|
+
self,
|
|
821
|
+
graph_module: fx.GraphModule,
|
|
822
|
+
) -> list[PatternMatch]:
|
|
823
|
+
matches = match_tree(graph_module, pattern=self)
|
|
824
|
+
return [m for m in matches if self._linears_compatible(m)]
|
|
825
|
+
|
|
826
|
+
@staticmethod
|
|
827
|
+
def _linears_compatible(pattern_match: PatternMatch) -> bool:
|
|
828
|
+
"""Return ``True`` when all three matched Linears are shape-compatible.
|
|
829
|
+
|
|
830
|
+
Required for weight packing: shape/bias constraints can't be
|
|
831
|
+
expressed in the tree grammar, so they are checked here to
|
|
832
|
+
reject otherwise-structural matches before replacement runs.
|
|
833
|
+
"""
|
|
834
|
+
first = _branch_linear(pattern_match.tree_match, 0)
|
|
835
|
+
for i in (1, 2):
|
|
836
|
+
lin = _branch_linear(pattern_match.tree_match, i)
|
|
837
|
+
if lin.in_features != first.in_features:
|
|
838
|
+
return False
|
|
839
|
+
if lin.out_features != first.out_features:
|
|
840
|
+
return False
|
|
841
|
+
if (lin.bias is None) != (first.bias is None):
|
|
842
|
+
return False
|
|
843
|
+
return True
|
|
844
|
+
|
|
845
|
+
def replace(
|
|
846
|
+
self,
|
|
847
|
+
pattern_match: PatternMatch,
|
|
848
|
+
) -> list[fx.Node]:
|
|
849
|
+
assert pattern_match.pattern is self
|
|
850
|
+
tree_match = pattern_match.tree_match
|
|
851
|
+
|
|
852
|
+
sdpa_node = tree_match.pre_trunk_nodes[0]
|
|
853
|
+
sdpa_mod = resolve_module(sdpa_node, ScaledDotProductAttention)
|
|
854
|
+
num_heads = sdpa_mod.num_heads
|
|
855
|
+
head_dim = sdpa_mod.head_dim
|
|
856
|
+
embed_dim = num_heads * head_dim
|
|
857
|
+
|
|
858
|
+
shared_input = tree_match.get_node(0, 0, 0)
|
|
859
|
+
q_lin = _branch_linear(tree_match, 0)
|
|
860
|
+
k_lin = _branch_linear(tree_match, 1)
|
|
861
|
+
v_lin = _branch_linear(tree_match, 2)
|
|
862
|
+
has_bias = q_lin.bias is not None
|
|
863
|
+
|
|
864
|
+
packed = nn.Linear(embed_dim, 3 * embed_dim, bias=has_bias)
|
|
865
|
+
packed.weight = nn.Parameter(
|
|
866
|
+
torch.cat(
|
|
867
|
+
[
|
|
868
|
+
q_lin.weight,
|
|
869
|
+
k_lin.weight,
|
|
870
|
+
v_lin.weight,
|
|
871
|
+
],
|
|
872
|
+
dim=0,
|
|
873
|
+
)
|
|
874
|
+
)
|
|
875
|
+
if has_bias:
|
|
876
|
+
packed.bias = nn.Parameter(
|
|
877
|
+
torch.cat(
|
|
878
|
+
[
|
|
879
|
+
q_lin.bias,
|
|
880
|
+
k_lin.bias,
|
|
881
|
+
v_lin.bias, # type: ignore[arg-type]
|
|
882
|
+
],
|
|
883
|
+
dim=0,
|
|
884
|
+
)
|
|
885
|
+
)
|
|
886
|
+
|
|
887
|
+
in_proj = MHAInProjection(packed, num_heads, head_dim)
|
|
888
|
+
return replace_tree(
|
|
889
|
+
pattern_match,
|
|
890
|
+
[_get_parallel_linears_insert(shared_input, in_proj, sdpa_node)],
|
|
891
|
+
)
|
|
@@ -101,6 +101,62 @@ class RemoveDeadAssertPattern(Pattern):
|
|
|
101
101
|
return replace_tree(pattern_match, [])
|
|
102
102
|
|
|
103
103
|
|
|
104
|
+
def _is_export_assert(node: fx.Node) -> bool:
|
|
105
|
+
"""Return ``True`` for ``torch.export`` device/dtype guard nodes.
|
|
106
|
+
|
|
107
|
+
``torch.export`` inserts ``aten._assert_tensor_metadata.default`` and
|
|
108
|
+
``aten._assert_async.msg`` nodes to enforce the device, dtype, and layout
|
|
109
|
+
of tensors at the export site. These guards always fail when the model is
|
|
110
|
+
moved to a different device (e.g. CPU export → CUDA inference) and must be
|
|
111
|
+
removed for deployment.
|
|
112
|
+
|
|
113
|
+
Both ops return ``None`` and have no downstream users, so removal is safe.
|
|
114
|
+
"""
|
|
115
|
+
if node.op != "call_function":
|
|
116
|
+
return False
|
|
117
|
+
target = node.target
|
|
118
|
+
_metadata = getattr(
|
|
119
|
+
getattr(torch.ops, "aten", None),
|
|
120
|
+
"_assert_tensor_metadata",
|
|
121
|
+
None,
|
|
122
|
+
)
|
|
123
|
+
_async = getattr(
|
|
124
|
+
getattr(torch.ops, "aten", None),
|
|
125
|
+
"_assert_async",
|
|
126
|
+
None,
|
|
127
|
+
)
|
|
128
|
+
if _metadata is not None and target is getattr(_metadata, "default", None):
|
|
129
|
+
return True
|
|
130
|
+
if _async is not None and target is getattr(_async, "msg", None):
|
|
131
|
+
return True
|
|
132
|
+
return False
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class RemoveExportAssertPattern(Pattern):
|
|
136
|
+
"""Remove ``torch.export`` tensor-metadata guard nodes.
|
|
137
|
+
|
|
138
|
+
``torch.export`` inserts ``aten._assert_tensor_metadata`` calls to enforce
|
|
139
|
+
the device/dtype/layout of inputs at export time. These guards always
|
|
140
|
+
raise when the model is run on a device other than the one used during
|
|
141
|
+
export (e.g. CPU-exported model deployed on CUDA).
|
|
142
|
+
|
|
143
|
+
Unlike :class:`RemoveAssertPattern` this pattern is not
|
|
144
|
+
``symbolic_trace_only`` — it targets ``torch.export`` graph modules
|
|
145
|
+
specifically.
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
is_conversion = True
|
|
149
|
+
export_graph_only = True
|
|
150
|
+
tree: Tree = (_is_export_assert,)
|
|
151
|
+
|
|
152
|
+
def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]:
|
|
153
|
+
return match_tree(graph_module, pattern=self)
|
|
154
|
+
|
|
155
|
+
def replace(self, pattern_match: PatternMatch) -> list[fx.Node]:
|
|
156
|
+
assert pattern_match.pattern is self
|
|
157
|
+
return replace_tree(pattern_match, [])
|
|
158
|
+
|
|
159
|
+
|
|
104
160
|
def _is_flatten(node: fx.Node) -> bool:
|
|
105
161
|
"""Return ``True`` when `node` is a flatten call with shape metadata."""
|
|
106
162
|
if node.op == "call_function":
|
|
@@ -38,6 +38,23 @@ from embedl_deploy._internal.tensorrt.modules.pool import (
|
|
|
38
38
|
from embedl_deploy._internal.tensorrt.modules.swin_attention import (
|
|
39
39
|
FusedSwinAttention,
|
|
40
40
|
)
|
|
41
|
+
from embedl_deploy._internal.tensorrt.patterns.utils import get_input_shape
|
|
42
|
+
|
|
43
|
+
#: Head sizes for which TensorRT has a fused INT8 MHA kernel.
|
|
44
|
+
INT8_MHA_HEAD_SIZES: frozenset[int] = frozenset({16, 32, 64})
|
|
45
|
+
|
|
46
|
+
#: Maximum sequence length supported by the fused INT8 MHA kernel.
|
|
47
|
+
#:
|
|
48
|
+
#: TensorRT's fused INT8 multi-head attention kernel (SM75–SM90, SM120,
|
|
49
|
+
#: SM121) only supports head sizes in :data:`INT8_MHA_HEAD_SIZES` and
|
|
50
|
+
#: sequence lengths at most :data:`INT8_MHA_MAX_SEQ`. Outside those
|
|
51
|
+
#: bounds, quantising the softmax output forces an unfused FP32
|
|
52
|
+
#: softmax + INT8 requantise path that is slower than leaving the
|
|
53
|
+
#: attention block in FP16.
|
|
54
|
+
# Reference:
|
|
55
|
+
# pylint: disable-next=line-too-long
|
|
56
|
+
# https://docs.nvidia.com/deeplearning/tensorrt/latest/inference-library/work-with-transformers.html
|
|
57
|
+
INT8_MHA_MAX_SEQ: int = 512
|
|
41
58
|
|
|
42
59
|
|
|
43
60
|
def _has_quant_stubs(node: fx.Node) -> bool:
|
|
@@ -169,16 +186,36 @@ class PropagateQuantStubPattern(Pattern):
|
|
|
169
186
|
|
|
170
187
|
|
|
171
188
|
def _needs_surround(node: fx.Node) -> bool:
|
|
172
|
-
"""Return ``True`` for a surround-type ``FusedModule`` not yet surrounded.
|
|
189
|
+
"""Return ``True`` for a surround-type ``FusedModule`` not yet surrounded.
|
|
190
|
+
|
|
191
|
+
For :class:`FusedScaledDotProductAttention`, surrounding is skipped
|
|
192
|
+
when the head size or sequence length fall outside TensorRT's fused
|
|
193
|
+
INT8 MHA constraints (:data:`INT8_MHA_HEAD_SIZES`,
|
|
194
|
+
:data:`INT8_MHA_MAX_SEQ`). The internal ``softmax_quant`` is also
|
|
195
|
+
disabled so the attention block stays entirely in FP16.
|
|
196
|
+
"""
|
|
173
197
|
mod = get_module(node)
|
|
174
|
-
|
|
198
|
+
if not isinstance(
|
|
175
199
|
mod,
|
|
176
200
|
(
|
|
177
201
|
FusedAdaptiveAvgPool2d,
|
|
178
202
|
FusedScaledDotProductAttention,
|
|
179
203
|
FusedSwinAttention,
|
|
180
204
|
),
|
|
181
|
-
)
|
|
205
|
+
):
|
|
206
|
+
return False
|
|
207
|
+
if getattr(mod, "_surrounded", False):
|
|
208
|
+
return False
|
|
209
|
+
if isinstance(mod, FusedScaledDotProductAttention):
|
|
210
|
+
head_dim = mod.attention.head_dim
|
|
211
|
+
shape = get_input_shape(node)
|
|
212
|
+
seq_len = shape[-2] if shape is not None and len(shape) >= 3 else None
|
|
213
|
+
if head_dim not in INT8_MHA_HEAD_SIZES or (
|
|
214
|
+
seq_len is not None and seq_len > INT8_MHA_MAX_SEQ
|
|
215
|
+
):
|
|
216
|
+
mod.softmax_quant.enabled = False
|
|
217
|
+
return False
|
|
218
|
+
return True
|
|
182
219
|
|
|
183
220
|
|
|
184
221
|
class SurroundWithQuantStubsPattern(Pattern):
|