embedl-deploy 0.1.0__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. {embedl_deploy-0.1.0/src/embedl_deploy.egg-info → embedl_deploy-0.2.0}/PKG-INFO +1 -1
  2. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/pyproject.toml +18 -0
  3. embedl_deploy-0.2.0/src/embedl_deploy/_internal/core/backend.py +125 -0
  4. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/core/match.py +75 -48
  5. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/core/modules.py +27 -7
  6. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/core/pattern.py +235 -52
  7. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/core/plan.py +33 -18
  8. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/core/quantize/__init__.py +1 -1
  9. embedl_deploy-0.2.0/src/embedl_deploy/_internal/core/quantize/calibrate.py +133 -0
  10. {embedl_deploy-0.1.0/src/embedl_deploy/_internal/core → embedl_deploy-0.2.0/src/embedl_deploy/_internal/core/quantize}/config.py +60 -76
  11. embedl_deploy-0.2.0/src/embedl_deploy/_internal/core/quantize/main.py +179 -0
  12. embedl_deploy-0.2.0/src/embedl_deploy/_internal/core/quantize/prepare.py +52 -0
  13. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/core/quantize/qat.py +1 -1
  14. embedl_deploy-0.2.0/src/embedl_deploy/_internal/core/quantize/stubs.py +385 -0
  15. embedl_deploy-0.2.0/src/embedl_deploy/_internal/core/quantize/utils.py +90 -0
  16. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/core/replace.py +5 -1
  17. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/backend.py +18 -0
  18. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/modules/attention.py +23 -16
  19. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/modules/conv.py +31 -32
  20. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/modules/linear.py +23 -24
  21. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/modules/pointwise.py +39 -0
  22. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/modules/pool.py +1 -1
  23. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/modules/swin_attention.py +460 -0
  24. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/conversions/__init__.py +15 -0
  25. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/conversions/attention.py +691 -0
  26. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/conversions/general.py +300 -0
  27. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions/attention.py +29 -13
  28. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions/conv.py +42 -41
  29. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions/linear.py +16 -23
  30. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/fusions/pointwise.py +55 -0
  31. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions/pool.py +2 -4
  32. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/patterns/quantizations.py +47 -31
  33. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions.py +819 -0
  34. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/patterns/smoothings.py +27 -21
  35. embedl_deploy-0.2.0/src/embedl_deploy/_internal/tensorrt/patterns/utils.py +81 -0
  36. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/plan.py +52 -9
  37. embedl_deploy-0.2.0/src/embedl_deploy/backend/__init__.py +15 -0
  38. embedl_deploy-0.2.0/src/embedl_deploy/py.typed +0 -0
  39. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/quantize/__init__.py +12 -10
  40. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/tensorrt/modules/__init__.py +9 -9
  41. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/tensorrt/patterns/__init__.py +19 -9
  42. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/version/public.py +1 -1
  43. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0/src/embedl_deploy.egg-info}/PKG-INFO +1 -1
  44. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy.egg-info/SOURCES.txt +17 -7
  45. embedl_deploy-0.1.0/src/embedl_deploy/_internal/core/quantize/main.py +0 -118
  46. embedl_deploy-0.1.0/src/embedl_deploy/_internal/core/quantize/modules.py +0 -251
  47. embedl_deploy-0.1.0/src/embedl_deploy/_internal/core/quantize/qdq.py +0 -500
  48. embedl_deploy-0.1.0/src/embedl_deploy/_internal/core/quantize/qdq_v2.py +0 -98
  49. embedl_deploy-0.1.0/src/embedl_deploy/_internal/core/quantize/smooth.py +0 -267
  50. embedl_deploy-0.1.0/src/embedl_deploy/_internal/core/quantize/smooth_v2.py +0 -237
  51. embedl_deploy-0.1.0/src/embedl_deploy/_internal/tensorrt/patterns/conversions.py +0 -425
  52. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/LICENSE +0 -0
  53. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/NOTICE +0 -0
  54. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/README.md +0 -0
  55. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/setup.cfg +0 -0
  56. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/__init__.py +0 -0
  57. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/__init__.py +0 -0
  58. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/core/__init__.py +0 -0
  59. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/__init__.py +0 -0
  60. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/modules/__init__.py +0 -0
  61. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/patterns/__init__.py +0 -0
  62. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions/__init__.py +0 -0
  63. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/tensorrt/__init__.py +0 -0
  64. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy/version/__init__.py +0 -0
  65. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy.egg-info/dependency_links.txt +0 -0
  66. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy.egg-info/requires.txt +0 -0
  67. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/src/embedl_deploy.egg-info/top_level.txt +0 -0
  68. {embedl_deploy-0.1.0 → embedl_deploy-0.2.0}/tests/test_version.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: embedl-deploy
3
- Version: 0.1.0
3
+ Version: 0.2.0
4
4
  Summary: Python package to make AI models deployment-ready for any hardware.
5
5
  Author-email: Embedl AB <support@embedl.com>
6
6
  Project-URL: Homepage, https://www.embedl.com/
@@ -107,5 +107,23 @@ select = [
107
107
  "PLR0402",
108
108
  ]
109
109
 
110
+ [tool.mypy]
111
+ ignore_missing_imports = false
112
+ strict = true
113
+
114
+ [[tool.mypy.overrides]]
115
+ module = ["torch.*", "pytest.*"]
116
+ ignore_missing_imports = true
117
+
118
+ [[tool.mypy.overrides]]
119
+ module = ["tests.*"]
120
+ disallow_untyped_defs = false
121
+ disallow_untyped_calls = false
122
+ disable_error_code = ["misc", "no-any-return"]
123
+
124
+ [[tool.mypy.overrides]]
125
+ module = ["embedl_deploy._internal.tensorrt.modules.*"]
126
+ disable_error_code = ["no-any-return"]
127
+
110
128
  [tool.setuptools.dynamic]
111
129
  version = { attr = "embedl_deploy.version.public.PUBLIC_VERSION" }
@@ -0,0 +1,125 @@
1
+ # Copyright (C) 2026 Embedl AB
2
+
3
+ """Backend discovery and selection."""
4
+
5
+ import importlib
6
+ from collections.abc import Sequence
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+
10
+ from embedl_deploy._internal.core.pattern import Pattern
11
+
12
+ _INTERNAL_DIR = Path(__file__).resolve().parent.parent
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class Backend:
17
+ """A collection of patterns for a specific hardware target."""
18
+
19
+ #: Structural rewrite patterns, applied iteratively.
20
+ conversion_patterns: Sequence[Pattern]
21
+ #: Fusion patterns, applied in a single pass after conversions.
22
+ fusion_patterns: Sequence[Pattern]
23
+ #: SmoothQuant preparation patterns.
24
+ smooth_patterns: Sequence[Pattern]
25
+ #: Q/DQ stub insertion patterns for quantisation.
26
+ quantized_patterns: Sequence[Pattern]
27
+
28
+
29
+ class _BackendState:
30
+ """Module-level mutable state for backend discovery and selection."""
31
+
32
+ #: The currently selected backend.
33
+ backend: Backend | None = None
34
+ #: Cached discovery result.
35
+ backends: dict[str, Backend] | None = None
36
+
37
+ @classmethod
38
+ def reset(cls) -> None:
39
+ """Clear cached discovery results and the active backend."""
40
+ cls.backend = None
41
+ cls.backends = None
42
+
43
+
44
+ def _discover_backends() -> dict[str, Backend]:
45
+ """Scan ``_internal/`` for importable backend packages.
46
+
47
+ Each subdirectory (except ``core``) is tried as
48
+ ``embedl_deploy._internal.<name>.backend``. Directories whose
49
+ module cannot be found are skipped; import errors from
50
+ transitive dependencies are propagated. Results are cached
51
+ after the first call.
52
+
53
+ :returns:
54
+ Mapping of backend name to ``Backend`` instance.
55
+ """
56
+ backends = _BackendState.backends
57
+ if backends is None:
58
+ backends = {}
59
+ for entry in sorted(_INTERNAL_DIR.iterdir()):
60
+ if (
61
+ not entry.is_dir()
62
+ or entry.name.startswith("_")
63
+ or entry.name == "core"
64
+ ):
65
+ continue
66
+ module_path = f"embedl_deploy._internal.{entry.name}.backend"
67
+ try:
68
+ mod = importlib.import_module(module_path)
69
+ except ModuleNotFoundError as e:
70
+ if e.name == module_path:
71
+ continue
72
+ raise
73
+ backend = getattr(mod, "BACKEND", None)
74
+ if isinstance(backend, Backend):
75
+ backends[entry.name] = backend
76
+ _BackendState.backends = backends
77
+ return backends
78
+
79
+
80
+ def get_backend() -> Backend:
81
+ """Return the active backend, discovering it if necessary.
82
+
83
+ If no backend has been set via :func:`set_backend`, the installed
84
+ backends are discovered automatically. When exactly one is found
85
+ it becomes the active backend.
86
+
87
+ :returns:
88
+ The active :class:`Backend`.
89
+ :raises RuntimeError:
90
+ If no backends are installed, or if multiple backends are
91
+ installed and none has been explicitly selected.
92
+ """
93
+ backend = _BackendState.backend
94
+ if backend is None:
95
+ backends = _discover_backends()
96
+ if len(backends) == 0:
97
+ raise RuntimeError(
98
+ "No backends found — install at least one backend"
99
+ )
100
+ if len(backends) > 1:
101
+ names = ", ".join(sorted(backends))
102
+ raise RuntimeError(
103
+ f"Multiple backends found ({names}). "
104
+ "Call set_backend() to select one."
105
+ )
106
+ backend = next(iter(backends.values()))
107
+ _BackendState.backend = backend
108
+ return backend
109
+
110
+
111
+ def set_backend(name: str) -> None:
112
+ """Select the active backend by name.
113
+
114
+ :param name:
115
+ The name of a discovered backend (e.g. ``"tensorrt"``).
116
+ :raises ValueError:
117
+ If `name` does not match any installed backend.
118
+ """
119
+ backends = _discover_backends()
120
+ if name not in backends:
121
+ available = ", ".join(sorted(backends)) or "(none)"
122
+ raise ValueError(
123
+ f"Backend {name!r} not found. " f"Available backends: {available}"
124
+ )
125
+ _BackendState.backend = backends[name]
@@ -9,8 +9,9 @@ occurrences of operation chains in ``torch.fx`` graphs.
9
9
  """
10
10
 
11
11
  import itertools
12
+ import types
12
13
 
13
- from torch import fx, nn
14
+ from torch import fx
14
15
 
15
16
  from embedl_deploy._internal.core.pattern import (
16
17
  Fork,
@@ -18,6 +19,7 @@ from embedl_deploy._internal.core.pattern import (
18
19
  NodeCheck,
19
20
  Pattern,
20
21
  PatternMatch,
22
+ SharedNodeCheck,
21
23
  Tree,
22
24
  TreeMatch,
23
25
  Trunk,
@@ -25,26 +27,40 @@ from embedl_deploy._internal.core.pattern import (
25
27
  get_module,
26
28
  )
27
29
 
28
- #: Module types recognized as activation functions by the matchers.
29
- ACTIVATION_MODULES: tuple[type[nn.Module], ...] = (
30
- nn.ReLU,
31
- nn.ReLU6,
32
- nn.GELU,
33
- nn.SiLU,
34
- nn.Mish,
35
- nn.Hardswish,
36
- nn.Hardsigmoid,
37
- nn.LeakyReLU,
38
- nn.PReLU,
39
- nn.ELU,
40
- nn.Sigmoid,
41
- nn.Tanh,
42
- )
30
+
31
+ class _SharedNodeCheckSession:
32
+ """Checkpoint/rollback scope for ``SharedNodeCheck`` cache entries.
33
+
34
+ Between permutation attempts inside a fork, call :meth:`rollback`
35
+ explicitly. On normal exit (success) cache entries survive so that
36
+ enclosing sessions can still enforce cross-fork shared-node
37
+ constraints. On failure exit (no permutation matched and the block
38
+ falls through) ``__exit__`` rolls back automatically.
39
+ """
40
+
41
+ def __init__(self) -> None:
42
+ self._checkpoint = SharedNodeCheck.checkpoint()
43
+ self._succeeded = False
44
+
45
+ def __enter__(self) -> "_SharedNodeCheckSession":
46
+ return self
47
+
48
+ def success(self) -> None:
49
+ """Mark this session as successful — skip rollback on exit."""
50
+ self._succeeded = True
51
+
52
+ def __exit__(self, *args: object) -> None:
53
+ if not self._succeeded:
54
+ self.rollback()
55
+
56
+ def rollback(self) -> None:
57
+ """Reset every cache event logged since this session started."""
58
+ SharedNodeCheck.rollback_to(self._checkpoint)
43
59
 
44
60
 
45
61
  def _node_matches(node: fx.Node, checks: ModType | NodeCheck) -> bool:
46
62
  """Return whether `node` satisfies `checks`."""
47
- if isinstance(checks, (type, tuple)):
63
+ if isinstance(checks, (type, types.UnionType)):
48
64
  return isinstance(get_module(node), checks)
49
65
  return checks(node)
50
66
 
@@ -130,34 +146,43 @@ def _match_fork_at(
130
146
  return None
131
147
 
132
148
  fork_node = trunk.pre_trunk_nodes[0]
133
- if (
149
+ if getattr(fork.operator, "is_node_check", False):
150
+ if not fork.operator(fork_node):
151
+ return None
152
+ elif (
134
153
  fork_node.op != "call_function"
135
154
  or fork_node.target is not fork.operator
136
155
  ):
137
156
  return None
138
157
 
139
158
  args = [a for a in fork_node.args if isinstance(a, fx.Node)]
140
- if len(args) != len(fork.inputs):
141
- return None
142
-
143
- for perm in itertools.permutations(range(len(fork.inputs))):
144
- tree_matches: dict[int, TreeMatch] = {}
145
- for arg_idx, branch_idx in enumerate(perm):
146
- matched = _match_tree_at(
147
- args[arg_idx],
148
- fork.inputs[branch_idx],
149
- )
150
- if matched is None:
151
- break
152
- tree_matches[branch_idx] = matched
153
-
154
- if len(tree_matches) == len(fork.inputs):
155
- nested = [tree_matches[i] for i in range(len(fork.inputs))]
156
- return TreeMatch(
157
- pre_trunk_nodes=[fork_node],
158
- trunk_nodes=trunk.trunk_nodes,
159
- nested=nested,
160
- )
159
+ perms = (
160
+ itertools.permutations(range(len(fork.inputs)))
161
+ if fork.perms_override is None
162
+ else fork.perms_override
163
+ )
164
+ with _SharedNodeCheckSession() as session:
165
+ for perm in perms:
166
+ session.rollback()
167
+ if len(perm) != len(args):
168
+ continue
169
+ fork_matched = True
170
+ tree_matches = [TreeMatch() for _ in fork.inputs]
171
+ for arg_idx, input_idx in enumerate(perm):
172
+ tree_match = _match_tree_at(
173
+ args[arg_idx], fork.inputs[input_idx]
174
+ )
175
+ if tree_match is None:
176
+ fork_matched = False
177
+ break
178
+ tree_matches[input_idx] = tree_match
179
+ if fork_matched:
180
+ session.success()
181
+ return TreeMatch(
182
+ pre_trunk_nodes=[fork_node],
183
+ trunk_nodes=trunk.trunk_nodes,
184
+ nested=tree_matches,
185
+ )
161
186
  return None
162
187
 
163
188
 
@@ -215,15 +240,17 @@ def match_tree(
215
240
  raise ValueError("``pattern`` has no tree to match.")
216
241
 
217
242
  matches: list[PatternMatch] = []
218
- for node in reversed(graph_module.graph.nodes):
219
- matched = _match_tree_at(node, tree)
220
- if matched is not None:
221
- matches.append(
222
- PatternMatch(
223
- pattern=pattern,
224
- graph_module=graph_module,
225
- tree_match=matched,
243
+ with _SharedNodeCheckSession() as session:
244
+ for node in reversed(list(graph_module.graph.nodes)):
245
+ session.rollback()
246
+ matched = _match_tree_at(node, tree)
247
+ if matched is not None:
248
+ matches.append(
249
+ PatternMatch(
250
+ pattern=pattern,
251
+ graph_module=graph_module,
252
+ tree_match=matched,
253
+ )
226
254
  )
227
- )
228
255
 
229
256
  return matches
@@ -2,11 +2,30 @@
2
2
 
3
3
  """Abstract ``nn.Module`` marker bases and tracing helpers."""
4
4
 
5
+ # mypy: disable-error-code="misc"
6
+ # torch lacks type stubs, so nn.Module resolves to Any.
7
+
5
8
  from abc import ABC
9
+ from typing import TypeAlias
6
10
 
7
11
  from torch import fx, nn
8
12
 
9
- from embedl_deploy._internal.core.quantize.modules import QuantStub
13
+ from embedl_deploy._internal.core.quantize.stubs import QuantStub
14
+
15
+ ActivationLike: TypeAlias = (
16
+ nn.ReLU
17
+ | nn.ReLU6
18
+ | nn.GELU
19
+ | nn.SiLU
20
+ | nn.Mish
21
+ | nn.Hardswish
22
+ | nn.Hardsigmoid
23
+ | nn.LeakyReLU
24
+ | nn.PReLU
25
+ | nn.ELU
26
+ | nn.Sigmoid
27
+ | nn.Tanh
28
+ )
10
29
 
11
30
 
12
31
  class ConvertedModule(nn.Module, ABC):
@@ -25,14 +44,14 @@ class FusedModule(nn.Module, ABC):
25
44
  """Marker base for all backend-specific fused modules.
26
45
 
27
46
  Backend packages (e.g. ``tensorrt``) subclass this for their concrete
28
- fused modules (``FusedConvBNReLU``, etc.). The generic Q/DQ insertion
29
- pass in :mod:`~embedl_deploy._internal.core.quantize.qdq` uses
47
+ fused modules (``FusedConvBNAct``, etc.). The generic Q/DQ insertion
48
+ pass in :mod:`~embedl_deploy._internal.core.quantize.prepare` uses
30
49
  ``isinstance(mod, FusedModule)`` to identify fused nodes without
31
50
  knowing backend-specific types.
32
51
  """
33
52
 
34
53
  #: Positional argument indices that should receive a
35
- #: :class:`~embedl_deploy._internal.core.quantize.modules.QuantStub`.
54
+ #: :class:`~embedl_deploy._internal.core.quantize.stubs.QuantStub`.
36
55
  #: The Q/DQ insertion pass uses this to decide which inputs of the
37
56
  #: fused node to quantize. Every subclass must set this explicitly.
38
57
  inputs_to_quantize: set[int]
@@ -40,9 +59,9 @@ class FusedModule(nn.Module, ABC):
40
59
  def __init__(self) -> None:
41
60
  super().__init__()
42
61
  #: Maps each index in :attr:`inputs_to_quantize` to a
43
- #: :class:`~embedl_deploy._internal.core.quantize.modules.QuantStub`.
62
+ #: :class:`~embedl_deploy._internal.core.quantize.stubs.QuantStub`.
44
63
  self.input_quant_stubs: dict[int, QuantStub] = {
45
- idx: QuantStub() for idx in self.inputs_to_quantize
64
+ idx: QuantStub({self}) for idx in self.inputs_to_quantize
46
65
  }
47
66
 
48
67
 
@@ -52,7 +71,8 @@ class _LeafTracer(fx.Tracer):
52
71
  def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool:
53
72
  if isinstance(m, (ConvertedModule, FusedModule, QuantStub)):
54
73
  return True
55
- return super().is_leaf_module(m, module_qualified_name)
74
+ result: bool = super().is_leaf_module(m, module_qualified_name)
75
+ return result
56
76
 
57
77
 
58
78
  def symbolic_trace(model: nn.Module) -> fx.GraphModule: