liger-kernel 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. liger_kernel/chunked_loss/README.md +25 -0
  2. liger_kernel/chunked_loss/__init__.py +3 -0
  3. liger_kernel/chunked_loss/cpo_loss.py +18 -8
  4. liger_kernel/chunked_loss/dpo_loss.py +20 -10
  5. liger_kernel/chunked_loss/functional.py +4 -0
  6. liger_kernel/chunked_loss/fused_linear_distillation.py +58 -44
  7. liger_kernel/chunked_loss/fused_linear_preference.py +108 -60
  8. liger_kernel/chunked_loss/fused_linear_rlhf.py +213 -0
  9. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +246 -0
  10. liger_kernel/chunked_loss/grpo_loss.py +160 -0
  11. liger_kernel/chunked_loss/jsd_loss.py +154 -0
  12. liger_kernel/chunked_loss/kto_loss.py +172 -0
  13. liger_kernel/chunked_loss/orpo_loss.py +8 -9
  14. liger_kernel/chunked_loss/simpo_loss.py +22 -8
  15. liger_kernel/env_report.py +5 -12
  16. liger_kernel/ops/cross_entropy.py +102 -51
  17. liger_kernel/ops/experimental/embedding.py +1 -3
  18. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  19. liger_kernel/ops/fused_linear_cross_entropy.py +89 -55
  20. liger_kernel/ops/fused_linear_jsd.py +14 -32
  21. liger_kernel/ops/geglu.py +6 -17
  22. liger_kernel/ops/group_norm.py +11 -28
  23. liger_kernel/ops/jsd.py +5 -9
  24. liger_kernel/ops/kl_div.py +8 -11
  25. liger_kernel/ops/layer_norm.py +23 -12
  26. liger_kernel/ops/qwen2vl_mrope.py +8 -25
  27. liger_kernel/ops/rms_norm.py +14 -32
  28. liger_kernel/ops/rope.py +31 -33
  29. liger_kernel/ops/swiglu.py +4 -8
  30. liger_kernel/ops/tvd.py +207 -0
  31. liger_kernel/ops/utils.py +3 -2
  32. liger_kernel/transformers/__init__.py +19 -24
  33. liger_kernel/transformers/auto_model.py +6 -13
  34. liger_kernel/transformers/cross_entropy.py +7 -9
  35. liger_kernel/transformers/experimental/embedding.py +1 -3
  36. liger_kernel/transformers/functional.py +28 -7
  37. liger_kernel/transformers/fused_linear_cross_entropy.py +15 -10
  38. liger_kernel/transformers/geglu.py +1 -4
  39. liger_kernel/transformers/group_norm.py +9 -15
  40. liger_kernel/transformers/jsd.py +1 -3
  41. liger_kernel/transformers/kl_div.py +1 -3
  42. liger_kernel/transformers/layer_norm.py +3 -9
  43. liger_kernel/transformers/model/gemma.py +18 -40
  44. liger_kernel/transformers/model/gemma2.py +19 -41
  45. liger_kernel/transformers/model/llama.py +22 -48
  46. liger_kernel/transformers/model/mistral.py +14 -26
  47. liger_kernel/transformers/model/mixtral.py +24 -54
  48. liger_kernel/transformers/model/mllama.py +16 -36
  49. liger_kernel/transformers/model/olmo2.py +124 -0
  50. liger_kernel/transformers/model/phi3.py +18 -40
  51. liger_kernel/transformers/model/qwen2.py +18 -40
  52. liger_kernel/transformers/model/qwen2_vl.py +36 -32
  53. liger_kernel/transformers/monkey_patch.py +214 -144
  54. liger_kernel/transformers/rms_norm.py +4 -4
  55. liger_kernel/transformers/rope.py +2 -2
  56. liger_kernel/transformers/swiglu.py +2 -8
  57. liger_kernel/transformers/trainer/__init__.py +1 -3
  58. liger_kernel/transformers/trainer/orpo_trainer.py +31 -18
  59. liger_kernel/transformers/tvd.py +13 -0
  60. liger_kernel/triton/__init__.py +1 -3
  61. liger_kernel/triton/monkey_patch.py +1 -3
  62. liger_kernel/utils.py +49 -0
  63. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/METADATA +53 -26
  64. liger_kernel-0.5.4.dist-info/RECORD +74 -0
  65. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/WHEEL +1 -1
  66. liger_kernel-0.5.2.dist-info/RECORD +0 -65
  67. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/LICENSE +0 -0
  68. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/NOTICE +0 -0
  69. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,14 @@
1
1
  import operator
2
+
2
3
  from typing import Optional
3
4
 
4
5
  import torch
5
6
  import triton
6
7
  import triton.language as tl
7
8
 
8
- from liger_kernel.ops.utils import compare_version, element_mul_kernel, is_hip
9
+ from liger_kernel.ops.utils import compare_version
10
+ from liger_kernel.ops.utils import element_mul_kernel
11
+ from liger_kernel.ops.utils import is_hip
9
12
 
10
13
  if compare_version("triton", operator.ge, "3.0.0"):
11
14
  try:
@@ -17,9 +20,6 @@ if compare_version("triton", operator.ge, "3.0.0"):
17
20
  else:
18
21
  from triton.language.math import tanh
19
22
 
20
- _TRUE = tl.constexpr(1)
21
- _FALSE = tl.constexpr(0)
22
-
23
23
 
24
24
  @triton.jit
25
25
  def liger_cross_entropy_kernel(
@@ -27,11 +27,14 @@ def liger_cross_entropy_kernel(
27
27
  X_stride,
28
28
  Y_ptr,
29
29
  Y_stride,
30
+ weight_ptr,
30
31
  loss_ptr,
31
32
  z_loss_ptr,
32
33
  loss_stride,
33
34
  n_cols,
34
35
  n_non_ignore,
36
+ sum_non_ignore_weight,
37
+ weight_sum,
35
38
  ignore_index,
36
39
  lse_square_scale: tl.constexpr,
37
40
  label_smoothing: tl.constexpr,
@@ -39,6 +42,7 @@ def liger_cross_entropy_kernel(
39
42
  softcap,
40
43
  RETURN_Z_LOSS: tl.constexpr,
41
44
  BLOCK_SIZE: tl.constexpr,
45
+ HAS_WEIGHT: tl.constexpr,
42
46
  HAS_SOFTCAPPING: tl.constexpr,
43
47
  ):
44
48
  """
@@ -50,18 +54,22 @@ def liger_cross_entropy_kernel(
50
54
  X_stride (int): The stride of the input tensor.
51
55
  Y_ptr: Pointer to target tensor.
52
56
  Y_stride (int): The stride of the target tensor.
57
+ weight_ptr: Pointer to weight tensor.
53
58
  loss_ptr: Pointer to tensor to store the loss.
54
59
  z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
55
60
  loss_stride (int): The stride of the loss tensor.
56
61
  n_cols (int): The number of columns in the input tensor.
57
- n_non_ignore (int): The number of non-ignored elements in the batch.
62
+ n_non_ignore (flaot): The number of non-ignored elements in the batch.
63
+ sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
64
+ weight_sum (float): The sum of weight tensor.
58
65
  ignore_index (int): The index to ignore in the target.
59
66
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
60
67
  lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
61
- RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
62
68
  reduction (str): The string for the reduction to apply
63
69
  softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
70
+ RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
64
71
  BLOCK_SIZE (int): The block size for Triton operations.
72
+ HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
65
73
  HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
66
74
  """
67
75
 
@@ -84,7 +92,11 @@ def liger_cross_entropy_kernel(
84
92
  return
85
93
 
86
94
  loss_ptr += program_id * loss_stride
87
- z_loss_ptr += program_id * loss_stride
95
+ if RETURN_Z_LOSS:
96
+ z_loss_ptr += program_id * loss_stride
97
+
98
+ if HAS_WEIGHT:
99
+ weight_y = tl.load(weight_ptr + y).cast(tl.float32)
88
100
 
89
101
  # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
90
102
  # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
@@ -92,9 +104,7 @@ def liger_cross_entropy_kernel(
92
104
  # 3. [Online softmax] first pass: find max + sum
93
105
  m = float("-inf") # m is the max value. use the notation from the paper
94
106
  d = 0.0 # d is the sum. use the notation from the paper
95
- ori_X_y = tl.load(X_ptr + y).cast(
96
- tl.float32
97
- ) # we need to store the original value of X_y for the loss calculation
107
+ ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
98
108
  if HAS_SOFTCAPPING:
99
109
  ori_X_y = softcap * tanh(ori_X_y / softcap)
100
110
 
@@ -116,7 +126,11 @@ def liger_cross_entropy_kernel(
116
126
  block_max = tl.max(X_block)
117
127
  if label_smoothing > 0:
118
128
  # scale X beforehand to avoid overflow
119
- scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
129
+ if HAS_WEIGHT:
130
+ weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
131
+ scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0))
132
+ else:
133
+ scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
120
134
  m_new = tl.maximum(m, block_max)
121
135
  d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
122
136
  m = m_new
@@ -152,18 +166,41 @@ def liger_cross_entropy_kernel(
152
166
  if HAS_SOFTCAPPING:
153
167
  intermediate = tanh(X_block / softcap)
154
168
  X_block = softcap * intermediate
155
- # softmax(x_i)
156
- X_block = tl.exp(X_block - m) / d
157
- # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
158
- X_block += 2 * lse_square_scale * lse * X_block
159
- # smoothing term
160
- X_block += -eps
161
- # special handle dx_y
162
- X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
163
- # reduction scale
164
- if reduction == "mean":
165
- X_block = X_block / (n_non_ignore)
166
- # chain rule
169
+
170
+ if not HAS_WEIGHT:
171
+ # softmax(x_i)
172
+ X_block = tl.exp(X_block - m) / d
173
+ # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
174
+ X_block += 2 * lse_square_scale * lse * X_block
175
+ # smoothing term
176
+ X_block += -eps
177
+ # special handle dx_y
178
+ X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
179
+ # reduction scale
180
+ if reduction == "mean":
181
+ X_block = X_block / n_non_ignore
182
+ else:
183
+ weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
184
+ softmax_X = tl.exp(X_block - m) / d
185
+ # derivative of original_loss
186
+ dloss_ori = (1 - label_smoothing) * softmax_X
187
+ # specially handle dx_y
188
+ dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
189
+ dloss_ori = dloss_ori * weight_y
190
+ # derivative of smooth_loss
191
+ dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
192
+ # derivative of z-loss
193
+ dz_loss = 2 * lse_square_scale * lse * softmax_X
194
+ # reduction scale
195
+ if reduction == "mean":
196
+ dloss_ori = dloss_ori / sum_non_ignore_weight
197
+ dloss_smooth = dloss_smooth / sum_non_ignore_weight
198
+ # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
199
+ dz_loss = dz_loss / n_non_ignore
200
+ # derivative of total_loss
201
+ X_block = dloss_ori + dloss_smooth + dz_loss
202
+
203
+ # chain rule softcapping
167
204
  # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
168
205
  if HAS_SOFTCAPPING:
169
206
  X_block = X_block * (1 - intermediate * intermediate)
@@ -182,6 +219,8 @@ def liger_cross_entropy_kernel(
182
219
  # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
183
220
  # So we can safely calculate log (softmax(X_y)) without overflow
184
221
  loss = lse - ori_X_y
222
+ if HAS_WEIGHT:
223
+ loss = weight_y * loss
185
224
 
186
225
  # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
187
226
  # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
@@ -192,20 +231,27 @@ def liger_cross_entropy_kernel(
192
231
  # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
193
232
  # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
194
233
  if label_smoothing > 0:
195
- smooth_loss = scaled_x_sum + label_smoothing * lse
234
+ if HAS_WEIGHT:
235
+ smooth_loss = scaled_x_sum + eps * lse * weight_sum
236
+ else:
237
+ smooth_loss = scaled_x_sum + label_smoothing * lse
196
238
  loss = loss * (1 - label_smoothing) + smooth_loss
197
239
 
198
240
  # An auxiliary loss, z_loss
199
241
  # Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
200
242
  z_loss = lse_square_scale * lse * lse
201
- loss += z_loss
202
243
  # Normalize the loss by the number of non-ignored elements if reduction is "mean"
203
244
  if reduction == "mean":
245
+ if HAS_WEIGHT:
246
+ loss = loss / sum_non_ignore_weight
247
+ else:
248
+ loss = loss / n_non_ignore
249
+ # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
204
250
  z_loss = z_loss / n_non_ignore
205
- loss = loss / n_non_ignore
251
+ loss += z_loss
206
252
 
207
253
  tl.store(loss_ptr, loss)
208
- if RETURN_Z_LOSS == _TRUE:
254
+ if RETURN_Z_LOSS:
209
255
  tl.store(z_loss_ptr, z_loss)
210
256
 
211
257
 
@@ -215,15 +261,10 @@ def liger_cross_entropy_kernel(
215
261
  MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
216
262
 
217
263
 
218
- _bool_to_return_z_loss = {
219
- True: _TRUE.value,
220
- False: _FALSE.value,
221
- }
222
-
223
-
224
264
  def cross_entropy_forward(
225
265
  _input,
226
266
  target,
267
+ weight,
227
268
  ignore_index,
228
269
  lse_square_scale,
229
270
  label_smoothing,
@@ -231,15 +272,7 @@ def cross_entropy_forward(
231
272
  softcap,
232
273
  return_z_loss,
233
274
  ):
234
- if not isinstance(return_z_loss, int):
235
- assert (
236
- return_z_loss in _bool_to_return_z_loss
237
- ), f"return_z_loss must be True or False. Got: {return_z_loss}"
238
- return_z_loss = _bool_to_return_z_loss[return_z_loss]
239
- else:
240
- assert (
241
- return_z_loss in _bool_to_return_z_loss
242
- ), f"return_z_loss must be True or False. Got: {return_z_loss}"
275
+ assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
243
276
 
244
277
  BT, V = _input.shape
245
278
  n_rows = BT
@@ -248,12 +281,22 @@ def cross_entropy_forward(
248
281
 
249
282
  # unreduced loss
250
283
  loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
251
- if return_z_loss == _TRUE.value:
252
- z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
253
- else:
254
- z_loss_1d = loss_1d # dummy ptr when return_z_loss == False
255
-
256
- n_non_ignore = (target != ignore_index).sum().item()
284
+ z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None
285
+
286
+ target_mask = target != ignore_index
287
+ n_non_ignore = target_mask.sum().item()
288
+ sum_non_ignore_weight = n_non_ignore
289
+ weight_sum = 0.0
290
+ if weight is not None:
291
+ assert weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
292
+ assert torch.is_floating_point(weight), (
293
+ f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
294
+ )
295
+ sum_non_ignore_weight = torch.gather(weight, dim=0, index=target.masked_select(target_mask)).sum().item()
296
+ weight_sum = weight.sum().item()
297
+ # ensure weight is contiguous
298
+ if weight.stride(-1) != 1:
299
+ weight = weight.contiguous()
257
300
 
258
301
  # ensure _input and target are contiguous in the last dimension
259
302
  if _input.stride(-1) != 1:
@@ -267,18 +310,22 @@ def cross_entropy_forward(
267
310
  X_stride=_input.stride(-2),
268
311
  Y_ptr=target,
269
312
  Y_stride=target.stride(-1), # always 1
313
+ weight_ptr=weight, # dummy if None
270
314
  loss_ptr=loss_1d,
271
315
  z_loss_ptr=z_loss_1d,
272
316
  loss_stride=loss_1d.stride(-1), # always 1
273
317
  n_cols=V,
274
318
  n_non_ignore=n_non_ignore,
319
+ sum_non_ignore_weight=sum_non_ignore_weight,
275
320
  ignore_index=ignore_index,
321
+ weight_sum=weight_sum,
276
322
  lse_square_scale=lse_square_scale,
277
323
  label_smoothing=label_smoothing,
278
324
  reduction=reduction,
279
- softcap=softcap if softcap is not None else 0.0,
325
+ softcap=softcap,
280
326
  RETURN_Z_LOSS=return_z_loss,
281
327
  BLOCK_SIZE=BLOCK_SIZE,
328
+ HAS_WEIGHT=True if weight is not None else False,
282
329
  HAS_SOFTCAPPING=True if softcap is not None else False,
283
330
  # TODO: 32 seems to give the best performance
284
331
  # Performance is quite sensitive to num_warps
@@ -287,10 +334,10 @@ def cross_entropy_forward(
287
334
 
288
335
  if reduction == "none":
289
336
  loss = loss_1d
290
- z_loss = z_loss_1d if return_z_loss == _TRUE.value else None
337
+ z_loss = z_loss_1d if return_z_loss else None
291
338
  else:
292
339
  loss = torch.sum(loss_1d)
293
- z_loss = torch.sum(z_loss_1d) if return_z_loss == _TRUE.value else None
340
+ z_loss = torch.sum(z_loss_1d) if return_z_loss else None
294
341
 
295
342
  return loss, z_loss, _input
296
343
 
@@ -330,6 +377,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
330
377
  ctx,
331
378
  _input: torch.Tensor,
332
379
  target: torch.Tensor,
380
+ weight: Optional[torch.FloatTensor],
333
381
  ignore_index: int = -100,
334
382
  lse_square_scale: float = 0.0,
335
383
  label_smoothing: float = 0.0,
@@ -344,6 +392,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
344
392
  ctx : The context object.
345
393
  _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
346
394
  target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
395
+ weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
347
396
  ignore_index (int): The index to ignore in the target.
348
397
  lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
349
398
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
@@ -357,6 +406,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
357
406
  loss, z_loss, _input = cross_entropy_forward(
358
407
  _input,
359
408
  target,
409
+ weight,
360
410
  ignore_index,
361
411
  lse_square_scale,
362
412
  label_smoothing,
@@ -398,4 +448,5 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
398
448
  None,
399
449
  None,
400
450
  None,
451
+ None,
401
452
  )
@@ -34,9 +34,7 @@ def embedding_forward_kernel(
34
34
  )
35
35
 
36
36
  output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :]
37
- tl.store(
38
- output_ptr + output_offsets, embeddings, mask=mask_m[:, None] & mask_n[None, :]
39
- )
37
+ tl.store(output_ptr + output_offsets, embeddings, mask=mask_m[:, None] & mask_n[None, :])
40
38
 
41
39
 
42
40
  @triton.jit
@@ -37,9 +37,7 @@ def pack_weights(intweights: torch.Tensor, bits: int = 2) -> torch.Tensor:
37
37
  else:
38
38
  packed_tensor_shape = (row_dim, *original_shape[1:])
39
39
 
40
- packed = torch.zeros(
41
- packed_tensor_shape, device=intweights.device, dtype=torch.uint8
42
- )
40
+ packed = torch.zeros(packed_tensor_shape, device=intweights.device, dtype=torch.uint8)
43
41
  unpacked = intweights.to(torch.uint8)
44
42
 
45
43
  def lshift(t: torch.Tensor, bits: int):
@@ -327,17 +325,13 @@ def matmul_kernel(
327
325
 
328
326
 
329
327
  def matmul(a, b):
330
- assert (
331
- a.shape[1] == b.shape[0] * 4
332
- ), "Incompatible dimensions, the weight matrix need to be packed"
328
+ assert a.shape[1] == b.shape[0] * 4, "Incompatible dimensions, the weight matrix need to be packed"
333
329
  assert a.is_contiguous(), "Matrix A must be contiguous"
334
330
  M, K = a.shape
335
331
  _, N = b.shape
336
332
  # c is in int32 to avoid any overflows or underflows
337
333
  c = torch.empty((M, N), device=a.device, dtype=torch.int32)
338
- grid = lambda META: (
339
- triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
340
- )
334
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),)
341
335
  matmul_kernel[grid](
342
336
  a,
343
337
  b,
@@ -2,12 +2,10 @@ import torch
2
2
  import triton
3
3
 
4
4
  from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel
5
- from liger_kernel.ops.utils import (
6
- amp_custom_bwd,
7
- amp_custom_fwd,
8
- element_mul_kernel,
9
- is_hip,
10
- )
5
+ from liger_kernel.ops.utils import amp_custom_bwd
6
+ from liger_kernel.ops.utils import amp_custom_fwd
7
+ from liger_kernel.ops.utils import element_mul_kernel
8
+ from liger_kernel.ops.utils import is_hip
11
9
 
12
10
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
13
11
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
@@ -19,13 +17,16 @@ def fused_linear_cross_entropy_forward(
19
17
  _input,
20
18
  weight,
21
19
  target,
20
+ ce_weight=None,
22
21
  bias=None,
23
22
  ignore_index=-100,
24
23
  lse_square_scale=0.0,
25
24
  label_smoothing=0.0,
26
25
  reduction="mean",
27
26
  softcap=None,
27
+ return_z_loss=False,
28
28
  ):
29
+ assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
29
30
  device = _input.device
30
31
 
31
32
  # inputs have shape: BT x H
@@ -40,21 +41,32 @@ def fused_linear_cross_entropy_forward(
40
41
  BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
41
42
 
42
43
  inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
43
- chunk_size = triton.next_power_of_2(
44
- triton.cdiv(BT, inc_factor)
45
- ) # (BT + inc_factor - 1) // inc_factor
44
+ chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
46
45
  num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
47
46
 
48
- grad_weight = (
49
- torch.zeros_like(weight, device=device) if weight.requires_grad else None
50
- )
47
+ grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
51
48
  grad_input = torch.zeros_like(_input, device=device)
52
49
  grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
53
50
  # we use fp32 for loss accumulator
54
51
  loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
55
-
56
- # NOTE: skip .item() here to avoid CUDA synchronization
57
- total_n_non_ignore = (target != ignore_index).sum()
52
+ z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None
53
+
54
+ # TODO: evaluate how CUDA synchronization caused by .item() affects the speed
55
+ target_mask = target != ignore_index
56
+ total_n_non_ignore = target_mask.sum().item()
57
+ total_sum_non_ignore_ce_weight = total_n_non_ignore
58
+ ce_weight_sum = 0.0
59
+ if ce_weight is not None:
60
+ assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}"
61
+ assert torch.is_floating_point(ce_weight), (
62
+ f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}"
63
+ )
64
+ total_sum_non_ignore_ce_weight = (
65
+ torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item()
66
+ )
67
+ ce_weight_sum = ce_weight.sum().item()
68
+ if ce_weight.stride(-1) != 1:
69
+ ce_weight = ce_weight.contiguous()
58
70
 
59
71
  for chunk_id in range(num_chunks):
60
72
  start_idx = chunk_id * chunk_size
@@ -65,13 +77,14 @@ def fused_linear_cross_entropy_forward(
65
77
  logits_chunk = _input_chunk @ weight.t() # chunk_size x V
66
78
  if bias is not None:
67
79
  logits_chunk = logits_chunk + bias
80
+
68
81
  target_chunk = target[start_idx:end_idx] # chunk_size,
69
82
 
70
83
  n_rows = logits_chunk.shape[0]
71
84
 
72
85
  # unreduced loss
73
86
  loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
74
- n_non_ignore = (target_chunk != ignore_index).sum().item()
87
+ z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None
75
88
 
76
89
  # ensure _input and target are contiguous
77
90
  logits_chunk = logits_chunk.contiguous()
@@ -83,45 +96,42 @@ def fused_linear_cross_entropy_forward(
83
96
  X_stride=logits_chunk.stride(-2),
84
97
  Y_ptr=target_chunk,
85
98
  Y_stride=target_chunk.stride(-1), # always 1
99
+ weight_ptr=ce_weight,
86
100
  loss_ptr=loss_1d_slice,
87
- z_loss_ptr=loss_1d_slice, # dummy ptr, not used
101
+ z_loss_ptr=z_loss_1d_slice,
88
102
  loss_stride=loss_1d_slice.stride(-1), # always 1
89
103
  n_cols=V,
90
- n_non_ignore=n_non_ignore,
104
+ n_non_ignore=total_n_non_ignore,
105
+ sum_non_ignore_weight=total_sum_non_ignore_ce_weight,
106
+ weight_sum=ce_weight_sum,
91
107
  ignore_index=ignore_index,
92
108
  lse_square_scale=lse_square_scale,
93
109
  label_smoothing=label_smoothing,
94
110
  reduction=reduction,
95
- softcap=softcap if softcap is not None else 0.0,
96
- RETURN_Z_LOSS=0, # False
111
+ softcap=softcap,
112
+ RETURN_Z_LOSS=return_z_loss,
113
+ HAS_WEIGHT=True if ce_weight is not None else False,
97
114
  HAS_SOFTCAPPING=True if softcap is not None else False,
98
115
  BLOCK_SIZE=BLOCK_SIZE,
99
116
  num_warps=32 if not is_hip() else 16,
100
117
  )
101
118
 
102
- # gradient of logits_chunk is computed in-place by the above triton kernel and is of shape: chunk_size x V
103
- # thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H
104
- # additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
105
- # on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens.
106
- # Thus, we need an additional scaling factor of (n_non_ignore/total_n_non_ignore) to scale the gradients.
107
-
108
- if reduction == "mean":
109
- alpha = n_non_ignore / total_n_non_ignore if total_n_non_ignore > 0 else 0.0
110
- else:
111
- alpha = 1.0
112
-
113
- loss_1d[start_idx:end_idx] = loss_1d_slice * alpha
114
- grad_logits_chunk = logits_chunk * alpha # chunk_size x V
119
+ loss_1d[start_idx:end_idx] = loss_1d_slice
120
+ if return_z_loss:
121
+ z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
122
+ grad_logits_chunk = logits_chunk # chunk_size x V
115
123
 
116
124
  grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
117
125
 
118
126
  if grad_weight is not None:
119
127
  torch.addmm(
120
128
  input=grad_weight,
121
- mat1=logits_chunk.t(),
129
+ mat1=logits_chunk.t().to(
130
+ _input_chunk.dtype
131
+ ), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error.
122
132
  mat2=_input_chunk,
123
133
  out=grad_weight,
124
- alpha=alpha,
134
+ alpha=1.0,
125
135
  beta=1.0,
126
136
  )
127
137
 
@@ -130,18 +140,22 @@ def fused_linear_cross_entropy_forward(
130
140
  input=grad_bias,
131
141
  other=logits_chunk.sum(dim=0),
132
142
  out=grad_bias,
133
- alpha=alpha,
143
+ alpha=1.0,
134
144
  )
135
145
 
136
- loss = torch.sum(loss_1d)
137
- return loss, grad_input, grad_weight, grad_bias
146
+ if reduction == "none":
147
+ loss = loss_1d
148
+ z_loss = z_loss_1d if return_z_loss else None
138
149
 
150
+ else:
151
+ loss = torch.sum(loss_1d)
152
+ z_loss = torch.sum(z_loss_1d) if return_z_loss else None
153
+ return loss, z_loss, grad_input, grad_weight, grad_bias
139
154
 
140
- def fused_linear_cross_entropy_backward(
141
- grad_output, grad_input, grad_weight, grad_bias
142
- ):
155
+
156
+ def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
143
157
  # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
144
- if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
158
+ if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
145
159
  # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
146
160
  # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
147
161
  BT, H = grad_input.shape
@@ -195,11 +209,13 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
195
209
  weight,
196
210
  target,
197
211
  bias=None,
212
+ ce_weight=None,
198
213
  ignore_index=-100,
199
214
  lse_square_scale=0.0,
200
215
  label_smoothing=0.0,
201
216
  reduction="mean",
202
217
  softcap=None,
218
+ return_z_loss: bool = False,
203
219
  ):
204
220
  """
205
221
  Fusing the last linear layer with cross-entropy loss
@@ -214,21 +230,24 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
214
230
  target: (B*T) where each value is in [0, V-1]
215
231
  weight: (V, H) where V is the number of classes
216
232
  bias: (V) where V is the number of classes
233
+ ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
217
234
  ignore_index: the index to ignore in the target
218
235
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
219
236
  reduction: reduction to apply
220
237
  """
221
238
 
222
- loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
223
- _input,
224
- weight,
225
- target,
226
- bias,
227
- ignore_index,
228
- lse_square_scale,
229
- label_smoothing,
230
- reduction,
231
- softcap,
239
+ loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
240
+ _input=_input,
241
+ weight=weight,
242
+ target=target,
243
+ bias=bias,
244
+ ce_weight=ce_weight,
245
+ ignore_index=ignore_index,
246
+ lse_square_scale=lse_square_scale,
247
+ label_smoothing=label_smoothing,
248
+ reduction=reduction,
249
+ softcap=softcap,
250
+ return_z_loss=return_z_loss,
232
251
  )
233
252
  # downcast to dtype and store for backward
234
253
  ctx.save_for_backward(
@@ -236,13 +255,28 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
236
255
  grad_weight.detach() if grad_weight is not None else None,
237
256
  grad_bias.detach() if bias is not None else None,
238
257
  )
239
- return loss
258
+ ctx.return_z_loss = return_z_loss
259
+ return loss, z_loss
240
260
 
241
261
  @staticmethod
242
262
  @amp_custom_bwd
243
- def backward(ctx, grad_output):
263
+ def backward(ctx, grad_output, grad_output2):
264
+ if ctx.return_z_loss:
265
+ del grad_output2 # z_loss is only for logging
244
266
  (grad_input, grad_weight, grad_bias) = ctx.saved_tensors
245
267
  grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
246
268
  grad_output, grad_input, grad_weight, grad_bias
247
269
  )
248
- return (grad_input, grad_weight, None, grad_bias, None, None, None, None, None)
270
+ return (
271
+ grad_input,
272
+ grad_weight,
273
+ None,
274
+ grad_bias,
275
+ None,
276
+ None,
277
+ None,
278
+ None,
279
+ None,
280
+ None,
281
+ None,
282
+ )