liger-kernel-nightly 0.6.2.dev20251011154427__py3-none-any.whl → 0.6.4.dev20260107111351__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/cosine_similarity_loss.py +20 -5
- liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
- liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
- liger_kernel/chunked_loss/grpo_loss.py +8 -5
- liger_kernel/chunked_loss/jsd_loss.py +39 -11
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +43 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +244 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +75 -12
- liger_kernel/ops/dyt.py +5 -2
- liger_kernel/ops/fused_add_rms_norm.py +5 -1
- liger_kernel/ops/fused_linear_cross_entropy.py +45 -14
- liger_kernel/ops/geglu.py +5 -3
- liger_kernel/ops/group_norm.py +2 -1
- liger_kernel/ops/grpo_loss.py +3 -1
- liger_kernel/ops/layer_norm.py +86 -66
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +131 -49
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +14 -0
- liger_kernel/transformers/__init__.py +30 -0
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +1 -1
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/functional.py +48 -25
- liger_kernel/transformers/fused_add_rms_norm.py +1 -1
- liger_kernel/transformers/fused_linear_cross_entropy.py +9 -4
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +57 -2
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +1 -1
- liger_kernel/transformers/model/falcon_h1.py +19 -5
- liger_kernel/transformers/model/gemma.py +17 -6
- liger_kernel/transformers/model/gemma2.py +14 -5
- liger_kernel/transformers/model/gemma3.py +26 -12
- liger_kernel/transformers/model/glm4.py +16 -4
- liger_kernel/transformers/model/glm4v.py +16 -4
- liger_kernel/transformers/model/glm4v_moe.py +23 -4
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +12 -5
- liger_kernel/transformers/model/llama.py +14 -5
- liger_kernel/transformers/model/llama4.py +16 -4
- liger_kernel/transformers/model/llava.py +12 -4
- liger_kernel/transformers/model/loss_utils.py +31 -3
- liger_kernel/transformers/model/mistral.py +15 -6
- liger_kernel/transformers/model/mixtral.py +16 -7
- liger_kernel/transformers/model/mllama.py +12 -4
- liger_kernel/transformers/model/olmo2.py +16 -4
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +23 -5
- liger_kernel/transformers/model/phi3.py +14 -7
- liger_kernel/transformers/model/qwen2.py +16 -3
- liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
- liger_kernel/transformers/model/qwen2_vl.py +16 -4
- liger_kernel/transformers/model/qwen3.py +20 -5
- liger_kernel/transformers/model/qwen3_moe.py +19 -5
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +15 -6
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +702 -48
- liger_kernel/transformers/multi_token_attention.py +1 -1
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +15 -3
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +1 -1
- liger_kernel/transformers/sparsemax.py +1 -1
- liger_kernel/transformers/swiglu.py +18 -1
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +52 -0
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/METADATA +12 -3
- liger_kernel_nightly-0.6.4.dev20260107111351.dist-info/RECORD +130 -0
- liger_kernel_nightly-0.6.2.dev20251011154427.dist-info/RECORD +0 -107
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Unified Buffer (UB) Manager for Ascend NPU.
|
|
3
|
+
|
|
4
|
+
This module provides UB capacity detection and tiling strategy computation
|
|
5
|
+
for running Triton kernels on Ascend NPU. It automatically calculates
|
|
6
|
+
optimal block sizes based on UB capacity constraints to prevent UB overflow.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
|
|
11
|
+
from typing import Optional
|
|
12
|
+
from typing import Tuple
|
|
13
|
+
from typing import Union
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
import triton
|
|
17
|
+
|
|
18
|
+
from liger_kernel.utils import is_npu_available
|
|
19
|
+
|
|
20
|
+
# Default UB capacities for different NPU models (in bits)
|
|
21
|
+
_DEFAULT_UB_CAPACITIES = {
|
|
22
|
+
"Ascend910B1": 2097152, # ~256 KB
|
|
23
|
+
"Ascend910B4": 1572864, # ~192 KB
|
|
24
|
+
"default": 2097152, # ~256 KB
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _normalize_tiling_dims(tiling_dim: Union[int, Tuple[int, ...]]) -> set:
|
|
29
|
+
"""
|
|
30
|
+
Normalize tiling dimension specification to a set of dimension indices.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
tiling_dim: Either an int (single dimension) or tuple of ints (multiple dimensions).
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
Set of dimension indices that can be tiled.
|
|
37
|
+
"""
|
|
38
|
+
if isinstance(tiling_dim, int):
|
|
39
|
+
return {tiling_dim}
|
|
40
|
+
elif isinstance(tiling_dim, tuple):
|
|
41
|
+
return set(tiling_dim)
|
|
42
|
+
else:
|
|
43
|
+
return set()
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _default_strategy(
|
|
47
|
+
ub_capacity_bits: int,
|
|
48
|
+
safety_margin: float,
|
|
49
|
+
dtype_size: int,
|
|
50
|
+
memory_multiplier: float,
|
|
51
|
+
shapes: Tuple[Tuple[int, ...], ...],
|
|
52
|
+
tiling_dims: Tuple[Union[int, Tuple[int, ...]], ...],
|
|
53
|
+
) -> Tuple[int, ...]:
|
|
54
|
+
"""
|
|
55
|
+
Default tiling strategy: calculate maximum safe block size based on UB capacity.
|
|
56
|
+
|
|
57
|
+
This is a unified strategy function that works for all kernels by abstracting
|
|
58
|
+
the memory calculation as: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 bits
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
ub_capacity_bits: UB capacity in bits
|
|
62
|
+
safety_margin: Safety margin as a float (e.g., 0.80 for 80%)
|
|
63
|
+
dtype_size: Size of data type in bytes (e.g., 2 for float16, 4 for float32)
|
|
64
|
+
memory_multiplier: Memory multiplier for estimating peak memory usage
|
|
65
|
+
shapes: Tuple of full shapes. Each shape is a tuple of dimension sizes.
|
|
66
|
+
- For ROPE: ((n_q_head, hd), (n_kv_head, hd))
|
|
67
|
+
- For GEGLU: ((n_cols,),)
|
|
68
|
+
tiling_dims: Tuple specifying which dimensions can be tiled for each shape.
|
|
69
|
+
Each element can be:
|
|
70
|
+
- int: single dimension index (e.g., 0 for first dimension)
|
|
71
|
+
- tuple of ints: multiple dimensions that can be tiled together
|
|
72
|
+
- For ROPE: (0, 0) means first dimension of each shape can be tiled
|
|
73
|
+
- For GEGLU: (0,) means first dimension of the shape can be tiled
|
|
74
|
+
Length must match len(shapes).
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Tuple of maximum safe block sizes, one for each shape.
|
|
78
|
+
Each element is a power of 2.
|
|
79
|
+
|
|
80
|
+
Note:
|
|
81
|
+
For each shape, fixed dimensions (non-tiling) are multiplied together to get unit_param.
|
|
82
|
+
The final block size is computed in compute_default_tiling_strategy by taking
|
|
83
|
+
min(desired_block_size, max_safe_block_size) where desired_block_size = triton.next_power_of_2(original_dim).
|
|
84
|
+
"""
|
|
85
|
+
if not shapes or not tiling_dims:
|
|
86
|
+
return ()
|
|
87
|
+
|
|
88
|
+
# Calculate max_safe_block_size for each tiling dimension
|
|
89
|
+
max_safe_sizes = []
|
|
90
|
+
|
|
91
|
+
for shape, tiling_dim in zip(shapes, tiling_dims):
|
|
92
|
+
# Normalize tiling_dim to a set of dimension indices
|
|
93
|
+
tiling_dim_set = _normalize_tiling_dims(tiling_dim)
|
|
94
|
+
|
|
95
|
+
# Validate tiling dimensions are within shape bounds
|
|
96
|
+
if not tiling_dim_set:
|
|
97
|
+
raise ValueError(
|
|
98
|
+
f"Invalid tiling_dim: {tiling_dim}. tiling_dim must be an int or a non-empty tuple of ints."
|
|
99
|
+
)
|
|
100
|
+
if any(dim_idx < 0 or dim_idx >= len(shape) for dim_idx in tiling_dim_set):
|
|
101
|
+
raise ValueError(
|
|
102
|
+
f"Invalid tiling_dim: {tiling_dim} for shape {shape}. "
|
|
103
|
+
f"All dimension indices must be in range [0, {len(shape)})."
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Calculate unit_param: product of fixed (non-tiling) dimensions
|
|
107
|
+
unit_param = 1.0
|
|
108
|
+
for dim_idx, dim_size in enumerate(shape):
|
|
109
|
+
if dim_idx not in tiling_dim_set:
|
|
110
|
+
if dim_size <= 0:
|
|
111
|
+
# Invalid dimension size, use conservative default
|
|
112
|
+
unit_param = 1.0
|
|
113
|
+
break
|
|
114
|
+
unit_param *= float(dim_size)
|
|
115
|
+
|
|
116
|
+
# Ensure unit_param is at least 1.0
|
|
117
|
+
if unit_param <= 0:
|
|
118
|
+
unit_param = 1.0
|
|
119
|
+
|
|
120
|
+
# Calculate maximum safe block size based on UB capacity
|
|
121
|
+
# Memory: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 bits
|
|
122
|
+
SAFE_UB_CAPACITY_BITS = int(ub_capacity_bits * safety_margin)
|
|
123
|
+
|
|
124
|
+
# Solve: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 <= SAFE_UB_CAPACITY_BITS
|
|
125
|
+
# BLOCK_SIZE <= SAFE_UB_CAPACITY_BITS / (memory_multiplier * unit_param * dtype_size * 8)
|
|
126
|
+
max_block_size = int(SAFE_UB_CAPACITY_BITS // (memory_multiplier * unit_param * dtype_size * 8))
|
|
127
|
+
max_block_size = max(1, max_block_size)
|
|
128
|
+
|
|
129
|
+
# Find largest power of 2 <= max_block_size
|
|
130
|
+
# Use triton.next_power_of_2(max_block_size + 1) // 2 to get the largest power of 2 <= max_block_size
|
|
131
|
+
safe_block_size = triton.next_power_of_2(max_block_size + 1) // 2
|
|
132
|
+
max_safe_sizes.append(safe_block_size)
|
|
133
|
+
|
|
134
|
+
return tuple(max_safe_sizes)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class UBManager:
|
|
138
|
+
"""
|
|
139
|
+
Unified Buffer Manager for Ascend NPU.
|
|
140
|
+
|
|
141
|
+
Provides UB capacity detection and management for Ascend NPU devices.
|
|
142
|
+
The UB capacity is used by tiling strategy functions to calculate optimal block sizes.
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
def __init__(self, ub_capacity_bits: Optional[int] = None):
|
|
146
|
+
"""
|
|
147
|
+
Initialize UB Manager.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
ub_capacity_bits: UB capacity in bits. If None, will be detected automatically.
|
|
151
|
+
"""
|
|
152
|
+
self._npu_model = self._detect_npu_model()
|
|
153
|
+
self._ub_capacity_bits = ub_capacity_bits or self._detect_ub_capacity()
|
|
154
|
+
|
|
155
|
+
@property
|
|
156
|
+
def ub_capacity_bits(self) -> int:
|
|
157
|
+
"""Get UB capacity in bits."""
|
|
158
|
+
return self._ub_capacity_bits
|
|
159
|
+
|
|
160
|
+
@property
|
|
161
|
+
def ub_capacity_bytes(self) -> int:
|
|
162
|
+
"""Get UB capacity in bytes."""
|
|
163
|
+
return self._ub_capacity_bits // 8
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
def npu_model(self) -> str:
|
|
167
|
+
"""Get detected NPU model name."""
|
|
168
|
+
return self._npu_model
|
|
169
|
+
|
|
170
|
+
def _detect_npu_model(self) -> str:
|
|
171
|
+
"""Detect NPU model from device properties."""
|
|
172
|
+
if not is_npu_available():
|
|
173
|
+
return "unknown"
|
|
174
|
+
|
|
175
|
+
try:
|
|
176
|
+
dev_props = torch.npu.get_device_properties(0)
|
|
177
|
+
# Try to get model name from device properties
|
|
178
|
+
return dev_props.name
|
|
179
|
+
except Exception:
|
|
180
|
+
pass
|
|
181
|
+
|
|
182
|
+
return "default"
|
|
183
|
+
|
|
184
|
+
def _detect_ub_capacity(self) -> int:
|
|
185
|
+
"""
|
|
186
|
+
Detect UB capacity from environment variable or device properties.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
UB capacity in bits.
|
|
190
|
+
"""
|
|
191
|
+
# Check environment variable first
|
|
192
|
+
env_capacity = os.getenv("ASCEND_UB_CAPACITY_BITS")
|
|
193
|
+
if env_capacity is not None:
|
|
194
|
+
try:
|
|
195
|
+
return int(env_capacity)
|
|
196
|
+
except ValueError:
|
|
197
|
+
pass
|
|
198
|
+
|
|
199
|
+
# Try to get from device properties
|
|
200
|
+
if is_npu_available():
|
|
201
|
+
try:
|
|
202
|
+
dev_props = torch.npu.get_device_properties(0)
|
|
203
|
+
if hasattr(dev_props, "ub_capacity_bits"):
|
|
204
|
+
return dev_props.ub_capacity_bits
|
|
205
|
+
except Exception:
|
|
206
|
+
pass
|
|
207
|
+
|
|
208
|
+
# Fall back to model-based defaults
|
|
209
|
+
model = self._npu_model
|
|
210
|
+
return _DEFAULT_UB_CAPACITIES.get(model, _DEFAULT_UB_CAPACITIES["default"])
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
# Global singleton instance
|
|
214
|
+
_ub_manager: Optional[UBManager] = None
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def get_ub_manager() -> UBManager:
|
|
218
|
+
"""Get global UB manager instance."""
|
|
219
|
+
global _ub_manager
|
|
220
|
+
if _ub_manager is None:
|
|
221
|
+
_ub_manager = UBManager()
|
|
222
|
+
return _ub_manager
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def compute_default_tiling_strategy(
|
|
226
|
+
safety_margin: float = 0.80,
|
|
227
|
+
dtype_size: Optional[int] = None,
|
|
228
|
+
memory_multiplier: Optional[float] = None,
|
|
229
|
+
shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
|
|
230
|
+
tiling_dims: Optional[Tuple[Union[int, Tuple[int, ...]], ...]] = None,
|
|
231
|
+
) -> Optional[Tuple[Tuple[int, ...], ...]]:
|
|
232
|
+
"""
|
|
233
|
+
Compute tiling strategy using the default strategy function.
|
|
234
|
+
|
|
235
|
+
This function directly calls the default strategy and computes the final
|
|
236
|
+
tiling result. All kernels use the same unified strategy function, so
|
|
237
|
+
there's no need for kernel_name-based lookup.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
safety_margin: Safety margin as a float (e.g., 0.80 for 80%). Default is 0.80.
|
|
241
|
+
dtype_size: Size of data type in bytes (e.g., 2 for float16, 4 for float32).
|
|
242
|
+
Must be provided. If None or <= 0, defaults to 4 (float32).
|
|
243
|
+
memory_multiplier: Memory multiplier for estimating peak memory usage.
|
|
244
|
+
- For GEGLU: typically 10.0 for backward, 7.0 for forward
|
|
245
|
+
- For ROPE: typically 3.0
|
|
246
|
+
If None, defaults to 10.0 (conservative estimate).
|
|
247
|
+
shapes: Tuple of full shapes. Each shape is a tuple of dimension sizes.
|
|
248
|
+
- For ROPE: ((n_q_head, hd), (n_kv_head, hd))
|
|
249
|
+
- For GEGLU: ((n_cols,),)
|
|
250
|
+
Can pass original shapes (will handle padding internally) or padded shapes.
|
|
251
|
+
tiling_dims: Tuple specifying which dimensions can be tiled for each shape.
|
|
252
|
+
Each element can be:
|
|
253
|
+
- int: single dimension index (e.g., 0 for first dimension)
|
|
254
|
+
- tuple of ints: multiple dimensions that can be tiled together
|
|
255
|
+
- For ROPE: (0, 0) means first dimension of each shape can be tiled
|
|
256
|
+
- For GEGLU: (0,) means first dimension of the shape can be tiled
|
|
257
|
+
Length must match len(shapes). Cannot be empty.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
Tuple of tiled shapes with same structure as input shapes.
|
|
261
|
+
Tiling dimensions are replaced with computed block sizes (power of 2),
|
|
262
|
+
while non-tiling dimensions are padded to next power of 2.
|
|
263
|
+
- For ROPE: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
|
|
264
|
+
- For GEGLU: ((block_size,),)
|
|
265
|
+
Returns None if shapes or tiling_dims is None or empty.
|
|
266
|
+
|
|
267
|
+
Examples:
|
|
268
|
+
>>> # ROPE forward
|
|
269
|
+
>>> strategy = compute_default_tiling_strategy(
|
|
270
|
+
... safety_margin=0.90,
|
|
271
|
+
... dtype_size=4,
|
|
272
|
+
... memory_multiplier=3.0,
|
|
273
|
+
... shapes=((32, 128), (32, 128)),
|
|
274
|
+
... tiling_dims=(0, 0)
|
|
275
|
+
... )
|
|
276
|
+
>>> # Returns: ((block_size_q, 128), (block_size_kv, 128))
|
|
277
|
+
>>> # GEGLU forward
|
|
278
|
+
>>> strategy = compute_default_tiling_strategy(
|
|
279
|
+
... safety_margin=0.80,
|
|
280
|
+
... dtype_size=2,
|
|
281
|
+
... memory_multiplier=7.0,
|
|
282
|
+
... shapes=((4096,),),
|
|
283
|
+
... tiling_dims=(0,)
|
|
284
|
+
... )
|
|
285
|
+
>>> # Returns: ((block_size,),)
|
|
286
|
+
"""
|
|
287
|
+
ub_manager = get_ub_manager()
|
|
288
|
+
|
|
289
|
+
if shapes is None or not shapes or tiling_dims is None or not tiling_dims:
|
|
290
|
+
return None
|
|
291
|
+
|
|
292
|
+
if len(shapes) != len(tiling_dims):
|
|
293
|
+
return None
|
|
294
|
+
|
|
295
|
+
if dtype_size is None or dtype_size <= 0:
|
|
296
|
+
dtype_size = 4 # Default to float32
|
|
297
|
+
|
|
298
|
+
if memory_multiplier is None or memory_multiplier <= 0:
|
|
299
|
+
memory_multiplier = 10.0 # Default conservative estimate
|
|
300
|
+
|
|
301
|
+
# Call strategy to get max_safe_block_size for each shape
|
|
302
|
+
max_supported = _default_strategy(
|
|
303
|
+
ub_manager.ub_capacity_bits,
|
|
304
|
+
safety_margin,
|
|
305
|
+
dtype_size,
|
|
306
|
+
memory_multiplier,
|
|
307
|
+
shapes,
|
|
308
|
+
tiling_dims,
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
if not max_supported or len(max_supported) != len(shapes):
|
|
312
|
+
return None
|
|
313
|
+
|
|
314
|
+
# Build result: same structure as shapes, with tiling dims replaced by computed block sizes
|
|
315
|
+
result = []
|
|
316
|
+
for shape, tiling_dim, max_safe in zip(shapes, tiling_dims, max_supported):
|
|
317
|
+
result_shape = list(shape)
|
|
318
|
+
|
|
319
|
+
# Normalize tiling_dim to a set of dimension indices
|
|
320
|
+
tiling_dim_set = _normalize_tiling_dims(tiling_dim)
|
|
321
|
+
|
|
322
|
+
# Validate tiling dimensions are within shape bounds
|
|
323
|
+
if not tiling_dim_set:
|
|
324
|
+
raise ValueError(
|
|
325
|
+
f"Invalid tiling_dim: {tiling_dim}. tiling_dim must be an int or a non-empty tuple of ints."
|
|
326
|
+
)
|
|
327
|
+
if any(dim_idx < 0 or dim_idx >= len(result_shape) for dim_idx in tiling_dim_set):
|
|
328
|
+
raise ValueError(
|
|
329
|
+
f"Invalid tiling_dim: {tiling_dim} for shape {shape}. "
|
|
330
|
+
f"All dimension indices must be in range [0, {len(result_shape)})."
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
# Replace tiling dimensions with computed block sizes
|
|
334
|
+
# For each tiling dimension, compute: min(desired, max_safe)
|
|
335
|
+
for dim_idx in tiling_dim_set:
|
|
336
|
+
original_dim = result_shape[dim_idx]
|
|
337
|
+
desired = triton.next_power_of_2(original_dim)
|
|
338
|
+
final_val = min(desired, max_safe)
|
|
339
|
+
final_val = max(1, final_val) # Ensure at least 1
|
|
340
|
+
result_shape[dim_idx] = final_val
|
|
341
|
+
|
|
342
|
+
# Pad non-tiling dimensions to next power of 2
|
|
343
|
+
for dim_idx, dim_size in enumerate(result_shape):
|
|
344
|
+
if dim_idx not in tiling_dim_set:
|
|
345
|
+
result_shape[dim_idx] = triton.next_power_of_2(dim_size)
|
|
346
|
+
|
|
347
|
+
result.append(tuple(result_shape))
|
|
348
|
+
|
|
349
|
+
return tuple(result)
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Vendor registry for Liger-Kernel multi-backend support.
|
|
3
|
+
|
|
4
|
+
This module defines VendorInfo and the registry for vendor registration.
|
|
5
|
+
Each vendor registers itself by calling register_vendor() in its __init__.py.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
# Dynamically get backends package path to avoid hardcoding
|
|
12
|
+
_BACKENDS_PACKAGE = __name__.rsplit(".", 1)[0] # "liger_kernel.ops.backends"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class VendorInfo:
|
|
17
|
+
"""
|
|
18
|
+
Information about a chip vendor and its supported device.
|
|
19
|
+
|
|
20
|
+
Attributes:
|
|
21
|
+
vendor: Vendor name (e.g., "ascend", "intel", "nvidia")
|
|
22
|
+
device: Device type this vendor supports (e.g., "npu", "xpu")
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
vendor: str
|
|
26
|
+
device: str
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def module_path(self) -> str:
|
|
30
|
+
"""Auto-generated module path based on vendor name."""
|
|
31
|
+
return f"{_BACKENDS_PACKAGE}._{self.vendor}.ops"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# Registry mapping device types to their vendor info
|
|
35
|
+
# Vendors register themselves via register_vendor()
|
|
36
|
+
VENDOR_REGISTRY: dict[str, VendorInfo] = {}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def register_vendor(vendor_info: VendorInfo) -> None:
|
|
40
|
+
"""
|
|
41
|
+
Register a vendor's info in the global registry.
|
|
42
|
+
|
|
43
|
+
This should be called in each vendor's __init__.py to register itself.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
vendor_info: VendorInfo instance to register
|
|
47
|
+
"""
|
|
48
|
+
VENDOR_REGISTRY[vendor_info.device] = vendor_info
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_vendor_for_device(device: str) -> Optional[VendorInfo]:
|
|
52
|
+
"""
|
|
53
|
+
Get the VendorInfo for a given device type.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
device: Device type (e.g., "npu", "xpu")
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
VendorInfo if found, None otherwise
|
|
60
|
+
"""
|
|
61
|
+
return VENDOR_REGISTRY.get(device)
|
|
@@ -10,8 +10,9 @@ from liger_kernel.ops.utils import compare_version
|
|
|
10
10
|
from liger_kernel.ops.utils import element_mul_kernel
|
|
11
11
|
from liger_kernel.ops.utils import is_hip
|
|
12
12
|
from liger_kernel.utils import infer_device
|
|
13
|
+
from liger_kernel.utils import is_npu_available
|
|
13
14
|
|
|
14
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
15
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
15
16
|
try:
|
|
16
17
|
# typical import path with dispatch available
|
|
17
18
|
from triton.language.extra.libdevice import tanh
|
|
@@ -32,6 +33,8 @@ def liger_cross_entropy_kernel(
|
|
|
32
33
|
loss_ptr,
|
|
33
34
|
z_loss_ptr,
|
|
34
35
|
loss_stride,
|
|
36
|
+
token_accuracy_ptr,
|
|
37
|
+
token_accuracy_stride,
|
|
35
38
|
n_cols,
|
|
36
39
|
n_non_ignore,
|
|
37
40
|
sum_non_ignore_weight,
|
|
@@ -42,6 +45,7 @@ def liger_cross_entropy_kernel(
|
|
|
42
45
|
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
|
|
43
46
|
softcap,
|
|
44
47
|
RETURN_Z_LOSS: tl.constexpr,
|
|
48
|
+
RETURN_TOKEN_ACCURACY: tl.constexpr,
|
|
45
49
|
BLOCK_SIZE: tl.constexpr,
|
|
46
50
|
HAS_WEIGHT: tl.constexpr,
|
|
47
51
|
HAS_SOFTCAPPING: tl.constexpr,
|
|
@@ -60,6 +64,8 @@ def liger_cross_entropy_kernel(
|
|
|
60
64
|
loss_ptr: Pointer to tensor to store the loss.
|
|
61
65
|
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
|
|
62
66
|
loss_stride (int): The stride of the loss tensor.
|
|
67
|
+
token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
|
|
68
|
+
token_accuracy_stride (int): The stride of the token accuracy tensor.
|
|
63
69
|
n_cols (int): The number of columns in the input tensor.
|
|
64
70
|
n_non_ignore (float): The number of non-ignored elements in the batch.
|
|
65
71
|
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
|
|
@@ -69,7 +75,8 @@ def liger_cross_entropy_kernel(
|
|
|
69
75
|
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
|
|
70
76
|
reduction (str): The string for the reduction to apply
|
|
71
77
|
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
72
|
-
RETURN_Z_LOSS (int): The boolean value to decide whether
|
|
78
|
+
RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
|
|
79
|
+
RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
|
|
73
80
|
BLOCK_SIZE (int): The block size for Triton operations.
|
|
74
81
|
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
|
|
75
82
|
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
|
|
@@ -92,11 +99,17 @@ def liger_cross_entropy_kernel(
|
|
|
92
99
|
for i in range(0, n_cols, BLOCK_SIZE):
|
|
93
100
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
94
101
|
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
|
|
102
|
+
# For ignored tokens, set token accuracy to 0
|
|
103
|
+
if RETURN_TOKEN_ACCURACY:
|
|
104
|
+
token_accuracy_ptr += program_id * token_accuracy_stride
|
|
105
|
+
tl.store(token_accuracy_ptr, 0.0)
|
|
95
106
|
return
|
|
96
107
|
|
|
97
108
|
loss_ptr += program_id * loss_stride
|
|
98
109
|
if RETURN_Z_LOSS:
|
|
99
110
|
z_loss_ptr += program_id * loss_stride
|
|
111
|
+
if RETURN_TOKEN_ACCURACY:
|
|
112
|
+
token_accuracy_ptr += program_id * token_accuracy_stride
|
|
100
113
|
|
|
101
114
|
if HAS_WEIGHT:
|
|
102
115
|
weight_y = tl.load(weight_ptr + y).cast(tl.float32)
|
|
@@ -107,6 +120,7 @@ def liger_cross_entropy_kernel(
|
|
|
107
120
|
# 3. [Online softmax] first pass: find max + sum
|
|
108
121
|
m = float("-inf") # m is the max value. use the notation from the paper
|
|
109
122
|
d = 0.0 # d is the sum. use the notation from the paper
|
|
123
|
+
argmax_idx = 0 # Track the index of the maximum value for token accuracy computation
|
|
110
124
|
ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
|
|
111
125
|
if HAS_SOFTCAPPING:
|
|
112
126
|
ori_X_y = softcap * tanh(ori_X_y / softcap)
|
|
@@ -127,6 +141,19 @@ def liger_cross_entropy_kernel(
|
|
|
127
141
|
if HAS_SOFTCAPPING:
|
|
128
142
|
X_block = softcap * tanh(X_block / softcap)
|
|
129
143
|
block_max = tl.max(X_block)
|
|
144
|
+
|
|
145
|
+
# Track argmax for accuracy computation
|
|
146
|
+
if RETURN_TOKEN_ACCURACY:
|
|
147
|
+
# Find the index of the maximum value in this block
|
|
148
|
+
is_max_mask = X_block == block_max
|
|
149
|
+
# Mask out invalid indices with a value larger than n_cols
|
|
150
|
+
masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
|
|
151
|
+
# Get the first (smallest) index where max occurs
|
|
152
|
+
current_block_argmax_idx = tl.min(masked_offsets)
|
|
153
|
+
|
|
154
|
+
is_new_max = block_max > m
|
|
155
|
+
argmax_idx = tl.where(is_new_max, current_block_argmax_idx, argmax_idx)
|
|
156
|
+
|
|
130
157
|
if label_smoothing > 0:
|
|
131
158
|
# scale X beforehand to avoid overflow
|
|
132
159
|
if HAS_WEIGHT:
|
|
@@ -256,12 +283,22 @@ def liger_cross_entropy_kernel(
|
|
|
256
283
|
tl.store(loss_ptr, loss)
|
|
257
284
|
if RETURN_Z_LOSS:
|
|
258
285
|
tl.store(z_loss_ptr, z_loss)
|
|
286
|
+
if RETURN_TOKEN_ACCURACY:
|
|
287
|
+
# Store 1.0 if prediction is correct, 0.0 otherwise
|
|
288
|
+
is_correct = 1.0 if argmax_idx == y else 0.0
|
|
289
|
+
tl.store(token_accuracy_ptr, is_correct)
|
|
259
290
|
|
|
260
291
|
|
|
261
292
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
262
293
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
263
294
|
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
|
264
|
-
|
|
295
|
+
# the best size we found by manually tuning on xpu and npu.
|
|
296
|
+
if infer_device() == "xpu":
|
|
297
|
+
MAX_FUSED_SIZE = 4096
|
|
298
|
+
elif infer_device() == "npu":
|
|
299
|
+
MAX_FUSED_SIZE = 2048
|
|
300
|
+
else:
|
|
301
|
+
MAX_FUSED_SIZE = 65536 // 2
|
|
265
302
|
|
|
266
303
|
|
|
267
304
|
def cross_entropy_forward(
|
|
@@ -274,8 +311,12 @@ def cross_entropy_forward(
|
|
|
274
311
|
reduction,
|
|
275
312
|
softcap,
|
|
276
313
|
return_z_loss,
|
|
314
|
+
return_token_accuracy=False,
|
|
277
315
|
):
|
|
278
316
|
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
|
|
317
|
+
assert isinstance(return_token_accuracy, bool), (
|
|
318
|
+
f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
|
|
319
|
+
)
|
|
279
320
|
|
|
280
321
|
BT, V = _input.shape
|
|
281
322
|
n_rows = BT
|
|
@@ -285,6 +326,9 @@ def cross_entropy_forward(
|
|
|
285
326
|
# unreduced loss
|
|
286
327
|
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
|
|
287
328
|
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
|
|
329
|
+
token_accuracy_1d = (
|
|
330
|
+
torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None
|
|
331
|
+
)
|
|
288
332
|
|
|
289
333
|
target_mask = target != ignore_index
|
|
290
334
|
n_non_ignore = target_mask.sum().item()
|
|
@@ -321,6 +365,10 @@ def cross_entropy_forward(
|
|
|
321
365
|
loss_ptr=loss_1d,
|
|
322
366
|
z_loss_ptr=z_loss_1d,
|
|
323
367
|
loss_stride=loss_1d.stride(-1), # always 1
|
|
368
|
+
token_accuracy_ptr=token_accuracy_1d,
|
|
369
|
+
token_accuracy_stride=token_accuracy_1d.stride(-1)
|
|
370
|
+
if return_token_accuracy
|
|
371
|
+
else 0, # always 1 if accuracy is enabled
|
|
324
372
|
n_cols=V,
|
|
325
373
|
n_non_ignore=n_non_ignore,
|
|
326
374
|
sum_non_ignore_weight=sum_non_ignore_weight,
|
|
@@ -331,6 +379,7 @@ def cross_entropy_forward(
|
|
|
331
379
|
reduction=reduction,
|
|
332
380
|
softcap=softcap,
|
|
333
381
|
RETURN_Z_LOSS=return_z_loss,
|
|
382
|
+
RETURN_TOKEN_ACCURACY=return_token_accuracy,
|
|
334
383
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
335
384
|
HAS_WEIGHT=True if weight is not None else False,
|
|
336
385
|
HAS_SOFTCAPPING=True if softcap is not None else False,
|
|
@@ -343,11 +392,14 @@ def cross_entropy_forward(
|
|
|
343
392
|
if reduction == "none":
|
|
344
393
|
loss = loss_1d
|
|
345
394
|
z_loss = z_loss_1d if return_z_loss else None
|
|
395
|
+
token_accuracy = token_accuracy_1d if return_token_accuracy else None
|
|
346
396
|
else:
|
|
347
397
|
loss = torch.sum(loss_1d)
|
|
348
398
|
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
|
|
399
|
+
# For accuracy, we compute the mean across all non-ignored tokens
|
|
400
|
+
token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None
|
|
349
401
|
|
|
350
|
-
return loss, z_loss, _input
|
|
402
|
+
return loss, z_loss, token_accuracy, _input
|
|
351
403
|
|
|
352
404
|
|
|
353
405
|
def cross_entropy_backward(_input, grad_output):
|
|
@@ -395,6 +447,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
395
447
|
reduction: str = "mean",
|
|
396
448
|
softcap: Optional[float] = None,
|
|
397
449
|
return_z_loss: bool = False,
|
|
450
|
+
return_token_accuracy: bool = False,
|
|
398
451
|
):
|
|
399
452
|
"""
|
|
400
453
|
The forward pass of the Liger Cross Entropy loss.
|
|
@@ -409,12 +462,15 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
409
462
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
410
463
|
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
|
|
411
464
|
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
|
|
412
|
-
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
|
|
465
|
+
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy) instead of (loss, None, None). Default: `False`
|
|
466
|
+
return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
|
|
413
467
|
|
|
414
468
|
Returns:
|
|
415
|
-
tuple: A tuple with the
|
|
469
|
+
tuple: A tuple with the computed losses and accuracy: (loss, z_loss, token_accuracy). z_loss and token_accuracy are None if not requested.
|
|
416
470
|
"""
|
|
417
|
-
|
|
471
|
+
input_requires_grad = _input.requires_grad
|
|
472
|
+
|
|
473
|
+
loss, z_loss, token_accuracy, _input = cross_entropy_forward(
|
|
418
474
|
_input,
|
|
419
475
|
target,
|
|
420
476
|
weight,
|
|
@@ -424,29 +480,35 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
424
480
|
reduction,
|
|
425
481
|
softcap,
|
|
426
482
|
return_z_loss,
|
|
483
|
+
return_token_accuracy,
|
|
427
484
|
)
|
|
428
485
|
# TODO: investigation
|
|
429
486
|
# If we don't detach the _input tensor, the memory will double
|
|
430
487
|
# Not sure why but seems that there will be a time both grad and value exist but in different location
|
|
431
|
-
|
|
488
|
+
if input_requires_grad:
|
|
489
|
+
ctx.save_for_backward(_input.detach())
|
|
432
490
|
ctx.return_z_loss = return_z_loss
|
|
491
|
+
ctx.return_token_accuracy = return_token_accuracy
|
|
433
492
|
|
|
434
|
-
return loss, z_loss
|
|
493
|
+
return loss, z_loss, token_accuracy
|
|
435
494
|
|
|
436
495
|
@staticmethod
|
|
437
|
-
def backward(ctx, grad_output,
|
|
496
|
+
def backward(ctx, grad_output, grad_output2, grad_output3):
|
|
438
497
|
"""
|
|
439
498
|
The backward pass of the Liger Cross Entropy loss.
|
|
440
499
|
|
|
441
500
|
Parameters:
|
|
442
501
|
ctx : The context object with saved tensors.
|
|
443
502
|
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
|
|
444
|
-
grad_output2 (
|
|
503
|
+
grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging).
|
|
504
|
+
grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics).
|
|
445
505
|
Returns:
|
|
446
506
|
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
|
|
447
507
|
"""
|
|
448
508
|
if ctx.return_z_loss:
|
|
449
|
-
del
|
|
509
|
+
del grad_output2 # z_loss is only for logging
|
|
510
|
+
if ctx.return_token_accuracy:
|
|
511
|
+
del grad_output3 # token_accuracy is only for metrics
|
|
450
512
|
|
|
451
513
|
(_input,) = ctx.saved_tensors
|
|
452
514
|
_input = cross_entropy_backward(_input, grad_output)
|
|
@@ -460,4 +522,5 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
|
|
|
460
522
|
None,
|
|
461
523
|
None,
|
|
462
524
|
None,
|
|
525
|
+
None,
|
|
463
526
|
)
|
liger_kernel/ops/dyt.py
CHANGED
|
@@ -7,8 +7,10 @@ import triton.language as tl
|
|
|
7
7
|
from liger_kernel.ops.utils import compare_version
|
|
8
8
|
from liger_kernel.ops.utils import ensure_contiguous
|
|
9
9
|
from liger_kernel.ops.utils import infer_device
|
|
10
|
+
from liger_kernel.utils import get_npu_multi_processor_count
|
|
11
|
+
from liger_kernel.utils import is_npu_available
|
|
10
12
|
|
|
11
|
-
if compare_version("triton", operator.ge, "3.0.0"):
|
|
13
|
+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
|
12
14
|
try:
|
|
13
15
|
# typical import path with dispatch available
|
|
14
16
|
from triton.language.extra.libdevice import tanh
|
|
@@ -125,7 +127,8 @@ def liger_dyt_bwd(dy, x, alpha, gamma, beta):
|
|
|
125
127
|
NUM_SMS = torch.cuda.get_device_properties(x.device).multi_processor_count
|
|
126
128
|
elif device == "xpu":
|
|
127
129
|
NUM_SMS = torch.xpu.get_device_properties(x.device).gpu_subslice_count
|
|
128
|
-
|
|
130
|
+
elif device == "npu":
|
|
131
|
+
NUM_SMS = get_npu_multi_processor_count()
|
|
129
132
|
da = torch.zeros(NUM_SMS, triton.cdiv(N, 512), dtype=torch.float32, device=x.device)
|
|
130
133
|
dg = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device)
|
|
131
134
|
db = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None
|