liger-kernel-nightly 0.6.4.dev20251202054858__py3-none-any.whl → 0.6.4.dev20260107111351__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (58) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +7 -1
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +10 -3
  3. liger_kernel/chunked_loss/jsd_loss.py +21 -6
  4. liger_kernel/ops/__init__.py +141 -0
  5. liger_kernel/ops/backends/README.md +151 -0
  6. liger_kernel/ops/backends/__init__.py +13 -0
  7. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  8. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  9. liger_kernel/ops/backends/_ascend/ops/__init__.py +43 -0
  10. liger_kernel/ops/backends/_ascend/ops/geglu.py +244 -0
  11. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  12. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  13. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  14. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  15. liger_kernel/ops/backends/registry.py +61 -0
  16. liger_kernel/ops/cross_entropy.py +12 -3
  17. liger_kernel/ops/fused_linear_cross_entropy.py +2 -1
  18. liger_kernel/ops/geglu.py +3 -2
  19. liger_kernel/ops/rms_norm.py +126 -49
  20. liger_kernel/ops/utils.py +12 -0
  21. liger_kernel/transformers/__init__.py +3 -0
  22. liger_kernel/transformers/auto_model.py +21 -0
  23. liger_kernel/transformers/cross_entropy.py +1 -1
  24. liger_kernel/transformers/dyt.py +1 -1
  25. liger_kernel/transformers/experimental/embedding.py +1 -1
  26. liger_kernel/transformers/functional.py +20 -20
  27. liger_kernel/transformers/fused_add_rms_norm.py +1 -1
  28. liger_kernel/transformers/fused_linear_cross_entropy.py +1 -1
  29. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  30. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  31. liger_kernel/transformers/geglu.py +1 -1
  32. liger_kernel/transformers/group_norm.py +1 -1
  33. liger_kernel/transformers/grpo_loss.py +1 -1
  34. liger_kernel/transformers/jsd.py +1 -1
  35. liger_kernel/transformers/kl_div.py +1 -1
  36. liger_kernel/transformers/layer_norm.py +1 -1
  37. liger_kernel/transformers/llama4_rope.py +1 -1
  38. liger_kernel/transformers/model/gemma3.py +1 -0
  39. liger_kernel/transformers/model/gpt_oss.py +211 -0
  40. liger_kernel/transformers/model/paligemma.py +1 -0
  41. liger_kernel/transformers/monkey_patch.py +118 -39
  42. liger_kernel/transformers/multi_token_attention.py +1 -1
  43. liger_kernel/transformers/poly_norm.py +1 -1
  44. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  45. liger_kernel/transformers/rms_norm.py +8 -3
  46. liger_kernel/transformers/rope.py +28 -27
  47. liger_kernel/transformers/softmax.py +1 -1
  48. liger_kernel/transformers/sparsemax.py +1 -1
  49. liger_kernel/transformers/swiglu.py +1 -1
  50. liger_kernel/transformers/tiled_mlp.py +3 -3
  51. liger_kernel/transformers/tvd.py +1 -1
  52. liger_kernel/utils.py +27 -0
  53. {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/METADATA +9 -3
  54. {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/RECORD +58 -46
  55. {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/LICENSE +0 -0
  56. {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/NOTICE +0 -0
  57. {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/WHEEL +0 -0
  58. {liger_kernel_nightly-0.6.4.dev20251202054858.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,349 @@
1
+ """
2
+ Unified Buffer (UB) Manager for Ascend NPU.
3
+
4
+ This module provides UB capacity detection and tiling strategy computation
5
+ for running Triton kernels on Ascend NPU. It automatically calculates
6
+ optimal block sizes based on UB capacity constraints to prevent UB overflow.
7
+ """
8
+
9
+ import os
10
+
11
+ from typing import Optional
12
+ from typing import Tuple
13
+ from typing import Union
14
+
15
+ import torch
16
+ import triton
17
+
18
+ from liger_kernel.utils import is_npu_available
19
+
20
+ # Default UB capacities for different NPU models (in bits)
21
+ _DEFAULT_UB_CAPACITIES = {
22
+ "Ascend910B1": 2097152, # ~256 KB
23
+ "Ascend910B4": 1572864, # ~192 KB
24
+ "default": 2097152, # ~256 KB
25
+ }
26
+
27
+
28
+ def _normalize_tiling_dims(tiling_dim: Union[int, Tuple[int, ...]]) -> set:
29
+ """
30
+ Normalize tiling dimension specification to a set of dimension indices.
31
+
32
+ Args:
33
+ tiling_dim: Either an int (single dimension) or tuple of ints (multiple dimensions).
34
+
35
+ Returns:
36
+ Set of dimension indices that can be tiled.
37
+ """
38
+ if isinstance(tiling_dim, int):
39
+ return {tiling_dim}
40
+ elif isinstance(tiling_dim, tuple):
41
+ return set(tiling_dim)
42
+ else:
43
+ return set()
44
+
45
+
46
+ def _default_strategy(
47
+ ub_capacity_bits: int,
48
+ safety_margin: float,
49
+ dtype_size: int,
50
+ memory_multiplier: float,
51
+ shapes: Tuple[Tuple[int, ...], ...],
52
+ tiling_dims: Tuple[Union[int, Tuple[int, ...]], ...],
53
+ ) -> Tuple[int, ...]:
54
+ """
55
+ Default tiling strategy: calculate maximum safe block size based on UB capacity.
56
+
57
+ This is a unified strategy function that works for all kernels by abstracting
58
+ the memory calculation as: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 bits
59
+
60
+ Args:
61
+ ub_capacity_bits: UB capacity in bits
62
+ safety_margin: Safety margin as a float (e.g., 0.80 for 80%)
63
+ dtype_size: Size of data type in bytes (e.g., 2 for float16, 4 for float32)
64
+ memory_multiplier: Memory multiplier for estimating peak memory usage
65
+ shapes: Tuple of full shapes. Each shape is a tuple of dimension sizes.
66
+ - For ROPE: ((n_q_head, hd), (n_kv_head, hd))
67
+ - For GEGLU: ((n_cols,),)
68
+ tiling_dims: Tuple specifying which dimensions can be tiled for each shape.
69
+ Each element can be:
70
+ - int: single dimension index (e.g., 0 for first dimension)
71
+ - tuple of ints: multiple dimensions that can be tiled together
72
+ - For ROPE: (0, 0) means first dimension of each shape can be tiled
73
+ - For GEGLU: (0,) means first dimension of the shape can be tiled
74
+ Length must match len(shapes).
75
+
76
+ Returns:
77
+ Tuple of maximum safe block sizes, one for each shape.
78
+ Each element is a power of 2.
79
+
80
+ Note:
81
+ For each shape, fixed dimensions (non-tiling) are multiplied together to get unit_param.
82
+ The final block size is computed in compute_default_tiling_strategy by taking
83
+ min(desired_block_size, max_safe_block_size) where desired_block_size = triton.next_power_of_2(original_dim).
84
+ """
85
+ if not shapes or not tiling_dims:
86
+ return ()
87
+
88
+ # Calculate max_safe_block_size for each tiling dimension
89
+ max_safe_sizes = []
90
+
91
+ for shape, tiling_dim in zip(shapes, tiling_dims):
92
+ # Normalize tiling_dim to a set of dimension indices
93
+ tiling_dim_set = _normalize_tiling_dims(tiling_dim)
94
+
95
+ # Validate tiling dimensions are within shape bounds
96
+ if not tiling_dim_set:
97
+ raise ValueError(
98
+ f"Invalid tiling_dim: {tiling_dim}. tiling_dim must be an int or a non-empty tuple of ints."
99
+ )
100
+ if any(dim_idx < 0 or dim_idx >= len(shape) for dim_idx in tiling_dim_set):
101
+ raise ValueError(
102
+ f"Invalid tiling_dim: {tiling_dim} for shape {shape}. "
103
+ f"All dimension indices must be in range [0, {len(shape)})."
104
+ )
105
+
106
+ # Calculate unit_param: product of fixed (non-tiling) dimensions
107
+ unit_param = 1.0
108
+ for dim_idx, dim_size in enumerate(shape):
109
+ if dim_idx not in tiling_dim_set:
110
+ if dim_size <= 0:
111
+ # Invalid dimension size, use conservative default
112
+ unit_param = 1.0
113
+ break
114
+ unit_param *= float(dim_size)
115
+
116
+ # Ensure unit_param is at least 1.0
117
+ if unit_param <= 0:
118
+ unit_param = 1.0
119
+
120
+ # Calculate maximum safe block size based on UB capacity
121
+ # Memory: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 bits
122
+ SAFE_UB_CAPACITY_BITS = int(ub_capacity_bits * safety_margin)
123
+
124
+ # Solve: memory_multiplier * BLOCK_SIZE * unit_param * dtype_size * 8 <= SAFE_UB_CAPACITY_BITS
125
+ # BLOCK_SIZE <= SAFE_UB_CAPACITY_BITS / (memory_multiplier * unit_param * dtype_size * 8)
126
+ max_block_size = int(SAFE_UB_CAPACITY_BITS // (memory_multiplier * unit_param * dtype_size * 8))
127
+ max_block_size = max(1, max_block_size)
128
+
129
+ # Find largest power of 2 <= max_block_size
130
+ # Use triton.next_power_of_2(max_block_size + 1) // 2 to get the largest power of 2 <= max_block_size
131
+ safe_block_size = triton.next_power_of_2(max_block_size + 1) // 2
132
+ max_safe_sizes.append(safe_block_size)
133
+
134
+ return tuple(max_safe_sizes)
135
+
136
+
137
+ class UBManager:
138
+ """
139
+ Unified Buffer Manager for Ascend NPU.
140
+
141
+ Provides UB capacity detection and management for Ascend NPU devices.
142
+ The UB capacity is used by tiling strategy functions to calculate optimal block sizes.
143
+ """
144
+
145
+ def __init__(self, ub_capacity_bits: Optional[int] = None):
146
+ """
147
+ Initialize UB Manager.
148
+
149
+ Args:
150
+ ub_capacity_bits: UB capacity in bits. If None, will be detected automatically.
151
+ """
152
+ self._npu_model = self._detect_npu_model()
153
+ self._ub_capacity_bits = ub_capacity_bits or self._detect_ub_capacity()
154
+
155
+ @property
156
+ def ub_capacity_bits(self) -> int:
157
+ """Get UB capacity in bits."""
158
+ return self._ub_capacity_bits
159
+
160
+ @property
161
+ def ub_capacity_bytes(self) -> int:
162
+ """Get UB capacity in bytes."""
163
+ return self._ub_capacity_bits // 8
164
+
165
+ @property
166
+ def npu_model(self) -> str:
167
+ """Get detected NPU model name."""
168
+ return self._npu_model
169
+
170
+ def _detect_npu_model(self) -> str:
171
+ """Detect NPU model from device properties."""
172
+ if not is_npu_available():
173
+ return "unknown"
174
+
175
+ try:
176
+ dev_props = torch.npu.get_device_properties(0)
177
+ # Try to get model name from device properties
178
+ return dev_props.name
179
+ except Exception:
180
+ pass
181
+
182
+ return "default"
183
+
184
+ def _detect_ub_capacity(self) -> int:
185
+ """
186
+ Detect UB capacity from environment variable or device properties.
187
+
188
+ Returns:
189
+ UB capacity in bits.
190
+ """
191
+ # Check environment variable first
192
+ env_capacity = os.getenv("ASCEND_UB_CAPACITY_BITS")
193
+ if env_capacity is not None:
194
+ try:
195
+ return int(env_capacity)
196
+ except ValueError:
197
+ pass
198
+
199
+ # Try to get from device properties
200
+ if is_npu_available():
201
+ try:
202
+ dev_props = torch.npu.get_device_properties(0)
203
+ if hasattr(dev_props, "ub_capacity_bits"):
204
+ return dev_props.ub_capacity_bits
205
+ except Exception:
206
+ pass
207
+
208
+ # Fall back to model-based defaults
209
+ model = self._npu_model
210
+ return _DEFAULT_UB_CAPACITIES.get(model, _DEFAULT_UB_CAPACITIES["default"])
211
+
212
+
213
+ # Global singleton instance
214
+ _ub_manager: Optional[UBManager] = None
215
+
216
+
217
+ def get_ub_manager() -> UBManager:
218
+ """Get global UB manager instance."""
219
+ global _ub_manager
220
+ if _ub_manager is None:
221
+ _ub_manager = UBManager()
222
+ return _ub_manager
223
+
224
+
225
+ def compute_default_tiling_strategy(
226
+ safety_margin: float = 0.80,
227
+ dtype_size: Optional[int] = None,
228
+ memory_multiplier: Optional[float] = None,
229
+ shapes: Optional[Tuple[Tuple[int, ...], ...]] = None,
230
+ tiling_dims: Optional[Tuple[Union[int, Tuple[int, ...]], ...]] = None,
231
+ ) -> Optional[Tuple[Tuple[int, ...], ...]]:
232
+ """
233
+ Compute tiling strategy using the default strategy function.
234
+
235
+ This function directly calls the default strategy and computes the final
236
+ tiling result. All kernels use the same unified strategy function, so
237
+ there's no need for kernel_name-based lookup.
238
+
239
+ Args:
240
+ safety_margin: Safety margin as a float (e.g., 0.80 for 80%). Default is 0.80.
241
+ dtype_size: Size of data type in bytes (e.g., 2 for float16, 4 for float32).
242
+ Must be provided. If None or <= 0, defaults to 4 (float32).
243
+ memory_multiplier: Memory multiplier for estimating peak memory usage.
244
+ - For GEGLU: typically 10.0 for backward, 7.0 for forward
245
+ - For ROPE: typically 3.0
246
+ If None, defaults to 10.0 (conservative estimate).
247
+ shapes: Tuple of full shapes. Each shape is a tuple of dimension sizes.
248
+ - For ROPE: ((n_q_head, hd), (n_kv_head, hd))
249
+ - For GEGLU: ((n_cols,),)
250
+ Can pass original shapes (will handle padding internally) or padded shapes.
251
+ tiling_dims: Tuple specifying which dimensions can be tiled for each shape.
252
+ Each element can be:
253
+ - int: single dimension index (e.g., 0 for first dimension)
254
+ - tuple of ints: multiple dimensions that can be tiled together
255
+ - For ROPE: (0, 0) means first dimension of each shape can be tiled
256
+ - For GEGLU: (0,) means first dimension of the shape can be tiled
257
+ Length must match len(shapes). Cannot be empty.
258
+
259
+ Returns:
260
+ Tuple of tiled shapes with same structure as input shapes.
261
+ Tiling dimensions are replaced with computed block sizes (power of 2),
262
+ while non-tiling dimensions are padded to next power of 2.
263
+ - For ROPE: ((block_size_q, pad_hd), (block_size_kv, pad_hd))
264
+ - For GEGLU: ((block_size,),)
265
+ Returns None if shapes or tiling_dims is None or empty.
266
+
267
+ Examples:
268
+ >>> # ROPE forward
269
+ >>> strategy = compute_default_tiling_strategy(
270
+ ... safety_margin=0.90,
271
+ ... dtype_size=4,
272
+ ... memory_multiplier=3.0,
273
+ ... shapes=((32, 128), (32, 128)),
274
+ ... tiling_dims=(0, 0)
275
+ ... )
276
+ >>> # Returns: ((block_size_q, 128), (block_size_kv, 128))
277
+ >>> # GEGLU forward
278
+ >>> strategy = compute_default_tiling_strategy(
279
+ ... safety_margin=0.80,
280
+ ... dtype_size=2,
281
+ ... memory_multiplier=7.0,
282
+ ... shapes=((4096,),),
283
+ ... tiling_dims=(0,)
284
+ ... )
285
+ >>> # Returns: ((block_size,),)
286
+ """
287
+ ub_manager = get_ub_manager()
288
+
289
+ if shapes is None or not shapes or tiling_dims is None or not tiling_dims:
290
+ return None
291
+
292
+ if len(shapes) != len(tiling_dims):
293
+ return None
294
+
295
+ if dtype_size is None or dtype_size <= 0:
296
+ dtype_size = 4 # Default to float32
297
+
298
+ if memory_multiplier is None or memory_multiplier <= 0:
299
+ memory_multiplier = 10.0 # Default conservative estimate
300
+
301
+ # Call strategy to get max_safe_block_size for each shape
302
+ max_supported = _default_strategy(
303
+ ub_manager.ub_capacity_bits,
304
+ safety_margin,
305
+ dtype_size,
306
+ memory_multiplier,
307
+ shapes,
308
+ tiling_dims,
309
+ )
310
+
311
+ if not max_supported or len(max_supported) != len(shapes):
312
+ return None
313
+
314
+ # Build result: same structure as shapes, with tiling dims replaced by computed block sizes
315
+ result = []
316
+ for shape, tiling_dim, max_safe in zip(shapes, tiling_dims, max_supported):
317
+ result_shape = list(shape)
318
+
319
+ # Normalize tiling_dim to a set of dimension indices
320
+ tiling_dim_set = _normalize_tiling_dims(tiling_dim)
321
+
322
+ # Validate tiling dimensions are within shape bounds
323
+ if not tiling_dim_set:
324
+ raise ValueError(
325
+ f"Invalid tiling_dim: {tiling_dim}. tiling_dim must be an int or a non-empty tuple of ints."
326
+ )
327
+ if any(dim_idx < 0 or dim_idx >= len(result_shape) for dim_idx in tiling_dim_set):
328
+ raise ValueError(
329
+ f"Invalid tiling_dim: {tiling_dim} for shape {shape}. "
330
+ f"All dimension indices must be in range [0, {len(result_shape)})."
331
+ )
332
+
333
+ # Replace tiling dimensions with computed block sizes
334
+ # For each tiling dimension, compute: min(desired, max_safe)
335
+ for dim_idx in tiling_dim_set:
336
+ original_dim = result_shape[dim_idx]
337
+ desired = triton.next_power_of_2(original_dim)
338
+ final_val = min(desired, max_safe)
339
+ final_val = max(1, final_val) # Ensure at least 1
340
+ result_shape[dim_idx] = final_val
341
+
342
+ # Pad non-tiling dimensions to next power of 2
343
+ for dim_idx, dim_size in enumerate(result_shape):
344
+ if dim_idx not in tiling_dim_set:
345
+ result_shape[dim_idx] = triton.next_power_of_2(dim_size)
346
+
347
+ result.append(tuple(result_shape))
348
+
349
+ return tuple(result)
@@ -0,0 +1,61 @@
1
+ """
2
+ Vendor registry for Liger-Kernel multi-backend support.
3
+
4
+ This module defines VendorInfo and the registry for vendor registration.
5
+ Each vendor registers itself by calling register_vendor() in its __init__.py.
6
+ """
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Optional
10
+
11
+ # Dynamically get backends package path to avoid hardcoding
12
+ _BACKENDS_PACKAGE = __name__.rsplit(".", 1)[0] # "liger_kernel.ops.backends"
13
+
14
+
15
+ @dataclass
16
+ class VendorInfo:
17
+ """
18
+ Information about a chip vendor and its supported device.
19
+
20
+ Attributes:
21
+ vendor: Vendor name (e.g., "ascend", "intel", "nvidia")
22
+ device: Device type this vendor supports (e.g., "npu", "xpu")
23
+ """
24
+
25
+ vendor: str
26
+ device: str
27
+
28
+ @property
29
+ def module_path(self) -> str:
30
+ """Auto-generated module path based on vendor name."""
31
+ return f"{_BACKENDS_PACKAGE}._{self.vendor}.ops"
32
+
33
+
34
+ # Registry mapping device types to their vendor info
35
+ # Vendors register themselves via register_vendor()
36
+ VENDOR_REGISTRY: dict[str, VendorInfo] = {}
37
+
38
+
39
+ def register_vendor(vendor_info: VendorInfo) -> None:
40
+ """
41
+ Register a vendor's info in the global registry.
42
+
43
+ This should be called in each vendor's __init__.py to register itself.
44
+
45
+ Args:
46
+ vendor_info: VendorInfo instance to register
47
+ """
48
+ VENDOR_REGISTRY[vendor_info.device] = vendor_info
49
+
50
+
51
+ def get_vendor_for_device(device: str) -> Optional[VendorInfo]:
52
+ """
53
+ Get the VendorInfo for a given device type.
54
+
55
+ Args:
56
+ device: Device type (e.g., "npu", "xpu")
57
+
58
+ Returns:
59
+ VendorInfo if found, None otherwise
60
+ """
61
+ return VENDOR_REGISTRY.get(device)
@@ -143,13 +143,16 @@ def liger_cross_entropy_kernel(
143
143
  block_max = tl.max(X_block)
144
144
 
145
145
  # Track argmax for accuracy computation
146
- if RETURN_TOKEN_ACCURACY and block_max > m:
146
+ if RETURN_TOKEN_ACCURACY:
147
147
  # Find the index of the maximum value in this block
148
148
  is_max_mask = X_block == block_max
149
149
  # Mask out invalid indices with a value larger than n_cols
150
150
  masked_offsets = tl.where(is_max_mask, X_offsets, n_cols)
151
151
  # Get the first (smallest) index where max occurs
152
- argmax_idx = tl.min(masked_offsets)
152
+ current_block_argmax_idx = tl.min(masked_offsets)
153
+
154
+ is_new_max = block_max > m
155
+ argmax_idx = tl.where(is_new_max, current_block_argmax_idx, argmax_idx)
153
156
 
154
157
  if label_smoothing > 0:
155
158
  # scale X beforehand to avoid overflow
@@ -289,7 +292,13 @@ def liger_cross_entropy_kernel(
289
292
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
290
293
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
291
294
  # The optimal maximum block size depends on your hardware, your kernel, and your dtype
292
- MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning
295
+ # the best size we found by manually tuning on xpu and npu.
296
+ if infer_device() == "xpu":
297
+ MAX_FUSED_SIZE = 4096
298
+ elif infer_device() == "npu":
299
+ MAX_FUSED_SIZE = 2048
300
+ else:
301
+ MAX_FUSED_SIZE = 65536 // 2
293
302
 
294
303
 
295
304
  def cross_entropy_forward(
@@ -6,11 +6,12 @@ from liger_kernel.ops.utils import amp_custom_bwd
6
6
  from liger_kernel.ops.utils import amp_custom_fwd
7
7
  from liger_kernel.ops.utils import element_mul_kernel
8
8
  from liger_kernel.ops.utils import is_hip
9
+ from liger_kernel.utils import infer_device
9
10
 
10
11
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
11
12
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
12
13
  # The optimal maximum block size depends on your hardware, your kernel, and your dtype
13
- MAX_FUSED_SIZE = 65536 // 2
14
+ MAX_FUSED_SIZE = 2048 if infer_device() == "npu" else 65536 // 2
14
15
 
15
16
 
16
17
  def fused_linear_cross_entropy_forward(
liger_kernel/ops/geglu.py CHANGED
@@ -67,8 +67,9 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
67
67
  tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
68
68
  tanh_result = tanh(tanh_arg)
69
69
  geglu_a = 0.5 * a_row * (1 + tanh_result)
70
+ geglu_a = geglu_a.to(dc_row.dtype).to(tl.float32)
70
71
 
71
- db_row = dc_row * geglu_a
72
+ db_row = dc_row.cast(tl.float32) * geglu_a
72
73
 
73
74
  # Gradient w.r.t. a can be computed with:
74
75
  # b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
@@ -79,7 +80,7 @@ def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SI
79
80
  da_row = dc_row * b_row * (term1 + term2)
80
81
 
81
82
  tl.store(a + col_offsets, da_row, mask=mask)
82
- tl.store(b + col_offsets, db_row, mask=mask)
83
+ tl.store(b + col_offsets, db_row.to(dc_row.dtype), mask=mask)
83
84
 
84
85
 
85
86
  def geglu_forward(a, b):