embedl-deploy 0.1.0__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {embedl_deploy-0.1.0/src/embedl_deploy.egg-info → embedl_deploy-0.2.0}/PKG-INFO +1 -1
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/pyproject.toml +18 -0
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/core/backend.py +125 -0
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/core/match.py +75 -48
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/core/modules.py +27 -7
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/core/pattern.py +235 -52
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/core/plan.py +33 -18
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/core/quantize/__init__.py +1 -1
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/core/quantize/calibrate.py +133 -0
- {embedl_deploy-0.1.0/src/embedl_deploy/_internal/core → embedl_deploy-0.2.0/src/embedl_deploy/_internal/core/quantize}/config.py +60 -76
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/core/quantize/main.py +179 -0
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/core/quantize/prepare.py +52 -0
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/core/quantize/qat.py +1 -1
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/core/quantize/stubs.py +385 -0
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/core/quantize/utils.py +90 -0
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/core/replace.py +5 -1
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/backend.py +18 -0
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/modules/attention.py +23 -16
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/modules/conv.py +31 -32
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/modules/linear.py +23 -24
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/modules/pointwise.py +39 -0
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/modules/pool.py +1 -1
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/modules/swin_attention.py +460 -0
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/conversions/__init__.py +15 -0
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/conversions/attention.py +691 -0
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/conversions/general.py +300 -0
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions/attention.py +29 -13
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions/conv.py +42 -41
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions/linear.py +16 -23
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/fusions/pointwise.py +55 -0
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions/pool.py +2 -4
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/patterns/quantizations.py +47 -31
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions.py +819 -0
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/patterns/smoothings.py +27 -21
- embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/utils.py +81 -0
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/plan.py +52 -9
- embedl_deploy-0.2.0/src/embedl_deploy/backend/__init__.py +15 -0
- embedl_deploy-0.2.0/src/embedl_deploy/py.typed +0 -0
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/quantize/__init__.py +12 -10
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/tensorrt/modules/__init__.py +9 -9
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/tensorrt/patterns/__init__.py +19 -9
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/version/public.py +1 -1
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0/src/embedl_deploy.egg-info}/PKG-INFO +1 -1
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy.egg-info/SOURCES.txt +17 -7
- embedl_deploy-0.1.0/src/embedl_deploy/_internal/core/quantize/main.py +0 -118
- embedl_deploy-0.1.0/src/embedl_deploy/_internal/core/quantize/modules.py +0 -251
- embedl_deploy-0.1.0/src/embedl_deploy/_internal/core/quantize/qdq.py +0 -500
- embedl_deploy-0.1.0/src/embedl_deploy/_internal/core/quantize/qdq_v2.py +0 -98
- embedl_deploy-0.1.0/src/embedl_deploy/_internal/core/quantize/smooth.py +0 -267
- embedl_deploy-0.1.0/src/embedl_deploy/_internal/core/quantize/smooth_v2.py +0 -237
- embedl_deploy-0.1.0/src/embedl_deploy/_internal/tensorrt/patterns/conversions.py +0 -425
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/LICENSE +0 -0
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/NOTICE +0 -0
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/README.md +0 -0
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/setup.cfg +0 -0
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/__init__.py +0 -0
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/__init__.py +0 -0
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/core/__init__.py +0 -0
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/__init__.py +0 -0
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/modules/__init__.py +0 -0
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/patterns/__init__.py +0 -0
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions/__init__.py +0 -0
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/tensorrt/__init__.py +0 -0
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/version/__init__.py +0 -0
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy.egg-info/dependency_links.txt +0 -0
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy.egg-info/requires.txt +0 -0
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy.egg-info/top_level.txt +0 -0
- {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/tests/test_version.py +0 -0
|
@@ -107,5 +107,23 @@ select = [
|
|
|
107
107
|
"PLR0402",
|
|
108
108
|
]
|
|
109
109
|
|
|
110
|
+
[tool.mypy]
|
|
111
|
+
ignore_missing_imports = false
|
|
112
|
+
strict = true
|
|
113
|
+
|
|
114
|
+
[[tool.mypy.overrides]]
|
|
115
|
+
module = ["torch.*", "pytest.*"]
|
|
116
|
+
ignore_missing_imports = true
|
|
117
|
+
|
|
118
|
+
[[tool.mypy.overrides]]
|
|
119
|
+
module = ["tests.*"]
|
|
120
|
+
disallow_untyped_defs = false
|
|
121
|
+
disallow_untyped_calls = false
|
|
122
|
+
disable_error_code = ["misc", "no-any-return"]
|
|
123
|
+
|
|
124
|
+
[[tool.mypy.overrides]]
|
|
125
|
+
module = ["embedl_deploy._internal.tensorrt.modules.*"]
|
|
126
|
+
disable_error_code = ["no-any-return"]
|
|
127
|
+
|
|
110
128
|
[tool.setuptools.dynamic]
|
|
111
129
|
version = { attr = "embedl_deploy.version.public.PUBLIC_VERSION" }
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
# Copyright (C) 2026 Embedl AB
|
|
2
|
+
|
|
3
|
+
"""Backend discovery and selection."""
|
|
4
|
+
|
|
5
|
+
import importlib
|
|
6
|
+
from collections.abc import Sequence
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
from embedl_deploy._internal.core.pattern import Pattern
|
|
11
|
+
|
|
12
|
+
_INTERNAL_DIR = Path(__file__).resolve().parent.parent
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(frozen=True)
|
|
16
|
+
class Backend:
|
|
17
|
+
"""A collection of patterns for a specific hardware target."""
|
|
18
|
+
|
|
19
|
+
#: Structural rewrite patterns, applied iteratively.
|
|
20
|
+
conversion_patterns: Sequence[Pattern]
|
|
21
|
+
#: Fusion patterns, applied in a single pass after conversions.
|
|
22
|
+
fusion_patterns: Sequence[Pattern]
|
|
23
|
+
#: SmoothQuant preparation patterns.
|
|
24
|
+
smooth_patterns: Sequence[Pattern]
|
|
25
|
+
#: Q/DQ stub insertion patterns for quantisation.
|
|
26
|
+
quantized_patterns: Sequence[Pattern]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class _BackendState:
|
|
30
|
+
"""Module-level mutable state for backend discovery and selection."""
|
|
31
|
+
|
|
32
|
+
#: The currently selected backend.
|
|
33
|
+
backend: Backend | None = None
|
|
34
|
+
#: Cached discovery result.
|
|
35
|
+
backends: dict[str, Backend] | None = None
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def reset(cls) -> None:
|
|
39
|
+
"""Clear cached discovery results and the active backend."""
|
|
40
|
+
cls.backend = None
|
|
41
|
+
cls.backends = None
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _discover_backends() -> dict[str, Backend]:
|
|
45
|
+
"""Scan ``_internal/`` for importable backend packages.
|
|
46
|
+
|
|
47
|
+
Each subdirectory (except ``core``) is tried as
|
|
48
|
+
``embedl_deploy._internal.<name>.backend``. Directories whose
|
|
49
|
+
module cannot be found are skipped; import errors from
|
|
50
|
+
transitive dependencies are propagated. Results are cached
|
|
51
|
+
after the first call.
|
|
52
|
+
|
|
53
|
+
:returns:
|
|
54
|
+
Mapping of backend name to ``Backend`` instance.
|
|
55
|
+
"""
|
|
56
|
+
backends = _BackendState.backends
|
|
57
|
+
if backends is None:
|
|
58
|
+
backends = {}
|
|
59
|
+
for entry in sorted(_INTERNAL_DIR.iterdir()):
|
|
60
|
+
if (
|
|
61
|
+
not entry.is_dir()
|
|
62
|
+
or entry.name.startswith("_")
|
|
63
|
+
or entry.name == "core"
|
|
64
|
+
):
|
|
65
|
+
continue
|
|
66
|
+
module_path = f"embedl_deploy._internal.{entry.name}.backend"
|
|
67
|
+
try:
|
|
68
|
+
mod = importlib.import_module(module_path)
|
|
69
|
+
except ModuleNotFoundError as e:
|
|
70
|
+
if e.name == module_path:
|
|
71
|
+
continue
|
|
72
|
+
raise
|
|
73
|
+
backend = getattr(mod, "BACKEND", None)
|
|
74
|
+
if isinstance(backend, Backend):
|
|
75
|
+
backends[entry.name] = backend
|
|
76
|
+
_BackendState.backends = backends
|
|
77
|
+
return backends
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def get_backend() -> Backend:
|
|
81
|
+
"""Return the active backend, discovering it if necessary.
|
|
82
|
+
|
|
83
|
+
If no backend has been set via :func:`set_backend`, the installed
|
|
84
|
+
backends are discovered automatically. When exactly one is found
|
|
85
|
+
it becomes the active backend.
|
|
86
|
+
|
|
87
|
+
:returns:
|
|
88
|
+
The active :class:`Backend`.
|
|
89
|
+
:raises RuntimeError:
|
|
90
|
+
If no backends are installed, or if multiple backends are
|
|
91
|
+
installed and none has been explicitly selected.
|
|
92
|
+
"""
|
|
93
|
+
backend = _BackendState.backend
|
|
94
|
+
if backend is None:
|
|
95
|
+
backends = _discover_backends()
|
|
96
|
+
if len(backends) == 0:
|
|
97
|
+
raise RuntimeError(
|
|
98
|
+
"No backends found — install at least one backend"
|
|
99
|
+
)
|
|
100
|
+
if len(backends) > 1:
|
|
101
|
+
names = ", ".join(sorted(backends))
|
|
102
|
+
raise RuntimeError(
|
|
103
|
+
f"Multiple backends found ({names}). "
|
|
104
|
+
"Call set_backend() to select one."
|
|
105
|
+
)
|
|
106
|
+
backend = next(iter(backends.values()))
|
|
107
|
+
_BackendState.backend = backend
|
|
108
|
+
return backend
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def set_backend(name: str) -> None:
|
|
112
|
+
"""Select the active backend by name.
|
|
113
|
+
|
|
114
|
+
:param name:
|
|
115
|
+
The name of a discovered backend (e.g. ``"tensorrt"``).
|
|
116
|
+
:raises ValueError:
|
|
117
|
+
If `name` does not match any installed backend.
|
|
118
|
+
"""
|
|
119
|
+
backends = _discover_backends()
|
|
120
|
+
if name not in backends:
|
|
121
|
+
available = ", ".join(sorted(backends)) or "(none)"
|
|
122
|
+
raise ValueError(
|
|
123
|
+
f"Backend {name!r} not found. " f"Available backends: {available}"
|
|
124
|
+
)
|
|
125
|
+
_BackendState.backend = backends[name]
|
|
@@ -9,8 +9,9 @@ occurrences of operation chains in ``torch.fx`` graphs.
|
|
|
9
9
|
"""
|
|
10
10
|
|
|
11
11
|
import itertools
|
|
12
|
+
import types
|
|
12
13
|
|
|
13
|
-
from torch import fx
|
|
14
|
+
from torch import fx
|
|
14
15
|
|
|
15
16
|
from embedl_deploy._internal.core.pattern import (
|
|
16
17
|
Fork,
|
|
@@ -18,6 +19,7 @@ from embedl_deploy._internal.core.pattern import (
|
|
|
18
19
|
NodeCheck,
|
|
19
20
|
Pattern,
|
|
20
21
|
PatternMatch,
|
|
22
|
+
SharedNodeCheck,
|
|
21
23
|
Tree,
|
|
22
24
|
TreeMatch,
|
|
23
25
|
Trunk,
|
|
@@ -25,26 +27,40 @@ from embedl_deploy._internal.core.pattern import (
|
|
|
25
27
|
get_module,
|
|
26
28
|
)
|
|
27
29
|
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
30
|
+
|
|
31
|
+
class _SharedNodeCheckSession:
|
|
32
|
+
"""Checkpoint/rollback scope for ``SharedNodeCheck`` cache entries.
|
|
33
|
+
|
|
34
|
+
Between permutation attempts inside a fork, call :meth:`rollback`
|
|
35
|
+
explicitly. On normal exit (success) cache entries survive so that
|
|
36
|
+
enclosing sessions can still enforce cross-fork shared-node
|
|
37
|
+
constraints. On failure exit (no permutation matched and the block
|
|
38
|
+
falls through) ``__exit__`` rolls back automatically.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self) -> None:
|
|
42
|
+
self._checkpoint = SharedNodeCheck.checkpoint()
|
|
43
|
+
self._succeeded = False
|
|
44
|
+
|
|
45
|
+
def __enter__(self) -> "_SharedNodeCheckSession":
|
|
46
|
+
return self
|
|
47
|
+
|
|
48
|
+
def success(self) -> None:
|
|
49
|
+
"""Mark this session as successful — skip rollback on exit."""
|
|
50
|
+
self._succeeded = True
|
|
51
|
+
|
|
52
|
+
def __exit__(self, *args: object) -> None:
|
|
53
|
+
if not self._succeeded:
|
|
54
|
+
self.rollback()
|
|
55
|
+
|
|
56
|
+
def rollback(self) -> None:
|
|
57
|
+
"""Reset every cache event logged since this session started."""
|
|
58
|
+
SharedNodeCheck.rollback_to(self._checkpoint)
|
|
43
59
|
|
|
44
60
|
|
|
45
61
|
def _node_matches(node: fx.Node, checks: ModType | NodeCheck) -> bool:
|
|
46
62
|
"""Return whether `node` satisfies `checks`."""
|
|
47
|
-
if isinstance(checks, (type,
|
|
63
|
+
if isinstance(checks, (type, types.UnionType)):
|
|
48
64
|
return isinstance(get_module(node), checks)
|
|
49
65
|
return checks(node)
|
|
50
66
|
|
|
@@ -130,34 +146,43 @@ def _match_fork_at(
|
|
|
130
146
|
return None
|
|
131
147
|
|
|
132
148
|
fork_node = trunk.pre_trunk_nodes[0]
|
|
133
|
-
if (
|
|
149
|
+
if getattr(fork.operator, "is_node_check", False):
|
|
150
|
+
if not fork.operator(fork_node):
|
|
151
|
+
return None
|
|
152
|
+
elif (
|
|
134
153
|
fork_node.op != "call_function"
|
|
135
154
|
or fork_node.target is not fork.operator
|
|
136
155
|
):
|
|
137
156
|
return None
|
|
138
157
|
|
|
139
158
|
args = [a for a in fork_node.args if isinstance(a, fx.Node)]
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
159
|
+
perms = (
|
|
160
|
+
itertools.permutations(range(len(fork.inputs)))
|
|
161
|
+
if fork.perms_override is None
|
|
162
|
+
else fork.perms_override
|
|
163
|
+
)
|
|
164
|
+
with _SharedNodeCheckSession() as session:
|
|
165
|
+
for perm in perms:
|
|
166
|
+
session.rollback()
|
|
167
|
+
if len(perm) != len(args):
|
|
168
|
+
continue
|
|
169
|
+
fork_matched = True
|
|
170
|
+
tree_matches = [TreeMatch() for _ in fork.inputs]
|
|
171
|
+
for arg_idx, input_idx in enumerate(perm):
|
|
172
|
+
tree_match = _match_tree_at(
|
|
173
|
+
args[arg_idx], fork.inputs[input_idx]
|
|
174
|
+
)
|
|
175
|
+
if tree_match is None:
|
|
176
|
+
fork_matched = False
|
|
177
|
+
break
|
|
178
|
+
tree_matches[input_idx] = tree_match
|
|
179
|
+
if fork_matched:
|
|
180
|
+
session.success()
|
|
181
|
+
return TreeMatch(
|
|
182
|
+
pre_trunk_nodes=[fork_node],
|
|
183
|
+
trunk_nodes=trunk.trunk_nodes,
|
|
184
|
+
nested=tree_matches,
|
|
185
|
+
)
|
|
161
186
|
return None
|
|
162
187
|
|
|
163
188
|
|
|
@@ -215,15 +240,17 @@ def match_tree(
|
|
|
215
240
|
raise ValueError("``pattern`` has no tree to match.")
|
|
216
241
|
|
|
217
242
|
matches: list[PatternMatch] = []
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
243
|
+
with _SharedNodeCheckSession() as session:
|
|
244
|
+
for node in reversed(list(graph_module.graph.nodes)):
|
|
245
|
+
session.rollback()
|
|
246
|
+
matched = _match_tree_at(node, tree)
|
|
247
|
+
if matched is not None:
|
|
248
|
+
matches.append(
|
|
249
|
+
PatternMatch(
|
|
250
|
+
pattern=pattern,
|
|
251
|
+
graph_module=graph_module,
|
|
252
|
+
tree_match=matched,
|
|
253
|
+
)
|
|
226
254
|
)
|
|
227
|
-
)
|
|
228
255
|
|
|
229
256
|
return matches
|
|
@@ -2,11 +2,30 @@
|
|
|
2
2
|
|
|
3
3
|
"""Abstract ``nn.Module`` marker bases and tracing helpers."""
|
|
4
4
|
|
|
5
|
+
# mypy: disable-error-code="misc"
|
|
6
|
+
# torch lacks type stubs, so nn.Module resolves to Any.
|
|
7
|
+
|
|
5
8
|
from abc import ABC
|
|
9
|
+
from typing import TypeAlias
|
|
6
10
|
|
|
7
11
|
from torch import fx, nn
|
|
8
12
|
|
|
9
|
-
from embedl_deploy._internal.core.quantize.
|
|
13
|
+
from embedl_deploy._internal.core.quantize.stubs import QuantStub
|
|
14
|
+
|
|
15
|
+
ActivationLike: TypeAlias = (
|
|
16
|
+
nn.ReLU
|
|
17
|
+
| nn.ReLU6
|
|
18
|
+
| nn.GELU
|
|
19
|
+
| nn.SiLU
|
|
20
|
+
| nn.Mish
|
|
21
|
+
| nn.Hardswish
|
|
22
|
+
| nn.Hardsigmoid
|
|
23
|
+
| nn.LeakyReLU
|
|
24
|
+
| nn.PReLU
|
|
25
|
+
| nn.ELU
|
|
26
|
+
| nn.Sigmoid
|
|
27
|
+
| nn.Tanh
|
|
28
|
+
)
|
|
10
29
|
|
|
11
30
|
|
|
12
31
|
class ConvertedModule(nn.Module, ABC):
|
|
@@ -25,14 +44,14 @@ class FusedModule(nn.Module, ABC):
|
|
|
25
44
|
"""Marker base for all backend-specific fused modules.
|
|
26
45
|
|
|
27
46
|
Backend packages (e.g. ``tensorrt``) subclass this for their concrete
|
|
28
|
-
fused modules (``
|
|
29
|
-
pass in :mod:`~embedl_deploy._internal.core.quantize.
|
|
47
|
+
fused modules (``FusedConvBNAct``, etc.). The generic Q/DQ insertion
|
|
48
|
+
pass in :mod:`~embedl_deploy._internal.core.quantize.prepare` uses
|
|
30
49
|
``isinstance(mod, FusedModule)`` to identify fused nodes without
|
|
31
50
|
knowing backend-specific types.
|
|
32
51
|
"""
|
|
33
52
|
|
|
34
53
|
#: Positional argument indices that should receive a
|
|
35
|
-
#: :class:`~embedl_deploy._internal.core.quantize.
|
|
54
|
+
#: :class:`~embedl_deploy._internal.core.quantize.stubs.QuantStub`.
|
|
36
55
|
#: The Q/DQ insertion pass uses this to decide which inputs of the
|
|
37
56
|
#: fused node to quantize. Every subclass must set this explicitly.
|
|
38
57
|
inputs_to_quantize: set[int]
|
|
@@ -40,9 +59,9 @@ class FusedModule(nn.Module, ABC):
|
|
|
40
59
|
def __init__(self) -> None:
|
|
41
60
|
super().__init__()
|
|
42
61
|
#: Maps each index in :attr:`inputs_to_quantize` to a
|
|
43
|
-
#: :class:`~embedl_deploy._internal.core.quantize.
|
|
62
|
+
#: :class:`~embedl_deploy._internal.core.quantize.stubs.QuantStub`.
|
|
44
63
|
self.input_quant_stubs: dict[int, QuantStub] = {
|
|
45
|
-
idx: QuantStub() for idx in self.inputs_to_quantize
|
|
64
|
+
idx: QuantStub({self}) for idx in self.inputs_to_quantize
|
|
46
65
|
}
|
|
47
66
|
|
|
48
67
|
|
|
@@ -52,7 +71,8 @@ class _LeafTracer(fx.Tracer):
|
|
|
52
71
|
def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool:
|
|
53
72
|
if isinstance(m, (ConvertedModule, FusedModule, QuantStub)):
|
|
54
73
|
return True
|
|
55
|
-
|
|
74
|
+
result: bool = super().is_leaf_module(m, module_qualified_name)
|
|
75
|
+
return result
|
|
56
76
|
|
|
57
77
|
|
|
58
78
|
def symbolic_trace(model: nn.Module) -> fx.GraphModule:
|