embedl-deploy 0.2.0__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- embedl_deploy-0.4.0/MANIFEST.in +10 -0
- {embedl_deploy-0.2.0/src/embedl_deploy.egg-info → embedl_deploy-0.4.0}/PKG-INFO +2 -2
- {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/pyproject.toml +54 -6
- {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/core/backend.py +2 -2
- {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/core/modules.py +4 -0
- embedl_deploy-0.4.0/src/embedl_deploy/_internal/core/pattern.py +204 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/core/plan.py +63 -4
- {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/core/quantize/calibrate.py +5 -4
- {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/core/quantize/main.py +15 -1
- {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/core/quantize/stubs.py +2 -1
- embedl_deploy-0.4.0/src/embedl_deploy/_internal/core/tree/__init__.py +3 -0
- embedl_deploy-0.4.0/src/embedl_deploy/_internal/core/tree/match.py +334 -0
- {embedl_deploy-0.2.0/src/embedl_deploy/_internal/core → embedl_deploy-0.4.0/src/embedl_deploy/_internal/core/tree}/replace.py +93 -45
- embedl_deploy-0.4.0/src/embedl_deploy/_internal/core/tree/types.py +326 -0
- embedl_deploy-0.4.0/src/embedl_deploy/_internal/core/tree/utils.py +64 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/version/public.py +1 -1
- {embedl_deploy-0.2.0 → embedl_deploy-0.4.0/src/embedl_deploy.egg-info}/PKG-INFO +2 -2
- embedl_deploy-0.4.0/src/embedl_deploy.egg-info/SOURCES.txt +35 -0
- embedl_deploy-0.4.0/src/embedl_deploy.egg-info/requires.txt +4 -0
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/core/match.py +0 -256
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/core/pattern.py +0 -476
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/__init__.py +0 -3
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/backend.py +0 -18
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/modules/__init__.py +0 -3
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/modules/attention.py +0 -274
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/modules/conv.py +0 -232
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/modules/linear.py +0 -112
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/modules/pointwise.py +0 -39
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/modules/pool.py +0 -25
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/modules/swin_attention.py +0 -460
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/__init__.py +0 -3
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/conversions/__init__.py +0 -15
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/conversions/attention.py +0 -691
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/conversions/general.py +0 -300
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/fusions/__init__.py +0 -3
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/fusions/attention.py +0 -87
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/fusions/conv.py +0 -196
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/fusions/linear.py +0 -86
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/fusions/pointwise.py +0 -55
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/fusions/pool.py +0 -50
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/quantizations.py +0 -292
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions.py +0 -819
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/smoothings.py +0 -123
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/utils.py +0 -81
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/plan.py +0 -123
- embedl_deploy-0.2.0/src/embedl_deploy/tensorrt/__init__.py +0 -40
- embedl_deploy-0.2.0/src/embedl_deploy/tensorrt/modules/__init__.py +0 -40
- embedl_deploy-0.2.0/src/embedl_deploy/tensorrt/patterns/__init__.py +0 -60
- embedl_deploy-0.2.0/src/embedl_deploy.egg-info/SOURCES.txt +0 -59
- embedl_deploy-0.2.0/src/embedl_deploy.egg-info/requires.txt +0 -4
- embedl_deploy-0.2.0/tests/test_version.py +0 -20
- {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/LICENSE +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/NOTICE +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/README.md +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/setup.cfg +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/__init__.py +1 -1
- {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/__init__.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/core/__init__.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/core/quantize/__init__.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/core/quantize/config.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/core/quantize/prepare.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/core/quantize/qat.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/core/quantize/utils.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/backend/__init__.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/py.typed +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/quantize/__init__.py +1 -1
- {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/version/__init__.py +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy.egg-info/dependency_links.txt +0 -0
- {embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: embedl-deploy
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary: Python package to make AI models deployment-ready for any hardware.
|
|
5
5
|
Author-email: Embedl AB <support@embedl.com>
|
|
6
6
|
Project-URL: Homepage, https://www.embedl.com/
|
|
@@ -15,7 +15,7 @@ License-File: LICENSE
|
|
|
15
15
|
License-File: NOTICE
|
|
16
16
|
Requires-Dist: torch
|
|
17
17
|
Provides-Extra: tensorrt
|
|
18
|
-
Requires-Dist: tensorrt; extra == "tensorrt"
|
|
18
|
+
Requires-Dist: embedl-deploy-tensorrt; extra == "tensorrt"
|
|
19
19
|
Dynamic: license-file
|
|
20
20
|
|
|
21
21
|
# embedl-deploy
|
|
@@ -27,7 +27,7 @@ dynamic = ["version"]
|
|
|
27
27
|
dependencies = ["torch"]
|
|
28
28
|
|
|
29
29
|
[project.optional-dependencies]
|
|
30
|
-
tensorrt = ["tensorrt"]
|
|
30
|
+
tensorrt = ["embedl-deploy-tensorrt"]
|
|
31
31
|
|
|
32
32
|
[project.urls]
|
|
33
33
|
Homepage = "https://www.embedl.com/"
|
|
@@ -100,13 +100,58 @@ line-length = 79
|
|
|
100
100
|
quote-style = "preserve"
|
|
101
101
|
|
|
102
102
|
[tool.ruff.lint]
|
|
103
|
-
select = [
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
103
|
+
select = ["ALL"]
|
|
104
|
+
ignore = [
|
|
105
|
+
# Dynamic attributes on fx.Node require string-based access for mypy
|
|
106
|
+
"B009", "B010",
|
|
107
|
+
# Conflicts with ruff format
|
|
108
|
+
"COM812",
|
|
109
|
+
# Descriptive exception messages preferred
|
|
110
|
+
"EM", "TRY003",
|
|
111
|
+
# Allow long lines for URLs, Sphinx cross-references, and imports
|
|
112
|
+
"E501",
|
|
113
|
+
# Too many false positives
|
|
114
|
+
"ERA001",
|
|
115
|
+
# Common in PyTorch-style APIs
|
|
116
|
+
"FBT",
|
|
117
|
+
# TODOs are fine
|
|
118
|
+
"FIX002",
|
|
119
|
+
# PyTorch naming conventions (N, C, H, W; import F)
|
|
120
|
+
"N806", "N812",
|
|
121
|
+
# Allow magic value comparisons
|
|
122
|
+
"PLR2004",
|
|
123
|
+
# Intermediate variables before return aid readability
|
|
124
|
+
"RET504",
|
|
125
|
+
# Conflicts with quote-style = "preserve"
|
|
126
|
+
"Q000",
|
|
127
|
+
# Intentional Unicode in docstrings and comments
|
|
128
|
+
"RUF002", "RUF003",
|
|
129
|
+
# Explicit if/return True/return False is clearer for predicate functions
|
|
130
|
+
"SIM103",
|
|
131
|
+
# Type-only imports are fine as regular imports
|
|
132
|
+
"TC001",
|
|
133
|
+
# Non-cryptographic random is expected in ML code
|
|
134
|
+
"S311",
|
|
135
|
+
# Prefer unquoted type expressions in cast()
|
|
136
|
+
"TC006",
|
|
137
|
+
# Clashes with dataclass and nn.Module patterns
|
|
138
|
+
"RUF012",
|
|
139
|
+
# Too prescriptive about TODO format
|
|
140
|
+
"TD",
|
|
141
|
+
# D203/D211 and D212/D213 are mutually exclusive pairs
|
|
142
|
+
"D203", "D213",
|
|
108
143
|
]
|
|
109
144
|
|
|
145
|
+
[tool.ruff.lint.per-file-ignores]
|
|
146
|
+
"src/**/*.py" = ["S101"]
|
|
147
|
+
"tests/**/*.py" = ["ANN", "D103", "S101"]
|
|
148
|
+
"docs/**/*.py" = ["ANN", "E402", "INP001", "S", "T201"]
|
|
149
|
+
"examples/**/*.py" = ["INP001", "T201"]
|
|
150
|
+
".claude/**/*.py" = ["ALL"]
|
|
151
|
+
|
|
152
|
+
[tool.ruff.lint.pylint]
|
|
153
|
+
max-args = 8
|
|
154
|
+
|
|
110
155
|
[tool.mypy]
|
|
111
156
|
ignore_missing_imports = false
|
|
112
157
|
strict = true
|
|
@@ -125,5 +170,8 @@ disable_error_code = ["misc", "no-any-return"]
|
|
|
125
170
|
module = ["embedl_deploy._internal.tensorrt.modules.*"]
|
|
126
171
|
disable_error_code = ["no-any-return"]
|
|
127
172
|
|
|
173
|
+
[tool.setuptools.package-data]
|
|
174
|
+
embedl_deploy = ["py.typed"]
|
|
175
|
+
|
|
128
176
|
[tool.setuptools.dynamic]
|
|
129
177
|
version = { attr = "embedl_deploy.version.public.PUBLIC_VERSION" }
|
|
@@ -22,7 +22,7 @@ class Backend:
|
|
|
22
22
|
fusion_patterns: Sequence[Pattern]
|
|
23
23
|
#: SmoothQuant preparation patterns.
|
|
24
24
|
smooth_patterns: Sequence[Pattern]
|
|
25
|
-
#: Q/DQ stub insertion patterns for
|
|
25
|
+
#: Q/DQ stub insertion patterns for quantization.
|
|
26
26
|
quantized_patterns: Sequence[Pattern]
|
|
27
27
|
|
|
28
28
|
|
|
@@ -120,6 +120,6 @@ def set_backend(name: str) -> None:
|
|
|
120
120
|
if name not in backends:
|
|
121
121
|
available = ", ".join(sorted(backends)) or "(none)"
|
|
122
122
|
raise ValueError(
|
|
123
|
-
f"Backend {name!r} not found.
|
|
123
|
+
f"Backend {name!r} not found. Available backends: {available}"
|
|
124
124
|
)
|
|
125
125
|
_BackendState.backend = backends[name]
|
|
@@ -63,6 +63,10 @@ class FusedModule(nn.Module, ABC):
|
|
|
63
63
|
self.input_quant_stubs: dict[int, QuantStub] = {
|
|
64
64
|
idx: QuantStub({self}) for idx in self.inputs_to_quantize
|
|
65
65
|
}
|
|
66
|
+
#: Whether this module has been surrounded with input
|
|
67
|
+
#: ``QuantStub`` entries by
|
|
68
|
+
#: :class:`~embedl_deploy._internal.tensorrt.patterns.quantizations.SurroundWithQuantStubsPattern`.
|
|
69
|
+
self.surrounded: bool = False
|
|
66
70
|
|
|
67
71
|
|
|
68
72
|
class _LeafTracer(fx.Tracer):
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
# Copyright (C) 2026 Embedl AB
|
|
2
|
+
|
|
3
|
+
"""Core abstractions: Pattern base class and PatternMatch dataclass.
|
|
4
|
+
|
|
5
|
+
Every fusion, conversion, and quantization rule is a
|
|
6
|
+
:class:`~embedl_deploy._internal.core.pattern.Pattern` subclass. The two
|
|
7
|
+
methods — :meth:`~embedl_deploy._internal.core.pattern.Pattern.match` and
|
|
8
|
+
:meth:`~embedl_deploy._internal.core.pattern.Pattern.replace` — encapsulate
|
|
9
|
+
what to look for and how to rewrite the graph.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
|
|
14
|
+
from torch import fx, nn
|
|
15
|
+
|
|
16
|
+
from embedl_deploy._internal.core.tree.match import match_tree
|
|
17
|
+
from embedl_deploy._internal.core.tree.replace import replace_tree
|
|
18
|
+
from embedl_deploy._internal.core.tree.types import (
|
|
19
|
+
Graft,
|
|
20
|
+
Replacement,
|
|
21
|
+
Tree,
|
|
22
|
+
TreeMatch,
|
|
23
|
+
Wildcard,
|
|
24
|
+
)
|
|
25
|
+
from embedl_deploy._internal.core.tree.utils import get_module
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _collect_modules(tree_match: TreeMatch) -> list[nn.Module | None]:
|
|
29
|
+
"""Resolve matched modules from a tree match.
|
|
30
|
+
|
|
31
|
+
Walks nested branches first (in input order), then
|
|
32
|
+
trunk nodes. For a
|
|
33
|
+
:class:`~embedl_deploy._internal.core.tree.types.Fork`
|
|
34
|
+
tree this means the fork-input branches precede the output
|
|
35
|
+
trunk, so the resulting list matches a constructor signature
|
|
36
|
+
like
|
|
37
|
+
``FusedModule(branch0_mod, branch1_mod, …, output_mod)``.
|
|
38
|
+
|
|
39
|
+
:class:`~embedl_deploy._internal.core.tree.types.Wildcard`
|
|
40
|
+
entries with ``"?"`` quantifier that matched nothing
|
|
41
|
+
contribute ``None``.
|
|
42
|
+
|
|
43
|
+
:raises TypeError:
|
|
44
|
+
If a matched node is not a ``call_module`` node.
|
|
45
|
+
"""
|
|
46
|
+
modules: list[nn.Module | None] = []
|
|
47
|
+
for nested in tree_match.nested:
|
|
48
|
+
modules.extend(_collect_modules(nested))
|
|
49
|
+
for entry in tree_match.trunk_nodes:
|
|
50
|
+
if isinstance(entry, Wildcard):
|
|
51
|
+
if entry.quantifier != "?":
|
|
52
|
+
raise TypeError(
|
|
53
|
+
f"wildcard with quantifier"
|
|
54
|
+
f" {entry.quantifier!r} is not"
|
|
55
|
+
f" supported — graft only supports"
|
|
56
|
+
f" '?' wildcards"
|
|
57
|
+
)
|
|
58
|
+
node = entry.nodes[0] if entry.nodes else None
|
|
59
|
+
else:
|
|
60
|
+
node = entry
|
|
61
|
+
if node is None:
|
|
62
|
+
modules.append(None)
|
|
63
|
+
else:
|
|
64
|
+
mod = get_module(node)
|
|
65
|
+
if mod is None:
|
|
66
|
+
raise TypeError(
|
|
67
|
+
f"node {node.name!r} is not a call_module "
|
|
68
|
+
f"node — graft only works with "
|
|
69
|
+
f"module-only trees"
|
|
70
|
+
)
|
|
71
|
+
modules.append(mod)
|
|
72
|
+
return modules
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _get_replacements(
|
|
76
|
+
graft: Graft,
|
|
77
|
+
tree_match: TreeMatch,
|
|
78
|
+
) -> list[Replacement]:
|
|
79
|
+
"""Build the replacement list from a graft specification."""
|
|
80
|
+
if isinstance(graft, tuple):
|
|
81
|
+
replacements: list[Replacement] = []
|
|
82
|
+
for rep_maker in graft:
|
|
83
|
+
replacements.extend(rep_maker(tree_match))
|
|
84
|
+
return replacements
|
|
85
|
+
modules = _collect_modules(tree_match)
|
|
86
|
+
try:
|
|
87
|
+
return [graft(*modules)]
|
|
88
|
+
except TypeError as exc:
|
|
89
|
+
raise TypeError(
|
|
90
|
+
f"{graft.__name__}() got"
|
|
91
|
+
f" {len(modules)} modules from"
|
|
92
|
+
f" the tree match — check that"
|
|
93
|
+
f" the tree shape matches the"
|
|
94
|
+
f" constructor signature"
|
|
95
|
+
) from exc
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class Pattern:
|
|
99
|
+
"""A graph transformation rule: find a sub-graph and replace it.
|
|
100
|
+
|
|
101
|
+
The default :meth:`match` delegates to
|
|
102
|
+
:func:`~embedl_deploy._internal.core.tree.match.match_tree` using the
|
|
103
|
+
class's :attr:`tree`. The default :meth:`replace` constructs
|
|
104
|
+
replacements from :attr:`graft` and delegates to
|
|
105
|
+
:func:`~embedl_deploy._internal.core.tree.replace.replace_tree`.
|
|
106
|
+
Subclasses override either method when they need custom logic
|
|
107
|
+
(pre/post side-effects, post-match filtering, etc.).
|
|
108
|
+
|
|
109
|
+
Patterns with
|
|
110
|
+
:attr:`~embedl_deploy._internal.core.pattern.Pattern.is_conversion` set to
|
|
111
|
+
``True`` are applied in a first pass to rewrite graph topology before
|
|
112
|
+
fusion patterns are matched.
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
tree: Tree | None = None
|
|
116
|
+
"""The pattern topology to match, if using tree-based matching."""
|
|
117
|
+
|
|
118
|
+
graft: Graft | None = None
|
|
119
|
+
"""The factories to make replacements for each matched tree, if used."""
|
|
120
|
+
|
|
121
|
+
is_conversion: bool = False
|
|
122
|
+
"""If ``True``, this pattern is a structural conversion that must
|
|
123
|
+
be applied before fusion matching."""
|
|
124
|
+
|
|
125
|
+
symbolic_trace_only: bool = False
|
|
126
|
+
"""If ``True``, this pattern removes nodes that are artifacts of
|
|
127
|
+
``symbolic_trace``. This pattern has no effect on graphs exported with
|
|
128
|
+
``torch.export`` because the nodes never appear in those graphs."""
|
|
129
|
+
|
|
130
|
+
export_graph_only: bool = False
|
|
131
|
+
"""If ``True``, this pattern targets nodes that only appear in
|
|
132
|
+
``torch.export`` aten graphs and has no effect on symbolic-trace output."""
|
|
133
|
+
|
|
134
|
+
def match(self, graph_module: fx.GraphModule) -> list["PatternMatch"]:
|
|
135
|
+
"""Find all occurrences of this pattern in `graph_module`.
|
|
136
|
+
|
|
137
|
+
:raises ValueError:
|
|
138
|
+
If the pattern has no ``tree``.
|
|
139
|
+
"""
|
|
140
|
+
tree = self.tree
|
|
141
|
+
if tree is None:
|
|
142
|
+
raise ValueError(f"{type(self).__name__} has no tree to match.")
|
|
143
|
+
tree_matches = match_tree(graph_module, tree)
|
|
144
|
+
return [
|
|
145
|
+
PatternMatch(
|
|
146
|
+
pattern=self,
|
|
147
|
+
graph_module=graph_module,
|
|
148
|
+
tree_match=tm,
|
|
149
|
+
)
|
|
150
|
+
for tm in tree_matches
|
|
151
|
+
]
|
|
152
|
+
|
|
153
|
+
def replace(
|
|
154
|
+
self,
|
|
155
|
+
pattern_match: "PatternMatch",
|
|
156
|
+
) -> list[fx.Node]:
|
|
157
|
+
"""Replace one matched occurrence in-place.
|
|
158
|
+
|
|
159
|
+
:param pattern_match:
|
|
160
|
+
The pattern match to replace.
|
|
161
|
+
:returns:
|
|
162
|
+
The replacement nodes inserted into the graph.
|
|
163
|
+
:raises ValueError:
|
|
164
|
+
If the pattern has no ``graft``.
|
|
165
|
+
:raises TypeError:
|
|
166
|
+
If the ``graft`` class constructor rejects the
|
|
167
|
+
collected modules.
|
|
168
|
+
"""
|
|
169
|
+
assert pattern_match.pattern is self
|
|
170
|
+
tree_match = pattern_match.tree_match
|
|
171
|
+
graft = self.graft
|
|
172
|
+
if graft is None:
|
|
173
|
+
raise ValueError(
|
|
174
|
+
f"{type(self).__name__} has no graft"
|
|
175
|
+
f" — override replace() or set graft."
|
|
176
|
+
)
|
|
177
|
+
replacements = _get_replacements(graft, tree_match)
|
|
178
|
+
return replace_tree(
|
|
179
|
+
pattern_match.graph_module, tree_match, replacements
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
@dataclass
|
|
184
|
+
class PatternMatch:
|
|
185
|
+
"""One matched occurrence of a ``Pattern`` in a graph."""
|
|
186
|
+
|
|
187
|
+
#: The pattern that produced this match.
|
|
188
|
+
pattern: Pattern
|
|
189
|
+
#: The graph module that produced this match.
|
|
190
|
+
graph_module: fx.GraphModule
|
|
191
|
+
#: Structured match result produced by
|
|
192
|
+
#: :func:`~embedl_deploy._internal.core.tree.match.match_tree`.
|
|
193
|
+
#: Contains the matched nodes, modules, and nested per-branch
|
|
194
|
+
#: sub-matches for
|
|
195
|
+
#: :class:`~embedl_deploy._internal.core.tree.types.Fork`
|
|
196
|
+
#: topologies.
|
|
197
|
+
tree_match: TreeMatch
|
|
198
|
+
#: Whether to apply this match during transformation.
|
|
199
|
+
apply: bool = True
|
|
200
|
+
|
|
201
|
+
def __repr__(self) -> str:
|
|
202
|
+
pat = type(self.pattern).__name__
|
|
203
|
+
node_names = [n.name for n in self.tree_match.get_tree_nodes()]
|
|
204
|
+
return f"PatternMatch({pat}: {' -> '.join(node_names)})"
|
|
@@ -144,11 +144,24 @@ def get_transformation_plan(
|
|
|
144
144
|
for node, pats in plan.matches.items():
|
|
145
145
|
for name, match in pats.items():
|
|
146
146
|
print(f"{node}: {name} apply={match.apply}")
|
|
147
|
+
|
|
147
148
|
"""
|
|
148
149
|
if not getattr(graph_module, "_deep_copy_done", False):
|
|
149
150
|
graph_module = copy.deepcopy(graph_module)
|
|
150
151
|
setattr(graph_module, "_deep_copy_done", True)
|
|
151
152
|
|
|
153
|
+
# Strip torch.export shape-guard nodes that ShapeProp cannot evaluate.
|
|
154
|
+
guards = [
|
|
155
|
+
n
|
|
156
|
+
for n in graph_module.graph.nodes
|
|
157
|
+
if n.op == "call_module" and n.name.startswith("_guards")
|
|
158
|
+
]
|
|
159
|
+
for node in guards:
|
|
160
|
+
node.replace_all_uses_with(next(iter(node.args)))
|
|
161
|
+
graph_module.graph.erase_node(node)
|
|
162
|
+
if guards:
|
|
163
|
+
graph_module.recompile()
|
|
164
|
+
|
|
152
165
|
pattern_matches: list[PatternMatch] = []
|
|
153
166
|
for pattern in patterns:
|
|
154
167
|
pattern_matches.extend(pattern.match(graph_module))
|
|
@@ -168,6 +181,53 @@ def get_transformation_plan(
|
|
|
168
181
|
)
|
|
169
182
|
|
|
170
183
|
|
|
184
|
+
def _propagate_shapes(graph_module: fx.GraphModule) -> None:
|
|
185
|
+
"""Re-propagate tensor shapes after graph surgery.
|
|
186
|
+
|
|
187
|
+
Builds fake inputs from placeholder ``tensor_meta`` and runs
|
|
188
|
+
:class:`~torch.fx.passes.shape_prop.ShapeProp`. Pins tensors to the
|
|
189
|
+
graph's parameter device so ``torch.export``'d graphs with
|
|
190
|
+
device-dispatched ops (e.g. SDPA) don't crash on cross-device tensors.
|
|
191
|
+
|
|
192
|
+
:param graph_module:
|
|
193
|
+
The graph module whose shapes should be refreshed.
|
|
194
|
+
"""
|
|
195
|
+
try:
|
|
196
|
+
device = next(graph_module.parameters()).device
|
|
197
|
+
except StopIteration:
|
|
198
|
+
device = torch.device("cpu")
|
|
199
|
+
|
|
200
|
+
# Patterns may register new submodules (QuantStub, FusedX) whose
|
|
201
|
+
# buffers default to CPU. Sync to the graph's device so ShapeProp
|
|
202
|
+
# doesn't hit a mixed-device forward.
|
|
203
|
+
if device.type != "cpu":
|
|
204
|
+
graph_module.to(device)
|
|
205
|
+
|
|
206
|
+
fake_args: list[torch.Tensor] = []
|
|
207
|
+
for n in graph_module.graph.nodes:
|
|
208
|
+
if n.op != "placeholder":
|
|
209
|
+
continue
|
|
210
|
+
meta = n.meta.get("tensor_meta")
|
|
211
|
+
if meta is None or not hasattr(meta, "shape"):
|
|
212
|
+
fake_args.clear()
|
|
213
|
+
break
|
|
214
|
+
dtype = getattr(meta, "dtype", torch.float32)
|
|
215
|
+
if dtype.is_floating_point:
|
|
216
|
+
fake_args.append(
|
|
217
|
+
torch.randn(meta.shape, dtype=dtype, device=device)
|
|
218
|
+
)
|
|
219
|
+
else:
|
|
220
|
+
fake_args.append(
|
|
221
|
+
torch.zeros(meta.shape, dtype=dtype, device=device)
|
|
222
|
+
)
|
|
223
|
+
if fake_args:
|
|
224
|
+
# `no_grad` keeps ShapeProp from materialising an autograd tape
|
|
225
|
+
# for the whole forward pass — for large transformer graphs
|
|
226
|
+
# (SAM3, ViT-L, …) the activation tape can blow GPU memory.
|
|
227
|
+
with torch.no_grad():
|
|
228
|
+
ShapeProp(graph_module).propagate(*fake_args) # type: ignore[no-untyped-call]
|
|
229
|
+
|
|
230
|
+
|
|
171
231
|
def apply_transformation_plan(
|
|
172
232
|
plan: TransformationPlan,
|
|
173
233
|
) -> TransformationResult:
|
|
@@ -196,6 +256,7 @@ def apply_transformation_plan(
|
|
|
196
256
|
result = apply_transformation_plan(plan)
|
|
197
257
|
print(result.report)
|
|
198
258
|
torch.onnx.export(result.model, x, "deployed.onnx")
|
|
259
|
+
|
|
199
260
|
"""
|
|
200
261
|
graph_module = plan.model
|
|
201
262
|
|
|
@@ -219,10 +280,7 @@ def apply_transformation_plan(
|
|
|
219
280
|
graph_module.recompile()
|
|
220
281
|
graph_module.eval()
|
|
221
282
|
|
|
222
|
-
|
|
223
|
-
meta = input_node.meta.get("tensor_meta")
|
|
224
|
-
if meta is not None and hasattr(meta, "shape"):
|
|
225
|
-
ShapeProp(graph_module).propagate(torch.randn(meta.shape)) # type: ignore[no-untyped-call]
|
|
283
|
+
_propagate_shapes(graph_module)
|
|
226
284
|
|
|
227
285
|
report = _build_report(enabled, skipped)
|
|
228
286
|
|
|
@@ -264,6 +322,7 @@ def transform(
|
|
|
264
322
|
from embedl_deploy.tensorrt import TENSORRT_PATTERNS
|
|
265
323
|
|
|
266
324
|
deployable_model = transform(model, patterns=TENSORRT_PATTERNS).model
|
|
325
|
+
|
|
267
326
|
"""
|
|
268
327
|
graph_module = (
|
|
269
328
|
model if isinstance(model, fx.GraphModule) else symbolic_trace(model)
|
{embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/core/quantize/calibrate.py
RENAMED
|
@@ -46,12 +46,13 @@ def calibrate_smooth_quant(
|
|
|
46
46
|
if not mq.smooth:
|
|
47
47
|
return
|
|
48
48
|
|
|
49
|
-
hooks = []
|
|
50
49
|
for stub in enabled_stubs:
|
|
51
50
|
stub.enabled = False
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
51
|
+
hooks = [
|
|
52
|
+
obs.register_forward_hook()
|
|
53
|
+
for obs in mq.smooth
|
|
54
|
+
if obs.downstream_linears
|
|
55
|
+
]
|
|
55
56
|
|
|
56
57
|
try:
|
|
57
58
|
model.eval()
|
{embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/core/quantize/main.py
RENAMED
|
@@ -86,9 +86,23 @@ def configure(
|
|
|
86
86
|
if obs.enabled:
|
|
87
87
|
obs.config = copy.copy(config.smooth_quant)
|
|
88
88
|
|
|
89
|
+
# Snapshot the model device before insertion so newly-created
|
|
90
|
+
# QuantStub / WeightFakeQuantize / observer buffers can be moved
|
|
91
|
+
# to the same device as the rest of the graph after the
|
|
92
|
+
# smooth-quant / Q-DQ insertion passes run. Without this, models
|
|
93
|
+
# already on CUDA hit "scale on cpu vs other tensors on cuda" in
|
|
94
|
+
# ``fake_quantize_per_tensor_affine_cachemask_tensor_qparams``.
|
|
95
|
+
try:
|
|
96
|
+
device = next(model.parameters()).device
|
|
97
|
+
except StopIteration:
|
|
98
|
+
device = torch.device("cpu")
|
|
99
|
+
|
|
89
100
|
prepare_smooth_quant(model)
|
|
90
101
|
prepare_qdq(model)
|
|
91
102
|
|
|
103
|
+
if device.type != "cpu":
|
|
104
|
+
model.to(device)
|
|
105
|
+
|
|
92
106
|
|
|
93
107
|
def quantize(
|
|
94
108
|
model: fx.GraphModule,
|
|
@@ -153,7 +167,7 @@ def freeze_weight_quantization(model: fx.GraphModule) -> None:
|
|
|
153
167
|
mq = get_model_quants(model)
|
|
154
168
|
|
|
155
169
|
for wfq in mq.weight:
|
|
156
|
-
mod =
|
|
170
|
+
mod = next(iter(wfq.consumers))
|
|
157
171
|
weight = _get_quantized_weight(mod)
|
|
158
172
|
if isinstance(weight, torch.Tensor):
|
|
159
173
|
wfq.freeze(weight)
|
{embedl_deploy-0.2.0 → embedl_deploy-0.4.0}/src/embedl_deploy/_internal/core/quantize/stubs.py
RENAMED
|
@@ -159,7 +159,8 @@ class WeightFakeQuantize(nn.Module):
|
|
|
159
159
|
return weight
|
|
160
160
|
if self.frozen:
|
|
161
161
|
scale, zero_point = self.scale, self.zero_point
|
|
162
|
-
assert scale is not None
|
|
162
|
+
assert scale is not None
|
|
163
|
+
assert zero_point is not None
|
|
163
164
|
else:
|
|
164
165
|
scale, zero_point = self._compute_quant_params(weight)
|
|
165
166
|
q_min, q_max = self.config.quant_min, self.config.quant_max
|