embedl-deploy-tensorrt 0.4.0__tar.gz → 0.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/PKG-INFO +65 -33
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/README.md +64 -31
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/pyproject.toml +1 -1
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/backend.py +1 -0
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/modules/attention.py +35 -10
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/modules/conv.py +4 -4
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/modules/linear.py +3 -3
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/modules/pointwise.py +1 -1
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/modules/pool.py +1 -1
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/modules/swin_attention.py +56 -64
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/attention.py +124 -172
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/general.py +29 -83
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions.py +26 -3
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/patterns/quantizations.py +80 -15
- embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/__init__.py +3 -0
- embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/attention.py +808 -0
- embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/functional.py +325 -0
- embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/general.py +718 -0
- embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/utils.py +44 -0
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/patterns/smoothings.py +6 -4
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/plan.py +24 -12
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/tensorrt/__init__.py +2 -2
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/tensorrt/patterns/__init__.py +0 -2
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/version/public.py +1 -1
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy_tensorrt.egg-info/SOURCES.txt +5 -2
- embedl_deploy_tensorrt-0.4.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions.py +0 -1475
- embedl_deploy_tensorrt-0.4.0/src/embedl_deploy/_internal/tensorrt/patterns/utils.py +0 -89
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/LICENSE +0 -0
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/MANIFEST.in +0 -0
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/NOTICE +0 -0
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/setup.cfg +0 -0
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/modules/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/patterns/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/tensorrt/modules/__init__.py +0 -0
- {embedl_deploy_tensorrt-0.4.0 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/version/__init__.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: embedl-deploy-tensorrt
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.5.0
|
|
4
4
|
Summary: TensorRT backend for embedl-deploy.
|
|
5
5
|
Author-email: Embedl AB <support@embedl.com>
|
|
6
6
|
Project-URL: Homepage, https://www.embedl.com/
|
|
@@ -13,7 +13,6 @@ Requires-Python: >=3.10
|
|
|
13
13
|
Description-Content-Type: text/markdown
|
|
14
14
|
License-File: LICENSE
|
|
15
15
|
License-File: NOTICE
|
|
16
|
-
Requires-Dist: tensorrt
|
|
17
16
|
Provides-Extra: core
|
|
18
17
|
Requires-Dist: embedl-deploy; extra == "core"
|
|
19
18
|
Dynamic: license-file
|
|
@@ -55,16 +54,17 @@ hardware target ensuring correct quantization and compilation.
|
|
|
55
54
|
|
|
56
55
|
## Supported Backends
|
|
57
56
|
|
|
58
|
-
| Backend
|
|
59
|
-
|
|
60
|
-
| NVIDIA TensorRT
|
|
57
|
+
| Backend | Status |
|
|
58
|
+
|---------------------------|-----------------|
|
|
59
|
+
| NVIDIA TensorRT (v10.3) | Supported |
|
|
60
|
+
| Lattice SensAI (v8.0) | In Development |
|
|
61
61
|
|
|
62
|
-
Contact
|
|
62
|
+
Contact Embedl for other backends.
|
|
63
63
|
|
|
64
64
|
## Installation
|
|
65
65
|
|
|
66
66
|
```bash
|
|
67
|
-
pip install embedl-deploy
|
|
67
|
+
pip install "embedl-deploy[tensorrt]"
|
|
68
68
|
```
|
|
69
69
|
Note that you may need to also install `onnx` and `onnx-simplifier` to export
|
|
70
70
|
and get the exported model compiled with TensorRT if using ONNX as an
|
|
@@ -72,7 +72,7 @@ intermediate.
|
|
|
72
72
|
|
|
73
73
|
---
|
|
74
74
|
|
|
75
|
-
## Quick Start
|
|
75
|
+
## Quick Start for TensorRT Backend
|
|
76
76
|
|
|
77
77
|
```python
|
|
78
78
|
import torch
|
|
@@ -86,6 +86,9 @@ model = Model().eval()
|
|
|
86
86
|
example_input = torch.randn(1, 3, 224, 224)
|
|
87
87
|
|
|
88
88
|
# 2. Transform — fuse and optimize for TensorRT in one call
|
|
89
|
+
# For more compatibility you can trace your model with torch.export.export
|
|
90
|
+
# as follows:
|
|
91
|
+
# model = torch.export.export(model, (example_input)).module()
|
|
89
92
|
res = transform(model, patterns=TENSORRT_PATTERNS)
|
|
90
93
|
print("Model\n", res.model.print_readable())
|
|
91
94
|
print("Matches", "\n".join([str(match) for match in res.matches]))
|
|
@@ -112,28 +115,54 @@ torch.onnx.export(
|
|
|
112
115
|
qat_model = quantized_model.train()
|
|
113
116
|
# Freeze BatchNorm, or apply other QAT utilities as needed
|
|
114
117
|
# train(qat_model)
|
|
118
|
+
```
|
|
119
|
+
|
|
120
|
+
### Compile
|
|
121
|
+
|
|
122
|
+
Compilation can be done with TensorRT's trtexec tool, which can take the ONNX
|
|
123
|
+
model and compile it for inference. The exported layer info and profile can
|
|
124
|
+
be used for debugging, optimization and visualization.
|
|
125
|
+
|
|
126
|
+
Note: that the ONNX model might need to be simplified with onnx-simplifier to
|
|
127
|
+
make trtexec compile it. Dynamo exported models may have compilation issues,
|
|
128
|
+
so it's recommended to export with dynamo=False.
|
|
129
|
+
|
|
130
|
+
```bash
|
|
131
|
+
onnxsim model.onnx model.onnx
|
|
132
|
+
/usr/src/tensorrt/bin/trtexec --onnx=model.onnx --fp16 --int8 --useCudaGraph
|
|
133
|
+
```
|
|
134
|
+
|
|
135
|
+
Optionally you can get the layer profile with the following flags:
|
|
136
|
+
```
|
|
137
|
+
--exportLayerInfo=layer_info.json
|
|
138
|
+
--exportProfile=profile.json
|
|
139
|
+
--profilingVerbosity=detailed
|
|
140
|
+
```
|
|
115
141
|
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
#
|
|
134
|
-
#
|
|
135
|
-
|
|
136
|
-
|
|
142
|
+
## Mixed Precision
|
|
143
|
+
|
|
144
|
+
To keep a specific layer in higher precision while quantizing the rest to INT8,
|
|
145
|
+
pass its `nn.Conv2d` instance to `ModulesToSkip` after `transform`. Note that
|
|
146
|
+
`torch.fx.GraphModule` deep-copies submodules during tracing, so you must take
|
|
147
|
+
the reference **from the fused graph**, not from the original model:
|
|
148
|
+
|
|
149
|
+
```python
|
|
150
|
+
from embedl_deploy.quantize import quantize, QuantConfig, ModulesToSkip
|
|
151
|
+
|
|
152
|
+
res = transform(model, patterns=TENSORRT_PATTERNS)
|
|
153
|
+
|
|
154
|
+
# Grab the conv instance from the fused graph (not from the original model)
|
|
155
|
+
first_conv = res.model.FusedConvBNActMaxPool_0.conv
|
|
156
|
+
|
|
157
|
+
config = QuantConfig(
|
|
158
|
+
skip=ModulesToSkip(
|
|
159
|
+
stub={first_conv}, # disables input activation quantization
|
|
160
|
+
weight={first_conv}, # disables weight fake-quantization
|
|
161
|
+
)
|
|
162
|
+
)
|
|
163
|
+
quantized_model = quantize(
|
|
164
|
+
res.model, (example_input,), config=config, forward_loop=calibration_loop
|
|
165
|
+
)
|
|
137
166
|
```
|
|
138
167
|
|
|
139
168
|
## Design Principles
|
|
@@ -150,10 +179,13 @@ qat_model = quantized_model.train()
|
|
|
150
179
|
`transform()` is a convenience for the common case where you want
|
|
151
180
|
everything applied.
|
|
152
181
|
|
|
153
|
-
3. **
|
|
154
|
-
All graph analysis and surgery uses
|
|
155
|
-
and manipulated as `fx.GraphModule` objects
|
|
156
|
-
|
|
182
|
+
3. **Graph-based models (torch.export.export and symbolic traced).**
|
|
183
|
+
All graph analysis and surgery uses traced graphs. Models are traced once
|
|
184
|
+
and manipulated as `fx.GraphModule` objects with suport for tracing via both
|
|
185
|
+
`torch.fx` (symbolic) as well as `torch.export.export` (Aten). Support for
|
|
186
|
+
Aten graphs is automatically enabled using Aten recomposition
|
|
187
|
+
patterns that compose Aten operations into equivalent `torch.nn` modules
|
|
188
|
+
automatically before conversions and fusions.
|
|
157
189
|
|
|
158
190
|
## Support
|
|
159
191
|
|
|
@@ -35,16 +35,17 @@ hardware target ensuring correct quantization and compilation.
|
|
|
35
35
|
|
|
36
36
|
## Supported Backends
|
|
37
37
|
|
|
38
|
-
| Backend
|
|
39
|
-
|
|
40
|
-
| NVIDIA TensorRT
|
|
38
|
+
| Backend | Status |
|
|
39
|
+
|---------------------------|-----------------|
|
|
40
|
+
| NVIDIA TensorRT (v10.3) | Supported |
|
|
41
|
+
| Lattice SensAI (v8.0) | In Development |
|
|
41
42
|
|
|
42
|
-
Contact
|
|
43
|
+
Contact Embedl for other backends.
|
|
43
44
|
|
|
44
45
|
## Installation
|
|
45
46
|
|
|
46
47
|
```bash
|
|
47
|
-
pip install embedl-deploy
|
|
48
|
+
pip install "embedl-deploy[tensorrt]"
|
|
48
49
|
```
|
|
49
50
|
Note that you may need to also install `onnx` and `onnx-simplifier` to export
|
|
50
51
|
and get the exported model compiled with TensorRT if using ONNX as an
|
|
@@ -52,7 +53,7 @@ intermediate.
|
|
|
52
53
|
|
|
53
54
|
---
|
|
54
55
|
|
|
55
|
-
## Quick Start
|
|
56
|
+
## Quick Start for TensorRT Backend
|
|
56
57
|
|
|
57
58
|
```python
|
|
58
59
|
import torch
|
|
@@ -66,6 +67,9 @@ model = Model().eval()
|
|
|
66
67
|
example_input = torch.randn(1, 3, 224, 224)
|
|
67
68
|
|
|
68
69
|
# 2. Transform — fuse and optimize for TensorRT in one call
|
|
70
|
+
# For more compatibility you can trace your model with torch.export.export
|
|
71
|
+
# as follows:
|
|
72
|
+
# model = torch.export.export(model, (example_input)).module()
|
|
69
73
|
res = transform(model, patterns=TENSORRT_PATTERNS)
|
|
70
74
|
print("Model\n", res.model.print_readable())
|
|
71
75
|
print("Matches", "\n".join([str(match) for match in res.matches]))
|
|
@@ -92,28 +96,54 @@ torch.onnx.export(
|
|
|
92
96
|
qat_model = quantized_model.train()
|
|
93
97
|
# Freeze BatchNorm, or apply other QAT utilities as needed
|
|
94
98
|
# train(qat_model)
|
|
99
|
+
```
|
|
100
|
+
|
|
101
|
+
### Compile
|
|
102
|
+
|
|
103
|
+
Compilation can be done with TensorRT's trtexec tool, which can take the ONNX
|
|
104
|
+
model and compile it for inference. The exported layer info and profile can
|
|
105
|
+
be used for debugging, optimization and visualization.
|
|
106
|
+
|
|
107
|
+
Note: that the ONNX model might need to be simplified with onnx-simplifier to
|
|
108
|
+
make trtexec compile it. Dynamo exported models may have compilation issues,
|
|
109
|
+
so it's recommended to export with dynamo=False.
|
|
110
|
+
|
|
111
|
+
```bash
|
|
112
|
+
onnxsim model.onnx model.onnx
|
|
113
|
+
/usr/src/tensorrt/bin/trtexec --onnx=model.onnx --fp16 --int8 --useCudaGraph
|
|
114
|
+
```
|
|
115
|
+
|
|
116
|
+
Optionally you can get the layer profile with the following flags:
|
|
117
|
+
```
|
|
118
|
+
--exportLayerInfo=layer_info.json
|
|
119
|
+
--exportProfile=profile.json
|
|
120
|
+
--profilingVerbosity=detailed
|
|
121
|
+
```
|
|
95
122
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
#
|
|
114
|
-
#
|
|
115
|
-
|
|
116
|
-
|
|
123
|
+
## Mixed Precision
|
|
124
|
+
|
|
125
|
+
To keep a specific layer in higher precision while quantizing the rest to INT8,
|
|
126
|
+
pass its `nn.Conv2d` instance to `ModulesToSkip` after `transform`. Note that
|
|
127
|
+
`torch.fx.GraphModule` deep-copies submodules during tracing, so you must take
|
|
128
|
+
the reference **from the fused graph**, not from the original model:
|
|
129
|
+
|
|
130
|
+
```python
|
|
131
|
+
from embedl_deploy.quantize import quantize, QuantConfig, ModulesToSkip
|
|
132
|
+
|
|
133
|
+
res = transform(model, patterns=TENSORRT_PATTERNS)
|
|
134
|
+
|
|
135
|
+
# Grab the conv instance from the fused graph (not from the original model)
|
|
136
|
+
first_conv = res.model.FusedConvBNActMaxPool_0.conv
|
|
137
|
+
|
|
138
|
+
config = QuantConfig(
|
|
139
|
+
skip=ModulesToSkip(
|
|
140
|
+
stub={first_conv}, # disables input activation quantization
|
|
141
|
+
weight={first_conv}, # disables weight fake-quantization
|
|
142
|
+
)
|
|
143
|
+
)
|
|
144
|
+
quantized_model = quantize(
|
|
145
|
+
res.model, (example_input,), config=config, forward_loop=calibration_loop
|
|
146
|
+
)
|
|
117
147
|
```
|
|
118
148
|
|
|
119
149
|
## Design Principles
|
|
@@ -130,10 +160,13 @@ qat_model = quantized_model.train()
|
|
|
130
160
|
`transform()` is a convenience for the common case where you want
|
|
131
161
|
everything applied.
|
|
132
162
|
|
|
133
|
-
3. **
|
|
134
|
-
All graph analysis and surgery uses
|
|
135
|
-
and manipulated as `fx.GraphModule` objects
|
|
136
|
-
|
|
163
|
+
3. **Graph-based models (torch.export.export and symbolic traced).**
|
|
164
|
+
All graph analysis and surgery uses traced graphs. Models are traced once
|
|
165
|
+
and manipulated as `fx.GraphModule` objects with suport for tracing via both
|
|
166
|
+
`torch.fx` (symbolic) as well as `torch.export.export` (Aten). Support for
|
|
167
|
+
Aten graphs is automatically enabled using Aten recomposition
|
|
168
|
+
patterns that compose Aten operations into equivalent `torch.nn` modules
|
|
169
|
+
automatically before conversions and fusions.
|
|
137
170
|
|
|
138
171
|
## Support
|
|
139
172
|
|
|
@@ -69,7 +69,7 @@ class MHAInProjection(ConvertedModule):
|
|
|
69
69
|
v = v.view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
|
|
70
70
|
return q, k, v
|
|
71
71
|
|
|
72
|
-
def __repr__(self) -> str:
|
|
72
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
73
73
|
embed_dim = self.num_heads * self.head_dim
|
|
74
74
|
return (
|
|
75
75
|
f"MHAInProjection("
|
|
@@ -80,7 +80,7 @@ class MHAInProjection(ConvertedModule):
|
|
|
80
80
|
|
|
81
81
|
|
|
82
82
|
class ScaledDotProductAttention(ConvertedModule):
|
|
83
|
-
"""Core attention: ``softmax(Q · Kᵀ
|
|
83
|
+
"""Core attention: ``softmax(Q · Kᵀ · scale) · V``.
|
|
84
84
|
|
|
85
85
|
:param num_heads:
|
|
86
86
|
Number of attention heads.
|
|
@@ -88,6 +88,14 @@ class ScaledDotProductAttention(ConvertedModule):
|
|
|
88
88
|
Dimension of each head.
|
|
89
89
|
:param dropout:
|
|
90
90
|
Dropout probability (applied during training only).
|
|
91
|
+
:param is_causal:
|
|
92
|
+
Whether to apply a causal mask. Mirrors the ``is_causal`` kwarg
|
|
93
|
+
of ``F.scaled_dot_product_attention``.
|
|
94
|
+
:param scale:
|
|
95
|
+
Explicit attention score scale (multiplied on Q·Kᵀ). When
|
|
96
|
+
``None`` the PyTorch default ``1/√head_dim`` is used. Models
|
|
97
|
+
that pre-scale Q themselves (e.g. chronos-2 + RoPE) must pass
|
|
98
|
+
``scale=1.0`` so the default scaling does not apply twice.
|
|
91
99
|
"""
|
|
92
100
|
|
|
93
101
|
def __init__(
|
|
@@ -95,11 +103,15 @@ class ScaledDotProductAttention(ConvertedModule):
|
|
|
95
103
|
num_heads: int,
|
|
96
104
|
head_dim: int,
|
|
97
105
|
dropout: float = 0.0,
|
|
106
|
+
is_causal: bool = False,
|
|
107
|
+
scale: float | None = None,
|
|
98
108
|
) -> None:
|
|
99
109
|
super().__init__()
|
|
100
110
|
self.num_heads = num_heads
|
|
101
111
|
self.head_dim = head_dim
|
|
102
112
|
self.dropout = dropout
|
|
113
|
+
self.is_causal = is_causal
|
|
114
|
+
self.scale = scale
|
|
103
115
|
|
|
104
116
|
def forward(
|
|
105
117
|
self,
|
|
@@ -117,8 +129,9 @@ class ScaledDotProductAttention(ConvertedModule):
|
|
|
117
129
|
:param v:
|
|
118
130
|
Value tensor ``[B, num_heads, S, head_dim]``.
|
|
119
131
|
:param attn_mask:
|
|
120
|
-
Optional attention mask.
|
|
121
|
-
takes an
|
|
132
|
+
Optional attention mask.
|
|
133
|
+
``torch.nn.functional.scaled_dot_product_attention`` takes an
|
|
134
|
+
optional 4th positional arg; ``WrapFunctionalSDPAPattern``
|
|
122
135
|
forwards whatever positional args were on the source node, so
|
|
123
136
|
this module accepts the mask too. SAM3, masked-LM, and
|
|
124
137
|
similar models that compile with mixed-mask attention rely
|
|
@@ -135,14 +148,18 @@ class ScaledDotProductAttention(ConvertedModule):
|
|
|
135
148
|
v,
|
|
136
149
|
attn_mask=attn_mask,
|
|
137
150
|
dropout_p=self.dropout if self.training else 0.0,
|
|
151
|
+
is_causal=self.is_causal,
|
|
152
|
+
scale=self.scale,
|
|
138
153
|
)
|
|
139
154
|
|
|
140
|
-
def __repr__(self) -> str:
|
|
155
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
141
156
|
return (
|
|
142
157
|
f"ScaledDotProductAttention("
|
|
143
158
|
f"num_heads={self.num_heads}, "
|
|
144
159
|
f"head_dim={self.head_dim}, "
|
|
145
|
-
f"dropout={self.dropout}
|
|
160
|
+
f"dropout={self.dropout}, "
|
|
161
|
+
f"is_causal={self.is_causal}, "
|
|
162
|
+
f"scale={self.scale})"
|
|
146
163
|
)
|
|
147
164
|
|
|
148
165
|
|
|
@@ -197,7 +214,7 @@ class FusedMHAInProjection(FusedModule):
|
|
|
197
214
|
v = v.view(batch, seq, num_heads, head_dim).transpose(1, 2)
|
|
198
215
|
return q, k, v
|
|
199
216
|
|
|
200
|
-
def __repr__(self) -> str:
|
|
217
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
201
218
|
embed_dim = self.in_proj.num_heads * self.in_proj.head_dim
|
|
202
219
|
return (
|
|
203
220
|
f"FusedMHAInProjection("
|
|
@@ -275,10 +292,18 @@ class FusedScaledDotProductAttention(FusedModule):
|
|
|
275
292
|
# MHA kernel onto the slower INT8-aware variant for no gain.
|
|
276
293
|
if not self.surrounded or not self.softmax_quant.enabled:
|
|
277
294
|
return self.attention(q, k, v, attn_mask)
|
|
278
|
-
#
|
|
295
|
+
# Honour the wrapped attention module's explicit ``scale`` if
|
|
296
|
+
# set — models that pre-scale Q themselves (chronos-2 + RoPE,
|
|
297
|
+
# for example) build with ``scale=1.0`` to disable the default
|
|
298
|
+
# ``1/sqrt(head_dim)`` scaling. Falling back to the default
|
|
299
|
+
# here would apply it twice and collapse softmax.
|
|
300
|
+
# Note on ``1/sqrt(head_dim)`` vs ``head_dim ** -0.5``: the
|
|
279
301
|
# tensor Pow with a negative float exponent traces to ONNX as a
|
|
280
302
|
# ``Cast → complex128`` node that TRT 10.x can't parse.
|
|
281
|
-
scale
|
|
303
|
+
if self.attention.scale is not None:
|
|
304
|
+
scale = self.attention.scale
|
|
305
|
+
else:
|
|
306
|
+
scale = 1.0 / math.sqrt(q.shape[-1])
|
|
282
307
|
attn_weight = torch.matmul(q, k.transpose(-2, -1)) * scale
|
|
283
308
|
if attn_mask is not None:
|
|
284
309
|
if attn_mask.dtype == torch.bool:
|
|
@@ -297,7 +322,7 @@ class FusedScaledDotProductAttention(FusedModule):
|
|
|
297
322
|
)
|
|
298
323
|
return torch.matmul(attn_weight, v)
|
|
299
324
|
|
|
300
|
-
def __repr__(self) -> str:
|
|
325
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
301
326
|
a = self.attention
|
|
302
327
|
qdq = "yes" if self.softmax_quant.enabled else "no"
|
|
303
328
|
return (
|
|
@@ -91,7 +91,7 @@ class FusedConvBNAct(FusedModule):
|
|
|
91
91
|
x = self.bn(x)
|
|
92
92
|
return self.act(x)
|
|
93
93
|
|
|
94
|
-
def __repr__(self) -> str:
|
|
94
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
95
95
|
bn_info = ""
|
|
96
96
|
if self.bn is not None:
|
|
97
97
|
bn_info = f", bn={self.bn.num_features} (foldable)"
|
|
@@ -129,7 +129,7 @@ class FusedConvBN(FusedModule):
|
|
|
129
129
|
x = self.bn(x)
|
|
130
130
|
return x
|
|
131
131
|
|
|
132
|
-
def __repr__(self) -> str:
|
|
132
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
133
133
|
bn_info = ""
|
|
134
134
|
if self.bn is not None:
|
|
135
135
|
bn_info = f", bn={self.bn.num_features} (foldable)"
|
|
@@ -168,7 +168,7 @@ class FusedConvBNActMaxPool(FusedModule):
|
|
|
168
168
|
x = self.act(x)
|
|
169
169
|
return self.maxpool(x)
|
|
170
170
|
|
|
171
|
-
def __repr__(self) -> str:
|
|
171
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
172
172
|
bn_info = ""
|
|
173
173
|
if self.bn is not None:
|
|
174
174
|
bn_info = f", bn={self.bn.num_features} (foldable)"
|
|
@@ -213,7 +213,7 @@ class FusedConvBNAddAct(FusedModule):
|
|
|
213
213
|
x = self.bn(x)
|
|
214
214
|
return self.act(x + residual)
|
|
215
215
|
|
|
216
|
-
def __repr__(self) -> str:
|
|
216
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
217
217
|
return (
|
|
218
218
|
f"FusedConvBNAddAct("
|
|
219
219
|
f"{self.conv.in_channels}→{self.conv.out_channels}, "
|
|
@@ -82,7 +82,7 @@ class FusedLinear(FusedModule):
|
|
|
82
82
|
# pylint: disable-next=not-callable
|
|
83
83
|
return F.linear(x, weight, self.linear.bias)
|
|
84
84
|
|
|
85
|
-
def __repr__(self) -> str:
|
|
85
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
86
86
|
return (
|
|
87
87
|
f"FusedLinear("
|
|
88
88
|
f"{self.linear.in_features}→{self.linear.out_features})"
|
|
@@ -113,7 +113,7 @@ class FusedLinearAct(FusedModule):
|
|
|
113
113
|
x = F.linear(x, weight, self.linear.bias)
|
|
114
114
|
return self.act(x)
|
|
115
115
|
|
|
116
|
-
def __repr__(self) -> str:
|
|
116
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
117
117
|
act_name = type(self.act).__name__
|
|
118
118
|
return (
|
|
119
119
|
f"FusedLinearAct("
|
|
@@ -151,7 +151,7 @@ class FusedLayerNorm(FusedModule):
|
|
|
151
151
|
"""Apply ``layer_norm``."""
|
|
152
152
|
return self.layer_norm(x)
|
|
153
153
|
|
|
154
|
-
def __repr__(self) -> str:
|
|
154
|
+
def __repr__(self) -> str: # pragma: no cover
|
|
155
155
|
return (
|
|
156
156
|
f"FusedLayerNorm("
|
|
157
157
|
f"normalized_shape={self.layer_norm.normalized_shape}, "
|