embedl-deploy-tensorrt 0.4.0__tar.gz → 0.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/PKG-INFO +65 -33
  2. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/README.md +64 -31
  3. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/pyproject.toml +1 -1
  4. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/backend.py +1 -0
  5. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/modules/attention.py +35 -10
  6. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/modules/conv.py +4 -4
  7. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/modules/linear.py +3 -3
  8. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/modules/pointwise.py +1 -1
  9. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/modules/pool.py +1 -1
  10. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/modules/swin_attention.py +56 -64
  11. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/attention.py +124 -172
  12. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/general.py +29 -83
  13. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions.py +26 -3
  14. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/patterns/quantizations.py +80 -15
  15. embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/__init__.py +3 -0
  16. embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/attention.py +808 -0
  17. embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/functional.py +325 -0
  18. embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/general.py +718 -0
  19. embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/utils.py +44 -0
  20. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/patterns/smoothings.py +6 -4
  21. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/plan.py +24 -12
  22. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/tensorrt/__init__.py +2 -2
  23. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/tensorrt/patterns/__init__.py +0 -2
  24. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/version/public.py +1 -1
  25. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy_tensorrt.egg-info/SOURCES.txt +5 -2
  26. embedl_deploy_tensorrt-0.4.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions.py +0 -1475
  27. embedl_deploy_tensorrt-0.4.0/src/embedl_deploy/_internal/tensorrt/patterns/utils.py +0 -89
  28. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/LICENSE +0 -0
  29. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/MANIFEST.in +0 -0
  30. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/NOTICE +0 -0
  31. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/setup.cfg +0 -0
  32. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/__init__.py +0 -0
  33. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/__init__.py +0 -0
  34. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/__init__.py +0 -0
  35. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/modules/__init__.py +0 -0
  36. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/patterns/__init__.py +0 -0
  37. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/__init__.py +0 -0
  38. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/tensorrt/modules/__init__.py +0 -0
  39. {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/version/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: embedl-deploy-tensorrt
3
- Version: 0.4.0
3
+ Version: 0.5.0
4
4
  Summary: TensorRT backend for embedl-deploy.
5
5
  Author-email: Embedl AB <support@embedl.com>
6
6
  Project-URL: Homepage, https://www.embedl.com/
@@ -13,7 +13,6 @@ Requires-Python: >=3.10
13
13
  Description-Content-Type: text/markdown
14
14
  License-File: LICENSE
15
15
  License-File: NOTICE
16
- Requires-Dist: tensorrt
17
16
  Provides-Extra: core
18
17
  Requires-Dist: embedl-deploy; extra == "core"
19
18
  Dynamic: license-file
@@ -55,16 +54,17 @@ hardware target ensuring correct quantization and compilation.
55
54
 
56
55
  ## Supported Backends
57
56
 
58
- | Backend | Status |
59
- |---------------------|-------------|
60
- | NVIDIA TensorRT | Supported |
57
+ | Backend | Status |
58
+ |---------------------------|-----------------|
59
+ | NVIDIA TensorRT (v10.3) | Supported |
60
+ | Lattice SensAI (v8.0) | In Development |
61
61
 
62
- Contact us for other backends.
62
+ Contact Embedl for other backends.
63
63
 
64
64
  ## Installation
65
65
 
66
66
  ```bash
67
- pip install embedl-deploy
67
+ pip install "embedl-deploy[tensorrt]"
68
68
  ```
69
69
  Note that you may need to also install `onnx` and `onnx-simplifier` to export
70
70
  and get the exported model compiled with TensorRT if using ONNX as an
@@ -72,7 +72,7 @@ intermediate.
72
72
 
73
73
  ---
74
74
 
75
- ## Quick Start
75
+ ## Quick Start for TensorRT Backend
76
76
 
77
77
  ```python
78
78
  import torch
@@ -86,6 +86,9 @@ model = Model().eval()
86
86
  example_input = torch.randn(1, 3, 224, 224)
87
87
 
88
88
  # 2. Transform — fuse and optimize for TensorRT in one call
89
+ # For more compatibility you can trace your model with torch.export.export
90
+ # as follows:
91
+ # model = torch.export.export(model, (example_input)).module()
89
92
  res = transform(model, patterns=TENSORRT_PATTERNS)
90
93
  print("Model\n", res.model.print_readable())
91
94
  print("Matches", "\n".join([str(match) for match in res.matches]))
@@ -112,28 +115,54 @@ torch.onnx.export(
112
115
  qat_model = quantized_model.train()
113
116
  # Freeze BatchNorm, or apply other QAT utilities as needed
114
117
  # train(qat_model)
118
+ ```
119
+
120
+ ### Compile
121
+
122
+ Compilation can be done with TensorRT's trtexec tool, which can take the ONNX
123
+ model and compile it for inference. The exported layer info and profile can
124
+ be used for debugging, optimization and visualization.
125
+
126
+ Note: that the ONNX model might need to be simplified with onnx-simplifier to
127
+ make trtexec compile it. Dynamo exported models may have compilation issues,
128
+ so it's recommended to export with dynamo=False.
129
+
130
+ ```bash
131
+ onnxsim model.onnx model.onnx
132
+ /usr/src/tensorrt/bin/trtexec --onnx=model.onnx --fp16 --int8 --useCudaGraph
133
+ ```
134
+
135
+ Optionally you can get the layer profile with the following flags:
136
+ ```
137
+ --exportLayerInfo=layer_info.json
138
+ --exportProfile=profile.json
139
+ --profilingVerbosity=detailed
140
+ ```
115
141
 
116
- # Compile
117
- # -------
118
- # Compilation can be done with TensorRT's trtexec tool, which can take the ONNX
119
- # model and compile it for inference. The exported layer info and profile can
120
- # be used for debugging, optimization and visualization.
121
- #
122
- # Note: that the ONNX model might need to be simplified with onnx-simplifier to
123
- # make trtexec compile it. Dynamo exported models may have compilation issues,
124
- # so it's recommended to export with dynamo=False.
125
- #
126
- # We are working on a Aten-based export path that should be more robust and
127
- # support more models in the future.
128
-
129
- # >> onnxsim model.onnx model.onnx
130
- # >> trtexec \
131
- # --onnx=model.onnx \
132
- # --exportLayerInfo=layer_info.json \
133
- # --exportProfile=profile.json \
134
- # --profilingVerbosity=detailed
135
-
136
- # More benchmarking scripts can be found in the examples/ directory
142
+ ## Mixed Precision
143
+
144
+ To keep a specific layer in higher precision while quantizing the rest to INT8,
145
+ pass its `nn.Conv2d` instance to `ModulesToSkip` after `transform`. Note that
146
+ `torch.fx.GraphModule` deep-copies submodules during tracing, so you must take
147
+ the reference **from the fused graph**, not from the original model:
148
+
149
+ ```python
150
+ from embedl_deploy.quantize import quantize, QuantConfig, ModulesToSkip
151
+
152
+ res = transform(model, patterns=TENSORRT_PATTERNS)
153
+
154
+ # Grab the conv instance from the fused graph (not from the original model)
155
+ first_conv = res.model.FusedConvBNActMaxPool_0.conv
156
+
157
+ config = QuantConfig(
158
+ skip=ModulesToSkip(
159
+ stub={first_conv}, # disables input activation quantization
160
+ weight={first_conv}, # disables weight fake-quantization
161
+ )
162
+ )
163
+ quantized_model = quantize(
164
+ res.model, (example_input,), config=config, forward_loop=calibration_loop
165
+ )
137
166
  ```
138
167
 
139
168
  ## Design Principles
@@ -150,10 +179,13 @@ qat_model = quantized_model.train()
150
179
  `transform()` is a convenience for the common case where you want
151
180
  everything applied.
152
181
 
153
- 3. **FX-graph-based.**
154
- All graph analysis and surgery uses `torch.fx`. Models are traced once
155
- and manipulated as `fx.GraphModule` objects. Support for Aten graphs
156
- produced by `torch.export.export` is planned for the future.
182
+ 3. **Graph-based models (torch.export.export and symbolic traced).**
183
+ All graph analysis and surgery uses traced graphs. Models are traced once
184
+ and manipulated as `fx.GraphModule` objects with suport for tracing via both
185
+ `torch.fx` (symbolic) as well as `torch.export.export` (Aten). Support for
186
+ Aten graphs is automatically enabled using Aten recomposition
187
+ patterns that compose Aten operations into equivalent `torch.nn` modules
188
+ automatically before conversions and fusions.
157
189
 
158
190
  ## Support
159
191
 
@@ -35,16 +35,17 @@ hardware target ensuring correct quantization and compilation.
35
35
 
36
36
  ## Supported Backends
37
37
 
38
- | Backend | Status |
39
- |---------------------|-------------|
40
- | NVIDIA TensorRT | Supported |
38
+ | Backend | Status |
39
+ |---------------------------|-----------------|
40
+ | NVIDIA TensorRT (v10.3) | Supported |
41
+ | Lattice SensAI (v8.0) | In Development |
41
42
 
42
- Contact us for other backends.
43
+ Contact Embedl for other backends.
43
44
 
44
45
  ## Installation
45
46
 
46
47
  ```bash
47
- pip install embedl-deploy
48
+ pip install "embedl-deploy[tensorrt]"
48
49
  ```
49
50
  Note that you may need to also install `onnx` and `onnx-simplifier` to export
50
51
  and get the exported model compiled with TensorRT if using ONNX as an
@@ -52,7 +53,7 @@ intermediate.
52
53
 
53
54
  ---
54
55
 
55
- ## Quick Start
56
+ ## Quick Start for TensorRT Backend
56
57
 
57
58
  ```python
58
59
  import torch
@@ -66,6 +67,9 @@ model = Model().eval()
66
67
  example_input = torch.randn(1, 3, 224, 224)
67
68
 
68
69
  # 2. Transform — fuse and optimize for TensorRT in one call
70
+ # For more compatibility you can trace your model with torch.export.export
71
+ # as follows:
72
+ # model = torch.export.export(model, (example_input)).module()
69
73
  res = transform(model, patterns=TENSORRT_PATTERNS)
70
74
  print("Model\n", res.model.print_readable())
71
75
  print("Matches", "\n".join([str(match) for match in res.matches]))
@@ -92,28 +96,54 @@ torch.onnx.export(
92
96
  qat_model = quantized_model.train()
93
97
  # Freeze BatchNorm, or apply other QAT utilities as needed
94
98
  # train(qat_model)
99
+ ```
100
+
101
+ ### Compile
102
+
103
+ Compilation can be done with TensorRT's trtexec tool, which can take the ONNX
104
+ model and compile it for inference. The exported layer info and profile can
105
+ be used for debugging, optimization and visualization.
106
+
107
+ Note: that the ONNX model might need to be simplified with onnx-simplifier to
108
+ make trtexec compile it. Dynamo exported models may have compilation issues,
109
+ so it's recommended to export with dynamo=False.
110
+
111
+ ```bash
112
+ onnxsim model.onnx model.onnx
113
+ /usr/src/tensorrt/bin/trtexec --onnx=model.onnx --fp16 --int8 --useCudaGraph
114
+ ```
115
+
116
+ Optionally you can get the layer profile with the following flags:
117
+ ```
118
+ --exportLayerInfo=layer_info.json
119
+ --exportProfile=profile.json
120
+ --profilingVerbosity=detailed
121
+ ```
95
122
 
96
- # Compile
97
- # -------
98
- # Compilation can be done with TensorRT's trtexec tool, which can take the ONNX
99
- # model and compile it for inference. The exported layer info and profile can
100
- # be used for debugging, optimization and visualization.
101
- #
102
- # Note: that the ONNX model might need to be simplified with onnx-simplifier to
103
- # make trtexec compile it. Dynamo exported models may have compilation issues,
104
- # so it's recommended to export with dynamo=False.
105
- #
106
- # We are working on a Aten-based export path that should be more robust and
107
- # support more models in the future.
108
-
109
- # >> onnxsim model.onnx model.onnx
110
- # >> trtexec \
111
- # --onnx=model.onnx \
112
- # --exportLayerInfo=layer_info.json \
113
- # --exportProfile=profile.json \
114
- # --profilingVerbosity=detailed
115
-
116
- # More benchmarking scripts can be found in the examples/ directory
123
+ ## Mixed Precision
124
+
125
+ To keep a specific layer in higher precision while quantizing the rest to INT8,
126
+ pass its `nn.Conv2d` instance to `ModulesToSkip` after `transform`. Note that
127
+ `torch.fx.GraphModule` deep-copies submodules during tracing, so you must take
128
+ the reference **from the fused graph**, not from the original model:
129
+
130
+ ```python
131
+ from embedl_deploy.quantize import quantize, QuantConfig, ModulesToSkip
132
+
133
+ res = transform(model, patterns=TENSORRT_PATTERNS)
134
+
135
+ # Grab the conv instance from the fused graph (not from the original model)
136
+ first_conv = res.model.FusedConvBNActMaxPool_0.conv
137
+
138
+ config = QuantConfig(
139
+ skip=ModulesToSkip(
140
+ stub={first_conv}, # disables input activation quantization
141
+ weight={first_conv}, # disables weight fake-quantization
142
+ )
143
+ )
144
+ quantized_model = quantize(
145
+ res.model, (example_input,), config=config, forward_loop=calibration_loop
146
+ )
117
147
  ```
118
148
 
119
149
  ## Design Principles
@@ -130,10 +160,13 @@ qat_model = quantized_model.train()
130
160
  `transform()` is a convenience for the common case where you want
131
161
  everything applied.
132
162
 
133
- 3. **FX-graph-based.**
134
- All graph analysis and surgery uses `torch.fx`. Models are traced once
135
- and manipulated as `fx.GraphModule` objects. Support for Aten graphs
136
- produced by `torch.export.export` is planned for the future.
163
+ 3. **Graph-based models (torch.export.export and symbolic traced).**
164
+ All graph analysis and surgery uses traced graphs. Models are traced once
165
+ and manipulated as `fx.GraphModule` objects with suport for tracing via both
166
+ `torch.fx` (symbolic) as well as `torch.export.export` (Aten). Support for
167
+ Aten graphs is automatically enabled using Aten recomposition
168
+ patterns that compose Aten operations into equivalent `torch.nn` modules
169
+ automatically before conversions and fusions.
137
170
 
138
171
  ## Support
139
172
 
@@ -24,7 +24,7 @@ license-files = [
24
24
  readme = "README.md"
25
25
  description = "TensorRT backend for embedl-deploy."
26
26
  dynamic = ["version"]
27
- dependencies = ["tensorrt"]
27
+ dependencies = []
28
28
 
29
29
  [project.optional-dependencies]
30
30
  core = ["embedl-deploy"]
@@ -11,6 +11,7 @@ from embedl_deploy._internal.tensorrt.plan import (
11
11
  )
12
12
 
13
13
  BACKEND = Backend(
14
+ name="tensorrt",
14
15
  conversion_patterns=TENSORRT_CONVERSION_PATTERNS,
15
16
  fusion_patterns=TENSORRT_FUSION_PATTERNS,
16
17
  smooth_patterns=TENSORRT_SMOOTH_PATTERNS,
@@ -69,7 +69,7 @@ class MHAInProjection(ConvertedModule):
69
69
  v = v.view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
70
70
  return q, k, v
71
71
 
72
- def __repr__(self) -> str:
72
+ def __repr__(self) -> str: # pragma: no cover
73
73
  embed_dim = self.num_heads * self.head_dim
74
74
  return (
75
75
  f"MHAInProjection("
@@ -80,7 +80,7 @@ class MHAInProjection(ConvertedModule):
80
80
 
81
81
 
82
82
  class ScaledDotProductAttention(ConvertedModule):
83
- """Core attention: ``softmax(Q · Kᵀ / √H) · V``.
83
+ """Core attention: ``softmax(Q · Kᵀ · scale) · V``.
84
84
 
85
85
  :param num_heads:
86
86
  Number of attention heads.
@@ -88,6 +88,14 @@ class ScaledDotProductAttention(ConvertedModule):
88
88
  Dimension of each head.
89
89
  :param dropout:
90
90
  Dropout probability (applied during training only).
91
+ :param is_causal:
92
+ Whether to apply a causal mask. Mirrors the ``is_causal`` kwarg
93
+ of ``F.scaled_dot_product_attention``.
94
+ :param scale:
95
+ Explicit attention score scale (multiplied on Q·Kᵀ). When
96
+ ``None`` the PyTorch default ``1/√head_dim`` is used. Models
97
+ that pre-scale Q themselves (e.g. chronos-2 + RoPE) must pass
98
+ ``scale=1.0`` so the default scaling does not apply twice.
91
99
  """
92
100
 
93
101
  def __init__(
@@ -95,11 +103,15 @@ class ScaledDotProductAttention(ConvertedModule):
95
103
  num_heads: int,
96
104
  head_dim: int,
97
105
  dropout: float = 0.0,
106
+ is_causal: bool = False,
107
+ scale: float | None = None,
98
108
  ) -> None:
99
109
  super().__init__()
100
110
  self.num_heads = num_heads
101
111
  self.head_dim = head_dim
102
112
  self.dropout = dropout
113
+ self.is_causal = is_causal
114
+ self.scale = scale
103
115
 
104
116
  def forward(
105
117
  self,
@@ -117,8 +129,9 @@ class ScaledDotProductAttention(ConvertedModule):
117
129
  :param v:
118
130
  Value tensor ``[B, num_heads, S, head_dim]``.
119
131
  :param attn_mask:
120
- Optional attention mask. ``aten.scaled_dot_product_attention``
121
- takes an optional 4th positional arg; ``WrapAtenSDPAPattern``
132
+ Optional attention mask.
133
+ ``torch.nn.functional.scaled_dot_product_attention`` takes an
134
+ optional 4th positional arg; ``WrapFunctionalSDPAPattern``
122
135
  forwards whatever positional args were on the source node, so
123
136
  this module accepts the mask too. SAM3, masked-LM, and
124
137
  similar models that compile with mixed-mask attention rely
@@ -135,14 +148,18 @@ class ScaledDotProductAttention(ConvertedModule):
135
148
  v,
136
149
  attn_mask=attn_mask,
137
150
  dropout_p=self.dropout if self.training else 0.0,
151
+ is_causal=self.is_causal,
152
+ scale=self.scale,
138
153
  )
139
154
 
140
- def __repr__(self) -> str:
155
+ def __repr__(self) -> str: # pragma: no cover
141
156
  return (
142
157
  f"ScaledDotProductAttention("
143
158
  f"num_heads={self.num_heads}, "
144
159
  f"head_dim={self.head_dim}, "
145
- f"dropout={self.dropout})"
160
+ f"dropout={self.dropout}, "
161
+ f"is_causal={self.is_causal}, "
162
+ f"scale={self.scale})"
146
163
  )
147
164
 
148
165
 
@@ -197,7 +214,7 @@ class FusedMHAInProjection(FusedModule):
197
214
  v = v.view(batch, seq, num_heads, head_dim).transpose(1, 2)
198
215
  return q, k, v
199
216
 
200
- def __repr__(self) -> str:
217
+ def __repr__(self) -> str: # pragma: no cover
201
218
  embed_dim = self.in_proj.num_heads * self.in_proj.head_dim
202
219
  return (
203
220
  f"FusedMHAInProjection("
@@ -275,10 +292,18 @@ class FusedScaledDotProductAttention(FusedModule):
275
292
  # MHA kernel onto the slower INT8-aware variant for no gain.
276
293
  if not self.surrounded or not self.softmax_quant.enabled:
277
294
  return self.attention(q, k, v, attn_mask)
278
- # Use ``1/sqrt(head_dim)`` rather than ``head_dim ** -0.5``: the
295
+ # Honour the wrapped attention module's explicit ``scale`` if
296
+ # set — models that pre-scale Q themselves (chronos-2 + RoPE,
297
+ # for example) build with ``scale=1.0`` to disable the default
298
+ # ``1/sqrt(head_dim)`` scaling. Falling back to the default
299
+ # here would apply it twice and collapse softmax.
300
+ # Note on ``1/sqrt(head_dim)`` vs ``head_dim ** -0.5``: the
279
301
  # tensor Pow with a negative float exponent traces to ONNX as a
280
302
  # ``Cast → complex128`` node that TRT 10.x can't parse.
281
- scale = 1.0 / math.sqrt(q.shape[-1])
303
+ if self.attention.scale is not None:
304
+ scale = self.attention.scale
305
+ else:
306
+ scale = 1.0 / math.sqrt(q.shape[-1])
282
307
  attn_weight = torch.matmul(q, k.transpose(-2, -1)) * scale
283
308
  if attn_mask is not None:
284
309
  if attn_mask.dtype == torch.bool:
@@ -297,7 +322,7 @@ class FusedScaledDotProductAttention(FusedModule):
297
322
  )
298
323
  return torch.matmul(attn_weight, v)
299
324
 
300
- def __repr__(self) -> str:
325
+ def __repr__(self) -> str: # pragma: no cover
301
326
  a = self.attention
302
327
  qdq = "yes" if self.softmax_quant.enabled else "no"
303
328
  return (
@@ -91,7 +91,7 @@ class FusedConvBNAct(FusedModule):
91
91
  x = self.bn(x)
92
92
  return self.act(x)
93
93
 
94
- def __repr__(self) -> str:
94
+ def __repr__(self) -> str: # pragma: no cover
95
95
  bn_info = ""
96
96
  if self.bn is not None:
97
97
  bn_info = f", bn={self.bn.num_features} (foldable)"
@@ -129,7 +129,7 @@ class FusedConvBN(FusedModule):
129
129
  x = self.bn(x)
130
130
  return x
131
131
 
132
- def __repr__(self) -> str:
132
+ def __repr__(self) -> str: # pragma: no cover
133
133
  bn_info = ""
134
134
  if self.bn is not None:
135
135
  bn_info = f", bn={self.bn.num_features} (foldable)"
@@ -168,7 +168,7 @@ class FusedConvBNActMaxPool(FusedModule):
168
168
  x = self.act(x)
169
169
  return self.maxpool(x)
170
170
 
171
- def __repr__(self) -> str:
171
+ def __repr__(self) -> str: # pragma: no cover
172
172
  bn_info = ""
173
173
  if self.bn is not None:
174
174
  bn_info = f", bn={self.bn.num_features} (foldable)"
@@ -213,7 +213,7 @@ class FusedConvBNAddAct(FusedModule):
213
213
  x = self.bn(x)
214
214
  return self.act(x + residual)
215
215
 
216
- def __repr__(self) -> str:
216
+ def __repr__(self) -> str: # pragma: no cover
217
217
  return (
218
218
  f"FusedConvBNAddAct("
219
219
  f"{self.conv.in_channels}→{self.conv.out_channels}, "
@@ -82,7 +82,7 @@ class FusedLinear(FusedModule):
82
82
  # pylint: disable-next=not-callable
83
83
  return F.linear(x, weight, self.linear.bias)
84
84
 
85
- def __repr__(self) -> str:
85
+ def __repr__(self) -> str: # pragma: no cover
86
86
  return (
87
87
  f"FusedLinear("
88
88
  f"{self.linear.in_features}→{self.linear.out_features})"
@@ -113,7 +113,7 @@ class FusedLinearAct(FusedModule):
113
113
  x = F.linear(x, weight, self.linear.bias)
114
114
  return self.act(x)
115
115
 
116
- def __repr__(self) -> str:
116
+ def __repr__(self) -> str: # pragma: no cover
117
117
  act_name = type(self.act).__name__
118
118
  return (
119
119
  f"FusedLinearAct("
@@ -151,7 +151,7 @@ class FusedLayerNorm(FusedModule):
151
151
  """Apply ``layer_norm``."""
152
152
  return self.layer_norm(x)
153
153
 
154
- def __repr__(self) -> str:
154
+ def __repr__(self) -> str: # pragma: no cover
155
155
  return (
156
156
  f"FusedLayerNorm("
157
157
  f"normalized_shape={self.layer_norm.normalized_shape}, "
@@ -34,5 +34,5 @@ class FusedActAdd(FusedModule):
34
34
  """Apply ``act(x) + residual``."""
35
35
  return self.act(x) + residual
36
36
 
37
- def __repr__(self) -> str:
37
+ def __repr__(self) -> str: # pragma: no cover
38
38
  return f"FusedActAdd({type(self.act).__name__})"
@@ -21,5 +21,5 @@ class FusedAdaptiveAvgPool2d(FusedModule):
21
21
  """Apply adaptive average pooling."""
22
22
  return self.pool(x)
23
23
 
24
- def __repr__(self) -> str:
24
+ def __repr__(self) -> str: # pragma: no cover
25
25
  return f"FusedAdaptiveAvgPool2d(output_size={self.pool.output_size})"