liger-kernel-nightly 0.4.0.dev20241106174658__tar.gz → 0.4.0.dev20241107052928__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (55) hide show
  1. {liger_kernel_nightly-0.4.0.dev20241106174658/src/liger_kernel_nightly.egg-info → liger_kernel_nightly-0.4.0.dev20241107052928}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/pyproject.toml +1 -1
  3. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/ops/cross_entropy.py +104 -20
  4. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/ops/fused_linear_cross_entropy.py +14 -2
  5. liger_kernel_nightly-0.4.0.dev20241107052928/src/liger_kernel/transformers/cross_entropy.py +43 -0
  6. liger_kernel_nightly-0.4.0.dev20241107052928/src/liger_kernel/transformers/fused_linear_cross_entropy.py +35 -0
  7. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928/src/liger_kernel_nightly.egg-info}/PKG-INFO +1 -1
  8. liger_kernel_nightly-0.4.0.dev20241106174658/src/liger_kernel/transformers/cross_entropy.py +0 -21
  9. liger_kernel_nightly-0.4.0.dev20241106174658/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -21
  10. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/LICENSE +0 -0
  11. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/NOTICE +0 -0
  12. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/README.md +0 -0
  13. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/setup.cfg +0 -0
  14. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/env_report.py +0 -0
  15. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/ops/__init__.py +0 -0
  16. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  17. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  18. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  19. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/ops/geglu.py +0 -0
  20. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/ops/jsd.py +0 -0
  21. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/ops/kl_div.py +0 -0
  22. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/ops/layer_norm.py +0 -0
  23. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/ops/rms_norm.py +0 -0
  24. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/ops/rope.py +0 -0
  25. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/ops/swiglu.py +0 -0
  26. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/ops/utils.py +0 -0
  27. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/transformers/__init__.py +0 -0
  28. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/transformers/auto_model.py +0 -0
  29. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  30. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/transformers/functional.py +0 -0
  31. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  32. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/transformers/geglu.py +0 -0
  33. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/transformers/jsd.py +0 -0
  34. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/transformers/kl_div.py +0 -0
  35. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/transformers/layer_norm.py +0 -0
  36. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/transformers/model/__init__.py +0 -0
  37. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/transformers/model/gemma.py +0 -0
  38. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/transformers/model/llama.py +0 -0
  39. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/transformers/model/mistral.py +0 -0
  40. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  41. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/transformers/model/mllama.py +0 -0
  42. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/transformers/model/phi3.py +0 -0
  43. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  44. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  45. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  46. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/transformers/rms_norm.py +0 -0
  47. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/transformers/rope.py +0 -0
  48. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/transformers/swiglu.py +0 -0
  49. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  50. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/triton/__init__.py +0 -0
  51. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel/triton/monkey_patch.py +0 -0
  52. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
  53. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  54. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  55. {liger_kernel_nightly-0.4.0.dev20241106174658 → liger_kernel_nightly-0.4.0.dev20241107052928}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.0.dev20241106174658
3
+ Version: 0.4.0.dev20241107052928
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.4.0.dev20241106174658"
7
+ version = "0.4.0.dev20241107052928"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -4,6 +4,9 @@ import triton.language as tl
4
4
 
5
5
  from liger_kernel.ops.utils import element_mul_kernel, is_hip
6
6
 
7
+ _TRUE = tl.constexpr(1)
8
+ _FALSE = tl.constexpr(0)
9
+
7
10
 
8
11
  @triton.jit
9
12
  def liger_cross_entropy_kernel(
@@ -12,12 +15,15 @@ def liger_cross_entropy_kernel(
12
15
  Y_ptr,
13
16
  Y_stride,
14
17
  loss_ptr,
18
+ z_loss_ptr,
15
19
  loss_stride,
16
20
  n_cols,
17
21
  n_non_ignore,
18
22
  ignore_index,
23
+ lse_square_scale: tl.constexpr,
19
24
  label_smoothing: tl.constexpr,
20
25
  reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
26
+ RETURN_Z_LOSS: tl.constexpr,
21
27
  BLOCK_SIZE: tl.constexpr,
22
28
  ):
23
29
  """
@@ -30,11 +36,14 @@ def liger_cross_entropy_kernel(
30
36
  Y_ptr: Pointer to target tensor.
31
37
  Y_stride (int): The stride of the target tensor.
32
38
  loss_ptr: Pointer to tensor to store the loss.
39
+ z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
33
40
  loss_stride (int): The stride of the loss tensor.
34
41
  n_cols (int): The number of columns in the input tensor.
35
42
  n_non_ignore (int): The number of non-ignored elements in the batch.
36
43
  ignore_index (int): The index to ignore in the target.
37
44
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
45
+ lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
46
+ RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
38
47
  reduction (str): The string for the reduction to apply
39
48
  BLOCK_SIZE (int): The block size for Triton operations.
40
49
  """
@@ -58,6 +67,7 @@ def liger_cross_entropy_kernel(
58
67
  return
59
68
 
60
69
  loss_ptr += program_id * loss_stride
70
+ z_loss_ptr += program_id * loss_stride
61
71
 
62
72
  # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
63
73
  # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
@@ -87,32 +97,40 @@ def liger_cross_entropy_kernel(
87
97
  d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
88
98
  m = m_new
89
99
 
100
+ # log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X)))))
101
+ # = log (e^(max(X)) * sum(e ^ (X_i - max(X))))
102
+ # = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d
103
+ lse = m + tl.log(d)
104
+
90
105
  # 4. [Online Softmax] Second pass: compute gradients
91
106
  # For 'mean' reduction, gradients are normalized by number of non-ignored elements (N)
92
107
  # dx_y = (softmax(x_y) - 1) / N
93
108
  # dx_i = softmax(x_i) / N, i != y
94
109
  # For label smoothing:
95
- # dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y
110
+ # dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y
96
111
  # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
97
112
  # = dx_i - (1 - label_smoothing) / N
98
- #
113
+ # With Z loss:
114
+ # dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y
115
+ # dx_y = dx_i - (1 - label_smoothing) / N
99
116
  # For 'sum' reduction, no normalization is applied:
100
117
  # dx_y = softmax(x_y) - 1
101
118
  # dx_i = softmax(x_i), for i ≠ y
102
- # For label smoothing:
103
- # dx_i = (softmax(x_y) - label_smoothing / V), V = n_cols, i != y
104
- # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing))
105
- # = dx_i - (1 - label_smoothing)
106
119
 
107
120
  for i in range(0, n_cols, BLOCK_SIZE):
108
121
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
109
122
  X_block = tl.load(
110
123
  X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
111
124
  )
125
+ # softmax(x_i)
126
+ X_block = tl.exp(X_block - m) / d
127
+ # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
128
+ X_block += 2 * lse_square_scale * lse * X_block
129
+ # smoothing term
130
+ X_block += -eps
131
+ # reduction scale
112
132
  if reduction == "mean":
113
- X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
114
- else:
115
- X_block = tl.exp(X_block - m) / d - eps
133
+ X_block = X_block / (n_non_ignore)
116
134
 
117
135
  tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
118
136
 
@@ -124,9 +142,10 @@ def liger_cross_entropy_kernel(
124
142
 
125
143
  # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
126
144
  # = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
145
+ # = X_y - m - log d = X_y - lse
127
146
  # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
128
147
  # So we can safely calculate log (softmax(X_y)) without overflow
129
- loss = -(ori_X_y - m - tl.log(d))
148
+ loss = lse - ori_X_y
130
149
 
131
150
  # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
132
151
  # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
@@ -137,11 +156,16 @@ def liger_cross_entropy_kernel(
137
156
  # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
138
157
  # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
139
158
  if label_smoothing > 0:
140
- smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d))
159
+ smooth_loss = scaled_x_sum + label_smoothing * lse
141
160
  loss = loss * (1 - label_smoothing) + smooth_loss
142
161
 
162
+ # An auxiliary loss, z_loss
163
+ # Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
164
+ z_loss = lse_square_scale * lse * lse
165
+ loss += z_loss
143
166
  # Normalize the loss by the number of non-ignored elements if reduction is "mean"
144
167
  if reduction == "mean":
168
+ z_loss = z_loss / n_non_ignore
145
169
  loss = loss / n_non_ignore
146
170
 
147
171
  # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
@@ -152,6 +176,8 @@ def liger_cross_entropy_kernel(
152
176
  X_y += -(1 - label_smoothing)
153
177
 
154
178
  tl.store(loss_ptr, loss)
179
+ if RETURN_Z_LOSS == _TRUE:
180
+ tl.store(z_loss_ptr, z_loss)
155
181
  tl.store(X_ptr + y, X_y)
156
182
 
157
183
 
@@ -161,7 +187,31 @@ def liger_cross_entropy_kernel(
161
187
  MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
162
188
 
163
189
 
164
- def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reduction):
190
+ _bool_to_return_z_loss = {
191
+ True: _TRUE.value,
192
+ False: _FALSE.value,
193
+ }
194
+
195
+
196
+ def cross_entropy_forward(
197
+ _input,
198
+ target,
199
+ ignore_index,
200
+ lse_square_scale,
201
+ label_smoothing,
202
+ reduction,
203
+ return_z_loss,
204
+ ):
205
+ if not isinstance(return_z_loss, int):
206
+ assert (
207
+ return_z_loss in _bool_to_return_z_loss
208
+ ), f"return_z_loss must be True or False. Got: {return_z_loss}"
209
+ return_z_loss = _bool_to_return_z_loss[return_z_loss]
210
+ else:
211
+ assert (
212
+ return_z_loss in _bool_to_return_z_loss
213
+ ), f"return_z_loss must be True or False. Got: {return_z_loss}"
214
+
165
215
  BT, V = _input.shape
166
216
  n_rows = BT
167
217
 
@@ -169,6 +219,10 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti
169
219
 
170
220
  # unreduced loss
171
221
  loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
222
+ if return_z_loss == _TRUE.value:
223
+ z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
224
+ else:
225
+ z_loss_1d = loss_1d # dummy ptr when return_z_loss == False
172
226
 
173
227
  n_non_ignore = (target != ignore_index).sum().item()
174
228
 
@@ -185,12 +239,15 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti
185
239
  Y_ptr=target,
186
240
  Y_stride=target.stride(-1), # always 1
187
241
  loss_ptr=loss_1d,
242
+ z_loss_ptr=z_loss_1d,
188
243
  loss_stride=loss_1d.stride(-1), # always 1
189
244
  n_cols=V,
190
245
  n_non_ignore=n_non_ignore,
191
246
  ignore_index=ignore_index,
247
+ lse_square_scale=lse_square_scale,
192
248
  label_smoothing=label_smoothing,
193
249
  reduction=reduction,
250
+ RETURN_Z_LOSS=return_z_loss,
194
251
  BLOCK_SIZE=BLOCK_SIZE,
195
252
  # TODO: 32 seems to give the best performance
196
253
  # Performance is quite sensitive to num_warps
@@ -198,7 +255,12 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti
198
255
  )
199
256
 
200
257
  loss = torch.sum(loss_1d)
201
- return loss, _input
258
+ if return_z_loss == _TRUE.value:
259
+ z_loss = torch.sum(z_loss_1d)
260
+ else:
261
+ z_loss = None
262
+
263
+ return loss, z_loss, _input
202
264
 
203
265
 
204
266
  def cross_entropy_backward(_input, grad_output):
@@ -233,7 +295,14 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
233
295
 
234
296
  @staticmethod
235
297
  def forward(
236
- ctx, _input, target, ignore_index=-100, label_smoothing=0.0, reduction="mean"
298
+ ctx,
299
+ _input,
300
+ target,
301
+ ignore_index=-100,
302
+ lse_square_scale=0.0,
303
+ label_smoothing=0.0,
304
+ reduction="mean",
305
+ return_z_loss=False,
237
306
  ):
238
307
  """
239
308
  The forward pass of the Liger Cross Entropy loss.
@@ -243,33 +312,46 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
243
312
  _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
244
313
  target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
245
314
  ignore_index (int): The index to ignore in the target.
315
+ lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
246
316
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
247
317
  reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
318
+ return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
248
319
 
249
320
  Returns:
250
- tensor: The computed loss.
321
+ tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
251
322
  """
252
- loss, _input = cross_entropy_forward(
253
- _input, target, ignore_index, label_smoothing, reduction
323
+ loss, z_loss, _input = cross_entropy_forward(
324
+ _input,
325
+ target,
326
+ ignore_index,
327
+ lse_square_scale,
328
+ label_smoothing,
329
+ reduction,
330
+ return_z_loss,
254
331
  )
255
332
  # TODO: investigation
256
333
  # If we don't detach the _input tensor, the memory will double
257
334
  # Not sure why but seems that there will be a time both grad and value exist but in different location
258
335
  ctx.save_for_backward(_input.detach())
259
- return loss
336
+ ctx.return_z_loss = return_z_loss
337
+
338
+ return loss, z_loss
260
339
 
261
340
  @staticmethod
262
- def backward(ctx, grad_output):
341
+ def backward(ctx, grad_output, grad_ouput2):
263
342
  """
264
343
  The backward pass of the Liger Cross Entropy loss.
265
344
 
266
345
  Parameters:
267
346
  ctx : The context object with saved tensors.
268
347
  grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
269
-
348
+ grad_output2 (tenosr): No use.
270
349
  Returns:
271
350
  tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
272
351
  """
352
+ if ctx.return_z_loss:
353
+ del grad_ouput2 # z_loss is only for logging
354
+
273
355
  (_input,) = ctx.saved_tensors
274
356
  _input = cross_entropy_backward(_input, grad_output)
275
357
  return (
@@ -278,4 +360,6 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
278
360
  None,
279
361
  None,
280
362
  None,
363
+ None,
364
+ None,
281
365
  )
@@ -21,6 +21,7 @@ def fused_linear_cross_entropy_forward(
21
21
  target,
22
22
  bias=None,
23
23
  ignore_index=-100,
24
+ lse_square_scale=0.0,
24
25
  label_smoothing=0.0,
25
26
  reduction="mean",
26
27
  ):
@@ -86,12 +87,15 @@ def fused_linear_cross_entropy_forward(
86
87
  Y_ptr=target_chunk,
87
88
  Y_stride=target_chunk.stride(-1), # always 1
88
89
  loss_ptr=loss_1d_slice,
90
+ z_loss_ptr=loss_1d_slice, # dummy ptr, not used
89
91
  loss_stride=loss_1d_slice.stride(-1), # always 1
90
92
  n_cols=V,
91
93
  n_non_ignore=n_non_ignore,
92
94
  ignore_index=ignore_index,
95
+ lse_square_scale=lse_square_scale,
93
96
  label_smoothing=label_smoothing,
94
97
  reduction=reduction,
98
+ RETURN_Z_LOSS=0, # False
95
99
  BLOCK_SIZE=BLOCK_SIZE,
96
100
  num_warps=32 if not is_hip() else 16,
97
101
  )
@@ -200,6 +204,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
200
204
  target,
201
205
  bias=None,
202
206
  ignore_index=-100,
207
+ lse_square_scale=0.0,
203
208
  label_smoothing=0.0,
204
209
  reduction="mean",
205
210
  ):
@@ -221,7 +226,14 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
221
226
  reduction: reduction to apply
222
227
  """
223
228
  loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
224
- _input, weight, target, bias, ignore_index, label_smoothing, reduction
229
+ _input,
230
+ weight,
231
+ target,
232
+ bias,
233
+ ignore_index,
234
+ lse_square_scale,
235
+ label_smoothing,
236
+ reduction,
225
237
  )
226
238
  # downcast to dtype and store for backward
227
239
  ctx.save_for_backward(
@@ -238,4 +250,4 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
238
250
  grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
239
251
  grad_output, grad_input, grad_weight, grad_bias
240
252
  )
241
- return (grad_input, grad_weight, None, grad_bias, None, None, None)
253
+ return (grad_input, grad_weight, None, grad_bias, None, None, None, None)
@@ -0,0 +1,43 @@
1
+ import torch.nn as nn
2
+
3
+ from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
4
+
5
+
6
+ class LigerCrossEntropyLoss(nn.Module):
7
+ def __init__(
8
+ self,
9
+ ignore_index=-100,
10
+ lse_square_scale=0.0,
11
+ label_smoothing=0.0,
12
+ reduction="mean",
13
+ return_z_loss=False,
14
+ ):
15
+ super().__init__()
16
+ self.ignore_index = ignore_index
17
+ self.lse_square_scale = lse_square_scale
18
+ self.label_smoothing = label_smoothing
19
+ self.reduction = reduction
20
+ self.return_z_loss = return_z_loss
21
+
22
+ assert (self.label_smoothing >= 0) and (
23
+ self.label_smoothing <= 1
24
+ ), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}"
25
+ assert self.reduction in {
26
+ "mean",
27
+ "sum",
28
+ "none",
29
+ }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {self.reduction}"
30
+
31
+ def forward(self, _input, target):
32
+ loss, z_loss = LigerCrossEntropyFunction.apply(
33
+ _input,
34
+ target,
35
+ self.ignore_index,
36
+ self.lse_square_scale,
37
+ self.label_smoothing,
38
+ self.reduction,
39
+ self.return_z_loss,
40
+ )
41
+ if not self.return_z_loss:
42
+ return loss
43
+ return loss, z_loss
@@ -0,0 +1,35 @@
1
+ import torch.nn as nn
2
+
3
+ from liger_kernel.ops.fused_linear_cross_entropy import (
4
+ LigerFusedLinearCrossEntropyFunction,
5
+ )
6
+
7
+
8
+ class LigerFusedLinearCrossEntropyLoss(nn.Module):
9
+ def __init__(
10
+ self,
11
+ ignore_index=-100,
12
+ label_smoothing=0.0,
13
+ reduction="mean",
14
+ lse_square_scale=0.0,
15
+ ):
16
+ super().__init__()
17
+ self.ignore_index = ignore_index
18
+ self.label_smoothing = label_smoothing
19
+ self.reduction = reduction
20
+ self.lse_square_scale = lse_square_scale
21
+ assert (self.label_smoothing >= 0) and (
22
+ self.label_smoothing <= 1
23
+ ), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}"
24
+
25
+ def forward(self, lin_weight, _input, target, bias=None):
26
+ return LigerFusedLinearCrossEntropyFunction.apply(
27
+ _input,
28
+ lin_weight,
29
+ target,
30
+ bias,
31
+ self.ignore_index,
32
+ self.lse_square_scale,
33
+ self.label_smoothing,
34
+ self.reduction,
35
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.0.dev20241106174658
3
+ Version: 0.4.0.dev20241107052928
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -1,21 +0,0 @@
1
- from torch.nn import CrossEntropyLoss
2
-
3
- from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
4
-
5
-
6
- class LigerCrossEntropyLoss(CrossEntropyLoss):
7
- def __init__(self, *args, **kwargs):
8
- super(LigerCrossEntropyLoss, self).__init__(*args, **kwargs)
9
- assert (self.label_smoothing >= 0) and (
10
- self.label_smoothing <= 1
11
- ), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}"
12
- assert self.reduction in {
13
- "mean",
14
- "sum",
15
- "none",
16
- }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {self.reduction}"
17
-
18
- def forward(self, _input, target):
19
- return LigerCrossEntropyFunction.apply(
20
- _input, target, self.ignore_index, self.label_smoothing, self.reduction
21
- )
@@ -1,21 +0,0 @@
1
- from torch.nn import CrossEntropyLoss
2
-
3
- from liger_kernel.ops.fused_linear_cross_entropy import (
4
- LigerFusedLinearCrossEntropyFunction,
5
- )
6
-
7
-
8
- class LigerFusedLinearCrossEntropyLoss(CrossEntropyLoss):
9
- def __init__(self, *args, **kwargs):
10
- super(LigerFusedLinearCrossEntropyLoss, self).__init__(*args, **kwargs)
11
-
12
- def forward(self, lin_weight, _input, target, bias=None):
13
- return LigerFusedLinearCrossEntropyFunction.apply(
14
- _input,
15
- lin_weight,
16
- target,
17
- bias,
18
- self.ignore_index,
19
- self.label_smoothing,
20
- self.reduction,
21
- )