liger-kernel 0.6.4__py3-none-any.whl → 0.6.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/cosine_similarity_loss.py +7 -1
- liger_kernel/chunked_loss/fused_linear_distillation.py +10 -3
- liger_kernel/chunked_loss/jsd_loss.py +21 -6
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +492 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +61 -0
- liger_kernel/ops/backends/_ascend/ops/embedding.py +214 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +191 -0
- liger_kernel/ops/backends/_ascend/ops/llama4_rope.py +298 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +275 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +265 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ops/tvd.py +223 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +367 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +14 -4
- liger_kernel/ops/dyt.py +5 -2
- liger_kernel/ops/fused_add_rms_norm.py +21 -23
- liger_kernel/ops/fused_linear_cross_entropy.py +2 -1
- liger_kernel/ops/geglu.py +5 -3
- liger_kernel/ops/group_norm.py +12 -8
- liger_kernel/ops/kl_div.py +8 -11
- liger_kernel/ops/layer_norm.py +17 -16
- liger_kernel/ops/poly_norm.py +19 -21
- liger_kernel/ops/rms_norm.py +149 -71
- liger_kernel/ops/utils.py +25 -0
- liger_kernel/transformers/__init__.py +6 -0
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +1 -1
- liger_kernel/transformers/dyt.py +1 -1
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/functional.py +20 -20
- liger_kernel/transformers/fused_add_rms_norm.py +1 -1
- liger_kernel/transformers/fused_linear_cross_entropy.py +1 -1
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +1 -1
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +1 -1
- liger_kernel/transformers/model/exaone4.py +136 -0
- liger_kernel/transformers/model/gemma2.py +3 -3
- liger_kernel/transformers/model/gemma3.py +11 -5
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/loss_utils.py +6 -0
- liger_kernel/transformers/model/paligemma.py +1 -0
- liger_kernel/transformers/monkey_patch.py +196 -39
- liger_kernel/transformers/multi_token_attention.py +1 -1
- liger_kernel/transformers/poly_norm.py +1 -1
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +8 -3
- liger_kernel/transformers/rope.py +28 -27
- liger_kernel/transformers/softmax.py +1 -1
- liger_kernel/transformers/sparsemax.py +1 -1
- liger_kernel/transformers/swiglu.py +1 -1
- liger_kernel/transformers/tiled_mlp.py +5 -13
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +54 -0
- {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/METADATA +11 -4
- liger_kernel-0.6.5.dist-info/RECORD +134 -0
- {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/WHEEL +1 -1
- liger_kernel-0.6.4.dist-info/RECORD +0 -118
- {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.6.4.dist-info → liger_kernel-0.6.5.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
|
|
6
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
7
|
+
from liger_kernel.ops.utils import get_npu_core_count
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@triton.jit
|
|
11
|
+
def embedding_forward_kernel(
|
|
12
|
+
embeddings_ptr,
|
|
13
|
+
indices_ptr,
|
|
14
|
+
output_ptr,
|
|
15
|
+
n_elements,
|
|
16
|
+
embedding_dim: tl.constexpr,
|
|
17
|
+
BLOCK_SIZE_M: tl.constexpr,
|
|
18
|
+
BLOCK_SIZE_N: tl.constexpr,
|
|
19
|
+
NUM_STAGES: tl.constexpr,
|
|
20
|
+
):
|
|
21
|
+
pid = tl.program_id(0)
|
|
22
|
+
num_progs = tl.num_programs(0)
|
|
23
|
+
|
|
24
|
+
grid_m = tl.cdiv(n_elements, BLOCK_SIZE_M)
|
|
25
|
+
grid_n = tl.cdiv(embedding_dim, BLOCK_SIZE_N)
|
|
26
|
+
total_2d_blocks = grid_m * grid_n
|
|
27
|
+
|
|
28
|
+
for block_idx in tl.range(pid, total_2d_blocks, num_progs, num_stages=NUM_STAGES):
|
|
29
|
+
block_m = block_idx // grid_n
|
|
30
|
+
block_n = block_idx % grid_n
|
|
31
|
+
|
|
32
|
+
start_m = block_m * BLOCK_SIZE_M
|
|
33
|
+
start_n = block_n * BLOCK_SIZE_N
|
|
34
|
+
|
|
35
|
+
offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M)
|
|
36
|
+
mask_m = offsets_m < n_elements
|
|
37
|
+
|
|
38
|
+
indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0)
|
|
39
|
+
|
|
40
|
+
offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N)
|
|
41
|
+
mask_n = offsets_n < embedding_dim
|
|
42
|
+
|
|
43
|
+
block_mask = mask_m[:, None] & mask_n[None, :]
|
|
44
|
+
|
|
45
|
+
embedding_offsets = indices[:, None] * embedding_dim + offsets_n[None, :]
|
|
46
|
+
embeddings = tl.load(
|
|
47
|
+
embeddings_ptr + embedding_offsets,
|
|
48
|
+
mask=block_mask,
|
|
49
|
+
other=0.0,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :]
|
|
53
|
+
tl.store(
|
|
54
|
+
output_ptr + output_offsets,
|
|
55
|
+
embeddings,
|
|
56
|
+
mask=block_mask,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@triton.jit
|
|
61
|
+
def embedding_backward_kernel(
|
|
62
|
+
grad_output_ptr,
|
|
63
|
+
grad_weight_ptr,
|
|
64
|
+
indices_ptr,
|
|
65
|
+
n_elements,
|
|
66
|
+
embedding_dim: tl.constexpr,
|
|
67
|
+
BLOCK_SIZE_M: tl.constexpr,
|
|
68
|
+
BLOCK_SIZE_N: tl.constexpr,
|
|
69
|
+
NUM_STAGES: tl.constexpr,
|
|
70
|
+
):
|
|
71
|
+
pid = tl.program_id(0)
|
|
72
|
+
num_progs = tl.num_programs(0)
|
|
73
|
+
|
|
74
|
+
grid_m = tl.cdiv(n_elements, BLOCK_SIZE_M)
|
|
75
|
+
grid_n = tl.cdiv(embedding_dim, BLOCK_SIZE_N)
|
|
76
|
+
total_2d_blocks = grid_m * grid_n
|
|
77
|
+
|
|
78
|
+
for block_idx in tl.range(pid, total_2d_blocks, num_progs, num_stages=NUM_STAGES):
|
|
79
|
+
block_m = block_idx // grid_n
|
|
80
|
+
block_n = block_idx % grid_n
|
|
81
|
+
|
|
82
|
+
start_m = block_m * BLOCK_SIZE_M
|
|
83
|
+
start_n = block_n * BLOCK_SIZE_N
|
|
84
|
+
|
|
85
|
+
offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M)
|
|
86
|
+
mask_m = offsets_m < n_elements
|
|
87
|
+
|
|
88
|
+
indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0)
|
|
89
|
+
|
|
90
|
+
offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N)
|
|
91
|
+
mask_n = offsets_n < embedding_dim
|
|
92
|
+
|
|
93
|
+
block_mask = mask_m[:, None] & mask_n[None, :]
|
|
94
|
+
|
|
95
|
+
grad_output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :]
|
|
96
|
+
grad_output = tl.load(
|
|
97
|
+
grad_output_ptr + grad_output_offsets,
|
|
98
|
+
mask=block_mask,
|
|
99
|
+
other=0.0,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
grad_weight_offsets = indices[:, None] * embedding_dim + offsets_n[None, :]
|
|
103
|
+
tl.atomic_add(
|
|
104
|
+
grad_weight_ptr + grad_weight_offsets,
|
|
105
|
+
grad_output,
|
|
106
|
+
mask=block_mask,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def get_optimal_block_size(total_elements, dtype_size, BLOCK_SIZE_N: tl.constexpr):
|
|
111
|
+
# 1. Set Memory Multiplier
|
|
112
|
+
# 3.0 are empirical values based on 910B UB (192KB)
|
|
113
|
+
# embedding_offsets, embedding_offsets : BLOCK_SIZE_N * BLOCK_SIZE_M (total 2 * BLOCK_SIZE_N * BLOCK_SIZE_M)
|
|
114
|
+
# Reserve a unit of space for the remaining one-dimensional ub to occupy.
|
|
115
|
+
# A conservative estimate of the total space occupation is 3 * BLOCK_SIZE_N * BLOCK_SIZE_M
|
|
116
|
+
multiplier = 3.0
|
|
117
|
+
|
|
118
|
+
# 2. Call calculation function
|
|
119
|
+
# Treat input as 1D (total_elements,), only tiling on dim 0
|
|
120
|
+
tile_shapes = compute_default_tiling_strategy(
|
|
121
|
+
safety_margin=0.9,
|
|
122
|
+
dtype_size=dtype_size,
|
|
123
|
+
memory_multiplier=multiplier,
|
|
124
|
+
shapes=((total_elements, BLOCK_SIZE_N),),
|
|
125
|
+
tiling_dims=(0,),
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# 3. Parse result
|
|
129
|
+
if tile_shapes and len(tile_shapes) > 0:
|
|
130
|
+
block_size = tile_shapes[0][0]
|
|
131
|
+
return block_size
|
|
132
|
+
else:
|
|
133
|
+
return triton.next_power_of_2(min(128, total_elements))
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def embedding_forward(embeddings, indices):
|
|
137
|
+
ori_shape = indices.shape
|
|
138
|
+
indices = indices.view(-1)
|
|
139
|
+
|
|
140
|
+
n_elements = indices.numel()
|
|
141
|
+
embedding_dim = embeddings.shape[1]
|
|
142
|
+
output = torch.empty(
|
|
143
|
+
indices.shape[0],
|
|
144
|
+
embeddings.shape[1],
|
|
145
|
+
device=indices.device,
|
|
146
|
+
dtype=embeddings.dtype,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# Due to the involvement of two-dimensional partitioning,
|
|
150
|
+
# the sizes of block_m and block_n in the ub space will influence each other.
|
|
151
|
+
# Considering that embedding_dim is usually relatively smaller in most cases,
|
|
152
|
+
# a value is first assigned to block_n, and then the largest possible block_m is used.
|
|
153
|
+
BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim))
|
|
154
|
+
BLOCK_SIZE_M = get_optimal_block_size(n_elements, embeddings.element_size(), BLOCK_SIZE_N)
|
|
155
|
+
num_cores = get_npu_core_count()
|
|
156
|
+
total_blocks = triton.cdiv(n_elements, BLOCK_SIZE_M) * triton.cdiv(embedding_dim, BLOCK_SIZE_N)
|
|
157
|
+
grid = min(num_cores, total_blocks)
|
|
158
|
+
|
|
159
|
+
embedding_forward_kernel[(grid,)](
|
|
160
|
+
embeddings,
|
|
161
|
+
indices,
|
|
162
|
+
output,
|
|
163
|
+
n_elements,
|
|
164
|
+
embedding_dim=embedding_dim,
|
|
165
|
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
|
166
|
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
|
167
|
+
NUM_STAGES=3,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
return output.view(*ori_shape, -1)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def embedding_backward(embeddings, indices, grad_output):
|
|
174
|
+
grad_output = grad_output.contiguous().view(-1, embeddings.shape[1])
|
|
175
|
+
|
|
176
|
+
grad_weight = torch.zeros_like(embeddings)
|
|
177
|
+
|
|
178
|
+
n_elements = indices.numel()
|
|
179
|
+
embedding_dim = embeddings.shape[1]
|
|
180
|
+
BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim))
|
|
181
|
+
BLOCK_SIZE_M = get_optimal_block_size(n_elements, embeddings.element_size(), BLOCK_SIZE_N)
|
|
182
|
+
num_cores = get_npu_core_count()
|
|
183
|
+
total_blocks = triton.cdiv(n_elements, BLOCK_SIZE_M) * triton.cdiv(embedding_dim, BLOCK_SIZE_N)
|
|
184
|
+
grid = min(num_cores, total_blocks)
|
|
185
|
+
|
|
186
|
+
embedding_backward_kernel[(grid,)](
|
|
187
|
+
grad_output,
|
|
188
|
+
grad_weight,
|
|
189
|
+
indices,
|
|
190
|
+
n_elements,
|
|
191
|
+
embedding_dim=embedding_dim,
|
|
192
|
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
|
193
|
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
|
194
|
+
NUM_STAGES=3,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
return grad_weight
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class LigerEmbeddingFunction(torch.autograd.Function):
|
|
201
|
+
@staticmethod
|
|
202
|
+
@ensure_contiguous
|
|
203
|
+
def forward(ctx, embeddings: torch.Tensor, indices: torch.Tensor):
|
|
204
|
+
output = embedding_forward(embeddings, indices)
|
|
205
|
+
ctx.save_for_backward(indices, embeddings)
|
|
206
|
+
return output
|
|
207
|
+
|
|
208
|
+
@staticmethod
|
|
209
|
+
@ensure_contiguous
|
|
210
|
+
def backward(ctx, grad_output: torch.Tensor):
|
|
211
|
+
indices, embeddings = ctx.saved_tensors
|
|
212
|
+
grad_weight = embedding_backward(embeddings, indices, grad_output)
|
|
213
|
+
|
|
214
|
+
return grad_weight, None
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
from triton.language.math import tanh
|
|
6
|
+
|
|
7
|
+
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
|
|
8
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
9
|
+
from liger_kernel.ops.utils import get_npu_core_count
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@triton.jit
|
|
13
|
+
def _geglu_forward_kernel_flat(a_ptr, b_ptr, c_ptr, total_elements, BLOCK_SIZE: tl.constexpr, NUM_STAGES: tl.constexpr):
|
|
14
|
+
"""
|
|
15
|
+
High-performance GEGLU forward kernel using flatten 1D approach.
|
|
16
|
+
|
|
17
|
+
Uses grid-stride loop pattern for optimal performance on NPU.
|
|
18
|
+
"""
|
|
19
|
+
pid = tl.program_id(0)
|
|
20
|
+
num_progs = tl.num_programs(0)
|
|
21
|
+
|
|
22
|
+
# Grid-Stride Loop
|
|
23
|
+
start_idx = pid * BLOCK_SIZE
|
|
24
|
+
stride = num_progs * BLOCK_SIZE
|
|
25
|
+
|
|
26
|
+
# Constants for GELU tanh approximation
|
|
27
|
+
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
|
|
28
|
+
gelu_coeff = 0.044715
|
|
29
|
+
|
|
30
|
+
for idx in tl.range(start_idx, total_elements, stride, num_stages=NUM_STAGES):
|
|
31
|
+
offsets = idx + tl.arange(0, BLOCK_SIZE)
|
|
32
|
+
mask = offsets < total_elements
|
|
33
|
+
|
|
34
|
+
a_val = tl.load(a_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
|
|
35
|
+
b_val = tl.load(b_ptr + offsets, mask=mask, other=0.0)
|
|
36
|
+
|
|
37
|
+
# tanh approximation form of GELU is computed with:
|
|
38
|
+
# 0.5 * a * (1 + tanh(sqrt(2 / pi) * (a + 0.044715 * a^3)))
|
|
39
|
+
a_cubed = a_val * a_val * a_val
|
|
40
|
+
tanh_arg = sqrt_2_over_pi * (a_val + gelu_coeff * a_cubed)
|
|
41
|
+
tanh_result = tanh(tanh_arg)
|
|
42
|
+
geglu_a = 0.5 * a_val * (1.0 + tanh_result)
|
|
43
|
+
c_row = geglu_a.cast(b_val.dtype) * b_val
|
|
44
|
+
tl.store(c_ptr + offsets, c_row, mask=mask)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@triton.jit
|
|
48
|
+
def _geglu_backward_kernel_flat(
|
|
49
|
+
dc_ptr, a_ptr, b_ptr, da_ptr, db_ptr, total_elements, BLOCK_SIZE: tl.constexpr, NUM_STAGES: tl.constexpr
|
|
50
|
+
):
|
|
51
|
+
"""
|
|
52
|
+
High-performance GEGLU backward kernel using flatten 1D approach.
|
|
53
|
+
|
|
54
|
+
Uses grid-stride loop pattern for optimal performance on NPU.
|
|
55
|
+
"""
|
|
56
|
+
pid = tl.program_id(0)
|
|
57
|
+
num_progs = tl.num_programs(0)
|
|
58
|
+
start_idx = pid * BLOCK_SIZE
|
|
59
|
+
stride = num_progs * BLOCK_SIZE
|
|
60
|
+
|
|
61
|
+
# Constants for GELU tanh approximation
|
|
62
|
+
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
|
|
63
|
+
gelu_coeff = 0.044715
|
|
64
|
+
|
|
65
|
+
for idx in tl.range(start_idx, total_elements, stride, num_stages=NUM_STAGES):
|
|
66
|
+
offsets = idx + tl.arange(0, BLOCK_SIZE)
|
|
67
|
+
mask = offsets < total_elements
|
|
68
|
+
|
|
69
|
+
dc = tl.load(dc_ptr + offsets, mask=mask, other=0.0)
|
|
70
|
+
a = tl.load(a_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
|
|
71
|
+
b = tl.load(b_ptr + offsets, mask=mask, other=0.0)
|
|
72
|
+
|
|
73
|
+
# recomputation to save memory
|
|
74
|
+
a_cubed = a * a * a
|
|
75
|
+
tanh_arg = sqrt_2_over_pi * (a + gelu_coeff * a_cubed)
|
|
76
|
+
tanh_result = tanh(tanh_arg)
|
|
77
|
+
geglu_a = 0.5 * a * (1 + tanh_result)
|
|
78
|
+
geglu_a = geglu_a.to(dc.dtype).to(tl.float32)
|
|
79
|
+
|
|
80
|
+
db = dc.cast(tl.float32) * geglu_a
|
|
81
|
+
|
|
82
|
+
# Gradient w.r.t. a can be computed with:
|
|
83
|
+
# b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
|
|
84
|
+
# where z = sqrt(2/pi) * (a + 0.044715 * a^3)
|
|
85
|
+
term1 = 0.5 * (1.0 + tanh_result)
|
|
86
|
+
tanh_sq = tanh_result * tanh_result
|
|
87
|
+
a_sq = a * a
|
|
88
|
+
term2 = 0.5 * a * (1.0 - tanh_sq) * (sqrt_2_over_pi * (1.0 + 3.0 * gelu_coeff * a_sq))
|
|
89
|
+
da = dc * b * (term1 + term2)
|
|
90
|
+
|
|
91
|
+
tl.store(da_ptr + offsets, da, mask=mask)
|
|
92
|
+
tl.store(db_ptr + offsets, db.to(dc.dtype), mask=mask)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def get_optimal_block_size(total_elements, is_backward=False):
|
|
96
|
+
"""
|
|
97
|
+
Calculate optimal Block Size using compute_default_tiling_strategy.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
total_elements: Total number of elements to process
|
|
101
|
+
is_backward: Whether this is for backward pass (requires more memory)
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
Optimal block size for the kernel
|
|
105
|
+
"""
|
|
106
|
+
# Memory multiplier based on peak memory usage analysis
|
|
107
|
+
if is_backward:
|
|
108
|
+
memory_multiplier = 6.0
|
|
109
|
+
else:
|
|
110
|
+
memory_multiplier = 3.0
|
|
111
|
+
# Call calculation function
|
|
112
|
+
# Treat input as 1D (total_elements,), only tiling on dim 0
|
|
113
|
+
tile_shapes = compute_default_tiling_strategy(
|
|
114
|
+
safety_margin=0.9,
|
|
115
|
+
dtype_size=4,
|
|
116
|
+
memory_multiplier=memory_multiplier,
|
|
117
|
+
shapes=((total_elements,),),
|
|
118
|
+
tiling_dims=(0,),
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# Parse result
|
|
122
|
+
if tile_shapes and len(tile_shapes) > 0:
|
|
123
|
+
block_size = tile_shapes[0][0]
|
|
124
|
+
return max(256, block_size)
|
|
125
|
+
else:
|
|
126
|
+
return 2048
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def geglu_forward(a, b):
|
|
130
|
+
"""
|
|
131
|
+
High-performance GEGLU forward pass for NPU using flatten 1D approach.
|
|
132
|
+
"""
|
|
133
|
+
if not a.is_contiguous():
|
|
134
|
+
a = a.contiguous()
|
|
135
|
+
if not b.is_contiguous():
|
|
136
|
+
b = b.contiguous()
|
|
137
|
+
|
|
138
|
+
total_elements = a.numel()
|
|
139
|
+
c = torch.empty_like(a)
|
|
140
|
+
|
|
141
|
+
block_size = get_optimal_block_size(total_elements, is_backward=False)
|
|
142
|
+
|
|
143
|
+
num_cores = get_npu_core_count()
|
|
144
|
+
grid_size = min(num_cores, (total_elements + block_size - 1) // block_size)
|
|
145
|
+
|
|
146
|
+
_geglu_forward_kernel_flat[(grid_size,)](a, b, c, total_elements, BLOCK_SIZE=block_size, NUM_STAGES=3, num_warps=4)
|
|
147
|
+
return c
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def geglu_backward(a, b, dc):
|
|
151
|
+
"""
|
|
152
|
+
High-performance GEGLU backward pass for NPU using flatten 1D approach.
|
|
153
|
+
"""
|
|
154
|
+
if not dc.is_contiguous():
|
|
155
|
+
dc = dc.contiguous()
|
|
156
|
+
if not a.is_contiguous():
|
|
157
|
+
a = a.contiguous()
|
|
158
|
+
if not b.is_contiguous():
|
|
159
|
+
b = b.contiguous()
|
|
160
|
+
|
|
161
|
+
total_elements = dc.numel()
|
|
162
|
+
grad_a = torch.empty_like(a)
|
|
163
|
+
grad_b = torch.empty_like(b)
|
|
164
|
+
|
|
165
|
+
block_size = get_optimal_block_size(total_elements, is_backward=True)
|
|
166
|
+
|
|
167
|
+
num_cores = get_npu_core_count()
|
|
168
|
+
grid_size = min(num_cores, (total_elements + block_size - 1) // block_size)
|
|
169
|
+
|
|
170
|
+
_geglu_backward_kernel_flat[(grid_size,)](
|
|
171
|
+
dc, a, b, grad_a, grad_b, total_elements, BLOCK_SIZE=block_size, NUM_STAGES=3, num_warps=4
|
|
172
|
+
)
|
|
173
|
+
return grad_a, grad_b
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class LigerGELUMulFunction(torch.autograd.Function):
|
|
177
|
+
"""High-performance GEGLU function for Ascend NPU."""
|
|
178
|
+
|
|
179
|
+
@staticmethod
|
|
180
|
+
@ensure_contiguous
|
|
181
|
+
def forward(ctx, a, b):
|
|
182
|
+
c = geglu_forward(a, b)
|
|
183
|
+
ctx.save_for_backward(a, b)
|
|
184
|
+
return c
|
|
185
|
+
|
|
186
|
+
@staticmethod
|
|
187
|
+
@ensure_contiguous
|
|
188
|
+
def backward(ctx, dc):
|
|
189
|
+
a, b = ctx.saved_tensors
|
|
190
|
+
grad_a, grad_b = geglu_backward(a, b, dc)
|
|
191
|
+
return grad_a, grad_b
|
|
@@ -0,0 +1,298 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _prepare_freqs(freqs_cis: torch.Tensor, seq_len: int, head_dim_half: int):
|
|
9
|
+
"""
|
|
10
|
+
Canonicalize freqs to (seq_len, head_dim_half) real/imag tensors.
|
|
11
|
+
|
|
12
|
+
Supports:
|
|
13
|
+
- complex freqs: (..., head_dim_half) complex -> real/imag
|
|
14
|
+
- packed freqs: (..., 2*head_dim_half) real -> split into real/imag
|
|
15
|
+
"""
|
|
16
|
+
if freqs_cis.is_complex():
|
|
17
|
+
freqs_real = freqs_cis.real
|
|
18
|
+
freqs_imag = freqs_cis.imag
|
|
19
|
+
else:
|
|
20
|
+
if freqs_cis.shape[-1] == 2 * head_dim_half:
|
|
21
|
+
freqs_real = freqs_cis[..., :head_dim_half]
|
|
22
|
+
freqs_imag = freqs_cis[..., head_dim_half:]
|
|
23
|
+
else:
|
|
24
|
+
raise ValueError(
|
|
25
|
+
f"Unexpected freqs_cis shape for non-complex input: {freqs_cis.shape}, "
|
|
26
|
+
f"expected last dim = {2 * head_dim_half}"
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
if freqs_real.shape[-1] != head_dim_half:
|
|
30
|
+
raise ValueError(f"Unexpected last dim for freqs: {freqs_real.shape[-1]} (expected {head_dim_half})")
|
|
31
|
+
|
|
32
|
+
# Flatten leading dims -> (N, head_dim_half)
|
|
33
|
+
freqs_real = freqs_real.reshape(-1, head_dim_half)
|
|
34
|
+
freqs_imag = freqs_imag.reshape(-1, head_dim_half)
|
|
35
|
+
|
|
36
|
+
# Broadcast/slice to (seq_len, head_dim_half)
|
|
37
|
+
if freqs_real.shape[0] < seq_len:
|
|
38
|
+
if freqs_real.shape[0] == 1:
|
|
39
|
+
freqs_real = freqs_real.expand(seq_len, -1)
|
|
40
|
+
freqs_imag = freqs_imag.expand(seq_len, -1)
|
|
41
|
+
else:
|
|
42
|
+
raise ValueError(f"Insufficient rows in freqs: {freqs_real.shape[0]} < seq_len={seq_len}")
|
|
43
|
+
elif freqs_real.shape[0] > seq_len:
|
|
44
|
+
freqs_real = freqs_real[:seq_len]
|
|
45
|
+
freqs_imag = freqs_imag[:seq_len]
|
|
46
|
+
|
|
47
|
+
return freqs_real, freqs_imag
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _cast_and_contiguous(q, k, freqs_real, freqs_imag):
|
|
51
|
+
# Align dtype: fp32 only when q is fp32; otherwise keep q dtype for perf
|
|
52
|
+
compute_dtype = torch.float32 if q.dtype == torch.float32 else q.dtype
|
|
53
|
+
|
|
54
|
+
if k.dtype != q.dtype:
|
|
55
|
+
k = k.to(q.dtype)
|
|
56
|
+
|
|
57
|
+
q = q.to(compute_dtype).contiguous()
|
|
58
|
+
k = k.to(compute_dtype).contiguous()
|
|
59
|
+
freqs_real = freqs_real.to(compute_dtype).contiguous()
|
|
60
|
+
freqs_imag = freqs_imag.to(compute_dtype).contiguous()
|
|
61
|
+
return q, k, freqs_real, freqs_imag, compute_dtype
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@triton.jit
|
|
65
|
+
def _triton_llama4_rope_npu(
|
|
66
|
+
q_ptr,
|
|
67
|
+
k_ptr,
|
|
68
|
+
freqs_real_ptr,
|
|
69
|
+
freqs_imag_ptr,
|
|
70
|
+
q_row_stride,
|
|
71
|
+
k_row_stride,
|
|
72
|
+
q_head_stride,
|
|
73
|
+
k_head_stride,
|
|
74
|
+
freqs_row_stride,
|
|
75
|
+
sl,
|
|
76
|
+
bs: tl.constexpr,
|
|
77
|
+
n_qh: tl.constexpr,
|
|
78
|
+
n_kh: tl.constexpr,
|
|
79
|
+
hd: tl.constexpr,
|
|
80
|
+
BLOCK_Q: tl.constexpr,
|
|
81
|
+
BLOCK_K: tl.constexpr,
|
|
82
|
+
imag_sign: tl.constexpr,
|
|
83
|
+
):
|
|
84
|
+
"""
|
|
85
|
+
Llama4 RoPE on Ascend NPU for interleaved complex layout:
|
|
86
|
+
- q/k shape: (bs, sl, n_heads, hd)
|
|
87
|
+
- last dim layout: [real0, imag0, real1, imag1, ...]
|
|
88
|
+
- freqs_real/imag: (sl, hd//2)
|
|
89
|
+
"""
|
|
90
|
+
pid = tl.program_id(0).to(tl.int64)
|
|
91
|
+
batch_idx = pid // sl
|
|
92
|
+
seq_idx = pid % sl
|
|
93
|
+
|
|
94
|
+
if batch_idx >= bs:
|
|
95
|
+
return
|
|
96
|
+
|
|
97
|
+
q_base = q_ptr + pid * q_row_stride
|
|
98
|
+
k_base = k_ptr + pid * k_row_stride
|
|
99
|
+
|
|
100
|
+
freq_base = seq_idx * freqs_row_stride
|
|
101
|
+
hd_idx = tl.arange(0, hd)
|
|
102
|
+
hd_mask = hd_idx < (hd)
|
|
103
|
+
|
|
104
|
+
freq_idx = tl.arange(0, hd // 2)
|
|
105
|
+
freq_mask = freq_idx < (hd // 2)
|
|
106
|
+
|
|
107
|
+
freqs_real = tl.load(freqs_real_ptr + freq_base + freq_idx, mask=freq_mask, other=0.0)
|
|
108
|
+
freqs_imag = tl.load(freqs_imag_ptr + freq_base + freq_idx, mask=freq_mask, other=0.0) * imag_sign
|
|
109
|
+
|
|
110
|
+
# Q heads (chunked for UB)
|
|
111
|
+
for qh_block in range(0, n_qh, BLOCK_Q):
|
|
112
|
+
qh_idx = tl.arange(0, BLOCK_Q) + qh_block
|
|
113
|
+
qh_mask = qh_idx < n_qh
|
|
114
|
+
block_mask = qh_mask[:, None] & hd_mask[None, :]
|
|
115
|
+
|
|
116
|
+
head_ptr = q_base + qh_idx[:, None] * q_head_stride
|
|
117
|
+
|
|
118
|
+
q_pair = tl.load(
|
|
119
|
+
head_ptr + hd_idx[None, :],
|
|
120
|
+
mask=block_mask,
|
|
121
|
+
other=0.0,
|
|
122
|
+
)
|
|
123
|
+
q_pair = q_pair.reshape(BLOCK_Q, hd // 2, 2, can_reorder=True)
|
|
124
|
+
q_real, q_imag = tl.split(q_pair)
|
|
125
|
+
|
|
126
|
+
new_real = tl.math.fma(q_real, freqs_real, -(q_imag * freqs_imag))
|
|
127
|
+
new_imag = tl.math.fma(q_real, freqs_imag, q_imag * freqs_real)
|
|
128
|
+
new_q_pair = tl.interleave(new_real, new_imag)
|
|
129
|
+
|
|
130
|
+
tl.store(head_ptr + hd_idx[None, :], new_q_pair, mask=block_mask)
|
|
131
|
+
|
|
132
|
+
# K heads (chunked for UB)
|
|
133
|
+
for kh_block in range(0, n_kh, BLOCK_K):
|
|
134
|
+
kh_idx = tl.arange(0, BLOCK_K) + kh_block
|
|
135
|
+
kh_mask = kh_idx < n_kh
|
|
136
|
+
block_mask = kh_mask[:, None] & hd_mask[None, :]
|
|
137
|
+
|
|
138
|
+
head_ptr = k_base + kh_idx[:, None] * k_head_stride
|
|
139
|
+
|
|
140
|
+
k_pair = tl.load(
|
|
141
|
+
head_ptr + hd_idx[None, :],
|
|
142
|
+
mask=block_mask,
|
|
143
|
+
other=0.0,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
k_pair = k_pair.reshape(BLOCK_K, hd // 2, 2, can_reorder=True)
|
|
147
|
+
k_real, k_imag = tl.split(k_pair)
|
|
148
|
+
|
|
149
|
+
new_real = tl.math.fma(k_real, freqs_real, -(k_imag * freqs_imag))
|
|
150
|
+
new_imag = tl.math.fma(k_real, freqs_imag, k_imag * freqs_real)
|
|
151
|
+
new_k_pair = tl.interleave(new_real, new_imag)
|
|
152
|
+
|
|
153
|
+
tl.store(head_ptr + hd_idx[None, :], new_k_pair, mask=block_mask)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def llama4_rope_forward(q, k, freqs_cis):
|
|
157
|
+
"""
|
|
158
|
+
Ascend NPU implementation of Llama4 RoPE.
|
|
159
|
+
|
|
160
|
+
q/k: (bs, sl, n_heads, hd) with interleaved complex last-dim layout.
|
|
161
|
+
freqs_cis: complex (..., hd//2) OR packed (..., 2*(hd//2)).
|
|
162
|
+
"""
|
|
163
|
+
original_dtype = q.dtype
|
|
164
|
+
|
|
165
|
+
bs, sl, n_qh, hd = q.shape
|
|
166
|
+
_, _, n_kh, _ = k.shape
|
|
167
|
+
if hd % 2 != 0:
|
|
168
|
+
raise ValueError(f"head_dim must be even for interleaved complex layout, got {hd}")
|
|
169
|
+
hd_half = hd // 2
|
|
170
|
+
|
|
171
|
+
freqs_real, freqs_imag = _prepare_freqs(freqs_cis, sl, hd_half)
|
|
172
|
+
q, k, freqs_real, freqs_imag, compute_dtype = _cast_and_contiguous(q, k, freqs_real, freqs_imag)
|
|
173
|
+
|
|
174
|
+
# UB tiling strategy: tile heads dimension only
|
|
175
|
+
dtype_size = q.element_size()
|
|
176
|
+
shapes = ((n_qh, hd), (n_kh, hd))
|
|
177
|
+
tile_shapes = compute_default_tiling_strategy(
|
|
178
|
+
safety_margin=0.90,
|
|
179
|
+
dtype_size=dtype_size,
|
|
180
|
+
memory_multiplier=12.0,
|
|
181
|
+
shapes=shapes,
|
|
182
|
+
tiling_dims=(0, 0),
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
if tile_shapes is not None and len(tile_shapes) == len(shapes):
|
|
186
|
+
q_tile_shape, k_tile_shape = tile_shapes
|
|
187
|
+
BLOCK_Q, _ = q_tile_shape
|
|
188
|
+
BLOCK_K, _ = k_tile_shape
|
|
189
|
+
else:
|
|
190
|
+
BLOCK_Q = triton.next_power_of_2(n_qh)
|
|
191
|
+
BLOCK_K = triton.next_power_of_2(n_kh)
|
|
192
|
+
|
|
193
|
+
n_row = bs * sl
|
|
194
|
+
|
|
195
|
+
_triton_llama4_rope_npu[(n_row,)](
|
|
196
|
+
q,
|
|
197
|
+
k,
|
|
198
|
+
freqs_real,
|
|
199
|
+
freqs_imag,
|
|
200
|
+
q.stride(1),
|
|
201
|
+
k.stride(1),
|
|
202
|
+
q.stride(2),
|
|
203
|
+
k.stride(2),
|
|
204
|
+
freqs_real.stride(0),
|
|
205
|
+
sl,
|
|
206
|
+
bs,
|
|
207
|
+
n_qh,
|
|
208
|
+
n_kh,
|
|
209
|
+
hd,
|
|
210
|
+
BLOCK_Q,
|
|
211
|
+
BLOCK_K,
|
|
212
|
+
imag_sign=1.0,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
if compute_dtype != original_dtype:
|
|
216
|
+
q = q.to(original_dtype)
|
|
217
|
+
k = k.to(original_dtype)
|
|
218
|
+
return q, k
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def llama4_rope_backward(dq, dk, freqs_cis):
|
|
222
|
+
"""
|
|
223
|
+
Ascend NPU implementation of Llama4 RoPE.
|
|
224
|
+
|
|
225
|
+
q/k: (bs, sl, n_heads, hd) with interleaved complex last-dim layout.
|
|
226
|
+
freqs_cis: complex (..., hd//2) OR packed (..., 2*(hd//2)).
|
|
227
|
+
"""
|
|
228
|
+
original_dtype = dq.dtype
|
|
229
|
+
|
|
230
|
+
bs, sl, n_qh, hd = dq.shape
|
|
231
|
+
_, _, n_kh, _ = dk.shape
|
|
232
|
+
if hd % 2 != 0:
|
|
233
|
+
raise ValueError(f"head_dim must be even for interleaved complex layout, got {hd}")
|
|
234
|
+
hd_half = hd // 2
|
|
235
|
+
|
|
236
|
+
freqs_real, freqs_imag = _prepare_freqs(freqs_cis, sl, hd_half)
|
|
237
|
+
dq, dk, freqs_real, freqs_imag, compute_dtype = _cast_and_contiguous(dq, dk, freqs_real, freqs_imag)
|
|
238
|
+
|
|
239
|
+
# UB tiling strategy: tile heads dimension only
|
|
240
|
+
dtype_size = dq.element_size()
|
|
241
|
+
shapes = ((n_qh, hd), (n_kh, hd))
|
|
242
|
+
tile_shapes = compute_default_tiling_strategy(
|
|
243
|
+
safety_margin=0.90,
|
|
244
|
+
dtype_size=dtype_size,
|
|
245
|
+
memory_multiplier=12.0,
|
|
246
|
+
shapes=shapes,
|
|
247
|
+
tiling_dims=(0, 0),
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
if tile_shapes is not None and len(tile_shapes) == len(shapes):
|
|
251
|
+
q_tile_shape, k_tile_shape = tile_shapes
|
|
252
|
+
BLOCK_Q, _ = q_tile_shape
|
|
253
|
+
BLOCK_K, _ = k_tile_shape
|
|
254
|
+
else:
|
|
255
|
+
BLOCK_Q = triton.next_power_of_2(n_qh)
|
|
256
|
+
BLOCK_K = triton.next_power_of_2(n_kh)
|
|
257
|
+
|
|
258
|
+
n_row = bs * sl
|
|
259
|
+
|
|
260
|
+
_triton_llama4_rope_npu[(n_row,)](
|
|
261
|
+
dq,
|
|
262
|
+
dk,
|
|
263
|
+
freqs_real,
|
|
264
|
+
freqs_imag,
|
|
265
|
+
dq.stride(1),
|
|
266
|
+
dk.stride(1),
|
|
267
|
+
dq.stride(2),
|
|
268
|
+
dk.stride(2),
|
|
269
|
+
freqs_real.stride(0),
|
|
270
|
+
sl,
|
|
271
|
+
bs,
|
|
272
|
+
n_qh,
|
|
273
|
+
n_kh,
|
|
274
|
+
hd,
|
|
275
|
+
BLOCK_Q,
|
|
276
|
+
BLOCK_K,
|
|
277
|
+
imag_sign=-1.0,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
if compute_dtype != original_dtype:
|
|
281
|
+
dq = dq.to(original_dtype)
|
|
282
|
+
dk = dk.to(original_dtype)
|
|
283
|
+
return dq, dk
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
class LigerLlama4RopeFunction(torch.autograd.Function):
|
|
287
|
+
@staticmethod
|
|
288
|
+
def forward(ctx, q, k, freqs_cis, BLOCK_SIZE: int = None):
|
|
289
|
+
# BLOCK_SIZE is ignored for Ascend (we auto-tile heads by UB), kept for API compatibility
|
|
290
|
+
q_out, k_out = llama4_rope_forward(q, k, freqs_cis)
|
|
291
|
+
ctx.save_for_backward(freqs_cis.detach() if isinstance(freqs_cis, torch.Tensor) else freqs_cis)
|
|
292
|
+
return q_out, k_out
|
|
293
|
+
|
|
294
|
+
@staticmethod
|
|
295
|
+
def backward(ctx, dq, dk):
|
|
296
|
+
(freqs_cis,) = ctx.saved_tensors
|
|
297
|
+
dq_out, dk_out = llama4_rope_backward(dq, dk, freqs_cis)
|
|
298
|
+
return dq_out, dk_out, None, None
|