liger-kernel 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. liger_kernel/ops/cross_entropy.py +5 -39
  2. liger_kernel/ops/experimental/mm_int8int2.py +355 -0
  3. liger_kernel/ops/fused_linear_cross_entropy.py +12 -9
  4. liger_kernel/ops/fused_linear_jsd.py +245 -0
  5. liger_kernel/ops/geglu.py +2 -2
  6. liger_kernel/ops/jsd.py +176 -0
  7. liger_kernel/ops/kl_div.py +2 -2
  8. liger_kernel/ops/rms_norm.py +67 -42
  9. liger_kernel/ops/swiglu.py +2 -2
  10. liger_kernel/ops/utils.py +62 -1
  11. liger_kernel/transformers/__init__.py +3 -0
  12. liger_kernel/transformers/functional.py +4 -0
  13. liger_kernel/transformers/fused_linear_jsd.py +98 -0
  14. liger_kernel/transformers/jsd.py +75 -0
  15. liger_kernel/transformers/model/gemma.py +124 -1
  16. liger_kernel/transformers/model/llama.py +135 -4
  17. liger_kernel/transformers/model/mistral.py +3 -0
  18. liger_kernel/transformers/model/mixtral.py +153 -2
  19. liger_kernel/transformers/model/mllama.py +274 -0
  20. liger_kernel/transformers/model/phi3.py +140 -2
  21. liger_kernel/transformers/model/qwen2.py +123 -2
  22. liger_kernel/transformers/model/qwen2_vl.py +8 -1
  23. liger_kernel/transformers/monkey_patch.py +158 -7
  24. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.0.dist-info}/METADATA +60 -28
  25. liger_kernel-0.4.0.dist-info/NOTICE +58 -0
  26. liger_kernel-0.4.0.dist-info/RECORD +48 -0
  27. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.0.dist-info}/WHEEL +1 -1
  28. liger_kernel-0.3.1.dist-info/NOTICE +0 -4
  29. liger_kernel-0.3.1.dist-info/RECORD +0 -42
  30. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.0.dist-info}/LICENSE +0 -0
  31. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.0.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,8 @@ import torch
2
2
  import triton
3
3
  import triton.language as tl
4
4
 
5
+ from liger_kernel.ops.utils import element_mul_kernel, is_hip
6
+
5
7
 
6
8
  @triton.jit
7
9
  def liger_cross_entropy_kernel(
@@ -126,7 +128,7 @@ def liger_cross_entropy_kernel(
126
128
  # So we can safely calculate log (softmax(X_y)) without overflow
127
129
  loss = -(ori_X_y - m - tl.log(d))
128
130
 
129
- # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
131
+ # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
130
132
  # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
131
133
  # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
132
134
  # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
@@ -159,42 +161,6 @@ def liger_cross_entropy_kernel(
159
161
  MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
160
162
 
161
163
 
162
- @triton.jit
163
- def element_mul_kernel(
164
- X_ptr,
165
- X_stride,
166
- grad_output_ptr,
167
- n_cols,
168
- BLOCK_SIZE: tl.constexpr,
169
- ):
170
- """
171
- This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
172
- The multiplication is performed in-place on the tensor pointed by X_ptr.
173
-
174
- Parameters:
175
- X_ptr: Pointer to the input tensor.
176
- X_stride (int): The stride of the input tensor.
177
- grad_output_ptr: Pointer to the gradient output value.
178
- n_cols (int): The number of columns in the input tensor.
179
- BLOCK_SIZE (int): The block size for Triton operations.
180
- """
181
-
182
- # Get the program ID and convert it to int64 to avoid overflow
183
- program_id = tl.program_id(0).to(tl.int64)
184
-
185
- # Locate the start index
186
- X_ptr += program_id * X_stride
187
-
188
- # Load the gradient output value
189
- grad_output = tl.load(grad_output_ptr)
190
-
191
- # Perform the element-wise multiplication
192
- for i in range(0, n_cols, BLOCK_SIZE):
193
- X_offsets = i + tl.arange(0, BLOCK_SIZE)
194
- X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
195
- tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
196
-
197
-
198
164
  def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reduction):
199
165
  BT, V = _input.shape
200
166
  n_rows = BT
@@ -228,7 +194,7 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti
228
194
  BLOCK_SIZE=BLOCK_SIZE,
229
195
  # TODO: 32 seems to give the best performance
230
196
  # Performance is quite sensitive to num_warps
231
- num_warps=32,
197
+ num_warps=32 if not is_hip() else 16,
232
198
  )
233
199
 
234
200
  loss = torch.sum(loss_1d)
@@ -253,7 +219,7 @@ def cross_entropy_backward(_input, grad_output):
253
219
  grad_output,
254
220
  V,
255
221
  BLOCK_SIZE=BLOCK_SIZE,
256
- num_warps=32,
222
+ num_warps=32 if not is_hip() else 16,
257
223
  )
258
224
 
259
225
  return _input
@@ -0,0 +1,355 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ def unpack_weights(packed: torch.Tensor, bits: int = 2) -> torch.Tensor:
7
+ values_per_item = 8 // bits
8
+ packed_shape = packed.shape
9
+
10
+ if len(packed_shape) == 1:
11
+ original_row_dim = packed_shape[0] * values_per_item
12
+ unpacked_shape = (original_row_dim,)
13
+ else:
14
+ original_row_dim = packed_shape[0] * values_per_item
15
+ unpacked_shape = (original_row_dim, *packed_shape[1:])
16
+
17
+ unpacked = torch.zeros(unpacked_shape, device=packed.device, dtype=torch.uint8)
18
+
19
+ for i in range(values_per_item):
20
+ start = i * packed_shape[0]
21
+ end = start + packed_shape[0]
22
+ mask = 3 << (2 * i)
23
+ unpacked[start:end] = (packed & mask) >> (2 * i)
24
+
25
+ unpacked = unpacked.to(torch.int32) - 1
26
+ return unpacked
27
+
28
+
29
+ def pack_weights(intweights: torch.Tensor, bits: int = 2) -> torch.Tensor:
30
+ intweights += 1
31
+ original_shape = intweights.shape
32
+ values_per_item = 8 // bits
33
+ row_dim = (original_shape[0] + values_per_item - 1) // values_per_item
34
+
35
+ if len(original_shape) == 1:
36
+ packed_tensor_shape = (row_dim,)
37
+ else:
38
+ packed_tensor_shape = (row_dim, *original_shape[1:])
39
+
40
+ packed = torch.zeros(
41
+ packed_tensor_shape, device=intweights.device, dtype=torch.uint8
42
+ )
43
+ unpacked = intweights.to(torch.uint8)
44
+
45
+ def lshift(t: torch.Tensor, bits: int):
46
+ return t << bits
47
+
48
+ it = min(values_per_item, (original_shape[0] // row_dim) + 1)
49
+ for i in range(it):
50
+ start = i * row_dim
51
+ end = min(start + row_dim, original_shape[0])
52
+ packed[: (end - start)] |= lshift(unpacked[start:end], bits * i)
53
+
54
+ return packed
55
+
56
+
57
+ def get_autotune_config():
58
+ return [
59
+ triton.Config(
60
+ {
61
+ "BLOCK_SIZE_M": 128,
62
+ "BLOCK_SIZE_N": 256,
63
+ "BLOCK_SIZE_K": 64,
64
+ "GROUP_SIZE_M": 8,
65
+ },
66
+ num_stages=3,
67
+ num_warps=8,
68
+ ),
69
+ triton.Config(
70
+ {
71
+ "BLOCK_SIZE_M": 64,
72
+ "BLOCK_SIZE_N": 256,
73
+ "BLOCK_SIZE_K": 32,
74
+ "GROUP_SIZE_M": 8,
75
+ },
76
+ num_stages=4,
77
+ num_warps=4,
78
+ ),
79
+ triton.Config(
80
+ {
81
+ "BLOCK_SIZE_M": 128,
82
+ "BLOCK_SIZE_N": 128,
83
+ "BLOCK_SIZE_K": 32,
84
+ "GROUP_SIZE_M": 8,
85
+ },
86
+ num_stages=4,
87
+ num_warps=4,
88
+ ),
89
+ triton.Config(
90
+ {
91
+ "BLOCK_SIZE_M": 128,
92
+ "BLOCK_SIZE_N": 64,
93
+ "BLOCK_SIZE_K": 32,
94
+ "GROUP_SIZE_M": 8,
95
+ },
96
+ num_stages=4,
97
+ num_warps=4,
98
+ ),
99
+ triton.Config(
100
+ {
101
+ "BLOCK_SIZE_M": 64,
102
+ "BLOCK_SIZE_N": 128,
103
+ "BLOCK_SIZE_K": 32,
104
+ "GROUP_SIZE_M": 8,
105
+ },
106
+ num_stages=4,
107
+ num_warps=4,
108
+ ),
109
+ triton.Config(
110
+ {
111
+ "BLOCK_SIZE_M": 128,
112
+ "BLOCK_SIZE_N": 32,
113
+ "BLOCK_SIZE_K": 32,
114
+ "GROUP_SIZE_M": 8,
115
+ },
116
+ num_stages=4,
117
+ num_warps=4,
118
+ ),
119
+ triton.Config(
120
+ {
121
+ "BLOCK_SIZE_M": 128,
122
+ "BLOCK_SIZE_N": 256,
123
+ "BLOCK_SIZE_K": 128,
124
+ "GROUP_SIZE_M": 8,
125
+ },
126
+ num_stages=3,
127
+ num_warps=8,
128
+ ),
129
+ triton.Config(
130
+ {
131
+ "BLOCK_SIZE_M": 256,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 8,
135
+ },
136
+ num_stages=3,
137
+ num_warps=8,
138
+ ),
139
+ triton.Config(
140
+ {
141
+ "BLOCK_SIZE_M": 256,
142
+ "BLOCK_SIZE_N": 64,
143
+ "BLOCK_SIZE_K": 128,
144
+ "GROUP_SIZE_M": 8,
145
+ },
146
+ num_stages=4,
147
+ num_warps=4,
148
+ ),
149
+ triton.Config(
150
+ {
151
+ "BLOCK_SIZE_M": 64,
152
+ "BLOCK_SIZE_N": 256,
153
+ "BLOCK_SIZE_K": 128,
154
+ "GROUP_SIZE_M": 8,
155
+ },
156
+ num_stages=4,
157
+ num_warps=4,
158
+ ),
159
+ triton.Config(
160
+ {
161
+ "BLOCK_SIZE_M": 128,
162
+ "BLOCK_SIZE_N": 128,
163
+ "BLOCK_SIZE_K": 128,
164
+ "GROUP_SIZE_M": 8,
165
+ },
166
+ num_stages=4,
167
+ num_warps=4,
168
+ ),
169
+ triton.Config(
170
+ {
171
+ "BLOCK_SIZE_M": 128,
172
+ "BLOCK_SIZE_N": 64,
173
+ "BLOCK_SIZE_K": 64,
174
+ "GROUP_SIZE_M": 8,
175
+ },
176
+ num_stages=4,
177
+ num_warps=4,
178
+ ),
179
+ triton.Config(
180
+ {
181
+ "BLOCK_SIZE_M": 64,
182
+ "BLOCK_SIZE_N": 128,
183
+ "BLOCK_SIZE_K": 64,
184
+ "GROUP_SIZE_M": 8,
185
+ },
186
+ num_stages=4,
187
+ num_warps=4,
188
+ ),
189
+ triton.Config(
190
+ {
191
+ "BLOCK_SIZE_M": 128,
192
+ "BLOCK_SIZE_N": 32,
193
+ "BLOCK_SIZE_K": 64,
194
+ "GROUP_SIZE_M": 8,
195
+ },
196
+ num_stages=4,
197
+ num_warps=4,
198
+ ),
199
+ triton.Config(
200
+ {
201
+ "BLOCK_SIZE_M": 32,
202
+ "BLOCK_SIZE_N": 32,
203
+ "BLOCK_SIZE_K": 32,
204
+ "GROUP_SIZE_M": 4,
205
+ },
206
+ num_stages=4,
207
+ num_warps=4,
208
+ ),
209
+ ]
210
+
211
+
212
+ @triton.autotune(
213
+ configs=get_autotune_config(),
214
+ key=["M", "N", "K"],
215
+ )
216
+ @triton.jit
217
+ def matmul_kernel(
218
+ a_ptr,
219
+ b_ptr,
220
+ c_ptr,
221
+ M,
222
+ N,
223
+ K: tl.constexpr,
224
+ stride_am,
225
+ stride_ak,
226
+ stride_bk,
227
+ stride_bn,
228
+ stride_cm,
229
+ stride_cn,
230
+ BLOCK_SIZE_M: tl.constexpr,
231
+ BLOCK_SIZE_N: tl.constexpr,
232
+ BLOCK_SIZE_K: tl.constexpr,
233
+ GROUP_SIZE_M: tl.constexpr,
234
+ ):
235
+ # We want K / 4 to be divisible by BLOCK_SIZE_K so that the multiplication can be aligned
236
+ tl.static_assert(
237
+ K % (4 * BLOCK_SIZE_K) == 0,
238
+ "K / 4 must be divisible by BLOCK_SIZE_K => K divisible by 4*BLOCK_SIZE_K",
239
+ )
240
+ # determine the block id in the 1D grid, pid <=> blockId in cuda
241
+ pid = tl.program_id(axis=0)
242
+ # number of blocks we would need in the M dimension
243
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
244
+ # number of blocks we would need in the N dimension
245
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
246
+ # blocks are grouped along the M dimension. num_pid_in_group computes how many blocks are grouped together,
247
+ # and group_id calculates the group to which the current block (pid) belongs.
248
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
249
+ group_id = pid // num_pid_in_group
250
+
251
+ # pid of the first block in the group that the current block belongs too
252
+ first_pid_m = group_id * GROUP_SIZE_M
253
+
254
+ # pid_m : pid of the block along the M dimension of the output matrix, and pid_n : pid of the block along the N dimension of the output matrix
255
+ # remember that the grid of blocks is 1D, but we calculate pid_m and pid_n to locate the block pid place in the output matrix
256
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
257
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
258
+ pid_n = (pid % num_pid_in_group) // group_size_m
259
+
260
+ # offs_am represent the indices of elements within the block for matrices A with respect to the M dimension
261
+ # offs_bn represent the indices of elements within the block for matrices B with respect to the N dimension
262
+ offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
263
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
264
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
265
+
266
+ """
267
+ This part of the code generates pointers to the specific blocks of matrices A and B that the current thread block will process.
268
+
269
+ As described in the PyTorch documentation, a stride refers to the step size needed to move from one element to the next along a given dimension:
270
+
271
+ For matrix A: stride_am = A.stride(0) = K (stride along the rows), and stride_ak = A.stride(1) = 1 (stride along the columns).
272
+ For matrix B: stride_bk = B.stride(0) = N (stride along the rows), and stride_bn = B.stride(1) = 1 (stride along the columns).
273
+ Now, let's break down the pointer generation:
274
+
275
+ offs_am[:, None] creates a column of shape [BLOCK_SIZE_M, 1], which represents the row indices of matrix A that this block is processing. It is multiplied by K (the number of columns in matrix A) since A is stored in row-major order. So, the element at position (i, j) in A is located at index i*K + j in memory.
276
+ offs_k[None, BLOCK_SIZE_K] creates a row vector representing the column indices of the block, i.e., a range from 0 to BLOCK_SIZE_K. This is used to compute the positions of the columns within the block.
277
+ When combined, the result has the shape [BLOCK_SIZE_M, BLOCK_SIZE_K], where each entry (i, j) points to the element in matrix A at position (i, j) for the current block.
278
+
279
+ The same logic is applied to matrix B, but the resulting shape is [BLOCK_SIZE_K, BLOCK_SIZE_N], representing the block of matrix B that the thread block will work on.
280
+ """
281
+ a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
282
+ b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
283
+
284
+ # An accumulator matrix is initialized with zeros. It stores the intermediate results of the block matrix multiplication.
285
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)
286
+ """
287
+ We split the loop into two layers. The outer loop runs 4 times, and each iteration focuses on a specific portion of matrix A.
288
+
289
+ For example, when i = 0, we’re only concerned with the blocks of matrix A that cover the range from 0 to K // (4 * BLOCK_SIZE_K).
290
+ Since matrix B is packed, its first dimension is effectively divided by 4. So, while we process the first segment of matrix A,
291
+ we still iterate over the entire first dimension of matrix B.
292
+
293
+ In each of the 4 iterations of the outer loop, we go through the full blocks of matrix B, but what changes is the data we extract.
294
+ Matrix B elements contain 4 weights, all packed into an int8 format, and during each iteration of the outer loop,
295
+ we extract a different weight by using bitwise shifting operations. This way, we access a unique weight on each pass.
296
+ """
297
+ for i in range(4):
298
+ b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
299
+ for j in range(0, tl.cdiv(K // 4, BLOCK_SIZE_K)):
300
+ k = i * tl.cdiv(K // 4, BLOCK_SIZE_K) + j
301
+ # load the block of matrix A
302
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0)
303
+ # load the block of matrix B
304
+ b_uint8 = tl.load(b_ptrs, mask=offs_k[:, None] < K, other=0)
305
+ # when i = 0 for example, we only care about the first 2 bits of the elements of the matrix B, so we use the mask 00000011 to mask the other bits
306
+ mask = 3 << (2 * i)
307
+ # we shift the results after the mask
308
+ b = (b_uint8 & mask) >> (2 * i)
309
+ # During the packing of the weights, it's easier to pack 0, 1, 2, then -1, 0, 1, so we add 1 to the weight tensor, and we substract it here
310
+ tensor_full = tl.full((1,), 1, dtype=tl.int8)
311
+ # We accumulate the result of multiplication of the blocks along the K dimension on int32 to avoid any overflows or underflows.
312
+ accumulator += tl.dot(a, (b.to(tl.int8) - tensor_full), out_dtype=tl.int32)
313
+ # we move the pointers, for the a_ptrs we more in a horizontal way along the second dimension -> we use strid_ak=1
314
+ # for b_ptrs we move in a vertical way, along the rows -> we use stride_bk=N
315
+ a_ptrs += BLOCK_SIZE_K * stride_ak
316
+ b_ptrs += BLOCK_SIZE_K * stride_bk
317
+
318
+ c = accumulator
319
+ # These lines compute the offsets into matrix C where the result of this block’s computation should be stored.
320
+ # stride_cm = N & stride_cn = 1
321
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
322
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
323
+ c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
324
+ # we do a boundary check to ensure only elements within matrix bounds are stored
325
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
326
+ tl.store(c_ptrs, c, mask=c_mask)
327
+
328
+
329
+ def matmul(a, b):
330
+ assert (
331
+ a.shape[1] == b.shape[0] * 4
332
+ ), "Incompatible dimensions, the weight matrix need to be packed"
333
+ assert a.is_contiguous(), "Matrix A must be contiguous"
334
+ M, K = a.shape
335
+ _, N = b.shape
336
+ # c is in int32 to avoid any overflows or underflows
337
+ c = torch.empty((M, N), device=a.device, dtype=torch.int32)
338
+ grid = lambda META: (
339
+ triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
340
+ )
341
+ matmul_kernel[grid](
342
+ a,
343
+ b,
344
+ c,
345
+ M,
346
+ N,
347
+ K,
348
+ a.stride(0),
349
+ a.stride(1),
350
+ b.stride(0),
351
+ b.stride(1),
352
+ c.stride(0),
353
+ c.stride(1),
354
+ )
355
+ return c
@@ -1,9 +1,12 @@
1
1
  import torch
2
2
  import triton
3
3
 
4
- from liger_kernel.ops.cross_entropy import (
4
+ from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel
5
+ from liger_kernel.ops.utils import (
6
+ amp_custom_bwd,
7
+ amp_custom_fwd,
5
8
  element_mul_kernel,
6
- liger_cross_entropy_kernel,
9
+ is_hip,
7
10
  )
8
11
 
9
12
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
@@ -21,9 +24,7 @@ def fused_linear_cross_entropy_forward(
21
24
  label_smoothing=0.0,
22
25
  reduction="mean",
23
26
  ):
24
- dtype = (
25
- torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else _input.dtype
26
- )
27
+ dtype = _input.dtype
27
28
  device = _input.device
28
29
 
29
30
  # inputs have shape: BT x H
@@ -92,7 +93,7 @@ def fused_linear_cross_entropy_forward(
92
93
  label_smoothing=label_smoothing,
93
94
  reduction=reduction,
94
95
  BLOCK_SIZE=BLOCK_SIZE,
95
- num_warps=32,
96
+ num_warps=32 if not is_hip() else 16,
96
97
  )
97
98
 
98
99
  # gradient of logits_chunk is computed in-place by the above triton kernel.
@@ -157,7 +158,7 @@ def fused_linear_cross_entropy_backward(
157
158
  grad_output,
158
159
  H,
159
160
  BLOCK_SIZE=BLOCK_SIZE,
160
- num_warps=32,
161
+ num_warps=32 if not is_hip() else 16,
161
162
  )
162
163
 
163
164
  # handle grad_weight
@@ -171,7 +172,7 @@ def fused_linear_cross_entropy_backward(
171
172
  grad_output,
172
173
  H,
173
174
  BLOCK_SIZE=BLOCK_SIZE,
174
- num_warps=32,
175
+ num_warps=32 if not is_hip() else 16,
175
176
  )
176
177
 
177
178
  if grad_bias is not None:
@@ -184,13 +185,14 @@ def fused_linear_cross_entropy_backward(
184
185
  grad_output,
185
186
  1,
186
187
  BLOCK_SIZE=BLOCK_SIZE,
187
- num_warps=32,
188
+ num_warps=32 if not is_hip() else 16,
188
189
  )
189
190
  return grad_input, grad_weight, grad_bias
190
191
 
191
192
 
192
193
  class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
193
194
  @staticmethod
195
+ @amp_custom_fwd
194
196
  def forward(
195
197
  ctx,
196
198
  _input,
@@ -230,6 +232,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
230
232
  return loss
231
233
 
232
234
  @staticmethod
235
+ @amp_custom_bwd
233
236
  def backward(ctx, grad_output):
234
237
  (grad_input, grad_weight, grad_bias) = ctx.saved_tensors
235
238
  grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(