embedl-deploy 0.3.0__tar.gz → 0.4.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/MANIFEST.in +2 -0
- {embedl_deploy-0.3.0/src/embedl_deploy.egg-info → embedl_deploy-0.4.1}/PKG-INFO +64 -32
- {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/README.md +62 -30
- {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/pyproject.toml +63 -12
- {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/_internal/core/backend.py +2 -2
- {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/_internal/core/modules.py +4 -0
- embedl_deploy-0.4.1/src/embedl_deploy/_internal/core/pattern.py +204 -0
- {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/_internal/core/plan.py +51 -11
- {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/_internal/core/quantize/calibrate.py +5 -4
- {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/_internal/core/quantize/main.py +15 -1
- {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/_internal/core/quantize/stubs.py +2 -1
- embedl_deploy-0.4.1/src/embedl_deploy/_internal/core/tree/__init__.py +3 -0
- embedl_deploy-0.4.1/src/embedl_deploy/_internal/core/tree/match.py +334 -0
- {embedl_deploy-0.3.0/src/embedl_deploy/_internal/core → embedl_deploy-0.4.1/src/embedl_deploy/_internal/core/tree}/replace.py +93 -45
- embedl_deploy-0.4.1/src/embedl_deploy/_internal/core/tree/types.py +325 -0
- embedl_deploy-0.4.1/src/embedl_deploy/_internal/core/tree/utils.py +64 -0
- {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/version/public.py +1 -1
- {embedl_deploy-0.3.0 → embedl_deploy-0.4.1/src/embedl_deploy.egg-info}/PKG-INFO +64 -32
- embedl_deploy-0.4.1/src/embedl_deploy.egg-info/SOURCES.txt +35 -0
- embedl_deploy-0.4.1/src/embedl_deploy.egg-info/requires.txt +4 -0
- embedl_deploy-0.3.0/src/embedl_deploy/_internal/core/match.py +0 -256
- embedl_deploy-0.3.0/src/embedl_deploy/_internal/core/pattern.py +0 -480
- embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/__init__.py +0 -3
- embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/backend.py +0 -18
- embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/modules/__init__.py +0 -3
- embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/modules/attention.py +0 -275
- embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/modules/conv.py +0 -238
- embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/modules/linear.py +0 -159
- embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/modules/pointwise.py +0 -39
- embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/modules/pool.py +0 -25
- embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/modules/swin_attention.py +0 -460
- embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/__init__.py +0 -3
- embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/conversions/__init__.py +0 -15
- embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/conversions/attention.py +0 -891
- embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/conversions/general.py +0 -356
- embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/fusions/__init__.py +0 -3
- embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/fusions/attention.py +0 -87
- embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/fusions/conv.py +0 -196
- embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/fusions/linear.py +0 -86
- embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/fusions/pointwise.py +0 -55
- embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/fusions/pool.py +0 -50
- embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/quantizations.py +0 -329
- embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions.py +0 -1584
- embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/smoothings.py +0 -123
- embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/utils.py +0 -81
- embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/plan.py +0 -131
- embedl_deploy-0.3.0/src/embedl_deploy/tensorrt/__init__.py +0 -45
- embedl_deploy-0.3.0/src/embedl_deploy/tensorrt/modules/__init__.py +0 -40
- embedl_deploy-0.3.0/src/embedl_deploy/tensorrt/patterns/__init__.py +0 -60
- embedl_deploy-0.3.0/src/embedl_deploy.egg-info/SOURCES.txt +0 -59
- embedl_deploy-0.3.0/src/embedl_deploy.egg-info/requires.txt +0 -4
- {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/LICENSE +0 -0
- {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/NOTICE +0 -0
- {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/setup.cfg +0 -0
- {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/__init__.py +1 -1
- {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/_internal/__init__.py +0 -0
- {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/_internal/core/__init__.py +0 -0
- {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/_internal/core/quantize/__init__.py +0 -0
- {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/_internal/core/quantize/config.py +0 -0
- {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/_internal/core/quantize/prepare.py +0 -0
- {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/_internal/core/quantize/qat.py +0 -0
- {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/_internal/core/quantize/utils.py +0 -0
- {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/backend/__init__.py +0 -0
- {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/py.typed +0 -0
- {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/quantize/__init__.py +1 -1
- {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/version/__init__.py +0 -0
- {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy.egg-info/dependency_links.txt +0 -0
- {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: embedl-deploy
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.1
|
|
4
4
|
Summary: Python package to make AI models deployment-ready for any hardware.
|
|
5
5
|
Author-email: Embedl AB <support@embedl.com>
|
|
6
6
|
Project-URL: Homepage, https://www.embedl.com/
|
|
@@ -15,7 +15,7 @@ License-File: LICENSE
|
|
|
15
15
|
License-File: NOTICE
|
|
16
16
|
Requires-Dist: torch
|
|
17
17
|
Provides-Extra: tensorrt
|
|
18
|
-
Requires-Dist: tensorrt; extra == "tensorrt"
|
|
18
|
+
Requires-Dist: embedl-deploy-tensorrt; extra == "tensorrt"
|
|
19
19
|
Dynamic: license-file
|
|
20
20
|
|
|
21
21
|
# embedl-deploy
|
|
@@ -55,16 +55,16 @@ hardware target ensuring correct quantization and compilation.
|
|
|
55
55
|
|
|
56
56
|
## Supported Backends
|
|
57
57
|
|
|
58
|
-
| Backend
|
|
59
|
-
|
|
60
|
-
| NVIDIA TensorRT
|
|
58
|
+
| Backend | Status |
|
|
59
|
+
|-------------------------|-------------|
|
|
60
|
+
| NVIDIA TensorRT (v10.3) | Supported |
|
|
61
61
|
|
|
62
|
-
Contact
|
|
62
|
+
Contact Embedl for other backends.
|
|
63
63
|
|
|
64
64
|
## Installation
|
|
65
65
|
|
|
66
66
|
```bash
|
|
67
|
-
pip install embedl-deploy
|
|
67
|
+
pip install "embedl-deploy[tensorrt]"
|
|
68
68
|
```
|
|
69
69
|
Note that you may need to also install `onnx` and `onnx-simplifier` to export
|
|
70
70
|
and get the exported model compiled with TensorRT if using ONNX as an
|
|
@@ -86,6 +86,9 @@ model = Model().eval()
|
|
|
86
86
|
example_input = torch.randn(1, 3, 224, 224)
|
|
87
87
|
|
|
88
88
|
# 2. Transform — fuse and optimize for TensorRT in one call
|
|
89
|
+
# For more compatibilty you can trace your model with torch.export.export
|
|
90
|
+
# as follows:
|
|
91
|
+
# model = torch.export.export(model, (example_input)).module()
|
|
89
92
|
res = transform(model, patterns=TENSORRT_PATTERNS)
|
|
90
93
|
print("Model\n", res.model.print_readable())
|
|
91
94
|
print("Matches", "\n".join([str(match) for match in res.matches]))
|
|
@@ -112,28 +115,54 @@ torch.onnx.export(
|
|
|
112
115
|
qat_model = quantized_model.train()
|
|
113
116
|
# Freeze BatchNorm, or apply other QAT utilities as needed
|
|
114
117
|
# train(qat_model)
|
|
118
|
+
```
|
|
119
|
+
|
|
120
|
+
### Compile
|
|
121
|
+
|
|
122
|
+
Compilation can be done with TensorRT's trtexec tool, which can take the ONNX
|
|
123
|
+
model and compile it for inference. The exported layer info and profile can
|
|
124
|
+
be used for debugging, optimization and visualization.
|
|
125
|
+
|
|
126
|
+
Note: that the ONNX model might need to be simplified with onnx-simplifier to
|
|
127
|
+
make trtexec compile it. Dynamo exported models may have compilation issues,
|
|
128
|
+
so it's recommended to export with dynamo=False.
|
|
129
|
+
|
|
130
|
+
```bash
|
|
131
|
+
onnxsim model.onnx model.onnx
|
|
132
|
+
/usr/src/tensorrt/bin/trtexec --onnx=model.onnx --fp16 --int8 --useCudaGraph
|
|
133
|
+
```
|
|
134
|
+
|
|
135
|
+
Optionally you can get the layer profile with the following flags:
|
|
136
|
+
```
|
|
137
|
+
--exportLayerInfo=layer_info.json
|
|
138
|
+
--exportProfile=profile.json
|
|
139
|
+
--profilingVerbosity=detailed
|
|
140
|
+
```
|
|
115
141
|
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
#
|
|
134
|
-
#
|
|
135
|
-
|
|
136
|
-
|
|
142
|
+
## Mixed Precision
|
|
143
|
+
|
|
144
|
+
To keep a specific layer in higher precision while quantizing the rest to INT8,
|
|
145
|
+
pass its `nn.Conv2d` instance to `ModulesToSkip` after `transform`. Note that
|
|
146
|
+
`torch.fx.GraphModule` deep-copies submodules during tracing, so you must take
|
|
147
|
+
the reference **from the fused graph**, not from the original model:
|
|
148
|
+
|
|
149
|
+
```python
|
|
150
|
+
from embedl_deploy.quantize import quantize, QuantConfig, ModulesToSkip
|
|
151
|
+
|
|
152
|
+
res = transform(model, patterns=TENSORRT_PATTERNS)
|
|
153
|
+
|
|
154
|
+
# Grab the conv instance from the fused graph (not from the original model)
|
|
155
|
+
first_conv = res.model.FusedConvBNActMaxPool_0.conv
|
|
156
|
+
|
|
157
|
+
config = QuantConfig(
|
|
158
|
+
skip=ModulesToSkip(
|
|
159
|
+
stub={first_conv}, # disables input activation quantization
|
|
160
|
+
weight={first_conv}, # disables weight fake-quantization
|
|
161
|
+
)
|
|
162
|
+
)
|
|
163
|
+
quantized_model = quantize(
|
|
164
|
+
res.model, (example_input,), config=config, forward_loop=calibration_loop
|
|
165
|
+
)
|
|
137
166
|
```
|
|
138
167
|
|
|
139
168
|
## Design Principles
|
|
@@ -150,10 +179,13 @@ qat_model = quantized_model.train()
|
|
|
150
179
|
`transform()` is a convenience for the common case where you want
|
|
151
180
|
everything applied.
|
|
152
181
|
|
|
153
|
-
3. **
|
|
154
|
-
All graph analysis and surgery uses
|
|
155
|
-
and manipulated as `fx.GraphModule` objects
|
|
156
|
-
|
|
182
|
+
3. **Graph-based models (torch.export.export and symbolic traced).**
|
|
183
|
+
All graph analysis and surgery uses traced graphs. Models are traced once
|
|
184
|
+
and manipulated as `fx.GraphModule` objects with suport for tracing via both
|
|
185
|
+
`torch.fx` (symbolic) as well as `torch.export.export` (Aten). Support for
|
|
186
|
+
Aten graphs is automatically enabled using Aten recomposition
|
|
187
|
+
patterns that compose Aten operations into equivalent `torch.nn` modules
|
|
188
|
+
automatically before conversions and fusions.
|
|
157
189
|
|
|
158
190
|
## Support
|
|
159
191
|
|
|
@@ -35,16 +35,16 @@ hardware target ensuring correct quantization and compilation.
|
|
|
35
35
|
|
|
36
36
|
## Supported Backends
|
|
37
37
|
|
|
38
|
-
| Backend
|
|
39
|
-
|
|
40
|
-
| NVIDIA TensorRT
|
|
38
|
+
| Backend | Status |
|
|
39
|
+
|-------------------------|-------------|
|
|
40
|
+
| NVIDIA TensorRT (v10.3) | Supported |
|
|
41
41
|
|
|
42
|
-
Contact
|
|
42
|
+
Contact Embedl for other backends.
|
|
43
43
|
|
|
44
44
|
## Installation
|
|
45
45
|
|
|
46
46
|
```bash
|
|
47
|
-
pip install embedl-deploy
|
|
47
|
+
pip install "embedl-deploy[tensorrt]"
|
|
48
48
|
```
|
|
49
49
|
Note that you may need to also install `onnx` and `onnx-simplifier` to export
|
|
50
50
|
and get the exported model compiled with TensorRT if using ONNX as an
|
|
@@ -66,6 +66,9 @@ model = Model().eval()
|
|
|
66
66
|
example_input = torch.randn(1, 3, 224, 224)
|
|
67
67
|
|
|
68
68
|
# 2. Transform — fuse and optimize for TensorRT in one call
|
|
69
|
+
# For more compatibilty you can trace your model with torch.export.export
|
|
70
|
+
# as follows:
|
|
71
|
+
# model = torch.export.export(model, (example_input)).module()
|
|
69
72
|
res = transform(model, patterns=TENSORRT_PATTERNS)
|
|
70
73
|
print("Model\n", res.model.print_readable())
|
|
71
74
|
print("Matches", "\n".join([str(match) for match in res.matches]))
|
|
@@ -92,28 +95,54 @@ torch.onnx.export(
|
|
|
92
95
|
qat_model = quantized_model.train()
|
|
93
96
|
# Freeze BatchNorm, or apply other QAT utilities as needed
|
|
94
97
|
# train(qat_model)
|
|
98
|
+
```
|
|
99
|
+
|
|
100
|
+
### Compile
|
|
101
|
+
|
|
102
|
+
Compilation can be done with TensorRT's trtexec tool, which can take the ONNX
|
|
103
|
+
model and compile it for inference. The exported layer info and profile can
|
|
104
|
+
be used for debugging, optimization and visualization.
|
|
105
|
+
|
|
106
|
+
Note: that the ONNX model might need to be simplified with onnx-simplifier to
|
|
107
|
+
make trtexec compile it. Dynamo exported models may have compilation issues,
|
|
108
|
+
so it's recommended to export with dynamo=False.
|
|
109
|
+
|
|
110
|
+
```bash
|
|
111
|
+
onnxsim model.onnx model.onnx
|
|
112
|
+
/usr/src/tensorrt/bin/trtexec --onnx=model.onnx --fp16 --int8 --useCudaGraph
|
|
113
|
+
```
|
|
114
|
+
|
|
115
|
+
Optionally you can get the layer profile with the following flags:
|
|
116
|
+
```
|
|
117
|
+
--exportLayerInfo=layer_info.json
|
|
118
|
+
--exportProfile=profile.json
|
|
119
|
+
--profilingVerbosity=detailed
|
|
120
|
+
```
|
|
95
121
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
#
|
|
114
|
-
#
|
|
115
|
-
|
|
116
|
-
|
|
122
|
+
## Mixed Precision
|
|
123
|
+
|
|
124
|
+
To keep a specific layer in higher precision while quantizing the rest to INT8,
|
|
125
|
+
pass its `nn.Conv2d` instance to `ModulesToSkip` after `transform`. Note that
|
|
126
|
+
`torch.fx.GraphModule` deep-copies submodules during tracing, so you must take
|
|
127
|
+
the reference **from the fused graph**, not from the original model:
|
|
128
|
+
|
|
129
|
+
```python
|
|
130
|
+
from embedl_deploy.quantize import quantize, QuantConfig, ModulesToSkip
|
|
131
|
+
|
|
132
|
+
res = transform(model, patterns=TENSORRT_PATTERNS)
|
|
133
|
+
|
|
134
|
+
# Grab the conv instance from the fused graph (not from the original model)
|
|
135
|
+
first_conv = res.model.FusedConvBNActMaxPool_0.conv
|
|
136
|
+
|
|
137
|
+
config = QuantConfig(
|
|
138
|
+
skip=ModulesToSkip(
|
|
139
|
+
stub={first_conv}, # disables input activation quantization
|
|
140
|
+
weight={first_conv}, # disables weight fake-quantization
|
|
141
|
+
)
|
|
142
|
+
)
|
|
143
|
+
quantized_model = quantize(
|
|
144
|
+
res.model, (example_input,), config=config, forward_loop=calibration_loop
|
|
145
|
+
)
|
|
117
146
|
```
|
|
118
147
|
|
|
119
148
|
## Design Principles
|
|
@@ -130,10 +159,13 @@ qat_model = quantized_model.train()
|
|
|
130
159
|
`transform()` is a convenience for the common case where you want
|
|
131
160
|
everything applied.
|
|
132
161
|
|
|
133
|
-
3. **
|
|
134
|
-
All graph analysis and surgery uses
|
|
135
|
-
and manipulated as `fx.GraphModule` objects
|
|
136
|
-
|
|
162
|
+
3. **Graph-based models (torch.export.export and symbolic traced).**
|
|
163
|
+
All graph analysis and surgery uses traced graphs. Models are traced once
|
|
164
|
+
and manipulated as `fx.GraphModule` objects with suport for tracing via both
|
|
165
|
+
`torch.fx` (symbolic) as well as `torch.export.export` (Aten). Support for
|
|
166
|
+
Aten graphs is automatically enabled using Aten recomposition
|
|
167
|
+
patterns that compose Aten operations into equivalent `torch.nn` modules
|
|
168
|
+
automatically before conversions and fusions.
|
|
137
169
|
|
|
138
170
|
## Support
|
|
139
171
|
|
|
@@ -27,16 +27,11 @@ dynamic = ["version"]
|
|
|
27
27
|
dependencies = ["torch"]
|
|
28
28
|
|
|
29
29
|
[project.optional-dependencies]
|
|
30
|
-
tensorrt = ["tensorrt"]
|
|
30
|
+
tensorrt = ["embedl-deploy-tensorrt"]
|
|
31
31
|
|
|
32
32
|
[project.urls]
|
|
33
33
|
Homepage = "https://www.embedl.com/"
|
|
34
34
|
|
|
35
|
-
[tool.black]
|
|
36
|
-
line-length = 79
|
|
37
|
-
target-version = ["py310"]
|
|
38
|
-
skip-string-normalization = true
|
|
39
|
-
|
|
40
35
|
[tool.coverage.html]
|
|
41
36
|
show_contexts = true
|
|
42
37
|
|
|
@@ -100,19 +95,72 @@ line-length = 79
|
|
|
100
95
|
quote-style = "preserve"
|
|
101
96
|
|
|
102
97
|
[tool.ruff.lint]
|
|
103
|
-
select = [
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
98
|
+
select = ["ALL"]
|
|
99
|
+
ignore = [
|
|
100
|
+
# Dynamic attributes on fx.Node require string-based access for mypy
|
|
101
|
+
"B009", "B010",
|
|
102
|
+
# Conflicts with ruff format
|
|
103
|
+
"COM812",
|
|
104
|
+
# Descriptive exception messages preferred
|
|
105
|
+
"EM", "TRY003",
|
|
106
|
+
# Allow long lines for URLs, Sphinx cross-references, and imports
|
|
107
|
+
"E501",
|
|
108
|
+
# Too many false positives
|
|
109
|
+
"ERA001",
|
|
110
|
+
# Common in PyTorch-style APIs
|
|
111
|
+
"FBT",
|
|
112
|
+
# TODOs are fine
|
|
113
|
+
"FIX002",
|
|
114
|
+
# PyTorch naming conventions (N, C, H, W; import F)
|
|
115
|
+
"N806", "N812",
|
|
116
|
+
# Allow magic value comparisons
|
|
117
|
+
"PLR2004",
|
|
118
|
+
# Intermediate variables before return aid readability
|
|
119
|
+
"RET504",
|
|
120
|
+
# Conflicts with quote-style = "preserve"
|
|
121
|
+
"Q000",
|
|
122
|
+
# Intentional Unicode in docstrings and comments
|
|
123
|
+
"RUF002", "RUF003",
|
|
124
|
+
# Explicit if/return True/return False is clearer for predicate functions
|
|
125
|
+
"SIM103",
|
|
126
|
+
# Type-only imports are fine as regular imports
|
|
127
|
+
"TC001",
|
|
128
|
+
# Non-cryptographic random is expected in ML code
|
|
129
|
+
"S311",
|
|
130
|
+
# Prefer unquoted type expressions in cast()
|
|
131
|
+
"TC006",
|
|
132
|
+
# Clashes with dataclass and nn.Module patterns
|
|
133
|
+
"RUF012",
|
|
134
|
+
# Too prescriptive about TODO format
|
|
135
|
+
"TD",
|
|
136
|
+
# D203/D211 and D212/D213 are mutually exclusive pairs
|
|
137
|
+
"D203", "D213",
|
|
108
138
|
]
|
|
109
139
|
|
|
140
|
+
[tool.ruff.lint.per-file-ignores]
|
|
141
|
+
"src/**/*.py" = ["S101"]
|
|
142
|
+
"tests/**/*.py" = ["ANN", "D103", "S101"]
|
|
143
|
+
"docs/**/*.py" = ["ANN", "E402", "INP001", "S", "T201"]
|
|
144
|
+
"examples/**/*.py" = ["INP001", "T201"]
|
|
145
|
+
".claude/**/*.py" = ["ALL"]
|
|
146
|
+
|
|
147
|
+
[tool.ruff.lint.pylint]
|
|
148
|
+
max-args = 8
|
|
149
|
+
|
|
110
150
|
[tool.mypy]
|
|
111
151
|
ignore_missing_imports = false
|
|
112
152
|
strict = true
|
|
113
153
|
|
|
114
154
|
[[tool.mypy.overrides]]
|
|
115
|
-
module = [
|
|
155
|
+
module = [
|
|
156
|
+
"torch.*",
|
|
157
|
+
"pytest.*",
|
|
158
|
+
"torchvision.*",
|
|
159
|
+
"tensorrt.*",
|
|
160
|
+
"onnx.*",
|
|
161
|
+
"onnxsim.*",
|
|
162
|
+
"embedl_studio.*",
|
|
163
|
+
]
|
|
116
164
|
ignore_missing_imports = true
|
|
117
165
|
|
|
118
166
|
[[tool.mypy.overrides]]
|
|
@@ -125,5 +173,8 @@ disable_error_code = ["misc", "no-any-return"]
|
|
|
125
173
|
module = ["embedl_deploy._internal.tensorrt.modules.*"]
|
|
126
174
|
disable_error_code = ["no-any-return"]
|
|
127
175
|
|
|
176
|
+
[tool.setuptools.package-data]
|
|
177
|
+
embedl_deploy = ["py.typed"]
|
|
178
|
+
|
|
128
179
|
[tool.setuptools.dynamic]
|
|
129
180
|
version = { attr = "embedl_deploy.version.public.PUBLIC_VERSION" }
|
|
@@ -22,7 +22,7 @@ class Backend:
|
|
|
22
22
|
fusion_patterns: Sequence[Pattern]
|
|
23
23
|
#: SmoothQuant preparation patterns.
|
|
24
24
|
smooth_patterns: Sequence[Pattern]
|
|
25
|
-
#: Q/DQ stub insertion patterns for
|
|
25
|
+
#: Q/DQ stub insertion patterns for quantization.
|
|
26
26
|
quantized_patterns: Sequence[Pattern]
|
|
27
27
|
|
|
28
28
|
|
|
@@ -120,6 +120,6 @@ def set_backend(name: str) -> None:
|
|
|
120
120
|
if name not in backends:
|
|
121
121
|
available = ", ".join(sorted(backends)) or "(none)"
|
|
122
122
|
raise ValueError(
|
|
123
|
-
f"Backend {name!r} not found.
|
|
123
|
+
f"Backend {name!r} not found. Available backends: {available}"
|
|
124
124
|
)
|
|
125
125
|
_BackendState.backend = backends[name]
|
|
@@ -63,6 +63,10 @@ class FusedModule(nn.Module, ABC):
|
|
|
63
63
|
self.input_quant_stubs: dict[int, QuantStub] = {
|
|
64
64
|
idx: QuantStub({self}) for idx in self.inputs_to_quantize
|
|
65
65
|
}
|
|
66
|
+
#: Whether this module has been surrounded with input
|
|
67
|
+
#: ``QuantStub`` entries by
|
|
68
|
+
#: :class:`~embedl_deploy._internal.tensorrt.patterns.quantizations.SurroundWithQuantStubsPattern`.
|
|
69
|
+
self.surrounded: bool = False
|
|
66
70
|
|
|
67
71
|
|
|
68
72
|
class _LeafTracer(fx.Tracer):
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
# Copyright (C) 2026 Embedl AB
|
|
2
|
+
|
|
3
|
+
"""Core abstractions: Pattern base class and PatternMatch dataclass.
|
|
4
|
+
|
|
5
|
+
Every fusion, conversion, and quantization rule is a
|
|
6
|
+
:class:`~embedl_deploy._internal.core.pattern.Pattern` subclass. The two
|
|
7
|
+
methods — :meth:`~embedl_deploy._internal.core.pattern.Pattern.match` and
|
|
8
|
+
:meth:`~embedl_deploy._internal.core.pattern.Pattern.replace` — encapsulate
|
|
9
|
+
what to look for and how to rewrite the graph.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
|
|
14
|
+
from torch import fx, nn
|
|
15
|
+
|
|
16
|
+
from embedl_deploy._internal.core.tree.match import match_tree
|
|
17
|
+
from embedl_deploy._internal.core.tree.replace import replace_tree
|
|
18
|
+
from embedl_deploy._internal.core.tree.types import (
|
|
19
|
+
Graft,
|
|
20
|
+
Replacement,
|
|
21
|
+
Tree,
|
|
22
|
+
TreeMatch,
|
|
23
|
+
Wildcard,
|
|
24
|
+
)
|
|
25
|
+
from embedl_deploy._internal.core.tree.utils import get_module
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _collect_modules(tree_match: TreeMatch) -> list[nn.Module | None]:
|
|
29
|
+
"""Resolve matched modules from a tree match.
|
|
30
|
+
|
|
31
|
+
Walks nested branches first (in input order), then
|
|
32
|
+
trunk nodes. For a
|
|
33
|
+
:class:`~embedl_deploy._internal.core.tree.types.Fork`
|
|
34
|
+
tree this means the fork-input branches precede the output
|
|
35
|
+
trunk, so the resulting list matches a constructor signature
|
|
36
|
+
like
|
|
37
|
+
``FusedModule(branch0_mod, branch1_mod, …, output_mod)``.
|
|
38
|
+
|
|
39
|
+
:class:`~embedl_deploy._internal.core.tree.types.Wildcard`
|
|
40
|
+
entries with ``"?"`` quantifier that matched nothing
|
|
41
|
+
contribute ``None``.
|
|
42
|
+
|
|
43
|
+
:raises TypeError:
|
|
44
|
+
If a matched node is not a ``call_module`` node.
|
|
45
|
+
"""
|
|
46
|
+
modules: list[nn.Module | None] = []
|
|
47
|
+
for nested in tree_match.nested:
|
|
48
|
+
modules.extend(_collect_modules(nested))
|
|
49
|
+
for entry in tree_match.trunk_nodes:
|
|
50
|
+
if isinstance(entry, Wildcard):
|
|
51
|
+
if entry.quantifier != "?":
|
|
52
|
+
raise TypeError(
|
|
53
|
+
f"wildcard with quantifier"
|
|
54
|
+
f" {entry.quantifier!r} is not"
|
|
55
|
+
f" supported — graft only supports"
|
|
56
|
+
f" '?' wildcards"
|
|
57
|
+
)
|
|
58
|
+
node = entry.nodes[0] if entry.nodes else None
|
|
59
|
+
else:
|
|
60
|
+
node = entry
|
|
61
|
+
if node is None:
|
|
62
|
+
modules.append(None)
|
|
63
|
+
else:
|
|
64
|
+
mod = get_module(node)
|
|
65
|
+
if mod is None:
|
|
66
|
+
raise TypeError(
|
|
67
|
+
f"node {node.name!r} is not a call_module "
|
|
68
|
+
f"node — graft only works with "
|
|
69
|
+
f"module-only trees"
|
|
70
|
+
)
|
|
71
|
+
modules.append(mod)
|
|
72
|
+
return modules
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _get_replacements(
|
|
76
|
+
graft: Graft,
|
|
77
|
+
tree_match: TreeMatch,
|
|
78
|
+
) -> list[Replacement]:
|
|
79
|
+
"""Build the replacement list from a graft specification."""
|
|
80
|
+
if isinstance(graft, tuple):
|
|
81
|
+
replacements: list[Replacement] = []
|
|
82
|
+
for rep_maker in graft:
|
|
83
|
+
replacements.extend(rep_maker(tree_match))
|
|
84
|
+
return replacements
|
|
85
|
+
modules = _collect_modules(tree_match)
|
|
86
|
+
try:
|
|
87
|
+
return [graft(*modules)]
|
|
88
|
+
except TypeError as exc:
|
|
89
|
+
raise TypeError(
|
|
90
|
+
f"{graft.__name__}() got"
|
|
91
|
+
f" {len(modules)} modules from"
|
|
92
|
+
f" the tree match — check that"
|
|
93
|
+
f" the tree shape matches the"
|
|
94
|
+
f" constructor signature"
|
|
95
|
+
) from exc
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class Pattern:
|
|
99
|
+
"""A graph transformation rule: find a sub-graph and replace it.
|
|
100
|
+
|
|
101
|
+
The default :meth:`match` delegates to
|
|
102
|
+
:func:`~embedl_deploy._internal.core.tree.match.match_tree` using the
|
|
103
|
+
class's :attr:`tree`. The default :meth:`replace` constructs
|
|
104
|
+
replacements from :attr:`graft` and delegates to
|
|
105
|
+
:func:`~embedl_deploy._internal.core.tree.replace.replace_tree`.
|
|
106
|
+
Subclasses override either method when they need custom logic
|
|
107
|
+
(pre/post side-effects, post-match filtering, etc.).
|
|
108
|
+
|
|
109
|
+
Patterns with
|
|
110
|
+
:attr:`~embedl_deploy._internal.core.pattern.Pattern.is_conversion` set to
|
|
111
|
+
``True`` are applied in a first pass to rewrite graph topology before
|
|
112
|
+
fusion patterns are matched.
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
tree: Tree | None = None
|
|
116
|
+
"""The pattern topology to match, if using tree-based matching."""
|
|
117
|
+
|
|
118
|
+
graft: Graft | None = None
|
|
119
|
+
"""The factories to make replacements for each matched tree, if used."""
|
|
120
|
+
|
|
121
|
+
is_conversion: bool = False
|
|
122
|
+
"""If ``True``, this pattern is a structural conversion that must
|
|
123
|
+
be applied before fusion matching."""
|
|
124
|
+
|
|
125
|
+
symbolic_trace_only: bool = False
|
|
126
|
+
"""If ``True``, this pattern removes nodes that are artifacts of
|
|
127
|
+
``symbolic_trace``. This pattern has no effect on graphs exported with
|
|
128
|
+
``torch.export`` because the nodes never appear in those graphs."""
|
|
129
|
+
|
|
130
|
+
export_graph_only: bool = False
|
|
131
|
+
"""If ``True``, this pattern targets nodes that only appear in
|
|
132
|
+
``torch.export`` aten graphs and has no effect on symbolic-trace output."""
|
|
133
|
+
|
|
134
|
+
def match(self, graph_module: fx.GraphModule) -> list["PatternMatch"]:
|
|
135
|
+
"""Find all occurrences of this pattern in `graph_module`.
|
|
136
|
+
|
|
137
|
+
:raises ValueError:
|
|
138
|
+
If the pattern has no ``tree``.
|
|
139
|
+
"""
|
|
140
|
+
tree = self.tree
|
|
141
|
+
if tree is None:
|
|
142
|
+
raise ValueError(f"{type(self).__name__} has no tree to match.")
|
|
143
|
+
tree_matches = match_tree(graph_module, tree)
|
|
144
|
+
return [
|
|
145
|
+
PatternMatch(
|
|
146
|
+
pattern=self,
|
|
147
|
+
graph_module=graph_module,
|
|
148
|
+
tree_match=tm,
|
|
149
|
+
)
|
|
150
|
+
for tm in tree_matches
|
|
151
|
+
]
|
|
152
|
+
|
|
153
|
+
def replace(
|
|
154
|
+
self,
|
|
155
|
+
pattern_match: "PatternMatch",
|
|
156
|
+
) -> list[fx.Node]:
|
|
157
|
+
"""Replace one matched occurrence in-place.
|
|
158
|
+
|
|
159
|
+
:param pattern_match:
|
|
160
|
+
The pattern match to replace.
|
|
161
|
+
:returns:
|
|
162
|
+
The replacement nodes inserted into the graph.
|
|
163
|
+
:raises ValueError:
|
|
164
|
+
If the pattern has no ``graft``.
|
|
165
|
+
:raises TypeError:
|
|
166
|
+
If the ``graft`` class constructor rejects the
|
|
167
|
+
collected modules.
|
|
168
|
+
"""
|
|
169
|
+
assert pattern_match.pattern is self
|
|
170
|
+
tree_match = pattern_match.tree_match
|
|
171
|
+
graft = self.graft
|
|
172
|
+
if graft is None:
|
|
173
|
+
raise ValueError(
|
|
174
|
+
f"{type(self).__name__} has no graft"
|
|
175
|
+
f" — override replace() or set graft."
|
|
176
|
+
)
|
|
177
|
+
replacements = _get_replacements(graft, tree_match)
|
|
178
|
+
return replace_tree(
|
|
179
|
+
pattern_match.graph_module, tree_match, replacements
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
@dataclass
|
|
184
|
+
class PatternMatch:
|
|
185
|
+
"""One matched occurrence of a ``Pattern`` in a graph."""
|
|
186
|
+
|
|
187
|
+
#: The pattern that produced this match.
|
|
188
|
+
pattern: Pattern
|
|
189
|
+
#: The graph module that produced this match.
|
|
190
|
+
graph_module: fx.GraphModule
|
|
191
|
+
#: Structured match result produced by
|
|
192
|
+
#: :func:`~embedl_deploy._internal.core.tree.match.match_tree`.
|
|
193
|
+
#: Contains the matched nodes, modules, and nested per-branch
|
|
194
|
+
#: sub-matches for
|
|
195
|
+
#: :class:`~embedl_deploy._internal.core.tree.types.Fork`
|
|
196
|
+
#: topologies.
|
|
197
|
+
tree_match: TreeMatch
|
|
198
|
+
#: Whether to apply this match during transformation.
|
|
199
|
+
apply: bool = True
|
|
200
|
+
|
|
201
|
+
def __repr__(self) -> str:
|
|
202
|
+
pat = type(self.pattern).__name__
|
|
203
|
+
node_names = [n.name for n in self.tree_match.get_tree_nodes()]
|
|
204
|
+
return f"PatternMatch({pat}: {' -> '.join(node_names)})"
|