embedl-deploy-tensorrt 0.4.1__tar.gz → 0.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/PKG-INFO +7 -6
  2. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/README.md +6 -5
  3. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/backend.py +1 -0
  4. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/modules/attention.py +35 -10
  5. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/modules/conv.py +4 -4
  6. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/modules/linear.py +3 -3
  7. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/modules/pointwise.py +1 -1
  8. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/modules/pool.py +1 -1
  9. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/modules/swin_attention.py +56 -64
  10. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/attention.py +123 -171
  11. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/general.py +13 -82
  12. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/patterns/fusions.py +26 -3
  13. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/patterns/quantizations.py +80 -15
  14. embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/__init__.py +3 -0
  15. embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/attention.py +808 -0
  16. embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/functional.py +325 -0
  17. embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/general.py +718 -0
  18. embedl_deploy_tensorrt-0.5.0/src/embedl_deploy/_internal/tensorrt/patterns/recompositions/utils.py +44 -0
  19. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/patterns/smoothings.py +6 -4
  20. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/plan.py +24 -12
  21. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/tensorrt/__init__.py +2 -2
  22. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/tensorrt/patterns/__init__.py +0 -2
  23. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/version/public.py +1 -1
  24. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy_tensorrt.egg-info/SOURCES.txt +5 -2
  25. embedl_deploy_tensorrt-0.4.1/src/embedl_deploy/_internal/tensorrt/patterns/recompositions.py +0 -1475
  26. embedl_deploy_tensorrt-0.4.1/src/embedl_deploy/_internal/tensorrt/patterns/utils.py +0 -89
  27. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/LICENSE +0 -0
  28. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/MANIFEST.in +0 -0
  29. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/NOTICE +0 -0
  30. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/pyproject.toml +0 -0
  31. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/setup.cfg +0 -0
  32. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/__init__.py +0 -0
  33. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/__init__.py +0 -0
  34. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/__init__.py +0 -0
  35. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/modules/__init__.py +0 -0
  36. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/patterns/__init__.py +0 -0
  37. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/_internal/tensorrt/patterns/conversions/__init__.py +0 -0
  38. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/tensorrt/modules/__init__.py +0 -0
  39. {embedl_deploy_tensorrt-0.4.1 → embedl_deploy_tensorrt-0.5.0}/src/embedl_deploy/version/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: embedl-deploy-tensorrt
3
- Version: 0.4.1
3
+ Version: 0.5.0
4
4
  Summary: TensorRT backend for embedl-deploy.
5
5
  Author-email: Embedl AB <support@embedl.com>
6
6
  Project-URL: Homepage, https://www.embedl.com/
@@ -54,9 +54,10 @@ hardware target ensuring correct quantization and compilation.
54
54
 
55
55
  ## Supported Backends
56
56
 
57
- | Backend | Status |
58
- |-------------------------|-------------|
59
- | NVIDIA TensorRT (v10.3) | Supported |
57
+ | Backend | Status |
58
+ |---------------------------|-----------------|
59
+ | NVIDIA TensorRT (v10.3) | Supported |
60
+ | Lattice SensAI (v8.0) | In Development |
60
61
 
61
62
  Contact Embedl for other backends.
62
63
 
@@ -71,7 +72,7 @@ intermediate.
71
72
 
72
73
  ---
73
74
 
74
- ## Quick Start
75
+ ## Quick Start for TensorRT Backend
75
76
 
76
77
  ```python
77
78
  import torch
@@ -85,7 +86,7 @@ model = Model().eval()
85
86
  example_input = torch.randn(1, 3, 224, 224)
86
87
 
87
88
  # 2. Transform — fuse and optimize for TensorRT in one call
88
- # For more compatibilty you can trace your model with torch.export.export
89
+ # For more compatibility you can trace your model with torch.export.export
89
90
  # as follows:
90
91
  # model = torch.export.export(model, (example_input)).module()
91
92
  res = transform(model, patterns=TENSORRT_PATTERNS)
@@ -35,9 +35,10 @@ hardware target ensuring correct quantization and compilation.
35
35
 
36
36
  ## Supported Backends
37
37
 
38
- | Backend | Status |
39
- |-------------------------|-------------|
40
- | NVIDIA TensorRT (v10.3) | Supported |
38
+ | Backend | Status |
39
+ |---------------------------|-----------------|
40
+ | NVIDIA TensorRT (v10.3) | Supported |
41
+ | Lattice SensAI (v8.0) | In Development |
41
42
 
42
43
  Contact Embedl for other backends.
43
44
 
@@ -52,7 +53,7 @@ intermediate.
52
53
 
53
54
  ---
54
55
 
55
- ## Quick Start
56
+ ## Quick Start for TensorRT Backend
56
57
 
57
58
  ```python
58
59
  import torch
@@ -66,7 +67,7 @@ model = Model().eval()
66
67
  example_input = torch.randn(1, 3, 224, 224)
67
68
 
68
69
  # 2. Transform — fuse and optimize for TensorRT in one call
69
- # For more compatibilty you can trace your model with torch.export.export
70
+ # For more compatibility you can trace your model with torch.export.export
70
71
  # as follows:
71
72
  # model = torch.export.export(model, (example_input)).module()
72
73
  res = transform(model, patterns=TENSORRT_PATTERNS)
@@ -11,6 +11,7 @@ from embedl_deploy._internal.tensorrt.plan import (
11
11
  )
12
12
 
13
13
  BACKEND = Backend(
14
+ name="tensorrt",
14
15
  conversion_patterns=TENSORRT_CONVERSION_PATTERNS,
15
16
  fusion_patterns=TENSORRT_FUSION_PATTERNS,
16
17
  smooth_patterns=TENSORRT_SMOOTH_PATTERNS,
@@ -69,7 +69,7 @@ class MHAInProjection(ConvertedModule):
69
69
  v = v.view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
70
70
  return q, k, v
71
71
 
72
- def __repr__(self) -> str:
72
+ def __repr__(self) -> str: # pragma: no cover
73
73
  embed_dim = self.num_heads * self.head_dim
74
74
  return (
75
75
  f"MHAInProjection("
@@ -80,7 +80,7 @@ class MHAInProjection(ConvertedModule):
80
80
 
81
81
 
82
82
  class ScaledDotProductAttention(ConvertedModule):
83
- """Core attention: ``softmax(Q · Kᵀ / √H) · V``.
83
+ """Core attention: ``softmax(Q · Kᵀ · scale) · V``.
84
84
 
85
85
  :param num_heads:
86
86
  Number of attention heads.
@@ -88,6 +88,14 @@ class ScaledDotProductAttention(ConvertedModule):
88
88
  Dimension of each head.
89
89
  :param dropout:
90
90
  Dropout probability (applied during training only).
91
+ :param is_causal:
92
+ Whether to apply a causal mask. Mirrors the ``is_causal`` kwarg
93
+ of ``F.scaled_dot_product_attention``.
94
+ :param scale:
95
+ Explicit attention score scale (multiplied on Q·Kᵀ). When
96
+ ``None`` the PyTorch default ``1/√head_dim`` is used. Models
97
+ that pre-scale Q themselves (e.g. chronos-2 + RoPE) must pass
98
+ ``scale=1.0`` so the default scaling does not apply twice.
91
99
  """
92
100
 
93
101
  def __init__(
@@ -95,11 +103,15 @@ class ScaledDotProductAttention(ConvertedModule):
95
103
  num_heads: int,
96
104
  head_dim: int,
97
105
  dropout: float = 0.0,
106
+ is_causal: bool = False,
107
+ scale: float | None = None,
98
108
  ) -> None:
99
109
  super().__init__()
100
110
  self.num_heads = num_heads
101
111
  self.head_dim = head_dim
102
112
  self.dropout = dropout
113
+ self.is_causal = is_causal
114
+ self.scale = scale
103
115
 
104
116
  def forward(
105
117
  self,
@@ -117,8 +129,9 @@ class ScaledDotProductAttention(ConvertedModule):
117
129
  :param v:
118
130
  Value tensor ``[B, num_heads, S, head_dim]``.
119
131
  :param attn_mask:
120
- Optional attention mask. ``aten.scaled_dot_product_attention``
121
- takes an optional 4th positional arg; ``WrapAtenSDPAPattern``
132
+ Optional attention mask.
133
+ ``torch.nn.functional.scaled_dot_product_attention`` takes an
134
+ optional 4th positional arg; ``WrapFunctionalSDPAPattern``
122
135
  forwards whatever positional args were on the source node, so
123
136
  this module accepts the mask too. SAM3, masked-LM, and
124
137
  similar models that compile with mixed-mask attention rely
@@ -135,14 +148,18 @@ class ScaledDotProductAttention(ConvertedModule):
135
148
  v,
136
149
  attn_mask=attn_mask,
137
150
  dropout_p=self.dropout if self.training else 0.0,
151
+ is_causal=self.is_causal,
152
+ scale=self.scale,
138
153
  )
139
154
 
140
- def __repr__(self) -> str:
155
+ def __repr__(self) -> str: # pragma: no cover
141
156
  return (
142
157
  f"ScaledDotProductAttention("
143
158
  f"num_heads={self.num_heads}, "
144
159
  f"head_dim={self.head_dim}, "
145
- f"dropout={self.dropout})"
160
+ f"dropout={self.dropout}, "
161
+ f"is_causal={self.is_causal}, "
162
+ f"scale={self.scale})"
146
163
  )
147
164
 
148
165
 
@@ -197,7 +214,7 @@ class FusedMHAInProjection(FusedModule):
197
214
  v = v.view(batch, seq, num_heads, head_dim).transpose(1, 2)
198
215
  return q, k, v
199
216
 
200
- def __repr__(self) -> str:
217
+ def __repr__(self) -> str: # pragma: no cover
201
218
  embed_dim = self.in_proj.num_heads * self.in_proj.head_dim
202
219
  return (
203
220
  f"FusedMHAInProjection("
@@ -275,10 +292,18 @@ class FusedScaledDotProductAttention(FusedModule):
275
292
  # MHA kernel onto the slower INT8-aware variant for no gain.
276
293
  if not self.surrounded or not self.softmax_quant.enabled:
277
294
  return self.attention(q, k, v, attn_mask)
278
- # Use ``1/sqrt(head_dim)`` rather than ``head_dim ** -0.5``: the
295
+ # Honour the wrapped attention module's explicit ``scale`` if
296
+ # set — models that pre-scale Q themselves (chronos-2 + RoPE,
297
+ # for example) build with ``scale=1.0`` to disable the default
298
+ # ``1/sqrt(head_dim)`` scaling. Falling back to the default
299
+ # here would apply it twice and collapse softmax.
300
+ # Note on ``1/sqrt(head_dim)`` vs ``head_dim ** -0.5``: the
279
301
  # tensor Pow with a negative float exponent traces to ONNX as a
280
302
  # ``Cast → complex128`` node that TRT 10.x can't parse.
281
- scale = 1.0 / math.sqrt(q.shape[-1])
303
+ if self.attention.scale is not None:
304
+ scale = self.attention.scale
305
+ else:
306
+ scale = 1.0 / math.sqrt(q.shape[-1])
282
307
  attn_weight = torch.matmul(q, k.transpose(-2, -1)) * scale
283
308
  if attn_mask is not None:
284
309
  if attn_mask.dtype == torch.bool:
@@ -297,7 +322,7 @@ class FusedScaledDotProductAttention(FusedModule):
297
322
  )
298
323
  return torch.matmul(attn_weight, v)
299
324
 
300
- def __repr__(self) -> str:
325
+ def __repr__(self) -> str: # pragma: no cover
301
326
  a = self.attention
302
327
  qdq = "yes" if self.softmax_quant.enabled else "no"
303
328
  return (
@@ -91,7 +91,7 @@ class FusedConvBNAct(FusedModule):
91
91
  x = self.bn(x)
92
92
  return self.act(x)
93
93
 
94
- def __repr__(self) -> str:
94
+ def __repr__(self) -> str: # pragma: no cover
95
95
  bn_info = ""
96
96
  if self.bn is not None:
97
97
  bn_info = f", bn={self.bn.num_features} (foldable)"
@@ -129,7 +129,7 @@ class FusedConvBN(FusedModule):
129
129
  x = self.bn(x)
130
130
  return x
131
131
 
132
- def __repr__(self) -> str:
132
+ def __repr__(self) -> str: # pragma: no cover
133
133
  bn_info = ""
134
134
  if self.bn is not None:
135
135
  bn_info = f", bn={self.bn.num_features} (foldable)"
@@ -168,7 +168,7 @@ class FusedConvBNActMaxPool(FusedModule):
168
168
  x = self.act(x)
169
169
  return self.maxpool(x)
170
170
 
171
- def __repr__(self) -> str:
171
+ def __repr__(self) -> str: # pragma: no cover
172
172
  bn_info = ""
173
173
  if self.bn is not None:
174
174
  bn_info = f", bn={self.bn.num_features} (foldable)"
@@ -213,7 +213,7 @@ class FusedConvBNAddAct(FusedModule):
213
213
  x = self.bn(x)
214
214
  return self.act(x + residual)
215
215
 
216
- def __repr__(self) -> str:
216
+ def __repr__(self) -> str: # pragma: no cover
217
217
  return (
218
218
  f"FusedConvBNAddAct("
219
219
  f"{self.conv.in_channels}→{self.conv.out_channels}, "
@@ -82,7 +82,7 @@ class FusedLinear(FusedModule):
82
82
  # pylint: disable-next=not-callable
83
83
  return F.linear(x, weight, self.linear.bias)
84
84
 
85
- def __repr__(self) -> str:
85
+ def __repr__(self) -> str: # pragma: no cover
86
86
  return (
87
87
  f"FusedLinear("
88
88
  f"{self.linear.in_features}→{self.linear.out_features})"
@@ -113,7 +113,7 @@ class FusedLinearAct(FusedModule):
113
113
  x = F.linear(x, weight, self.linear.bias)
114
114
  return self.act(x)
115
115
 
116
- def __repr__(self) -> str:
116
+ def __repr__(self) -> str: # pragma: no cover
117
117
  act_name = type(self.act).__name__
118
118
  return (
119
119
  f"FusedLinearAct("
@@ -151,7 +151,7 @@ class FusedLayerNorm(FusedModule):
151
151
  """Apply ``layer_norm``."""
152
152
  return self.layer_norm(x)
153
153
 
154
- def __repr__(self) -> str:
154
+ def __repr__(self) -> str: # pragma: no cover
155
155
  return (
156
156
  f"FusedLayerNorm("
157
157
  f"normalized_shape={self.layer_norm.normalized_shape}, "
@@ -34,5 +34,5 @@ class FusedActAdd(FusedModule):
34
34
  """Apply ``act(x) + residual``."""
35
35
  return self.act(x) + residual
36
36
 
37
- def __repr__(self) -> str:
37
+ def __repr__(self) -> str: # pragma: no cover
38
38
  return f"FusedActAdd({type(self.act).__name__})"
@@ -21,5 +21,5 @@ class FusedAdaptiveAvgPool2d(FusedModule):
21
21
  """Apply adaptive average pooling."""
22
22
  return self.pool(x)
23
23
 
24
- def __repr__(self) -> str:
24
+ def __repr__(self) -> str: # pragma: no cover
25
25
  return f"FusedAdaptiveAvgPool2d(output_size={self.pool.output_size})"
@@ -9,7 +9,7 @@ are spatial rearrangements that need no quantization.
9
9
  computes the shifted-window attention mask.
10
10
  """
11
11
 
12
- from dataclasses import dataclass, field
12
+ from dataclasses import dataclass
13
13
 
14
14
  import torch
15
15
  import torch.nn.functional as F
@@ -19,7 +19,7 @@ from embedl_deploy._internal.core.modules import ConvertedModule, FusedModule
19
19
  from embedl_deploy._internal.core.quantize.stubs import QuantStub
20
20
 
21
21
 
22
- @dataclass
22
+ @dataclass(eq=False)
23
23
  class SwinSpatialState:
24
24
  """Shared mutable state for spatial dimensions.
25
25
 
@@ -34,9 +34,7 @@ class SwinSpatialState:
34
34
  pad_height: int = 0
35
35
  pad_width: int = 0
36
36
  #: Effective shift size after clamping for small feature maps.
37
- effective_shift_size: list[int] = field(
38
- default_factory=lambda: [0, 0],
39
- )
37
+ effective_shift_size: tuple[int, int] = (0, 0)
40
38
 
41
39
 
42
40
  class SwinWindowPartition(ConvertedModule):
@@ -71,31 +69,29 @@ class SwinWindowPartition(ConvertedModule):
71
69
  :returns:
72
70
  Windowed tensor ``[B*nW, Ws*Ws, C]``.
73
71
  """
74
- # pylint: disable-next=invalid-name
75
- B, H, W, C = x.shape
72
+ b, h, w, c = x.shape
76
73
 
77
74
  # Pad to multiples of window size.
78
75
  ws_h, ws_w = self.window_size
79
- pad_b = (ws_h - H % ws_h) % ws_h
80
- pad_r = (ws_w - W % ws_w) % ws_w
76
+ pad_b = (ws_h - h % ws_h) % ws_h
77
+ pad_r = (ws_w - w % ws_w) % ws_w
81
78
  x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
82
- # pylint: disable-next=invalid-name
83
- _, pad_H, pad_W, _ = x.shape
79
+ _, pad_h, pad_w, _ = x.shape
84
80
 
85
81
  # Clamp shift size when the window covers the whole feature map.
86
- eff_shift = list(self.shift_size)
87
- if ws_h >= pad_H:
88
- eff_shift[0] = 0
89
- if ws_w >= pad_W:
90
- eff_shift[1] = 0
82
+ sh, sw = self.shift_size
83
+ eff_shift = (
84
+ 0 if ws_h >= pad_h else sh,
85
+ 0 if ws_w >= pad_w else sw,
86
+ )
91
87
 
92
88
  # Write spatial state for downstream modules.
93
89
  st = self._spatial_state
94
- st.batch_size = B
95
- st.height = H
96
- st.width = W
97
- st.pad_height = pad_H
98
- st.pad_width = pad_W
90
+ st.batch_size = b
91
+ st.height = h
92
+ st.width = w
93
+ st.pad_height = pad_h
94
+ st.pad_width = pad_w
99
95
  st.effective_shift_size = eff_shift
100
96
 
101
97
  # Cyclic shift.
@@ -108,22 +104,22 @@ class SwinWindowPartition(ConvertedModule):
108
104
 
109
105
  # Window partition.
110
106
  x = x.view(
111
- B,
112
- pad_H // ws_h,
107
+ b,
108
+ pad_h // ws_h,
113
109
  ws_h,
114
- pad_W // ws_w,
110
+ pad_w // ws_w,
115
111
  ws_w,
116
- C,
112
+ c,
117
113
  )
118
- num_windows = (pad_H // ws_h) * (pad_W // ws_w)
114
+ num_windows = (pad_h // ws_h) * (pad_w // ws_w)
119
115
  x = x.permute(0, 1, 3, 2, 4, 5).reshape(
120
- B * num_windows,
116
+ b * num_windows,
121
117
  ws_h * ws_w,
122
- C,
118
+ c,
123
119
  )
124
120
  return x
125
121
 
126
- def __repr__(self) -> str:
122
+ def __repr__(self) -> str: # pragma: no cover
127
123
  return (
128
124
  f"SwinWindowPartition("
129
125
  f"window_size={self.window_size}, "
@@ -177,11 +173,10 @@ class SwinAttention(ConvertedModule):
177
173
 
178
174
  def _get_relative_position_bias(self) -> torch.Tensor:
179
175
  """Compute relative position bias ``[1, nH, N, N]``."""
180
- # pylint: disable-next=invalid-name
181
- N = self.window_size[0] * self.window_size[1]
176
+ n = self.window_size[0] * self.window_size[1]
182
177
  assert isinstance(self.relative_position_index, torch.Tensor)
183
178
  bias = self.relative_position_bias_table[self.relative_position_index]
184
- bias = bias.view(N, N, -1).permute(2, 0, 1).contiguous()
179
+ bias = bias.view(n, n, -1).permute(2, 0, 1).contiguous()
185
180
  return bias.unsqueeze(0)
186
181
 
187
182
  def _compute_attn_mask( # pylint: disable=too-many-locals
@@ -193,13 +188,13 @@ class SwinAttention(ConvertedModule):
193
188
  if sum(eff_shift) == 0:
194
189
  return None
195
190
 
196
- pad_H = st.pad_height # pylint: disable=invalid-name
197
- pad_W = st.pad_width # pylint: disable=invalid-name
191
+ pad_h = st.pad_height
192
+ pad_w = st.pad_width
198
193
  ws_h, ws_w = self.window_size
199
- num_windows = (pad_H // ws_h) * (pad_W // ws_w)
194
+ num_windows = (pad_h // ws_h) * (pad_w // ws_w)
200
195
 
201
196
  attn_mask = torch.zeros(
202
- (pad_H, pad_W),
197
+ (pad_h, pad_w),
203
198
  device=self.relative_position_bias_table.device,
204
199
  )
205
200
  h_slices = (
@@ -219,9 +214,9 @@ class SwinAttention(ConvertedModule):
219
214
  count += 1
220
215
 
221
216
  attn_mask = attn_mask.view(
222
- pad_H // ws_h,
217
+ pad_h // ws_h,
223
218
  ws_h,
224
- pad_W // ws_w,
219
+ pad_w // ws_w,
225
220
  ws_w,
226
221
  )
227
222
  attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(
@@ -259,10 +254,9 @@ class SwinAttention(ConvertedModule):
259
254
 
260
255
  attn_mask = self._compute_attn_mask()
261
256
  if attn_mask is not None:
262
- # pylint: disable-next=invalid-name
263
- B = self._spatial_state.batch_size
264
- nW = attn.size(0) // B # pylint: disable=invalid-name
265
- attn = attn.view(B, nW, self.num_heads, -1, attn.size(-1))
257
+ b = self._spatial_state.batch_size
258
+ n_w = attn.size(0) // b
259
+ attn = attn.view(b, n_w, self.num_heads, -1, attn.size(-1))
266
260
  attn = attn + attn_mask.unsqueeze(1).unsqueeze(0)
267
261
  attn = attn.view(-1, self.num_heads, attn.size(-2), attn.size(-1))
268
262
 
@@ -282,7 +276,7 @@ class SwinAttention(ConvertedModule):
282
276
  )
283
277
  return x
284
278
 
285
- def __repr__(self) -> str:
279
+ def __repr__(self) -> str: # pragma: no cover
286
280
  embed_dim = self.num_heads * self.head_dim
287
281
  return (
288
282
  f"SwinAttention("
@@ -320,27 +314,25 @@ class SwinWindowReverse(ConvertedModule):
320
314
  Spatial tensor ``[B, H, W, C]``.
321
315
  """
322
316
  st = self._spatial_state
323
- # pylint: disable=invalid-name
324
- B = st.batch_size
325
- pad_H = st.pad_height
326
- pad_W = st.pad_width
327
- H = st.height
328
- W = st.width
329
- # pylint: enable=invalid-name
317
+ b = st.batch_size
318
+ pad_h = st.pad_height
319
+ pad_w = st.pad_width
320
+ h = st.height
321
+ w = st.width
330
322
  eff_shift = st.effective_shift_size
331
323
  ws_h, ws_w = self.window_size
332
- C = x.size(-1) # pylint: disable=invalid-name
324
+ c = x.size(-1)
333
325
 
334
326
  # Reverse window partition.
335
327
  x = x.view(
336
- B,
337
- pad_H // ws_h,
338
- pad_W // ws_w,
328
+ b,
329
+ pad_h // ws_h,
330
+ pad_w // ws_w,
339
331
  ws_h,
340
332
  ws_w,
341
- C,
333
+ c,
342
334
  )
343
- x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C)
335
+ x = x.permute(0, 1, 3, 2, 4, 5).reshape(b, pad_h, pad_w, c)
344
336
 
345
337
  # Reverse cyclic shift.
346
338
  if sum(eff_shift) > 0:
@@ -351,10 +343,10 @@ class SwinWindowReverse(ConvertedModule):
351
343
  )
352
344
 
353
345
  # Remove padding.
354
- x = x[:, :H, :W, :].contiguous()
346
+ x = x[:, :h, :w, :].contiguous()
355
347
  return x
356
348
 
357
- def __repr__(self) -> str:
349
+ def __repr__(self) -> str: # pragma: no cover
358
350
  return f"SwinWindowReverse(window_size={self.window_size})"
359
351
 
360
352
 
@@ -411,13 +403,13 @@ class FusedSwinAttention(FusedModule):
411
403
  # pylint: disable-next=protected-access
412
404
  attn_mask = a._compute_attn_mask() # noqa: SLF001
413
405
  if attn_mask is not None:
414
- # pylint: disable-next=protected-access,invalid-name
415
- B = a._spatial_state.batch_size # noqa: SLF001
416
- nW = attn_weight.size(0) // B # pylint: disable=invalid-name
406
+ # pylint: disable-next=protected-access
407
+ b = a._spatial_state.batch_size # noqa: SLF001
408
+ n_w = attn_weight.size(0) // b
417
409
  n = attn_weight.size(-2)
418
410
  attn_weight = attn_weight.view(
419
- B,
420
- nW,
411
+ b,
412
+ n_w,
421
413
  a.num_heads,
422
414
  n,
423
415
  attn_weight.size(-1),
@@ -448,7 +440,7 @@ class FusedSwinAttention(FusedModule):
448
440
  )
449
441
  )
450
442
 
451
- def __repr__(self) -> str:
443
+ def __repr__(self) -> str: # pragma: no cover
452
444
  a = self.attention
453
445
  qdq = "yes" if self.softmax_quant.enabled else "no"
454
446
  return (