liger-kernel 0.6.4__py3-none-any.whl → 0.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +7 -1
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +10 -3
  3. liger_kernel/chunked_loss/jsd_loss.py +21 -6
  4. liger_kernel/ops/__init__.py +141 -0
  5. liger_kernel/ops/backends/README.md +151 -0
  6. liger_kernel/ops/backends/__init__.py +13 -0
  7. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  8. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +492 -0
  9. liger_kernel/ops/backends/_ascend/ops/__init__.py +61 -0
  10. liger_kernel/ops/backends/_ascend/ops/embedding.py +214 -0
  11. liger_kernel/ops/backends/_ascend/ops/geglu.py +191 -0
  12. liger_kernel/ops/backends/_ascend/ops/llama4_rope.py +298 -0
  13. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +275 -0
  14. liger_kernel/ops/backends/_ascend/ops/rope.py +265 -0
  15. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  16. liger_kernel/ops/backends/_ascend/ops/tvd.py +223 -0
  17. liger_kernel/ops/backends/_ascend/ub_manager.py +367 -0
  18. liger_kernel/ops/backends/registry.py +61 -0
  19. liger_kernel/ops/cross_entropy.py +14 -4
  20. liger_kernel/ops/dyt.py +5 -2
  21. liger_kernel/ops/fused_add_rms_norm.py +21 -23
  22. liger_kernel/ops/fused_linear_cross_entropy.py +2 -1
  23. liger_kernel/ops/geglu.py +5 -3
  24. liger_kernel/ops/group_norm.py +12 -8
  25. liger_kernel/ops/kl_div.py +8 -11
  26. liger_kernel/ops/layer_norm.py +17 -16
  27. liger_kernel/ops/poly_norm.py +19 -21
  28. liger_kernel/ops/rms_norm.py +149 -71
  29. liger_kernel/ops/utils.py +25 -0
  30. liger_kernel/transformers/__init__.py +6 -0
  31. liger_kernel/transformers/auto_model.py +21 -0
  32. liger_kernel/transformers/cross_entropy.py +1 -1
  33. liger_kernel/transformers/dyt.py +1 -1
  34. liger_kernel/transformers/experimental/embedding.py +1 -1
  35. liger_kernel/transformers/functional.py +20 -20
  36. liger_kernel/transformers/fused_add_rms_norm.py +1 -1
  37. liger_kernel/transformers/fused_linear_cross_entropy.py +1 -1
  38. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  39. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  40. liger_kernel/transformers/geglu.py +1 -1
  41. liger_kernel/transformers/group_norm.py +1 -1
  42. liger_kernel/transformers/grpo_loss.py +1 -1
  43. liger_kernel/transformers/jsd.py +1 -1
  44. liger_kernel/transformers/kl_div.py +1 -1
  45. liger_kernel/transformers/layer_norm.py +1 -1
  46. liger_kernel/transformers/llama4_rope.py +1 -1
  47. liger_kernel/transformers/model/exaone4.py +136 -0
  48. liger_kernel/transformers/model/gemma2.py +3 -3
  49. liger_kernel/transformers/model/gemma3.py +11 -5
  50. liger_kernel/transformers/model/gpt_oss.py +211 -0
  51. liger_kernel/transformers/model/loss_utils.py +6 -0
  52. liger_kernel/transformers/model/paligemma.py +1 -0
  53. liger_kernel/transformers/monkey_patch.py +196 -39
  54. liger_kernel/transformers/multi_token_attention.py +1 -1
  55. liger_kernel/transformers/poly_norm.py +1 -1
  56. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  57. liger_kernel/transformers/rms_norm.py +8 -3
  58. liger_kernel/transformers/rope.py +28 -27
  59. liger_kernel/transformers/softmax.py +1 -1
  60. liger_kernel/transformers/sparsemax.py +1 -1
  61. liger_kernel/transformers/swiglu.py +1 -1
  62. liger_kernel/transformers/tiled_mlp.py +5 -13
  63. liger_kernel/transformers/tvd.py +1 -1
  64. liger_kernel/utils.py +54 -0
  65. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/METADATA +11 -4
  66. liger_kernel-0.6.5.dist-info/RECORD +134 -0
  67. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/WHEEL +1 -1
  68. liger_kernel-0.6.4.dist-info/RECORD +0 -118
  69. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/licenses/LICENSE +0 -0
  70. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/licenses/NOTICE +0 -0
  71. {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,214 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
6
+ from liger_kernel.ops.utils import ensure_contiguous
7
+ from liger_kernel.ops.utils import get_npu_core_count
8
+
9
+
10
+ @triton.jit
11
+ def embedding_forward_kernel(
12
+ embeddings_ptr,
13
+ indices_ptr,
14
+ output_ptr,
15
+ n_elements,
16
+ embedding_dim: tl.constexpr,
17
+ BLOCK_SIZE_M: tl.constexpr,
18
+ BLOCK_SIZE_N: tl.constexpr,
19
+ NUM_STAGES: tl.constexpr,
20
+ ):
21
+ pid = tl.program_id(0)
22
+ num_progs = tl.num_programs(0)
23
+
24
+ grid_m = tl.cdiv(n_elements, BLOCK_SIZE_M)
25
+ grid_n = tl.cdiv(embedding_dim, BLOCK_SIZE_N)
26
+ total_2d_blocks = grid_m * grid_n
27
+
28
+ for block_idx in tl.range(pid, total_2d_blocks, num_progs, num_stages=NUM_STAGES):
29
+ block_m = block_idx // grid_n
30
+ block_n = block_idx % grid_n
31
+
32
+ start_m = block_m * BLOCK_SIZE_M
33
+ start_n = block_n * BLOCK_SIZE_N
34
+
35
+ offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M)
36
+ mask_m = offsets_m < n_elements
37
+
38
+ indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0)
39
+
40
+ offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N)
41
+ mask_n = offsets_n < embedding_dim
42
+
43
+ block_mask = mask_m[:, None] & mask_n[None, :]
44
+
45
+ embedding_offsets = indices[:, None] * embedding_dim + offsets_n[None, :]
46
+ embeddings = tl.load(
47
+ embeddings_ptr + embedding_offsets,
48
+ mask=block_mask,
49
+ other=0.0,
50
+ )
51
+
52
+ output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :]
53
+ tl.store(
54
+ output_ptr + output_offsets,
55
+ embeddings,
56
+ mask=block_mask,
57
+ )
58
+
59
+
60
+ @triton.jit
61
+ def embedding_backward_kernel(
62
+ grad_output_ptr,
63
+ grad_weight_ptr,
64
+ indices_ptr,
65
+ n_elements,
66
+ embedding_dim: tl.constexpr,
67
+ BLOCK_SIZE_M: tl.constexpr,
68
+ BLOCK_SIZE_N: tl.constexpr,
69
+ NUM_STAGES: tl.constexpr,
70
+ ):
71
+ pid = tl.program_id(0)
72
+ num_progs = tl.num_programs(0)
73
+
74
+ grid_m = tl.cdiv(n_elements, BLOCK_SIZE_M)
75
+ grid_n = tl.cdiv(embedding_dim, BLOCK_SIZE_N)
76
+ total_2d_blocks = grid_m * grid_n
77
+
78
+ for block_idx in tl.range(pid, total_2d_blocks, num_progs, num_stages=NUM_STAGES):
79
+ block_m = block_idx // grid_n
80
+ block_n = block_idx % grid_n
81
+
82
+ start_m = block_m * BLOCK_SIZE_M
83
+ start_n = block_n * BLOCK_SIZE_N
84
+
85
+ offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M)
86
+ mask_m = offsets_m < n_elements
87
+
88
+ indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0)
89
+
90
+ offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N)
91
+ mask_n = offsets_n < embedding_dim
92
+
93
+ block_mask = mask_m[:, None] & mask_n[None, :]
94
+
95
+ grad_output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :]
96
+ grad_output = tl.load(
97
+ grad_output_ptr + grad_output_offsets,
98
+ mask=block_mask,
99
+ other=0.0,
100
+ )
101
+
102
+ grad_weight_offsets = indices[:, None] * embedding_dim + offsets_n[None, :]
103
+ tl.atomic_add(
104
+ grad_weight_ptr + grad_weight_offsets,
105
+ grad_output,
106
+ mask=block_mask,
107
+ )
108
+
109
+
110
+ def get_optimal_block_size(total_elements, dtype_size, BLOCK_SIZE_N: tl.constexpr):
111
+ # 1. Set Memory Multiplier
112
+ # 3.0 are empirical values based on 910B UB (192KB)
113
+ # embedding_offsets, embedding_offsets : BLOCK_SIZE_N * BLOCK_SIZE_M (total 2 * BLOCK_SIZE_N * BLOCK_SIZE_M)
114
+ # Reserve a unit of space for the remaining one-dimensional ub to occupy.
115
+ # A conservative estimate of the total space occupation is 3 * BLOCK_SIZE_N * BLOCK_SIZE_M
116
+ multiplier = 3.0
117
+
118
+ # 2. Call calculation function
119
+ # Treat input as 1D (total_elements,), only tiling on dim 0
120
+ tile_shapes = compute_default_tiling_strategy(
121
+ safety_margin=0.9,
122
+ dtype_size=dtype_size,
123
+ memory_multiplier=multiplier,
124
+ shapes=((total_elements, BLOCK_SIZE_N),),
125
+ tiling_dims=(0,),
126
+ )
127
+
128
+ # 3. Parse result
129
+ if tile_shapes and len(tile_shapes) > 0:
130
+ block_size = tile_shapes[0][0]
131
+ return block_size
132
+ else:
133
+ return triton.next_power_of_2(min(128, total_elements))
134
+
135
+
136
+ def embedding_forward(embeddings, indices):
137
+ ori_shape = indices.shape
138
+ indices = indices.view(-1)
139
+
140
+ n_elements = indices.numel()
141
+ embedding_dim = embeddings.shape[1]
142
+ output = torch.empty(
143
+ indices.shape[0],
144
+ embeddings.shape[1],
145
+ device=indices.device,
146
+ dtype=embeddings.dtype,
147
+ )
148
+
149
+ # Due to the involvement of two-dimensional partitioning,
150
+ # the sizes of block_m and block_n in the ub space will influence each other.
151
+ # Considering that embedding_dim is usually relatively smaller in most cases,
152
+ # a value is first assigned to block_n, and then the largest possible block_m is used.
153
+ BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim))
154
+ BLOCK_SIZE_M = get_optimal_block_size(n_elements, embeddings.element_size(), BLOCK_SIZE_N)
155
+ num_cores = get_npu_core_count()
156
+ total_blocks = triton.cdiv(n_elements, BLOCK_SIZE_M) * triton.cdiv(embedding_dim, BLOCK_SIZE_N)
157
+ grid = min(num_cores, total_blocks)
158
+
159
+ embedding_forward_kernel[(grid,)](
160
+ embeddings,
161
+ indices,
162
+ output,
163
+ n_elements,
164
+ embedding_dim=embedding_dim,
165
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
166
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
167
+ NUM_STAGES=3,
168
+ )
169
+
170
+ return output.view(*ori_shape, -1)
171
+
172
+
173
+ def embedding_backward(embeddings, indices, grad_output):
174
+ grad_output = grad_output.contiguous().view(-1, embeddings.shape[1])
175
+
176
+ grad_weight = torch.zeros_like(embeddings)
177
+
178
+ n_elements = indices.numel()
179
+ embedding_dim = embeddings.shape[1]
180
+ BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim))
181
+ BLOCK_SIZE_M = get_optimal_block_size(n_elements, embeddings.element_size(), BLOCK_SIZE_N)
182
+ num_cores = get_npu_core_count()
183
+ total_blocks = triton.cdiv(n_elements, BLOCK_SIZE_M) * triton.cdiv(embedding_dim, BLOCK_SIZE_N)
184
+ grid = min(num_cores, total_blocks)
185
+
186
+ embedding_backward_kernel[(grid,)](
187
+ grad_output,
188
+ grad_weight,
189
+ indices,
190
+ n_elements,
191
+ embedding_dim=embedding_dim,
192
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
193
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
194
+ NUM_STAGES=3,
195
+ )
196
+
197
+ return grad_weight
198
+
199
+
200
+ class LigerEmbeddingFunction(torch.autograd.Function):
201
+ @staticmethod
202
+ @ensure_contiguous
203
+ def forward(ctx, embeddings: torch.Tensor, indices: torch.Tensor):
204
+ output = embedding_forward(embeddings, indices)
205
+ ctx.save_for_backward(indices, embeddings)
206
+ return output
207
+
208
+ @staticmethod
209
+ @ensure_contiguous
210
+ def backward(ctx, grad_output: torch.Tensor):
211
+ indices, embeddings = ctx.saved_tensors
212
+ grad_weight = embedding_backward(embeddings, indices, grad_output)
213
+
214
+ return grad_weight, None
@@ -0,0 +1,191 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from triton.language.math import tanh
6
+
7
+ from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
8
+ from liger_kernel.ops.utils import ensure_contiguous
9
+ from liger_kernel.ops.utils import get_npu_core_count
10
+
11
+
12
+ @triton.jit
13
+ def _geglu_forward_kernel_flat(a_ptr, b_ptr, c_ptr, total_elements, BLOCK_SIZE: tl.constexpr, NUM_STAGES: tl.constexpr):
14
+ """
15
+ High-performance GEGLU forward kernel using flatten 1D approach.
16
+
17
+ Uses grid-stride loop pattern for optimal performance on NPU.
18
+ """
19
+ pid = tl.program_id(0)
20
+ num_progs = tl.num_programs(0)
21
+
22
+ # Grid-Stride Loop
23
+ start_idx = pid * BLOCK_SIZE
24
+ stride = num_progs * BLOCK_SIZE
25
+
26
+ # Constants for GELU tanh approximation
27
+ sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
28
+ gelu_coeff = 0.044715
29
+
30
+ for idx in tl.range(start_idx, total_elements, stride, num_stages=NUM_STAGES):
31
+ offsets = idx + tl.arange(0, BLOCK_SIZE)
32
+ mask = offsets < total_elements
33
+
34
+ a_val = tl.load(a_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
35
+ b_val = tl.load(b_ptr + offsets, mask=mask, other=0.0)
36
+
37
+ # tanh approximation form of GELU is computed with:
38
+ # 0.5 * a * (1 + tanh(sqrt(2 / pi) * (a + 0.044715 * a^3)))
39
+ a_cubed = a_val * a_val * a_val
40
+ tanh_arg = sqrt_2_over_pi * (a_val + gelu_coeff * a_cubed)
41
+ tanh_result = tanh(tanh_arg)
42
+ geglu_a = 0.5 * a_val * (1.0 + tanh_result)
43
+ c_row = geglu_a.cast(b_val.dtype) * b_val
44
+ tl.store(c_ptr + offsets, c_row, mask=mask)
45
+
46
+
47
+ @triton.jit
48
+ def _geglu_backward_kernel_flat(
49
+ dc_ptr, a_ptr, b_ptr, da_ptr, db_ptr, total_elements, BLOCK_SIZE: tl.constexpr, NUM_STAGES: tl.constexpr
50
+ ):
51
+ """
52
+ High-performance GEGLU backward kernel using flatten 1D approach.
53
+
54
+ Uses grid-stride loop pattern for optimal performance on NPU.
55
+ """
56
+ pid = tl.program_id(0)
57
+ num_progs = tl.num_programs(0)
58
+ start_idx = pid * BLOCK_SIZE
59
+ stride = num_progs * BLOCK_SIZE
60
+
61
+ # Constants for GELU tanh approximation
62
+ sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
63
+ gelu_coeff = 0.044715
64
+
65
+ for idx in tl.range(start_idx, total_elements, stride, num_stages=NUM_STAGES):
66
+ offsets = idx + tl.arange(0, BLOCK_SIZE)
67
+ mask = offsets < total_elements
68
+
69
+ dc = tl.load(dc_ptr + offsets, mask=mask, other=0.0)
70
+ a = tl.load(a_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
71
+ b = tl.load(b_ptr + offsets, mask=mask, other=0.0)
72
+
73
+ # recomputation to save memory
74
+ a_cubed = a * a * a
75
+ tanh_arg = sqrt_2_over_pi * (a + gelu_coeff * a_cubed)
76
+ tanh_result = tanh(tanh_arg)
77
+ geglu_a = 0.5 * a * (1 + tanh_result)
78
+ geglu_a = geglu_a.to(dc.dtype).to(tl.float32)
79
+
80
+ db = dc.cast(tl.float32) * geglu_a
81
+
82
+ # Gradient w.r.t. a can be computed with:
83
+ # b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
84
+ # where z = sqrt(2/pi) * (a + 0.044715 * a^3)
85
+ term1 = 0.5 * (1.0 + tanh_result)
86
+ tanh_sq = tanh_result * tanh_result
87
+ a_sq = a * a
88
+ term2 = 0.5 * a * (1.0 - tanh_sq) * (sqrt_2_over_pi * (1.0 + 3.0 * gelu_coeff * a_sq))
89
+ da = dc * b * (term1 + term2)
90
+
91
+ tl.store(da_ptr + offsets, da, mask=mask)
92
+ tl.store(db_ptr + offsets, db.to(dc.dtype), mask=mask)
93
+
94
+
95
+ def get_optimal_block_size(total_elements, is_backward=False):
96
+ """
97
+ Calculate optimal Block Size using compute_default_tiling_strategy.
98
+
99
+ Args:
100
+ total_elements: Total number of elements to process
101
+ is_backward: Whether this is for backward pass (requires more memory)
102
+
103
+ Returns:
104
+ Optimal block size for the kernel
105
+ """
106
+ # Memory multiplier based on peak memory usage analysis
107
+ if is_backward:
108
+ memory_multiplier = 6.0
109
+ else:
110
+ memory_multiplier = 3.0
111
+ # Call calculation function
112
+ # Treat input as 1D (total_elements,), only tiling on dim 0
113
+ tile_shapes = compute_default_tiling_strategy(
114
+ safety_margin=0.9,
115
+ dtype_size=4,
116
+ memory_multiplier=memory_multiplier,
117
+ shapes=((total_elements,),),
118
+ tiling_dims=(0,),
119
+ )
120
+
121
+ # Parse result
122
+ if tile_shapes and len(tile_shapes) > 0:
123
+ block_size = tile_shapes[0][0]
124
+ return max(256, block_size)
125
+ else:
126
+ return 2048
127
+
128
+
129
+ def geglu_forward(a, b):
130
+ """
131
+ High-performance GEGLU forward pass for NPU using flatten 1D approach.
132
+ """
133
+ if not a.is_contiguous():
134
+ a = a.contiguous()
135
+ if not b.is_contiguous():
136
+ b = b.contiguous()
137
+
138
+ total_elements = a.numel()
139
+ c = torch.empty_like(a)
140
+
141
+ block_size = get_optimal_block_size(total_elements, is_backward=False)
142
+
143
+ num_cores = get_npu_core_count()
144
+ grid_size = min(num_cores, (total_elements + block_size - 1) // block_size)
145
+
146
+ _geglu_forward_kernel_flat[(grid_size,)](a, b, c, total_elements, BLOCK_SIZE=block_size, NUM_STAGES=3, num_warps=4)
147
+ return c
148
+
149
+
150
+ def geglu_backward(a, b, dc):
151
+ """
152
+ High-performance GEGLU backward pass for NPU using flatten 1D approach.
153
+ """
154
+ if not dc.is_contiguous():
155
+ dc = dc.contiguous()
156
+ if not a.is_contiguous():
157
+ a = a.contiguous()
158
+ if not b.is_contiguous():
159
+ b = b.contiguous()
160
+
161
+ total_elements = dc.numel()
162
+ grad_a = torch.empty_like(a)
163
+ grad_b = torch.empty_like(b)
164
+
165
+ block_size = get_optimal_block_size(total_elements, is_backward=True)
166
+
167
+ num_cores = get_npu_core_count()
168
+ grid_size = min(num_cores, (total_elements + block_size - 1) // block_size)
169
+
170
+ _geglu_backward_kernel_flat[(grid_size,)](
171
+ dc, a, b, grad_a, grad_b, total_elements, BLOCK_SIZE=block_size, NUM_STAGES=3, num_warps=4
172
+ )
173
+ return grad_a, grad_b
174
+
175
+
176
+ class LigerGELUMulFunction(torch.autograd.Function):
177
+ """High-performance GEGLU function for Ascend NPU."""
178
+
179
+ @staticmethod
180
+ @ensure_contiguous
181
+ def forward(ctx, a, b):
182
+ c = geglu_forward(a, b)
183
+ ctx.save_for_backward(a, b)
184
+ return c
185
+
186
+ @staticmethod
187
+ @ensure_contiguous
188
+ def backward(ctx, dc):
189
+ a, b = ctx.saved_tensors
190
+ grad_a, grad_b = geglu_backward(a, b, dc)
191
+ return grad_a, grad_b
@@ -0,0 +1,298 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
6
+
7
+
8
+ def _prepare_freqs(freqs_cis: torch.Tensor, seq_len: int, head_dim_half: int):
9
+ """
10
+ Canonicalize freqs to (seq_len, head_dim_half) real/imag tensors.
11
+
12
+ Supports:
13
+ - complex freqs: (..., head_dim_half) complex -> real/imag
14
+ - packed freqs: (..., 2*head_dim_half) real -> split into real/imag
15
+ """
16
+ if freqs_cis.is_complex():
17
+ freqs_real = freqs_cis.real
18
+ freqs_imag = freqs_cis.imag
19
+ else:
20
+ if freqs_cis.shape[-1] == 2 * head_dim_half:
21
+ freqs_real = freqs_cis[..., :head_dim_half]
22
+ freqs_imag = freqs_cis[..., head_dim_half:]
23
+ else:
24
+ raise ValueError(
25
+ f"Unexpected freqs_cis shape for non-complex input: {freqs_cis.shape}, "
26
+ f"expected last dim = {2 * head_dim_half}"
27
+ )
28
+
29
+ if freqs_real.shape[-1] != head_dim_half:
30
+ raise ValueError(f"Unexpected last dim for freqs: {freqs_real.shape[-1]} (expected {head_dim_half})")
31
+
32
+ # Flatten leading dims -> (N, head_dim_half)
33
+ freqs_real = freqs_real.reshape(-1, head_dim_half)
34
+ freqs_imag = freqs_imag.reshape(-1, head_dim_half)
35
+
36
+ # Broadcast/slice to (seq_len, head_dim_half)
37
+ if freqs_real.shape[0] < seq_len:
38
+ if freqs_real.shape[0] == 1:
39
+ freqs_real = freqs_real.expand(seq_len, -1)
40
+ freqs_imag = freqs_imag.expand(seq_len, -1)
41
+ else:
42
+ raise ValueError(f"Insufficient rows in freqs: {freqs_real.shape[0]} < seq_len={seq_len}")
43
+ elif freqs_real.shape[0] > seq_len:
44
+ freqs_real = freqs_real[:seq_len]
45
+ freqs_imag = freqs_imag[:seq_len]
46
+
47
+ return freqs_real, freqs_imag
48
+
49
+
50
+ def _cast_and_contiguous(q, k, freqs_real, freqs_imag):
51
+ # Align dtype: fp32 only when q is fp32; otherwise keep q dtype for perf
52
+ compute_dtype = torch.float32 if q.dtype == torch.float32 else q.dtype
53
+
54
+ if k.dtype != q.dtype:
55
+ k = k.to(q.dtype)
56
+
57
+ q = q.to(compute_dtype).contiguous()
58
+ k = k.to(compute_dtype).contiguous()
59
+ freqs_real = freqs_real.to(compute_dtype).contiguous()
60
+ freqs_imag = freqs_imag.to(compute_dtype).contiguous()
61
+ return q, k, freqs_real, freqs_imag, compute_dtype
62
+
63
+
64
+ @triton.jit
65
+ def _triton_llama4_rope_npu(
66
+ q_ptr,
67
+ k_ptr,
68
+ freqs_real_ptr,
69
+ freqs_imag_ptr,
70
+ q_row_stride,
71
+ k_row_stride,
72
+ q_head_stride,
73
+ k_head_stride,
74
+ freqs_row_stride,
75
+ sl,
76
+ bs: tl.constexpr,
77
+ n_qh: tl.constexpr,
78
+ n_kh: tl.constexpr,
79
+ hd: tl.constexpr,
80
+ BLOCK_Q: tl.constexpr,
81
+ BLOCK_K: tl.constexpr,
82
+ imag_sign: tl.constexpr,
83
+ ):
84
+ """
85
+ Llama4 RoPE on Ascend NPU for interleaved complex layout:
86
+ - q/k shape: (bs, sl, n_heads, hd)
87
+ - last dim layout: [real0, imag0, real1, imag1, ...]
88
+ - freqs_real/imag: (sl, hd//2)
89
+ """
90
+ pid = tl.program_id(0).to(tl.int64)
91
+ batch_idx = pid // sl
92
+ seq_idx = pid % sl
93
+
94
+ if batch_idx >= bs:
95
+ return
96
+
97
+ q_base = q_ptr + pid * q_row_stride
98
+ k_base = k_ptr + pid * k_row_stride
99
+
100
+ freq_base = seq_idx * freqs_row_stride
101
+ hd_idx = tl.arange(0, hd)
102
+ hd_mask = hd_idx < (hd)
103
+
104
+ freq_idx = tl.arange(0, hd // 2)
105
+ freq_mask = freq_idx < (hd // 2)
106
+
107
+ freqs_real = tl.load(freqs_real_ptr + freq_base + freq_idx, mask=freq_mask, other=0.0)
108
+ freqs_imag = tl.load(freqs_imag_ptr + freq_base + freq_idx, mask=freq_mask, other=0.0) * imag_sign
109
+
110
+ # Q heads (chunked for UB)
111
+ for qh_block in range(0, n_qh, BLOCK_Q):
112
+ qh_idx = tl.arange(0, BLOCK_Q) + qh_block
113
+ qh_mask = qh_idx < n_qh
114
+ block_mask = qh_mask[:, None] & hd_mask[None, :]
115
+
116
+ head_ptr = q_base + qh_idx[:, None] * q_head_stride
117
+
118
+ q_pair = tl.load(
119
+ head_ptr + hd_idx[None, :],
120
+ mask=block_mask,
121
+ other=0.0,
122
+ )
123
+ q_pair = q_pair.reshape(BLOCK_Q, hd // 2, 2, can_reorder=True)
124
+ q_real, q_imag = tl.split(q_pair)
125
+
126
+ new_real = tl.math.fma(q_real, freqs_real, -(q_imag * freqs_imag))
127
+ new_imag = tl.math.fma(q_real, freqs_imag, q_imag * freqs_real)
128
+ new_q_pair = tl.interleave(new_real, new_imag)
129
+
130
+ tl.store(head_ptr + hd_idx[None, :], new_q_pair, mask=block_mask)
131
+
132
+ # K heads (chunked for UB)
133
+ for kh_block in range(0, n_kh, BLOCK_K):
134
+ kh_idx = tl.arange(0, BLOCK_K) + kh_block
135
+ kh_mask = kh_idx < n_kh
136
+ block_mask = kh_mask[:, None] & hd_mask[None, :]
137
+
138
+ head_ptr = k_base + kh_idx[:, None] * k_head_stride
139
+
140
+ k_pair = tl.load(
141
+ head_ptr + hd_idx[None, :],
142
+ mask=block_mask,
143
+ other=0.0,
144
+ )
145
+
146
+ k_pair = k_pair.reshape(BLOCK_K, hd // 2, 2, can_reorder=True)
147
+ k_real, k_imag = tl.split(k_pair)
148
+
149
+ new_real = tl.math.fma(k_real, freqs_real, -(k_imag * freqs_imag))
150
+ new_imag = tl.math.fma(k_real, freqs_imag, k_imag * freqs_real)
151
+ new_k_pair = tl.interleave(new_real, new_imag)
152
+
153
+ tl.store(head_ptr + hd_idx[None, :], new_k_pair, mask=block_mask)
154
+
155
+
156
+ def llama4_rope_forward(q, k, freqs_cis):
157
+ """
158
+ Ascend NPU implementation of Llama4 RoPE.
159
+
160
+ q/k: (bs, sl, n_heads, hd) with interleaved complex last-dim layout.
161
+ freqs_cis: complex (..., hd//2) OR packed (..., 2*(hd//2)).
162
+ """
163
+ original_dtype = q.dtype
164
+
165
+ bs, sl, n_qh, hd = q.shape
166
+ _, _, n_kh, _ = k.shape
167
+ if hd % 2 != 0:
168
+ raise ValueError(f"head_dim must be even for interleaved complex layout, got {hd}")
169
+ hd_half = hd // 2
170
+
171
+ freqs_real, freqs_imag = _prepare_freqs(freqs_cis, sl, hd_half)
172
+ q, k, freqs_real, freqs_imag, compute_dtype = _cast_and_contiguous(q, k, freqs_real, freqs_imag)
173
+
174
+ # UB tiling strategy: tile heads dimension only
175
+ dtype_size = q.element_size()
176
+ shapes = ((n_qh, hd), (n_kh, hd))
177
+ tile_shapes = compute_default_tiling_strategy(
178
+ safety_margin=0.90,
179
+ dtype_size=dtype_size,
180
+ memory_multiplier=12.0,
181
+ shapes=shapes,
182
+ tiling_dims=(0, 0),
183
+ )
184
+
185
+ if tile_shapes is not None and len(tile_shapes) == len(shapes):
186
+ q_tile_shape, k_tile_shape = tile_shapes
187
+ BLOCK_Q, _ = q_tile_shape
188
+ BLOCK_K, _ = k_tile_shape
189
+ else:
190
+ BLOCK_Q = triton.next_power_of_2(n_qh)
191
+ BLOCK_K = triton.next_power_of_2(n_kh)
192
+
193
+ n_row = bs * sl
194
+
195
+ _triton_llama4_rope_npu[(n_row,)](
196
+ q,
197
+ k,
198
+ freqs_real,
199
+ freqs_imag,
200
+ q.stride(1),
201
+ k.stride(1),
202
+ q.stride(2),
203
+ k.stride(2),
204
+ freqs_real.stride(0),
205
+ sl,
206
+ bs,
207
+ n_qh,
208
+ n_kh,
209
+ hd,
210
+ BLOCK_Q,
211
+ BLOCK_K,
212
+ imag_sign=1.0,
213
+ )
214
+
215
+ if compute_dtype != original_dtype:
216
+ q = q.to(original_dtype)
217
+ k = k.to(original_dtype)
218
+ return q, k
219
+
220
+
221
+ def llama4_rope_backward(dq, dk, freqs_cis):
222
+ """
223
+ Ascend NPU implementation of Llama4 RoPE.
224
+
225
+ q/k: (bs, sl, n_heads, hd) with interleaved complex last-dim layout.
226
+ freqs_cis: complex (..., hd//2) OR packed (..., 2*(hd//2)).
227
+ """
228
+ original_dtype = dq.dtype
229
+
230
+ bs, sl, n_qh, hd = dq.shape
231
+ _, _, n_kh, _ = dk.shape
232
+ if hd % 2 != 0:
233
+ raise ValueError(f"head_dim must be even for interleaved complex layout, got {hd}")
234
+ hd_half = hd // 2
235
+
236
+ freqs_real, freqs_imag = _prepare_freqs(freqs_cis, sl, hd_half)
237
+ dq, dk, freqs_real, freqs_imag, compute_dtype = _cast_and_contiguous(dq, dk, freqs_real, freqs_imag)
238
+
239
+ # UB tiling strategy: tile heads dimension only
240
+ dtype_size = dq.element_size()
241
+ shapes = ((n_qh, hd), (n_kh, hd))
242
+ tile_shapes = compute_default_tiling_strategy(
243
+ safety_margin=0.90,
244
+ dtype_size=dtype_size,
245
+ memory_multiplier=12.0,
246
+ shapes=shapes,
247
+ tiling_dims=(0, 0),
248
+ )
249
+
250
+ if tile_shapes is not None and len(tile_shapes) == len(shapes):
251
+ q_tile_shape, k_tile_shape = tile_shapes
252
+ BLOCK_Q, _ = q_tile_shape
253
+ BLOCK_K, _ = k_tile_shape
254
+ else:
255
+ BLOCK_Q = triton.next_power_of_2(n_qh)
256
+ BLOCK_K = triton.next_power_of_2(n_kh)
257
+
258
+ n_row = bs * sl
259
+
260
+ _triton_llama4_rope_npu[(n_row,)](
261
+ dq,
262
+ dk,
263
+ freqs_real,
264
+ freqs_imag,
265
+ dq.stride(1),
266
+ dk.stride(1),
267
+ dq.stride(2),
268
+ dk.stride(2),
269
+ freqs_real.stride(0),
270
+ sl,
271
+ bs,
272
+ n_qh,
273
+ n_kh,
274
+ hd,
275
+ BLOCK_Q,
276
+ BLOCK_K,
277
+ imag_sign=-1.0,
278
+ )
279
+
280
+ if compute_dtype != original_dtype:
281
+ dq = dq.to(original_dtype)
282
+ dk = dk.to(original_dtype)
283
+ return dq, dk
284
+
285
+
286
+ class LigerLlama4RopeFunction(torch.autograd.Function):
287
+ @staticmethod
288
+ def forward(ctx, q, k, freqs_cis, BLOCK_SIZE: int = None):
289
+ # BLOCK_SIZE is ignored for Ascend (we auto-tile heads by UB), kept for API compatibility
290
+ q_out, k_out = llama4_rope_forward(q, k, freqs_cis)
291
+ ctx.save_for_backward(freqs_cis.detach() if isinstance(freqs_cis, torch.Tensor) else freqs_cis)
292
+ return q_out, k_out
293
+
294
+ @staticmethod
295
+ def backward(ctx, dq, dk):
296
+ (freqs_cis,) = ctx.saved_tensors
297
+ dq_out, dk_out = llama4_rope_backward(dq, dk, freqs_cis)
298
+ return dq_out, dk_out, None, None