embedl-deploy 0.2.0__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. embedl_deploy-0.4.0/MANIFEST.in +10 -0
  2. {embedl_deploy-0.2.0/src/embedl_deploy.egg-info → embedl_deploy-0.4.0}/PKG-INFO +2 -2
  3. {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/pyproject.toml +54 -6
  4. {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/core/backend.py +2 -2
  5. {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/core/modules.py +4 -0
  6. embedl_deploy-0.4.0/src/embedl_deploy/_internal/core/pattern.py +204 -0
  7. {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/core/plan.py +63 -4
  8. {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/core/quantize/calibrate.py +5 -4
  9. {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/core/quantize/main.py +15 -1
  10. {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/core/quantize/stubs.py +2 -1
  11. embedl_deploy-0.4.0/src/embedl_deploy/_internal/core/tree/__init__.py +3 -0
  12. embedl_deploy-0.4.0/src/embedl_deploy/_internal/core/tree/match.py +334 -0
  13. {embedl_deploy-0.2.0/src/embedl_deploy/_internal/core → embedl_deploy-0.4.0/src/embedl_deploy/_internal/core/tree}/replace.py +93 -45
  14. embedl_deploy-0.4.0/src/embedl_deploy/_internal/core/tree/types.py +326 -0
  15. embedl_deploy-0.4.0/src/embedl_deploy/_internal/core/tree/utils.py +64 -0
  16. {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/version/public.py +1 -1
  17. {embedl_deploy-0.2.0 → embedl_deploy-0.4.0/src/embedl_deploy.egg-info}/PKG-INFO +2 -2
  18. embedl_deploy-0.4.0/src/embedl_deploy.egg-info/SOURCES.txt +35 -0
  19. embedl_deploy-0.4.0/src/embedl_deploy.egg-info/requires.txt +4 -0
  20. embedl_deploy-0.2.0/src/embedl_deploy/_internal/core/match.py +0 -256
  21. embedl_deploy-0.2.0/src/embedl_deploy/_internal/core/pattern.py +0 -476
  22. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/__init__.py +0 -3
  23. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/backend.py +0 -18
  24. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/modules/__init__.py +0 -3
  25. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/modules/attention.py +0 -274
  26. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/modules/conv.py +0 -232
  27. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/modules/linear.py +0 -112
  28. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/modules/pointwise.py +0 -39
  29. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/modules/pool.py +0 -25
  30. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/modules/swin_attention.py +0 -460
  31. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/__init__.py +0 -3
  32. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/conversions/__init__.py +0 -15
  33. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/conversions/attention.py +0 -691
  34. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/conversions/general.py +0 -300
  35. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/fusions/__init__.py +0 -3
  36. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/fusions/attention.py +0 -87
  37. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/fusions/conv.py +0 -196
  38. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/fusions/linear.py +0 -86
  39. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/fusions/pointwise.py +0 -55
  40. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/fusions/pool.py +0 -50
  41. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/quantizations.py +0 -292
  42. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions.py +0 -819
  43. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/smoothings.py +0 -123
  44. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/utils.py +0 -81
  45. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/plan.py +0 -123
  46. embedl_deploy-0.2.0/src/embedl_deploy/tensorrt/__init__.py +0 -40
  47. embedl_deploy-0.2.0/src/embedl_deploy/tensorrt/modules/__init__.py +0 -40
  48. embedl_deploy-0.2.0/src/embedl_deploy/tensorrt/patterns/__init__.py +0 -60
  49. embedl_deploy-0.2.0/src/embedl_deploy.egg-info/SOURCES.txt +0 -59
  50. embedl_deploy-0.2.0/src/embedl_deploy.egg-info/requires.txt +0 -4
  51. embedl_deploy-0.2.0/tests/test_version.py +0 -20
  52. {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/LICENSE +0 -0
  53. {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/NOTICE +0 -0
  54. {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/README.md +0 -0
  55. {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/setup.cfg +0 -0
  56. {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/__init__.py +1 -1
  57. {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/__init__.py +0 -0
  58. {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/core/__init__.py +0 -0
  59. {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/core/quantize/__init__.py +0 -0
  60. {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/core/quantize/config.py +0 -0
  61. {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/core/quantize/prepare.py +0 -0
  62. {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/core/quantize/qat.py +0 -0
  63. {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/core/quantize/utils.py +0 -0
  64. {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/backend/__init__.py +0 -0
  65. {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/py.typed +0 -0
  66. {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/quantize/__init__.py +1 -1
  67. {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/version/__init__.py +0 -0
  68. {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy.egg-info/dependency_links.txt +0 -0
  69. {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy.egg-info/top_level.txt +0 -0
@@ -0,0 +1,10 @@
1
+ prune *
2
+ graft src
3
+ include LICENSE
4
+ include NOTICE
5
+ include README.md
6
+ prune src/embedl_deploy/tensorrt
7
+ prune src/embedl_deploy/_internal/tensorrt
8
+ global-exclude CLAUDE.md
9
+ global-exclude *.pyc
10
+ global-exclude __pycache__
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: embedl-deploy
3
- Version: 0.2.0
3
+ Version: 0.4.0
4
4
  Summary: Python package to make AI models deployment-ready for any hardware.
5
5
  Author-email: Embedl AB <support@embedl.com>
6
6
  Project-URL: Homepage, https://www.embedl.com/
@@ -15,7 +15,7 @@ License-File: LICENSE
15
15
  License-File: NOTICE
16
16
  Requires-Dist: torch
17
17
  Provides-Extra: tensorrt
18
- Requires-Dist: tensorrt; extra == "tensorrt"
18
+ Requires-Dist: embedl-deploy-tensorrt; extra == "tensorrt"
19
19
  Dynamic: license-file
20
20
 
21
21
  # embedl-deploy
@@ -27,7 +27,7 @@ dynamic = ["version"]
27
27
  dependencies = ["torch"]
28
28
 
29
29
  [project.optional-dependencies]
30
- tensorrt = ["tensorrt"]
30
+ tensorrt = ["embedl-deploy-tensorrt"]
31
31
 
32
32
  [project.urls]
33
33
  Homepage = "https://www.embedl.com/"
@@ -100,13 +100,58 @@ line-length = 79
100
100
  quote-style = "preserve"
101
101
 
102
102
  [tool.ruff.lint]
103
- select = [
104
- # isort
105
- "I",
106
- # Use `from X import Y` instead of `import X.Y as Y`
107
- "PLR0402",
103
+ select = ["ALL"]
104
+ ignore = [
105
+ # Dynamic attributes on fx.Node require string-based access for mypy
106
+ "B009", "B010",
107
+ # Conflicts with ruff format
108
+ "COM812",
109
+ # Descriptive exception messages preferred
110
+ "EM", "TRY003",
111
+ # Allow long lines for URLs, Sphinx cross-references, and imports
112
+ "E501",
113
+ # Too many false positives
114
+ "ERA001",
115
+ # Common in PyTorch-style APIs
116
+ "FBT",
117
+ # TODOs are fine
118
+ "FIX002",
119
+ # PyTorch naming conventions (N, C, H, W; import F)
120
+ "N806", "N812",
121
+ # Allow magic value comparisons
122
+ "PLR2004",
123
+ # Intermediate variables before return aid readability
124
+ "RET504",
125
+ # Conflicts with quote-style = "preserve"
126
+ "Q000",
127
+ # Intentional Unicode in docstrings and comments
128
+ "RUF002", "RUF003",
129
+ # Explicit if/return True/return False is clearer for predicate functions
130
+ "SIM103",
131
+ # Type-only imports are fine as regular imports
132
+ "TC001",
133
+ # Non-cryptographic random is expected in ML code
134
+ "S311",
135
+ # Prefer unquoted type expressions in cast()
136
+ "TC006",
137
+ # Clashes with dataclass and nn.Module patterns
138
+ "RUF012",
139
+ # Too prescriptive about TODO format
140
+ "TD",
141
+ # D203/D211 and D212/D213 are mutually exclusive pairs
142
+ "D203", "D213",
108
143
  ]
109
144
 
145
+ [tool.ruff.lint.per-file-ignores]
146
+ "src/**/*.py" = ["S101"]
147
+ "tests/**/*.py" = ["ANN", "D103", "S101"]
148
+ "docs/**/*.py" = ["ANN", "E402", "INP001", "S", "T201"]
149
+ "examples/**/*.py" = ["INP001", "T201"]
150
+ ".claude/**/*.py" = ["ALL"]
151
+
152
+ [tool.ruff.lint.pylint]
153
+ max-args = 8
154
+
110
155
  [tool.mypy]
111
156
  ignore_missing_imports = false
112
157
  strict = true
@@ -125,5 +170,8 @@ disable_error_code = ["misc", "no-any-return"]
125
170
  module = ["embedl_deploy._internal.tensorrt.modules.*"]
126
171
  disable_error_code = ["no-any-return"]
127
172
 
173
+ [tool.setuptools.package-data]
174
+ embedl_deploy = ["py.typed"]
175
+
128
176
  [tool.setuptools.dynamic]
129
177
  version = { attr = "embedl_deploy.version.public.PUBLIC_VERSION" }
@@ -22,7 +22,7 @@ class Backend:
22
22
  fusion_patterns: Sequence[Pattern]
23
23
  #: SmoothQuant preparation patterns.
24
24
  smooth_patterns: Sequence[Pattern]
25
- #: Q/DQ stub insertion patterns for quantisation.
25
+ #: Q/DQ stub insertion patterns for quantization.
26
26
  quantized_patterns: Sequence[Pattern]
27
27
 
28
28
 
@@ -120,6 +120,6 @@ def set_backend(name: str) -> None:
120
120
  if name not in backends:
121
121
  available = ", ".join(sorted(backends)) or "(none)"
122
122
  raise ValueError(
123
- f"Backend {name!r} not found. " f"Available backends: {available}"
123
+ f"Backend {name!r} not found. Available backends: {available}"
124
124
  )
125
125
  _BackendState.backend = backends[name]
@@ -63,6 +63,10 @@ class FusedModule(nn.Module, ABC):
63
63
  self.input_quant_stubs: dict[int, QuantStub] = {
64
64
  idx: QuantStub({self}) for idx in self.inputs_to_quantize
65
65
  }
66
+ #: Whether this module has been surrounded with input
67
+ #: ``QuantStub`` entries by
68
+ #: :class:`~embedl_deploy._internal.tensorrt.patterns.quantizations.SurroundWithQuantStubsPattern`.
69
+ self.surrounded: bool = False
66
70
 
67
71
 
68
72
  class _LeafTracer(fx.Tracer):
@@ -0,0 +1,204 @@
1
+ # Copyright (C) 2026 Embedl AB
2
+
3
+ """Core abstractions: Pattern base class and PatternMatch dataclass.
4
+
5
+ Every fusion, conversion, and quantization rule is a
6
+ :class:`~embedl_deploy._internal.core.pattern.Pattern` subclass. The two
7
+ methods — :meth:`~embedl_deploy._internal.core.pattern.Pattern.match` and
8
+ :meth:`~embedl_deploy._internal.core.pattern.Pattern.replace` — encapsulate
9
+ what to look for and how to rewrite the graph.
10
+ """
11
+
12
+ from dataclasses import dataclass
13
+
14
+ from torch import fx, nn
15
+
16
+ from embedl_deploy._internal.core.tree.match import match_tree
17
+ from embedl_deploy._internal.core.tree.replace import replace_tree
18
+ from embedl_deploy._internal.core.tree.types import (
19
+ Graft,
20
+ Replacement,
21
+ Tree,
22
+ TreeMatch,
23
+ Wildcard,
24
+ )
25
+ from embedl_deploy._internal.core.tree.utils import get_module
26
+
27
+
28
+ def _collect_modules(tree_match: TreeMatch) -> list[nn.Module | None]:
29
+ """Resolve matched modules from a tree match.
30
+
31
+ Walks nested branches first (in input order), then
32
+ trunk nodes. For a
33
+ :class:`~embedl_deploy._internal.core.tree.types.Fork`
34
+ tree this means the fork-input branches precede the output
35
+ trunk, so the resulting list matches a constructor signature
36
+ like
37
+ ``FusedModule(branch0_mod, branch1_mod, …, output_mod)``.
38
+
39
+ :class:`~embedl_deploy._internal.core.tree.types.Wildcard`
40
+ entries with ``"?"`` quantifier that matched nothing
41
+ contribute ``None``.
42
+
43
+ :raises TypeError:
44
+ If a matched node is not a ``call_module`` node.
45
+ """
46
+ modules: list[nn.Module | None] = []
47
+ for nested in tree_match.nested:
48
+ modules.extend(_collect_modules(nested))
49
+ for entry in tree_match.trunk_nodes:
50
+ if isinstance(entry, Wildcard):
51
+ if entry.quantifier != "?":
52
+ raise TypeError(
53
+ f"wildcard with quantifier"
54
+ f" {entry.quantifier!r} is not"
55
+ f" supported — graft only supports"
56
+ f" '?' wildcards"
57
+ )
58
+ node = entry.nodes[0] if entry.nodes else None
59
+ else:
60
+ node = entry
61
+ if node is None:
62
+ modules.append(None)
63
+ else:
64
+ mod = get_module(node)
65
+ if mod is None:
66
+ raise TypeError(
67
+ f"node {node.name!r} is not a call_module "
68
+ f"node — graft only works with "
69
+ f"module-only trees"
70
+ )
71
+ modules.append(mod)
72
+ return modules
73
+
74
+
75
+ def _get_replacements(
76
+ graft: Graft,
77
+ tree_match: TreeMatch,
78
+ ) -> list[Replacement]:
79
+ """Build the replacement list from a graft specification."""
80
+ if isinstance(graft, tuple):
81
+ replacements: list[Replacement] = []
82
+ for rep_maker in graft:
83
+ replacements.extend(rep_maker(tree_match))
84
+ return replacements
85
+ modules = _collect_modules(tree_match)
86
+ try:
87
+ return [graft(*modules)]
88
+ except TypeError as exc:
89
+ raise TypeError(
90
+ f"{graft.__name__}() got"
91
+ f" {len(modules)} modules from"
92
+ f" the tree match — check that"
93
+ f" the tree shape matches the"
94
+ f" constructor signature"
95
+ ) from exc
96
+
97
+
98
+ class Pattern:
99
+ """A graph transformation rule: find a sub-graph and replace it.
100
+
101
+ The default :meth:`match` delegates to
102
+ :func:`~embedl_deploy._internal.core.tree.match.match_tree` using the
103
+ class's :attr:`tree`. The default :meth:`replace` constructs
104
+ replacements from :attr:`graft` and delegates to
105
+ :func:`~embedl_deploy._internal.core.tree.replace.replace_tree`.
106
+ Subclasses override either method when they need custom logic
107
+ (pre/post side-effects, post-match filtering, etc.).
108
+
109
+ Patterns with
110
+ :attr:`~embedl_deploy._internal.core.pattern.Pattern.is_conversion` set to
111
+ ``True`` are applied in a first pass to rewrite graph topology before
112
+ fusion patterns are matched.
113
+ """
114
+
115
+ tree: Tree | None = None
116
+ """The pattern topology to match, if using tree-based matching."""
117
+
118
+ graft: Graft | None = None
119
+ """The factories to make replacements for each matched tree, if used."""
120
+
121
+ is_conversion: bool = False
122
+ """If ``True``, this pattern is a structural conversion that must
123
+ be applied before fusion matching."""
124
+
125
+ symbolic_trace_only: bool = False
126
+ """If ``True``, this pattern removes nodes that are artifacts of
127
+ ``symbolic_trace``. This pattern has no effect on graphs exported with
128
+ ``torch.export`` because the nodes never appear in those graphs."""
129
+
130
+ export_graph_only: bool = False
131
+ """If ``True``, this pattern targets nodes that only appear in
132
+ ``torch.export`` aten graphs and has no effect on symbolic-trace output."""
133
+
134
+ def match(self, graph_module: fx.GraphModule) -> list["PatternMatch"]:
135
+ """Find all occurrences of this pattern in `graph_module`.
136
+
137
+ :raises ValueError:
138
+ If the pattern has no ``tree``.
139
+ """
140
+ tree = self.tree
141
+ if tree is None:
142
+ raise ValueError(f"{type(self).__name__} has no tree to match.")
143
+ tree_matches = match_tree(graph_module, tree)
144
+ return [
145
+ PatternMatch(
146
+ pattern=self,
147
+ graph_module=graph_module,
148
+ tree_match=tm,
149
+ )
150
+ for tm in tree_matches
151
+ ]
152
+
153
+ def replace(
154
+ self,
155
+ pattern_match: "PatternMatch",
156
+ ) -> list[fx.Node]:
157
+ """Replace one matched occurrence in-place.
158
+
159
+ :param pattern_match:
160
+ The pattern match to replace.
161
+ :returns:
162
+ The replacement nodes inserted into the graph.
163
+ :raises ValueError:
164
+ If the pattern has no ``graft``.
165
+ :raises TypeError:
166
+ If the ``graft`` class constructor rejects the
167
+ collected modules.
168
+ """
169
+ assert pattern_match.pattern is self
170
+ tree_match = pattern_match.tree_match
171
+ graft = self.graft
172
+ if graft is None:
173
+ raise ValueError(
174
+ f"{type(self).__name__} has no graft"
175
+ f" — override replace() or set graft."
176
+ )
177
+ replacements = _get_replacements(graft, tree_match)
178
+ return replace_tree(
179
+ pattern_match.graph_module, tree_match, replacements
180
+ )
181
+
182
+
183
+ @dataclass
184
+ class PatternMatch:
185
+ """One matched occurrence of a ``Pattern`` in a graph."""
186
+
187
+ #: The pattern that produced this match.
188
+ pattern: Pattern
189
+ #: The graph module that produced this match.
190
+ graph_module: fx.GraphModule
191
+ #: Structured match result produced by
192
+ #: :func:`~embedl_deploy._internal.core.tree.match.match_tree`.
193
+ #: Contains the matched nodes, modules, and nested per-branch
194
+ #: sub-matches for
195
+ #: :class:`~embedl_deploy._internal.core.tree.types.Fork`
196
+ #: topologies.
197
+ tree_match: TreeMatch
198
+ #: Whether to apply this match during transformation.
199
+ apply: bool = True
200
+
201
+ def __repr__(self) -> str:
202
+ pat = type(self.pattern).__name__
203
+ node_names = [n.name for n in self.tree_match.get_tree_nodes()]
204
+ return f"PatternMatch({pat}: {' -> '.join(node_names)})"
@@ -144,11 +144,24 @@ def get_transformation_plan(
144
144
  for node, pats in plan.matches.items():
145
145
  for name, match in pats.items():
146
146
  print(f"{node}: {name} apply={match.apply}")
147
+
147
148
  """
148
149
  if not getattr(graph_module, "_deep_copy_done", False):
149
150
  graph_module = copy.deepcopy(graph_module)
150
151
  setattr(graph_module, "_deep_copy_done", True)
151
152
 
153
+ # Strip torch.export shape-guard nodes that ShapeProp cannot evaluate.
154
+ guards = [
155
+ n
156
+ for n in graph_module.graph.nodes
157
+ if n.op == "call_module" and n.name.startswith("_guards")
158
+ ]
159
+ for node in guards:
160
+ node.replace_all_uses_with(next(iter(node.args)))
161
+ graph_module.graph.erase_node(node)
162
+ if guards:
163
+ graph_module.recompile()
164
+
152
165
  pattern_matches: list[PatternMatch] = []
153
166
  for pattern in patterns:
154
167
  pattern_matches.extend(pattern.match(graph_module))
@@ -168,6 +181,53 @@ def get_transformation_plan(
168
181
  )
169
182
 
170
183
 
184
+ def _propagate_shapes(graph_module: fx.GraphModule) -> None:
185
+ """Re-propagate tensor shapes after graph surgery.
186
+
187
+ Builds fake inputs from placeholder ``tensor_meta`` and runs
188
+ :class:`~torch.fx.passes.shape_prop.ShapeProp`. Pins tensors to the
189
+ graph's parameter device so ``torch.export``'d graphs with
190
+ device-dispatched ops (e.g. SDPA) don't crash on cross-device tensors.
191
+
192
+ :param graph_module:
193
+ The graph module whose shapes should be refreshed.
194
+ """
195
+ try:
196
+ device = next(graph_module.parameters()).device
197
+ except StopIteration:
198
+ device = torch.device("cpu")
199
+
200
+ # Patterns may register new submodules (QuantStub, FusedX) whose
201
+ # buffers default to CPU. Sync to the graph's device so ShapeProp
202
+ # doesn't hit a mixed-device forward.
203
+ if device.type != "cpu":
204
+ graph_module.to(device)
205
+
206
+ fake_args: list[torch.Tensor] = []
207
+ for n in graph_module.graph.nodes:
208
+ if n.op != "placeholder":
209
+ continue
210
+ meta = n.meta.get("tensor_meta")
211
+ if meta is None or not hasattr(meta, "shape"):
212
+ fake_args.clear()
213
+ break
214
+ dtype = getattr(meta, "dtype", torch.float32)
215
+ if dtype.is_floating_point:
216
+ fake_args.append(
217
+ torch.randn(meta.shape, dtype=dtype, device=device)
218
+ )
219
+ else:
220
+ fake_args.append(
221
+ torch.zeros(meta.shape, dtype=dtype, device=device)
222
+ )
223
+ if fake_args:
224
+ # `no_grad` keeps ShapeProp from materialising an autograd tape
225
+ # for the whole forward pass — for large transformer graphs
226
+ # (SAM3, ViT-L, …) the activation tape can blow GPU memory.
227
+ with torch.no_grad():
228
+ ShapeProp(graph_module).propagate(*fake_args) # type: ignore[no-untyped-call]
229
+
230
+
171
231
  def apply_transformation_plan(
172
232
  plan: TransformationPlan,
173
233
  ) -> TransformationResult:
@@ -196,6 +256,7 @@ def apply_transformation_plan(
196
256
  result = apply_transformation_plan(plan)
197
257
  print(result.report)
198
258
  torch.onnx.export(result.model, x, "deployed.onnx")
259
+
199
260
  """
200
261
  graph_module = plan.model
201
262
 
@@ -219,10 +280,7 @@ def apply_transformation_plan(
219
280
  graph_module.recompile()
220
281
  graph_module.eval()
221
282
 
222
- input_node = next(iter(graph_module.graph.nodes))
223
- meta = input_node.meta.get("tensor_meta")
224
- if meta is not None and hasattr(meta, "shape"):
225
- ShapeProp(graph_module).propagate(torch.randn(meta.shape)) # type: ignore[no-untyped-call]
283
+ _propagate_shapes(graph_module)
226
284
 
227
285
  report = _build_report(enabled, skipped)
228
286
 
@@ -264,6 +322,7 @@ def transform(
264
322
  from embedl_deploy.tensorrt import TENSORRT_PATTERNS
265
323
 
266
324
  deployable_model = transform(model, patterns=TENSORRT_PATTERNS).model
325
+
267
326
  """
268
327
  graph_module = (
269
328
  model if isinstance(model, fx.GraphModule) else symbolic_trace(model)
@@ -46,12 +46,13 @@ def calibrate_smooth_quant(
46
46
  if not mq.smooth:
47
47
  return
48
48
 
49
- hooks = []
50
49
  for stub in enabled_stubs:
51
50
  stub.enabled = False
52
- for obs in mq.smooth:
53
- if obs.downstream_linears:
54
- hooks.append(obs.register_forward_hook())
51
+ hooks = [
52
+ obs.register_forward_hook()
53
+ for obs in mq.smooth
54
+ if obs.downstream_linears
55
+ ]
55
56
 
56
57
  try:
57
58
  model.eval()
@@ -86,9 +86,23 @@ def configure(
86
86
  if obs.enabled:
87
87
  obs.config = copy.copy(config.smooth_quant)
88
88
 
89
+ # Snapshot the model device before insertion so newly-created
90
+ # QuantStub / WeightFakeQuantize / observer buffers can be moved
91
+ # to the same device as the rest of the graph after the
92
+ # smooth-quant / Q-DQ insertion passes run. Without this, models
93
+ # already on CUDA hit "scale on cpu vs other tensors on cuda" in
94
+ # ``fake_quantize_per_tensor_affine_cachemask_tensor_qparams``.
95
+ try:
96
+ device = next(model.parameters()).device
97
+ except StopIteration:
98
+ device = torch.device("cpu")
99
+
89
100
  prepare_smooth_quant(model)
90
101
  prepare_qdq(model)
91
102
 
103
+ if device.type != "cpu":
104
+ model.to(device)
105
+
92
106
 
93
107
  def quantize(
94
108
  model: fx.GraphModule,
@@ -153,7 +167,7 @@ def freeze_weight_quantization(model: fx.GraphModule) -> None:
153
167
  mq = get_model_quants(model)
154
168
 
155
169
  for wfq in mq.weight:
156
- mod = list(wfq.consumers)[0]
170
+ mod = next(iter(wfq.consumers))
157
171
  weight = _get_quantized_weight(mod)
158
172
  if isinstance(weight, torch.Tensor):
159
173
  wfq.freeze(weight)
@@ -159,7 +159,8 @@ class WeightFakeQuantize(nn.Module):
159
159
  return weight
160
160
  if self.frozen:
161
161
  scale, zero_point = self.scale, self.zero_point
162
- assert scale is not None and zero_point is not None
162
+ assert scale is not None
163
+ assert zero_point is not None
163
164
  else:
164
165
  scale, zero_point = self._compute_quant_params(weight)
165
166
  q_min, q_max = self.config.quant_min, self.config.quant_max
@@ -0,0 +1,3 @@
1
+ # Copyright (C) 2026 Embedl AB
2
+
3
+ """Tree-based pattern topology DSL, matching, and replacement."""