liger-kernel 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. liger_kernel/env_report.py +46 -0
  2. liger_kernel/ops/cross_entropy.py +130 -63
  3. liger_kernel/ops/experimental/embedding.py +143 -0
  4. liger_kernel/ops/fused_linear_cross_entropy.py +203 -126
  5. liger_kernel/ops/geglu.py +54 -42
  6. liger_kernel/ops/kl_div.py +247 -0
  7. liger_kernel/ops/layer_norm.py +236 -0
  8. liger_kernel/ops/rms_norm.py +220 -84
  9. liger_kernel/ops/rope.py +91 -84
  10. liger_kernel/ops/swiglu.py +48 -41
  11. liger_kernel/ops/utils.py +12 -0
  12. liger_kernel/transformers/__init__.py +22 -0
  13. liger_kernel/transformers/auto_model.py +33 -0
  14. liger_kernel/transformers/cross_entropy.py +11 -1
  15. liger_kernel/transformers/experimental/embedding.py +28 -0
  16. liger_kernel/transformers/functional.py +19 -0
  17. liger_kernel/transformers/fused_linear_cross_entropy.py +8 -2
  18. liger_kernel/transformers/geglu.py +4 -2
  19. liger_kernel/transformers/kl_div.py +13 -0
  20. liger_kernel/transformers/layer_norm.py +30 -0
  21. liger_kernel/transformers/model/gemma.py +138 -0
  22. liger_kernel/transformers/model/llama.py +1 -1
  23. liger_kernel/transformers/model/mistral.py +138 -0
  24. liger_kernel/transformers/model/mixtral.py +158 -0
  25. liger_kernel/transformers/model/phi3.py +136 -0
  26. liger_kernel/transformers/model/qwen2.py +135 -0
  27. liger_kernel/transformers/model/qwen2_vl.py +172 -0
  28. liger_kernel/transformers/monkey_patch.py +605 -14
  29. liger_kernel/transformers/rms_norm.py +23 -4
  30. liger_kernel/transformers/swiglu.py +24 -0
  31. liger_kernel/transformers/trainer_integration.py +2 -45
  32. liger_kernel-0.3.0.dist-info/METADATA +388 -0
  33. liger_kernel-0.3.0.dist-info/RECORD +42 -0
  34. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.0.dist-info}/WHEEL +1 -1
  35. liger_kernel-0.1.0.dist-info/METADATA +0 -16
  36. liger_kernel-0.1.0.dist-info/RECORD +0 -27
  37. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.0.dist-info}/LICENSE +0 -0
  38. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.0.dist-info}/NOTICE +0 -0
  39. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.0.dist-info}/top_level.txt +0 -0
@@ -1,26 +1,61 @@
1
+ """
2
+ This file incorporates code from Unsloth licensed under the Apache License, Version 2.0.
3
+ See the original Unsloth repository at https://github.com/unslothai/unsloth.
4
+
5
+ The following line
6
+ https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/rms_norm.py#L30
7
+ is based on code from Unsloth, located at:
8
+ https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
9
+
10
+ Modifications made by Yanning Chen, 2024.
11
+ """
12
+
13
+ import operator
14
+
1
15
  import torch
2
16
  import triton
3
17
  import triton.language as tl
4
18
 
5
- from liger_kernel.ops.utils import calculate_settings, ensure_contiguous
19
+ from liger_kernel.ops.utils import (
20
+ calculate_settings,
21
+ compare_version,
22
+ ensure_contiguous,
23
+ )
24
+
25
+ if compare_version("triton", operator.ge, "3.0.0"):
26
+ try:
27
+ # typical import path with dispatch available
28
+ from triton.language.extra.libdevice import rsqrt
29
+ except ModuleNotFoundError:
30
+ # for working with NGC containers
31
+ from triton.language.extra.cuda.libdevice import rsqrt
32
+ else:
33
+ from triton.language.math import rsqrt
34
+
35
+
36
+ _CASTING_MODE_NONE = tl.constexpr(-1)
37
+ _CASTING_MODE_LLAMA = tl.constexpr(0)
38
+ _CASTING_MODE_GEMMA = tl.constexpr(1)
6
39
 
7
40
 
8
41
  @triton.jit
9
- def _rms_norm_forward(
42
+ def _rms_norm_forward_kernel(
10
43
  Y_ptr,
11
44
  Y_row_stride,
12
45
  X_ptr,
13
46
  X_row_stride,
14
47
  W_ptr,
15
48
  W_row_stride,
16
- r_ptr,
17
- r_row_stride,
49
+ RSTD_ptr,
50
+ RSTD_row_stride,
18
51
  n_cols,
19
52
  eps,
53
+ offset,
54
+ casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
20
55
  BLOCK_SIZE: tl.constexpr,
21
56
  ):
22
57
  """
23
- y_i = (x_i / (RMS)) * wi, RMS = sqrt(sum(x_i^2) / N)
58
+ y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
24
59
 
25
60
  Reference:
26
61
  1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
@@ -34,42 +69,59 @@ def _rms_norm_forward(
34
69
 
35
70
  Y_ptr += row_idx * Y_row_stride
36
71
  X_ptr += row_idx * X_row_stride
37
- r_ptr += row_idx * r_row_stride
72
+ RSTD_ptr += row_idx * RSTD_row_stride
38
73
 
39
74
  X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
75
+ X_row_dtype = X_row.dtype
40
76
  W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
41
77
 
78
+ # On Llama, only rstd is computed on fp32
79
+ if casting_mode == _CASTING_MODE_LLAMA:
80
+ X_row = X_row.to(tl.float32)
81
+
82
+ # Gemma computes everything on fp32, and then casts back the output to the original dtype
83
+ if casting_mode == _CASTING_MODE_GEMMA:
84
+ W_row = W_row.to(tl.float32)
85
+ X_row = X_row.to(tl.float32)
86
+
42
87
  mean_square = tl.sum(X_row * X_row, axis=0) / n_cols
43
- inv_rms = tl.math.rsqrt(mean_square + eps)
88
+ rstd = rsqrt(mean_square + eps)
44
89
 
45
90
  # We can save time by caching rms with minimal memory overhead
46
91
  # because rms is much smaller compared to X_row, as rms is for each row.
47
92
  # However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
48
- tl.store(r_ptr, inv_rms)
93
+ tl.store(RSTD_ptr, rstd)
94
+
95
+ X_row = X_row * rstd
96
+
97
+ # On Llama, the multiplication with the weight is done on the original dtype
98
+ if casting_mode == _CASTING_MODE_LLAMA:
99
+ X_row = X_row.to(X_row_dtype)
49
100
 
50
- Y_row = X_row * inv_rms * W_row
101
+ Y_row = X_row * (offset + W_row)
51
102
 
52
103
  tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
53
104
 
54
105
 
55
106
  @triton.jit
56
- def _rms_norm_backward(
107
+ def _rms_norm_backward_kernel(
57
108
  dY_ptr,
58
109
  dY_row_stride,
59
110
  X_ptr,
60
111
  X_row_stride,
61
112
  W_ptr,
62
113
  W_row_stride,
63
- r_ptr,
64
- r_row_stride,
114
+ RSTD_ptr,
115
+ RSTD_row_stride,
65
116
  dW_ptr,
66
117
  dW_row_stride,
67
118
  n_cols,
68
- eps,
119
+ offset,
120
+ casting_mode: tl.constexpr,
69
121
  BLOCK_SIZE: tl.constexpr,
70
122
  ):
71
123
  """
72
- dx = (1 / RMS) * [dy * w - (1 / N) * (1 / RMS^2) * ((dy * w) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
124
+ dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
73
125
  dw = sum(dy * (x / RMS)). summation over BxT dimension
74
126
  """
75
127
 
@@ -79,75 +131,175 @@ def _rms_norm_backward(
79
131
 
80
132
  dY_ptr += row_idx * dY_row_stride
81
133
  X_ptr += row_idx * X_row_stride
82
- r_ptr += row_idx * r_row_stride
134
+ RSTD_ptr += row_idx * RSTD_row_stride
83
135
  dW_ptr += row_idx * dW_row_stride
84
136
 
85
137
  dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0)
86
138
  X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
87
139
  W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
140
+ original_x_dtype = X_row.dtype
88
141
 
89
142
  # Get cached rms
90
- inv_rms_row = tl.load(r_ptr)
91
-
92
- dX_row = (inv_rms_row) * (
93
- dY_row * W_row
94
- - (1 / n_cols)
95
- * inv_rms_row
96
- * inv_rms_row
97
- * tl.sum(dY_row * W_row * X_row, axis=0)
98
- * X_row
143
+ rstd_row = tl.load(RSTD_ptr)
144
+
145
+ W_row = W_row + offset
146
+
147
+ X_row = X_row.to(tl.float32)
148
+
149
+ # Different bacward graphs for different casting modes
150
+ if casting_mode == _CASTING_MODE_LLAMA:
151
+ m = (dY_row * W_row).to(tl.float32)
152
+
153
+ elif casting_mode == _CASTING_MODE_GEMMA:
154
+ dY_row, W_row = (
155
+ dY_row.to(tl.float32),
156
+ W_row.to(tl.float32),
157
+ )
158
+
159
+ m = dY_row * W_row
160
+
161
+ dX_row = rstd_row * m
162
+
163
+ dX_row += (rstd_row) * (
164
+ -(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
99
165
  )
100
- tl.store(dY_ptr + col_offsets, dX_row, mask=mask)
101
166
 
102
167
  # calculate the gradient of W
103
- dW_row = dY_row * X_row * inv_rms_row
168
+ if casting_mode == _CASTING_MODE_LLAMA:
169
+ dW_row = dY_row * (X_row * rstd_row).to(original_x_dtype)
170
+ else:
171
+ # here X_row is already in fp32 (see previous if block)
172
+ dW_row = dY_row * (X_row * rstd_row)
173
+
174
+ tl.store(dY_ptr + col_offsets, dX_row, mask=mask)
104
175
  tl.store(dW_ptr + col_offsets, dW_row, mask=mask)
105
176
 
106
177
 
178
+ _str_to_casting_mode = {
179
+ "llama": _CASTING_MODE_LLAMA.value,
180
+ "gemma": _CASTING_MODE_GEMMA.value,
181
+ "none": _CASTING_MODE_NONE.value,
182
+ }
183
+
184
+
185
+ def rms_norm_forward(X, W, eps, offset, casting_mode):
186
+ if not isinstance(casting_mode, int):
187
+ assert (
188
+ casting_mode in _str_to_casting_mode
189
+ ), f"Invalid casting mode: {casting_mode}"
190
+ casting_mode = _str_to_casting_mode[casting_mode]
191
+ else:
192
+ assert (
193
+ casting_mode in _str_to_casting_mode.values()
194
+ ), f"Invalid casting mode: {casting_mode}"
195
+
196
+ shape = X.shape
197
+ dim = shape[-1]
198
+ X = X.view(-1, dim)
199
+ n_rows, n_cols = X.shape
200
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
201
+
202
+ Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
203
+ # RSTD is to cache rstd for each row
204
+ # RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
205
+ rstd_dtype = (
206
+ torch.float32
207
+ if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value)
208
+ else X.dtype
209
+ )
210
+ RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
211
+
212
+ # Check constraints.
213
+ assert (
214
+ X.shape[1] == W.shape[0]
215
+ ), "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
216
+
217
+ _rms_norm_forward_kernel[(n_rows,)](
218
+ Y,
219
+ Y.stride(0),
220
+ X,
221
+ X.stride(0),
222
+ W,
223
+ W.stride(0),
224
+ RSTD,
225
+ RSTD.stride(0),
226
+ n_cols,
227
+ eps,
228
+ offset,
229
+ casting_mode,
230
+ BLOCK_SIZE=BLOCK_SIZE,
231
+ num_warps=num_warps,
232
+ )
233
+ return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
234
+
235
+
236
+ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps):
237
+ shape = dY.shape
238
+ dim = shape[-1]
239
+ dY = dY.view(-1, dim)
240
+ n_rows, n_cols = dY.shape
241
+ dW = torch.empty_like(
242
+ X,
243
+ dtype=(torch.float32 if casting_mode == _CASTING_MODE_GEMMA.value else W.dtype),
244
+ )
245
+
246
+ # Here we use dY to store the value of dX to save memory
247
+ _rms_norm_backward_kernel[(n_rows,)](
248
+ dY,
249
+ dY.stride(0),
250
+ X,
251
+ X.stride(0),
252
+ W,
253
+ W.stride(0),
254
+ RSTD,
255
+ RSTD.stride(0),
256
+ dW,
257
+ dW.stride(0),
258
+ n_cols,
259
+ offset,
260
+ casting_mode,
261
+ BLOCK_SIZE=BLOCK_SIZE,
262
+ num_warps=num_warps,
263
+ )
264
+ dX = dY.view(*shape)
265
+ dW = torch.sum(dW, dim=0).to(W.dtype)
266
+ return dX, dW
267
+
268
+
107
269
  class LigerRMSNormFunction(torch.autograd.Function):
270
+ """
271
+ Performs RMSNorm (Root Mean Square Normalization), which normalizes the input tensor `X` using the
272
+ weight tensor `W`, with an optional offset and casting mode.
273
+
274
+ Some models use an 'offset' to shift the weight tensor `W` by a constant value. For example, Gemma
275
+ uses an offset of 1.0, so the computation becomes `(X / RMS(X)) * (W + 1.0)` instead of the usual
276
+ `(X / RMS(X)) * W`. You can pass the offset value as an argument to the forward function.
277
+
278
+ In addition, different models cast their inputs at different places during RMSNorm computation. For
279
+ example, Gemma casts everything to fp32 nefore starting the computation, while Llama casts only the
280
+ inverse RMS to fp32. You can specify the casting mode using the `casting_mode` argument. We currently
281
+ support the following casting modes (they match HuggingFace Transformers' implementations):
282
+ - 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32.
283
+ - 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype.
284
+ - 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation.
285
+ """
286
+
108
287
  @staticmethod
109
288
  @ensure_contiguous
110
- def forward(ctx, X, W, eps):
289
+ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama"):
111
290
  """
112
291
  X: (B, T, H) or (BxT, H)
113
292
  W: (H,)
114
293
  """
115
-
116
- shape = X.shape
117
- dim = shape[-1]
118
- X = X.view(-1, dim)
119
- n_rows, n_cols = X.shape
120
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
121
-
122
- Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
123
- # r is to cache (1/rms) for each row
124
- r = torch.empty(n_rows, dtype=X.dtype, device=X.device)
125
-
126
- # Check constraints.
127
- assert (
128
- X.shape[1] == W.shape[0]
129
- ), "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
130
-
131
- _rms_norm_forward[(n_rows,)](
132
- Y,
133
- Y.stride(0),
134
- X,
135
- X.stride(0),
136
- W,
137
- W.stride(0),
138
- r,
139
- r.stride(0),
140
- n_cols,
141
- eps,
142
- BLOCK_SIZE=BLOCK_SIZE,
143
- num_warps=num_warps,
294
+ Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(
295
+ X, W, eps, offset, casting_mode
144
296
  )
145
- ctx.eps = eps
297
+ ctx.offset = offset
298
+ ctx.casting_mode = casting_mode
146
299
  ctx.BLOCK_SIZE = BLOCK_SIZE
147
300
  ctx.num_warps = num_warps
148
-
149
- ctx.save_for_backward(X, W, r)
150
- return Y.view(*shape)
301
+ ctx.save_for_backward(X, W, RSTD)
302
+ return Y
151
303
 
152
304
  @staticmethod
153
305
  @ensure_contiguous
@@ -155,31 +307,15 @@ class LigerRMSNormFunction(torch.autograd.Function):
155
307
  """
156
308
  Y: (B, T, H) or (BxT, H)
157
309
  """
158
-
159
- shape = dY.shape
160
- dim = shape[-1]
161
- dY = dY.view(-1, dim)
162
- X, W, r = ctx.saved_tensors
163
- n_rows, n_cols = dY.shape
164
- dW = torch.zeros_like(X)
165
-
166
- # Here we use dY to store the value of dX to save memory
167
- _rms_norm_backward[(n_rows,)](
310
+ X, W, RSTD = ctx.saved_tensors
311
+ dX, dW = rms_norm_backward(
168
312
  dY,
169
- dY.stride(0),
170
313
  X,
171
- X.stride(0),
172
314
  W,
173
- W.stride(0),
174
- r,
175
- r.stride(0),
176
- dW,
177
- dW.stride(0),
178
- n_cols,
179
- ctx.eps,
180
- BLOCK_SIZE=ctx.BLOCK_SIZE,
181
- num_warps=ctx.num_warps,
315
+ RSTD,
316
+ ctx.offset,
317
+ ctx.casting_mode,
318
+ ctx.BLOCK_SIZE,
319
+ ctx.num_warps,
182
320
  )
183
- dX = dY.view(*shape)
184
- dW = torch.sum(dW, dim=0)
185
- return dX, dW, None
321
+ return dX, dW, None, None, None
liger_kernel/ops/rope.py CHANGED
@@ -13,8 +13,8 @@ def _triton_rope(
13
13
  cos_row_stride,
14
14
  sin,
15
15
  sin_row_stride,
16
+ sl,
16
17
  bs: tl.constexpr,
17
- sl: tl.constexpr,
18
18
  n_qh: tl.constexpr,
19
19
  n_kh: tl.constexpr,
20
20
  hd: tl.constexpr,
@@ -117,6 +117,92 @@ def _triton_rope(
117
117
  tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
118
118
 
119
119
 
120
+ def rope_forward(q, k, cos, sin):
121
+
122
+ # transpose it back to the physical shape because Triton looks at the physical storage
123
+ # note: q and k are incontiguous before the transformation and will become contiguous after transpose
124
+ q = q.transpose(1, 2)
125
+ k = k.transpose(1, 2)
126
+
127
+ batch_size, seq_len, n_q_head, head_dim = q.shape
128
+ n_kv_head = k.shape[2]
129
+ pad_hd = triton.next_power_of_2(head_dim)
130
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
131
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
132
+ BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
133
+
134
+ n_row = batch_size * seq_len
135
+
136
+ # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
137
+ q = q.contiguous()
138
+ k = k.contiguous()
139
+ cos = cos.contiguous()
140
+ sin = sin.contiguous()
141
+
142
+ _triton_rope[(n_row,)](
143
+ q,
144
+ q.stride(1),
145
+ k,
146
+ k.stride(1),
147
+ cos,
148
+ cos.stride(-2),
149
+ sin,
150
+ sin.stride(-2),
151
+ seq_len,
152
+ batch_size,
153
+ n_q_head,
154
+ n_kv_head,
155
+ head_dim,
156
+ pad_n_q_head,
157
+ pad_n_kv_head,
158
+ pad_hd,
159
+ BLOCK_SIZE=BLOCK_SIZE,
160
+ BACKWARD_PASS=False,
161
+ )
162
+ return q.transpose(1, 2), k.transpose(1, 2), cos, sin
163
+
164
+
165
+ def rope_backward(dq, dk, cos, sin):
166
+ dq = dq.transpose(1, 2)
167
+ dk = dk.transpose(1, 2)
168
+
169
+ batch_size, seq_len, n_q_head, head_dim = dq.shape
170
+ n_kv_head = dk.shape[2]
171
+ pad_hd = triton.next_power_of_2(head_dim)
172
+ pad_n_q_head = triton.next_power_of_2(n_q_head)
173
+ pad_n_kv_head = triton.next_power_of_2(n_kv_head)
174
+ BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
175
+
176
+ n_row = batch_size * seq_len
177
+
178
+ # ensure dq and dk are contiguous
179
+ dq = dq.contiguous()
180
+ dk = dk.contiguous()
181
+
182
+ # backward is similar to forward except swapping few ops
183
+ _triton_rope[(n_row,)](
184
+ dq,
185
+ dq.stride(1),
186
+ dk,
187
+ dk.stride(1),
188
+ cos,
189
+ cos.stride(-2),
190
+ sin,
191
+ sin.stride(-2),
192
+ seq_len,
193
+ batch_size,
194
+ n_q_head,
195
+ n_kv_head,
196
+ head_dim,
197
+ pad_n_q_head,
198
+ pad_n_kv_head,
199
+ pad_hd,
200
+ BLOCK_SIZE=BLOCK_SIZE,
201
+ BACKWARD_PASS=True,
202
+ )
203
+ return dq.transpose(1, 2), dk.transpose(1, 2)
204
+
205
+
120
206
  class LigerRopeFunction(torch.autograd.Function):
121
207
  """
122
208
  Triton implementation of the Rotary Positional Embedding (RoPE) operation. Please note that
@@ -138,50 +224,9 @@ class LigerRopeFunction(torch.autograd.Function):
138
224
  cos size: (1, seq_len, head_dim)
139
225
  sin size: (1, seq_len, head_dim)
140
226
  """
141
-
142
- # transpose it back to the physical shape because Triton looks at the physical storage
143
- # note: q and k are incontiguous before the transformation and will become contiguous after transpose
144
- q = q.transpose(1, 2)
145
- k = k.transpose(1, 2)
146
-
147
- batch_size, seq_len, n_q_head, head_dim = q.shape
148
- n_kv_head = k.shape[2]
149
- pad_hd = triton.next_power_of_2(head_dim)
150
- pad_n_q_head = triton.next_power_of_2(n_q_head)
151
- pad_n_kv_head = triton.next_power_of_2(n_kv_head)
152
- BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
153
-
154
- n_row = batch_size * seq_len
155
-
156
- # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
157
- q = q.contiguous()
158
- k = k.contiguous()
159
- cos = cos.contiguous()
160
- sin = sin.contiguous()
161
-
162
- _triton_rope[(n_row,)](
163
- q,
164
- q.stride(1),
165
- k,
166
- k.stride(1),
167
- cos,
168
- cos.stride(-2),
169
- sin,
170
- sin.stride(-2),
171
- batch_size,
172
- seq_len,
173
- n_q_head,
174
- n_kv_head,
175
- head_dim,
176
- pad_n_q_head,
177
- pad_n_kv_head,
178
- pad_hd,
179
- BLOCK_SIZE=BLOCK_SIZE,
180
- BACKWARD_PASS=False,
181
- )
182
-
227
+ q, k, cos, sin = rope_forward(q, k, cos, sin)
183
228
  ctx.save_for_backward(cos, sin)
184
- return q.transpose(1, 2), k.transpose(1, 2)
229
+ return q, k
185
230
 
186
231
  def backward(ctx, dq, dk):
187
232
  """
@@ -192,43 +237,5 @@ class LigerRopeFunction(torch.autograd.Function):
192
237
  """
193
238
 
194
239
  cos, sin = ctx.saved_tensors
195
-
196
- dq = dq.transpose(1, 2)
197
- dk = dk.transpose(1, 2)
198
-
199
- batch_size, seq_len, n_q_head, head_dim = dq.shape
200
- n_kv_head = dk.shape[2]
201
- pad_hd = triton.next_power_of_2(head_dim)
202
- pad_n_q_head = triton.next_power_of_2(n_q_head)
203
- pad_n_kv_head = triton.next_power_of_2(n_kv_head)
204
- BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
205
-
206
- n_row = batch_size * seq_len
207
-
208
- # ensure dq and dk are contiguous
209
- dq = dq.contiguous()
210
- dk = dk.contiguous()
211
-
212
- # backward is similar to forward except swapping few ops
213
- _triton_rope[(n_row,)](
214
- dq,
215
- dq.stride(1),
216
- dk,
217
- dk.stride(1),
218
- cos,
219
- cos.stride(-2),
220
- sin,
221
- sin.stride(-2),
222
- batch_size,
223
- seq_len,
224
- n_q_head,
225
- n_kv_head,
226
- head_dim,
227
- pad_n_q_head,
228
- pad_n_kv_head,
229
- pad_hd,
230
- BLOCK_SIZE=BLOCK_SIZE,
231
- BACKWARD_PASS=True,
232
- )
233
-
234
- return dq.transpose(1, 2), dk.transpose(1, 2), None, None, None, None
240
+ dq, dk = rope_backward(dq, dk, cos, sin)
241
+ return dq, dk, None, None, None, None
@@ -60,54 +60,61 @@ def _swiglu_backward_kernel(
60
60
  tl.store(b_ptr + col_offsets, db_row, mask=mask)
61
61
 
62
62
 
63
+ def swiglu_forward(a, b):
64
+ ori_shape = a.shape
65
+
66
+ n_cols = ori_shape[-1]
67
+ a = a.view(-1, n_cols)
68
+ b = b.view(-1, n_cols)
69
+ c = torch.empty_like(a)
70
+ n_rows = a.shape[0]
71
+
72
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
73
+
74
+ _swiglu_forward_kernel[(n_rows,)](
75
+ a,
76
+ b,
77
+ c,
78
+ c.stride(-2),
79
+ n_cols=n_cols,
80
+ BLOCK_SIZE=BLOCK_SIZE,
81
+ num_warps=num_warps,
82
+ )
83
+ return a, b, c.view(*ori_shape)
84
+
85
+
86
+ def swiglu_backward(a, b, dc):
87
+
88
+ ori_shape = dc.shape
89
+ n_cols = ori_shape[-1]
90
+ dc = dc.view(-1, n_cols)
91
+ n_rows = dc.shape[0]
92
+
93
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
94
+
95
+ _swiglu_backward_kernel[(n_rows,)](
96
+ dc,
97
+ a,
98
+ b,
99
+ dc.stride(-2),
100
+ n_cols=n_cols,
101
+ BLOCK_SIZE=BLOCK_SIZE,
102
+ num_warps=num_warps,
103
+ )
104
+ return a.view(*ori_shape), b.view(*ori_shape)
105
+
106
+
63
107
  class LigerSiLUMulFunction(torch.autograd.Function):
64
108
  @staticmethod
65
109
  @ensure_contiguous
66
110
  def forward(ctx, a, b):
67
- ori_shape = a.shape
68
-
69
- n_cols = ori_shape[-1]
70
- a = a.view(-1, n_cols)
71
- b = b.view(-1, n_cols)
72
- c = torch.zeros_like(a)
73
- n_rows = a.shape[0]
74
-
75
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
76
-
77
- _swiglu_forward_kernel[(n_rows,)](
78
- a,
79
- b,
80
- c,
81
- c.stride(-2),
82
- n_cols=n_cols,
83
- BLOCK_SIZE=BLOCK_SIZE,
84
- num_warps=num_warps,
85
- )
86
-
111
+ a, b, c = swiglu_forward(a, b)
87
112
  ctx.save_for_backward(a, b)
88
-
89
- return c.view(*ori_shape)
113
+ return c
90
114
 
91
115
  @staticmethod
92
116
  @ensure_contiguous
93
117
  def backward(ctx, dc):
94
-
95
- ori_shape = dc.shape
96
- n_cols = ori_shape[-1]
97
- dc = dc.view(-1, n_cols)
98
118
  a, b = ctx.saved_tensors
99
- n_rows = dc.shape[0]
100
-
101
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
102
-
103
- _swiglu_backward_kernel[(n_rows,)](
104
- dc,
105
- a,
106
- b,
107
- dc.stride(-2),
108
- n_cols=n_cols,
109
- BLOCK_SIZE=BLOCK_SIZE,
110
- num_warps=num_warps,
111
- )
112
-
113
- return a.view(*ori_shape), b.view(*ori_shape)
119
+ a, b = swiglu_backward(a, b, dc)
120
+ return a, b