liger-kernel-nightly 0.6.2.dev20251011154427__py3-none-any.whl → 0.6.4.dev20260107111351__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (97) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +20 -5
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
  4. liger_kernel/chunked_loss/grpo_loss.py +8 -5
  5. liger_kernel/chunked_loss/jsd_loss.py +39 -11
  6. liger_kernel/ops/__init__.py +141 -0
  7. liger_kernel/ops/backends/README.md +151 -0
  8. liger_kernel/ops/backends/__init__.py +13 -0
  9. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  10. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  11. liger_kernel/ops/backends/_ascend/ops/__init__.py +43 -0
  12. liger_kernel/ops/backends/_ascend/ops/geglu.py +244 -0
  13. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  14. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  15. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  16. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  17. liger_kernel/ops/backends/registry.py +61 -0
  18. liger_kernel/ops/cross_entropy.py +75 -12
  19. liger_kernel/ops/dyt.py +5 -2
  20. liger_kernel/ops/fused_add_rms_norm.py +5 -1
  21. liger_kernel/ops/fused_linear_cross_entropy.py +45 -14
  22. liger_kernel/ops/geglu.py +5 -3
  23. liger_kernel/ops/group_norm.py +2 -1
  24. liger_kernel/ops/grpo_loss.py +3 -1
  25. liger_kernel/ops/layer_norm.py +86 -66
  26. liger_kernel/ops/poly_norm.py +390 -0
  27. liger_kernel/ops/rms_norm.py +131 -49
  28. liger_kernel/ops/tiled_mlp.py +136 -0
  29. liger_kernel/ops/utils.py +14 -0
  30. liger_kernel/transformers/__init__.py +30 -0
  31. liger_kernel/transformers/auto_model.py +21 -0
  32. liger_kernel/transformers/cross_entropy.py +9 -4
  33. liger_kernel/transformers/dyt.py +1 -1
  34. liger_kernel/transformers/experimental/embedding.py +1 -1
  35. liger_kernel/transformers/functional.py +48 -25
  36. liger_kernel/transformers/fused_add_rms_norm.py +1 -1
  37. liger_kernel/transformers/fused_linear_cross_entropy.py +9 -4
  38. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  39. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  40. liger_kernel/transformers/geglu.py +1 -1
  41. liger_kernel/transformers/group_norm.py +1 -1
  42. liger_kernel/transformers/grpo_loss.py +57 -2
  43. liger_kernel/transformers/jsd.py +1 -1
  44. liger_kernel/transformers/kl_div.py +1 -1
  45. liger_kernel/transformers/layer_norm.py +1 -1
  46. liger_kernel/transformers/llama4_rope.py +1 -1
  47. liger_kernel/transformers/model/falcon_h1.py +19 -5
  48. liger_kernel/transformers/model/gemma.py +17 -6
  49. liger_kernel/transformers/model/gemma2.py +14 -5
  50. liger_kernel/transformers/model/gemma3.py +26 -12
  51. liger_kernel/transformers/model/glm4.py +16 -4
  52. liger_kernel/transformers/model/glm4v.py +16 -4
  53. liger_kernel/transformers/model/glm4v_moe.py +23 -4
  54. liger_kernel/transformers/model/gpt_oss.py +211 -0
  55. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  56. liger_kernel/transformers/model/internvl.py +12 -5
  57. liger_kernel/transformers/model/llama.py +14 -5
  58. liger_kernel/transformers/model/llama4.py +16 -4
  59. liger_kernel/transformers/model/llava.py +12 -4
  60. liger_kernel/transformers/model/loss_utils.py +31 -3
  61. liger_kernel/transformers/model/mistral.py +15 -6
  62. liger_kernel/transformers/model/mixtral.py +16 -7
  63. liger_kernel/transformers/model/mllama.py +12 -4
  64. liger_kernel/transformers/model/olmo2.py +16 -4
  65. liger_kernel/transformers/model/olmo3.py +142 -0
  66. liger_kernel/transformers/model/output_classes.py +147 -0
  67. liger_kernel/transformers/model/paligemma.py +23 -5
  68. liger_kernel/transformers/model/phi3.py +14 -7
  69. liger_kernel/transformers/model/qwen2.py +16 -3
  70. liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
  71. liger_kernel/transformers/model/qwen2_vl.py +16 -4
  72. liger_kernel/transformers/model/qwen3.py +20 -5
  73. liger_kernel/transformers/model/qwen3_moe.py +19 -5
  74. liger_kernel/transformers/model/qwen3_next.py +146 -0
  75. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  76. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  77. liger_kernel/transformers/model/smollm3.py +15 -6
  78. liger_kernel/transformers/model/smolvlm.py +158 -0
  79. liger_kernel/transformers/monkey_patch.py +702 -48
  80. liger_kernel/transformers/multi_token_attention.py +1 -1
  81. liger_kernel/transformers/poly_norm.py +42 -0
  82. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  83. liger_kernel/transformers/rms_norm.py +15 -3
  84. liger_kernel/transformers/rope.py +45 -1
  85. liger_kernel/transformers/softmax.py +1 -1
  86. liger_kernel/transformers/sparsemax.py +1 -1
  87. liger_kernel/transformers/swiglu.py +18 -1
  88. liger_kernel/transformers/tiled_mlp.py +133 -0
  89. liger_kernel/transformers/tvd.py +1 -1
  90. liger_kernel/utils.py +52 -0
  91. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/METADATA +12 -3
  92. liger_kernel_nightly-0.6.4.dev20260107111351.dist-info/RECORD +130 -0
  93. liger_kernel_nightly-0.6.2.dev20251011154427.dist-info/RECORD +0 -107
  94. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/LICENSE +0 -0
  95. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/NOTICE +0 -0
  96. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/WHEEL +0 -0
  97. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,349 @@
1
+ """
2
+ Unified Buffer (UB) Manager for Ascend NPU.
3
+
4
+ This module provides UB capacity detection and tiling strategy computation
5
+ for running Triton kernels on Ascend NPU. It automatically calculates
6
+ optimal block sizes based on UB capacity constraints to prevent UB overflow.
7
+ """
8
+
9
+ import os
10
+
11
+ from typing import Optional
12
+ from typing import Tuple
13
+ from typing import Union
14
+
15
+ import torch
16
+ import triton
17
+
18
+ from liger_kernel.utils import is_npu_available
19
+
20
+ # Default UB capacities for different NPU models (in bits)
21
+ _DEFAULT_UB_CAPACITIES = {
22
+ "Ascend910B1": 2097152, # ~256 KB
23
+ "Ascend910B4": 1572864, # ~192 KB
24
+ "default": 2097152, # ~256 KB
25
+ }
26
+
27
+
28
+ def _normalize_tiling_dims(tiling_dim: Union[int, Tuple[int, ...]]) -> set:
29
+ """
30
+ Normalize tiling dimension specification to a set of dimension indices.
31
+
32
+ Args:
33
+ tiling_dim: Either an int (single dimension) or tuple of ints (multiple dimensions).
34
+
35
+ Returns:
36
+ Set of dimension indices that can be tiled.
37
+ """
38
+ if isinstance(tiling_dim, int):
39
+ return {tiling_dim}
40
+ elif isinstance(tiling_dim, tuple):
41
+ return set(tiling_dim)
42
+ else:
43
+ return set()
44
+
45
+
46
+ def _default_strategy(
47
+ ub_capacity_bits: int,
48
+ safety_margin: float,
49
+ dtype_size: int,
50
+ memory_multiplier: float,
51
+ shapes: Tuple[Tuple[int, ...], ...],
52
+ tiling_dims: Tuple[Union[int, Tuple[int, ...]], ...],
53
+ ) -> Tuple[int, ...]:
54
+ """
55
+ Default tiling strategy: calculate maximum safe block size based on UB capacity.
56
+
57
+ This is a unified strategy function that works for all kernels by abstracting
58
+ the memory calculation as: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 bits
59
+
60
+ Args:
61
+ ub_capacity_bits: UB capacity in bits
62
+ safety_margin: Safety margin as a float (e.g., 0.80 for 80%)
63
+ dtype_size: Size of data type in bytes (e.g., 2 for float16, 4 for float32)
64
+ memory_multiplier: Memory multiplier for estimating peak memory usage
65
+ shapes: Tuple of full shapes. Each shape is a tuple of dimension sizes.
66
+ - For ROPE: ((n_q_head, hd), (n_kv_head, hd))
67
+ - For GEGLU: ((n_cols,),)
68
+ tiling_dims: Tuple specifying which dimensions can be tiled for each shape.
69
+ Each element can be:
70
+ - int: single dimension index (e.g., 0 for first dimension)
71
+ - tuple of ints: multiple dimensions that can be tiled together
72
+ - For ROPE: (0, 0) means first dimension of each shape can be tiled
73
+ - For GEGLU: (0,) means first dimension of the shape can be tiled
74
+ Length must match len(shapes).
75
+
76
+ Returns:
77
+ Tuple of maximum safe block sizes, one for each shape.
78
+ Each element is a power of 2.
79
+
80
+ Note:
81
+ For each shape, fixed dimensions (non-tiling) are multiplied together to get unit_param.
82
+ The final block size is computed in compute_default_tiling_strategy by taking
83
+ min(desired_block_size, max_safe_block_size) where desired_block_size = triton.next_power_of_2(original_dim).
84
+ """
85
+ if not shapes or not tiling_dims:
86
+ return ()
87
+
88
+ # Calculate max_safe_block_size for each tiling dimension
89
+ max_safe_sizes = []
90
+
91
+ for shape, tiling_dim in zip(shapes, tiling_dims):
92
+ # Normalize tiling_dim to a set of dimension indices
93
+ tiling_dim_set = _normalize_tiling_dims(tiling_dim)
94
+
95
+ # Validate tiling dimensions are within shape bounds
96
+ if not tiling_dim_set:
97
+ raise ValueError(
98
+ f"Invalid tiling_dim: {tiling_dim}. tiling_dim must be an int or a non-empty tuple of ints."
99
+ )
100
+ if any(dim_idx < 0 or dim_idx >= len(shape) for dim_idx in tiling_dim_set):
101
+ raise ValueError(
102
+ f"Invalid tiling_dim: {tiling_dim} for shape {shape}. "
103
+ f"All dimension indices must be in range [0, {len(shape)})."
104
+ )
105
+
106
+ # Calculate unit_param: product of fixed (non-tiling) dimensions
107
+ unit_param = 1.0
108
+ for dim_idx, dim_size in enumerate(shape):
109
+ if dim_idx not in tiling_dim_set:
110
+ if dim_size <= 0:
111
+ # Invalid dimension size, use conservative default
112
+ unit_param = 1.0
113
+ break
114
+ unit_param *= float(dim_size)
115
+
116
+ # Ensure unit_param is at least 1.0
117
+ if unit_param <= 0:
118
+ unit_param = 1.0
119
+
120
+ # Calculate maximum safe block size based on UB capacity
121
+ # Memory: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 bits
122
+ SAFE_UB_CAPACITY_BITS = int(ub_capacity_bits * safety_margin)
123
+
124
+ # Solve: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 <= SAFE_UB_CAPACITY_BITS
125
+ # BLOCK_SIZE <= SAFE_UB_CAPACITY_BITS / (memory_multiplier * unit_param * dtype_size * 8)
126
+ max_block_size = int(SAFE_UB_CAPACITY_BITS // (memory_multiplier * unit_param * dtype_size * 8))
127
+ max_block_size = max(1, max_block_size)
128
+
129
+ # Find largest power of 2 <= max_block_size
130
+ # Use triton.next_power_of_2(max_block_size + 1) // 2 to get the largest power of 2 <= max_block_size
131
+ safe_block_size = triton.next_power_of_2(max_block_size + 1) // 2
132
+ max_safe_sizes.append(safe_block_size)
133
+
134
+ return tuple(max_safe_sizes)
135
+
136
+
137
+ class UBManager:
138
+ """
139
+ Unified Buffer Manager for Ascend NPU.
140
+
141
+ Provides UB capacity detection and management for Ascend NPU devices.
142
+ The UB capacity is used by tiling strategy functions to calculate optimal block sizes.
143
+ """
144
+
145
+ def __init__(self, ub_capacity_bits: Optional[int] = None):
146
+ """
147
+ Initialize UB Manager.
148
+
149
+ Args:
150
+ ub_capacity_bits: UB capacity in bits. If None, will be detected automatically.
151
+ """
152
+ self._npu_model = self._detect_npu_model()
153
+ self._ub_capacity_bits = ub_capacity_bits or self._detect_ub_capacity()
154
+
155
+ @property
156
+ def ub_capacity_bits(self) -> int:
157
+ """Get UB capacity in bits."""
158
+ return self._ub_capacity_bits
159
+
160
+ @property
161
+ def ub_capacity_bytes(self) -> int:
162
+ """Get UB capacity in bytes."""
163
+ return self._ub_capacity_bits // 8
164
+
165
+ @property
166
+ def npu_model(self) -> str:
167
+ """Get detected NPU model name."""
168
+ return self._npu_model
169
+
170
+ def _detect_npu_model(self) -> str:
171
+ """Detect NPU model from device properties."""
172
+ if not is_npu_available():
173
+ return "unknown"
174
+
175
+ try:
176
+ dev_props = torch.npu.get_device_properties(0)
177
+ # Try to get model name from device properties
178
+ return dev_props.name
179
+ except Exception:
180
+ pass
181
+
182
+ return "default"
183
+
184
+ def _detect_ub_capacity(self) -> int:
185
+ """
186
+ Detect UB capacity from environment variable or device properties.
187
+
188
+ Returns:
189
+ UB capacity in bits.
190
+ """
191
+ # Check environment variable first
192
+ env_capacity = os.getenv("ASCEND_UB_CAPACITY_BITS")
193
+ if env_capacity is not None:
194
+ try:
195
+ return int(env_capacity)
196
+ except ValueError:
197
+ pass
198
+
199
+ # Try to get from device properties
200
+ if is_npu_available():
201
+ try:
202
+ dev_props = torch.npu.get_device_properties(0)
203
+ if hasattr(dev_props, "ub_capacity_bits"):
204
+ return dev_props.ub_capacity_bits
205
+ except Exception:
206
+ pass
207
+
208
+ # Fall back to model-based defaults
209
+ model = self._npu_model
210
+ return _DEFAULT_UB_CAPACITIES.get(model, _DEFAULT_UB_CAPACITIES["default"])
211
+
212
+
213
+ # Global singleton instance
214
+ _ub_manager: Optional[UBManager] = None
215
+
216
+
217
+ def get_ub_manager() -> UBManager:
218
+ """Get global UB manager instance."""
219
+ global _ub_manager
220
+ if _ub_manager is None:
221
+ _ub_manager = UBManager()
222
+ return _ub_manager
223
+
224
+
225
+ def compute_default_tiling_strategy(
226
+ safety_margin: float = 0.80,
227
+ dtype_size: Optional[int] = None,
228
+ memory_multiplier: Optional[float] = None,
229
+ shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
230
+ tiling_dims: Optional[Tuple[Union[int, Tuple[int, ...]], ...]] = None,
231
+ ) -> Optional[Tuple[Tuple[int, ...], ...]]:
232
+ """
233
+ Compute tiling strategy using the default strategy function.
234
+
235
+ This function directly calls the default strategy and computes the final
236
+ tiling result. All kernels use the same unified strategy function, so
237
+ there's no need for kernel_name-based lookup.
238
+
239
+ Args:
240
+ safety_margin: Safety margin as a float (e.g., 0.80 for 80%). Default is 0.80.
241
+ dtype_size: Size of data type in bytes (e.g., 2 for float16, 4 for float32).
242
+ Must be provided. If None or <= 0, defaults to 4 (float32).
243
+ memory_multiplier: Memory multiplier for estimating peak memory usage.
244
+ - For GEGLU: typically 10.0 for backward, 7.0 for forward
245
+ - For ROPE: typically 3.0
246
+ If None, defaults to 10.0 (conservative estimate).
247
+ shapes: Tuple of full shapes. Each shape is a tuple of dimension sizes.
248
+ - For ROPE: ((n_q_head, hd), (n_kv_head, hd))
249
+ - For GEGLU: ((n_cols,),)
250
+ Can pass original shapes (will handle padding internally) or padded shapes.
251
+ tiling_dims: Tuple specifying which dimensions can be tiled for each shape.
252
+ Each element can be:
253
+ - int: single dimension index (e.g., 0 for first dimension)
254
+ - tuple of ints: multiple dimensions that can be tiled together
255
+ - For ROPE: (0, 0) means first dimension of each shape can be tiled
256
+ - For GEGLU: (0,) means first dimension of the shape can be tiled
257
+ Length must match len(shapes). Cannot be empty.
258
+
259
+ Returns:
260
+ Tuple of tiled shapes with same structure as input shapes.
261
+ Tiling dimensions are replaced with computed block sizes (power of 2),
262
+ while non-tiling dimensions are padded to next power of 2.
263
+ - For ROPE: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
264
+ - For GEGLU: ((block_size,),)
265
+ Returns None if shapes or tiling_dims is None or empty.
266
+
267
+ Examples:
268
+ >>> # ROPE forward
269
+ >>> strategy = compute_default_tiling_strategy(
270
+ ... safety_margin=0.90,
271
+ ... dtype_size=4,
272
+ ... memory_multiplier=3.0,
273
+ ... shapes=((32, 128), (32, 128)),
274
+ ... tiling_dims=(0, 0)
275
+ ... )
276
+ >>> # Returns: ((block_size_q, 128), (block_size_kv, 128))
277
+ >>> # GEGLU forward
278
+ >>> strategy = compute_default_tiling_strategy(
279
+ ... safety_margin=0.80,
280
+ ... dtype_size=2,
281
+ ... memory_multiplier=7.0,
282
+ ... shapes=((4096,),),
283
+ ... tiling_dims=(0,)
284
+ ... )
285
+ >>> # Returns: ((block_size,),)
286
+ """
287
+ ub_manager = get_ub_manager()
288
+
289
+ if shapes is None or not shapes or tiling_dims is None or not tiling_dims:
290
+ return None
291
+
292
+ if len(shapes) != len(tiling_dims):
293
+ return None
294
+
295
+ if dtype_size is None or dtype_size <= 0:
296
+ dtype_size = 4 # Default to float32
297
+
298
+ if memory_multiplier is None or memory_multiplier <= 0:
299
+ memory_multiplier = 10.0 # Default conservative estimate
300
+
301
+ # Call strategy to get max_safe_block_size for each shape
302
+ max_supported = _default_strategy(
303
+ ub_manager.ub_capacity_bits,
304
+ safety_margin,
305
+ dtype_size,
306
+ memory_multiplier,
307
+ shapes,
308
+ tiling_dims,
309
+ )
310
+
311
+ if not max_supported or len(max_supported) != len(shapes):
312
+ return None
313
+
314
+ # Build result: same structure as shapes, with tiling dims replaced by computed block sizes
315
+ result = []
316
+ for shape, tiling_dim, max_safe in zip(shapes, tiling_dims, max_supported):
317
+ result_shape = list(shape)
318
+
319
+ # Normalize tiling_dim to a set of dimension indices
320
+ tiling_dim_set = _normalize_tiling_dims(tiling_dim)
321
+
322
+ # Validate tiling dimensions are within shape bounds
323
+ if not tiling_dim_set:
324
+ raise ValueError(
325
+ f"Invalid tiling_dim: {tiling_dim}. tiling_dim must be an int or a non-empty tuple of ints."
326
+ )
327
+ if any(dim_idx < 0 or dim_idx >= len(result_shape) for dim_idx in tiling_dim_set):
328
+ raise ValueError(
329
+ f"Invalid tiling_dim: {tiling_dim} for shape {shape}. "
330
+ f"All dimension indices must be in range [0, {len(result_shape)})."
331
+ )
332
+
333
+ # Replace tiling dimensions with computed block sizes
334
+ # For each tiling dimension, compute: min(desired, max_safe)
335
+ for dim_idx in tiling_dim_set:
336
+ original_dim = result_shape[dim_idx]
337
+ desired = triton.next_power_of_2(original_dim)
338
+ final_val = min(desired, max_safe)
339
+ final_val = max(1, final_val) # Ensure at least 1
340
+ result_shape[dim_idx] = final_val
341
+
342
+ # Pad non-tiling dimensions to next power of 2
343
+ for dim_idx, dim_size in enumerate(result_shape):
344
+ if dim_idx not in tiling_dim_set:
345
+ result_shape[dim_idx] = triton.next_power_of_2(dim_size)
346
+
347
+ result.append(tuple(result_shape))
348
+
349
+ return tuple(result)
@@ -0,0 +1,61 @@
1
+ """
2
+ Vendor registry for Liger-Kernel multi-backend support.
3
+
4
+ This module defines VendorInfo and the registry for vendor registration.
5
+ Each vendor registers itself by calling register_vendor() in its __init__.py.
6
+ """
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Optional
10
+
11
+ # Dynamically get backends package path to avoid hardcoding
12
+ _BACKENDS_PACKAGE = __name__.rsplit(".", 1)[0] # "liger_kernel.ops.backends"
13
+
14
+
15
+ @dataclass
16
+ class VendorInfo:
17
+ """
18
+ Information about a chip vendor and its supported device.
19
+
20
+ Attributes:
21
+ vendor: Vendor name (e.g., "ascend", "intel", "nvidia")
22
+ device: Device type this vendor supports (e.g., "npu", "xpu")
23
+ """
24
+
25
+ vendor: str
26
+ device: str
27
+
28
+ @property
29
+ def module_path(self) -> str:
30
+ """Auto-generated module path based on vendor name."""
31
+ return f"{_BACKENDS_PACKAGE}._{self.vendor}.ops"
32
+
33
+
34
+ # Registry mapping device types to their vendor info
35
+ # Vendors register themselves via register_vendor()
36
+ VENDOR_REGISTRY: dict[str, VendorInfo] = {}
37
+
38
+
39
+ def register_vendor(vendor_info: VendorInfo) -> None:
40
+ """
41
+ Register a vendor's info in the global registry.
42
+
43
+ This should be called in each vendor's __init__.py to register itself.
44
+
45
+ Args:
46
+ vendor_info: VendorInfo instance to register
47
+ """
48
+ VENDOR_REGISTRY[vendor_info.device] = vendor_info
49
+
50
+
51
+ def get_vendor_for_device(device: str) -> Optional[VendorInfo]:
52
+ """
53
+ Get the VendorInfo for a given device type.
54
+
55
+ Args:
56
+ device: Device type (e.g., "npu", "xpu")
57
+
58
+ Returns:
59
+ VendorInfo if found, None otherwise
60
+ """
61
+ return VENDOR_REGISTRY.get(device)
@@ -10,8 +10,9 @@ from liger_kernel.ops.utils import compare_version
10
10
  from liger_kernel.ops.utils import element_mul_kernel
11
11
  from liger_kernel.ops.utils import is_hip
12
12
  from liger_kernel.utils import infer_device
13
+ from liger_kernel.utils import is_npu_available
13
14
 
14
- if compare_version("triton", operator.ge, "3.0.0"):
15
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
15
16
  try:
16
17
  # typical import path with dispatch available
17
18
  from triton.language.extra.libdevice import tanh
@@ -32,6 +33,8 @@ def liger_cross_entropy_kernel(
32
33
  loss_ptr,
33
34
  z_loss_ptr,
34
35
  loss_stride,
36
+ token_accuracy_ptr,
37
+ token_accuracy_stride,
35
38
  n_cols,
36
39
  n_non_ignore,
37
40
  sum_non_ignore_weight,
@@ -42,6 +45,7 @@ def liger_cross_entropy_kernel(
42
45
  reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
43
46
  softcap,
44
47
  RETURN_Z_LOSS: tl.constexpr,
48
+ RETURN_TOKEN_ACCURACY: tl.constexpr,
45
49
  BLOCK_SIZE: tl.constexpr,
46
50
  HAS_WEIGHT: tl.constexpr,
47
51
  HAS_SOFTCAPPING: tl.constexpr,
@@ -60,6 +64,8 @@ def liger_cross_entropy_kernel(
60
64
  loss_ptr: Pointer to tensor to store the loss.
61
65
  z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
62
66
  loss_stride (int): The stride of the loss tensor.
67
+ token_accuracy_ptr: Pointer to tensor to store the per-token accuracy. No operation if RETURN_TOKEN_ACCURACY is 0.
68
+ token_accuracy_stride (int): The stride of the token accuracy tensor.
63
69
  n_cols (int): The number of columns in the input tensor.
64
70
  n_non_ignore (float): The number of non-ignored elements in the batch.
65
71
  sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
@@ -69,7 +75,8 @@ def liger_cross_entropy_kernel(
69
75
  lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
70
76
  reduction (str): The string for the reduction to apply
71
77
  softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
72
- RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
78
+ RETURN_Z_LOSS (int): The boolean value to decide whether to store z loss to z_loss_ptr or not. It must be 0 or 1.
79
+ RETURN_TOKEN_ACCURACY (int): The boolean value to decide whether to store per-token accuracy to token_accuracy_ptr or not. It must be 0 or 1.
73
80
  BLOCK_SIZE (int): The block size for Triton operations.
74
81
  HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
75
82
  HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
@@ -92,11 +99,17 @@ def liger_cross_entropy_kernel(
92
99
  for i in range(0, n_cols, BLOCK_SIZE):
93
100
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
94
101
  tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
102
+ # For ignored tokens, set token accuracy to 0
103
+ if RETURN_TOKEN_ACCURACY:
104
+ token_accuracy_ptr += program_id * token_accuracy_stride
105
+ tl.store(token_accuracy_ptr, 0.0)
95
106
  return
96
107
 
97
108
  loss_ptr += program_id * loss_stride
98
109
  if RETURN_Z_LOSS:
99
110
  z_loss_ptr += program_id * loss_stride
111
+ if RETURN_TOKEN_ACCURACY:
112
+ token_accuracy_ptr += program_id * token_accuracy_stride
100
113
 
101
114
  if HAS_WEIGHT:
102
115
  weight_y = tl.load(weight_ptr + y).cast(tl.float32)
@@ -107,6 +120,7 @@ def liger_cross_entropy_kernel(
107
120
  # 3. [Online softmax] first pass: find max + sum
108
121
  m = float("-inf") # m is the max value. use the notation from the paper
109
122
  d = 0.0 # d is the sum. use the notation from the paper
123
+ argmax_idx = 0 # Track the index of the maximum value for token accuracy computation
110
124
  ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
111
125
  if HAS_SOFTCAPPING:
112
126
  ori_X_y = softcap * tanh(ori_X_y / softcap)
@@ -127,6 +141,19 @@ def liger_cross_entropy_kernel(
127
141
  if HAS_SOFTCAPPING:
128
142
  X_block = softcap * tanh(X_block / softcap)
129
143
  block_max = tl.max(X_block)
144
+
145
+ # Track argmax for accuracy computation
146
+ if RETURN_TOKEN_ACCURACY:
147
+ # Find the index of the maximum value in this block
148
+ is_max_mask = X_block == block_max
149
+ # Mask out invalid indices with a value larger than n_cols
150
+ masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
151
+ # Get the first (smallest) index where max occurs
152
+ current_block_argmax_idx = tl.min(masked_offsets)
153
+
154
+ is_new_max = block_max > m
155
+ argmax_idx = tl.where(is_new_max, current_block_argmax_idx, argmax_idx)
156
+
130
157
  if label_smoothing > 0:
131
158
  # scale X beforehand to avoid overflow
132
159
  if HAS_WEIGHT:
@@ -256,12 +283,22 @@ def liger_cross_entropy_kernel(
256
283
  tl.store(loss_ptr, loss)
257
284
  if RETURN_Z_LOSS:
258
285
  tl.store(z_loss_ptr, z_loss)
286
+ if RETURN_TOKEN_ACCURACY:
287
+ # Store 1.0 if prediction is correct, 0.0 otherwise
288
+ is_correct = 1.0 if argmax_idx == y else 0.0
289
+ tl.store(token_accuracy_ptr, is_correct)
259
290
 
260
291
 
261
292
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
262
293
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
263
294
  # The optimal maximum block size depends on your hardware, your kernel, and your dtype
264
- MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning
295
+ # the best size we found by manually tuning on xpu and npu.
296
+ if infer_device() == "xpu":
297
+ MAX_FUSED_SIZE = 4096
298
+ elif infer_device() == "npu":
299
+ MAX_FUSED_SIZE = 2048
300
+ else:
301
+ MAX_FUSED_SIZE = 65536 // 2
265
302
 
266
303
 
267
304
  def cross_entropy_forward(
@@ -274,8 +311,12 @@ def cross_entropy_forward(
274
311
  reduction,
275
312
  softcap,
276
313
  return_z_loss,
314
+ return_token_accuracy=False,
277
315
  ):
278
316
  assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
317
+ assert isinstance(return_token_accuracy, bool), (
318
+ f"return_token_accuracy must be True or False. Got: {return_token_accuracy}"
319
+ )
279
320
 
280
321
  BT, V = _input.shape
281
322
  n_rows = BT
@@ -285,6 +326,9 @@ def cross_entropy_forward(
285
326
  # unreduced loss
286
327
  loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
287
328
  z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
329
+ token_accuracy_1d = (
330
+ torch.zeros(n_rows, dtype=torch.float32, device=_input.device) if return_token_accuracy else None
331
+ )
288
332
 
289
333
  target_mask = target != ignore_index
290
334
  n_non_ignore = target_mask.sum().item()
@@ -321,6 +365,10 @@ def cross_entropy_forward(
321
365
  loss_ptr=loss_1d,
322
366
  z_loss_ptr=z_loss_1d,
323
367
  loss_stride=loss_1d.stride(-1), # always 1
368
+ token_accuracy_ptr=token_accuracy_1d,
369
+ token_accuracy_stride=token_accuracy_1d.stride(-1)
370
+ if return_token_accuracy
371
+ else 0, # always 1 if accuracy is enabled
324
372
  n_cols=V,
325
373
  n_non_ignore=n_non_ignore,
326
374
  sum_non_ignore_weight=sum_non_ignore_weight,
@@ -331,6 +379,7 @@ def cross_entropy_forward(
331
379
  reduction=reduction,
332
380
  softcap=softcap,
333
381
  RETURN_Z_LOSS=return_z_loss,
382
+ RETURN_TOKEN_ACCURACY=return_token_accuracy,
334
383
  BLOCK_SIZE=BLOCK_SIZE,
335
384
  HAS_WEIGHT=True if weight is not None else False,
336
385
  HAS_SOFTCAPPING=True if softcap is not None else False,
@@ -343,11 +392,14 @@ def cross_entropy_forward(
343
392
  if reduction == "none":
344
393
  loss = loss_1d
345
394
  z_loss = z_loss_1d if return_z_loss else None
395
+ token_accuracy = token_accuracy_1d if return_token_accuracy else None
346
396
  else:
347
397
  loss = torch.sum(loss_1d)
348
398
  z_loss = torch.sum(z_loss_1d) if return_z_loss else None
399
+ # For accuracy, we compute the mean across all non-ignored tokens
400
+ token_accuracy = torch.sum(token_accuracy_1d) / n_non_ignore if return_token_accuracy else None
349
401
 
350
- return loss, z_loss, _input
402
+ return loss, z_loss, token_accuracy, _input
351
403
 
352
404
 
353
405
  def cross_entropy_backward(_input, grad_output):
@@ -395,6 +447,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
395
447
  reduction: str = "mean",
396
448
  softcap: Optional[float] = None,
397
449
  return_z_loss: bool = False,
450
+ return_token_accuracy: bool = False,
398
451
  ):
399
452
  """
400
453
  The forward pass of the Liger Cross Entropy loss.
@@ -409,12 +462,15 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
409
462
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
410
463
  reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
411
464
  softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
412
- return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
465
+ return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss, token_accuracy) instead of (loss, None, None). Default: `False`
466
+ return_token_accuracy (bool): When `return_token_accuracy` is `True`, computes and returns per-token accuracy without materializing logits. Default: `False`
413
467
 
414
468
  Returns:
415
- tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
469
+ tuple: A tuple with the computed losses and accuracy: (loss, z_loss, token_accuracy). z_loss and token_accuracy are None if not requested.
416
470
  """
417
- loss, z_loss, _input = cross_entropy_forward(
471
+ input_requires_grad = _input.requires_grad
472
+
473
+ loss, z_loss, token_accuracy, _input = cross_entropy_forward(
418
474
  _input,
419
475
  target,
420
476
  weight,
@@ -424,29 +480,35 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
424
480
  reduction,
425
481
  softcap,
426
482
  return_z_loss,
483
+ return_token_accuracy,
427
484
  )
428
485
  # TODO: investigation
429
486
  # If we don't detach the _input tensor, the memory will double
430
487
  # Not sure why but seems that there will be a time both grad and value exist but in different location
431
- ctx.save_for_backward(_input.detach())
488
+ if input_requires_grad:
489
+ ctx.save_for_backward(_input.detach())
432
490
  ctx.return_z_loss = return_z_loss
491
+ ctx.return_token_accuracy = return_token_accuracy
433
492
 
434
- return loss, z_loss
493
+ return loss, z_loss, token_accuracy
435
494
 
436
495
  @staticmethod
437
- def backward(ctx, grad_output, grad_ouput2):
496
+ def backward(ctx, grad_output, grad_output2, grad_output3):
438
497
  """
439
498
  The backward pass of the Liger Cross Entropy loss.
440
499
 
441
500
  Parameters:
442
501
  ctx : The context object with saved tensors.
443
502
  grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
444
- grad_output2 (tenosr): No use.
503
+ grad_output2 (tensor): No use. Gradient for z_loss (not used as z_loss is only for logging).
504
+ grad_output3 (tensor): No use. Gradient for token_accuracy (not used as token_accuracy is only for metrics).
445
505
  Returns:
446
506
  tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
447
507
  """
448
508
  if ctx.return_z_loss:
449
- del grad_ouput2 # z_loss is only for logging
509
+ del grad_output2 # z_loss is only for logging
510
+ if ctx.return_token_accuracy:
511
+ del grad_output3 # token_accuracy is only for metrics
450
512
 
451
513
  (_input,) = ctx.saved_tensors
452
514
  _input = cross_entropy_backward(_input, grad_output)
@@ -460,4 +522,5 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
460
522
  None,
461
523
  None,
462
524
  None,
525
+ None,
463
526
  )
liger_kernel/ops/dyt.py CHANGED
@@ -7,8 +7,10 @@ import triton.language as tl
7
7
  from liger_kernel.ops.utils import compare_version
8
8
  from liger_kernel.ops.utils import ensure_contiguous
9
9
  from liger_kernel.ops.utils import infer_device
10
+ from liger_kernel.utils import get_npu_multi_processor_count
11
+ from liger_kernel.utils import is_npu_available
10
12
 
11
- if compare_version("triton", operator.ge, "3.0.0"):
13
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
12
14
  try:
13
15
  # typical import path with dispatch available
14
16
  from triton.language.extra.libdevice import tanh
@@ -125,7 +127,8 @@ def liger_dyt_bwd(dy, x, alpha, gamma, beta):
125
127
  NUM_SMS = torch.cuda.get_device_properties(x.device).multi_processor_count
126
128
  elif device == "xpu":
127
129
  NUM_SMS = torch.xpu.get_device_properties(x.device).gpu_subslice_count
128
-
130
+ elif device == "npu":
131
+ NUM_SMS = get_npu_multi_processor_count()
129
132
  da = torch.zeros(NUM_SMS, triton.cdiv(N, 512), dtype=torch.float32, device=x.device)
130
133
  dg = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device)
131
134
  db = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None