liger-kernel-nightly 0.5.5.dev20250402185702__py3-none-any.whl → 0.6.4.dev20260112233432__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +142 -0
- liger_kernel/chunked_loss/dpo_loss.py +61 -3
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
- liger_kernel/chunked_loss/fused_linear_ppo.py +36 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
- liger_kernel/chunked_loss/grpo_loss.py +76 -5
- liger_kernel/chunked_loss/jsd_loss.py +46 -15
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +49 -0
- liger_kernel/ops/backends/_ascend/ops/geglu.py +266 -0
- liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
- liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
- liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
- liger_kernel/ops/backends/_ascend/ops/tvd.py +221 -0
- liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +134 -65
- liger_kernel/ops/dyt.py +115 -180
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +117 -23
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +6 -4
- liger_kernel/ops/group_norm.py +7 -7
- liger_kernel/ops/grpo_loss.py +312 -0
- liger_kernel/ops/jsd.py +2 -1
- liger_kernel/ops/kl_div.py +9 -5
- liger_kernel/ops/layer_norm.py +146 -78
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/rms_norm.py +398 -99
- liger_kernel/ops/rope.py +1 -1
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/utils.py +14 -0
- liger_kernel/transformers/__init__.py +208 -17
- liger_kernel/transformers/auto_model.py +21 -0
- liger_kernel/transformers/cross_entropy.py +9 -4
- liger_kernel/transformers/dyt.py +6 -4
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -1
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +122 -20
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
- liger_kernel/transformers/fused_linear_jsd.py +1 -1
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +1 -1
- liger_kernel/transformers/group_norm.py +1 -1
- liger_kernel/transformers/grpo_loss.py +153 -0
- liger_kernel/transformers/jsd.py +1 -1
- liger_kernel/transformers/kl_div.py +1 -1
- liger_kernel/transformers/layer_norm.py +1 -1
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/exaone4.py +136 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +57 -27
- liger_kernel/transformers/model/gemma2.py +65 -28
- liger_kernel/transformers/model/gemma3.py +331 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +109 -27
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +111 -136
- liger_kernel/transformers/model/loss_utils.py +50 -12
- liger_kernel/transformers/model/mistral.py +51 -34
- liger_kernel/transformers/model/mixtral.py +50 -29
- liger_kernel/transformers/model/mllama.py +46 -24
- liger_kernel/transformers/model/olmo2.py +47 -22
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +50 -14
- liger_kernel/transformers/model/phi3.py +47 -172
- liger_kernel/transformers/model/qwen2.py +55 -23
- liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
- liger_kernel/transformers/model/qwen2_vl.py +59 -108
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2018 -244
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +1 -1
- liger_kernel/transformers/rms_norm.py +54 -6
- liger_kernel/transformers/rope.py +45 -1
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +39 -1
- liger_kernel/transformers/tiled_mlp.py +125 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
- liger_kernel/transformers/tvd.py +1 -1
- liger_kernel/utils.py +63 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/METADATA +73 -39
- liger_kernel_nightly-0.6.4.dev20260112233432.dist-info/RECORD +132 -0
- liger_kernel_nightly-0.5.5.dev20250402185702.dist-info/RECORD +0 -80
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402185702.dist-info → liger_kernel_nightly-0.6.4.dev20260112233432.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
import triton.language as tl
|
|
6
|
+
|
|
7
|
+
from liger_kernel.ops.utils import calculate_settings
|
|
8
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@triton.jit
|
|
12
|
+
def _softmax_single_block_forward_kernel(
|
|
13
|
+
Y_ptr,
|
|
14
|
+
Y_row_stride,
|
|
15
|
+
X_ptr,
|
|
16
|
+
X_row_stride,
|
|
17
|
+
n_cols,
|
|
18
|
+
BLOCK_SIZE: tl.constexpr,
|
|
19
|
+
):
|
|
20
|
+
row_id = tl.program_id(0)
|
|
21
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
|
22
|
+
mask = offs < n_cols
|
|
23
|
+
|
|
24
|
+
x = tl.load(X_ptr + row_id * X_row_stride + offs, mask=mask, other=-float("inf"), cache_modifier=".ca")
|
|
25
|
+
m = tl.max(x, axis=0)
|
|
26
|
+
e = tl.exp(x - m)
|
|
27
|
+
d = tl.sum(e, axis=0)
|
|
28
|
+
y = e / d
|
|
29
|
+
tl.store(Y_ptr + row_id * Y_row_stride + offs, y, mask=mask, cache_modifier=".cs")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@triton.jit
|
|
33
|
+
def _softmax_multi_block_forward_kernel(
|
|
34
|
+
Y_ptr,
|
|
35
|
+
Y_row_stride,
|
|
36
|
+
X_ptr,
|
|
37
|
+
X_row_stride,
|
|
38
|
+
n_cols,
|
|
39
|
+
BLOCK_SIZE: tl.constexpr,
|
|
40
|
+
):
|
|
41
|
+
row_id = tl.program_id(0)
|
|
42
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
|
43
|
+
|
|
44
|
+
m = tl.float32(-float("inf"))
|
|
45
|
+
d = tl.float32(0.0)
|
|
46
|
+
for start in tl.range(0, n_cols, BLOCK_SIZE):
|
|
47
|
+
idx = start + offs
|
|
48
|
+
mask = idx < n_cols
|
|
49
|
+
xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca")
|
|
50
|
+
blk_max = tl.max(xblk, axis=0)
|
|
51
|
+
new_m = tl.max(m, blk_max)
|
|
52
|
+
d = d * tl.exp(m - new_m) + tl.sum(tl.exp(xblk - new_m), axis=0)
|
|
53
|
+
m = new_m
|
|
54
|
+
|
|
55
|
+
for start in tl.range(0, n_cols, BLOCK_SIZE):
|
|
56
|
+
idx = start + offs
|
|
57
|
+
mask = idx < n_cols
|
|
58
|
+
xblk = tl.load(X_ptr + row_id * X_row_stride + idx, mask=mask, other=-float("inf"), cache_modifier=".ca")
|
|
59
|
+
yblk = tl.exp(xblk - m) / d
|
|
60
|
+
tl.store(Y_ptr + row_id * Y_row_stride + idx, yblk, mask=mask, cache_modifier=".cs")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@triton.jit
|
|
64
|
+
def _softmax_single_block_backward_kernel(
|
|
65
|
+
dy_ptr,
|
|
66
|
+
dy_stride,
|
|
67
|
+
y_ptr,
|
|
68
|
+
y_stride,
|
|
69
|
+
dx_ptr,
|
|
70
|
+
dx_stride,
|
|
71
|
+
n_cols,
|
|
72
|
+
BLOCK_SIZE: tl.constexpr,
|
|
73
|
+
):
|
|
74
|
+
row_id = tl.program_id(0)
|
|
75
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
|
76
|
+
mask = offs < n_cols
|
|
77
|
+
|
|
78
|
+
dy = tl.load(dy_ptr + row_id * dy_stride + offs, mask=mask, other=0.0)
|
|
79
|
+
y = tl.load(y_ptr + row_id * y_stride + offs, mask=mask, other=0.0, cache_modifier=".ca")
|
|
80
|
+
dot = tl.sum(dy * y, axis=0)
|
|
81
|
+
dx = y * (dy - dot)
|
|
82
|
+
tl.store(dx_ptr + row_id * dx_stride + offs, dx, mask=mask, cache_modifier=".wb")
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@triton.jit
|
|
86
|
+
def _softmax_multi_block_backward_kernel(
|
|
87
|
+
dy_ptr,
|
|
88
|
+
dy_stride,
|
|
89
|
+
y_ptr,
|
|
90
|
+
y_stride,
|
|
91
|
+
dx_ptr,
|
|
92
|
+
dx_stride,
|
|
93
|
+
n_cols,
|
|
94
|
+
BLOCK_SIZE: tl.constexpr,
|
|
95
|
+
):
|
|
96
|
+
row_id = tl.program_id(0)
|
|
97
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
|
98
|
+
acc = tl.float32(0.0)
|
|
99
|
+
|
|
100
|
+
for start in tl.range(0, n_cols, BLOCK_SIZE):
|
|
101
|
+
idx = start + offs
|
|
102
|
+
mask = idx < n_cols
|
|
103
|
+
dy_blk = tl.load(dy_ptr + row_id * dy_stride + idx, mask=mask, other=0.0)
|
|
104
|
+
y_blk = tl.load(y_ptr + row_id * y_stride + idx, mask=mask, other=0.0, cache_modifier=".ca")
|
|
105
|
+
acc += tl.sum(dy_blk * y_blk, axis=0)
|
|
106
|
+
|
|
107
|
+
for start in tl.range(0, n_cols, BLOCK_SIZE):
|
|
108
|
+
idx = start + offs
|
|
109
|
+
mask = idx < n_cols
|
|
110
|
+
dy_blk = tl.load(dy_ptr + row_id * dy_stride + idx, mask=mask, other=0.0)
|
|
111
|
+
y_blk = tl.load(y_ptr + row_id * y_stride + idx, mask=mask, other=0.0, cache_modifier=".ca")
|
|
112
|
+
dx_blk = y_blk * (dy_blk - acc)
|
|
113
|
+
tl.store(dx_ptr + row_id * dx_stride + idx, dx_blk, mask=mask, cache_modifier=".wb")
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _softmax_forward(x: torch.Tensor) -> Tuple[torch.Tensor, int, int, bool]:
|
|
117
|
+
*batch, n_cols = x.shape
|
|
118
|
+
x2d = x.contiguous().view(-1, n_cols)
|
|
119
|
+
n_rows = x2d.shape[0]
|
|
120
|
+
|
|
121
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
122
|
+
y2d = torch.empty_like(x2d)
|
|
123
|
+
|
|
124
|
+
if n_cols <= BLOCK_SIZE:
|
|
125
|
+
_softmax_single_block_forward_kernel[(n_rows,)](
|
|
126
|
+
y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
|
|
127
|
+
)
|
|
128
|
+
multi_block_launch = False
|
|
129
|
+
else:
|
|
130
|
+
_softmax_multi_block_forward_kernel[(n_rows,)](
|
|
131
|
+
y2d, y2d.stride(0), x2d, x2d.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
|
|
132
|
+
)
|
|
133
|
+
multi_block_launch = True
|
|
134
|
+
|
|
135
|
+
return y2d.view(*batch, n_cols), BLOCK_SIZE, num_warps, multi_block_launch
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _softmax_backward(
|
|
139
|
+
dy: torch.Tensor,
|
|
140
|
+
y: torch.Tensor,
|
|
141
|
+
BLOCK_SIZE: int,
|
|
142
|
+
num_warps: int,
|
|
143
|
+
multi_block_launch: bool,
|
|
144
|
+
) -> torch.Tensor:
|
|
145
|
+
*batch, n_cols = dy.shape
|
|
146
|
+
dy2d = dy.contiguous().view(-1, n_cols)
|
|
147
|
+
y2d = y.contiguous().view(-1, n_cols)
|
|
148
|
+
n_rows = dy2d.shape[0]
|
|
149
|
+
dx2d = torch.empty_like(dy2d)
|
|
150
|
+
|
|
151
|
+
if not multi_block_launch and n_cols <= BLOCK_SIZE:
|
|
152
|
+
_softmax_single_block_backward_kernel[(n_rows,)](
|
|
153
|
+
dy2d,
|
|
154
|
+
dy2d.stride(0),
|
|
155
|
+
y2d,
|
|
156
|
+
y2d.stride(0),
|
|
157
|
+
dx2d,
|
|
158
|
+
dx2d.stride(0),
|
|
159
|
+
n_cols,
|
|
160
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
161
|
+
num_warps=num_warps,
|
|
162
|
+
)
|
|
163
|
+
else:
|
|
164
|
+
_softmax_multi_block_backward_kernel[(n_rows,)](
|
|
165
|
+
dy2d,
|
|
166
|
+
dy2d.stride(0),
|
|
167
|
+
y2d,
|
|
168
|
+
y2d.stride(0),
|
|
169
|
+
dx2d,
|
|
170
|
+
dx2d.stride(0),
|
|
171
|
+
n_cols,
|
|
172
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
173
|
+
num_warps=num_warps,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
return dx2d.view(*batch, n_cols)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class LigerSoftmaxFunction(torch.autograd.Function):
|
|
180
|
+
@staticmethod
|
|
181
|
+
@ensure_contiguous
|
|
182
|
+
def forward(ctx, input_: torch.Tensor):
|
|
183
|
+
y, BLOCK_SIZE, num_warps, multi_block_launch = _softmax_forward(input_)
|
|
184
|
+
ctx.save_for_backward(y)
|
|
185
|
+
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
186
|
+
ctx.num_warps = num_warps
|
|
187
|
+
ctx.multi_block_launch = multi_block_launch
|
|
188
|
+
return y
|
|
189
|
+
|
|
190
|
+
@staticmethod
|
|
191
|
+
@ensure_contiguous
|
|
192
|
+
def backward(ctx, grad_output):
|
|
193
|
+
(y,) = ctx.saved_tensors
|
|
194
|
+
dx = _softmax_backward(
|
|
195
|
+
grad_output,
|
|
196
|
+
y,
|
|
197
|
+
ctx.BLOCK_SIZE,
|
|
198
|
+
ctx.num_warps,
|
|
199
|
+
ctx.multi_block_launch,
|
|
200
|
+
)
|
|
201
|
+
return dx
|
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
import triton.language as tl
|
|
6
|
+
|
|
7
|
+
from liger_kernel.ops.utils import calculate_settings
|
|
8
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@triton.jit
|
|
12
|
+
def _sparsemax_forward_kernel(
|
|
13
|
+
x_ptr,
|
|
14
|
+
x_stride_row,
|
|
15
|
+
sorted_x_ptr,
|
|
16
|
+
sorted_x_stride_row,
|
|
17
|
+
o_ptr,
|
|
18
|
+
o_stride_row,
|
|
19
|
+
n_cols,
|
|
20
|
+
BLOCK_SIZE: tl.constexpr,
|
|
21
|
+
num_warps: tl.constexpr,
|
|
22
|
+
):
|
|
23
|
+
pid_row = tl.program_id(0)
|
|
24
|
+
ptr_x_data_row = x_ptr + pid_row * x_stride_row
|
|
25
|
+
ptr_sorted_x_data_row = sorted_x_ptr + pid_row * sorted_x_stride_row
|
|
26
|
+
ptr_output_row = o_ptr + pid_row * o_stride_row
|
|
27
|
+
|
|
28
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
|
29
|
+
mask = offs < n_cols
|
|
30
|
+
|
|
31
|
+
z_sorted_block = tl.load(
|
|
32
|
+
ptr_sorted_x_data_row + offs,
|
|
33
|
+
mask=mask,
|
|
34
|
+
other=-float("inf"),
|
|
35
|
+
cache_modifier=".ca",
|
|
36
|
+
).to(tl.float32)
|
|
37
|
+
|
|
38
|
+
z_valid = tl.where(mask, z_sorted_block, 0.0)
|
|
39
|
+
cssv = tl.cumsum(z_valid, 0)
|
|
40
|
+
|
|
41
|
+
r = (offs + 1).to(tl.float32)
|
|
42
|
+
safe_r = tl.where(mask, r, 1.0)
|
|
43
|
+
|
|
44
|
+
t_vec = (cssv - 1.0) / safe_r
|
|
45
|
+
|
|
46
|
+
support = (z_sorted_block > t_vec) & mask
|
|
47
|
+
|
|
48
|
+
k_int = tl.sum(support.to(tl.int32), 0)
|
|
49
|
+
k_clamped_int = tl.maximum(k_int, 1)
|
|
50
|
+
k = k_clamped_int.to(tl.float32)
|
|
51
|
+
|
|
52
|
+
s = tl.sum(tl.where(support, z_sorted_block, 0.0), 0)
|
|
53
|
+
|
|
54
|
+
tau = (s - 1.0) / k
|
|
55
|
+
|
|
56
|
+
x_block = tl.load(
|
|
57
|
+
ptr_x_data_row + offs,
|
|
58
|
+
mask=mask,
|
|
59
|
+
other=0.0,
|
|
60
|
+
cache_modifier=".ca",
|
|
61
|
+
).to(tl.float32)
|
|
62
|
+
|
|
63
|
+
y = tl.maximum(x_block - tau, 0.0)
|
|
64
|
+
|
|
65
|
+
tl.store(
|
|
66
|
+
ptr_output_row + offs,
|
|
67
|
+
y.to(ptr_output_row.dtype.element_ty),
|
|
68
|
+
mask=mask,
|
|
69
|
+
cache_modifier=".cs",
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@triton.jit
|
|
74
|
+
def _sparsemax_backward_kernel(
|
|
75
|
+
o_ptr, go_ptr, gi_ptr, stride, n_cols, BLOCK_SIZE: tl.constexpr, num_warps: tl.constexpr
|
|
76
|
+
):
|
|
77
|
+
row = tl.program_id(0)
|
|
78
|
+
o_row = o_ptr + row * stride
|
|
79
|
+
go_row = go_ptr + row * stride
|
|
80
|
+
gi_row = gi_ptr + row * stride
|
|
81
|
+
|
|
82
|
+
offs = tl.arange(0, BLOCK_SIZE)
|
|
83
|
+
|
|
84
|
+
supp_cnt = tl.zeros((), tl.float32)
|
|
85
|
+
go_sum = tl.zeros((), tl.float32)
|
|
86
|
+
|
|
87
|
+
for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
|
|
88
|
+
offs_iter = i * BLOCK_SIZE + offs
|
|
89
|
+
mask_iter = offs_iter < n_cols
|
|
90
|
+
o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
|
|
91
|
+
go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
|
|
92
|
+
supp = o_val > 0.0
|
|
93
|
+
go_sum += tl.sum(tl.where(supp, go_val, 0.0))
|
|
94
|
+
supp_cnt += tl.sum(supp.to(tl.float32))
|
|
95
|
+
|
|
96
|
+
for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
|
|
97
|
+
offs_iter = i * BLOCK_SIZE + offs
|
|
98
|
+
mask_iter = offs_iter < n_cols
|
|
99
|
+
o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
|
|
100
|
+
go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
|
|
101
|
+
supp = o_val > 0.0
|
|
102
|
+
gi_val = tl.where(
|
|
103
|
+
supp,
|
|
104
|
+
go_val - tl.cast(go_sum / tl.maximum(supp_cnt, 1e-6), gi_row.dtype.element_ty).to(tl.float32),
|
|
105
|
+
0.0,
|
|
106
|
+
)
|
|
107
|
+
tl.store(gi_row + offs_iter, gi_val.to(gi_row.dtype.element_ty), mask=mask_iter, cache_modifier=".wb")
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _sparsemax_forward(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
111
|
+
if dim < 0:
|
|
112
|
+
dim += x.dim()
|
|
113
|
+
x_sw = x.transpose(dim, -1).contiguous()
|
|
114
|
+
n_cols = x_sw.size(-1)
|
|
115
|
+
n_rows = x_sw.numel() // n_cols
|
|
116
|
+
x_flat = x_sw.view(n_rows, n_cols)
|
|
117
|
+
x_sorted_flat = torch.sort(x_flat.float(), dim=-1, descending=True).values
|
|
118
|
+
|
|
119
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
120
|
+
out_flat = torch.empty_like(x_flat)
|
|
121
|
+
grid = (n_rows,)
|
|
122
|
+
_sparsemax_forward_kernel[grid](
|
|
123
|
+
x_flat,
|
|
124
|
+
x_flat.stride(0),
|
|
125
|
+
x_sorted_flat,
|
|
126
|
+
x_sorted_flat.stride(0),
|
|
127
|
+
out_flat,
|
|
128
|
+
out_flat.stride(0),
|
|
129
|
+
n_cols,
|
|
130
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
131
|
+
num_warps=num_warps,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
y = out_flat.view_as(x_sw).transpose(dim, -1)
|
|
135
|
+
return y, out_flat
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _sparsemax_backward(
|
|
139
|
+
grad_out: torch.Tensor,
|
|
140
|
+
out_flat: torch.Tensor,
|
|
141
|
+
dim: int,
|
|
142
|
+
) -> torch.Tensor:
|
|
143
|
+
grad_sw = grad_out.transpose(dim, -1).contiguous()
|
|
144
|
+
n_cols = grad_sw.size(-1)
|
|
145
|
+
n_rows = grad_sw.numel() // n_cols
|
|
146
|
+
go_flat = grad_sw.view(n_rows, n_cols)
|
|
147
|
+
|
|
148
|
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
149
|
+
dx_flat = torch.empty_like(go_flat)
|
|
150
|
+
grid = (n_rows,)
|
|
151
|
+
_sparsemax_backward_kernel[grid](
|
|
152
|
+
out_flat,
|
|
153
|
+
go_flat,
|
|
154
|
+
dx_flat,
|
|
155
|
+
out_flat.stride(0),
|
|
156
|
+
n_cols,
|
|
157
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
158
|
+
num_warps=num_warps,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
dx = dx_flat.view_as(grad_sw).transpose(dim, -1)
|
|
162
|
+
return dx
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class LigerSparsemaxFunction(torch.autograd.Function):
|
|
166
|
+
@staticmethod
|
|
167
|
+
@ensure_contiguous
|
|
168
|
+
def forward(ctx, x: torch.Tensor, dim: int):
|
|
169
|
+
y, out_flat = _sparsemax_forward(x, dim)
|
|
170
|
+
ctx.save_for_backward(out_flat)
|
|
171
|
+
ctx.dim = dim
|
|
172
|
+
return y
|
|
173
|
+
|
|
174
|
+
@staticmethod
|
|
175
|
+
@ensure_contiguous
|
|
176
|
+
def backward(ctx, grad_out: torch.Tensor):
|
|
177
|
+
(out_flat,) = ctx.saved_tensors
|
|
178
|
+
dx = _sparsemax_backward(grad_out, out_flat, ctx.dim)
|
|
179
|
+
return dx, None
|
liger_kernel/ops/swiglu.py
CHANGED
|
@@ -26,7 +26,7 @@ def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BL
|
|
|
26
26
|
# sigmoid requires type float32
|
|
27
27
|
a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
|
|
28
28
|
b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
|
|
29
|
-
c_row = silu(a_row) * b_row
|
|
29
|
+
c_row = silu(a_row).cast(b_row.dtype) * b_row
|
|
30
30
|
tl.store(c_ptr + col_offsets, c_row, mask=mask)
|
|
31
31
|
|
|
32
32
|
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
from typing import Callable
|
|
4
|
+
from typing import List
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LigerTiledMLPFunction(torch.autograd.Function):
|
|
13
|
+
"""
|
|
14
|
+
Based on DeepSpeed's TiledMLP:
|
|
15
|
+
https://github.com/deepspeedai/DeepSpeed/blob/v0.18.2/deepspeed/runtime/sequence_parallel/ulysses_sp.py#L838
|
|
16
|
+
|
|
17
|
+
Perform a tiled MLP computation to massively reduce memory usage needed to compute MLP
|
|
18
|
+
when using very long sequence lengths.
|
|
19
|
+
|
|
20
|
+
This module re-computes `forward` in the `backward`. So the `forward` occurs twice each iteration.
|
|
21
|
+
And if you're using activation checkpointing it then occurs thrice.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
fn: the function to call on sharded inputs (e.g., mlp.forward)
|
|
25
|
+
mlp_module: the MLP nn.Module object
|
|
26
|
+
x: the input to MLP.forward (hidden_states)
|
|
27
|
+
shards: how many shards to use
|
|
28
|
+
compute_params: a list of weights engaged in the compute
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
the computed hidden_states
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
@ensure_contiguous
|
|
36
|
+
def forward(
|
|
37
|
+
ctx,
|
|
38
|
+
fn: Callable,
|
|
39
|
+
mlp_module: torch.nn.Module,
|
|
40
|
+
x: torch.Tensor,
|
|
41
|
+
shards: int,
|
|
42
|
+
compute_params: Optional[List[torch.nn.Parameter]] = None,
|
|
43
|
+
) -> torch.Tensor:
|
|
44
|
+
ctx.fn = fn
|
|
45
|
+
ctx.mlp_module = mlp_module
|
|
46
|
+
ctx.shards = shards
|
|
47
|
+
ctx.save_for_backward(x)
|
|
48
|
+
|
|
49
|
+
# x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
|
|
50
|
+
x_shards = list(torch.chunk(x, chunks=shards, dim=-2))
|
|
51
|
+
with torch.no_grad():
|
|
52
|
+
output_shards = [fn(mlp_module, x_shard) for x_shard in x_shards]
|
|
53
|
+
output_unsharded = torch.cat(output_shards, dim=-2)
|
|
54
|
+
|
|
55
|
+
return output_unsharded
|
|
56
|
+
|
|
57
|
+
@staticmethod
|
|
58
|
+
@ensure_contiguous
|
|
59
|
+
def backward(ctx, *grads) -> tuple:
|
|
60
|
+
fn = ctx.fn
|
|
61
|
+
(x,) = ctx.saved_tensors
|
|
62
|
+
mlp_module = ctx.mlp_module
|
|
63
|
+
shards = ctx.shards
|
|
64
|
+
|
|
65
|
+
x_requires_grad = x.requires_grad
|
|
66
|
+
x = x.detach()
|
|
67
|
+
# detach() unsets x.requires_grad, so restore it
|
|
68
|
+
x.requires_grad_(x_requires_grad)
|
|
69
|
+
|
|
70
|
+
# x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
|
|
71
|
+
hidden_size = x.shape[-1]
|
|
72
|
+
x_shape_orig = x.shape
|
|
73
|
+
|
|
74
|
+
# flatten bs+seqlen to avoid having stride issues when narrowing into seqlen w/ bs>1
|
|
75
|
+
x = x.view(-1, hidden_size)
|
|
76
|
+
incoming_grad = grads[0].view(-1, hidden_size)
|
|
77
|
+
x_grad = torch.zeros_like(x)
|
|
78
|
+
|
|
79
|
+
x_shards = list(torch.chunk(x, chunks=shards, dim=0))
|
|
80
|
+
|
|
81
|
+
for i, x_shard in enumerate(x_shards):
|
|
82
|
+
x_shard.requires_grad_(x_requires_grad)
|
|
83
|
+
|
|
84
|
+
# if seqlen is not exactly divisible by shards the last step will be shorter than shard_step
|
|
85
|
+
shard_step = x_shards[i].shape[0]
|
|
86
|
+
shard_offset = i * x_shards[0].shape[0]
|
|
87
|
+
|
|
88
|
+
x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
|
|
89
|
+
incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
|
|
90
|
+
|
|
91
|
+
with torch.enable_grad():
|
|
92
|
+
output = fn(mlp_module, x_shard)
|
|
93
|
+
torch.autograd.backward(output, incoming_grad_shard)
|
|
94
|
+
|
|
95
|
+
# unflatten
|
|
96
|
+
x_grad = x_grad.view(x_shape_orig)
|
|
97
|
+
|
|
98
|
+
return (None, None, x_grad, None, None)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def apply_tiled_mlp(
|
|
102
|
+
fn: Callable,
|
|
103
|
+
mlp_module: torch.nn.Module,
|
|
104
|
+
x: torch.Tensor,
|
|
105
|
+
num_shards: Optional[int] = None,
|
|
106
|
+
compute_params: Optional[List[torch.nn.Parameter]] = None,
|
|
107
|
+
) -> torch.Tensor:
|
|
108
|
+
"""
|
|
109
|
+
Apply tiled MLP computation for memory efficiency.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
fn: the function to call on sharded inputs (e.g., lambda module, x: module(x))
|
|
113
|
+
mlp_module: the MLP nn.Module object
|
|
114
|
+
x: the input tensor with shape [bs, seqlen, hidden_size] or [seqlen, hidden_size]
|
|
115
|
+
num_shards: number of shards to use. If None, automatically calculated as ceil(seqlen / hidden_size)
|
|
116
|
+
compute_params: list of parameters for DeepSpeed ZeRO optimization
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
output tensor with the same shape as input
|
|
120
|
+
"""
|
|
121
|
+
if num_shards is None:
|
|
122
|
+
# x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size]
|
|
123
|
+
hidden_size = x.shape[-1]
|
|
124
|
+
seqlen = x.shape[-2]
|
|
125
|
+
num_shards = math.ceil(seqlen / hidden_size)
|
|
126
|
+
|
|
127
|
+
# Ensure num_shards is at least 1
|
|
128
|
+
num_shards = max(1, num_shards)
|
|
129
|
+
|
|
130
|
+
return LigerTiledMLPFunction.apply(
|
|
131
|
+
fn,
|
|
132
|
+
mlp_module,
|
|
133
|
+
x,
|
|
134
|
+
num_shards,
|
|
135
|
+
compute_params,
|
|
136
|
+
)
|
liger_kernel/ops/utils.py
CHANGED
|
@@ -78,6 +78,8 @@ def get_amp_custom_fwd_bwd() -> Callable:
|
|
|
78
78
|
functools.partial(torch.amp.custom_fwd, device_type=device),
|
|
79
79
|
functools.partial(torch.amp.custom_bwd, device_type=device),
|
|
80
80
|
)
|
|
81
|
+
if hasattr(torch, "npu") and getattr(torch.npu, "amp", None) is not None:
|
|
82
|
+
return torch.npu.amp.custom_fwd, torch.npu.amp.custom_bwd
|
|
81
83
|
return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
|
|
82
84
|
|
|
83
85
|
|
|
@@ -125,3 +127,15 @@ def element_mul_kernel(
|
|
|
125
127
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
126
128
|
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
|
|
127
129
|
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def get_npu_core_count(default: int = 20) -> int:
|
|
133
|
+
"""Return NPU vector core count.
|
|
134
|
+
Fallback to `default` if Triton runtime or NPU device is unavailable.
|
|
135
|
+
"""
|
|
136
|
+
try:
|
|
137
|
+
utils = triton.runtime.driver.active.utils
|
|
138
|
+
props = utils.get_device_properties(0)
|
|
139
|
+
return int(props.get("num_vectorcore", default))
|
|
140
|
+
except Exception:
|
|
141
|
+
return default
|