liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/__init__.py +0 -0
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +8 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/cpo_loss.py +157 -0
- liger_kernel/chunked_loss/dpo_loss.py +229 -0
- liger_kernel/chunked_loss/functional.py +17 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
- liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
- liger_kernel/chunked_loss/grpo_loss.py +304 -0
- liger_kernel/chunked_loss/jsd_loss.py +200 -0
- liger_kernel/chunked_loss/kto_loss.py +210 -0
- liger_kernel/chunked_loss/orpo_loss.py +144 -0
- liger_kernel/chunked_loss/simpo_loss.py +165 -0
- liger_kernel/env_report.py +21 -4
- liger_kernel/ops/cross_entropy.py +235 -84
- liger_kernel/ops/dyt.py +157 -0
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_add_rms_norm.py +412 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
- liger_kernel/ops/fused_linear_jsd.py +17 -34
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +7 -18
- liger_kernel/ops/group_norm.py +305 -0
- liger_kernel/ops/grpo_loss.py +310 -0
- liger_kernel/ops/jsd.py +46 -21
- liger_kernel/ops/kl_div.py +23 -19
- liger_kernel/ops/layer_norm.py +150 -86
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +386 -0
- liger_kernel/ops/qwen2vl_mrope.py +222 -0
- liger_kernel/ops/rms_norm.py +314 -84
- liger_kernel/ops/rope.py +32 -34
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +5 -9
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +8 -4
- liger_kernel/transformers/__init__.py +199 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +33 -20
- liger_kernel/transformers/dyt.py +22 -0
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +291 -13
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
- liger_kernel/transformers/fused_linear_jsd.py +1 -4
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +50 -0
- liger_kernel/transformers/grpo_loss.py +98 -0
- liger_kernel/transformers/jsd.py +2 -7
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +77 -77
- liger_kernel/transformers/model/gemma2.py +283 -0
- liger_kernel/transformers/model/gemma3.py +331 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +128 -79
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +344 -0
- liger_kernel/transformers/model/loss_utils.py +95 -0
- liger_kernel/transformers/model/mistral.py +68 -64
- liger_kernel/transformers/model/mixtral.py +75 -91
- liger_kernel/transformers/model/mllama.py +63 -68
- liger_kernel/transformers/model/olmo2.py +141 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +432 -0
- liger_kernel/transformers/model/phi3.py +59 -213
- liger_kernel/transformers/model/qwen2.py +75 -72
- liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
- liger_kernel/transformers/model/qwen2_vl.py +78 -98
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2106 -289
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +20 -0
- liger_kernel/transformers/rms_norm.py +57 -6
- liger_kernel/transformers/rope.py +45 -2
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +23 -8
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/trainer/__init__.py +4 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- liger_kernel/utils.py +71 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
- liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
- liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
liger_kernel/ops/geglu.py
CHANGED
|
@@ -4,11 +4,9 @@ import torch
|
|
|
4
4
|
import triton
|
|
5
5
|
import triton.language as tl
|
|
6
6
|
|
|
7
|
-
from liger_kernel.ops.utils import
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
ensure_contiguous,
|
|
11
|
-
)
|
|
7
|
+
from liger_kernel.ops.utils import calculate_settings
|
|
8
|
+
from liger_kernel.ops.utils import compare_version
|
|
9
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
12
10
|
|
|
13
11
|
if compare_version("triton", operator.ge, "3.0.0"):
|
|
14
12
|
try:
|
|
@@ -22,9 +20,7 @@ else:
|
|
|
22
20
|
|
|
23
21
|
|
|
24
22
|
@triton.jit
|
|
25
|
-
def _geglu_tanh_forward_kernel(
|
|
26
|
-
a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
|
27
|
-
):
|
|
23
|
+
def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
|
28
24
|
program_id = tl.program_id(0).to(tl.int64)
|
|
29
25
|
|
|
30
26
|
# locate start index
|
|
@@ -44,14 +40,12 @@ def _geglu_tanh_forward_kernel(
|
|
|
44
40
|
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
|
|
45
41
|
tanh_result = tanh(tanh_arg)
|
|
46
42
|
geglu_a = 0.5 * a_row * (1 + tanh_result)
|
|
47
|
-
c_row = geglu_a * b_row
|
|
43
|
+
c_row = geglu_a.cast(b_row.dtype) * b_row
|
|
48
44
|
tl.store(c + col_offsets, c_row, mask=mask)
|
|
49
45
|
|
|
50
46
|
|
|
51
47
|
@triton.jit
|
|
52
|
-
def _geglu_tanh_backward_kernel(
|
|
53
|
-
dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr
|
|
54
|
-
):
|
|
48
|
+
def _geglu_tanh_backward_kernel(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
|
|
55
49
|
program_id = tl.program_id(0).to(tl.int64)
|
|
56
50
|
|
|
57
51
|
# locate start index
|
|
@@ -80,12 +74,7 @@ def _geglu_tanh_backward_kernel(
|
|
|
80
74
|
# where z = sqrt(2/pi) * (a + 0.044715 * a^3)
|
|
81
75
|
term1 = 0.5 * (1 + tanh_result)
|
|
82
76
|
tanh_sq = tanh_result * tanh_result
|
|
83
|
-
term2 = (
|
|
84
|
-
0.5
|
|
85
|
-
* a_row
|
|
86
|
-
* (1 - tanh_sq)
|
|
87
|
-
* (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
|
|
88
|
-
)
|
|
77
|
+
term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
|
|
89
78
|
da_row = dc_row * b_row * (term1 + term2)
|
|
90
79
|
|
|
91
80
|
tl.store(a + col_offsets, da_row, mask=mask)
|
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
import operator
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import triton
|
|
5
|
+
import triton.language as tl
|
|
6
|
+
|
|
7
|
+
from liger_kernel.ops.utils import compare_version
|
|
8
|
+
from liger_kernel.ops.utils import ensure_contiguous
|
|
9
|
+
|
|
10
|
+
if compare_version("triton", operator.ge, "3.0.0"):
|
|
11
|
+
try:
|
|
12
|
+
# typical import path with dispatch available
|
|
13
|
+
from triton.language.extra.libdevice import rsqrt
|
|
14
|
+
except ModuleNotFoundError:
|
|
15
|
+
# for working with NGC containers
|
|
16
|
+
from triton.language.extra.cuda.libdevice import rsqrt
|
|
17
|
+
else:
|
|
18
|
+
from triton.language.math import rsqrt
|
|
19
|
+
|
|
20
|
+
MAX_FUSED_SIZE = 65536
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@triton.jit
|
|
24
|
+
def _group_norm_forward_kernel(
|
|
25
|
+
Y_ptr, # pointer to output, shape (n_rows, n_groups, hidden_size)
|
|
26
|
+
Y_row_stride, # stride of each row in output
|
|
27
|
+
Y_col_stride, # stride of each column in output
|
|
28
|
+
X_ptr, # pointer to input, shape (n_rows, n_groups, hidden_size)
|
|
29
|
+
X_row_stride, # stride of each row in input
|
|
30
|
+
X_col_stride, # stride of each column in input
|
|
31
|
+
Mean_ptr, # pointer to mean, shape (n_rows, n_groups)
|
|
32
|
+
Mean_row_stride, # stride of each row in mean
|
|
33
|
+
Mean_col_stride, # stride of each column in mean
|
|
34
|
+
RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups)
|
|
35
|
+
RSTD_row_stride, # stride of each row in rstd
|
|
36
|
+
RSTD_col_stride, # stride of each column in rstd
|
|
37
|
+
W_ptr, # pointer to W
|
|
38
|
+
B_ptr, # pointer to B
|
|
39
|
+
hidden_size, # hidden size of X
|
|
40
|
+
channels_per_group, # the number of channels per group
|
|
41
|
+
eps,
|
|
42
|
+
BLOCK_SIZE: tl.constexpr,
|
|
43
|
+
):
|
|
44
|
+
"""
|
|
45
|
+
References:
|
|
46
|
+
https://nn.labml.ai/normalization/group_norm/index.html
|
|
47
|
+
"""
|
|
48
|
+
batch_idx = tl.program_id(0)
|
|
49
|
+
group_idx = tl.program_id(1)
|
|
50
|
+
|
|
51
|
+
X_ptr += batch_idx * X_row_stride + group_idx * X_col_stride
|
|
52
|
+
Y_ptr += batch_idx * Y_row_stride + group_idx * Y_col_stride
|
|
53
|
+
|
|
54
|
+
block_range = tl.arange(0, BLOCK_SIZE)
|
|
55
|
+
|
|
56
|
+
# Compute mean and variance using the online algorithm
|
|
57
|
+
s = 0.0
|
|
58
|
+
squared_sum = 0.0
|
|
59
|
+
for i in tl.range(0, hidden_size, BLOCK_SIZE):
|
|
60
|
+
hidden_size_offsets = i + block_range
|
|
61
|
+
mask = hidden_size_offsets < hidden_size
|
|
62
|
+
X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0)
|
|
63
|
+
s += tl.sum(X)
|
|
64
|
+
# X**2
|
|
65
|
+
squared_sum += tl.sum(X * X)
|
|
66
|
+
|
|
67
|
+
m = s / hidden_size
|
|
68
|
+
|
|
69
|
+
# variance = E[X**2] - E[X]**2
|
|
70
|
+
variance = (squared_sum / hidden_size) - (m * m)
|
|
71
|
+
|
|
72
|
+
# 1/std
|
|
73
|
+
rstd = rsqrt(variance + eps)
|
|
74
|
+
|
|
75
|
+
# Normalize
|
|
76
|
+
hidden_size_per_channel = hidden_size // channels_per_group
|
|
77
|
+
for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
|
78
|
+
W = tl.load(W_ptr + channel_idx)
|
|
79
|
+
B = tl.load(B_ptr + channel_idx)
|
|
80
|
+
for i in range(0, hidden_size_per_channel, BLOCK_SIZE):
|
|
81
|
+
hidden_size_offsets = i + block_range
|
|
82
|
+
mask = hidden_size_offsets < hidden_size_per_channel
|
|
83
|
+
X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m)
|
|
84
|
+
Y = (X - m) * rstd * W + B
|
|
85
|
+
tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask)
|
|
86
|
+
|
|
87
|
+
X_ptr += hidden_size_per_channel
|
|
88
|
+
Y_ptr += hidden_size_per_channel
|
|
89
|
+
|
|
90
|
+
tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m)
|
|
91
|
+
tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@triton.jit
|
|
95
|
+
def _group_norm_backward_kernel(
|
|
96
|
+
X_ptr, # pointer to input, shape (n_rows, n_channels, hidden_size)
|
|
97
|
+
X_row_stride, # stride of each row in input
|
|
98
|
+
X_col_stride, # stride of each column in input
|
|
99
|
+
W_ptr, # pointer to weights, shape (n_channels)
|
|
100
|
+
Mean_ptr, # pointer to mean, shape (n_rows, n_groups)
|
|
101
|
+
Mean_ptr_row_stride, # stride of each column in mean
|
|
102
|
+
Mean_ptr_col_stride, # stride of each column in mean
|
|
103
|
+
RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups)
|
|
104
|
+
DX_ptr, # pointer to input grad, shape (n_rows, n_groups, hidden_size)
|
|
105
|
+
DW_ptr, # pointer to weights grad, shape (n_channels)
|
|
106
|
+
DB_ptr, # pointer to bias grad, shape (n_channels)
|
|
107
|
+
UPSTREAM_ptr, # pointer to output grad, shape (n_rows, n_channels, hidden_size)
|
|
108
|
+
hidden_size: tl.constexpr, # hidden size
|
|
109
|
+
channels_per_group: tl.constexpr, # number of groups in group norm
|
|
110
|
+
BLOCK_SIZE: tl.constexpr,
|
|
111
|
+
dtype: tl.constexpr,
|
|
112
|
+
):
|
|
113
|
+
"""
|
|
114
|
+
References:
|
|
115
|
+
https://nn.labml.ai/normalization/group_norm/index.html
|
|
116
|
+
https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
|
117
|
+
|
|
118
|
+
The backprop equations are the same for group_norm and layer_norm
|
|
119
|
+
the only difference here is that we load the Mean, Rstd corresponding to the
|
|
120
|
+
group we're computing gradients for and the mean and rstd are computed over n-channels
|
|
121
|
+
so the total number of elements we compute the mean over is num_channels_per_group * hidden_size
|
|
122
|
+
|
|
123
|
+
We also need to load the Weights corresponding to the current channel to compute the gradients.
|
|
124
|
+
"""
|
|
125
|
+
batch_idx = tl.program_id(0)
|
|
126
|
+
group_idx = tl.program_id(1)
|
|
127
|
+
|
|
128
|
+
# Move the pointers to the correct batch
|
|
129
|
+
X_ptr += batch_idx * X_row_stride
|
|
130
|
+
DX_ptr += batch_idx * X_row_stride
|
|
131
|
+
UPSTREAM_ptr += batch_idx * X_row_stride
|
|
132
|
+
|
|
133
|
+
# Mean and rstd are the same shape so have the same strides
|
|
134
|
+
mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
|
|
135
|
+
rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride)
|
|
136
|
+
|
|
137
|
+
c1 = 0.0
|
|
138
|
+
c2 = 0.0
|
|
139
|
+
block_range = tl.arange(0, BLOCK_SIZE)
|
|
140
|
+
|
|
141
|
+
# We need to compute the sum terms of the backprop equations across all channels in the group
|
|
142
|
+
for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
|
143
|
+
dW = 0.0
|
|
144
|
+
dB = 0.0
|
|
145
|
+
# Move the pointers to the correct channel
|
|
146
|
+
W = tl.load(W_ptr + channel_idx)
|
|
147
|
+
for i in tl.range(0, hidden_size, BLOCK_SIZE):
|
|
148
|
+
hidden_size_offsets = i + block_range
|
|
149
|
+
mask = hidden_size_offsets < hidden_size
|
|
150
|
+
X = tl.load(
|
|
151
|
+
X_ptr + channel_idx * X_col_stride + hidden_size_offsets,
|
|
152
|
+
mask=mask,
|
|
153
|
+
other=0.0,
|
|
154
|
+
)
|
|
155
|
+
UPSTREAM_grad = tl.load(
|
|
156
|
+
UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets,
|
|
157
|
+
mask=mask,
|
|
158
|
+
other=0.0,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
x_hat = (X - mean) * rstd
|
|
162
|
+
dW += tl.sum(UPSTREAM_grad * x_hat)
|
|
163
|
+
dB += tl.sum(UPSTREAM_grad)
|
|
164
|
+
|
|
165
|
+
wdy = W * UPSTREAM_grad
|
|
166
|
+
c1 += tl.sum(x_hat * wdy)
|
|
167
|
+
c2 += tl.sum(wdy)
|
|
168
|
+
|
|
169
|
+
# Need to ensure additions to the same channel are atomic
|
|
170
|
+
tl.atomic_add(DW_ptr + channel_idx, dW.to(dtype))
|
|
171
|
+
tl.atomic_add(DB_ptr + channel_idx, dB.to(dtype))
|
|
172
|
+
|
|
173
|
+
N = hidden_size * channels_per_group
|
|
174
|
+
c1 = c1 / N
|
|
175
|
+
c2 = c2 / N
|
|
176
|
+
|
|
177
|
+
for channel_idx in tl.range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group):
|
|
178
|
+
# Move the pointers to the correct channel
|
|
179
|
+
W = tl.load(W_ptr + channel_idx)
|
|
180
|
+
for i in range(0, hidden_size, BLOCK_SIZE):
|
|
181
|
+
hidden_size_offsets = i + block_range
|
|
182
|
+
mask = hidden_size_offsets < hidden_size
|
|
183
|
+
X = tl.load(
|
|
184
|
+
X_ptr + channel_idx * X_col_stride + hidden_size_offsets,
|
|
185
|
+
mask=mask,
|
|
186
|
+
other=0.0,
|
|
187
|
+
)
|
|
188
|
+
UPSTREAM_grad = tl.load(
|
|
189
|
+
UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets,
|
|
190
|
+
mask=mask,
|
|
191
|
+
other=0.0,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
x_hat = (X - mean) * rstd
|
|
195
|
+
wdy = W * UPSTREAM_grad
|
|
196
|
+
dx = (wdy - (x_hat * c1 + c2)) * rstd
|
|
197
|
+
tl.store(DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def group_norm_forward(X, num_channels, num_groups, W, B, eps):
|
|
201
|
+
shape = X.shape
|
|
202
|
+
batch_size = shape[0]
|
|
203
|
+
channels_per_group = num_channels // num_groups
|
|
204
|
+
# Reshape X so that the mean and std are computed across the groups
|
|
205
|
+
X = X.view(batch_size, num_groups, -1).contiguous()
|
|
206
|
+
hidden_size = X.shape[-1]
|
|
207
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
|
|
208
|
+
Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device)
|
|
209
|
+
Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
|
|
210
|
+
RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device)
|
|
211
|
+
|
|
212
|
+
_group_norm_forward_kernel[(batch_size, num_groups)](
|
|
213
|
+
Y,
|
|
214
|
+
Y.stride(0),
|
|
215
|
+
Y.stride(1),
|
|
216
|
+
X,
|
|
217
|
+
X.stride(0),
|
|
218
|
+
X.stride(1),
|
|
219
|
+
Mean,
|
|
220
|
+
Mean.stride(0),
|
|
221
|
+
Mean.stride(1),
|
|
222
|
+
RSTD,
|
|
223
|
+
RSTD.stride(0),
|
|
224
|
+
RSTD.stride(1),
|
|
225
|
+
W,
|
|
226
|
+
B,
|
|
227
|
+
hidden_size,
|
|
228
|
+
channels_per_group,
|
|
229
|
+
eps,
|
|
230
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
231
|
+
)
|
|
232
|
+
# Return tensors in the original shape
|
|
233
|
+
return Y.view(*shape), X.view(*shape), Mean, RSTD, BLOCK_SIZE
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups):
|
|
237
|
+
shape = dY.shape
|
|
238
|
+
batch_size = shape[0]
|
|
239
|
+
hidden_size = dY.shape[-1]
|
|
240
|
+
channels_per_group = num_channels // num_groups
|
|
241
|
+
dY = dY.view(batch_size, num_groups, -1)
|
|
242
|
+
DX = torch.empty(
|
|
243
|
+
(batch_size, num_groups, hidden_size * channels_per_group),
|
|
244
|
+
dtype=X.dtype,
|
|
245
|
+
device=X.device,
|
|
246
|
+
)
|
|
247
|
+
DW = torch.zeros((num_channels), dtype=W.dtype, device=W.device)
|
|
248
|
+
DB = torch.zeros((num_channels), dtype=B.dtype, device=B.device)
|
|
249
|
+
triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16
|
|
250
|
+
|
|
251
|
+
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size))
|
|
252
|
+
_group_norm_backward_kernel[(batch_size, num_groups)](
|
|
253
|
+
X,
|
|
254
|
+
X.stride(0),
|
|
255
|
+
X.stride(1),
|
|
256
|
+
W,
|
|
257
|
+
Mean,
|
|
258
|
+
Mean.stride(0),
|
|
259
|
+
Mean.stride(1),
|
|
260
|
+
RSTD,
|
|
261
|
+
DX,
|
|
262
|
+
DW,
|
|
263
|
+
DB,
|
|
264
|
+
dY,
|
|
265
|
+
hidden_size,
|
|
266
|
+
channels_per_group,
|
|
267
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
268
|
+
dtype=triton_dtype,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
# Return tensors in the original shape
|
|
272
|
+
return DX.view(*shape), DW, DB
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
class LigerGroupNormFunction(torch.autograd.Function):
|
|
276
|
+
@staticmethod
|
|
277
|
+
@ensure_contiguous
|
|
278
|
+
def forward(
|
|
279
|
+
ctx,
|
|
280
|
+
X,
|
|
281
|
+
affine_scaling_weight,
|
|
282
|
+
affine_shifting_bias,
|
|
283
|
+
num_channels,
|
|
284
|
+
num_groups,
|
|
285
|
+
eps,
|
|
286
|
+
):
|
|
287
|
+
Y, X, Mean, RSTD, BLOCK_SIZE = group_norm_forward(
|
|
288
|
+
X,
|
|
289
|
+
num_channels,
|
|
290
|
+
num_groups,
|
|
291
|
+
affine_scaling_weight,
|
|
292
|
+
affine_shifting_bias,
|
|
293
|
+
eps,
|
|
294
|
+
)
|
|
295
|
+
ctx.num_channels = num_channels
|
|
296
|
+
ctx.num_groups = num_groups
|
|
297
|
+
ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD)
|
|
298
|
+
return Y
|
|
299
|
+
|
|
300
|
+
@staticmethod
|
|
301
|
+
@ensure_contiguous
|
|
302
|
+
def backward(ctx, dY):
|
|
303
|
+
X, W, B, Mean, RSTD = ctx.saved_tensors
|
|
304
|
+
DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups)
|
|
305
|
+
return DX, DW, DB, None, None, None
|
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@triton.jit
|
|
7
|
+
def _selective_log_softmax_kernel(
|
|
8
|
+
LOGITS,
|
|
9
|
+
INPUT_IDS,
|
|
10
|
+
LOG_P,
|
|
11
|
+
MASK,
|
|
12
|
+
TEMPERATURE,
|
|
13
|
+
stride_input_ids_b,
|
|
14
|
+
L: tl.constexpr,
|
|
15
|
+
N: tl.constexpr,
|
|
16
|
+
BLOCK_N: tl.constexpr = 4096,
|
|
17
|
+
):
|
|
18
|
+
off_b = tl.program_id(0).cast(tl.int64)
|
|
19
|
+
off_l = tl.program_id(1).cast(tl.int64)
|
|
20
|
+
|
|
21
|
+
LOGITS += off_b * (L + 1) * N + off_l * N
|
|
22
|
+
INPUT_IDS += off_b * stride_input_ids_b + off_l
|
|
23
|
+
LOG_P += off_b * L + off_l
|
|
24
|
+
|
|
25
|
+
if MASK is not None:
|
|
26
|
+
MASK += off_b * stride_input_ids_b + off_l
|
|
27
|
+
not_skip = tl.load(MASK)
|
|
28
|
+
if not_skip == 0:
|
|
29
|
+
return
|
|
30
|
+
|
|
31
|
+
m_i = float("-inf")
|
|
32
|
+
l_i = 0.0
|
|
33
|
+
for start in range(0, N, BLOCK_N):
|
|
34
|
+
cols = start + tl.arange(0, BLOCK_N)
|
|
35
|
+
logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
|
|
36
|
+
new_m_i = tl.maximum(m_i, tl.max(logits))
|
|
37
|
+
alpha = tl.exp(m_i - new_m_i)
|
|
38
|
+
l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
|
|
39
|
+
m_i = new_m_i
|
|
40
|
+
lse = m_i + tl.log(l_i)
|
|
41
|
+
|
|
42
|
+
ids = tl.load(INPUT_IDS)
|
|
43
|
+
x = tl.load(LOGITS + ids).to(tl.float32) / TEMPERATURE
|
|
44
|
+
logp = x - lse
|
|
45
|
+
tl.store(LOG_P, logp)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# compue old_logp and ref_logp, it reduce 10G peak Memory. it does not requires grad
|
|
49
|
+
@torch.no_grad
|
|
50
|
+
def fused_selective_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 0.9, mask=None):
|
|
51
|
+
assert logits.is_contiguous()
|
|
52
|
+
B, L_ADD_1, N = logits.shape
|
|
53
|
+
L = L_ADD_1 - 1
|
|
54
|
+
input_ids = input_ids[:, -L:]
|
|
55
|
+
if mask is not None:
|
|
56
|
+
mask = mask[:, -L:]
|
|
57
|
+
log_p = torch.zeros(B, L, dtype=torch.float32, device=logits.device)
|
|
58
|
+
kwargs = {"BLOCK_N": 2048, "num_stages": 4, "num_warps": 1}
|
|
59
|
+
_selective_log_softmax_kernel[(B, L)](
|
|
60
|
+
logits, input_ids, log_p, mask, temperature, input_ids.stride(0), L, N, **kwargs
|
|
61
|
+
)
|
|
62
|
+
return log_p
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
|
|
66
|
+
# for BLOCK_N in [2048, 4096, 8192]
|
|
67
|
+
# for ns in [1, 2, 4]
|
|
68
|
+
# for nw in [1, 2, 4, 8, 16]],
|
|
69
|
+
# key=['N'])
|
|
70
|
+
@triton.jit
|
|
71
|
+
def _grpo_loss_fwd_kernel(
|
|
72
|
+
LOGITS,
|
|
73
|
+
OLD_LOGP,
|
|
74
|
+
REF_LOGP,
|
|
75
|
+
INPUT_IDS,
|
|
76
|
+
COMPLETION_MASK,
|
|
77
|
+
ADVANTAGES,
|
|
78
|
+
LOSS,
|
|
79
|
+
LSE,
|
|
80
|
+
KL,
|
|
81
|
+
IS_CLIPPED,
|
|
82
|
+
TEMPERATURE,
|
|
83
|
+
BETA: tl.constexpr,
|
|
84
|
+
EPS_LOW,
|
|
85
|
+
EPS_HIGH,
|
|
86
|
+
L: tl.constexpr,
|
|
87
|
+
N: tl.constexpr,
|
|
88
|
+
BLOCK_N: tl.constexpr = 4096,
|
|
89
|
+
):
|
|
90
|
+
off_b = tl.program_id(0).cast(tl.int64)
|
|
91
|
+
off_l = tl.program_id(1).cast(tl.int64)
|
|
92
|
+
|
|
93
|
+
if COMPLETION_MASK is not None:
|
|
94
|
+
COMPLETION_MASK += off_b * L + off_l
|
|
95
|
+
not_skip = tl.load(COMPLETION_MASK)
|
|
96
|
+
if not_skip == 0:
|
|
97
|
+
return
|
|
98
|
+
|
|
99
|
+
LOGITS += off_b * (L + 1) * N + off_l * N
|
|
100
|
+
INPUT_IDS += off_b * L + off_l
|
|
101
|
+
ADVANTAGES += off_b
|
|
102
|
+
LOSS += off_b * L + off_l
|
|
103
|
+
LSE += off_b * L + off_l
|
|
104
|
+
IS_CLIPPED += off_b * L + off_l
|
|
105
|
+
|
|
106
|
+
m_i = float("-inf")
|
|
107
|
+
l_i = 0.0
|
|
108
|
+
for start in range(0, N, BLOCK_N):
|
|
109
|
+
cols = start + tl.arange(0, BLOCK_N)
|
|
110
|
+
logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
|
|
111
|
+
new_m_i = tl.maximum(m_i, tl.max(logits))
|
|
112
|
+
alpha = tl.exp(m_i - new_m_i)
|
|
113
|
+
l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
|
|
114
|
+
m_i = new_m_i
|
|
115
|
+
lse = m_i + tl.log(l_i)
|
|
116
|
+
|
|
117
|
+
idx = tl.load(INPUT_IDS)
|
|
118
|
+
x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
|
|
119
|
+
logp = x - lse
|
|
120
|
+
if OLD_LOGP is None:
|
|
121
|
+
old_logp = logp
|
|
122
|
+
else:
|
|
123
|
+
OLD_LOGP += off_b * L + off_l
|
|
124
|
+
old_logp = tl.load(OLD_LOGP).to(tl.float32)
|
|
125
|
+
coef_1 = tl.exp(logp - old_logp)
|
|
126
|
+
coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
|
|
127
|
+
advantage = tl.load(ADVANTAGES).to(tl.float32)
|
|
128
|
+
per_token_loss1 = coef_1 * advantage
|
|
129
|
+
per_token_loss2 = coef_2 * advantage
|
|
130
|
+
per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
|
|
131
|
+
is_clipped = per_token_loss1 < per_token_loss2
|
|
132
|
+
|
|
133
|
+
if BETA != 0.0:
|
|
134
|
+
REF_LOGP += off_b * L + off_l
|
|
135
|
+
KL += off_b * L + off_l
|
|
136
|
+
ref_logp = tl.load(REF_LOGP).to(tl.float32)
|
|
137
|
+
kl = tl.exp(ref_logp - logp) - (ref_logp - logp) - 1
|
|
138
|
+
per_token_loss += BETA * kl
|
|
139
|
+
tl.store(KL, kl)
|
|
140
|
+
|
|
141
|
+
tl.store(LOSS, per_token_loss)
|
|
142
|
+
tl.store(LSE, lse)
|
|
143
|
+
tl.store(IS_CLIPPED, is_clipped)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
# @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
|
|
147
|
+
# for BLOCK_N in [2048, 4096, 8192]
|
|
148
|
+
# for ns in [1, 2, 4]
|
|
149
|
+
# for nw in [1, 2, 4, 8, 16]],
|
|
150
|
+
# key=['N'])
|
|
151
|
+
@triton.jit
|
|
152
|
+
def _grpo_loss_bwd_kernel(
|
|
153
|
+
DLOSS,
|
|
154
|
+
DLOGITS,
|
|
155
|
+
LOGITS,
|
|
156
|
+
OLD_LOGP,
|
|
157
|
+
REF_LOGP,
|
|
158
|
+
INPUT_IDS,
|
|
159
|
+
ADVANTAGES,
|
|
160
|
+
COMPLETION_MASK,
|
|
161
|
+
LSE,
|
|
162
|
+
TEMPERATURE,
|
|
163
|
+
BETA: tl.constexpr,
|
|
164
|
+
EPS_LOW,
|
|
165
|
+
EPS_HIGH,
|
|
166
|
+
loss_stride0,
|
|
167
|
+
loss_stride1,
|
|
168
|
+
L: tl.constexpr,
|
|
169
|
+
N: tl.constexpr,
|
|
170
|
+
BLOCK_N: tl.constexpr = 4096,
|
|
171
|
+
):
|
|
172
|
+
off_b = tl.program_id(0).cast(tl.int64)
|
|
173
|
+
off_l = tl.program_id(1).cast(tl.int64)
|
|
174
|
+
|
|
175
|
+
DLOGITS += off_b * (L + 1) * N + off_l * N
|
|
176
|
+
if COMPLETION_MASK is not None:
|
|
177
|
+
COMPLETION_MASK += off_b * L + off_l
|
|
178
|
+
not_skip = tl.load(COMPLETION_MASK)
|
|
179
|
+
if not_skip == 0:
|
|
180
|
+
for start in range(0, N, BLOCK_N):
|
|
181
|
+
cols = tl.arange(0, BLOCK_N) + start
|
|
182
|
+
tl.store(DLOGITS + cols, 0.0, mask=cols < N)
|
|
183
|
+
return
|
|
184
|
+
|
|
185
|
+
LOGITS += off_b * (L + 1) * N + off_l * N
|
|
186
|
+
DLOSS += off_b * loss_stride0 + off_l * loss_stride1
|
|
187
|
+
INPUT_IDS += off_b * L + off_l
|
|
188
|
+
ADVANTAGES += off_b
|
|
189
|
+
LSE += off_b * L + off_l
|
|
190
|
+
|
|
191
|
+
dloss = tl.load(DLOSS).to(tl.float32)
|
|
192
|
+
lse = tl.load(LSE).to(tl.float32)
|
|
193
|
+
|
|
194
|
+
idx = tl.load(INPUT_IDS)
|
|
195
|
+
x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
|
|
196
|
+
logp = x - lse
|
|
197
|
+
if OLD_LOGP is None:
|
|
198
|
+
old_logp = logp
|
|
199
|
+
else:
|
|
200
|
+
OLD_LOGP += off_b * L + off_l
|
|
201
|
+
old_logp = tl.load(OLD_LOGP).to(tl.float32)
|
|
202
|
+
coef_1 = tl.exp(logp - old_logp)
|
|
203
|
+
coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
|
|
204
|
+
advantage = tl.load(ADVANTAGES).to(tl.float32)
|
|
205
|
+
per_token_loss1 = coef_1 * advantage
|
|
206
|
+
per_token_loss2 = coef_2 * advantage
|
|
207
|
+
mask = per_token_loss2 >= per_token_loss1
|
|
208
|
+
|
|
209
|
+
dlogp = -per_token_loss1 * mask
|
|
210
|
+
if BETA != 0.0:
|
|
211
|
+
REF_LOGP += off_b * L + off_l
|
|
212
|
+
ref_logp = tl.load(REF_LOGP).to(tl.float32)
|
|
213
|
+
dlogp += BETA * (1 - tl.exp(ref_logp - logp))
|
|
214
|
+
|
|
215
|
+
dlogp = dlogp * dloss / TEMPERATURE
|
|
216
|
+
tl.debug_barrier()
|
|
217
|
+
for start_n in tl.range(0, N, BLOCK_N):
|
|
218
|
+
cols = start_n + tl.arange(0, BLOCK_N)
|
|
219
|
+
logits = tl.load(LOGITS + cols, mask=cols < N, other=-float("inf")).to(tl.float32) / TEMPERATURE
|
|
220
|
+
probs = tl.exp(logits - lse)
|
|
221
|
+
dlogits = tl.where(cols == idx, 1 - probs, -probs) * dlogp
|
|
222
|
+
tl.store(DLOGITS + cols, dlogits, mask=cols < N)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
class GrpoLossFunction(torch.autograd.Function):
|
|
226
|
+
@staticmethod
|
|
227
|
+
def forward(
|
|
228
|
+
ctx,
|
|
229
|
+
logits,
|
|
230
|
+
old_logp,
|
|
231
|
+
ref_logp,
|
|
232
|
+
completion_ids,
|
|
233
|
+
advantages,
|
|
234
|
+
completion_mask,
|
|
235
|
+
temperature,
|
|
236
|
+
beta,
|
|
237
|
+
eps_low,
|
|
238
|
+
eps_high,
|
|
239
|
+
inplace,
|
|
240
|
+
):
|
|
241
|
+
assert logits.is_contiguous() and completion_ids.is_contiguous()
|
|
242
|
+
assert old_logp is None or old_logp.is_contiguous()
|
|
243
|
+
assert (ref_logp is not None and ref_logp.is_contiguous()) if beta != 0.0 else True
|
|
244
|
+
|
|
245
|
+
B, L_ADD_1, N = logits.shape
|
|
246
|
+
L = L_ADD_1 - 1
|
|
247
|
+
|
|
248
|
+
if completion_mask is not None:
|
|
249
|
+
assert completion_mask.is_contiguous()
|
|
250
|
+
|
|
251
|
+
loss = torch.zeros(B, L, device=logits.device, dtype=torch.float32)
|
|
252
|
+
lse = torch.zeros_like(loss)
|
|
253
|
+
is_clipped = torch.zeros_like(loss)
|
|
254
|
+
kl = torch.zeros_like(loss) if beta != 0.0 else None
|
|
255
|
+
kwargs = {"BLOCK_N": 2048, "num_stages": 2, "num_warps": 1}
|
|
256
|
+
_grpo_loss_fwd_kernel[(B, L)](
|
|
257
|
+
logits,
|
|
258
|
+
old_logp,
|
|
259
|
+
ref_logp,
|
|
260
|
+
completion_ids,
|
|
261
|
+
completion_mask,
|
|
262
|
+
advantages,
|
|
263
|
+
loss,
|
|
264
|
+
lse,
|
|
265
|
+
kl,
|
|
266
|
+
is_clipped,
|
|
267
|
+
temperature,
|
|
268
|
+
beta,
|
|
269
|
+
eps_low,
|
|
270
|
+
eps_high,
|
|
271
|
+
L,
|
|
272
|
+
N,
|
|
273
|
+
**kwargs,
|
|
274
|
+
)
|
|
275
|
+
ctx.save_for_backward(logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse)
|
|
276
|
+
ctx.infos = (temperature, beta, eps_low, eps_high, inplace)
|
|
277
|
+
# return loss
|
|
278
|
+
return loss, kl, is_clipped
|
|
279
|
+
|
|
280
|
+
@staticmethod
|
|
281
|
+
def backward(ctx, *args):
|
|
282
|
+
dloss = args[0]
|
|
283
|
+
# print(dloss.shape)
|
|
284
|
+
logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse = ctx.saved_tensors
|
|
285
|
+
temperature, beta, eps_low, eps_high, inplace = ctx.infos
|
|
286
|
+
B, L_ADD_1, N = logits.shape
|
|
287
|
+
L = L_ADD_1 - 1
|
|
288
|
+
dlogits = logits.data if inplace else torch.empty_like(logits)
|
|
289
|
+
kwargs = {"BLOCK_N": 4096, "num_stages": 1, "num_warps": 16}
|
|
290
|
+
_grpo_loss_bwd_kernel[(B, L)](
|
|
291
|
+
dloss,
|
|
292
|
+
dlogits,
|
|
293
|
+
logits,
|
|
294
|
+
old_logp,
|
|
295
|
+
ref_logp,
|
|
296
|
+
completion_ids,
|
|
297
|
+
advantages,
|
|
298
|
+
completion_mask,
|
|
299
|
+
lse,
|
|
300
|
+
temperature,
|
|
301
|
+
beta,
|
|
302
|
+
eps_low,
|
|
303
|
+
eps_high,
|
|
304
|
+
*dloss.stride(),
|
|
305
|
+
L,
|
|
306
|
+
N,
|
|
307
|
+
**kwargs,
|
|
308
|
+
)
|
|
309
|
+
dlogits[:, -1, :] = 0
|
|
310
|
+
return dlogits, None, None, None, None, None, None, None, None, None, None
|