liger-kernel 0.6.3__py3-none-any.whl → 0.6.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/cosine_similarity_loss.py +20 -5
- liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
- liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
- liger_kernel/chunked_loss/grpo_loss.py +8 -5
- liger_kernel/chunked_loss/jsd_loss.py +39 -11
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +492 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +61 -0
- liger_kernel/ops/backends/_ascend/ops/embedding.py +214 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +191 -0
- liger_kernel/ops/backends/_ascend/ops/llama4_rope.py +298 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +275 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +265 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ops/tvd.py +223 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +367 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +71 -11
- liger_kernel/ops/dyt.py +5 -2
- liger_kernel/ops/fused_add_rms_norm.py +21 -23
- liger_kernel/ops/fused_linear_cross_entropy.py +32 -5
- liger_kernel/ops/geglu.py +5 -3
- liger_kernel/ops/group_norm.py +12 -8
- liger_kernel/ops/grpo_loss.py +3 -1
- liger_kernel/ops/kl_div.py +8 -11
- liger_kernel/ops/layer_norm.py +89 -69
- liger_kernel/ops/poly_norm.py +19 -21
- liger_kernel/ops/rms_norm.py +149 -71
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +25 -0
- liger_kernel/transformers/__init__.py +25 -0
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +1 -1
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/functional.py +44 -26
- liger_kernel/transformers/fused_add_rms_norm.py +1 -1
- liger_kernel/transformers/fused_linear_cross_entropy.py +9 -4
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +57 -2
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +1 -1
- liger_kernel/transformers/model/exaone4.py +136 -0
- liger_kernel/transformers/model/falcon_h1.py +19 -5
- liger_kernel/transformers/model/gemma.py +17 -6
- liger_kernel/transformers/model/gemma2.py +17 -8
- liger_kernel/transformers/model/gemma3.py +35 -16
- liger_kernel/transformers/model/glm4.py +16 -4
- liger_kernel/transformers/model/glm4v.py +16 -4
- liger_kernel/transformers/model/glm4v_moe.py +23 -4
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +12 -5
- liger_kernel/transformers/model/llama.py +14 -5
- liger_kernel/transformers/model/llama4.py +16 -4
- liger_kernel/transformers/model/llava.py +12 -4
- liger_kernel/transformers/model/loss_utils.py +37 -3
- liger_kernel/transformers/model/mistral.py +15 -6
- liger_kernel/transformers/model/mixtral.py +16 -7
- liger_kernel/transformers/model/mllama.py +12 -4
- liger_kernel/transformers/model/olmo2.py +16 -4
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +23 -5
- liger_kernel/transformers/model/phi3.py +14 -7
- liger_kernel/transformers/model/qwen2.py +16 -3
- liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
- liger_kernel/transformers/model/qwen2_vl.py +16 -4
- liger_kernel/transformers/model/qwen3.py +20 -5
- liger_kernel/transformers/model/qwen3_moe.py +19 -5
- liger_kernel/transformers/model/qwen3_next.py +17 -5
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +15 -6
- liger_kernel/transformers/monkey_patch.py +584 -49
- liger_kernel/transformers/multi_token_attention.py +1 -1
- liger_kernel/transformers/poly_norm.py +1 -1
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +8 -3
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +1 -1
- liger_kernel/transformers/sparsemax.py +1 -1
- liger_kernel/transformers/swiglu.py +18 -1
- liger_kernel/transformers/tiled_mlp.py +125 -0
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +54 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/METADATA +14 -4
- liger_kernel-0.6.5.dist-info/RECORD +134 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/WHEEL +1 -1
- liger_kernel-0.6.3.dist-info/RECORD +0 -111
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.6.3.dist-info → liger_kernel-0.6.5.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Ascend NPU operator implementations.
|
|
3
|
+
|
|
4
|
+
This module exports Ascend NPU-optimized implementations that will automatically
|
|
5
|
+
replace the default implementations when running on NPU devices.
|
|
6
|
+
|
|
7
|
+
Both Function classes and kernel functions can be exported here.
|
|
8
|
+
|
|
9
|
+
To add a new operator:
|
|
10
|
+
1. Create the implementation file (e.g., rms_norm.py)
|
|
11
|
+
2. Import the Function class and/or kernel functions here
|
|
12
|
+
3. Optionally add to __all__ for explicit control
|
|
13
|
+
|
|
14
|
+
If __all__ is not defined, all public symbols will be auto-discovered.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from liger_kernel.ops.backends._ascend.ops.embedding import LigerEmbeddingFunction
|
|
18
|
+
from liger_kernel.ops.backends._ascend.ops.embedding import embedding_backward
|
|
19
|
+
from liger_kernel.ops.backends._ascend.ops.embedding import embedding_forward
|
|
20
|
+
from liger_kernel.ops.backends._ascend.ops.geglu import LigerGELUMulFunction
|
|
21
|
+
from liger_kernel.ops.backends._ascend.ops.geglu import geglu_backward
|
|
22
|
+
from liger_kernel.ops.backends._ascend.ops.geglu import geglu_forward
|
|
23
|
+
from liger_kernel.ops.backends._ascend.ops.llama4_rope import LigerLlama4RopeFunction
|
|
24
|
+
from liger_kernel.ops.backends._ascend.ops.llama4_rope import llama4_rope_backward
|
|
25
|
+
from liger_kernel.ops.backends._ascend.ops.llama4_rope import llama4_rope_forward
|
|
26
|
+
from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
|
|
27
|
+
from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import qwen2vl_mrope_backward
|
|
28
|
+
from liger_kernel.ops.backends._ascend.ops.qwen2vl_mrope import qwen2vl_mrope_forward
|
|
29
|
+
from liger_kernel.ops.backends._ascend.ops.rope import LigerRopeFunction
|
|
30
|
+
from liger_kernel.ops.backends._ascend.ops.rope import rope_backward
|
|
31
|
+
from liger_kernel.ops.backends._ascend.ops.rope import rope_forward
|
|
32
|
+
from liger_kernel.ops.backends._ascend.ops.swiglu import LigerSiLUMulFunction
|
|
33
|
+
from liger_kernel.ops.backends._ascend.ops.swiglu import swiglu_backward
|
|
34
|
+
from liger_kernel.ops.backends._ascend.ops.swiglu import swiglu_forward
|
|
35
|
+
from liger_kernel.ops.backends._ascend.ops.tvd import LigerTVDLossFunction
|
|
36
|
+
from liger_kernel.ops.backends._ascend.ops.tvd import tv_distance_forward_triton
|
|
37
|
+
from liger_kernel.ops.backends._ascend.ops.tvd import tvd_backward_triton
|
|
38
|
+
|
|
39
|
+
__all__ = [
|
|
40
|
+
"LigerEmbeddingFunction",
|
|
41
|
+
"embedding_forward",
|
|
42
|
+
"embedding_backward",
|
|
43
|
+
"LigerGELUMulFunction",
|
|
44
|
+
"geglu_forward",
|
|
45
|
+
"geglu_backward",
|
|
46
|
+
"LigerQwen2VLMRopeFunction",
|
|
47
|
+
"qwen2vl_mrope_forward",
|
|
48
|
+
"qwen2vl_mrope_backward",
|
|
49
|
+
"LigerRopeFunction",
|
|
50
|
+
"rope_forward",
|
|
51
|
+
"rope_backward",
|
|
52
|
+
"LigerSiLUMulFunction",
|
|
53
|
+
"swiglu_forward",
|
|
54
|
+
"swiglu_backward",
|
|
55
|
+
"LigerTVDLossFunction",
|
|
56
|
+
"tv_distance_forward_triton",
|
|
57
|
+
"tvd_backward_triton",
|
|
58
|
+
"LigerLlama4RopeFunction",
|
|
59
|
+
"llama4_rope_forward",
|
|
60
|
+
"llama4_rope_backward",
|
|
61
|
+
]
|
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
|
|
6
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
7
|
+
from liger_kernel.ops.utils import get_npu_core_count
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@triton.jit
|
|
11
|
+
def embedding_forward_kernel(
|
|
12
|
+
embeddings_ptr,
|
|
13
|
+
indices_ptr,
|
|
14
|
+
output_ptr,
|
|
15
|
+
n_elements,
|
|
16
|
+
embedding_dim: tl.constexpr,
|
|
17
|
+
BLOCK_SIZE_M: tl.constexpr,
|
|
18
|
+
BLOCK_SIZE_N: tl.constexpr,
|
|
19
|
+
NUM_STAGES: tl.constexpr,
|
|
20
|
+
):
|
|
21
|
+
pid = tl.program_id(0)
|
|
22
|
+
num_progs = tl.num_programs(0)
|
|
23
|
+
|
|
24
|
+
grid_m = tl.cdiv(n_elements, BLOCK_SIZE_M)
|
|
25
|
+
grid_n = tl.cdiv(embedding_dim, BLOCK_SIZE_N)
|
|
26
|
+
total_2d_blocks = grid_m * grid_n
|
|
27
|
+
|
|
28
|
+
for block_idx in tl.range(pid, total_2d_blocks, num_progs, num_stages=NUM_STAGES):
|
|
29
|
+
block_m = block_idx // grid_n
|
|
30
|
+
block_n = block_idx % grid_n
|
|
31
|
+
|
|
32
|
+
start_m = block_m * BLOCK_SIZE_M
|
|
33
|
+
start_n = block_n * BLOCK_SIZE_N
|
|
34
|
+
|
|
35
|
+
offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M)
|
|
36
|
+
mask_m = offsets_m < n_elements
|
|
37
|
+
|
|
38
|
+
indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0)
|
|
39
|
+
|
|
40
|
+
offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N)
|
|
41
|
+
mask_n = offsets_n < embedding_dim
|
|
42
|
+
|
|
43
|
+
block_mask = mask_m[:, None] & mask_n[None, :]
|
|
44
|
+
|
|
45
|
+
embedding_offsets = indices[:, None] * embedding_dim + offsets_n[None, :]
|
|
46
|
+
embeddings = tl.load(
|
|
47
|
+
embeddings_ptr + embedding_offsets,
|
|
48
|
+
mask=block_mask,
|
|
49
|
+
other=0.0,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :]
|
|
53
|
+
tl.store(
|
|
54
|
+
output_ptr + output_offsets,
|
|
55
|
+
embeddings,
|
|
56
|
+
mask=block_mask,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@triton.jit
|
|
61
|
+
def embedding_backward_kernel(
|
|
62
|
+
grad_output_ptr,
|
|
63
|
+
grad_weight_ptr,
|
|
64
|
+
indices_ptr,
|
|
65
|
+
n_elements,
|
|
66
|
+
embedding_dim: tl.constexpr,
|
|
67
|
+
BLOCK_SIZE_M: tl.constexpr,
|
|
68
|
+
BLOCK_SIZE_N: tl.constexpr,
|
|
69
|
+
NUM_STAGES: tl.constexpr,
|
|
70
|
+
):
|
|
71
|
+
pid = tl.program_id(0)
|
|
72
|
+
num_progs = tl.num_programs(0)
|
|
73
|
+
|
|
74
|
+
grid_m = tl.cdiv(n_elements, BLOCK_SIZE_M)
|
|
75
|
+
grid_n = tl.cdiv(embedding_dim, BLOCK_SIZE_N)
|
|
76
|
+
total_2d_blocks = grid_m * grid_n
|
|
77
|
+
|
|
78
|
+
for block_idx in tl.range(pid, total_2d_blocks, num_progs, num_stages=NUM_STAGES):
|
|
79
|
+
block_m = block_idx // grid_n
|
|
80
|
+
block_n = block_idx % grid_n
|
|
81
|
+
|
|
82
|
+
start_m = block_m * BLOCK_SIZE_M
|
|
83
|
+
start_n = block_n * BLOCK_SIZE_N
|
|
84
|
+
|
|
85
|
+
offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M)
|
|
86
|
+
mask_m = offsets_m < n_elements
|
|
87
|
+
|
|
88
|
+
indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0)
|
|
89
|
+
|
|
90
|
+
offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N)
|
|
91
|
+
mask_n = offsets_n < embedding_dim
|
|
92
|
+
|
|
93
|
+
block_mask = mask_m[:, None] & mask_n[None, :]
|
|
94
|
+
|
|
95
|
+
grad_output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :]
|
|
96
|
+
grad_output = tl.load(
|
|
97
|
+
grad_output_ptr + grad_output_offsets,
|
|
98
|
+
mask=block_mask,
|
|
99
|
+
other=0.0,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
grad_weight_offsets = indices[:, None] * embedding_dim + offsets_n[None, :]
|
|
103
|
+
tl.atomic_add(
|
|
104
|
+
grad_weight_ptr + grad_weight_offsets,
|
|
105
|
+
grad_output,
|
|
106
|
+
mask=block_mask,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def get_optimal_block_size(total_elements, dtype_size, BLOCK_SIZE_N: tl.constexpr):
|
|
111
|
+
# 1. Set Memory Multiplier
|
|
112
|
+
# 3.0 are empirical values based on 910B UB (192KB)
|
|
113
|
+
# embedding_offsets, embedding_offsets : BLOCK_SIZE_N * BLOCK_SIZE_M (total 2 * BLOCK_SIZE_N * BLOCK_SIZE_M)
|
|
114
|
+
# Reserve a unit of space for the remaining one-dimensional ub to occupy.
|
|
115
|
+
# A conservative estimate of the total space occupation is 3 * BLOCK_SIZE_N * BLOCK_SIZE_M
|
|
116
|
+
multiplier = 3.0
|
|
117
|
+
|
|
118
|
+
# 2. Call calculation function
|
|
119
|
+
# Treat input as 1D (total_elements,), only tiling on dim 0
|
|
120
|
+
tile_shapes = compute_default_tiling_strategy(
|
|
121
|
+
safety_margin=0.9,
|
|
122
|
+
dtype_size=dtype_size,
|
|
123
|
+
memory_multiplier=multiplier,
|
|
124
|
+
shapes=((total_elements, BLOCK_SIZE_N),),
|
|
125
|
+
tiling_dims=(0,),
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# 3. Parse result
|
|
129
|
+
if tile_shapes and len(tile_shapes) > 0:
|
|
130
|
+
block_size = tile_shapes[0][0]
|
|
131
|
+
return block_size
|
|
132
|
+
else:
|
|
133
|
+
return triton.next_power_of_2(min(128, total_elements))
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def embedding_forward(embeddings, indices):
|
|
137
|
+
ori_shape = indices.shape
|
|
138
|
+
indices = indices.view(-1)
|
|
139
|
+
|
|
140
|
+
n_elements = indices.numel()
|
|
141
|
+
embedding_dim = embeddings.shape[1]
|
|
142
|
+
output = torch.empty(
|
|
143
|
+
indices.shape[0],
|
|
144
|
+
embeddings.shape[1],
|
|
145
|
+
device=indices.device,
|
|
146
|
+
dtype=embeddings.dtype,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# Due to the involvement of two-dimensional partitioning,
|
|
150
|
+
# the sizes of block_m and block_n in the ub space will influence each other.
|
|
151
|
+
# Considering that embedding_dim is usually relatively smaller in most cases,
|
|
152
|
+
# a value is first assigned to block_n, and then the largest possible block_m is used.
|
|
153
|
+
BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim))
|
|
154
|
+
BLOCK_SIZE_M = get_optimal_block_size(n_elements, embeddings.element_size(), BLOCK_SIZE_N)
|
|
155
|
+
num_cores = get_npu_core_count()
|
|
156
|
+
total_blocks = triton.cdiv(n_elements, BLOCK_SIZE_M) * triton.cdiv(embedding_dim, BLOCK_SIZE_N)
|
|
157
|
+
grid = min(num_cores, total_blocks)
|
|
158
|
+
|
|
159
|
+
embedding_forward_kernel[(grid,)](
|
|
160
|
+
embeddings,
|
|
161
|
+
indices,
|
|
162
|
+
output,
|
|
163
|
+
n_elements,
|
|
164
|
+
embedding_dim=embedding_dim,
|
|
165
|
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
|
166
|
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
|
167
|
+
NUM_STAGES=3,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
return output.view(*ori_shape, -1)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def embedding_backward(embeddings, indices, grad_output):
|
|
174
|
+
grad_output = grad_output.contiguous().view(-1, embeddings.shape[1])
|
|
175
|
+
|
|
176
|
+
grad_weight = torch.zeros_like(embeddings)
|
|
177
|
+
|
|
178
|
+
n_elements = indices.numel()
|
|
179
|
+
embedding_dim = embeddings.shape[1]
|
|
180
|
+
BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim))
|
|
181
|
+
BLOCK_SIZE_M = get_optimal_block_size(n_elements, embeddings.element_size(), BLOCK_SIZE_N)
|
|
182
|
+
num_cores = get_npu_core_count()
|
|
183
|
+
total_blocks = triton.cdiv(n_elements, BLOCK_SIZE_M) * triton.cdiv(embedding_dim, BLOCK_SIZE_N)
|
|
184
|
+
grid = min(num_cores, total_blocks)
|
|
185
|
+
|
|
186
|
+
embedding_backward_kernel[(grid,)](
|
|
187
|
+
grad_output,
|
|
188
|
+
grad_weight,
|
|
189
|
+
indices,
|
|
190
|
+
n_elements,
|
|
191
|
+
embedding_dim=embedding_dim,
|
|
192
|
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
|
193
|
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
|
194
|
+
NUM_STAGES=3,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
return grad_weight
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class LigerEmbeddingFunction(torch.autograd.Function):
|
|
201
|
+
@staticmethod
|
|
202
|
+
@ensure_contiguous
|
|
203
|
+
def forward(ctx, embeddings: torch.Tensor, indices: torch.Tensor):
|
|
204
|
+
output = embedding_forward(embeddings, indices)
|
|
205
|
+
ctx.save_for_backward(indices, embeddings)
|
|
206
|
+
return output
|
|
207
|
+
|
|
208
|
+
@staticmethod
|
|
209
|
+
@ensure_contiguous
|
|
210
|
+
def backward(ctx, grad_output: torch.Tensor):
|
|
211
|
+
indices, embeddings = ctx.saved_tensors
|
|
212
|
+
grad_weight = embedding_backward(embeddings, indices, grad_output)
|
|
213
|
+
|
|
214
|
+
return grad_weight, None
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
from triton.language.math import tanh
|
|
6
|
+
|
|
7
|
+
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
|
|
8
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
9
|
+
from liger_kernel.ops.utils import get_npu_core_count
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@triton.jit
|
|
13
|
+
def _geglu_forward_kernel_flat(a_ptr, b_ptr, c_ptr, total_elements, BLOCK_SIZE: tl.constexpr, NUM_STAGES: tl.constexpr):
|
|
14
|
+
"""
|
|
15
|
+
High-performance GEGLU forward kernel using flatten 1D approach.
|
|
16
|
+
|
|
17
|
+
Uses grid-stride loop pattern for optimal performance on NPU.
|
|
18
|
+
"""
|
|
19
|
+
pid = tl.program_id(0)
|
|
20
|
+
num_progs = tl.num_programs(0)
|
|
21
|
+
|
|
22
|
+
# Grid-Stride Loop
|
|
23
|
+
start_idx = pid * BLOCK_SIZE
|
|
24
|
+
stride = num_progs * BLOCK_SIZE
|
|
25
|
+
|
|
26
|
+
# Constants for GELU tanh approximation
|
|
27
|
+
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
|
|
28
|
+
gelu_coeff = 0.044715
|
|
29
|
+
|
|
30
|
+
for idx in tl.range(start_idx, total_elements, stride, num_stages=NUM_STAGES):
|
|
31
|
+
offsets = idx + tl.arange(0, BLOCK_SIZE)
|
|
32
|
+
mask = offsets < total_elements
|
|
33
|
+
|
|
34
|
+
a_val = tl.load(a_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
|
|
35
|
+
b_val = tl.load(b_ptr + offsets, mask=mask, other=0.0)
|
|
36
|
+
|
|
37
|
+
# tanh approximation form of GELU is computed with:
|
|
38
|
+
# 0.5 * a * (1 + tanh(sqrt(2 / pi) * (a + 0.044715 * a^3)))
|
|
39
|
+
a_cubed = a_val * a_val * a_val
|
|
40
|
+
tanh_arg = sqrt_2_over_pi * (a_val + gelu_coeff * a_cubed)
|
|
41
|
+
tanh_result = tanh(tanh_arg)
|
|
42
|
+
geglu_a = 0.5 * a_val * (1.0 + tanh_result)
|
|
43
|
+
c_row = geglu_a.cast(b_val.dtype) * b_val
|
|
44
|
+
tl.store(c_ptr + offsets, c_row, mask=mask)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@triton.jit
|
|
48
|
+
def _geglu_backward_kernel_flat(
|
|
49
|
+
dc_ptr, a_ptr, b_ptr, da_ptr, db_ptr, total_elements, BLOCK_SIZE: tl.constexpr, NUM_STAGES: tl.constexpr
|
|
50
|
+
):
|
|
51
|
+
"""
|
|
52
|
+
High-performance GEGLU backward kernel using flatten 1D approach.
|
|
53
|
+
|
|
54
|
+
Uses grid-stride loop pattern for optimal performance on NPU.
|
|
55
|
+
"""
|
|
56
|
+
pid = tl.program_id(0)
|
|
57
|
+
num_progs = tl.num_programs(0)
|
|
58
|
+
start_idx = pid * BLOCK_SIZE
|
|
59
|
+
stride = num_progs * BLOCK_SIZE
|
|
60
|
+
|
|
61
|
+
# Constants for GELU tanh approximation
|
|
62
|
+
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
|
|
63
|
+
gelu_coeff = 0.044715
|
|
64
|
+
|
|
65
|
+
for idx in tl.range(start_idx, total_elements, stride, num_stages=NUM_STAGES):
|
|
66
|
+
offsets = idx + tl.arange(0, BLOCK_SIZE)
|
|
67
|
+
mask = offsets < total_elements
|
|
68
|
+
|
|
69
|
+
dc = tl.load(dc_ptr + offsets, mask=mask, other=0.0)
|
|
70
|
+
a = tl.load(a_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
|
|
71
|
+
b = tl.load(b_ptr + offsets, mask=mask, other=0.0)
|
|
72
|
+
|
|
73
|
+
# recomputation to save memory
|
|
74
|
+
a_cubed = a * a * a
|
|
75
|
+
tanh_arg = sqrt_2_over_pi * (a + gelu_coeff * a_cubed)
|
|
76
|
+
tanh_result = tanh(tanh_arg)
|
|
77
|
+
geglu_a = 0.5 * a * (1 + tanh_result)
|
|
78
|
+
geglu_a = geglu_a.to(dc.dtype).to(tl.float32)
|
|
79
|
+
|
|
80
|
+
db = dc.cast(tl.float32) * geglu_a
|
|
81
|
+
|
|
82
|
+
# Gradient w.r.t. a can be computed with:
|
|
83
|
+
# b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
|
|
84
|
+
# where z = sqrt(2/pi) * (a + 0.044715 * a^3)
|
|
85
|
+
term1 = 0.5 * (1.0 + tanh_result)
|
|
86
|
+
tanh_sq = tanh_result * tanh_result
|
|
87
|
+
a_sq = a * a
|
|
88
|
+
term2 = 0.5 * a * (1.0 - tanh_sq) * (sqrt_2_over_pi * (1.0 + 3.0 * gelu_coeff * a_sq))
|
|
89
|
+
da = dc * b * (term1 + term2)
|
|
90
|
+
|
|
91
|
+
tl.store(da_ptr + offsets, da, mask=mask)
|
|
92
|
+
tl.store(db_ptr + offsets, db.to(dc.dtype), mask=mask)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def get_optimal_block_size(total_elements, is_backward=False):
|
|
96
|
+
"""
|
|
97
|
+
Calculate optimal Block Size using compute_default_tiling_strategy.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
total_elements: Total number of elements to process
|
|
101
|
+
is_backward: Whether this is for backward pass (requires more memory)
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
Optimal block size for the kernel
|
|
105
|
+
"""
|
|
106
|
+
# Memory multiplier based on peak memory usage analysis
|
|
107
|
+
if is_backward:
|
|
108
|
+
memory_multiplier = 6.0
|
|
109
|
+
else:
|
|
110
|
+
memory_multiplier = 3.0
|
|
111
|
+
# Call calculation function
|
|
112
|
+
# Treat input as 1D (total_elements,), only tiling on dim 0
|
|
113
|
+
tile_shapes = compute_default_tiling_strategy(
|
|
114
|
+
safety_margin=0.9,
|
|
115
|
+
dtype_size=4,
|
|
116
|
+
memory_multiplier=memory_multiplier,
|
|
117
|
+
shapes=((total_elements,),),
|
|
118
|
+
tiling_dims=(0,),
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# Parse result
|
|
122
|
+
if tile_shapes and len(tile_shapes) > 0:
|
|
123
|
+
block_size = tile_shapes[0][0]
|
|
124
|
+
return max(256, block_size)
|
|
125
|
+
else:
|
|
126
|
+
return 2048
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def geglu_forward(a, b):
|
|
130
|
+
"""
|
|
131
|
+
High-performance GEGLU forward pass for NPU using flatten 1D approach.
|
|
132
|
+
"""
|
|
133
|
+
if not a.is_contiguous():
|
|
134
|
+
a = a.contiguous()
|
|
135
|
+
if not b.is_contiguous():
|
|
136
|
+
b = b.contiguous()
|
|
137
|
+
|
|
138
|
+
total_elements = a.numel()
|
|
139
|
+
c = torch.empty_like(a)
|
|
140
|
+
|
|
141
|
+
block_size = get_optimal_block_size(total_elements, is_backward=False)
|
|
142
|
+
|
|
143
|
+
num_cores = get_npu_core_count()
|
|
144
|
+
grid_size = min(num_cores, (total_elements + block_size - 1) // block_size)
|
|
145
|
+
|
|
146
|
+
_geglu_forward_kernel_flat[(grid_size,)](a, b, c, total_elements, BLOCK_SIZE=block_size, NUM_STAGES=3, num_warps=4)
|
|
147
|
+
return c
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def geglu_backward(a, b, dc):
|
|
151
|
+
"""
|
|
152
|
+
High-performance GEGLU backward pass for NPU using flatten 1D approach.
|
|
153
|
+
"""
|
|
154
|
+
if not dc.is_contiguous():
|
|
155
|
+
dc = dc.contiguous()
|
|
156
|
+
if not a.is_contiguous():
|
|
157
|
+
a = a.contiguous()
|
|
158
|
+
if not b.is_contiguous():
|
|
159
|
+
b = b.contiguous()
|
|
160
|
+
|
|
161
|
+
total_elements = dc.numel()
|
|
162
|
+
grad_a = torch.empty_like(a)
|
|
163
|
+
grad_b = torch.empty_like(b)
|
|
164
|
+
|
|
165
|
+
block_size = get_optimal_block_size(total_elements, is_backward=True)
|
|
166
|
+
|
|
167
|
+
num_cores = get_npu_core_count()
|
|
168
|
+
grid_size = min(num_cores, (total_elements + block_size - 1) // block_size)
|
|
169
|
+
|
|
170
|
+
_geglu_backward_kernel_flat[(grid_size,)](
|
|
171
|
+
dc, a, b, grad_a, grad_b, total_elements, BLOCK_SIZE=block_size, NUM_STAGES=3, num_warps=4
|
|
172
|
+
)
|
|
173
|
+
return grad_a, grad_b
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class LigerGELUMulFunction(torch.autograd.Function):
|
|
177
|
+
"""High-performance GEGLU function for Ascend NPU."""
|
|
178
|
+
|
|
179
|
+
@staticmethod
|
|
180
|
+
@ensure_contiguous
|
|
181
|
+
def forward(ctx, a, b):
|
|
182
|
+
c = geglu_forward(a, b)
|
|
183
|
+
ctx.save_for_backward(a, b)
|
|
184
|
+
return c
|
|
185
|
+
|
|
186
|
+
@staticmethod
|
|
187
|
+
@ensure_contiguous
|
|
188
|
+
def backward(ctx, dc):
|
|
189
|
+
a, b = ctx.saved_tensors
|
|
190
|
+
grad_a, grad_b = geglu_backward(a, b, dc)
|
|
191
|
+
return grad_a, grad_b
|