liger-kernel 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/__init__.py +0 -0
- liger_kernel/chunked_loss/__init__.py +4 -0
- liger_kernel/chunked_loss/cpo_loss.py +107 -0
- liger_kernel/chunked_loss/dpo_loss.py +135 -0
- liger_kernel/chunked_loss/functional.py +9 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +252 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +386 -0
- liger_kernel/chunked_loss/orpo_loss.py +113 -0
- liger_kernel/chunked_loss/simpo_loss.py +115 -0
- liger_kernel/env_report.py +22 -0
- liger_kernel/ops/cross_entropy.py +17 -10
- liger_kernel/ops/fused_linear_cross_entropy.py +1 -11
- liger_kernel/ops/fused_linear_jsd.py +1 -1
- liger_kernel/ops/jsd.py +19 -10
- liger_kernel/ops/layer_norm.py +6 -1
- liger_kernel/ops/qwen2vl_mrope.py +238 -0
- liger_kernel/ops/rms_norm.py +6 -1
- liger_kernel/ops/utils.py +5 -2
- liger_kernel/transformers/__init__.py +1 -0
- liger_kernel/transformers/functional.py +128 -11
- liger_kernel/transformers/fused_linear_jsd.py +1 -4
- liger_kernel/transformers/jsd.py +1 -4
- liger_kernel/transformers/model/qwen2_vl.py +43 -17
- liger_kernel/transformers/monkey_patch.py +11 -6
- liger_kernel/transformers/orpo_trainer.py +171 -0
- liger_kernel/transformers/qwen2vl_mrope.py +20 -0
- liger_kernel/utils.py +13 -0
- {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/METADATA +80 -123
- {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/RECORD +33 -20
- {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/WHEEL +1 -1
- {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/LICENSE +0 -0
- {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/NOTICE +0 -0
- {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/top_level.txt +0 -0
|
@@ -92,8 +92,8 @@ def liger_cross_entropy_kernel(
|
|
|
92
92
|
# 3. [Online softmax] first pass: find max + sum
|
|
93
93
|
m = float("-inf") # m is the max value. use the notation from the paper
|
|
94
94
|
d = 0.0 # d is the sum. use the notation from the paper
|
|
95
|
-
ori_X_y = tl.load(
|
|
96
|
-
|
|
95
|
+
ori_X_y = tl.load(X_ptr + y).cast(
|
|
96
|
+
tl.float32
|
|
97
97
|
) # we need to store the original value of X_y for the loss calculation
|
|
98
98
|
if HAS_SOFTCAPPING:
|
|
99
99
|
ori_X_y = softcap * tanh(ori_X_y / softcap)
|
|
@@ -106,8 +106,11 @@ def liger_cross_entropy_kernel(
|
|
|
106
106
|
for i in range(0, n_cols, BLOCK_SIZE):
|
|
107
107
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
108
108
|
X_block = tl.load(
|
|
109
|
-
X_ptr + X_offsets,
|
|
110
|
-
|
|
109
|
+
X_ptr + X_offsets,
|
|
110
|
+
mask=X_offsets < n_cols,
|
|
111
|
+
other=float("-inf"),
|
|
112
|
+
# Ensure float32 precision for softmax calculation
|
|
113
|
+
).cast(tl.float32)
|
|
111
114
|
if HAS_SOFTCAPPING:
|
|
112
115
|
X_block = softcap * tanh(X_block / softcap)
|
|
113
116
|
block_max = tl.max(X_block)
|
|
@@ -141,8 +144,11 @@ def liger_cross_entropy_kernel(
|
|
|
141
144
|
for i in range(0, n_cols, BLOCK_SIZE):
|
|
142
145
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
|
143
146
|
X_block = tl.load(
|
|
144
|
-
X_ptr + X_offsets,
|
|
145
|
-
|
|
147
|
+
X_ptr + X_offsets,
|
|
148
|
+
mask=X_offsets < n_cols,
|
|
149
|
+
other=float("-inf"),
|
|
150
|
+
# Ensure float32 precision for softmax calculation
|
|
151
|
+
).cast(tl.float32)
|
|
146
152
|
if HAS_SOFTCAPPING:
|
|
147
153
|
intermediate = tanh(X_block / softcap)
|
|
148
154
|
X_block = softcap * intermediate
|
|
@@ -279,11 +285,12 @@ def cross_entropy_forward(
|
|
|
279
285
|
num_warps=32 if not is_hip() else 16,
|
|
280
286
|
)
|
|
281
287
|
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
z_loss =
|
|
288
|
+
if reduction == "none":
|
|
289
|
+
loss = loss_1d
|
|
290
|
+
z_loss = z_loss_1d if return_z_loss == _TRUE.value else None
|
|
285
291
|
else:
|
|
286
|
-
|
|
292
|
+
loss = torch.sum(loss_1d)
|
|
293
|
+
z_loss = torch.sum(z_loss_1d) if return_z_loss == _TRUE.value else None
|
|
287
294
|
|
|
288
295
|
return loss, z_loss, _input
|
|
289
296
|
|
|
@@ -26,7 +26,6 @@ def fused_linear_cross_entropy_forward(
|
|
|
26
26
|
reduction="mean",
|
|
27
27
|
softcap=None,
|
|
28
28
|
):
|
|
29
|
-
dtype = _input.dtype
|
|
30
29
|
device = _input.device
|
|
31
30
|
|
|
32
31
|
# inputs have shape: BT x H
|
|
@@ -74,9 +73,6 @@ def fused_linear_cross_entropy_forward(
|
|
|
74
73
|
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
|
|
75
74
|
n_non_ignore = (target_chunk != ignore_index).sum().item()
|
|
76
75
|
|
|
77
|
-
# when doing CE, use the upcasted precision
|
|
78
|
-
logits_chunk = logits_chunk.float()
|
|
79
|
-
|
|
80
76
|
# ensure _input and target are contiguous
|
|
81
77
|
logits_chunk = logits_chunk.contiguous()
|
|
82
78
|
target_chunk = target_chunk.contiguous()
|
|
@@ -103,13 +99,6 @@ def fused_linear_cross_entropy_forward(
|
|
|
103
99
|
num_warps=32 if not is_hip() else 16,
|
|
104
100
|
)
|
|
105
101
|
|
|
106
|
-
# gradient of logits_chunk is computed in-place by the above triton kernel.
|
|
107
|
-
# Following HuggingFace model source code, we do the forward and backward
|
|
108
|
-
# w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size) is huge.
|
|
109
|
-
# (reference: https://github.com/huggingface/transformers/blob/v4.42.4/src/transformers/models/llama/modeling_llama.py#L1194)
|
|
110
|
-
# Propagating to lm_head's backward, we'll switch back to the original dtype.
|
|
111
|
-
logits_chunk = logits_chunk.to(dtype)
|
|
112
|
-
|
|
113
102
|
# gradient of logits_chunk is computed in-place by the above triton kernel and is of shape: chunk_size x V
|
|
114
103
|
# thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H
|
|
115
104
|
# additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
|
|
@@ -229,6 +218,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
|
|
|
229
218
|
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
|
|
230
219
|
reduction: reduction to apply
|
|
231
220
|
"""
|
|
221
|
+
|
|
232
222
|
loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
|
|
233
223
|
_input,
|
|
234
224
|
weight,
|
|
@@ -202,7 +202,7 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
|
|
|
202
202
|
teacher_input (torch.tensor): input of the last projection layer in teacher model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension.
|
|
203
203
|
teacher_weight (torch.tensor): the last projection layer in teacher model, with shape (V, H), where V is vocab size
|
|
204
204
|
shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
|
|
205
|
-
jsd_beta (float): coefficient beta of generalized JSD in the
|
|
205
|
+
jsd_beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
|
|
206
206
|
ignore_index (int): the index to ignore. Default: -100
|
|
207
207
|
temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0`
|
|
208
208
|
|
liger_kernel/ops/jsd.py
CHANGED
|
@@ -18,7 +18,7 @@ def _jsd_kernel(
|
|
|
18
18
|
dX_ptr,
|
|
19
19
|
dX_stride,
|
|
20
20
|
label_ptr,
|
|
21
|
-
beta,
|
|
21
|
+
beta: tl.constexpr,
|
|
22
22
|
n_non_ignore: int,
|
|
23
23
|
ignore_index: tl.constexpr,
|
|
24
24
|
n_cols,
|
|
@@ -50,17 +50,26 @@ def _jsd_kernel(
|
|
|
50
50
|
X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
|
|
51
51
|
Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
|
|
52
52
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
53
|
+
if beta == 0.0: # forward KL
|
|
54
|
+
Y_prob = tl.exp(Y)
|
|
55
|
+
loss = Y_prob * (Y - X)
|
|
56
|
+
dX = -Y_prob
|
|
57
|
+
elif beta == 1.0:
|
|
58
|
+
X_prob = tl.exp(X)
|
|
59
|
+
loss = X_prob * (X - Y)
|
|
60
|
+
dX = loss + X_prob
|
|
61
|
+
else:
|
|
62
|
+
Q = tl.exp(X)
|
|
63
|
+
P = tl.exp(Y)
|
|
64
|
+
M = beta * P + (1 - beta) * Q
|
|
65
|
+
log_M = tl.log(M)
|
|
66
|
+
|
|
67
|
+
loss = beta * P * Y + (1 - beta) * Q * X - M * log_M
|
|
68
|
+
dX = (1 - beta) * Q * (X - log_M)
|
|
57
69
|
|
|
58
|
-
loss = beta * P * Y + (1 - beta) * Q * X - M * log_M
|
|
59
|
-
# reduction == "batchmean"
|
|
60
70
|
loss = loss / n_non_ignore
|
|
71
|
+
dX = dX / n_non_ignore
|
|
61
72
|
tl.store(loss_ptr + offsets, loss, mask=mask)
|
|
62
|
-
|
|
63
|
-
dX = (1 - beta) * Q * (X - log_M) / n_non_ignore
|
|
64
73
|
tl.store(dX_ptr + offsets, dX, mask=mask)
|
|
65
74
|
|
|
66
75
|
|
|
@@ -142,7 +151,7 @@ class LigerJSDFunction(torch.autograd.Function):
|
|
|
142
151
|
_input (torch.Tensor): predict values with shape (BT, V) in logspace
|
|
143
152
|
target (torch.Tensor): ground truth values with shape (BT, V) in logspace
|
|
144
153
|
shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
|
|
145
|
-
beta (float): coefficient beta of generalized JSD in the
|
|
154
|
+
beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
|
|
146
155
|
ignore_index (int): the index to ignore. Default: -100
|
|
147
156
|
|
|
148
157
|
Returns:
|
liger_kernel/ops/layer_norm.py
CHANGED
|
@@ -180,8 +180,13 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
|
|
|
180
180
|
dY = dY.view(-1, dim)
|
|
181
181
|
n_rows, n_cols = dY.shape
|
|
182
182
|
|
|
183
|
+
sm_count = 1
|
|
184
|
+
if X.device.type == "cuda":
|
|
185
|
+
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
186
|
+
elif X.device.type == "xpu":
|
|
187
|
+
sm_count = torch.xpu.get_device_properties(X.device).gpu_subslice_count
|
|
188
|
+
|
|
183
189
|
DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
|
184
|
-
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
185
190
|
_DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
|
|
186
191
|
_DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
|
|
187
192
|
|
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@triton.jit
|
|
7
|
+
def _triton_qwen2vl_mrope(
|
|
8
|
+
q_ptr,
|
|
9
|
+
k_ptr,
|
|
10
|
+
cos,
|
|
11
|
+
sin,
|
|
12
|
+
sl,
|
|
13
|
+
n_qh: tl.constexpr,
|
|
14
|
+
n_kh: tl.constexpr,
|
|
15
|
+
hd: tl.constexpr,
|
|
16
|
+
pad_n_qh: tl.constexpr,
|
|
17
|
+
pad_n_kh: tl.constexpr,
|
|
18
|
+
pad_hd: tl.constexpr,
|
|
19
|
+
mrope_section_t: tl.constexpr,
|
|
20
|
+
mrope_section_h: tl.constexpr,
|
|
21
|
+
BLOCK_SIZE: tl.constexpr,
|
|
22
|
+
BACKWARD_PASS: tl.constexpr = False,
|
|
23
|
+
):
|
|
24
|
+
pid = tl.program_id(0)
|
|
25
|
+
|
|
26
|
+
# locate start address
|
|
27
|
+
q_ptr = q_ptr + pid * (n_qh * hd)
|
|
28
|
+
k_ptr = k_ptr + pid * (n_kh * hd)
|
|
29
|
+
|
|
30
|
+
# ####################################################################
|
|
31
|
+
# get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
|
|
32
|
+
# m of this program instance
|
|
33
|
+
# ####################################################################
|
|
34
|
+
|
|
35
|
+
# 1. program instances are laid out in a 1D vector of size bsz * seq_len, which
|
|
36
|
+
# effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension
|
|
37
|
+
# being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index
|
|
38
|
+
# and pid % sl to get the sequence index.
|
|
39
|
+
# 2. We only need the left half of cos and sin matrix because the right half is just
|
|
40
|
+
# a clone of the left half.
|
|
41
|
+
t_end = mrope_section_t
|
|
42
|
+
h_end = t_end + mrope_section_h
|
|
43
|
+
|
|
44
|
+
cos_row_idx = pid % sl
|
|
45
|
+
t_cos = cos + cos_row_idx * hd
|
|
46
|
+
h_cos = t_cos + sl * hd
|
|
47
|
+
w_cos = h_cos + sl * hd
|
|
48
|
+
t_sin = sin + cos_row_idx * hd
|
|
49
|
+
h_sin = t_sin + sl * hd
|
|
50
|
+
w_sin = h_sin + sl * hd
|
|
51
|
+
|
|
52
|
+
cos_offsets = tl.arange(0, pad_hd // 2)
|
|
53
|
+
t_mask = cos_offsets < t_end
|
|
54
|
+
h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
|
|
55
|
+
w_mask = (h_end <= cos_offsets) & (cos_offsets < hd // 2)
|
|
56
|
+
t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
|
|
57
|
+
h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
|
|
58
|
+
w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0)
|
|
59
|
+
t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0)
|
|
60
|
+
h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0)
|
|
61
|
+
w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0)
|
|
62
|
+
cos_row = t_cos_row + h_cos_row + w_cos_row
|
|
63
|
+
sin_row = t_sin_row + h_sin_row + w_sin_row
|
|
64
|
+
|
|
65
|
+
# ####################################################################
|
|
66
|
+
# Load the left and right half of q and k for the current
|
|
67
|
+
# program instance (i.e. for the current token) separately
|
|
68
|
+
# ####################################################################
|
|
69
|
+
# left half of the head
|
|
70
|
+
first_half_q_offsets = (
|
|
71
|
+
tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
|
72
|
+
)
|
|
73
|
+
first_half_k_offsets = (
|
|
74
|
+
tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
|
|
75
|
+
)
|
|
76
|
+
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
|
|
77
|
+
tl.arange(0, pad_hd // 2)[None, :] < hd // 2
|
|
78
|
+
)
|
|
79
|
+
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
|
|
80
|
+
tl.arange(0, pad_hd // 2)[None, :] < hd // 2
|
|
81
|
+
)
|
|
82
|
+
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
|
|
83
|
+
sin_row.dtype
|
|
84
|
+
)
|
|
85
|
+
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
|
|
86
|
+
sin_row.dtype
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# right half of the head
|
|
90
|
+
second_half_q_offsets = first_half_q_offsets + (hd // 2)
|
|
91
|
+
second_half_k_offsets = first_half_k_offsets + (hd // 2)
|
|
92
|
+
second_q_mask = first_q_mask
|
|
93
|
+
second_k_mask = first_k_mask
|
|
94
|
+
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
|
|
95
|
+
sin_row.dtype
|
|
96
|
+
)
|
|
97
|
+
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
|
|
98
|
+
sin_row.dtype
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
if not BACKWARD_PASS:
|
|
102
|
+
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
|
|
103
|
+
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
|
|
104
|
+
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
|
|
105
|
+
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
|
|
106
|
+
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
|
|
107
|
+
|
|
108
|
+
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
|
|
109
|
+
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
|
|
110
|
+
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
|
|
111
|
+
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
|
|
112
|
+
else:
|
|
113
|
+
# with some math, we can get:
|
|
114
|
+
# dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin]
|
|
115
|
+
new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row
|
|
116
|
+
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
|
|
117
|
+
new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row
|
|
118
|
+
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
|
|
119
|
+
|
|
120
|
+
new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row
|
|
121
|
+
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
|
|
122
|
+
new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row
|
|
123
|
+
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
|
|
127
|
+
|
|
128
|
+
# transpose it back to the physical shape because Triton looks at the physical storage
|
|
129
|
+
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
|
|
130
|
+
q = q.transpose(1, 2)
|
|
131
|
+
k = k.transpose(1, 2)
|
|
132
|
+
|
|
133
|
+
batch_size, seq_len, n_q_head, head_dim = q.shape
|
|
134
|
+
n_kv_head = k.shape[2]
|
|
135
|
+
pad_hd = triton.next_power_of_2(head_dim)
|
|
136
|
+
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
|
137
|
+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
|
138
|
+
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
|
|
139
|
+
|
|
140
|
+
n_row = batch_size * seq_len
|
|
141
|
+
|
|
142
|
+
# ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
|
|
143
|
+
q = q.contiguous()
|
|
144
|
+
k = k.contiguous()
|
|
145
|
+
cos = cos.contiguous()
|
|
146
|
+
sin = sin.contiguous()
|
|
147
|
+
|
|
148
|
+
_triton_qwen2vl_mrope[(n_row,)](
|
|
149
|
+
q,
|
|
150
|
+
k,
|
|
151
|
+
cos,
|
|
152
|
+
sin,
|
|
153
|
+
seq_len,
|
|
154
|
+
n_q_head,
|
|
155
|
+
n_kv_head,
|
|
156
|
+
head_dim,
|
|
157
|
+
pad_n_q_head,
|
|
158
|
+
pad_n_kv_head,
|
|
159
|
+
pad_hd,
|
|
160
|
+
mrope_section[0],
|
|
161
|
+
mrope_section[1],
|
|
162
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
163
|
+
BACKWARD_PASS=False,
|
|
164
|
+
)
|
|
165
|
+
return q.transpose(1, 2), k.transpose(1, 2), cos, sin
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
|
|
169
|
+
dq = dq.transpose(1, 2)
|
|
170
|
+
dk = dk.transpose(1, 2)
|
|
171
|
+
|
|
172
|
+
batch_size, seq_len, n_q_head, head_dim = dq.shape
|
|
173
|
+
n_kv_head = dk.shape[2]
|
|
174
|
+
pad_hd = triton.next_power_of_2(head_dim)
|
|
175
|
+
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
|
176
|
+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
|
177
|
+
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
|
|
178
|
+
|
|
179
|
+
n_row = batch_size * seq_len
|
|
180
|
+
|
|
181
|
+
# ensure dq and dk are contiguous
|
|
182
|
+
dq = dq.contiguous()
|
|
183
|
+
dk = dk.contiguous()
|
|
184
|
+
|
|
185
|
+
# backward is similar to forward except swapping few ops
|
|
186
|
+
_triton_qwen2vl_mrope[(n_row,)](
|
|
187
|
+
dq,
|
|
188
|
+
dk,
|
|
189
|
+
cos,
|
|
190
|
+
sin,
|
|
191
|
+
seq_len,
|
|
192
|
+
n_q_head,
|
|
193
|
+
n_kv_head,
|
|
194
|
+
head_dim,
|
|
195
|
+
pad_n_q_head,
|
|
196
|
+
pad_n_kv_head,
|
|
197
|
+
pad_hd,
|
|
198
|
+
mrope_section[0],
|
|
199
|
+
mrope_section[1],
|
|
200
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
201
|
+
BACKWARD_PASS=True,
|
|
202
|
+
)
|
|
203
|
+
return dq.transpose(1, 2), dk.transpose(1, 2)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class LigerQwen2VLMRopeFunction(torch.autograd.Function):
|
|
207
|
+
"""
|
|
208
|
+
Triton implementation of the Qwen2VL Multimodal Rotary Positional Embedding (M-RoPE) operation.
|
|
209
|
+
|
|
210
|
+
Please find the corresponding HuggingFace implementation here:
|
|
211
|
+
https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
|
|
212
|
+
"""
|
|
213
|
+
|
|
214
|
+
@staticmethod
|
|
215
|
+
def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
|
216
|
+
"""
|
|
217
|
+
q size: (bsz, n_q_head, seq_len, head_dim)
|
|
218
|
+
k size: (bsz, n_kv_head, seq_len, head_dim)
|
|
219
|
+
cos size: (3, 1, seq_len, head_dim)
|
|
220
|
+
sin size: (3, 1, seq_len, head_dim)
|
|
221
|
+
"""
|
|
222
|
+
q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section)
|
|
223
|
+
ctx.save_for_backward(cos, sin)
|
|
224
|
+
ctx.mrope_section = mrope_section
|
|
225
|
+
return q, k
|
|
226
|
+
|
|
227
|
+
def backward(ctx, dq, dk):
|
|
228
|
+
"""
|
|
229
|
+
dq size: (bsz, n_q_head, seq_len, head_dim)
|
|
230
|
+
dk size: (bsz, n_kv_head, seq_len, head_dim)
|
|
231
|
+
cos size: (3, 1, seq_len, head_dim)
|
|
232
|
+
sin size: (3, 1, seq_len, head_dim)
|
|
233
|
+
"""
|
|
234
|
+
|
|
235
|
+
cos, sin = ctx.saved_tensors
|
|
236
|
+
mrope_section = ctx.mrope_section
|
|
237
|
+
dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
|
|
238
|
+
return dq, dk, None, None, None, None
|
liger_kernel/ops/rms_norm.py
CHANGED
|
@@ -264,7 +264,12 @@ def rms_norm_backward(
|
|
|
264
264
|
dY = dY.view(-1, dim)
|
|
265
265
|
n_rows, n_cols = dY.shape
|
|
266
266
|
|
|
267
|
-
sm_count =
|
|
267
|
+
sm_count = 1
|
|
268
|
+
if X.device.type == "cuda":
|
|
269
|
+
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
|
270
|
+
elif X.device.type == "xpu":
|
|
271
|
+
sm_count = torch.xpu.get_device_properties(X.device).gpu_subslice_count
|
|
272
|
+
|
|
268
273
|
# fp32 for numerical stability especially.
|
|
269
274
|
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
|
270
275
|
|
liger_kernel/ops/utils.py
CHANGED
|
@@ -20,6 +20,8 @@ import triton
|
|
|
20
20
|
import triton.language as tl
|
|
21
21
|
from packaging.version import Version
|
|
22
22
|
|
|
23
|
+
from liger_kernel.utils import infer_device
|
|
24
|
+
|
|
23
25
|
|
|
24
26
|
def is_hip() -> bool:
|
|
25
27
|
return torch.version.hip is not None
|
|
@@ -69,10 +71,11 @@ def compare_version(package: str, operator: Callable, target: str):
|
|
|
69
71
|
|
|
70
72
|
|
|
71
73
|
def get_amp_custom_fwd_bwd() -> Callable:
|
|
74
|
+
device = infer_device()
|
|
72
75
|
if compare_version("torch", operator.ge, "2.4.0"):
|
|
73
76
|
return (
|
|
74
|
-
functools.partial(torch.amp.custom_fwd, device_type=
|
|
75
|
-
functools.partial(torch.amp.custom_bwd, device_type=
|
|
77
|
+
functools.partial(torch.amp.custom_fwd, device_type=device),
|
|
78
|
+
functools.partial(torch.amp.custom_bwd, device_type=device),
|
|
76
79
|
)
|
|
77
80
|
return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
|
|
78
81
|
|
|
@@ -22,6 +22,7 @@ from liger_kernel.transformers.monkey_patch import ( # noqa: F401
|
|
|
22
22
|
apply_liger_kernel_to_qwen2,
|
|
23
23
|
apply_liger_kernel_to_qwen2_vl,
|
|
24
24
|
)
|
|
25
|
+
from liger_kernel.transformers.orpo_trainer import LigerORPOTrainer # noqa: F401
|
|
25
26
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
|
|
26
27
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
|
|
27
28
|
from liger_kernel.transformers.swiglu import ( # noqa: F401
|
|
@@ -10,21 +10,11 @@ from liger_kernel.ops.group_norm import LigerGroupNormFunction
|
|
|
10
10
|
from liger_kernel.ops.jsd import LigerJSDFunction
|
|
11
11
|
from liger_kernel.ops.kl_div import LigerKLDivLossFunction
|
|
12
12
|
from liger_kernel.ops.layer_norm import LigerLayerNormFunction
|
|
13
|
+
from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
|
|
13
14
|
from liger_kernel.ops.rms_norm import LigerRMSNormFunction
|
|
14
15
|
from liger_kernel.ops.rope import LigerRopeFunction
|
|
15
16
|
from liger_kernel.ops.swiglu import LigerSiLUMulFunction
|
|
16
17
|
|
|
17
|
-
liger_swiglu = LigerSiLUMulFunction.apply
|
|
18
|
-
liger_fused_linear_cross_entropy = LigerFusedLinearCrossEntropyFunction.apply
|
|
19
|
-
liger_geglu = LigerGELUMulFunction.apply
|
|
20
|
-
liger_rms_norm = LigerRMSNormFunction.apply
|
|
21
|
-
liger_rope = LigerRopeFunction.apply
|
|
22
|
-
liger_layer_norm = LigerLayerNormFunction.apply
|
|
23
|
-
liger_kl_div = LigerKLDivLossFunction.apply
|
|
24
|
-
liger_jsd = LigerJSDFunction.apply
|
|
25
|
-
liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
|
|
26
|
-
liger_group_norm = LigerGroupNormFunction.apply
|
|
27
|
-
|
|
28
18
|
|
|
29
19
|
# conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
|
|
30
20
|
# `weight` and `size_average` are placeholders and not implemented yet
|
|
@@ -54,3 +44,130 @@ def liger_cross_entropy(
|
|
|
54
44
|
if not return_z_loss:
|
|
55
45
|
return loss
|
|
56
46
|
return loss, z_loss
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def liger_fused_linear_cross_entropy(
|
|
50
|
+
input,
|
|
51
|
+
weight,
|
|
52
|
+
target,
|
|
53
|
+
bias=None,
|
|
54
|
+
ignore_index: int = -100,
|
|
55
|
+
lse_square_scale: float = 0.0,
|
|
56
|
+
label_smoothing: float = 0.0,
|
|
57
|
+
reduction: str = "mean",
|
|
58
|
+
softcap: Optional[float] = None,
|
|
59
|
+
):
|
|
60
|
+
return LigerFusedLinearCrossEntropyFunction.apply(
|
|
61
|
+
input,
|
|
62
|
+
weight,
|
|
63
|
+
target,
|
|
64
|
+
bias,
|
|
65
|
+
ignore_index,
|
|
66
|
+
lse_square_scale,
|
|
67
|
+
label_smoothing,
|
|
68
|
+
reduction,
|
|
69
|
+
softcap,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def liger_fused_linear_jsd(
|
|
74
|
+
student_input,
|
|
75
|
+
student_weight,
|
|
76
|
+
teacher_input,
|
|
77
|
+
teacher_weight,
|
|
78
|
+
shift_labels=None,
|
|
79
|
+
jsd_beta: float = 0.5,
|
|
80
|
+
ignore_index: int = -100,
|
|
81
|
+
temperature: float = 1.0,
|
|
82
|
+
):
|
|
83
|
+
return LigerFusedLinearJSDFunction.apply(
|
|
84
|
+
student_input,
|
|
85
|
+
student_weight,
|
|
86
|
+
teacher_input,
|
|
87
|
+
teacher_weight,
|
|
88
|
+
shift_labels,
|
|
89
|
+
jsd_beta,
|
|
90
|
+
ignore_index,
|
|
91
|
+
temperature,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def liger_geglu(a, b):
|
|
96
|
+
return LigerGELUMulFunction.apply(a, b)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def liger_group_norm(
|
|
100
|
+
X,
|
|
101
|
+
affine_scaling_weight,
|
|
102
|
+
affine_shifting_bias,
|
|
103
|
+
num_channels,
|
|
104
|
+
num_groups,
|
|
105
|
+
eps,
|
|
106
|
+
):
|
|
107
|
+
return LigerGroupNormFunction.apply(
|
|
108
|
+
X,
|
|
109
|
+
affine_scaling_weight,
|
|
110
|
+
affine_shifting_bias,
|
|
111
|
+
num_channels,
|
|
112
|
+
num_groups,
|
|
113
|
+
eps,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def liger_jsd(
|
|
118
|
+
input,
|
|
119
|
+
target,
|
|
120
|
+
shift_labels=None,
|
|
121
|
+
beta: float = 0.5,
|
|
122
|
+
ignore_index: int = -100,
|
|
123
|
+
):
|
|
124
|
+
return LigerJSDFunction.apply(
|
|
125
|
+
input,
|
|
126
|
+
target,
|
|
127
|
+
shift_labels,
|
|
128
|
+
beta,
|
|
129
|
+
ignore_index,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
# conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.kl_div.html#torch.nn.functional.kl_div
|
|
134
|
+
# `size_average` and `mean` are being deprecated in torch API and are placeholders here
|
|
135
|
+
def liger_kl_div(
|
|
136
|
+
input,
|
|
137
|
+
target,
|
|
138
|
+
size_average: bool = True,
|
|
139
|
+
reduce: bool = True,
|
|
140
|
+
reduction: str = "mean",
|
|
141
|
+
log_target: bool = False,
|
|
142
|
+
eps: float = 1e-10,
|
|
143
|
+
):
|
|
144
|
+
# Note: the default reduction in torch is `mean`, but being `batchmean` in Liger
|
|
145
|
+
return LigerKLDivLossFunction.apply(
|
|
146
|
+
input,
|
|
147
|
+
target,
|
|
148
|
+
reduction,
|
|
149
|
+
log_target,
|
|
150
|
+
eps,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def liger_layer_norm(X, W, B, eps):
|
|
155
|
+
return LigerLayerNormFunction.apply(X, W, B, eps)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def liger_qwen2vl_mrope(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
|
159
|
+
return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def liger_rms_norm(
|
|
163
|
+
X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True
|
|
164
|
+
):
|
|
165
|
+
return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
169
|
+
return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def liger_swiglu(a, b):
|
|
173
|
+
return LigerSiLUMulFunction.apply(a, b)
|
|
@@ -12,7 +12,7 @@ class LigerFusedLinearJSD(torch.nn.Module):
|
|
|
12
12
|
the materialization of the large logits tensor.
|
|
13
13
|
|
|
14
14
|
Args:
|
|
15
|
-
jsd_beta (float): coefficient beta of generalized JSD in the
|
|
15
|
+
jsd_beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
|
|
16
16
|
ignore_index (int): The index to ignore in the target. Default: `-100`
|
|
17
17
|
temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0`
|
|
18
18
|
|
|
@@ -70,9 +70,6 @@ class LigerFusedLinearJSD(torch.nn.Module):
|
|
|
70
70
|
|
|
71
71
|
def __init__(self, jsd_beta=0.5, ignore_index=-100, temperature=1.0):
|
|
72
72
|
super().__init__()
|
|
73
|
-
assert (
|
|
74
|
-
jsd_beta > 0 and jsd_beta < 1
|
|
75
|
-
), f"beta must be greater than 0 and less than 1. Got: {jsd_beta}"
|
|
76
73
|
assert temperature != 0, "temperature cannot be 0."
|
|
77
74
|
self.jsd_beta = jsd_beta
|
|
78
75
|
self.temperature = temperature
|
liger_kernel/transformers/jsd.py
CHANGED
|
@@ -18,7 +18,7 @@ class LigerJSD(torch.nn.Module):
|
|
|
18
18
|
:math:`P` denotes the teacher model and :math:`Q` denotes the student model.
|
|
19
19
|
|
|
20
20
|
Args:
|
|
21
|
-
beta (float): coefficient beta of generalized JSD in the
|
|
21
|
+
beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
|
|
22
22
|
ignore_index (int): The index to ignore in the target. Default: `-100`
|
|
23
23
|
|
|
24
24
|
Shape:
|
|
@@ -58,9 +58,6 @@ class LigerJSD(torch.nn.Module):
|
|
|
58
58
|
|
|
59
59
|
def __init__(self, beta: float = 0.5, ignore_index: int = -100):
|
|
60
60
|
super().__init__()
|
|
61
|
-
assert (
|
|
62
|
-
beta > 0 and beta < 1
|
|
63
|
-
), f"beta must be greater than 0 and less than 1. Got: {beta}"
|
|
64
61
|
self.beta = beta
|
|
65
62
|
self.ignore_index = ignore_index
|
|
66
63
|
|