liger-kernel 0.4.2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/__init__.py +4 -0
  3. liger_kernel/chunked_loss/cpo_loss.py +107 -0
  4. liger_kernel/chunked_loss/dpo_loss.py +95 -17
  5. liger_kernel/chunked_loss/functional.py +9 -0
  6. liger_kernel/chunked_loss/fused_linear_distillation.py +252 -0
  7. liger_kernel/chunked_loss/fused_linear_preference.py +245 -65
  8. liger_kernel/chunked_loss/orpo_loss.py +63 -13
  9. liger_kernel/chunked_loss/simpo_loss.py +115 -0
  10. liger_kernel/env_report.py +22 -0
  11. liger_kernel/ops/cross_entropy.py +17 -10
  12. liger_kernel/ops/fused_linear_cross_entropy.py +0 -11
  13. liger_kernel/ops/fused_linear_jsd.py +1 -1
  14. liger_kernel/ops/jsd.py +19 -10
  15. liger_kernel/ops/layer_norm.py +6 -1
  16. liger_kernel/ops/qwen2vl_mrope.py +238 -0
  17. liger_kernel/ops/rms_norm.py +6 -1
  18. liger_kernel/ops/utils.py +5 -2
  19. liger_kernel/transformers/functional.py +128 -11
  20. liger_kernel/transformers/fused_linear_jsd.py +1 -4
  21. liger_kernel/transformers/jsd.py +1 -4
  22. liger_kernel/transformers/monkey_patch.py +6 -4
  23. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  24. liger_kernel/transformers/trainer/__init__.py +6 -0
  25. liger_kernel/transformers/trainer/orpo_trainer.py +169 -0
  26. liger_kernel/utils.py +13 -0
  27. {liger_kernel-0.4.2.dist-info → liger_kernel-0.5.1.dist-info}/METADATA +71 -47
  28. {liger_kernel-0.4.2.dist-info → liger_kernel-0.5.1.dist-info}/RECORD +32 -22
  29. {liger_kernel-0.4.2.dist-info → liger_kernel-0.5.1.dist-info}/WHEEL +1 -1
  30. {liger_kernel-0.4.2.dist-info → liger_kernel-0.5.1.dist-info}/LICENSE +0 -0
  31. {liger_kernel-0.4.2.dist-info → liger_kernel-0.5.1.dist-info}/NOTICE +0 -0
  32. {liger_kernel-0.4.2.dist-info → liger_kernel-0.5.1.dist-info}/top_level.txt +0 -0
@@ -92,8 +92,8 @@ def liger_cross_entropy_kernel(
92
92
  # 3. [Online softmax] first pass: find max + sum
93
93
  m = float("-inf") # m is the max value. use the notation from the paper
94
94
  d = 0.0 # d is the sum. use the notation from the paper
95
- ori_X_y = tl.load(
96
- X_ptr + y
95
+ ori_X_y = tl.load(X_ptr + y).cast(
96
+ tl.float32
97
97
  ) # we need to store the original value of X_y for the loss calculation
98
98
  if HAS_SOFTCAPPING:
99
99
  ori_X_y = softcap * tanh(ori_X_y / softcap)
@@ -106,8 +106,11 @@ def liger_cross_entropy_kernel(
106
106
  for i in range(0, n_cols, BLOCK_SIZE):
107
107
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
108
108
  X_block = tl.load(
109
- X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
110
- )
109
+ X_ptr + X_offsets,
110
+ mask=X_offsets < n_cols,
111
+ other=float("-inf"),
112
+ # Ensure float32 precision for softmax calculation
113
+ ).cast(tl.float32)
111
114
  if HAS_SOFTCAPPING:
112
115
  X_block = softcap * tanh(X_block / softcap)
113
116
  block_max = tl.max(X_block)
@@ -141,8 +144,11 @@ def liger_cross_entropy_kernel(
141
144
  for i in range(0, n_cols, BLOCK_SIZE):
142
145
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
143
146
  X_block = tl.load(
144
- X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
145
- )
147
+ X_ptr + X_offsets,
148
+ mask=X_offsets < n_cols,
149
+ other=float("-inf"),
150
+ # Ensure float32 precision for softmax calculation
151
+ ).cast(tl.float32)
146
152
  if HAS_SOFTCAPPING:
147
153
  intermediate = tanh(X_block / softcap)
148
154
  X_block = softcap * intermediate
@@ -279,11 +285,12 @@ def cross_entropy_forward(
279
285
  num_warps=32 if not is_hip() else 16,
280
286
  )
281
287
 
282
- loss = torch.sum(loss_1d)
283
- if return_z_loss == _TRUE.value:
284
- z_loss = torch.sum(z_loss_1d)
288
+ if reduction == "none":
289
+ loss = loss_1d
290
+ z_loss = z_loss_1d if return_z_loss == _TRUE.value else None
285
291
  else:
286
- z_loss = None
292
+ loss = torch.sum(loss_1d)
293
+ z_loss = torch.sum(z_loss_1d) if return_z_loss == _TRUE.value else None
287
294
 
288
295
  return loss, z_loss, _input
289
296
 
@@ -26,7 +26,6 @@ def fused_linear_cross_entropy_forward(
26
26
  reduction="mean",
27
27
  softcap=None,
28
28
  ):
29
- dtype = _input.dtype
30
29
  device = _input.device
31
30
 
32
31
  # inputs have shape: BT x H
@@ -74,9 +73,6 @@ def fused_linear_cross_entropy_forward(
74
73
  loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
75
74
  n_non_ignore = (target_chunk != ignore_index).sum().item()
76
75
 
77
- # when doing CE, use the upcasted precision
78
- logits_chunk = logits_chunk.float()
79
-
80
76
  # ensure _input and target are contiguous
81
77
  logits_chunk = logits_chunk.contiguous()
82
78
  target_chunk = target_chunk.contiguous()
@@ -103,13 +99,6 @@ def fused_linear_cross_entropy_forward(
103
99
  num_warps=32 if not is_hip() else 16,
104
100
  )
105
101
 
106
- # gradient of logits_chunk is computed in-place by the above triton kernel.
107
- # Following HuggingFace model source code, we do the forward and backward
108
- # w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size) is huge.
109
- # (reference: https://github.com/huggingface/transformers/blob/v4.42.4/src/transformers/models/llama/modeling_llama.py#L1194)
110
- # Propagating to lm_head's backward, we'll switch back to the original dtype.
111
- logits_chunk = logits_chunk.to(dtype)
112
-
113
102
  # gradient of logits_chunk is computed in-place by the above triton kernel and is of shape: chunk_size x V
114
103
  # thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H
115
104
  # additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
@@ -202,7 +202,7 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
202
202
  teacher_input (torch.tensor): input of the last projection layer in teacher model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension.
203
203
  teacher_weight (torch.tensor): the last projection layer in teacher model, with shape (V, H), where V is vocab size
204
204
  shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
205
- jsd_beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5`
205
+ jsd_beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
206
206
  ignore_index (int): the index to ignore. Default: -100
207
207
  temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0`
208
208
 
liger_kernel/ops/jsd.py CHANGED
@@ -18,7 +18,7 @@ def _jsd_kernel(
18
18
  dX_ptr,
19
19
  dX_stride,
20
20
  label_ptr,
21
- beta,
21
+ beta: tl.constexpr,
22
22
  n_non_ignore: int,
23
23
  ignore_index: tl.constexpr,
24
24
  n_cols,
@@ -50,17 +50,26 @@ def _jsd_kernel(
50
50
  X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
51
51
  Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
52
52
 
53
- Q = tl.exp(X)
54
- P = tl.exp(Y)
55
- M = beta * P + (1 - beta) * Q
56
- log_M = tl.log(M)
53
+ if beta == 0.0: # forward KL
54
+ Y_prob = tl.exp(Y)
55
+ loss = Y_prob * (Y - X)
56
+ dX = -Y_prob
57
+ elif beta == 1.0:
58
+ X_prob = tl.exp(X)
59
+ loss = X_prob * (X - Y)
60
+ dX = loss + X_prob
61
+ else:
62
+ Q = tl.exp(X)
63
+ P = tl.exp(Y)
64
+ M = beta * P + (1 - beta) * Q
65
+ log_M = tl.log(M)
66
+
67
+ loss = beta * P * Y + (1 - beta) * Q * X - M * log_M
68
+ dX = (1 - beta) * Q * (X - log_M)
57
69
 
58
- loss = beta * P * Y + (1 - beta) * Q * X - M * log_M
59
- # reduction == "batchmean"
60
70
  loss = loss / n_non_ignore
71
+ dX = dX / n_non_ignore
61
72
  tl.store(loss_ptr + offsets, loss, mask=mask)
62
-
63
- dX = (1 - beta) * Q * (X - log_M) / n_non_ignore
64
73
  tl.store(dX_ptr + offsets, dX, mask=mask)
65
74
 
66
75
 
@@ -142,7 +151,7 @@ class LigerJSDFunction(torch.autograd.Function):
142
151
  _input (torch.Tensor): predict values with shape (BT, V) in logspace
143
152
  target (torch.Tensor): ground truth values with shape (BT, V) in logspace
144
153
  shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
145
- beta (float): coefficient beta of generalized JSD in the open interval (0, 1)
154
+ beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
146
155
  ignore_index (int): the index to ignore. Default: -100
147
156
 
148
157
  Returns:
@@ -180,8 +180,13 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
180
180
  dY = dY.view(-1, dim)
181
181
  n_rows, n_cols = dY.shape
182
182
 
183
+ sm_count = 1
184
+ if X.device.type == "cuda":
185
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
186
+ elif X.device.type == "xpu":
187
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_subslice_count
188
+
183
189
  DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
184
- sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
185
190
  _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
186
191
  _DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
187
192
 
@@ -0,0 +1,238 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def _triton_qwen2vl_mrope(
8
+ q_ptr,
9
+ k_ptr,
10
+ cos,
11
+ sin,
12
+ sl,
13
+ n_qh: tl.constexpr,
14
+ n_kh: tl.constexpr,
15
+ hd: tl.constexpr,
16
+ pad_n_qh: tl.constexpr,
17
+ pad_n_kh: tl.constexpr,
18
+ pad_hd: tl.constexpr,
19
+ mrope_section_t: tl.constexpr,
20
+ mrope_section_h: tl.constexpr,
21
+ BLOCK_SIZE: tl.constexpr,
22
+ BACKWARD_PASS: tl.constexpr = False,
23
+ ):
24
+ pid = tl.program_id(0)
25
+
26
+ # locate start address
27
+ q_ptr = q_ptr + pid * (n_qh * hd)
28
+ k_ptr = k_ptr + pid * (n_kh * hd)
29
+
30
+ # ####################################################################
31
+ # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
32
+ # m of this program instance
33
+ # ####################################################################
34
+
35
+ # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which
36
+ # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension
37
+ # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index
38
+ # and pid % sl to get the sequence index.
39
+ # 2. We only need the left half of cos and sin matrix because the right half is just
40
+ # a clone of the left half.
41
+ t_end = mrope_section_t
42
+ h_end = t_end + mrope_section_h
43
+
44
+ cos_row_idx = pid % sl
45
+ t_cos = cos + cos_row_idx * hd
46
+ h_cos = t_cos + sl * hd
47
+ w_cos = h_cos + sl * hd
48
+ t_sin = sin + cos_row_idx * hd
49
+ h_sin = t_sin + sl * hd
50
+ w_sin = h_sin + sl * hd
51
+
52
+ cos_offsets = tl.arange(0, pad_hd // 2)
53
+ t_mask = cos_offsets < t_end
54
+ h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
55
+ w_mask = (h_end <= cos_offsets) & (cos_offsets < hd // 2)
56
+ t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
57
+ h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
58
+ w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0)
59
+ t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0)
60
+ h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0)
61
+ w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0)
62
+ cos_row = t_cos_row + h_cos_row + w_cos_row
63
+ sin_row = t_sin_row + h_sin_row + w_sin_row
64
+
65
+ # ####################################################################
66
+ # Load the left and right half of q and k for the current
67
+ # program instance (i.e. for the current token) separately
68
+ # ####################################################################
69
+ # left half of the head
70
+ first_half_q_offsets = (
71
+ tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
72
+ )
73
+ first_half_k_offsets = (
74
+ tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
75
+ )
76
+ first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
77
+ tl.arange(0, pad_hd // 2)[None, :] < hd // 2
78
+ )
79
+ first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
80
+ tl.arange(0, pad_hd // 2)[None, :] < hd // 2
81
+ )
82
+ q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
83
+ sin_row.dtype
84
+ )
85
+ k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
86
+ sin_row.dtype
87
+ )
88
+
89
+ # right half of the head
90
+ second_half_q_offsets = first_half_q_offsets + (hd // 2)
91
+ second_half_k_offsets = first_half_k_offsets + (hd // 2)
92
+ second_q_mask = first_q_mask
93
+ second_k_mask = first_k_mask
94
+ q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
95
+ sin_row.dtype
96
+ )
97
+ k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
98
+ sin_row.dtype
99
+ )
100
+
101
+ if not BACKWARD_PASS:
102
+ # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
103
+ new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
104
+ tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
105
+ new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
106
+ tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
107
+
108
+ new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
109
+ tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
110
+ new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
111
+ tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
112
+ else:
113
+ # with some math, we can get:
114
+ # dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin]
115
+ new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row
116
+ tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
117
+ new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row
118
+ tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
119
+
120
+ new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row
121
+ tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
122
+ new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row
123
+ tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
124
+
125
+
126
+ def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
127
+
128
+ # transpose it back to the physical shape because Triton looks at the physical storage
129
+ # note: q and k are incontiguous before the transformation and will become contiguous after transpose
130
+ q = q.transpose(1, 2)
131
+ k = k.transpose(1, 2)
132
+
133
+ batch_size, seq_len, n_q_head, head_dim = q.shape
134
+ n_kv_head = k.shape[2]
135
+ pad_hd = triton.next_power_of_2(head_dim)
136
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
137
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
138
+ BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
139
+
140
+ n_row = batch_size * seq_len
141
+
142
+ # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
143
+ q = q.contiguous()
144
+ k = k.contiguous()
145
+ cos = cos.contiguous()
146
+ sin = sin.contiguous()
147
+
148
+ _triton_qwen2vl_mrope[(n_row,)](
149
+ q,
150
+ k,
151
+ cos,
152
+ sin,
153
+ seq_len,
154
+ n_q_head,
155
+ n_kv_head,
156
+ head_dim,
157
+ pad_n_q_head,
158
+ pad_n_kv_head,
159
+ pad_hd,
160
+ mrope_section[0],
161
+ mrope_section[1],
162
+ BLOCK_SIZE=BLOCK_SIZE,
163
+ BACKWARD_PASS=False,
164
+ )
165
+ return q.transpose(1, 2), k.transpose(1, 2), cos, sin
166
+
167
+
168
+ def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
169
+ dq = dq.transpose(1, 2)
170
+ dk = dk.transpose(1, 2)
171
+
172
+ batch_size, seq_len, n_q_head, head_dim = dq.shape
173
+ n_kv_head = dk.shape[2]
174
+ pad_hd = triton.next_power_of_2(head_dim)
175
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
176
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
177
+ BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
178
+
179
+ n_row = batch_size * seq_len
180
+
181
+ # ensure dq and dk are contiguous
182
+ dq = dq.contiguous()
183
+ dk = dk.contiguous()
184
+
185
+ # backward is similar to forward except swapping few ops
186
+ _triton_qwen2vl_mrope[(n_row,)](
187
+ dq,
188
+ dk,
189
+ cos,
190
+ sin,
191
+ seq_len,
192
+ n_q_head,
193
+ n_kv_head,
194
+ head_dim,
195
+ pad_n_q_head,
196
+ pad_n_kv_head,
197
+ pad_hd,
198
+ mrope_section[0],
199
+ mrope_section[1],
200
+ BLOCK_SIZE=BLOCK_SIZE,
201
+ BACKWARD_PASS=True,
202
+ )
203
+ return dq.transpose(1, 2), dk.transpose(1, 2)
204
+
205
+
206
+ class LigerQwen2VLMRopeFunction(torch.autograd.Function):
207
+ """
208
+ Triton implementation of the Qwen2VL Multimodal Rotary Positional Embedding (M-RoPE) operation.
209
+
210
+ Please find the corresponding HuggingFace implementation here:
211
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
212
+ """
213
+
214
+ @staticmethod
215
+ def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1):
216
+ """
217
+ q size: (bsz, n_q_head, seq_len, head_dim)
218
+ k size: (bsz, n_kv_head, seq_len, head_dim)
219
+ cos size: (3, 1, seq_len, head_dim)
220
+ sin size: (3, 1, seq_len, head_dim)
221
+ """
222
+ q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section)
223
+ ctx.save_for_backward(cos, sin)
224
+ ctx.mrope_section = mrope_section
225
+ return q, k
226
+
227
+ def backward(ctx, dq, dk):
228
+ """
229
+ dq size: (bsz, n_q_head, seq_len, head_dim)
230
+ dk size: (bsz, n_kv_head, seq_len, head_dim)
231
+ cos size: (3, 1, seq_len, head_dim)
232
+ sin size: (3, 1, seq_len, head_dim)
233
+ """
234
+
235
+ cos, sin = ctx.saved_tensors
236
+ mrope_section = ctx.mrope_section
237
+ dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
238
+ return dq, dk, None, None, None, None
@@ -264,7 +264,12 @@ def rms_norm_backward(
264
264
  dY = dY.view(-1, dim)
265
265
  n_rows, n_cols = dY.shape
266
266
 
267
- sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
267
+ sm_count = 1
268
+ if X.device.type == "cuda":
269
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
270
+ elif X.device.type == "xpu":
271
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_subslice_count
272
+
268
273
  # fp32 for numerical stability especially.
269
274
  _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
270
275
 
liger_kernel/ops/utils.py CHANGED
@@ -20,6 +20,8 @@ import triton
20
20
  import triton.language as tl
21
21
  from packaging.version import Version
22
22
 
23
+ from liger_kernel.utils import infer_device
24
+
23
25
 
24
26
  def is_hip() -> bool:
25
27
  return torch.version.hip is not None
@@ -69,10 +71,11 @@ def compare_version(package: str, operator: Callable, target: str):
69
71
 
70
72
 
71
73
  def get_amp_custom_fwd_bwd() -> Callable:
74
+ device = infer_device()
72
75
  if compare_version("torch", operator.ge, "2.4.0"):
73
76
  return (
74
- functools.partial(torch.amp.custom_fwd, device_type="cuda"),
75
- functools.partial(torch.amp.custom_bwd, device_type="cuda"),
77
+ functools.partial(torch.amp.custom_fwd, device_type=device),
78
+ functools.partial(torch.amp.custom_bwd, device_type=device),
76
79
  )
77
80
  return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
78
81
 
@@ -10,21 +10,11 @@ from liger_kernel.ops.group_norm import LigerGroupNormFunction
10
10
  from liger_kernel.ops.jsd import LigerJSDFunction
11
11
  from liger_kernel.ops.kl_div import LigerKLDivLossFunction
12
12
  from liger_kernel.ops.layer_norm import LigerLayerNormFunction
13
+ from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
13
14
  from liger_kernel.ops.rms_norm import LigerRMSNormFunction
14
15
  from liger_kernel.ops.rope import LigerRopeFunction
15
16
  from liger_kernel.ops.swiglu import LigerSiLUMulFunction
16
17
 
17
- liger_swiglu = LigerSiLUMulFunction.apply
18
- liger_fused_linear_cross_entropy = LigerFusedLinearCrossEntropyFunction.apply
19
- liger_geglu = LigerGELUMulFunction.apply
20
- liger_rms_norm = LigerRMSNormFunction.apply
21
- liger_rope = LigerRopeFunction.apply
22
- liger_layer_norm = LigerLayerNormFunction.apply
23
- liger_kl_div = LigerKLDivLossFunction.apply
24
- liger_jsd = LigerJSDFunction.apply
25
- liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
26
- liger_group_norm = LigerGroupNormFunction.apply
27
-
28
18
 
29
19
  # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
30
20
  # `weight` and `size_average` are placeholders and not implemented yet
@@ -54,3 +44,130 @@ def liger_cross_entropy(
54
44
  if not return_z_loss:
55
45
  return loss
56
46
  return loss, z_loss
47
+
48
+
49
+ def liger_fused_linear_cross_entropy(
50
+ input,
51
+ weight,
52
+ target,
53
+ bias=None,
54
+ ignore_index: int = -100,
55
+ lse_square_scale: float = 0.0,
56
+ label_smoothing: float = 0.0,
57
+ reduction: str = "mean",
58
+ softcap: Optional[float] = None,
59
+ ):
60
+ return LigerFusedLinearCrossEntropyFunction.apply(
61
+ input,
62
+ weight,
63
+ target,
64
+ bias,
65
+ ignore_index,
66
+ lse_square_scale,
67
+ label_smoothing,
68
+ reduction,
69
+ softcap,
70
+ )
71
+
72
+
73
+ def liger_fused_linear_jsd(
74
+ student_input,
75
+ student_weight,
76
+ teacher_input,
77
+ teacher_weight,
78
+ shift_labels=None,
79
+ jsd_beta: float = 0.5,
80
+ ignore_index: int = -100,
81
+ temperature: float = 1.0,
82
+ ):
83
+ return LigerFusedLinearJSDFunction.apply(
84
+ student_input,
85
+ student_weight,
86
+ teacher_input,
87
+ teacher_weight,
88
+ shift_labels,
89
+ jsd_beta,
90
+ ignore_index,
91
+ temperature,
92
+ )
93
+
94
+
95
+ def liger_geglu(a, b):
96
+ return LigerGELUMulFunction.apply(a, b)
97
+
98
+
99
+ def liger_group_norm(
100
+ X,
101
+ affine_scaling_weight,
102
+ affine_shifting_bias,
103
+ num_channels,
104
+ num_groups,
105
+ eps,
106
+ ):
107
+ return LigerGroupNormFunction.apply(
108
+ X,
109
+ affine_scaling_weight,
110
+ affine_shifting_bias,
111
+ num_channels,
112
+ num_groups,
113
+ eps,
114
+ )
115
+
116
+
117
+ def liger_jsd(
118
+ input,
119
+ target,
120
+ shift_labels=None,
121
+ beta: float = 0.5,
122
+ ignore_index: int = -100,
123
+ ):
124
+ return LigerJSDFunction.apply(
125
+ input,
126
+ target,
127
+ shift_labels,
128
+ beta,
129
+ ignore_index,
130
+ )
131
+
132
+
133
+ # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.kl_div.html#torch.nn.functional.kl_div
134
+ # `size_average` and `mean` are being deprecated in torch API and are placeholders here
135
+ def liger_kl_div(
136
+ input,
137
+ target,
138
+ size_average: bool = True,
139
+ reduce: bool = True,
140
+ reduction: str = "mean",
141
+ log_target: bool = False,
142
+ eps: float = 1e-10,
143
+ ):
144
+ # Note: the default reduction in torch is `mean`, but being `batchmean` in Liger
145
+ return LigerKLDivLossFunction.apply(
146
+ input,
147
+ target,
148
+ reduction,
149
+ log_target,
150
+ eps,
151
+ )
152
+
153
+
154
+ def liger_layer_norm(X, W, B, eps):
155
+ return LigerLayerNormFunction.apply(X, W, B, eps)
156
+
157
+
158
+ def liger_qwen2vl_mrope(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
159
+ return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
160
+
161
+
162
+ def liger_rms_norm(
163
+ X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True
164
+ ):
165
+ return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place)
166
+
167
+
168
+ def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
169
+ return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
170
+
171
+
172
+ def liger_swiglu(a, b):
173
+ return LigerSiLUMulFunction.apply(a, b)
@@ -12,7 +12,7 @@ class LigerFusedLinearJSD(torch.nn.Module):
12
12
  the materialization of the large logits tensor.
13
13
 
14
14
  Args:
15
- jsd_beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5`
15
+ jsd_beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
16
16
  ignore_index (int): The index to ignore in the target. Default: `-100`
17
17
  temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0`
18
18
 
@@ -70,9 +70,6 @@ class LigerFusedLinearJSD(torch.nn.Module):
70
70
 
71
71
  def __init__(self, jsd_beta=0.5, ignore_index=-100, temperature=1.0):
72
72
  super().__init__()
73
- assert (
74
- jsd_beta > 0 and jsd_beta < 1
75
- ), f"beta must be greater than 0 and less than 1. Got: {jsd_beta}"
76
73
  assert temperature != 0, "temperature cannot be 0."
77
74
  self.jsd_beta = jsd_beta
78
75
  self.temperature = temperature
@@ -18,7 +18,7 @@ class LigerJSD(torch.nn.Module):
18
18
  :math:`P` denotes the teacher model and :math:`Q` denotes the student model.
19
19
 
20
20
  Args:
21
- beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5`
21
+ beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
22
22
  ignore_index (int): The index to ignore in the target. Default: `-100`
23
23
 
24
24
  Shape:
@@ -58,9 +58,6 @@ class LigerJSD(torch.nn.Module):
58
58
 
59
59
  def __init__(self, beta: float = 0.5, ignore_index: int = -100):
60
60
  super().__init__()
61
- assert (
62
- beta > 0 and beta < 1
63
- ), f"beta must be greater than 0 and less than 1. Got: {beta}"
64
61
  self.beta = beta
65
62
  self.ignore_index = ignore_index
66
63
 
@@ -36,6 +36,7 @@ from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forwa
36
36
  from liger_kernel.transformers.model.qwen2 import (
37
37
  lce_forward_deprecated as qwen2_lce_forward_deprecated,
38
38
  )
39
+ from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
39
40
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
40
41
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
41
42
  from liger_kernel.transformers.swiglu import (
@@ -610,9 +611,7 @@ def apply_liger_kernel_to_qwen2(
610
611
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
611
612
  modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
612
613
 
613
- # import pdb; pdb.set_trace()
614
614
  if fused_linear_cross_entropy:
615
-
616
615
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
617
616
  modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
618
617
  else: # if version < 4.46.1
@@ -644,6 +643,7 @@ def apply_liger_kernel_to_qwen2(
644
643
 
645
644
 
646
645
  def apply_liger_kernel_to_qwen2_vl(
646
+ rope: bool = True,
647
647
  cross_entropy: bool = False,
648
648
  fused_linear_cross_entropy: bool = True,
649
649
  rms_norm: bool = True,
@@ -678,8 +678,10 @@ def apply_liger_kernel_to_qwen2_vl(
678
678
  lce_forward as qwen2_vl_lce_forward,
679
679
  )
680
680
 
681
- # TODO: Support Qwen2-VL's multimodal RoPE implementation
682
-
681
+ if rope:
682
+ modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = (
683
+ liger_multimodal_rotary_pos_emb
684
+ )
683
685
  if rms_norm:
684
686
  # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
685
687
  modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm