liger-kernel 0.6.4__py3-none-any.whl → 0.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +7 -1
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +10 -3
  3. liger_kernel/chunked_loss/jsd_loss.py +21 -6
  4. liger_kernel/ops/__init__.py +141 -0
  5. liger_kernel/ops/backends/README.md +151 -0
  6. liger_kernel/ops/backends/__init__.py +13 -0
  7. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  8. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +492 -0
  9. liger_kernel/ops/backends/_ascend/ops/__init__.py +61 -0
  10. liger_kernel/ops/backends/_ascend/ops/embedding.py +214 -0
  11. liger_kernel/ops/backends/_ascend/ops/geglu.py +191 -0
  12. liger_kernel/ops/backends/_ascend/ops/llama4_rope.py +298 -0
  13. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +275 -0
  14. liger_kernel/ops/backends/_ascend/ops/rope.py +265 -0
  15. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  16. liger_kernel/ops/backends/_ascend/ops/tvd.py +223 -0
  17. liger_kernel/ops/backends/_ascend/ub_manager.py +367 -0
  18. liger_kernel/ops/backends/registry.py +61 -0
  19. liger_kernel/ops/cross_entropy.py +14 -4
  20. liger_kernel/ops/dyt.py +5 -2
  21. liger_kernel/ops/fused_add_rms_norm.py +21 -23
  22. liger_kernel/ops/fused_linear_cross_entropy.py +2 -1
  23. liger_kernel/ops/geglu.py +5 -3
  24. liger_kernel/ops/group_norm.py +12 -8
  25. liger_kernel/ops/kl_div.py +8 -11
  26. liger_kernel/ops/layer_norm.py +17 -16
  27. liger_kernel/ops/poly_norm.py +19 -21
  28. liger_kernel/ops/rms_norm.py +149 -71
  29. liger_kernel/ops/utils.py +25 -0
  30. liger_kernel/transformers/__init__.py +6 -0
  31. liger_kernel/transformers/auto_model.py +21 -0
  32. liger_kernel/transformers/cross_entropy.py +1 -1
  33. liger_kernel/transformers/dyt.py +1 -1
  34. liger_kernel/transformers/experimental/embedding.py +1 -1
  35. liger_kernel/transformers/functional.py +20 -20
  36. liger_kernel/transformers/fused_add_rms_norm.py +1 -1
  37. liger_kernel/transformers/fused_linear_cross_entropy.py +1 -1
  38. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  39. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  40. liger_kernel/transformers/geglu.py +1 -1
  41. liger_kernel/transformers/group_norm.py +1 -1
  42. liger_kernel/transformers/grpo_loss.py +1 -1
  43. liger_kernel/transformers/jsd.py +1 -1
  44. liger_kernel/transformers/kl_div.py +1 -1
  45. liger_kernel/transformers/layer_norm.py +1 -1
  46. liger_kernel/transformers/llama4_rope.py +1 -1
  47. liger_kernel/transformers/model/exaone4.py +136 -0
  48. liger_kernel/transformers/model/gemma2.py +3 -3
  49. liger_kernel/transformers/model/gemma3.py +11 -5
  50. liger_kernel/transformers/model/gpt_oss.py +211 -0
  51. liger_kernel/transformers/model/loss_utils.py +6 -0
  52. liger_kernel/transformers/model/paligemma.py +1 -0
  53. liger_kernel/transformers/monkey_patch.py +196 -39
  54. liger_kernel/transformers/multi_token_attention.py +1 -1
  55. liger_kernel/transformers/poly_norm.py +1 -1
  56. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  57. liger_kernel/transformers/rms_norm.py +8 -3
  58. liger_kernel/transformers/rope.py +28 -27
  59. liger_kernel/transformers/softmax.py +1 -1
  60. liger_kernel/transformers/sparsemax.py +1 -1
  61. liger_kernel/transformers/swiglu.py +1 -1
  62. liger_kernel/transformers/tiled_mlp.py +5 -13
  63. liger_kernel/transformers/tvd.py +1 -1
  64. liger_kernel/utils.py +54 -0
  65. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/METADATA +11 -4
  66. liger_kernel-0.6.5.dist-info/RECORD +134 -0
  67. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/WHEEL +1 -1
  68. liger_kernel-0.6.4.dist-info/RECORD +0 -118
  69. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/licenses/LICENSE +0 -0
  70. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/licenses/NOTICE +0 -0
  71. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,223 @@
1
+ from typing import Literal
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+ from liger_kernel.ops.utils import get_npu_core_count
11
+
12
+ MAX_FUSED_SIZE = 65536 // 4
13
+
14
+ REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
15
+
16
+
17
+ @triton.jit
18
+ def _tv_distance_kernel(
19
+ p_ptr,
20
+ p_stride,
21
+ q_ptr,
22
+ q_stride,
23
+ loss_ptr,
24
+ loss_stride,
25
+ grads_ptr,
26
+ grads_stride,
27
+ label_ptr,
28
+ ignore_index: tl.constexpr,
29
+ n_cols, # V
30
+ total_rows: tl.constexpr, # BT
31
+ BLOCK_SIZE: tl.constexpr,
32
+ HAS_LABEL: tl.constexpr,
33
+ NUM_STAGES: tl.constexpr,
34
+ reduction: tl.constexpr = "batchmean",
35
+ ):
36
+ thread_id = tl.program_id(0)
37
+ num_threads = tl.num_programs(0)
38
+
39
+ for pid in tl.range(thread_id, total_rows, num_threads, num_stages=NUM_STAGES):
40
+ p_row_ptr = p_ptr + pid * p_stride
41
+ q_row_ptr = q_ptr + pid * q_stride
42
+ loss_row_ptr = loss_ptr + pid * loss_stride
43
+ grads_row_ptr = grads_ptr + pid * grads_stride
44
+ label_row_ptr = label_ptr + pid
45
+
46
+ base_offsets = tl.arange(0, BLOCK_SIZE)
47
+
48
+ should_skip = False
49
+ if HAS_LABEL:
50
+ label = tl.load(label_row_ptr)
51
+ if label == ignore_index:
52
+ should_skip = True
53
+
54
+ if should_skip:
55
+ for i in range(0, n_cols, BLOCK_SIZE):
56
+ offsets = i + base_offsets
57
+ mask = offsets < n_cols
58
+ tl.store(grads_row_ptr + offsets, 0.0, mask=mask)
59
+ if reduction == "none":
60
+ tl.store(loss_row_ptr + offsets, 0.0, mask=mask)
61
+ else:
62
+ loss_sum = 0.0
63
+ for i in range(0, n_cols, BLOCK_SIZE):
64
+ offsets = i + base_offsets
65
+ mask = offsets < n_cols
66
+
67
+ p = tl.load(p_row_ptr + offsets, mask=mask, other=0.0)
68
+ q = tl.load(q_row_ptr + offsets, mask=mask, other=0.0)
69
+
70
+ # TVD(P || Q) = 0.5 * |P - Q|
71
+ tv_loss = 0.5 * tl.abs(p - q)
72
+ grad_res = tl.where(p > q, 0.5, -0.5)
73
+
74
+ tl.store(grads_row_ptr + offsets, grad_res, mask=mask)
75
+
76
+ if reduction == "none":
77
+ tl.store(loss_row_ptr + offsets, tv_loss, mask=mask)
78
+ else:
79
+ loss_sum += tl.sum(tv_loss, axis=0)
80
+
81
+ if reduction != "none":
82
+ tl.store(loss_row_ptr, loss_sum)
83
+
84
+
85
+ def tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label):
86
+ BT, V = p.shape
87
+
88
+ # TVD forward tiling strategy
89
+ # - In main loop (calculate loss and grad):
90
+ # * p: BLOCK_Q elements
91
+ # * q: BLOCK_Q elements
92
+ # * tv_loss: BLOCK_Q elements
93
+ # * grad_res: BLOCK_Q elements
94
+ # * loss_sum: BLOCK_Q elements (when reduction != "none")
95
+ # * Total: 4 * BLOCK_Q elements or 5 * BLOCK_Q elements when reduction != "none"
96
+ # - Since loss_sum is not necessarily used in every calculation,
97
+ # - and considering the consumption of other shared memory and the potential memory consumption of the HAS_LABEL loop.
98
+ # - Conservative estimate: 5 * BLOCK_Q * dtype_size * 8 bits
99
+ # - For safety, use: memory_multiplier=5.0 * BLOCK_SIZE * pad_hd * dtype_size * 8 bits
100
+ # - shapes: ((V,),)
101
+ # - tiling_dims: (0,) means first dimension of each shape can be tiled
102
+ # - Returns: ((block_size,),
103
+ shapes = ((V,),)
104
+ tile_shapes = compute_default_tiling_strategy(
105
+ safety_margin=0.80,
106
+ # In the TVD calculation, many data are implicitly converted to f32, so the size of f32 can be directly used.
107
+ dtype_size=4,
108
+ memory_multiplier=5.0,
109
+ shapes=shapes,
110
+ tiling_dims=(0,),
111
+ )
112
+
113
+ if tile_shapes is not None and len(tile_shapes) > 0 and len(tile_shapes[0]) > 0:
114
+ # Strategy returns ((block_size,),)
115
+ BLOCK_SIZE = tile_shapes[0][0]
116
+ else:
117
+ # Fallback to desired block size if no best practice found (no tiling needed)
118
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
119
+
120
+ num_cores = get_npu_core_count()
121
+ grid = (min(num_cores, BT),)
122
+
123
+ out_size = (BT, V) if reduction == "none" else (BT,)
124
+
125
+ # The loss and grid accumulation on BF16 platform of NPU will have precision errors.
126
+ output_tensor = torch.zeros(out_size, device=p.device, dtype=torch.float32)
127
+ grads = torch.empty_like(p, dtype=torch.float32)
128
+
129
+ n_non_ignore = (shift_labels != ignore_index).sum().item() if has_label else BT
130
+
131
+ _tv_distance_kernel[grid](
132
+ p,
133
+ p.stride(0),
134
+ q,
135
+ q.stride(0),
136
+ output_tensor,
137
+ output_tensor.stride(0),
138
+ grads,
139
+ grads.stride(0),
140
+ shift_labels if has_label else torch.empty(1, device=p.device),
141
+ ignore_index,
142
+ V,
143
+ BT,
144
+ BLOCK_SIZE=BLOCK_SIZE,
145
+ HAS_LABEL=has_label,
146
+ NUM_STAGES=3 if BT < 4096 else 4,
147
+ reduction=reduction,
148
+ )
149
+
150
+ if reduction == "batchmean":
151
+ return output_tensor.sum() / n_non_ignore, grads / n_non_ignore
152
+ elif reduction == "sum":
153
+ return output_tensor.sum(dim=0), grads
154
+ elif reduction == "mean":
155
+ return output_tensor.sum() / (n_non_ignore * V), grads / (n_non_ignore * V)
156
+ else:
157
+ return output_tensor, grads
158
+
159
+
160
+ def tvd_backward_triton(grad_output, grads):
161
+ # If this is the last layer, grad_output is 1.0. Skip the mul then.
162
+ if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
163
+ return grads
164
+
165
+ return grads * grad_output
166
+
167
+
168
+ class LigerTVDLossFunction(torch.autograd.Function):
169
+ """
170
+ Class implementing the forward and backward pass for the Total Variation Distance Loss using Triton.
171
+ """
172
+
173
+ @staticmethod
174
+ @ensure_contiguous
175
+ def forward(
176
+ ctx,
177
+ p: torch.Tensor,
178
+ q: torch.Tensor,
179
+ shift_labels: Optional[torch.Tensor] = None,
180
+ reduction: REDUCTION_LITERAL = "batchmean",
181
+ ignore_index: int = -100,
182
+ ) -> torch.Tensor:
183
+ """A forward pass for the Total Variation Distance Loss.
184
+
185
+ Args:
186
+ ctx: Torch autograd context
187
+ p (torch.Tensor): A tensor of shape (BT, V) containing the first distribution.
188
+ q (torch.Tensor): A tensor of shape (BT, V) containing the second distribution.
189
+ shift_labels (Optional[torch.Tensor]): A tensor of shape (BT,) containing the labels.
190
+ reduction (REDUCTION_LITERAL, optional): The reduction method to be applied. Defaults to "batchmean".
191
+ ignore_index (int, optional): The index to ignore during loss calculation. Defaults to -100.
192
+
193
+ Returns:
194
+ torch.Tensor: The computed Total Variation Distance Loss.
195
+ """
196
+ has_label = False
197
+ if shift_labels is not None:
198
+ assert shift_labels.shape == (p.shape[0],), (
199
+ f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
200
+ )
201
+ shift_labels = shift_labels.contiguous()
202
+ has_label = True
203
+
204
+ loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label)
205
+ ctx.save_for_backward(grads)
206
+ return loss
207
+
208
+ @staticmethod
209
+ @ensure_contiguous
210
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
211
+ """A backward pass for the Total Variation Distance Loss.
212
+
213
+ Args:
214
+ ctx: Torch autograd context
215
+ grad_output (torch.Tensor): The gradient of the loss with respect to the output.
216
+
217
+ Returns:
218
+ tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs.
219
+ """
220
+ (grads,) = ctx.saved_tensors
221
+ grads = tvd_backward_triton(grad_output, grads)
222
+
223
+ return grads, None, None, None, None
@@ -0,0 +1,367 @@
1
+ """
2
+ Unified Buffer (UB) Manager for Ascend NPU.
3
+
4
+ This module provides UB capacity detection and tiling strategy computation
5
+ for running Triton kernels on Ascend NPU. It automatically calculates
6
+ optimal block sizes based on UB capacity constraints to prevent UB overflow.
7
+ """
8
+
9
+ import os
10
+
11
+ from typing import Optional
12
+ from typing import Tuple
13
+ from typing import Union
14
+
15
+ import torch
16
+ import triton
17
+
18
+ from liger_kernel.utils import is_npu_available
19
+
20
+
21
+ def _normalize_tiling_dims(tiling_dim: Union[int, Tuple[int, ...]]) -> set:
22
+ """
23
+ Normalize tiling dimension specification to a set of dimension indices.
24
+
25
+ Args:
26
+ tiling_dim: Either an int (single dimension) or tuple of ints (multiple dimensions).
27
+
28
+ Returns:
29
+ Set of dimension indices that can be tiled.
30
+ """
31
+ if isinstance(tiling_dim, int):
32
+ return {tiling_dim}
33
+ elif isinstance(tiling_dim, tuple):
34
+ return set(tiling_dim)
35
+ else:
36
+ return set()
37
+
38
+
39
+ def _default_strategy(
40
+ ub_capacity_bits: int,
41
+ safety_margin: float,
42
+ dtype_size: int,
43
+ memory_multiplier: float,
44
+ shapes: Tuple[Tuple[int, ...], ...],
45
+ tiling_dims: Tuple[Union[int, Tuple[int, ...]], ...],
46
+ ) -> Tuple[int, ...]:
47
+ """
48
+ Default tiling strategy: calculate maximum safe block size based on UB capacity.
49
+
50
+ This is a unified strategy function that works for all kernels by abstracting
51
+ the memory calculation as: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 bits
52
+
53
+ Args:
54
+ ub_capacity_bits: UB capacity in bits
55
+ safety_margin: Safety margin as a float (e.g., 0.80 for 80%)
56
+ dtype_size: Size of data type in bytes (e.g., 2 for float16, 4 for float32)
57
+ memory_multiplier: Memory multiplier for estimating peak memory usage
58
+ shapes: Tuple of full shapes. Each shape is a tuple of dimension sizes.
59
+ - For ROPE: ((n_q_head, hd), (n_kv_head, hd))
60
+ - For GEGLU: ((n_cols,),)
61
+ tiling_dims: Tuple specifying which dimensions can be tiled for each shape.
62
+ Each element can be:
63
+ - int: single dimension index (e.g., 0 for first dimension)
64
+ - tuple of ints: multiple dimensions that can be tiled together
65
+ - For ROPE: (0, 0) means first dimension of each shape can be tiled
66
+ - For GEGLU: (0,) means first dimension of the shape can be tiled
67
+ Length must match len(shapes).
68
+
69
+ Returns:
70
+ Tuple of maximum safe block sizes, one for each shape.
71
+ Each element is a power of 2.
72
+
73
+ Note:
74
+ For each shape, fixed dimensions (non-tiling) are multiplied together to get unit_param.
75
+ The final block size is computed in compute_default_tiling_strategy by taking
76
+ min(desired_block_size, max_safe_block_size) where desired_block_size = triton.next_power_of_2(original_dim).
77
+ """
78
+ if not shapes or not tiling_dims:
79
+ return ()
80
+
81
+ # Calculate max_safe_block_size for each tiling dimension
82
+ max_safe_sizes = []
83
+
84
+ for shape, tiling_dim in zip(shapes, tiling_dims):
85
+ # Normalize tiling_dim to a set of dimension indices
86
+ tiling_dim_set = _normalize_tiling_dims(tiling_dim)
87
+
88
+ # Validate tiling dimensions are within shape bounds
89
+ if not tiling_dim_set:
90
+ raise ValueError(
91
+ f"Invalid tiling_dim: {tiling_dim}. tiling_dim must be an int or a non-empty tuple of ints."
92
+ )
93
+ if any(dim_idx < 0 or dim_idx >= len(shape) for dim_idx in tiling_dim_set):
94
+ raise ValueError(
95
+ f"Invalid tiling_dim: {tiling_dim} for shape {shape}. "
96
+ f"All dimension indices must be in range [0, {len(shape)})."
97
+ )
98
+
99
+ # Calculate unit_param: product of fixed (non-tiling) dimensions
100
+ unit_param = 1.0
101
+ for dim_idx, dim_size in enumerate(shape):
102
+ if dim_idx not in tiling_dim_set:
103
+ if dim_size <= 0:
104
+ # Invalid dimension size, use conservative default
105
+ unit_param = 1.0
106
+ break
107
+ unit_param *= float(dim_size)
108
+
109
+ # Ensure unit_param is at least 1.0
110
+ if unit_param <= 0:
111
+ unit_param = 1.0
112
+
113
+ # Calculate maximum safe block size based on UB capacity
114
+ # Memory: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 bits
115
+ SAFE_UB_CAPACITY_BITS = int(ub_capacity_bits * safety_margin)
116
+
117
+ # Solve: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 <= SAFE_UB_CAPACITY_BITS
118
+ # BLOCK_SIZE <= SAFE_UB_CAPACITY_BITS / (memory_multiplier * unit_param * dtype_size * 8)
119
+ max_block_size = int(SAFE_UB_CAPACITY_BITS // (memory_multiplier * unit_param * dtype_size * 8))
120
+ max_block_size = max(1, max_block_size)
121
+
122
+ # Find largest power of 2 <= max_block_size
123
+ # Use triton.next_power_of_2(max_block_size + 1) // 2 to get the largest power of 2 <= max_block_size
124
+ safe_block_size = triton.next_power_of_2(max_block_size + 1) // 2
125
+ max_safe_sizes.append(safe_block_size)
126
+
127
+ return tuple(max_safe_sizes)
128
+
129
+
130
+ class UBManager:
131
+ """
132
+ Unified Buffer Manager for Ascend NPU.
133
+
134
+ Provides UB capacity detection and management for Ascend NPU devices.
135
+ The UB capacity is used by tiling strategy functions to calculate optimal block sizes.
136
+ """
137
+
138
+ def __init__(self, ub_capacity_bits: Optional[int] = None):
139
+ """
140
+ Initialize UB Manager.
141
+
142
+ Args:
143
+ ub_capacity_bits: UB capacity in bits. If None, will be detected automatically.
144
+ """
145
+ self._npu_model = self._detect_npu_model()
146
+ self._ub_capacity_bits = ub_capacity_bits or self._detect_ub_capacity()
147
+
148
+ @property
149
+ def ub_capacity_bits(self) -> int:
150
+ """Get UB capacity in bits."""
151
+ return self._ub_capacity_bits
152
+
153
+ @property
154
+ def ub_capacity_bytes(self) -> int:
155
+ """Get UB capacity in bytes."""
156
+ return self._ub_capacity_bits // 8
157
+
158
+ @property
159
+ def npu_model(self) -> str:
160
+ """Get detected NPU model name."""
161
+ return self._npu_model
162
+
163
+ def _detect_npu_model(self) -> str:
164
+ """Detect NPU model from device properties."""
165
+ if not is_npu_available():
166
+ return "unknown"
167
+
168
+ try:
169
+ dev_props = torch.npu.get_device_properties(0)
170
+ # Try to get model name from device properties
171
+ return dev_props.name
172
+ except Exception:
173
+ pass
174
+
175
+ return "default"
176
+
177
+ def _detect_ub_capacity(self) -> int:
178
+ """
179
+ Detect UB capacity from environment variable or get_soc_spec.
180
+
181
+ Returns:
182
+ UB capacity in bits.
183
+
184
+ Raises:
185
+ RuntimeError: If UB capacity cannot be detected and no environment variable is set.
186
+ """
187
+ # Check environment variable first (in bits)
188
+ env_capacity = os.getenv("ASCEND_UB_CAPACITY_BITS")
189
+ if env_capacity is not None:
190
+ try:
191
+ capacity_bits = int(env_capacity)
192
+ if capacity_bits > 0:
193
+ return capacity_bits
194
+ except ValueError:
195
+ pass
196
+
197
+ # Try to get from get_soc_spec (returns bytes, convert to bits)
198
+ if is_npu_available():
199
+ try:
200
+ from tbe.common.platform import get_soc_spec
201
+
202
+ # Query UB size (get_soc_spec returns size in bytes)
203
+ ub_size_bytes = get_soc_spec("UB_SIZE")
204
+
205
+ if ub_size_bytes is None or ub_size_bytes <= 0:
206
+ raise ValueError(f"Invalid UB_SIZE from get_soc_spec: {ub_size_bytes}")
207
+
208
+ # Convert bytes to bits
209
+ ub_capacity_bits = ub_size_bytes * 8
210
+ return ub_capacity_bits
211
+
212
+ except ImportError:
213
+ raise RuntimeError(
214
+ "Cannot import tbe.common.platform.get_soc_spec. "
215
+ "Please ensure CANN environment variables are sourced "
216
+ "(e.g., source /usr/local/Ascend/ascend-toolkit/set_env.sh)"
217
+ )
218
+ except Exception as e:
219
+ raise RuntimeError(
220
+ f"Failed to detect UB capacity from get_soc_spec: {e}. "
221
+ "Please set ASCEND_UB_CAPACITY_BITS environment variable as fallback."
222
+ )
223
+
224
+ # If NPU is not available, raise error
225
+ raise RuntimeError(
226
+ "NPU is not available and UB capacity cannot be detected. "
227
+ "Please set ASCEND_UB_CAPACITY_BITS environment variable."
228
+ )
229
+
230
+
231
+ # Global singleton instance
232
+ _ub_manager: Optional[UBManager] = None
233
+
234
+
235
+ def get_ub_manager() -> UBManager:
236
+ """Get global UB manager instance."""
237
+ global _ub_manager
238
+ if _ub_manager is None:
239
+ _ub_manager = UBManager()
240
+ return _ub_manager
241
+
242
+
243
+ def compute_default_tiling_strategy(
244
+ safety_margin: float = 0.80,
245
+ dtype_size: Optional[int] = None,
246
+ memory_multiplier: Optional[float] = None,
247
+ shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
248
+ tiling_dims: Optional[Tuple[Union[int, Tuple[int, ...]], ...]] = None,
249
+ ) -> Optional[Tuple[Tuple[int, ...], ...]]:
250
+ """
251
+ Compute tiling strategy using the default strategy function.
252
+
253
+ This function directly calls the default strategy and computes the final
254
+ tiling result. All kernels use the same unified strategy function, so
255
+ there's no need for kernel_name-based lookup.
256
+
257
+ Args:
258
+ safety_margin: Safety margin as a float (e.g., 0.80 for 80%). Default is 0.80.
259
+ dtype_size: Size of data type in bytes (e.g., 2 for float16, 4 for float32).
260
+ Must be provided. If None or <= 0, defaults to 4 (float32).
261
+ memory_multiplier: Memory multiplier for estimating peak memory usage.
262
+ - For GEGLU: typically 10.0 for backward, 4.0 for forward
263
+ - For ROPE: typically 3.0
264
+ If None, defaults to 10.0 (conservative estimate).
265
+ shapes: Tuple of full shapes. Each shape is a tuple of dimension sizes.
266
+ - For ROPE: ((n_q_head, hd), (n_kv_head, hd))
267
+ - For GEGLU: ((n_cols,),)
268
+ Can pass original shapes (will handle padding internally) or padded shapes.
269
+ tiling_dims: Tuple specifying which dimensions can be tiled for each shape.
270
+ Each element can be:
271
+ - int: single dimension index (e.g., 0 for first dimension)
272
+ - tuple of ints: multiple dimensions that can be tiled together
273
+ - For ROPE: (0, 0) means first dimension of each shape can be tiled
274
+ - For GEGLU: (0,) means first dimension of the shape can be tiled
275
+ Length must match len(shapes). Cannot be empty.
276
+
277
+ Returns:
278
+ Tuple of tiled shapes with same structure as input shapes.
279
+ Tiling dimensions are replaced with computed block sizes (power of 2),
280
+ while non-tiling dimensions are padded to next power of 2.
281
+ - For ROPE: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
282
+ - For GEGLU: ((block_size,),)
283
+ Returns None if shapes or tiling_dims is None or empty.
284
+
285
+ Examples:
286
+ >>> # ROPE forward
287
+ >>> strategy = compute_default_tiling_strategy(
288
+ ... safety_margin=0.90,
289
+ ... dtype_size=4,
290
+ ... memory_multiplier=3.0,
291
+ ... shapes=((32, 128), (32, 128)),
292
+ ... tiling_dims=(0, 0)
293
+ ... )
294
+ >>> # Returns: ((block_size_q, 128), (block_size_kv, 128))
295
+ >>> # GEGLU forward
296
+ >>> strategy = compute_default_tiling_strategy(
297
+ ... safety_margin=0.80,
298
+ ... dtype_size=2,
299
+ ... memory_multiplier=7.0,
300
+ ... shapes=((4096,),),
301
+ ... tiling_dims=(0,)
302
+ ... )
303
+ >>> # Returns: ((block_size,),)
304
+ """
305
+ ub_manager = get_ub_manager()
306
+
307
+ if shapes is None or not shapes or tiling_dims is None or not tiling_dims:
308
+ return None
309
+
310
+ if len(shapes) != len(tiling_dims):
311
+ return None
312
+
313
+ if dtype_size is None or dtype_size <= 0:
314
+ dtype_size = 4 # Default to float32
315
+
316
+ if memory_multiplier is None or memory_multiplier <= 0:
317
+ memory_multiplier = 10.0 # Default conservative estimate
318
+
319
+ # Call strategy to get max_safe_block_size for each shape
320
+ max_supported = _default_strategy(
321
+ ub_manager.ub_capacity_bits,
322
+ safety_margin,
323
+ dtype_size,
324
+ memory_multiplier,
325
+ shapes,
326
+ tiling_dims,
327
+ )
328
+
329
+ if not max_supported or len(max_supported) != len(shapes):
330
+ return None
331
+
332
+ # Build result: same structure as shapes, with tiling dims replaced by computed block sizes
333
+ result = []
334
+ for shape, tiling_dim, max_safe in zip(shapes, tiling_dims, max_supported):
335
+ result_shape = list(shape)
336
+
337
+ # Normalize tiling_dim to a set of dimension indices
338
+ tiling_dim_set = _normalize_tiling_dims(tiling_dim)
339
+
340
+ # Validate tiling dimensions are within shape bounds
341
+ if not tiling_dim_set:
342
+ raise ValueError(
343
+ f"Invalid tiling_dim: {tiling_dim}. tiling_dim must be an int or a non-empty tuple of ints."
344
+ )
345
+ if any(dim_idx < 0 or dim_idx >= len(result_shape) for dim_idx in tiling_dim_set):
346
+ raise ValueError(
347
+ f"Invalid tiling_dim: {tiling_dim} for shape {shape}. "
348
+ f"All dimension indices must be in range [0, {len(result_shape)})."
349
+ )
350
+
351
+ # Replace tiling dimensions with computed block sizes
352
+ # For each tiling dimension, compute: min(desired, max_safe)
353
+ for dim_idx in tiling_dim_set:
354
+ original_dim = result_shape[dim_idx]
355
+ desired = triton.next_power_of_2(original_dim)
356
+ final_val = min(desired, max_safe)
357
+ final_val = max(1, final_val) # Ensure at least 1
358
+ result_shape[dim_idx] = final_val
359
+
360
+ # Pad non-tiling dimensions to next power of 2
361
+ for dim_idx, dim_size in enumerate(result_shape):
362
+ if dim_idx not in tiling_dim_set:
363
+ result_shape[dim_idx] = triton.next_power_of_2(dim_size)
364
+
365
+ result.append(tuple(result_shape))
366
+
367
+ return tuple(result)
@@ -0,0 +1,61 @@
1
+ """
2
+ Vendor registry for Liger-Kernel multi-backend support.
3
+
4
+ This module defines VendorInfo and the registry for vendor registration.
5
+ Each vendor registers itself by calling register_vendor() in its __init__.py.
6
+ """
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Optional
10
+
11
+ # Dynamically get backends package path to avoid hardcoding
12
+ _BACKENDS_PACKAGE = __name__.rsplit(".", 1)[0] # "liger_kernel.ops.backends"
13
+
14
+
15
+ @dataclass
16
+ class VendorInfo:
17
+ """
18
+ Information about a chip vendor and its supported device.
19
+
20
+ Attributes:
21
+ vendor: Vendor name (e.g., "ascend", "intel", "nvidia")
22
+ device: Device type this vendor supports (e.g., "npu", "xpu")
23
+ """
24
+
25
+ vendor: str
26
+ device: str
27
+
28
+ @property
29
+ def module_path(self) -> str:
30
+ """Auto-generated module path based on vendor name."""
31
+ return f"{_BACKENDS_PACKAGE}._{self.vendor}.ops"
32
+
33
+
34
+ # Registry mapping device types to their vendor info
35
+ # Vendors register themselves via register_vendor()
36
+ VENDOR_REGISTRY: dict[str, VendorInfo] = {}
37
+
38
+
39
+ def register_vendor(vendor_info: VendorInfo) -> None:
40
+ """
41
+ Register a vendor's info in the global registry.
42
+
43
+ This should be called in each vendor's __init__.py to register itself.
44
+
45
+ Args:
46
+ vendor_info: VendorInfo instance to register
47
+ """
48
+ VENDOR_REGISTRY[vendor_info.device] = vendor_info
49
+
50
+
51
+ def get_vendor_for_device(device: str) -> Optional[VendorInfo]:
52
+ """
53
+ Get the VendorInfo for a given device type.
54
+
55
+ Args:
56
+ device: Device type (e.g., "npu", "xpu")
57
+
58
+ Returns:
59
+ VendorInfo if found, None otherwise
60
+ """
61
+ return VENDOR_REGISTRY.get(device)
@@ -10,8 +10,9 @@ from liger_kernel.ops.utils import compare_version
10
10
  from liger_kernel.ops.utils import element_mul_kernel
11
11
  from liger_kernel.ops.utils import is_hip
12
12
  from liger_kernel.utils import infer_device
13
+ from liger_kernel.utils import is_npu_available
13
14
 
14
- if compare_version("triton", operator.ge, "3.0.0"):
15
+ if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
15
16
  try:
16
17
  # typical import path with dispatch available
17
18
  from triton.language.extra.libdevice import tanh
@@ -142,13 +143,16 @@ def liger_cross_entropy_kernel(
142
143
  block_max = tl.max(X_block)
143
144
 
144
145
  # Track argmax for accuracy computation
145
- if RETURN_TOKEN_ACCURACY and block_max > m:
146
+ if RETURN_TOKEN_ACCURACY:
146
147
  # Find the index of the maximum value in this block
147
148
  is_max_mask = X_block == block_max
148
149
  # Mask out invalid indices with a value larger than n_cols
149
150
  masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
150
151
  # Get the first (smallest) index where max occurs
151
- argmax_idx = tl.min(masked_offsets)
152
+ current_block_argmax_idx = tl.min(masked_offsets)
153
+
154
+ is_new_max = block_max > m
155
+ argmax_idx = tl.where(is_new_max, current_block_argmax_idx, argmax_idx)
152
156
 
153
157
  if label_smoothing > 0:
154
158
  # scale X beforehand to avoid overflow
@@ -288,7 +292,13 @@ def liger_cross_entropy_kernel(
288
292
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
289
293
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
290
294
  # The optimal maximum block size depends on your hardware, your kernel, and your dtype
291
- MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning
295
+ # the best size we found by manually tuning on xpu and npu.
296
+ if infer_device() == "xpu":
297
+ MAX_FUSED_SIZE = 4096
298
+ elif infer_device() == "npu":
299
+ MAX_FUSED_SIZE = 2048
300
+ else:
301
+ MAX_FUSED_SIZE = 65536 // 2
292
302
 
293
303
 
294
304
  def cross_entropy_forward(