liger-kernel-nightly 0.4.2.dev20241119054456__py3-none-any.whl → 0.4.2.dev20241119054537__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,238 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def _triton_qwen2vl_mrope(
8
+ q_ptr,
9
+ k_ptr,
10
+ cos,
11
+ sin,
12
+ sl,
13
+ n_qh: tl.constexpr,
14
+ n_kh: tl.constexpr,
15
+ hd: tl.constexpr,
16
+ pad_n_qh: tl.constexpr,
17
+ pad_n_kh: tl.constexpr,
18
+ pad_hd: tl.constexpr,
19
+ mrope_section_t: tl.constexpr,
20
+ mrope_section_h: tl.constexpr,
21
+ BLOCK_SIZE: tl.constexpr,
22
+ BACKWARD_PASS: tl.constexpr = False,
23
+ ):
24
+ pid = tl.program_id(0)
25
+
26
+ # locate start address
27
+ q_ptr = q_ptr + pid * (n_qh * hd)
28
+ k_ptr = k_ptr + pid * (n_kh * hd)
29
+
30
+ # ####################################################################
31
+ # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
32
+ # m of this program instance
33
+ # ####################################################################
34
+
35
+ # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which
36
+ # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension
37
+ # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index
38
+ # and pid % sl to get the sequence index.
39
+ # 2. We only need the left half of cos and sin matrix because the right half is just
40
+ # a clone of the left half.
41
+ t_end = mrope_section_t
42
+ h_end = t_end + mrope_section_h
43
+
44
+ cos_row_idx = pid % sl
45
+ t_cos = cos + cos_row_idx * hd
46
+ h_cos = t_cos + sl * hd
47
+ w_cos = h_cos + sl * hd
48
+ t_sin = sin + cos_row_idx * hd
49
+ h_sin = t_sin + sl * hd
50
+ w_sin = h_sin + sl * hd
51
+
52
+ cos_offsets = tl.arange(0, pad_hd // 2)
53
+ t_mask = cos_offsets < t_end
54
+ h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
55
+ w_mask = (h_end <= cos_offsets) & (cos_offsets < hd // 2)
56
+ t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
57
+ h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
58
+ w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0)
59
+ t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0)
60
+ h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0)
61
+ w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0)
62
+ cos_row = t_cos_row + h_cos_row + w_cos_row
63
+ sin_row = t_sin_row + h_sin_row + w_sin_row
64
+
65
+ # ####################################################################
66
+ # Load the left and right half of q and k for the current
67
+ # program instance (i.e. for the current token) separately
68
+ # ####################################################################
69
+ # left half of the head
70
+ first_half_q_offsets = (
71
+ tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
72
+ )
73
+ first_half_k_offsets = (
74
+ tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
75
+ )
76
+ first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
77
+ tl.arange(0, pad_hd // 2)[None, :] < hd // 2
78
+ )
79
+ first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
80
+ tl.arange(0, pad_hd // 2)[None, :] < hd // 2
81
+ )
82
+ q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
83
+ sin_row.dtype
84
+ )
85
+ k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
86
+ sin_row.dtype
87
+ )
88
+
89
+ # right half of the head
90
+ second_half_q_offsets = first_half_q_offsets + (hd // 2)
91
+ second_half_k_offsets = first_half_k_offsets + (hd // 2)
92
+ second_q_mask = first_q_mask
93
+ second_k_mask = first_k_mask
94
+ q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
95
+ sin_row.dtype
96
+ )
97
+ k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
98
+ sin_row.dtype
99
+ )
100
+
101
+ if not BACKWARD_PASS:
102
+ # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
103
+ new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
104
+ tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
105
+ new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
106
+ tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
107
+
108
+ new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
109
+ tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
110
+ new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
111
+ tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
112
+ else:
113
+ # with some math, we can get:
114
+ # dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin]
115
+ new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row
116
+ tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
117
+ new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row
118
+ tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
119
+
120
+ new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row
121
+ tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
122
+ new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row
123
+ tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
124
+
125
+
126
+ def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
127
+
128
+ # transpose it back to the physical shape because Triton looks at the physical storage
129
+ # note: q and k are incontiguous before the transformation and will become contiguous after transpose
130
+ q = q.transpose(1, 2)
131
+ k = k.transpose(1, 2)
132
+
133
+ batch_size, seq_len, n_q_head, head_dim = q.shape
134
+ n_kv_head = k.shape[2]
135
+ pad_hd = triton.next_power_of_2(head_dim)
136
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
137
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
138
+ BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
139
+
140
+ n_row = batch_size * seq_len
141
+
142
+ # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
143
+ q = q.contiguous()
144
+ k = k.contiguous()
145
+ cos = cos.contiguous()
146
+ sin = sin.contiguous()
147
+
148
+ _triton_qwen2vl_mrope[(n_row,)](
149
+ q,
150
+ k,
151
+ cos,
152
+ sin,
153
+ seq_len,
154
+ n_q_head,
155
+ n_kv_head,
156
+ head_dim,
157
+ pad_n_q_head,
158
+ pad_n_kv_head,
159
+ pad_hd,
160
+ mrope_section[0],
161
+ mrope_section[1],
162
+ BLOCK_SIZE=BLOCK_SIZE,
163
+ BACKWARD_PASS=False,
164
+ )
165
+ return q.transpose(1, 2), k.transpose(1, 2), cos, sin
166
+
167
+
168
+ def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
169
+ dq = dq.transpose(1, 2)
170
+ dk = dk.transpose(1, 2)
171
+
172
+ batch_size, seq_len, n_q_head, head_dim = dq.shape
173
+ n_kv_head = dk.shape[2]
174
+ pad_hd = triton.next_power_of_2(head_dim)
175
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
176
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
177
+ BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
178
+
179
+ n_row = batch_size * seq_len
180
+
181
+ # ensure dq and dk are contiguous
182
+ dq = dq.contiguous()
183
+ dk = dk.contiguous()
184
+
185
+ # backward is similar to forward except swapping few ops
186
+ _triton_qwen2vl_mrope[(n_row,)](
187
+ dq,
188
+ dk,
189
+ cos,
190
+ sin,
191
+ seq_len,
192
+ n_q_head,
193
+ n_kv_head,
194
+ head_dim,
195
+ pad_n_q_head,
196
+ pad_n_kv_head,
197
+ pad_hd,
198
+ mrope_section[0],
199
+ mrope_section[1],
200
+ BLOCK_SIZE=BLOCK_SIZE,
201
+ BACKWARD_PASS=True,
202
+ )
203
+ return dq.transpose(1, 2), dk.transpose(1, 2)
204
+
205
+
206
+ class LigerQwen2VLMRopeFunction(torch.autograd.Function):
207
+ """
208
+ Triton implementation of the Qwen2VL Multimodal Rotary Positional Embedding (M-RoPE) operation.
209
+
210
+ Please find the corresponding HuggingFace implementation here:
211
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
212
+ """
213
+
214
+ @staticmethod
215
+ def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1):
216
+ """
217
+ q size: (bsz, n_q_head, seq_len, head_dim)
218
+ k size: (bsz, n_kv_head, seq_len, head_dim)
219
+ cos size: (3, 1, seq_len, head_dim)
220
+ sin size: (3, 1, seq_len, head_dim)
221
+ """
222
+ q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section)
223
+ ctx.save_for_backward(cos, sin)
224
+ ctx.mrope_section = mrope_section
225
+ return q, k
226
+
227
+ def backward(ctx, dq, dk):
228
+ """
229
+ dq size: (bsz, n_q_head, seq_len, head_dim)
230
+ dk size: (bsz, n_kv_head, seq_len, head_dim)
231
+ cos size: (3, 1, seq_len, head_dim)
232
+ sin size: (3, 1, seq_len, head_dim)
233
+ """
234
+
235
+ cos, sin = ctx.saved_tensors
236
+ mrope_section = ctx.mrope_section
237
+ dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
238
+ return dq, dk, None, None, None, None
@@ -10,6 +10,7 @@ from liger_kernel.ops.group_norm import LigerGroupNormFunction
10
10
  from liger_kernel.ops.jsd import LigerJSDFunction
11
11
  from liger_kernel.ops.kl_div import LigerKLDivLossFunction
12
12
  from liger_kernel.ops.layer_norm import LigerLayerNormFunction
13
+ from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
13
14
  from liger_kernel.ops.rms_norm import LigerRMSNormFunction
14
15
  from liger_kernel.ops.rope import LigerRopeFunction
15
16
  from liger_kernel.ops.swiglu import LigerSiLUMulFunction
@@ -19,6 +20,7 @@ liger_fused_linear_cross_entropy = LigerFusedLinearCrossEntropyFunction.apply
19
20
  liger_geglu = LigerGELUMulFunction.apply
20
21
  liger_rms_norm = LigerRMSNormFunction.apply
21
22
  liger_rope = LigerRopeFunction.apply
23
+ liger_qwen2vl_mrope = LigerQwen2VLMRopeFunction.apply
22
24
  liger_layer_norm = LigerLayerNormFunction.apply
23
25
  liger_kl_div = LigerKLDivLossFunction.apply
24
26
  liger_jsd = LigerJSDFunction.apply
@@ -36,6 +36,7 @@ from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forwa
36
36
  from liger_kernel.transformers.model.qwen2 import (
37
37
  lce_forward_deprecated as qwen2_lce_forward_deprecated,
38
38
  )
39
+ from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
39
40
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
40
41
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
41
42
  from liger_kernel.transformers.swiglu import (
@@ -642,6 +643,7 @@ def apply_liger_kernel_to_qwen2(
642
643
 
643
644
 
644
645
  def apply_liger_kernel_to_qwen2_vl(
646
+ rope: bool = True,
645
647
  cross_entropy: bool = False,
646
648
  fused_linear_cross_entropy: bool = True,
647
649
  rms_norm: bool = True,
@@ -676,8 +678,10 @@ def apply_liger_kernel_to_qwen2_vl(
676
678
  lce_forward as qwen2_vl_lce_forward,
677
679
  )
678
680
 
679
- # TODO: Support Qwen2-VL's multimodal RoPE implementation
680
-
681
+ if rope:
682
+ modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = (
683
+ liger_multimodal_rotary_pos_emb
684
+ )
681
685
  if rms_norm:
682
686
  # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
683
687
  modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
@@ -0,0 +1,20 @@
1
+ from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
2
+
3
+
4
+ def liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
5
+ """
6
+ Applies Multimodal Rotary Positional Embedding (M-RoPE) operation to query and key states.
7
+
8
+ Args:
9
+ q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
10
+ k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
11
+ cos (torch.Tensor): The cosine tensor of shape (3, 1, seq_len, head_dim).
12
+ sin (torch.Tensor): The sine tensor of shape (3, 1, seq_len, head_dim).
13
+ mrope_section (List[int]): The multimodal rope section for channel dimension of temporal, height and width in rope calculation.
14
+ unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
15
+
16
+ Returns:
17
+ Tuple[torch.Tensor, torch.Tensor]: The query and key tensors after applying the M-RoPE operation.
18
+ """
19
+
20
+ return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.2.dev20241119054456
3
+ Version: 0.4.2.dev20241119054537
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -13,6 +13,7 @@ liger_kernel/ops/group_norm.py,sha256=VaRErVJGR4JqgXXvuIjNGTn3E2egjLtU1y3ymwIf4d
13
13
  liger_kernel/ops/jsd.py,sha256=anWfdioucxZy4JQfTvbHBR-IQrZKeH-gBF1MHwwTuTQ,5781
14
14
  liger_kernel/ops/kl_div.py,sha256=03FNXfvCb6M-56hhFepAFV9p6brArPR6KOKkdGD34mw,8374
15
15
  liger_kernel/ops/layer_norm.py,sha256=unGMYMOPqtkM9aTrokhcqgPmsV2AUN7Yzv86isVB9OI,7422
16
+ liger_kernel/ops/qwen2vl_mrope.py,sha256=xZvQnhkSTjU-k6KiiRn9e0SYO1ESs1jmuZFMICduLpc,8552
16
17
  liger_kernel/ops/rms_norm.py,sha256=LAxCiFjpBbb7TDh9pOzsVmDGAR7eEbTDnEhjSd6TX_M,11583
17
18
  liger_kernel/ops/rope.py,sha256=jrzaA9-6Orn44y_IIam9_YNPQxOFK2FrIRNfFea4EtU,8513
18
19
  liger_kernel/ops/swiglu.py,sha256=Fwxtd76rhHKT9ShQAGca9RsnASplAVxtYKHmiT73_yA,2994
@@ -22,7 +23,7 @@ liger_kernel/ops/experimental/mm_int8int2.py,sha256=JpGVZCgRC6T8XMUJ_QbZRS2XU1bh
22
23
  liger_kernel/transformers/__init__.py,sha256=gia-eBxr7TLxU0GdDf8AfCY4WgDlFLqIGSt7EoQGsBA,1336
23
24
  liger_kernel/transformers/auto_model.py,sha256=RMIwQHSiXoksXFTIqFZ4PLBgoqkxJJAT3q1Qh47bGN8,1552
24
25
  liger_kernel/transformers/cross_entropy.py,sha256=yEm_YQ7oa3_BzT3hdW6KrAslduhSqWcJQVNZZDcWCg4,1758
25
- liger_kernel/transformers/functional.py,sha256=Hd4WvxNqOJHM9HmRfAQueRnmOy5WU9nFsFygB5Iv8Xs,2000
26
+ liger_kernel/transformers/functional.py,sha256=jwTHmyjOVC1_I-6ztY1EbbRqPIfFHojcHrP2c4P6U4I,2123
26
27
  liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=_i0PXSp5iZ9pKXdEeZ4lvHCENJYjV4y74yz3ZRG5XQg,1484
27
28
  liger_kernel/transformers/fused_linear_jsd.py,sha256=MJ-KjmLZnakuoVpnbDGkd95DQgvESniyrRWYzollVZM,4066
28
29
  liger_kernel/transformers/geglu.py,sha256=QcrME_8ooIn0xa59LaC0aoOdRrBIFd11Y0bAyF0NfCw,1130
@@ -30,7 +31,8 @@ liger_kernel/transformers/group_norm.py,sha256=FJ9R7mS9G1wO-GRIQ6QKSmIhnZ6nQ6GIk
30
31
  liger_kernel/transformers/jsd.py,sha256=W-5CypO2mx4-bUWOxq1KScfCdoXlLoYbtt5xBnRzMs4,3056
31
32
  liger_kernel/transformers/kl_div.py,sha256=qVhjBg6tjRyue5iZ3NFxo8uySY4JuIFJyv0IM_50F24,431
32
33
  liger_kernel/transformers/layer_norm.py,sha256=fd6o4kSHJWolQMWxh-l1qObfgL08ruNbUoBiANKX1ow,972
33
- liger_kernel/transformers/monkey_patch.py,sha256=Qk8jTO1AO6-knod7w8LtZKVIvm5gapsHInBMCjy6zR8,38233
34
+ liger_kernel/transformers/monkey_patch.py,sha256=Fk2v4GZQDJzfh3Cpc6BHNJbs_tungDyWmqS9nuG9Lc4,38406
35
+ liger_kernel/transformers/qwen2vl_mrope.py,sha256=SfSQVwOe7ArrVfpmIdfZrdzCxmcj7V-YQp9zDu17-ao,1043
34
36
  liger_kernel/transformers/rms_norm.py,sha256=AHstklNIO1PLHjjCBU-TPuUD-Fl_pycJUTLlJNojbV8,1189
35
37
  liger_kernel/transformers/rope.py,sha256=m-ah8vZBYW8tfplTXCiAPMHJWlB1tdp_JPXJeWE-Boo,943
36
38
  liger_kernel/transformers/swiglu.py,sha256=0-tVJ8xEYfhxnduc16PflXFj8sZPxdx9sHUn3hfwCI4,2468
@@ -48,9 +50,9 @@ liger_kernel/transformers/model/qwen2.py,sha256=EyhSSzQOskGjSnCsKMZpd1s5IAIlHd5P
48
50
  liger_kernel/transformers/model/qwen2_vl.py,sha256=bIQe2bWiY--G84FhCD29Gdi64_qHP6vbcGsK6vKysQE,8547
49
51
  liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
50
52
  liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
51
- liger_kernel_nightly-0.4.2.dev20241119054456.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
52
- liger_kernel_nightly-0.4.2.dev20241119054456.dist-info/METADATA,sha256=FDmZnvTxvl1UbpHLw6hwcuMTHMGHdTi_1GS9N7OhZoQ,21556
53
- liger_kernel_nightly-0.4.2.dev20241119054456.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
54
- liger_kernel_nightly-0.4.2.dev20241119054456.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
55
- liger_kernel_nightly-0.4.2.dev20241119054456.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
56
- liger_kernel_nightly-0.4.2.dev20241119054456.dist-info/RECORD,,
53
+ liger_kernel_nightly-0.4.2.dev20241119054537.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
54
+ liger_kernel_nightly-0.4.2.dev20241119054537.dist-info/METADATA,sha256=BoWJXhq2CldcpNJomRA4lKZ4J2AEr9tVfrjGyxq3EdM,21556
55
+ liger_kernel_nightly-0.4.2.dev20241119054537.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
56
+ liger_kernel_nightly-0.4.2.dev20241119054537.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
57
+ liger_kernel_nightly-0.4.2.dev20241119054537.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
58
+ liger_kernel_nightly-0.4.2.dev20241119054537.dist-info/RECORD,,