embedl-deploy 0.2.0__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. embedl_deploy-0.3.0/MANIFEST.in +8 -0
  2. {embedl_deploy-0.2.0/src/embedl_deploy.egg-info → embedl_deploy-0.3.0}/PKG-INFO +1 -1
  3. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/pattern.py +4 -0
  4. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/plan.py +23 -4
  5. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/modules/attention.py +6 -5
  6. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/modules/conv.py +10 -4
  7. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/modules/linear.py +51 -4
  8. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/attention.py +200 -0
  9. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/general.py +56 -0
  10. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/patterns/quantizations.py +40 -3
  11. embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions.py +1584 -0
  12. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/plan.py +9 -1
  13. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/tensorrt/__init__.py +7 -2
  14. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/version/public.py +1 -1
  15. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0/src/embedl_deploy.egg-info}/PKG-INFO +1 -1
  16. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy.egg-info/SOURCES.txt +2 -2
  17. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions.py +0 -819
  18. embedl_deploy-0.2.0/tests/test_version.py +0 -20
  19. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/LICENSE +0 -0
  20. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/NOTICE +0 -0
  21. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/README.md +0 -0
  22. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/pyproject.toml +0 -0
  23. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/setup.cfg +0 -0
  24. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/__init__.py +0 -0
  25. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/__init__.py +0 -0
  26. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/__init__.py +0 -0
  27. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/backend.py +0 -0
  28. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/match.py +0 -0
  29. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/modules.py +0 -0
  30. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/quantize/__init__.py +0 -0
  31. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/quantize/calibrate.py +0 -0
  32. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/quantize/config.py +0 -0
  33. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/quantize/main.py +0 -0
  34. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/quantize/prepare.py +0 -0
  35. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/quantize/qat.py +0 -0
  36. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/quantize/stubs.py +0 -0
  37. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/quantize/utils.py +0 -0
  38. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/core/replace.py +0 -0
  39. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/__init__.py +0 -0
  40. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/backend.py +0 -0
  41. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/modules/__init__.py +0 -0
  42. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/modules/pointwise.py +0 -0
  43. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/modules/pool.py +0 -0
  44. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/modules/swin_attention.py +0 -0
  45. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/patterns/__init__.py +0 -0
  46. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/__init__.py +0 -0
  47. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions/__init__.py +0 -0
  48. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions/attention.py +0 -0
  49. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions/conv.py +0 -0
  50. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions/linear.py +0 -0
  51. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions/pointwise.py +0 -0
  52. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions/pool.py +0 -0
  53. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/patterns/smoothings.py +0 -0
  54. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/_internal/tensorrt/patterns/utils.py +0 -0
  55. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/backend/__init__.py +0 -0
  56. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/py.typed +0 -0
  57. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/quantize/__init__.py +0 -0
  58. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/tensorrt/modules/__init__.py +0 -0
  59. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/tensorrt/patterns/__init__.py +0 -0
  60. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy/version/__init__.py +0 -0
  61. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy.egg-info/dependency_links.txt +0 -0
  62. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy.egg-info/requires.txt +0 -0
  63. {embedl_deploy-0.2.0 → embedl_deploy-0.3.0}/src/embedl_deploy.egg-info/top_level.txt +0 -0
@@ -0,0 +1,8 @@
1
+ prune *
2
+ graft src
3
+ include LICENSE
4
+ include NOTICE
5
+ include README.md
6
+ global-exclude CLAUDE.md
7
+ global-exclude *.pyc
8
+ global-exclude __pycache__
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: embedl-deploy
3
- Version: 0.2.0
3
+ Version: 0.3.0
4
4
  Summary: Python package to make AI models deployment-ready for any hardware.
5
5
  Author-email: Embedl AB <support@embedl.com>
6
6
  Project-URL: Homepage, https://www.embedl.com/
@@ -287,6 +287,10 @@ class Pattern(ABC):
287
287
  ``symbolic_trace``. This pattern has no effect on graphs exported with
288
288
  ``torch.export`` because the nodes never appear in those graphs."""
289
289
 
290
+ export_graph_only: bool = False
291
+ """If ``True``, this pattern targets nodes that only appear in
292
+ ``torch.export`` aten graphs and has no effect on symbolic-trace output."""
293
+
290
294
  @abstractmethod
291
295
  def match(self, graph_module: fx.GraphModule) -> list["PatternMatch"]:
292
296
  """Find all occurrences of this pattern in `graph_module`.
@@ -149,6 +149,18 @@ def get_transformation_plan(
149
149
  graph_module = copy.deepcopy(graph_module)
150
150
  setattr(graph_module, "_deep_copy_done", True)
151
151
 
152
+ # Strip torch.export shape-guard nodes that ShapeProp cannot evaluate.
153
+ guards = [
154
+ n
155
+ for n in graph_module.graph.nodes
156
+ if n.op == "call_module" and n.name.startswith("_guards")
157
+ ]
158
+ for node in guards:
159
+ node.replace_all_uses_with(next(iter(node.args)))
160
+ graph_module.graph.erase_node(node)
161
+ if guards:
162
+ graph_module.recompile()
163
+
152
164
  pattern_matches: list[PatternMatch] = []
153
165
  for pattern in patterns:
154
166
  pattern_matches.extend(pattern.match(graph_module))
@@ -219,10 +231,17 @@ def apply_transformation_plan(
219
231
  graph_module.recompile()
220
232
  graph_module.eval()
221
233
 
222
- input_node = next(iter(graph_module.graph.nodes))
223
- meta = input_node.meta.get("tensor_meta")
224
- if meta is not None and hasattr(meta, "shape"):
225
- ShapeProp(graph_module).propagate(torch.randn(meta.shape)) # type: ignore[no-untyped-call]
234
+ fake_args: list[torch.Tensor] = []
235
+ for n in graph_module.graph.nodes:
236
+ if n.op != "placeholder":
237
+ continue
238
+ meta = n.meta.get("tensor_meta")
239
+ if meta is None or not hasattr(meta, "shape"):
240
+ fake_args.clear()
241
+ break
242
+ fake_args.append(torch.randn(meta.shape))
243
+ if fake_args:
244
+ ShapeProp(graph_module).propagate(*fake_args) # type: ignore[no-untyped-call]
226
245
 
227
246
  report = _build_report(enabled, skipped)
228
247
 
@@ -12,9 +12,10 @@ import torch.nn.functional as F
12
12
  from torch import nn
13
13
 
14
14
  from embedl_deploy._internal.core.modules import ConvertedModule, FusedModule
15
- from embedl_deploy._internal.core.quantize.stubs import (
16
- QuantStub,
17
- WeightFakeQuantize,
15
+ from embedl_deploy._internal.core.quantize.stubs import QuantStub
16
+ from embedl_deploy._internal.tensorrt.modules.linear import (
17
+ attach_int8_weight_quant,
18
+ maybe_quantize_weight,
18
19
  )
19
20
 
20
21
 
@@ -153,7 +154,7 @@ class FusedMHAInProjection(FusedModule):
153
154
  def __init__(self, in_proj: MHAInProjection) -> None:
154
155
  super().__init__()
155
156
  self.in_proj = in_proj
156
- self.weight_fake_quant = WeightFakeQuantize({self})
157
+ attach_int8_weight_quant(self, in_proj.linear)
157
158
 
158
159
  def forward(
159
160
  self,
@@ -173,7 +174,7 @@ class FusedMHAInProjection(FusedModule):
173
174
  :returns:
174
175
  Tuple ``(Q, K, V)`` each of shape ``[B, num_heads, S, head_dim]``.
175
176
  """
176
- weight = self.weight_fake_quant(self.in_proj.linear.weight)
177
+ weight = maybe_quantize_weight(self, self.in_proj.linear.weight)
177
178
  batch, seq, _ = query.shape
178
179
  # pylint: disable-next=not-callable
179
180
  qkv = F.linear(query, weight, self.in_proj.linear.bias)
@@ -22,13 +22,19 @@ from embedl_deploy._internal.core.quantize.stubs import (
22
22
 
23
23
 
24
24
  def _is_int8_compatible_conv(conv: nn.Conv2d) -> bool:
25
- """Return ``True`` unless `conv` is a grouped conv violating TRT INT8.
26
-
27
- TensorRT requires ``in_channels / groups`` and
28
- ``out_channels / groups`` to both be multiples of 4 for INT8.
25
+ """Return ``True`` unless *conv* is a grouped conv violating TRT INT8.
26
+
27
+ TensorRT's documented constraint for ``IConvolutionLayer`` is that
28
+ ``in_channels / groups`` and ``out_channels / groups`` must both
29
+ be multiples of 4 in INT8 mode. Depthwise convolutions
30
+ (``groups == in_channels``) are an exception: our benchmarks on
31
+ the target devices show they still benefit from INT8 despite
32
+ channels-per-group being 1, so we let them through.
29
33
  """
30
34
  if conv.groups <= 1:
31
35
  return True
36
+ if conv.groups == conv.in_channels:
37
+ return True
32
38
  in_per_group: int = conv.in_channels // conv.groups
33
39
  out_per_group: int = conv.out_channels // conv.groups
34
40
  return in_per_group % 4 == 0 and out_per_group % 4 == 0
@@ -14,6 +14,53 @@ from embedl_deploy._internal.core.quantize.stubs import (
14
14
  WeightFakeQuantize,
15
15
  )
16
16
 
17
+ #: Minimum ``K * N / (K + N)`` for INT8 to outperform FP16.
18
+ INT8_LINEAR_MIN_RATIO: int = 256
19
+
20
+
21
+ def is_int8_beneficial_linear(linear: nn.Linear) -> bool:
22
+ """Return ``True`` when INT8 quantisation benefits *linear*.
23
+
24
+ Uses the harmonic mean of the weight dimensions ``K * N / (K + N)``
25
+ as a proxy for the ratio of INT8 compute savings to Q/DQ reformat
26
+ overhead. Below :data:`INT8_LINEAR_MIN_RATIO`, the overhead from
27
+ quantise/dequantise boundary layers exceeds any INT8 GEMM speedup
28
+ and the layer is better left in FP16.
29
+
30
+ Reference: NVIDIA benchmarks show INT8 GEMM outperforms FP16 only
31
+ when all three matrix dimensions exceed ~2048 (A100). The harmonic
32
+ mean threshold of 256 conservatively separates mobile-class models
33
+ (MobileViT FFN ratio ≤ 160) from server-class models (ViT-B/16 FFN
34
+ ratio = 614) where INT8 is beneficial.
35
+ """
36
+ k, n = linear.in_features, linear.out_features
37
+ return k * n / (k + n) >= INT8_LINEAR_MIN_RATIO
38
+
39
+
40
+ def attach_int8_weight_quant(
41
+ mod: FusedModule,
42
+ linear: nn.Linear,
43
+ ) -> None:
44
+ """Attach a ``WeightFakeQuantize`` to *mod* when INT8 helps *linear*.
45
+
46
+ When INT8 wouldn't pay for its Q/DQ boundary cost, also clear
47
+ ``mod.input_quant_stubs`` so the surrounding Q/DQ pass leaves the
48
+ wrapped linear entirely in FP16.
49
+ """
50
+ if is_int8_beneficial_linear(linear):
51
+ mod.weight_fake_quant = WeightFakeQuantize({mod})
52
+ else:
53
+ mod.input_quant_stubs = {}
54
+
55
+
56
+ def maybe_quantize_weight(
57
+ mod: nn.Module,
58
+ weight: torch.Tensor,
59
+ ) -> torch.Tensor:
60
+ """Fake-quantize *weight* through ``mod.weight_fake_quant`` if present."""
61
+ wfq = getattr(mod, "weight_fake_quant", None)
62
+ return wfq(weight) if wfq is not None else weight
63
+
17
64
 
18
65
  class FusedLinear(FusedModule):
19
66
  """Fused wrapper for a standalone ``Linear`` layer.
@@ -27,11 +74,11 @@ class FusedLinear(FusedModule):
27
74
  def __init__(self, linear: nn.Linear) -> None:
28
75
  super().__init__()
29
76
  self.linear = linear
30
- self.weight_fake_quant = WeightFakeQuantize({self})
77
+ attach_int8_weight_quant(self, linear)
31
78
 
32
79
  def forward(self, x: torch.Tensor) -> torch.Tensor:
33
80
  """Apply ``linear``, fake-quantizing the weight."""
34
- weight = self.weight_fake_quant(self.linear.weight)
81
+ weight = maybe_quantize_weight(self, self.linear.weight)
35
82
  # pylint: disable-next=not-callable
36
83
  return F.linear(x, weight, self.linear.bias)
37
84
 
@@ -57,11 +104,11 @@ class FusedLinearAct(FusedModule):
57
104
  super().__init__()
58
105
  self.linear = linear
59
106
  self.act = act
60
- self.weight_fake_quant = WeightFakeQuantize({self})
107
+ attach_int8_weight_quant(self, linear)
61
108
 
62
109
  def forward(self, x: torch.Tensor) -> torch.Tensor:
63
110
  """Apply ``linear → activation``, fake-quantizing the weight."""
64
- weight = self.weight_fake_quant(self.linear.weight)
111
+ weight = maybe_quantize_weight(self, self.linear.weight)
65
112
  # pylint: disable-next=not-callable
66
113
  x = F.linear(x, weight, self.linear.bias)
67
114
  return self.act(x)
@@ -21,7 +21,9 @@ from embedl_deploy._internal.core.pattern import (
21
21
  Fork,
22
22
  Pattern,
23
23
  PatternMatch,
24
+ SharedNodeCheck,
24
25
  Tree,
26
+ TreeMatch,
25
27
  Wildcard,
26
28
  get_module,
27
29
  node_check,
@@ -689,3 +691,201 @@ class ComposeScaledDotProductAttentionPattern(Pattern):
689
691
  # matmul pinned as a non-tree user and block erasure.
690
692
  pattern_match.graph_module.graph.eliminate_dead_code()
691
693
  return replace_tree(pattern_match, [sdpa])
694
+
695
+
696
+ # -- Compose parallel linears into MHAInProjection ----------------------
697
+
698
+
699
+ def _is_transpose_1_2(node: fx.Node) -> bool:
700
+ """Return ``True`` for ``tensor.transpose(1, 2)``."""
701
+ if node.op != "call_method" or node.target != "transpose":
702
+ return False
703
+ non_node = [a for a in node.args if not isinstance(a, fx.Node)]
704
+ return set(non_node) == {1, 2}
705
+
706
+
707
+ @node_check
708
+ def _is_sdpa_module(node: fx.Node) -> bool:
709
+ """Return ``True`` for a ``ScaledDotProductAttention`` call_module node."""
710
+ return isinstance(get_module(node), ScaledDotProductAttention)
711
+
712
+
713
+ #: Shared across the three Q/K/V branches so they are all constrained to
714
+ #: the same physical source tensor. Re-using one instance means the
715
+ #: first branch to run caches the source node and the other two succeed
716
+ #: only when they see that exact node.
717
+ _parallel_linears_shared_input = SharedNodeCheck(lambda _: True)
718
+
719
+
720
+ #: One of the three Q/K/V projection branches: the ``transpose(1, 2)``
721
+ #: tail, a ``view``/``reshape`` with two shape arguments, and the
722
+ #: ``nn.Linear`` whose input is constrained by
723
+ #: :data:`_parallel_linears_shared_input`.
724
+ _parallel_linears_branch: Tree = Fork(
725
+ inputs=(
726
+ (_parallel_linears_shared_input, nn.Linear),
727
+ (),
728
+ (),
729
+ ),
730
+ operator=_is_view_or_reshape,
731
+ output=(_is_transpose_1_2,),
732
+ )
733
+
734
+
735
+ def _get_parallel_linears_insert(
736
+ shared_input: fx.Node,
737
+ in_proj: MHAInProjection,
738
+ sdpa_node: fx.Node,
739
+ ) -> ReplacementFn:
740
+ """Return a replacement that inserts ``MHAInProjection → getitem×3`` and
741
+ rewires the existing SDPA to consume the Q/K/V getitems.
742
+
743
+ The three old ``transpose → view → nn.Linear`` input chains are part
744
+ of the matched tree and erased by
745
+ :func:`~embedl_deploy._internal.core.replace.replace_tree` once the
746
+ SDPA's args are rewired.
747
+ """
748
+
749
+ def _insert(
750
+ graph_module: fx.GraphModule,
751
+ prev_args: tuple[fx.Node, ...],
752
+ ) -> list[fx.Node]:
753
+ del prev_args # inputs are derived from the shared pre-linear tensor
754
+ replaced = get_replaced_nodes(graph_module)
755
+ resolved_input = replaced.get(shared_input, shared_input)
756
+ graph = graph_module.graph
757
+
758
+ ip_name = get_auto_name(graph_module, in_proj)
759
+ graph_module.add_module(ip_name, in_proj)
760
+
761
+ with graph.inserting_after(resolved_input):
762
+ ip_node = graph.call_module(
763
+ ip_name,
764
+ (resolved_input, resolved_input, resolved_input),
765
+ )
766
+ gis: list[fx.Node] = []
767
+ prev = ip_node
768
+ for i in range(3):
769
+ with graph.inserting_after(prev):
770
+ gi = graph.call_function(
771
+ operator.getitem,
772
+ (ip_node, i),
773
+ )
774
+ gis.append(gi)
775
+ prev = gi
776
+ sdpa_node.args = tuple(gis)
777
+ return [resolved_input, ip_node, *gis, sdpa_node]
778
+
779
+ return _insert
780
+
781
+
782
+ def _branch_linear(tree_match: TreeMatch, branch: int) -> nn.Linear:
783
+ """Return the ``nn.Linear`` module matched in the *branch*-th Q/K/V arm."""
784
+ linear_node = tree_match.get_node(branch, 0, 1)
785
+ return resolve_module(linear_node, nn.Linear)
786
+
787
+
788
+ class ComposeParallelLinearsPattern(Pattern):
789
+ """Compose three parallel ``nn.Linear`` Q/K/V into ``MHAInProjection``.
790
+
791
+ Matches a
792
+ :class:`~embedl_deploy._internal.tensorrt.modules.attention.ScaledDotProductAttention`
793
+ node whose three inputs each trace back through
794
+ ``transpose(1, 2) → view → nn.Linear`` from the same source tensor.
795
+ The three branches are tied to a single source node by a
796
+ :class:`~embedl_deploy._internal.core.pattern.SharedNodeCheck` shared
797
+ across their data sub-trunks.
798
+
799
+ Packs the three separate linear weights into a single
800
+ ``nn.Linear(embed_dim, 3 * embed_dim)`` and wraps it in an
801
+ :class:`~embedl_deploy._internal.tensorrt.modules.attention.MHAInProjection`.
802
+
803
+ Depends on
804
+ :class:`ComposeScaledDotProductAttentionPattern` having run
805
+ first (handled automatically by the iterative conversion loop).
806
+ """
807
+
808
+ is_conversion = True
809
+ tree: Tree = Fork(
810
+ inputs=(
811
+ _parallel_linears_branch,
812
+ _parallel_linears_branch,
813
+ _parallel_linears_branch,
814
+ ),
815
+ operator=_is_sdpa_module,
816
+ output=(),
817
+ )
818
+
819
+ def match(
820
+ self,
821
+ graph_module: fx.GraphModule,
822
+ ) -> list[PatternMatch]:
823
+ matches = match_tree(graph_module, pattern=self)
824
+ return [m for m in matches if self._linears_compatible(m)]
825
+
826
+ @staticmethod
827
+ def _linears_compatible(pattern_match: PatternMatch) -> bool:
828
+ """Return ``True`` when all three matched Linears are shape-compatible.
829
+
830
+ Required for weight packing: shape/bias constraints can't be
831
+ expressed in the tree grammar, so they are checked here to
832
+ reject otherwise-structural matches before replacement runs.
833
+ """
834
+ first = _branch_linear(pattern_match.tree_match, 0)
835
+ for i in (1, 2):
836
+ lin = _branch_linear(pattern_match.tree_match, i)
837
+ if lin.in_features != first.in_features:
838
+ return False
839
+ if lin.out_features != first.out_features:
840
+ return False
841
+ if (lin.bias is None) != (first.bias is None):
842
+ return False
843
+ return True
844
+
845
+ def replace(
846
+ self,
847
+ pattern_match: PatternMatch,
848
+ ) -> list[fx.Node]:
849
+ assert pattern_match.pattern is self
850
+ tree_match = pattern_match.tree_match
851
+
852
+ sdpa_node = tree_match.pre_trunk_nodes[0]
853
+ sdpa_mod = resolve_module(sdpa_node, ScaledDotProductAttention)
854
+ num_heads = sdpa_mod.num_heads
855
+ head_dim = sdpa_mod.head_dim
856
+ embed_dim = num_heads * head_dim
857
+
858
+ shared_input = tree_match.get_node(0, 0, 0)
859
+ q_lin = _branch_linear(tree_match, 0)
860
+ k_lin = _branch_linear(tree_match, 1)
861
+ v_lin = _branch_linear(tree_match, 2)
862
+ has_bias = q_lin.bias is not None
863
+
864
+ packed = nn.Linear(embed_dim, 3 * embed_dim, bias=has_bias)
865
+ packed.weight = nn.Parameter(
866
+ torch.cat(
867
+ [
868
+ q_lin.weight,
869
+ k_lin.weight,
870
+ v_lin.weight,
871
+ ],
872
+ dim=0,
873
+ )
874
+ )
875
+ if has_bias:
876
+ packed.bias = nn.Parameter(
877
+ torch.cat(
878
+ [
879
+ q_lin.bias,
880
+ k_lin.bias,
881
+ v_lin.bias, # type: ignore[arg-type]
882
+ ],
883
+ dim=0,
884
+ )
885
+ )
886
+
887
+ in_proj = MHAInProjection(packed, num_heads, head_dim)
888
+ return replace_tree(
889
+ pattern_match,
890
+ [_get_parallel_linears_insert(shared_input, in_proj, sdpa_node)],
891
+ )
@@ -101,6 +101,62 @@ class RemoveDeadAssertPattern(Pattern):
101
101
  return replace_tree(pattern_match, [])
102
102
 
103
103
 
104
+ def _is_export_assert(node: fx.Node) -> bool:
105
+ """Return ``True`` for ``torch.export`` device/dtype guard nodes.
106
+
107
+ ``torch.export`` inserts ``aten._assert_tensor_metadata.default`` and
108
+ ``aten._assert_async.msg`` nodes to enforce the device, dtype, and layout
109
+ of tensors at the export site. These guards always fail when the model is
110
+ moved to a different device (e.g. CPU export → CUDA inference) and must be
111
+ removed for deployment.
112
+
113
+ Both ops return ``None`` and have no downstream users, so removal is safe.
114
+ """
115
+ if node.op != "call_function":
116
+ return False
117
+ target = node.target
118
+ _metadata = getattr(
119
+ getattr(torch.ops, "aten", None),
120
+ "_assert_tensor_metadata",
121
+ None,
122
+ )
123
+ _async = getattr(
124
+ getattr(torch.ops, "aten", None),
125
+ "_assert_async",
126
+ None,
127
+ )
128
+ if _metadata is not None and target is getattr(_metadata, "default", None):
129
+ return True
130
+ if _async is not None and target is getattr(_async, "msg", None):
131
+ return True
132
+ return False
133
+
134
+
135
+ class RemoveExportAssertPattern(Pattern):
136
+ """Remove ``torch.export`` tensor-metadata guard nodes.
137
+
138
+ ``torch.export`` inserts ``aten._assert_tensor_metadata`` calls to enforce
139
+ the device/dtype/layout of inputs at export time. These guards always
140
+ raise when the model is run on a device other than the one used during
141
+ export (e.g. CPU-exported model deployed on CUDA).
142
+
143
+ Unlike :class:`RemoveAssertPattern` this pattern is not
144
+ ``symbolic_trace_only`` — it targets ``torch.export`` graph modules
145
+ specifically.
146
+ """
147
+
148
+ is_conversion = True
149
+ export_graph_only = True
150
+ tree: Tree = (_is_export_assert,)
151
+
152
+ def match(self, graph_module: fx.GraphModule) -> list[PatternMatch]:
153
+ return match_tree(graph_module, pattern=self)
154
+
155
+ def replace(self, pattern_match: PatternMatch) -> list[fx.Node]:
156
+ assert pattern_match.pattern is self
157
+ return replace_tree(pattern_match, [])
158
+
159
+
104
160
  def _is_flatten(node: fx.Node) -> bool:
105
161
  """Return ``True`` when `node` is a flatten call with shape metadata."""
106
162
  if node.op == "call_function":
@@ -38,6 +38,23 @@ from embedl_deploy._internal.tensorrt.modules.pool import (
38
38
  from embedl_deploy._internal.tensorrt.modules.swin_attention import (
39
39
  FusedSwinAttention,
40
40
  )
41
+ from embedl_deploy._internal.tensorrt.patterns.utils import get_input_shape
42
+
43
+ #: Head sizes for which TensorRT has a fused INT8 MHA kernel.
44
+ INT8_MHA_HEAD_SIZES: frozenset[int] = frozenset({16, 32, 64})
45
+
46
+ #: Maximum sequence length supported by the fused INT8 MHA kernel.
47
+ #:
48
+ #: TensorRT's fused INT8 multi-head attention kernel (SM75–SM90, SM120,
49
+ #: SM121) only supports head sizes in :data:`INT8_MHA_HEAD_SIZES` and
50
+ #: sequence lengths at most :data:`INT8_MHA_MAX_SEQ`. Outside those
51
+ #: bounds, quantising the softmax output forces an unfused FP32
52
+ #: softmax + INT8 requantise path that is slower than leaving the
53
+ #: attention block in FP16.
54
+ # Reference:
55
+ # pylint: disable-next=line-too-long
56
+ # https://docs.nvidia.com/deeplearning/tensorrt/latest/inference-library/work-with-transformers.html
57
+ INT8_MHA_MAX_SEQ: int = 512
41
58
 
42
59
 
43
60
  def _has_quant_stubs(node: fx.Node) -> bool:
@@ -169,16 +186,36 @@ class PropagateQuantStubPattern(Pattern):
169
186
 
170
187
 
171
188
  def _needs_surround(node: fx.Node) -> bool:
172
- """Return ``True`` for a surround-type ``FusedModule`` not yet surrounded."""
189
+ """Return ``True`` for a surround-type ``FusedModule`` not yet surrounded.
190
+
191
+ For :class:`FusedScaledDotProductAttention`, surrounding is skipped
192
+ when the head size or sequence length fall outside TensorRT's fused
193
+ INT8 MHA constraints (:data:`INT8_MHA_HEAD_SIZES`,
194
+ :data:`INT8_MHA_MAX_SEQ`). The internal ``softmax_quant`` is also
195
+ disabled so the attention block stays entirely in FP16.
196
+ """
173
197
  mod = get_module(node)
174
- return isinstance(
198
+ if not isinstance(
175
199
  mod,
176
200
  (
177
201
  FusedAdaptiveAvgPool2d,
178
202
  FusedScaledDotProductAttention,
179
203
  FusedSwinAttention,
180
204
  ),
181
- ) and not getattr(mod, "_surrounded", False)
205
+ ):
206
+ return False
207
+ if getattr(mod, "_surrounded", False):
208
+ return False
209
+ if isinstance(mod, FusedScaledDotProductAttention):
210
+ head_dim = mod.attention.head_dim
211
+ shape = get_input_shape(node)
212
+ seq_len = shape[-2] if shape is not None and len(shape) >= 3 else None
213
+ if head_dim not in INT8_MHA_HEAD_SIZES or (
214
+ seq_len is not None and seq_len > INT8_MHA_MAX_SEQ
215
+ ):
216
+ mod.softmax_quant.enabled = False
217
+ return False
218
+ return True
182
219
 
183
220
 
184
221
  class SurroundWithQuantStubsPattern(Pattern):