liger-kernel-nightly 0.6.4.dev20251202054858__py3-none-any.whl → 0.6.4.dev20260107111351__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/cosine_similarity_loss.py +7 -1
- liger_kernel/chunked_loss/fused_linear_distillation.py +10 -3
- liger_kernel/chunked_loss/jsd_loss.py +21 -6
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +43 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +244 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +12 -3
- liger_kernel/ops/fused_linear_cross_entropy.py +2 -1
- liger_kernel/ops/geglu.py +3 -2
- liger_kernel/ops/rms_norm.py +126 -49
- liger_kernel/ops/utils.py +12 -0
- liger_kernel/transformers/__init__.py +3 -0
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +1 -1
- liger_kernel/transformers/dyt.py +1 -1
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/functional.py +20 -20
- liger_kernel/transformers/fused_add_rms_norm.py +1 -1
- liger_kernel/transformers/fused_linear_cross_entropy.py +1 -1
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +1 -1
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +1 -1
- liger_kernel/transformers/model/gemma3.py +1 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/paligemma.py +1 -0
- liger_kernel/transformers/monkey_patch.py +118 -39
- liger_kernel/transformers/multi_token_attention.py +1 -1
- liger_kernel/transformers/poly_norm.py +1 -1
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +8 -3
- liger_kernel/transformers/rope.py +28 -27
- liger_kernel/transformers/softmax.py +1 -1
- liger_kernel/transformers/sparsemax.py +1 -1
- liger_kernel/transformers/swiglu.py +1 -1
- liger_kernel/transformers/tiled_mlp.py +3 -3
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +27 -0
- {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/METADATA +9 -3
- {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/RECORD +58 -46
- {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Unified Buffer (UB) Manager for Ascend NPU.
|
|
3
|
+
|
|
4
|
+
This module provides UB capacity detection and tiling strategy computation
|
|
5
|
+
for running Triton kernels on Ascend NPU. It automatically calculates
|
|
6
|
+
optimal block sizes based on UB capacity constraints to prevent UB overflow.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
|
|
11
|
+
from typing import Optional
|
|
12
|
+
from typing import Tuple
|
|
13
|
+
from typing import Union
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
import triton
|
|
17
|
+
|
|
18
|
+
from liger_kernel.utils import is_npu_available
|
|
19
|
+
|
|
20
|
+
# Default UB capacities for different NPU models (in bits)
|
|
21
|
+
_DEFAULT_UB_CAPACITIES = {
|
|
22
|
+
"Ascend910B1": 2097152, # ~256 KB
|
|
23
|
+
"Ascend910B4": 1572864, # ~192 KB
|
|
24
|
+
"default": 2097152, # ~256 KB
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _normalize_tiling_dims(tiling_dim: Union[int, Tuple[int, ...]]) -> set:
|
|
29
|
+
"""
|
|
30
|
+
Normalize tiling dimension specification to a set of dimension indices.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
tiling_dim: Either an int (single dimension) or tuple of ints (multiple dimensions).
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
Set of dimension indices that can be tiled.
|
|
37
|
+
"""
|
|
38
|
+
if isinstance(tiling_dim, int):
|
|
39
|
+
return {tiling_dim}
|
|
40
|
+
elif isinstance(tiling_dim, tuple):
|
|
41
|
+
return set(tiling_dim)
|
|
42
|
+
else:
|
|
43
|
+
return set()
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _default_strategy(
|
|
47
|
+
ub_capacity_bits: int,
|
|
48
|
+
safety_margin: float,
|
|
49
|
+
dtype_size: int,
|
|
50
|
+
memory_multiplier: float,
|
|
51
|
+
shapes: Tuple[Tuple[int, ...], ...],
|
|
52
|
+
tiling_dims: Tuple[Union[int, Tuple[int, ...]], ...],
|
|
53
|
+
) -> Tuple[int, ...]:
|
|
54
|
+
"""
|
|
55
|
+
Default tiling strategy: calculate maximum safe block size based on UB capacity.
|
|
56
|
+
|
|
57
|
+
This is a unified strategy function that works for all kernels by abstracting
|
|
58
|
+
the memory calculation as: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 bits
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
ub_capacity_bits: UB capacity in bits
|
|
62
|
+
safety_margin: Safety margin as a float (e.g., 0.80 for 80%)
|
|
63
|
+
dtype_size: Size of data type in bytes (e.g., 2 for float16, 4 for float32)
|
|
64
|
+
memory_multiplier: Memory multiplier for estimating peak memory usage
|
|
65
|
+
shapes: Tuple of full shapes. Each shape is a tuple of dimension sizes.
|
|
66
|
+
- For ROPE: ((n_q_head, hd), (n_kv_head, hd))
|
|
67
|
+
- For GEGLU: ((n_cols,),)
|
|
68
|
+
tiling_dims: Tuple specifying which dimensions can be tiled for each shape.
|
|
69
|
+
Each element can be:
|
|
70
|
+
- int: single dimension index (e.g., 0 for first dimension)
|
|
71
|
+
- tuple of ints: multiple dimensions that can be tiled together
|
|
72
|
+
- For ROPE: (0, 0) means first dimension of each shape can be tiled
|
|
73
|
+
- For GEGLU: (0,) means first dimension of the shape can be tiled
|
|
74
|
+
Length must match len(shapes).
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Tuple of maximum safe block sizes, one for each shape.
|
|
78
|
+
Each element is a power of 2.
|
|
79
|
+
|
|
80
|
+
Note:
|
|
81
|
+
For each shape, fixed dimensions (non-tiling) are multiplied together to get unit_param.
|
|
82
|
+
The final block size is computed in compute_default_tiling_strategy by taking
|
|
83
|
+
min(desired_block_size, max_safe_block_size) where desired_block_size = triton.next_power_of_2(original_dim).
|
|
84
|
+
"""
|
|
85
|
+
if not shapes or not tiling_dims:
|
|
86
|
+
return ()
|
|
87
|
+
|
|
88
|
+
# Calculate max_safe_block_size for each tiling dimension
|
|
89
|
+
max_safe_sizes = []
|
|
90
|
+
|
|
91
|
+
for shape, tiling_dim in zip(shapes, tiling_dims):
|
|
92
|
+
# Normalize tiling_dim to a set of dimension indices
|
|
93
|
+
tiling_dim_set = _normalize_tiling_dims(tiling_dim)
|
|
94
|
+
|
|
95
|
+
# Validate tiling dimensions are within shape bounds
|
|
96
|
+
if not tiling_dim_set:
|
|
97
|
+
raise ValueError(
|
|
98
|
+
f"Invalid tiling_dim: {tiling_dim}. tiling_dim must be an int or a non-empty tuple of ints."
|
|
99
|
+
)
|
|
100
|
+
if any(dim_idx < 0 or dim_idx >= len(shape) for dim_idx in tiling_dim_set):
|
|
101
|
+
raise ValueError(
|
|
102
|
+
f"Invalid tiling_dim: {tiling_dim} for shape {shape}. "
|
|
103
|
+
f"All dimension indices must be in range [0, {len(shape)})."
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Calculate unit_param: product of fixed (non-tiling) dimensions
|
|
107
|
+
unit_param = 1.0
|
|
108
|
+
for dim_idx, dim_size in enumerate(shape):
|
|
109
|
+
if dim_idx not in tiling_dim_set:
|
|
110
|
+
if dim_size <= 0:
|
|
111
|
+
# Invalid dimension size, use conservative default
|
|
112
|
+
unit_param = 1.0
|
|
113
|
+
break
|
|
114
|
+
unit_param *= float(dim_size)
|
|
115
|
+
|
|
116
|
+
# Ensure unit_param is at least 1.0
|
|
117
|
+
if unit_param <= 0:
|
|
118
|
+
unit_param = 1.0
|
|
119
|
+
|
|
120
|
+
# Calculate maximum safe block size based on UB capacity
|
|
121
|
+
# Memory: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 bits
|
|
122
|
+
SAFE_UB_CAPACITY_BITS = int(ub_capacity_bits * safety_margin)
|
|
123
|
+
|
|
124
|
+
# Solve: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 <= SAFE_UB_CAPACITY_BITS
|
|
125
|
+
# BLOCK_SIZE <= SAFE_UB_CAPACITY_BITS / (memory_multiplier * unit_param * dtype_size * 8)
|
|
126
|
+
max_block_size = int(SAFE_UB_CAPACITY_BITS // (memory_multiplier * unit_param * dtype_size * 8))
|
|
127
|
+
max_block_size = max(1, max_block_size)
|
|
128
|
+
|
|
129
|
+
# Find largest power of 2 <= max_block_size
|
|
130
|
+
# Use triton.next_power_of_2(max_block_size + 1) // 2 to get the largest power of 2 <= max_block_size
|
|
131
|
+
safe_block_size = triton.next_power_of_2(max_block_size + 1) // 2
|
|
132
|
+
max_safe_sizes.append(safe_block_size)
|
|
133
|
+
|
|
134
|
+
return tuple(max_safe_sizes)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class UBManager:
|
|
138
|
+
"""
|
|
139
|
+
Unified Buffer Manager for Ascend NPU.
|
|
140
|
+
|
|
141
|
+
Provides UB capacity detection and management for Ascend NPU devices.
|
|
142
|
+
The UB capacity is used by tiling strategy functions to calculate optimal block sizes.
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
def __init__(self, ub_capacity_bits: Optional[int] = None):
|
|
146
|
+
"""
|
|
147
|
+
Initialize UB Manager.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
ub_capacity_bits: UB capacity in bits. If None, will be detected automatically.
|
|
151
|
+
"""
|
|
152
|
+
self._npu_model = self._detect_npu_model()
|
|
153
|
+
self._ub_capacity_bits = ub_capacity_bits or self._detect_ub_capacity()
|
|
154
|
+
|
|
155
|
+
@property
|
|
156
|
+
def ub_capacity_bits(self) -> int:
|
|
157
|
+
"""Get UB capacity in bits."""
|
|
158
|
+
return self._ub_capacity_bits
|
|
159
|
+
|
|
160
|
+
@property
|
|
161
|
+
def ub_capacity_bytes(self) -> int:
|
|
162
|
+
"""Get UB capacity in bytes."""
|
|
163
|
+
return self._ub_capacity_bits // 8
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
def npu_model(self) -> str:
|
|
167
|
+
"""Get detected NPU model name."""
|
|
168
|
+
return self._npu_model
|
|
169
|
+
|
|
170
|
+
def _detect_npu_model(self) -> str:
|
|
171
|
+
"""Detect NPU model from device properties."""
|
|
172
|
+
if not is_npu_available():
|
|
173
|
+
return "unknown"
|
|
174
|
+
|
|
175
|
+
try:
|
|
176
|
+
dev_props = torch.npu.get_device_properties(0)
|
|
177
|
+
# Try to get model name from device properties
|
|
178
|
+
return dev_props.name
|
|
179
|
+
except Exception:
|
|
180
|
+
pass
|
|
181
|
+
|
|
182
|
+
return "default"
|
|
183
|
+
|
|
184
|
+
def _detect_ub_capacity(self) -> int:
|
|
185
|
+
"""
|
|
186
|
+
Detect UB capacity from environment variable or device properties.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
UB capacity in bits.
|
|
190
|
+
"""
|
|
191
|
+
# Check environment variable first
|
|
192
|
+
env_capacity = os.getenv("ASCEND_UB_CAPACITY_BITS")
|
|
193
|
+
if env_capacity is not None:
|
|
194
|
+
try:
|
|
195
|
+
return int(env_capacity)
|
|
196
|
+
except ValueError:
|
|
197
|
+
pass
|
|
198
|
+
|
|
199
|
+
# Try to get from device properties
|
|
200
|
+
if is_npu_available():
|
|
201
|
+
try:
|
|
202
|
+
dev_props = torch.npu.get_device_properties(0)
|
|
203
|
+
if hasattr(dev_props, "ub_capacity_bits"):
|
|
204
|
+
return dev_props.ub_capacity_bits
|
|
205
|
+
except Exception:
|
|
206
|
+
pass
|
|
207
|
+
|
|
208
|
+
# Fall back to model-based defaults
|
|
209
|
+
model = self._npu_model
|
|
210
|
+
return _DEFAULT_UB_CAPACITIES.get(model, _DEFAULT_UB_CAPACITIES["default"])
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
# Global singleton instance
|
|
214
|
+
_ub_manager: Optional[UBManager] = None
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def get_ub_manager() -> UBManager:
|
|
218
|
+
"""Get global UB manager instance."""
|
|
219
|
+
global _ub_manager
|
|
220
|
+
if _ub_manager is None:
|
|
221
|
+
_ub_manager = UBManager()
|
|
222
|
+
return _ub_manager
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def compute_default_tiling_strategy(
|
|
226
|
+
safety_margin: float = 0.80,
|
|
227
|
+
dtype_size: Optional[int] = None,
|
|
228
|
+
memory_multiplier: Optional[float] = None,
|
|
229
|
+
shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
|
|
230
|
+
tiling_dims: Optional[Tuple[Union[int, Tuple[int, ...]], ...]] = None,
|
|
231
|
+
) -> Optional[Tuple[Tuple[int, ...], ...]]:
|
|
232
|
+
"""
|
|
233
|
+
Compute tiling strategy using the default strategy function.
|
|
234
|
+
|
|
235
|
+
This function directly calls the default strategy and computes the final
|
|
236
|
+
tiling result. All kernels use the same unified strategy function, so
|
|
237
|
+
there's no need for kernel_name-based lookup.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
safety_margin: Safety margin as a float (e.g., 0.80 for 80%). Default is 0.80.
|
|
241
|
+
dtype_size: Size of data type in bytes (e.g., 2 for float16, 4 for float32).
|
|
242
|
+
Must be provided. If None or <= 0, defaults to 4 (float32).
|
|
243
|
+
memory_multiplier: Memory multiplier for estimating peak memory usage.
|
|
244
|
+
- For GEGLU: typically 10.0 for backward, 7.0 for forward
|
|
245
|
+
- For ROPE: typically 3.0
|
|
246
|
+
If None, defaults to 10.0 (conservative estimate).
|
|
247
|
+
shapes: Tuple of full shapes. Each shape is a tuple of dimension sizes.
|
|
248
|
+
- For ROPE: ((n_q_head, hd), (n_kv_head, hd))
|
|
249
|
+
- For GEGLU: ((n_cols,),)
|
|
250
|
+
Can pass original shapes (will handle padding internally) or padded shapes.
|
|
251
|
+
tiling_dims: Tuple specifying which dimensions can be tiled for each shape.
|
|
252
|
+
Each element can be:
|
|
253
|
+
- int: single dimension index (e.g., 0 for first dimension)
|
|
254
|
+
- tuple of ints: multiple dimensions that can be tiled together
|
|
255
|
+
- For ROPE: (0, 0) means first dimension of each shape can be tiled
|
|
256
|
+
- For GEGLU: (0,) means first dimension of the shape can be tiled
|
|
257
|
+
Length must match len(shapes). Cannot be empty.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
Tuple of tiled shapes with same structure as input shapes.
|
|
261
|
+
Tiling dimensions are replaced with computed block sizes (power of 2),
|
|
262
|
+
while non-tiling dimensions are padded to next power of 2.
|
|
263
|
+
- For ROPE: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
|
|
264
|
+
- For GEGLU: ((block_size,),)
|
|
265
|
+
Returns None if shapes or tiling_dims is None or empty.
|
|
266
|
+
|
|
267
|
+
Examples:
|
|
268
|
+
>>> # ROPE forward
|
|
269
|
+
>>> strategy = compute_default_tiling_strategy(
|
|
270
|
+
... safety_margin=0.90,
|
|
271
|
+
... dtype_size=4,
|
|
272
|
+
... memory_multiplier=3.0,
|
|
273
|
+
... shapes=((32, 128), (32, 128)),
|
|
274
|
+
... tiling_dims=(0, 0)
|
|
275
|
+
... )
|
|
276
|
+
>>> # Returns: ((block_size_q, 128), (block_size_kv, 128))
|
|
277
|
+
>>> # GEGLU forward
|
|
278
|
+
>>> strategy = compute_default_tiling_strategy(
|
|
279
|
+
... safety_margin=0.80,
|
|
280
|
+
... dtype_size=2,
|
|
281
|
+
... memory_multiplier=7.0,
|
|
282
|
+
... shapes=((4096,),),
|
|
283
|
+
... tiling_dims=(0,)
|
|
284
|
+
... )
|
|
285
|
+
>>> # Returns: ((block_size,),)
|
|
286
|
+
"""
|
|
287
|
+
ub_manager = get_ub_manager()
|
|
288
|
+
|
|
289
|
+
if shapes is None or not shapes or tiling_dims is None or not tiling_dims:
|
|
290
|
+
return None
|
|
291
|
+
|
|
292
|
+
if len(shapes) != len(tiling_dims):
|
|
293
|
+
return None
|
|
294
|
+
|
|
295
|
+
if dtype_size is None or dtype_size <= 0:
|
|
296
|
+
dtype_size = 4 # Default to float32
|
|
297
|
+
|
|
298
|
+
if memory_multiplier is None or memory_multiplier <= 0:
|
|
299
|
+
memory_multiplier = 10.0 # Default conservative estimate
|
|
300
|
+
|
|
301
|
+
# Call strategy to get max_safe_block_size for each shape
|
|
302
|
+
max_supported = _default_strategy(
|
|
303
|
+
ub_manager.ub_capacity_bits,
|
|
304
|
+
safety_margin,
|
|
305
|
+
dtype_size,
|
|
306
|
+
memory_multiplier,
|
|
307
|
+
shapes,
|
|
308
|
+
tiling_dims,
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
if not max_supported or len(max_supported) != len(shapes):
|
|
312
|
+
return None
|
|
313
|
+
|
|
314
|
+
# Build result: same structure as shapes, with tiling dims replaced by computed block sizes
|
|
315
|
+
result = []
|
|
316
|
+
for shape, tiling_dim, max_safe in zip(shapes, tiling_dims, max_supported):
|
|
317
|
+
result_shape = list(shape)
|
|
318
|
+
|
|
319
|
+
# Normalize tiling_dim to a set of dimension indices
|
|
320
|
+
tiling_dim_set = _normalize_tiling_dims(tiling_dim)
|
|
321
|
+
|
|
322
|
+
# Validate tiling dimensions are within shape bounds
|
|
323
|
+
if not tiling_dim_set:
|
|
324
|
+
raise ValueError(
|
|
325
|
+
f"Invalid tiling_dim: {tiling_dim}. tiling_dim must be an int or a non-empty tuple of ints."
|
|
326
|
+
)
|
|
327
|
+
if any(dim_idx < 0 or dim_idx >= len(result_shape) for dim_idx in tiling_dim_set):
|
|
328
|
+
raise ValueError(
|
|
329
|
+
f"Invalid tiling_dim: {tiling_dim} for shape {shape}. "
|
|
330
|
+
f"All dimension indices must be in range [0, {len(result_shape)})."
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
# Replace tiling dimensions with computed block sizes
|
|
334
|
+
# For each tiling dimension, compute: min(desired, max_safe)
|
|
335
|
+
for dim_idx in tiling_dim_set:
|
|
336
|
+
original_dim = result_shape[dim_idx]
|
|
337
|
+
desired = triton.next_power_of_2(original_dim)
|
|
338
|
+
final_val = min(desired, max_safe)
|
|
339
|
+
final_val = max(1, final_val) # Ensure at least 1
|
|
340
|
+
result_shape[dim_idx] = final_val
|
|
341
|
+
|
|
342
|
+
# Pad non-tiling dimensions to next power of 2
|
|
343
|
+
for dim_idx, dim_size in enumerate(result_shape):
|
|
344
|
+
if dim_idx not in tiling_dim_set:
|
|
345
|
+
result_shape[dim_idx] = triton.next_power_of_2(dim_size)
|
|
346
|
+
|
|
347
|
+
result.append(tuple(result_shape))
|
|
348
|
+
|
|
349
|
+
return tuple(result)
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Vendor registry for Liger-Kernel multi-backend support.
|
|
3
|
+
|
|
4
|
+
This module defines VendorInfo and the registry for vendor registration.
|
|
5
|
+
Each vendor registers itself by calling register_vendor() in its __init__.py.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
# Dynamically get backends package path to avoid hardcoding
|
|
12
|
+
_BACKENDS_PACKAGE = __name__.rsplit(".", 1)[0] # "liger_kernel.ops.backends"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class VendorInfo:
|
|
17
|
+
"""
|
|
18
|
+
Information about a chip vendor and its supported device.
|
|
19
|
+
|
|
20
|
+
Attributes:
|
|
21
|
+
vendor: Vendor name (e.g., "ascend", "intel", "nvidia")
|
|
22
|
+
device: Device type this vendor supports (e.g., "npu", "xpu")
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
vendor: str
|
|
26
|
+
device: str
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def module_path(self) -> str:
|
|
30
|
+
"""Auto-generated module path based on vendor name."""
|
|
31
|
+
return f"{_BACKENDS_PACKAGE}._{self.vendor}.ops"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# Registry mapping device types to their vendor info
|
|
35
|
+
# Vendors register themselves via register_vendor()
|
|
36
|
+
VENDOR_REGISTRY: dict[str, VendorInfo] = {}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def register_vendor(vendor_info: VendorInfo) -> None:
|
|
40
|
+
"""
|
|
41
|
+
Register a vendor's info in the global registry.
|
|
42
|
+
|
|
43
|
+
This should be called in each vendor's __init__.py to register itself.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
vendor_info: VendorInfo instance to register
|
|
47
|
+
"""
|
|
48
|
+
VENDOR_REGISTRY[vendor_info.device] = vendor_info
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_vendor_for_device(device: str) -> Optional[VendorInfo]:
|
|
52
|
+
"""
|
|
53
|
+
Get the VendorInfo for a given device type.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
device: Device type (e.g., "npu", "xpu")
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
VendorInfo if found, None otherwise
|
|
60
|
+
"""
|
|
61
|
+
return VENDOR_REGISTRY.get(device)
|
|
@@ -143,13 +143,16 @@ def liger_cross_entropy_kernel(
|
|
|
143
143
|
block_max = tl.max(X_block)
|
|
144
144
|
|
|
145
145
|
# Track argmax for accuracy computation
|
|
146
|
-
if RETURN_TOKEN_ACCURACY
|
|
146
|
+
if RETURN_TOKEN_ACCURACY:
|
|
147
147
|
# Find the index of the maximum value in this block
|
|
148
148
|
is_max_mask = X_block == block_max
|
|
149
149
|
# Mask out invalid indices with a value larger than n_cols
|
|
150
150
|
masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
|
|
151
151
|
# Get the first (smallest) index where max occurs
|
|
152
|
-
|
|
152
|
+
current_block_argmax_idx = tl.min(masked_offsets)
|
|
153
|
+
|
|
154
|
+
is_new_max = block_max > m
|
|
155
|
+
argmax_idx = tl.where(is_new_max, current_block_argmax_idx, argmax_idx)
|
|
153
156
|
|
|
154
157
|
if label_smoothing > 0:
|
|
155
158
|
# scale X beforehand to avoid overflow
|
|
@@ -289,7 +292,13 @@ def liger_cross_entropy_kernel(
|
|
|
289
292
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
290
293
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
291
294
|
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
|
292
|
-
|
|
295
|
+
# the best size we found by manually tuning on xpu and npu.
|
|
296
|
+
if infer_device() == "xpu":
|
|
297
|
+
MAX_FUSED_SIZE = 4096
|
|
298
|
+
elif infer_device() == "npu":
|
|
299
|
+
MAX_FUSED_SIZE = 2048
|
|
300
|
+
else:
|
|
301
|
+
MAX_FUSED_SIZE = 65536 // 2
|
|
293
302
|
|
|
294
303
|
|
|
295
304
|
def cross_entropy_forward(
|
|
@@ -6,11 +6,12 @@ from liger_kernel.ops.utils import amp_custom_bwd
|
|
|
6
6
|
from liger_kernel.ops.utils import amp_custom_fwd
|
|
7
7
|
from liger_kernel.ops.utils import element_mul_kernel
|
|
8
8
|
from liger_kernel.ops.utils import is_hip
|
|
9
|
+
from liger_kernel.utils import infer_device
|
|
9
10
|
|
|
10
11
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
11
12
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
12
13
|
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
|
13
|
-
MAX_FUSED_SIZE = 65536 // 2
|
|
14
|
+
MAX_FUSED_SIZE = 2048 if infer_device() == "npu" else 65536 // 2
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
def fused_linear_cross_entropy_forward(
|
liger_kernel/ops/geglu.py
CHANGED
|
@@ -67,8 +67,9 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
|
|
|
67
67
|
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
|
68
68
|
tanh_result = tanh(tanh_arg)
|
|
69
69
|
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
|
70
|
+
geglu_a = geglu_a.to(dc_row.dtype).to(tl.float32)
|
|
70
71
|
|
|
71
|
-
db_row = dc_row * geglu_a
|
|
72
|
+
db_row = dc_row.cast(tl.float32) * geglu_a
|
|
72
73
|
|
|
73
74
|
# Gradient w.r.t. a can be computed with:
|
|
74
75
|
# b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
|
|
@@ -79,7 +80,7 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
|
|
|
79
80
|
da_row = dc_row * b_row * (term1 + term2)
|
|
80
81
|
|
|
81
82
|
tl.store(a + col_offsets, da_row, mask=mask)
|
|
82
|
-
tl.store(b + col_offsets, db_row, mask=mask)
|
|
83
|
+
tl.store(b + col_offsets, db_row.to(dc_row.dtype), mask=mask)
|
|
83
84
|
|
|
84
85
|
|
|
85
86
|
def geglu_forward(a, b):
|