liger-kernel 0.0.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. liger_kernel/env_report.py +46 -0
  2. liger_kernel/ops/cross_entropy.py +130 -63
  3. liger_kernel/ops/experimental/embedding.py +143 -0
  4. liger_kernel/ops/fused_linear_cross_entropy.py +203 -126
  5. liger_kernel/ops/geglu.py +54 -42
  6. liger_kernel/ops/kl_div.py +247 -0
  7. liger_kernel/ops/layer_norm.py +236 -0
  8. liger_kernel/ops/rms_norm.py +235 -81
  9. liger_kernel/ops/rope.py +91 -84
  10. liger_kernel/ops/swiglu.py +64 -57
  11. liger_kernel/ops/utils.py +12 -0
  12. liger_kernel/transformers/__init__.py +23 -0
  13. liger_kernel/transformers/auto_model.py +33 -0
  14. liger_kernel/transformers/cross_entropy.py +11 -1
  15. liger_kernel/transformers/experimental/embedding.py +28 -0
  16. liger_kernel/transformers/functional.py +19 -0
  17. liger_kernel/transformers/fused_linear_cross_entropy.py +8 -2
  18. liger_kernel/transformers/geglu.py +4 -2
  19. liger_kernel/transformers/kl_div.py +13 -0
  20. liger_kernel/transformers/layer_norm.py +30 -0
  21. liger_kernel/transformers/model/gemma.py +138 -0
  22. liger_kernel/transformers/model/llama.py +1 -1
  23. liger_kernel/transformers/model/mistral.py +138 -0
  24. liger_kernel/transformers/model/mixtral.py +158 -0
  25. liger_kernel/transformers/model/phi3.py +136 -0
  26. liger_kernel/transformers/model/qwen2.py +135 -0
  27. liger_kernel/transformers/model/qwen2_vl.py +172 -0
  28. liger_kernel/transformers/monkey_patch.py +629 -8
  29. liger_kernel/transformers/rms_norm.py +23 -4
  30. liger_kernel/transformers/swiglu.py +24 -0
  31. liger_kernel/transformers/trainer_integration.py +2 -0
  32. liger_kernel/triton/monkey_patch.py +0 -2
  33. liger_kernel-0.3.0.dist-info/METADATA +388 -0
  34. liger_kernel-0.3.0.dist-info/RECORD +42 -0
  35. {liger_kernel-0.0.1.dist-info → liger_kernel-0.3.0.dist-info}/WHEEL +1 -1
  36. liger_kernel-0.0.1.dist-info/METADATA +0 -16
  37. liger_kernel-0.0.1.dist-info/RECORD +0 -26
  38. {liger_kernel-0.0.1.dist-info → liger_kernel-0.3.0.dist-info}/LICENSE +0 -0
  39. {liger_kernel-0.0.1.dist-info → liger_kernel-0.3.0.dist-info}/NOTICE +0 -0
  40. {liger_kernel-0.0.1.dist-info → liger_kernel-0.3.0.dist-info}/top_level.txt +0 -0
@@ -1,28 +1,66 @@
1
+ """
2
+ This file incorporates code from Unsloth licensed under the Apache License, Version 2.0.
3
+ See the original Unsloth repository at https://github.com/unslothai/unsloth.
4
+
5
+ The following line
6
+ https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/rms_norm.py#L30
7
+ is based on code from Unsloth, located at:
8
+ https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
9
+
10
+ Modifications made by Yanning Chen, 2024.
11
+ """
12
+
13
+ import operator
14
+
1
15
  import torch
2
16
  import triton
3
17
  import triton.language as tl
4
18
 
5
- from liger_kernel.ops.utils import calculate_settings, ensure_contiguous
19
+ from liger_kernel.ops.utils import (
20
+ calculate_settings,
21
+ compare_version,
22
+ ensure_contiguous,
23
+ )
24
+
25
+ if compare_version("triton", operator.ge, "3.0.0"):
26
+ try:
27
+ # typical import path with dispatch available
28
+ from triton.language.extra.libdevice import rsqrt
29
+ except ModuleNotFoundError:
30
+ # for working with NGC containers
31
+ from triton.language.extra.cuda.libdevice import rsqrt
32
+ else:
33
+ from triton.language.math import rsqrt
34
+
35
+
36
+ _CASTING_MODE_NONE = tl.constexpr(-1)
37
+ _CASTING_MODE_LLAMA = tl.constexpr(0)
38
+ _CASTING_MODE_GEMMA = tl.constexpr(1)
6
39
 
7
40
 
8
41
  @triton.jit
9
- def _rms_norm_forward(
42
+ def _rms_norm_forward_kernel(
10
43
  Y_ptr,
11
44
  Y_row_stride,
12
45
  X_ptr,
13
46
  X_row_stride,
14
47
  W_ptr,
15
48
  W_row_stride,
16
- r_ptr,
17
- r_row_stride,
49
+ RSTD_ptr,
50
+ RSTD_row_stride,
18
51
  n_cols,
19
52
  eps,
53
+ offset,
54
+ casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
20
55
  BLOCK_SIZE: tl.constexpr,
21
56
  ):
22
57
  """
58
+ y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
59
+
23
60
  Reference:
24
61
  1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
25
62
  2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
63
+ 3. https://arxiv.org/pdf/1910.07467
26
64
  """
27
65
 
28
66
  row_idx = tl.program_id(0)
@@ -31,137 +69,253 @@ def _rms_norm_forward(
31
69
 
32
70
  Y_ptr += row_idx * Y_row_stride
33
71
  X_ptr += row_idx * X_row_stride
34
- r_ptr += row_idx * r_row_stride
72
+ RSTD_ptr += row_idx * RSTD_row_stride
35
73
 
36
74
  X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
75
+ X_row_dtype = X_row.dtype
37
76
  W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
38
77
 
39
- row_var = tl.sum(X_row * X_row, axis=0) / n_cols
40
- inv_var = tl.math.rsqrt(row_var + eps)
78
+ # On Llama, only rstd is computed on fp32
79
+ if casting_mode == _CASTING_MODE_LLAMA:
80
+ X_row = X_row.to(tl.float32)
81
+
82
+ # Gemma computes everything on fp32, and then casts back the output to the original dtype
83
+ if casting_mode == _CASTING_MODE_GEMMA:
84
+ W_row = W_row.to(tl.float32)
85
+ X_row = X_row.to(tl.float32)
86
+
87
+ mean_square = tl.sum(X_row * X_row, axis=0) / n_cols
88
+ rstd = rsqrt(mean_square + eps)
89
+
90
+ # We can save time by caching rms with minimal memory overhead
91
+ # because rms is much smaller compared to X_row, as rms is for each row.
92
+ # However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
93
+ tl.store(RSTD_ptr, rstd)
41
94
 
42
- # trick: row_var is tiny compared to X_row because it just has one per row we can save 4 ops (*, sum, /, rqrt) if we cache it
43
- tl.store(r_ptr, inv_var)
95
+ X_row = X_row * rstd
44
96
 
45
- normed = X_row * inv_var
97
+ # On Llama, the multiplication with the weight is done on the original dtype
98
+ if casting_mode == _CASTING_MODE_LLAMA:
99
+ X_row = X_row.to(X_row_dtype)
46
100
 
47
- output = normed * W_row
48
- tl.store(Y_ptr + col_offsets, output, mask=mask)
101
+ Y_row = X_row * (offset + W_row)
102
+
103
+ tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
49
104
 
50
105
 
51
106
  @triton.jit
52
- def _rms_norm_backward(
107
+ def _rms_norm_backward_kernel(
53
108
  dY_ptr,
54
109
  dY_row_stride,
55
110
  X_ptr,
56
111
  X_row_stride,
57
112
  W_ptr,
58
113
  W_row_stride,
59
- r_ptr,
60
- r_row_stride,
114
+ RSTD_ptr,
115
+ RSTD_row_stride,
61
116
  dW_ptr,
62
117
  dW_row_stride,
63
118
  n_cols,
64
- eps,
119
+ offset,
120
+ casting_mode: tl.constexpr,
65
121
  BLOCK_SIZE: tl.constexpr,
66
122
  ):
67
123
  """
68
- dx = (1 / var(x)) * (dy * w - (1/N) * (dy * w) dot x) * x
69
- dw = sum(dy * (x / var(x)))
124
+ dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
125
+ dw = sum(dy * (x / RMS)). summation over BxT dimension
70
126
  """
127
+
71
128
  row_idx = tl.program_id(0)
72
129
  col_offsets = tl.arange(0, BLOCK_SIZE)
73
130
  mask = col_offsets < n_cols
74
131
 
75
132
  dY_ptr += row_idx * dY_row_stride
76
133
  X_ptr += row_idx * X_row_stride
77
- r_ptr += row_idx * r_row_stride
134
+ RSTD_ptr += row_idx * RSTD_row_stride
78
135
  dW_ptr += row_idx * dW_row_stride
79
136
 
80
137
  dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0)
81
138
  X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
82
139
  W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
140
+ original_x_dtype = X_row.dtype
83
141
 
84
- # Get saved row variance
85
- inv_var = tl.load(r_ptr)
142
+ # Get cached rms
143
+ rstd_row = tl.load(RSTD_ptr)
86
144
 
87
- normed = X_row * inv_var
145
+ W_row = W_row + offset
88
146
 
89
- dY_W = dY_row * W_row
90
- dY_normed = dY_row * normed
147
+ X_row = X_row.to(tl.float32)
148
+
149
+ # Different bacward graphs for different casting modes
150
+ if casting_mode == _CASTING_MODE_LLAMA:
151
+ m = (dY_row * W_row).to(tl.float32)
152
+
153
+ elif casting_mode == _CASTING_MODE_GEMMA:
154
+ dY_row, W_row = (
155
+ dY_row.to(tl.float32),
156
+ W_row.to(tl.float32),
157
+ )
91
158
 
92
- rowsum_dY_normed = tl.sum(dY_W * normed, axis=0)
93
- output = inv_var / n_cols * (n_cols * dY_W - normed * rowsum_dY_normed)
94
- tl.store(dY_ptr + col_offsets, output, mask=mask)
159
+ m = dY_row * W_row
160
+
161
+ dX_row = rstd_row * m
162
+
163
+ dX_row += (rstd_row) * (
164
+ -(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
165
+ )
95
166
 
96
167
  # calculate the gradient of W
97
- tl.store(dW_ptr + col_offsets, dY_normed, mask=mask)
168
+ if casting_mode == _CASTING_MODE_LLAMA:
169
+ dW_row = dY_row * (X_row * rstd_row).to(original_x_dtype)
170
+ else:
171
+ # here X_row is already in fp32 (see previous if block)
172
+ dW_row = dY_row * (X_row * rstd_row)
98
173
 
174
+ tl.store(dY_ptr + col_offsets, dX_row, mask=mask)
175
+ tl.store(dW_ptr + col_offsets, dW_row, mask=mask)
99
176
 
100
- class LigerRMSNormFunction(torch.autograd.Function):
101
- @staticmethod
102
- @ensure_contiguous
103
- def forward(ctx, X, W, eps):
104
- shape = X.shape
105
- dim = shape[-1]
106
- X = X.view(-1, dim)
107
- n_rows, n_cols = X.shape
108
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
109
177
 
110
- Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
111
- r = torch.empty(n_rows, dtype=X.dtype, device=X.device)
178
+ _str_to_casting_mode = {
179
+ "llama": _CASTING_MODE_LLAMA.value,
180
+ "gemma": _CASTING_MODE_GEMMA.value,
181
+ "none": _CASTING_MODE_NONE.value,
182
+ }
112
183
 
113
- # Check constraints.
184
+
185
+ def rms_norm_forward(X, W, eps, offset, casting_mode):
186
+ if not isinstance(casting_mode, int):
187
+ assert (
188
+ casting_mode in _str_to_casting_mode
189
+ ), f"Invalid casting mode: {casting_mode}"
190
+ casting_mode = _str_to_casting_mode[casting_mode]
191
+ else:
114
192
  assert (
115
- X.shape[1] == W.shape[0]
116
- ), "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
193
+ casting_mode in _str_to_casting_mode.values()
194
+ ), f"Invalid casting mode: {casting_mode}"
117
195
 
118
- _rms_norm_forward[(n_rows,)](
119
- Y,
120
- Y.stride(0),
121
- X,
122
- X.stride(0),
123
- W,
124
- W.stride(0),
125
- r,
126
- r.stride(0),
127
- n_cols,
128
- eps,
129
- BLOCK_SIZE=BLOCK_SIZE,
130
- num_warps=num_warps,
196
+ shape = X.shape
197
+ dim = shape[-1]
198
+ X = X.view(-1, dim)
199
+ n_rows, n_cols = X.shape
200
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
201
+
202
+ Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
203
+ # RSTD is to cache rstd for each row
204
+ # RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
205
+ rstd_dtype = (
206
+ torch.float32
207
+ if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value)
208
+ else X.dtype
209
+ )
210
+ RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
211
+
212
+ # Check constraints.
213
+ assert (
214
+ X.shape[1] == W.shape[0]
215
+ ), "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
216
+
217
+ _rms_norm_forward_kernel[(n_rows,)](
218
+ Y,
219
+ Y.stride(0),
220
+ X,
221
+ X.stride(0),
222
+ W,
223
+ W.stride(0),
224
+ RSTD,
225
+ RSTD.stride(0),
226
+ n_cols,
227
+ eps,
228
+ offset,
229
+ casting_mode,
230
+ BLOCK_SIZE=BLOCK_SIZE,
231
+ num_warps=num_warps,
232
+ )
233
+ return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
234
+
235
+
236
+ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps):
237
+ shape = dY.shape
238
+ dim = shape[-1]
239
+ dY = dY.view(-1, dim)
240
+ n_rows, n_cols = dY.shape
241
+ dW = torch.empty_like(
242
+ X,
243
+ dtype=(torch.float32 if casting_mode == _CASTING_MODE_GEMMA.value else W.dtype),
244
+ )
245
+
246
+ # Here we use dY to store the value of dX to save memory
247
+ _rms_norm_backward_kernel[(n_rows,)](
248
+ dY,
249
+ dY.stride(0),
250
+ X,
251
+ X.stride(0),
252
+ W,
253
+ W.stride(0),
254
+ RSTD,
255
+ RSTD.stride(0),
256
+ dW,
257
+ dW.stride(0),
258
+ n_cols,
259
+ offset,
260
+ casting_mode,
261
+ BLOCK_SIZE=BLOCK_SIZE,
262
+ num_warps=num_warps,
263
+ )
264
+ dX = dY.view(*shape)
265
+ dW = torch.sum(dW, dim=0).to(W.dtype)
266
+ return dX, dW
267
+
268
+
269
+ class LigerRMSNormFunction(torch.autograd.Function):
270
+ """
271
+ Performs RMSNorm (Root Mean Square Normalization), which normalizes the input tensor `X` using the
272
+ weight tensor `W`, with an optional offset and casting mode.
273
+
274
+ Some models use an 'offset' to shift the weight tensor `W` by a constant value. For example, Gemma
275
+ uses an offset of 1.0, so the computation becomes `(X / RMS(X)) * (W + 1.0)` instead of the usual
276
+ `(X / RMS(X)) * W`. You can pass the offset value as an argument to the forward function.
277
+
278
+ In addition, different models cast their inputs at different places during RMSNorm computation. For
279
+ example, Gemma casts everything to fp32 nefore starting the computation, while Llama casts only the
280
+ inverse RMS to fp32. You can specify the casting mode using the `casting_mode` argument. We currently
281
+ support the following casting modes (they match HuggingFace Transformers' implementations):
282
+ - 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32.
283
+ - 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype.
284
+ - 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation.
285
+ """
286
+
287
+ @staticmethod
288
+ @ensure_contiguous
289
+ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama"):
290
+ """
291
+ X: (B, T, H) or (BxT, H)
292
+ W: (H,)
293
+ """
294
+ Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(
295
+ X, W, eps, offset, casting_mode
131
296
  )
132
- ctx.eps = eps
297
+ ctx.offset = offset
298
+ ctx.casting_mode = casting_mode
133
299
  ctx.BLOCK_SIZE = BLOCK_SIZE
134
300
  ctx.num_warps = num_warps
135
-
136
- ctx.save_for_backward(X, W, r)
137
- return Y.view(*shape)
301
+ ctx.save_for_backward(X, W, RSTD)
302
+ return Y
138
303
 
139
304
  @staticmethod
140
305
  @ensure_contiguous
141
306
  def backward(ctx, dY):
142
- shape = dY.shape
143
- dim = shape[-1]
144
- dY = dY.view(-1, dim)
145
- X, W, r = ctx.saved_tensors
146
- n_rows, n_cols = dY.shape
147
- dW = torch.zeros_like(X)
148
-
149
- _rms_norm_backward[(n_rows,)](
307
+ """
308
+ Y: (B, T, H) or (BxT, H)
309
+ """
310
+ X, W, RSTD = ctx.saved_tensors
311
+ dX, dW = rms_norm_backward(
150
312
  dY,
151
- dY.stride(0),
152
313
  X,
153
- X.stride(0),
154
314
  W,
155
- W.stride(0),
156
- r,
157
- r.stride(0),
158
- dW,
159
- dW.stride(0),
160
- n_cols,
161
- ctx.eps,
162
- BLOCK_SIZE=ctx.BLOCK_SIZE,
163
- num_warps=ctx.num_warps,
315
+ RSTD,
316
+ ctx.offset,
317
+ ctx.casting_mode,
318
+ ctx.BLOCK_SIZE,
319
+ ctx.num_warps,
164
320
  )
165
- dX = dY.view(*shape)
166
- dW = torch.sum(dW, dim=0)
167
- return dX, dW, None
321
+ return dX, dW, None, None, None
liger_kernel/ops/rope.py CHANGED
@@ -13,8 +13,8 @@ def _triton_rope(
13
13
  cos_row_stride,
14
14
  sin,
15
15
  sin_row_stride,
16
+ sl,
16
17
  bs: tl.constexpr,
17
- sl: tl.constexpr,
18
18
  n_qh: tl.constexpr,
19
19
  n_kh: tl.constexpr,
20
20
  hd: tl.constexpr,
@@ -117,6 +117,92 @@ def _triton_rope(
117
117
  tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
118
118
 
119
119
 
120
+ def rope_forward(q, k, cos, sin):
121
+
122
+ # transpose it back to the physical shape because Triton looks at the physical storage
123
+ # note: q and k are incontiguous before the transformation and will become contiguous after transpose
124
+ q = q.transpose(1, 2)
125
+ k = k.transpose(1, 2)
126
+
127
+ batch_size, seq_len, n_q_head, head_dim = q.shape
128
+ n_kv_head = k.shape[2]
129
+ pad_hd = triton.next_power_of_2(head_dim)
130
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
131
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
132
+ BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
133
+
134
+ n_row = batch_size * seq_len
135
+
136
+ # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
137
+ q = q.contiguous()
138
+ k = k.contiguous()
139
+ cos = cos.contiguous()
140
+ sin = sin.contiguous()
141
+
142
+ _triton_rope[(n_row,)](
143
+ q,
144
+ q.stride(1),
145
+ k,
146
+ k.stride(1),
147
+ cos,
148
+ cos.stride(-2),
149
+ sin,
150
+ sin.stride(-2),
151
+ seq_len,
152
+ batch_size,
153
+ n_q_head,
154
+ n_kv_head,
155
+ head_dim,
156
+ pad_n_q_head,
157
+ pad_n_kv_head,
158
+ pad_hd,
159
+ BLOCK_SIZE=BLOCK_SIZE,
160
+ BACKWARD_PASS=False,
161
+ )
162
+ return q.transpose(1, 2), k.transpose(1, 2), cos, sin
163
+
164
+
165
+ def rope_backward(dq, dk, cos, sin):
166
+ dq = dq.transpose(1, 2)
167
+ dk = dk.transpose(1, 2)
168
+
169
+ batch_size, seq_len, n_q_head, head_dim = dq.shape
170
+ n_kv_head = dk.shape[2]
171
+ pad_hd = triton.next_power_of_2(head_dim)
172
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
173
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
174
+ BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
175
+
176
+ n_row = batch_size * seq_len
177
+
178
+ # ensure dq and dk are contiguous
179
+ dq = dq.contiguous()
180
+ dk = dk.contiguous()
181
+
182
+ # backward is similar to forward except swapping few ops
183
+ _triton_rope[(n_row,)](
184
+ dq,
185
+ dq.stride(1),
186
+ dk,
187
+ dk.stride(1),
188
+ cos,
189
+ cos.stride(-2),
190
+ sin,
191
+ sin.stride(-2),
192
+ seq_len,
193
+ batch_size,
194
+ n_q_head,
195
+ n_kv_head,
196
+ head_dim,
197
+ pad_n_q_head,
198
+ pad_n_kv_head,
199
+ pad_hd,
200
+ BLOCK_SIZE=BLOCK_SIZE,
201
+ BACKWARD_PASS=True,
202
+ )
203
+ return dq.transpose(1, 2), dk.transpose(1, 2)
204
+
205
+
120
206
  class LigerRopeFunction(torch.autograd.Function):
121
207
  """
122
208
  Triton implementation of the Rotary Positional Embedding (RoPE) operation. Please note that
@@ -138,50 +224,9 @@ class LigerRopeFunction(torch.autograd.Function):
138
224
  cos size: (1, seq_len, head_dim)
139
225
  sin size: (1, seq_len, head_dim)
140
226
  """
141
-
142
- # transpose it back to the physical shape because Triton looks at the physical storage
143
- # note: q and k are incontiguous before the transformation and will become contiguous after transpose
144
- q = q.transpose(1, 2)
145
- k = k.transpose(1, 2)
146
-
147
- batch_size, seq_len, n_q_head, head_dim = q.shape
148
- n_kv_head = k.shape[2]
149
- pad_hd = triton.next_power_of_2(head_dim)
150
- pad_n_q_head = triton.next_power_of_2(n_q_head)
151
- pad_n_kv_head = triton.next_power_of_2(n_kv_head)
152
- BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
153
-
154
- n_row = batch_size * seq_len
155
-
156
- # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
157
- q = q.contiguous()
158
- k = k.contiguous()
159
- cos = cos.contiguous()
160
- sin = sin.contiguous()
161
-
162
- _triton_rope[(n_row,)](
163
- q,
164
- q.stride(1),
165
- k,
166
- k.stride(1),
167
- cos,
168
- cos.stride(-2),
169
- sin,
170
- sin.stride(-2),
171
- batch_size,
172
- seq_len,
173
- n_q_head,
174
- n_kv_head,
175
- head_dim,
176
- pad_n_q_head,
177
- pad_n_kv_head,
178
- pad_hd,
179
- BLOCK_SIZE=BLOCK_SIZE,
180
- BACKWARD_PASS=False,
181
- )
182
-
227
+ q, k, cos, sin = rope_forward(q, k, cos, sin)
183
228
  ctx.save_for_backward(cos, sin)
184
- return q.transpose(1, 2), k.transpose(1, 2)
229
+ return q, k
185
230
 
186
231
  def backward(ctx, dq, dk):
187
232
  """
@@ -192,43 +237,5 @@ class LigerRopeFunction(torch.autograd.Function):
192
237
  """
193
238
 
194
239
  cos, sin = ctx.saved_tensors
195
-
196
- dq = dq.transpose(1, 2)
197
- dk = dk.transpose(1, 2)
198
-
199
- batch_size, seq_len, n_q_head, head_dim = dq.shape
200
- n_kv_head = dk.shape[2]
201
- pad_hd = triton.next_power_of_2(head_dim)
202
- pad_n_q_head = triton.next_power_of_2(n_q_head)
203
- pad_n_kv_head = triton.next_power_of_2(n_kv_head)
204
- BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
205
-
206
- n_row = batch_size * seq_len
207
-
208
- # ensure dq and dk are contiguous
209
- dq = dq.contiguous()
210
- dk = dk.contiguous()
211
-
212
- # backward is similar to forward except swapping few ops
213
- _triton_rope[(n_row,)](
214
- dq,
215
- dq.stride(1),
216
- dk,
217
- dk.stride(1),
218
- cos,
219
- cos.stride(-2),
220
- sin,
221
- sin.stride(-2),
222
- batch_size,
223
- seq_len,
224
- n_q_head,
225
- n_kv_head,
226
- head_dim,
227
- pad_n_q_head,
228
- pad_n_kv_head,
229
- pad_hd,
230
- BLOCK_SIZE=BLOCK_SIZE,
231
- BACKWARD_PASS=True,
232
- )
233
-
234
- return dq.transpose(1, 2), dk.transpose(1, 2), None, None, None, None
240
+ dq, dk = rope_backward(dq, dk, cos, sin)
241
+ return dq, dk, None, None, None, None