embedl-deploy 0.3.0__tar.gz → 0.4.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/MANIFEST.in +2 -0
  2. {embedl_deploy-0.3.0/src/embedl_deploy.egg-info → embedl_deploy-0.4.1}/PKG-INFO +64 -32
  3. {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/README.md +62 -30
  4. {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/pyproject.toml +63 -12
  5. {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/_internal/core/backend.py +2 -2
  6. {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/_internal/core/modules.py +4 -0
  7. embedl_deploy-0.4.1/src/embedl_deploy/_internal/core/pattern.py +204 -0
  8. {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/_internal/core/plan.py +51 -11
  9. {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/_internal/core/quantize/calibrate.py +5 -4
  10. {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/_internal/core/quantize/main.py +15 -1
  11. {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/_internal/core/quantize/stubs.py +2 -1
  12. embedl_deploy-0.4.1/src/embedl_deploy/_internal/core/tree/__init__.py +3 -0
  13. embedl_deploy-0.4.1/src/embedl_deploy/_internal/core/tree/match.py +334 -0
  14. {embedl_deploy-0.3.0/src/embedl_deploy/_internal/core → embedl_deploy-0.4.1/src/embedl_deploy/_internal/core/tree}/replace.py +93 -45
  15. embedl_deploy-0.4.1/src/embedl_deploy/_internal/core/tree/types.py +325 -0
  16. embedl_deploy-0.4.1/src/embedl_deploy/_internal/core/tree/utils.py +64 -0
  17. {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/version/public.py +1 -1
  18. {embedl_deploy-0.3.0 → embedl_deploy-0.4.1/src/embedl_deploy.egg-info}/PKG-INFO +64 -32
  19. embedl_deploy-0.4.1/src/embedl_deploy.egg-info/SOURCES.txt +35 -0
  20. embedl_deploy-0.4.1/src/embedl_deploy.egg-info/requires.txt +4 -0
  21. embedl_deploy-0.3.0/src/embedl_deploy/_internal/core/match.py +0 -256
  22. embedl_deploy-0.3.0/src/embedl_deploy/_internal/core/pattern.py +0 -480
  23. embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/__init__.py +0 -3
  24. embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/backend.py +0 -18
  25. embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/modules/__init__.py +0 -3
  26. embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/modules/attention.py +0 -275
  27. embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/modules/conv.py +0 -238
  28. embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/modules/linear.py +0 -159
  29. embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/modules/pointwise.py +0 -39
  30. embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/modules/pool.py +0 -25
  31. embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/modules/swin_attention.py +0 -460
  32. embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/__init__.py +0 -3
  33. embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/conversions/__init__.py +0 -15
  34. embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/conversions/attention.py +0 -891
  35. embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/conversions/general.py +0 -356
  36. embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/fusions/__init__.py +0 -3
  37. embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/fusions/attention.py +0 -87
  38. embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/fusions/conv.py +0 -196
  39. embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/fusions/linear.py +0 -86
  40. embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/fusions/pointwise.py +0 -55
  41. embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/fusions/pool.py +0 -50
  42. embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/quantizations.py +0 -329
  43. embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions.py +0 -1584
  44. embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/smoothings.py +0 -123
  45. embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/patterns/utils.py +0 -81
  46. embedl_deploy-0.3.0/src/embedl_deploy/_internal/tensorrt/plan.py +0 -131
  47. embedl_deploy-0.3.0/src/embedl_deploy/tensorrt/__init__.py +0 -45
  48. embedl_deploy-0.3.0/src/embedl_deploy/tensorrt/modules/__init__.py +0 -40
  49. embedl_deploy-0.3.0/src/embedl_deploy/tensorrt/patterns/__init__.py +0 -60
  50. embedl_deploy-0.3.0/src/embedl_deploy.egg-info/SOURCES.txt +0 -59
  51. embedl_deploy-0.3.0/src/embedl_deploy.egg-info/requires.txt +0 -4
  52. {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/LICENSE +0 -0
  53. {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/NOTICE +0 -0
  54. {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/setup.cfg +0 -0
  55. {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/__init__.py +1 -1
  56. {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/_internal/__init__.py +0 -0
  57. {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/_internal/core/__init__.py +0 -0
  58. {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/_internal/core/quantize/__init__.py +0 -0
  59. {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/_internal/core/quantize/config.py +0 -0
  60. {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/_internal/core/quantize/prepare.py +0 -0
  61. {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/_internal/core/quantize/qat.py +0 -0
  62. {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/_internal/core/quantize/utils.py +0 -0
  63. {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/backend/__init__.py +0 -0
  64. {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/py.typed +0 -0
  65. {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/quantize/__init__.py +1 -1
  66. {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy/version/__init__.py +0 -0
  67. {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy.egg-info/dependency_links.txt +0 -0
  68. {embedl_deploy-0.3.0 → embedl_deploy-0.4.1}/src/embedl_deploy.egg-info/top_level.txt +0 -0
@@ -3,6 +3,8 @@ graft src
3
3
  include LICENSE
4
4
  include NOTICE
5
5
  include README.md
6
+ prune src/embedl_deploy/tensorrt
7
+ prune src/embedl_deploy/_internal/tensorrt
6
8
  global-exclude CLAUDE.md
7
9
  global-exclude *.pyc
8
10
  global-exclude __pycache__
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: embedl-deploy
3
- Version: 0.3.0
3
+ Version: 0.4.1
4
4
  Summary: Python package to make AI models deployment-ready for any hardware.
5
5
  Author-email: Embedl AB <support@embedl.com>
6
6
  Project-URL: Homepage, https://www.embedl.com/
@@ -15,7 +15,7 @@ License-File: LICENSE
15
15
  License-File: NOTICE
16
16
  Requires-Dist: torch
17
17
  Provides-Extra: tensorrt
18
- Requires-Dist: tensorrt; extra == "tensorrt"
18
+ Requires-Dist: embedl-deploy-tensorrt; extra == "tensorrt"
19
19
  Dynamic: license-file
20
20
 
21
21
  # embedl-deploy
@@ -55,16 +55,16 @@ hardware target ensuring correct quantization and compilation.
55
55
 
56
56
  ## Supported Backends
57
57
 
58
- | Backend | Status |
59
- |---------------------|-------------|
60
- | NVIDIA TensorRT | Supported |
58
+ | Backend | Status |
59
+ |-------------------------|-------------|
60
+ | NVIDIA TensorRT (v10.3) | Supported |
61
61
 
62
- Contact us for other backends.
62
+ Contact Embedl for other backends.
63
63
 
64
64
  ## Installation
65
65
 
66
66
  ```bash
67
- pip install embedl-deploy
67
+ pip install "embedl-deploy[tensorrt]"
68
68
  ```
69
69
  Note that you may need to also install `onnx` and `onnx-simplifier` to export
70
70
  and get the exported model compiled with TensorRT if using ONNX as an
@@ -86,6 +86,9 @@ model = Model().eval()
86
86
  example_input = torch.randn(1, 3, 224, 224)
87
87
 
88
88
  # 2. Transform — fuse and optimize for TensorRT in one call
89
+ # For more compatibilty you can trace your model with torch.export.export
90
+ # as follows:
91
+ # model = torch.export.export(model, (example_input)).module()
89
92
  res = transform(model, patterns=TENSORRT_PATTERNS)
90
93
  print("Model\n", res.model.print_readable())
91
94
  print("Matches", "\n".join([str(match) for match in res.matches]))
@@ -112,28 +115,54 @@ torch.onnx.export(
112
115
  qat_model = quantized_model.train()
113
116
  # Freeze BatchNorm, or apply other QAT utilities as needed
114
117
  # train(qat_model)
118
+ ```
119
+
120
+ ### Compile
121
+
122
+ Compilation can be done with TensorRT's trtexec tool, which can take the ONNX
123
+ model and compile it for inference. The exported layer info and profile can
124
+ be used for debugging, optimization and visualization.
125
+
126
+ Note: that the ONNX model might need to be simplified with onnx-simplifier to
127
+ make trtexec compile it. Dynamo exported models may have compilation issues,
128
+ so it's recommended to export with dynamo=False.
129
+
130
+ ```bash
131
+ onnxsim model.onnx model.onnx
132
+ /usr/src/tensorrt/bin/trtexec --onnx=model.onnx --fp16 --int8 --useCudaGraph
133
+ ```
134
+
135
+ Optionally you can get the layer profile with the following flags:
136
+ ```
137
+ --exportLayerInfo=layer_info.json
138
+ --exportProfile=profile.json
139
+ --profilingVerbosity=detailed
140
+ ```
115
141
 
116
- # Compile
117
- # -------
118
- # Compilation can be done with TensorRT's trtexec tool, which can take the ONNX
119
- # model and compile it for inference. The exported layer info and profile can
120
- # be used for debugging, optimization and visualization.
121
- #
122
- # Note: that the ONNX model might need to be simplified with onnx-simplifier to
123
- # make trtexec compile it. Dynamo exported models may have compilation issues,
124
- # so it's recommended to export with dynamo=False.
125
- #
126
- # We are working on a Aten-based export path that should be more robust and
127
- # support more models in the future.
128
-
129
- # >> onnxsim model.onnx model.onnx
130
- # >> trtexec \
131
- # --onnx=model.onnx \
132
- # --exportLayerInfo=layer_info.json \
133
- # --exportProfile=profile.json \
134
- # --profilingVerbosity=detailed
135
-
136
- # More benchmarking scripts can be found in the examples/ directory
142
+ ## Mixed Precision
143
+
144
+ To keep a specific layer in higher precision while quantizing the rest to INT8,
145
+ pass its `nn.Conv2d` instance to `ModulesToSkip` after `transform`. Note that
146
+ `torch.fx.GraphModule` deep-copies submodules during tracing, so you must take
147
+ the reference **from the fused graph**, not from the original model:
148
+
149
+ ```python
150
+ from embedl_deploy.quantize import quantize, QuantConfig, ModulesToSkip
151
+
152
+ res = transform(model, patterns=TENSORRT_PATTERNS)
153
+
154
+ # Grab the conv instance from the fused graph (not from the original model)
155
+ first_conv = res.model.FusedConvBNActMaxPool_0.conv
156
+
157
+ config = QuantConfig(
158
+ skip=ModulesToSkip(
159
+ stub={first_conv}, # disables input activation quantization
160
+ weight={first_conv}, # disables weight fake-quantization
161
+ )
162
+ )
163
+ quantized_model = quantize(
164
+ res.model, (example_input,), config=config, forward_loop=calibration_loop
165
+ )
137
166
  ```
138
167
 
139
168
  ## Design Principles
@@ -150,10 +179,13 @@ qat_model = quantized_model.train()
150
179
  `transform()` is a convenience for the common case where you want
151
180
  everything applied.
152
181
 
153
- 3. **FX-graph-based.**
154
- All graph analysis and surgery uses `torch.fx`. Models are traced once
155
- and manipulated as `fx.GraphModule` objects. Support for Aten graphs
156
- produced by `torch.export.export` is planned for the future.
182
+ 3. **Graph-based models (torch.export.export and symbolic traced).**
183
+ All graph analysis and surgery uses traced graphs. Models are traced once
184
+ and manipulated as `fx.GraphModule` objects with suport for tracing via both
185
+ `torch.fx` (symbolic) as well as `torch.export.export` (Aten). Support for
186
+ Aten graphs is automatically enabled using Aten recomposition
187
+ patterns that compose Aten operations into equivalent `torch.nn` modules
188
+ automatically before conversions and fusions.
157
189
 
158
190
  ## Support
159
191
 
@@ -35,16 +35,16 @@ hardware target ensuring correct quantization and compilation.
35
35
 
36
36
  ## Supported Backends
37
37
 
38
- | Backend | Status |
39
- |---------------------|-------------|
40
- | NVIDIA TensorRT | Supported |
38
+ | Backend | Status |
39
+ |-------------------------|-------------|
40
+ | NVIDIA TensorRT (v10.3) | Supported |
41
41
 
42
- Contact us for other backends.
42
+ Contact Embedl for other backends.
43
43
 
44
44
  ## Installation
45
45
 
46
46
  ```bash
47
- pip install embedl-deploy
47
+ pip install "embedl-deploy[tensorrt]"
48
48
  ```
49
49
  Note that you may need to also install `onnx` and `onnx-simplifier` to export
50
50
  and get the exported model compiled with TensorRT if using ONNX as an
@@ -66,6 +66,9 @@ model = Model().eval()
66
66
  example_input = torch.randn(1, 3, 224, 224)
67
67
 
68
68
  # 2. Transform — fuse and optimize for TensorRT in one call
69
+ # For more compatibilty you can trace your model with torch.export.export
70
+ # as follows:
71
+ # model = torch.export.export(model, (example_input)).module()
69
72
  res = transform(model, patterns=TENSORRT_PATTERNS)
70
73
  print("Model\n", res.model.print_readable())
71
74
  print("Matches", "\n".join([str(match) for match in res.matches]))
@@ -92,28 +95,54 @@ torch.onnx.export(
92
95
  qat_model = quantized_model.train()
93
96
  # Freeze BatchNorm, or apply other QAT utilities as needed
94
97
  # train(qat_model)
98
+ ```
99
+
100
+ ### Compile
101
+
102
+ Compilation can be done with TensorRT's trtexec tool, which can take the ONNX
103
+ model and compile it for inference. The exported layer info and profile can
104
+ be used for debugging, optimization and visualization.
105
+
106
+ Note: that the ONNX model might need to be simplified with onnx-simplifier to
107
+ make trtexec compile it. Dynamo exported models may have compilation issues,
108
+ so it's recommended to export with dynamo=False.
109
+
110
+ ```bash
111
+ onnxsim model.onnx model.onnx
112
+ /usr/src/tensorrt/bin/trtexec --onnx=model.onnx --fp16 --int8 --useCudaGraph
113
+ ```
114
+
115
+ Optionally you can get the layer profile with the following flags:
116
+ ```
117
+ --exportLayerInfo=layer_info.json
118
+ --exportProfile=profile.json
119
+ --profilingVerbosity=detailed
120
+ ```
95
121
 
96
- # Compile
97
- # -------
98
- # Compilation can be done with TensorRT's trtexec tool, which can take the ONNX
99
- # model and compile it for inference. The exported layer info and profile can
100
- # be used for debugging, optimization and visualization.
101
- #
102
- # Note: that the ONNX model might need to be simplified with onnx-simplifier to
103
- # make trtexec compile it. Dynamo exported models may have compilation issues,
104
- # so it's recommended to export with dynamo=False.
105
- #
106
- # We are working on a Aten-based export path that should be more robust and
107
- # support more models in the future.
108
-
109
- # >> onnxsim model.onnx model.onnx
110
- # >> trtexec \
111
- # --onnx=model.onnx \
112
- # --exportLayerInfo=layer_info.json \
113
- # --exportProfile=profile.json \
114
- # --profilingVerbosity=detailed
115
-
116
- # More benchmarking scripts can be found in the examples/ directory
122
+ ## Mixed Precision
123
+
124
+ To keep a specific layer in higher precision while quantizing the rest to INT8,
125
+ pass its `nn.Conv2d` instance to `ModulesToSkip` after `transform`. Note that
126
+ `torch.fx.GraphModule` deep-copies submodules during tracing, so you must take
127
+ the reference **from the fused graph**, not from the original model:
128
+
129
+ ```python
130
+ from embedl_deploy.quantize import quantize, QuantConfig, ModulesToSkip
131
+
132
+ res = transform(model, patterns=TENSORRT_PATTERNS)
133
+
134
+ # Grab the conv instance from the fused graph (not from the original model)
135
+ first_conv = res.model.FusedConvBNActMaxPool_0.conv
136
+
137
+ config = QuantConfig(
138
+ skip=ModulesToSkip(
139
+ stub={first_conv}, # disables input activation quantization
140
+ weight={first_conv}, # disables weight fake-quantization
141
+ )
142
+ )
143
+ quantized_model = quantize(
144
+ res.model, (example_input,), config=config, forward_loop=calibration_loop
145
+ )
117
146
  ```
118
147
 
119
148
  ## Design Principles
@@ -130,10 +159,13 @@ qat_model = quantized_model.train()
130
159
  `transform()` is a convenience for the common case where you want
131
160
  everything applied.
132
161
 
133
- 3. **FX-graph-based.**
134
- All graph analysis and surgery uses `torch.fx`. Models are traced once
135
- and manipulated as `fx.GraphModule` objects. Support for Aten graphs
136
- produced by `torch.export.export` is planned for the future.
162
+ 3. **Graph-based models (torch.export.export and symbolic traced).**
163
+ All graph analysis and surgery uses traced graphs. Models are traced once
164
+ and manipulated as `fx.GraphModule` objects with suport for tracing via both
165
+ `torch.fx` (symbolic) as well as `torch.export.export` (Aten). Support for
166
+ Aten graphs is automatically enabled using Aten recomposition
167
+ patterns that compose Aten operations into equivalent `torch.nn` modules
168
+ automatically before conversions and fusions.
137
169
 
138
170
  ## Support
139
171
 
@@ -27,16 +27,11 @@ dynamic = ["version"]
27
27
  dependencies = ["torch"]
28
28
 
29
29
  [project.optional-dependencies]
30
- tensorrt = ["tensorrt"]
30
+ tensorrt = ["embedl-deploy-tensorrt"]
31
31
 
32
32
  [project.urls]
33
33
  Homepage = "https://www.embedl.com/"
34
34
 
35
- [tool.black]
36
- line-length = 79
37
- target-version = ["py310"]
38
- skip-string-normalization = true
39
-
40
35
  [tool.coverage.html]
41
36
  show_contexts = true
42
37
 
@@ -100,19 +95,72 @@ line-length = 79
100
95
  quote-style = "preserve"
101
96
 
102
97
  [tool.ruff.lint]
103
- select = [
104
- # isort
105
- "I",
106
- # Use `from X import Y` instead of `import X.Y as Y`
107
- "PLR0402",
98
+ select = ["ALL"]
99
+ ignore = [
100
+ # Dynamic attributes on fx.Node require string-based access for mypy
101
+ "B009", "B010",
102
+ # Conflicts with ruff format
103
+ "COM812",
104
+ # Descriptive exception messages preferred
105
+ "EM", "TRY003",
106
+ # Allow long lines for URLs, Sphinx cross-references, and imports
107
+ "E501",
108
+ # Too many false positives
109
+ "ERA001",
110
+ # Common in PyTorch-style APIs
111
+ "FBT",
112
+ # TODOs are fine
113
+ "FIX002",
114
+ # PyTorch naming conventions (N, C, H, W; import F)
115
+ "N806", "N812",
116
+ # Allow magic value comparisons
117
+ "PLR2004",
118
+ # Intermediate variables before return aid readability
119
+ "RET504",
120
+ # Conflicts with quote-style = "preserve"
121
+ "Q000",
122
+ # Intentional Unicode in docstrings and comments
123
+ "RUF002", "RUF003",
124
+ # Explicit if/return True/return False is clearer for predicate functions
125
+ "SIM103",
126
+ # Type-only imports are fine as regular imports
127
+ "TC001",
128
+ # Non-cryptographic random is expected in ML code
129
+ "S311",
130
+ # Prefer unquoted type expressions in cast()
131
+ "TC006",
132
+ # Clashes with dataclass and nn.Module patterns
133
+ "RUF012",
134
+ # Too prescriptive about TODO format
135
+ "TD",
136
+ # D203/D211 and D212/D213 are mutually exclusive pairs
137
+ "D203", "D213",
108
138
  ]
109
139
 
140
+ [tool.ruff.lint.per-file-ignores]
141
+ "src/**/*.py" = ["S101"]
142
+ "tests/**/*.py" = ["ANN", "D103", "S101"]
143
+ "docs/**/*.py" = ["ANN", "E402", "INP001", "S", "T201"]
144
+ "examples/**/*.py" = ["INP001", "T201"]
145
+ ".claude/**/*.py" = ["ALL"]
146
+
147
+ [tool.ruff.lint.pylint]
148
+ max-args = 8
149
+
110
150
  [tool.mypy]
111
151
  ignore_missing_imports = false
112
152
  strict = true
113
153
 
114
154
  [[tool.mypy.overrides]]
115
- module = ["torch.*", "pytest.*"]
155
+ module = [
156
+ "torch.*",
157
+ "pytest.*",
158
+ "torchvision.*",
159
+ "tensorrt.*",
160
+ "onnx.*",
161
+ "onnxsim.*",
162
+ "embedl_studio.*",
163
+ ]
116
164
  ignore_missing_imports = true
117
165
 
118
166
  [[tool.mypy.overrides]]
@@ -125,5 +173,8 @@ disable_error_code = ["misc", "no-any-return"]
125
173
  module = ["embedl_deploy._internal.tensorrt.modules.*"]
126
174
  disable_error_code = ["no-any-return"]
127
175
 
176
+ [tool.setuptools.package-data]
177
+ embedl_deploy = ["py.typed"]
178
+
128
179
  [tool.setuptools.dynamic]
129
180
  version = { attr = "embedl_deploy.version.public.PUBLIC_VERSION" }
@@ -22,7 +22,7 @@ class Backend:
22
22
  fusion_patterns: Sequence[Pattern]
23
23
  #: SmoothQuant preparation patterns.
24
24
  smooth_patterns: Sequence[Pattern]
25
- #: Q/DQ stub insertion patterns for quantisation.
25
+ #: Q/DQ stub insertion patterns for quantization.
26
26
  quantized_patterns: Sequence[Pattern]
27
27
 
28
28
 
@@ -120,6 +120,6 @@ def set_backend(name: str) -> None:
120
120
  if name not in backends:
121
121
  available = ", ".join(sorted(backends)) or "(none)"
122
122
  raise ValueError(
123
- f"Backend {name!r} not found. " f"Available backends: {available}"
123
+ f"Backend {name!r} not found. Available backends: {available}"
124
124
  )
125
125
  _BackendState.backend = backends[name]
@@ -63,6 +63,10 @@ class FusedModule(nn.Module, ABC):
63
63
  self.input_quant_stubs: dict[int, QuantStub] = {
64
64
  idx: QuantStub({self}) for idx in self.inputs_to_quantize
65
65
  }
66
+ #: Whether this module has been surrounded with input
67
+ #: ``QuantStub`` entries by
68
+ #: :class:`~embedl_deploy._internal.tensorrt.patterns.quantizations.SurroundWithQuantStubsPattern`.
69
+ self.surrounded: bool = False
66
70
 
67
71
 
68
72
  class _LeafTracer(fx.Tracer):
@@ -0,0 +1,204 @@
1
+ # Copyright (C) 2026 Embedl AB
2
+
3
+ """Core abstractions: Pattern base class and PatternMatch dataclass.
4
+
5
+ Every fusion, conversion, and quantization rule is a
6
+ :class:`~embedl_deploy._internal.core.pattern.Pattern` subclass. The two
7
+ methods — :meth:`~embedl_deploy._internal.core.pattern.Pattern.match` and
8
+ :meth:`~embedl_deploy._internal.core.pattern.Pattern.replace` — encapsulate
9
+ what to look for and how to rewrite the graph.
10
+ """
11
+
12
+ from dataclasses import dataclass
13
+
14
+ from torch import fx, nn
15
+
16
+ from embedl_deploy._internal.core.tree.match import match_tree
17
+ from embedl_deploy._internal.core.tree.replace import replace_tree
18
+ from embedl_deploy._internal.core.tree.types import (
19
+ Graft,
20
+ Replacement,
21
+ Tree,
22
+ TreeMatch,
23
+ Wildcard,
24
+ )
25
+ from embedl_deploy._internal.core.tree.utils import get_module
26
+
27
+
28
+ def _collect_modules(tree_match: TreeMatch) -> list[nn.Module | None]:
29
+ """Resolve matched modules from a tree match.
30
+
31
+ Walks nested branches first (in input order), then
32
+ trunk nodes. For a
33
+ :class:`~embedl_deploy._internal.core.tree.types.Fork`
34
+ tree this means the fork-input branches precede the output
35
+ trunk, so the resulting list matches a constructor signature
36
+ like
37
+ ``FusedModule(branch0_mod, branch1_mod, …, output_mod)``.
38
+
39
+ :class:`~embedl_deploy._internal.core.tree.types.Wildcard`
40
+ entries with ``"?"`` quantifier that matched nothing
41
+ contribute ``None``.
42
+
43
+ :raises TypeError:
44
+ If a matched node is not a ``call_module`` node.
45
+ """
46
+ modules: list[nn.Module | None] = []
47
+ for nested in tree_match.nested:
48
+ modules.extend(_collect_modules(nested))
49
+ for entry in tree_match.trunk_nodes:
50
+ if isinstance(entry, Wildcard):
51
+ if entry.quantifier != "?":
52
+ raise TypeError(
53
+ f"wildcard with quantifier"
54
+ f" {entry.quantifier!r} is not"
55
+ f" supported — graft only supports"
56
+ f" '?' wildcards"
57
+ )
58
+ node = entry.nodes[0] if entry.nodes else None
59
+ else:
60
+ node = entry
61
+ if node is None:
62
+ modules.append(None)
63
+ else:
64
+ mod = get_module(node)
65
+ if mod is None:
66
+ raise TypeError(
67
+ f"node {node.name!r} is not a call_module "
68
+ f"node — graft only works with "
69
+ f"module-only trees"
70
+ )
71
+ modules.append(mod)
72
+ return modules
73
+
74
+
75
+ def _get_replacements(
76
+ graft: Graft,
77
+ tree_match: TreeMatch,
78
+ ) -> list[Replacement]:
79
+ """Build the replacement list from a graft specification."""
80
+ if isinstance(graft, tuple):
81
+ replacements: list[Replacement] = []
82
+ for rep_maker in graft:
83
+ replacements.extend(rep_maker(tree_match))
84
+ return replacements
85
+ modules = _collect_modules(tree_match)
86
+ try:
87
+ return [graft(*modules)]
88
+ except TypeError as exc:
89
+ raise TypeError(
90
+ f"{graft.__name__}() got"
91
+ f" {len(modules)} modules from"
92
+ f" the tree match — check that"
93
+ f" the tree shape matches the"
94
+ f" constructor signature"
95
+ ) from exc
96
+
97
+
98
+ class Pattern:
99
+ """A graph transformation rule: find a sub-graph and replace it.
100
+
101
+ The default :meth:`match` delegates to
102
+ :func:`~embedl_deploy._internal.core.tree.match.match_tree` using the
103
+ class's :attr:`tree`. The default :meth:`replace` constructs
104
+ replacements from :attr:`graft` and delegates to
105
+ :func:`~embedl_deploy._internal.core.tree.replace.replace_tree`.
106
+ Subclasses override either method when they need custom logic
107
+ (pre/post side-effects, post-match filtering, etc.).
108
+
109
+ Patterns with
110
+ :attr:`~embedl_deploy._internal.core.pattern.Pattern.is_conversion` set to
111
+ ``True`` are applied in a first pass to rewrite graph topology before
112
+ fusion patterns are matched.
113
+ """
114
+
115
+ tree: Tree | None = None
116
+ """The pattern topology to match, if using tree-based matching."""
117
+
118
+ graft: Graft | None = None
119
+ """The factories to make replacements for each matched tree, if used."""
120
+
121
+ is_conversion: bool = False
122
+ """If ``True``, this pattern is a structural conversion that must
123
+ be applied before fusion matching."""
124
+
125
+ symbolic_trace_only: bool = False
126
+ """If ``True``, this pattern removes nodes that are artifacts of
127
+ ``symbolic_trace``. This pattern has no effect on graphs exported with
128
+ ``torch.export`` because the nodes never appear in those graphs."""
129
+
130
+ export_graph_only: bool = False
131
+ """If ``True``, this pattern targets nodes that only appear in
132
+ ``torch.export`` aten graphs and has no effect on symbolic-trace output."""
133
+
134
+ def match(self, graph_module: fx.GraphModule) -> list["PatternMatch"]:
135
+ """Find all occurrences of this pattern in `graph_module`.
136
+
137
+ :raises ValueError:
138
+ If the pattern has no ``tree``.
139
+ """
140
+ tree = self.tree
141
+ if tree is None:
142
+ raise ValueError(f"{type(self).__name__} has no tree to match.")
143
+ tree_matches = match_tree(graph_module, tree)
144
+ return [
145
+ PatternMatch(
146
+ pattern=self,
147
+ graph_module=graph_module,
148
+ tree_match=tm,
149
+ )
150
+ for tm in tree_matches
151
+ ]
152
+
153
+ def replace(
154
+ self,
155
+ pattern_match: "PatternMatch",
156
+ ) -> list[fx.Node]:
157
+ """Replace one matched occurrence in-place.
158
+
159
+ :param pattern_match:
160
+ The pattern match to replace.
161
+ :returns:
162
+ The replacement nodes inserted into the graph.
163
+ :raises ValueError:
164
+ If the pattern has no ``graft``.
165
+ :raises TypeError:
166
+ If the ``graft`` class constructor rejects the
167
+ collected modules.
168
+ """
169
+ assert pattern_match.pattern is self
170
+ tree_match = pattern_match.tree_match
171
+ graft = self.graft
172
+ if graft is None:
173
+ raise ValueError(
174
+ f"{type(self).__name__} has no graft"
175
+ f" — override replace() or set graft."
176
+ )
177
+ replacements = _get_replacements(graft, tree_match)
178
+ return replace_tree(
179
+ pattern_match.graph_module, tree_match, replacements
180
+ )
181
+
182
+
183
+ @dataclass
184
+ class PatternMatch:
185
+ """One matched occurrence of a ``Pattern`` in a graph."""
186
+
187
+ #: The pattern that produced this match.
188
+ pattern: Pattern
189
+ #: The graph module that produced this match.
190
+ graph_module: fx.GraphModule
191
+ #: Structured match result produced by
192
+ #: :func:`~embedl_deploy._internal.core.tree.match.match_tree`.
193
+ #: Contains the matched nodes, modules, and nested per-branch
194
+ #: sub-matches for
195
+ #: :class:`~embedl_deploy._internal.core.tree.types.Fork`
196
+ #: topologies.
197
+ tree_match: TreeMatch
198
+ #: Whether to apply this match during transformation.
199
+ apply: bool = True
200
+
201
+ def __repr__(self) -> str:
202
+ pat = type(self.pattern).__name__
203
+ node_names = [n.name for n in self.tree_match.get_tree_nodes()]
204
+ return f"PatternMatch({pat}: {' -> '.join(node_names)})"