liger-kernel 0.3.1__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. liger_kernel/env_report.py +2 -0
  2. liger_kernel/ops/cross_entropy.py +144 -65
  3. liger_kernel/ops/experimental/mm_int8int2.py +355 -0
  4. liger_kernel/ops/fused_linear_cross_entropy.py +31 -11
  5. liger_kernel/ops/fused_linear_jsd.py +245 -0
  6. liger_kernel/ops/geglu.py +2 -2
  7. liger_kernel/ops/group_norm.py +322 -0
  8. liger_kernel/ops/jsd.py +176 -0
  9. liger_kernel/ops/kl_div.py +2 -2
  10. liger_kernel/ops/rms_norm.py +92 -46
  11. liger_kernel/ops/swiglu.py +2 -2
  12. liger_kernel/ops/utils.py +62 -1
  13. liger_kernel/transformers/__init__.py +3 -0
  14. liger_kernel/transformers/cross_entropy.py +44 -12
  15. liger_kernel/transformers/functional.py +38 -1
  16. liger_kernel/transformers/fused_linear_cross_entropy.py +31 -4
  17. liger_kernel/transformers/fused_linear_jsd.py +98 -0
  18. liger_kernel/transformers/group_norm.py +56 -0
  19. liger_kernel/transformers/jsd.py +75 -0
  20. liger_kernel/transformers/model/gemma.py +124 -1
  21. liger_kernel/transformers/model/gemma2.py +277 -0
  22. liger_kernel/transformers/model/llama.py +135 -4
  23. liger_kernel/transformers/model/mistral.py +3 -0
  24. liger_kernel/transformers/model/mixtral.py +153 -2
  25. liger_kernel/transformers/model/mllama.py +274 -0
  26. liger_kernel/transformers/model/phi3.py +140 -2
  27. liger_kernel/transformers/model/qwen2.py +123 -2
  28. liger_kernel/transformers/model/qwen2_vl.py +8 -1
  29. liger_kernel/transformers/monkey_patch.py +258 -68
  30. liger_kernel/transformers/rms_norm.py +11 -3
  31. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/METADATA +63 -29
  32. liger_kernel-0.4.1.dist-info/NOTICE +58 -0
  33. liger_kernel-0.4.1.dist-info/RECORD +51 -0
  34. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/WHEEL +1 -1
  35. liger_kernel-0.3.1.dist-info/NOTICE +0 -4
  36. liger_kernel-0.3.1.dist-info/RECORD +0 -42
  37. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/LICENSE +0 -0
  38. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/top_level.txt +0 -0
@@ -4,11 +4,13 @@ import sys
4
4
 
5
5
  def print_env_report():
6
6
  """
7
+
7
8
  Prints a report of the environment. Useful for debugging and reproducibility.
8
9
  Usage:
9
10
  ```
10
11
  python -m liger_kernel.env_report
11
12
  ```
13
+
12
14
  """
13
15
  print("Environment Report:")
14
16
  print("-------------------")
@@ -1,7 +1,25 @@
1
+ import operator
2
+ from typing import Optional
3
+
1
4
  import torch
2
5
  import triton
3
6
  import triton.language as tl
4
7
 
8
+ from liger_kernel.ops.utils import compare_version, element_mul_kernel, is_hip
9
+
10
+ if compare_version("triton", operator.ge, "3.0.0"):
11
+ try:
12
+ # typical import path with dispatch available
13
+ from triton.language.extra.libdevice import tanh
14
+ except ModuleNotFoundError:
15
+ # for working with NGC containers
16
+ from triton.language.extra.cuda.libdevice import tanh
17
+ else:
18
+ from triton.language.math import tanh
19
+
20
+ _TRUE = tl.constexpr(1)
21
+ _FALSE = tl.constexpr(0)
22
+
5
23
 
6
24
  @triton.jit
7
25
  def liger_cross_entropy_kernel(
@@ -10,13 +28,18 @@ def liger_cross_entropy_kernel(
10
28
  Y_ptr,
11
29
  Y_stride,
12
30
  loss_ptr,
31
+ z_loss_ptr,
13
32
  loss_stride,
14
33
  n_cols,
15
34
  n_non_ignore,
16
35
  ignore_index,
36
+ lse_square_scale: tl.constexpr,
17
37
  label_smoothing: tl.constexpr,
18
38
  reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
39
+ softcap,
40
+ RETURN_Z_LOSS: tl.constexpr,
19
41
  BLOCK_SIZE: tl.constexpr,
42
+ HAS_SOFTCAPPING: tl.constexpr,
20
43
  ):
21
44
  """
22
45
  This kernel computes both cross entropy loss and the gradient of the input.
@@ -28,13 +51,18 @@ def liger_cross_entropy_kernel(
28
51
  Y_ptr: Pointer to target tensor.
29
52
  Y_stride (int): The stride of the target tensor.
30
53
  loss_ptr: Pointer to tensor to store the loss.
54
+ z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
31
55
  loss_stride (int): The stride of the loss tensor.
32
56
  n_cols (int): The number of columns in the input tensor.
33
57
  n_non_ignore (int): The number of non-ignored elements in the batch.
34
58
  ignore_index (int): The index to ignore in the target.
35
59
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
60
+ lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
61
+ RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
36
62
  reduction (str): The string for the reduction to apply
63
+ softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
37
64
  BLOCK_SIZE (int): The block size for Triton operations.
65
+ HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
38
66
  """
39
67
 
40
68
  # https://github.com/triton-lang/triton/issues/1058
@@ -56,6 +84,7 @@ def liger_cross_entropy_kernel(
56
84
  return
57
85
 
58
86
  loss_ptr += program_id * loss_stride
87
+ z_loss_ptr += program_id * loss_stride
59
88
 
60
89
  # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
61
90
  # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
@@ -66,6 +95,8 @@ def liger_cross_entropy_kernel(
66
95
  ori_X_y = tl.load(
67
96
  X_ptr + y
68
97
  ) # we need to store the original value of X_y for the loss calculation
98
+ if HAS_SOFTCAPPING:
99
+ ori_X_y = softcap * tanh(ori_X_y / softcap)
69
100
 
70
101
  # Label smoothing is a general case of normal cross entropy
71
102
  # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
@@ -77,6 +108,8 @@ def liger_cross_entropy_kernel(
77
108
  X_block = tl.load(
78
109
  X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
79
110
  )
111
+ if HAS_SOFTCAPPING:
112
+ X_block = softcap * tanh(X_block / softcap)
80
113
  block_max = tl.max(X_block)
81
114
  if label_smoothing > 0:
82
115
  # scale X beforehand to avoid overflow
@@ -85,32 +118,49 @@ def liger_cross_entropy_kernel(
85
118
  d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
86
119
  m = m_new
87
120
 
121
+ # log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X)))))
122
+ # = log (e^(max(X)) * sum(e ^ (X_i - max(X))))
123
+ # = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d
124
+ lse = m + tl.log(d)
125
+
88
126
  # 4. [Online Softmax] Second pass: compute gradients
89
127
  # For 'mean' reduction, gradients are normalized by number of non-ignored elements (N)
90
128
  # dx_y = (softmax(x_y) - 1) / N
91
129
  # dx_i = softmax(x_i) / N, i != y
92
130
  # For label smoothing:
93
- # dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y
131
+ # dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y
94
132
  # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
95
133
  # = dx_i - (1 - label_smoothing) / N
96
- #
134
+ # With Z loss:
135
+ # dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y
136
+ # dx_y = dx_i - (1 - label_smoothing) / N
97
137
  # For 'sum' reduction, no normalization is applied:
98
138
  # dx_y = softmax(x_y) - 1
99
139
  # dx_i = softmax(x_i), for i ≠ y
100
- # For label smoothing:
101
- # dx_i = (softmax(x_y) - label_smoothing / V), V = n_cols, i != y
102
- # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing))
103
- # = dx_i - (1 - label_smoothing)
104
140
 
105
141
  for i in range(0, n_cols, BLOCK_SIZE):
106
142
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
107
143
  X_block = tl.load(
108
144
  X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
109
145
  )
146
+ if HAS_SOFTCAPPING:
147
+ intermediate = tanh(X_block / softcap)
148
+ X_block = softcap * intermediate
149
+ # softmax(x_i)
150
+ X_block = tl.exp(X_block - m) / d
151
+ # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
152
+ X_block += 2 * lse_square_scale * lse * X_block
153
+ # smoothing term
154
+ X_block += -eps
155
+ # special handle dx_y
156
+ X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
157
+ # reduction scale
110
158
  if reduction == "mean":
111
- X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
112
- else:
113
- X_block = tl.exp(X_block - m) / d - eps
159
+ X_block = X_block / (n_non_ignore)
160
+ # chain rule
161
+ # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
162
+ if HAS_SOFTCAPPING:
163
+ X_block = X_block * (1 - intermediate * intermediate)
114
164
 
115
165
  tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
116
166
 
@@ -122,35 +172,35 @@ def liger_cross_entropy_kernel(
122
172
 
123
173
  # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
124
174
  # = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
175
+ # = X_y - m - log d = X_y - lse
125
176
  # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
126
177
  # So we can safely calculate log (softmax(X_y)) without overflow
127
- loss = -(ori_X_y - m - tl.log(d))
178
+ loss = lse - ori_X_y
128
179
 
129
- # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
180
+ # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
130
181
  # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
131
182
  # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
132
183
  # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
133
- # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd))
184
+ # = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd))
134
185
  # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
135
186
  # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
136
187
  # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
137
188
  if label_smoothing > 0:
138
- smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d))
189
+ smooth_loss = scaled_x_sum + label_smoothing * lse
139
190
  loss = loss * (1 - label_smoothing) + smooth_loss
140
191
 
192
+ # An auxiliary loss, z_loss
193
+ # Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
194
+ z_loss = lse_square_scale * lse * lse
195
+ loss += z_loss
141
196
  # Normalize the loss by the number of non-ignored elements if reduction is "mean"
142
197
  if reduction == "mean":
198
+ z_loss = z_loss / n_non_ignore
143
199
  loss = loss / n_non_ignore
144
200
 
145
- # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
146
- X_y = tl.load(X_ptr + y)
147
- if reduction == "mean":
148
- X_y += -(1 - label_smoothing) / (n_non_ignore)
149
- else:
150
- X_y += -(1 - label_smoothing)
151
-
152
201
  tl.store(loss_ptr, loss)
153
- tl.store(X_ptr + y, X_y)
202
+ if RETURN_Z_LOSS == _TRUE:
203
+ tl.store(z_loss_ptr, z_loss)
154
204
 
155
205
 
156
206
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
@@ -159,43 +209,32 @@ def liger_cross_entropy_kernel(
159
209
  MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
160
210
 
161
211
 
162
- @triton.jit
163
- def element_mul_kernel(
164
- X_ptr,
165
- X_stride,
166
- grad_output_ptr,
167
- n_cols,
168
- BLOCK_SIZE: tl.constexpr,
169
- ):
170
- """
171
- This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
172
- The multiplication is performed in-place on the tensor pointed by X_ptr.
173
-
174
- Parameters:
175
- X_ptr: Pointer to the input tensor.
176
- X_stride (int): The stride of the input tensor.
177
- grad_output_ptr: Pointer to the gradient output value.
178
- n_cols (int): The number of columns in the input tensor.
179
- BLOCK_SIZE (int): The block size for Triton operations.
180
- """
181
-
182
- # Get the program ID and convert it to int64 to avoid overflow
183
- program_id = tl.program_id(0).to(tl.int64)
184
-
185
- # Locate the start index
186
- X_ptr += program_id * X_stride
212
+ _bool_to_return_z_loss = {
213
+ True: _TRUE.value,
214
+ False: _FALSE.value,
215
+ }
187
216
 
188
- # Load the gradient output value
189
- grad_output = tl.load(grad_output_ptr)
190
-
191
- # Perform the element-wise multiplication
192
- for i in range(0, n_cols, BLOCK_SIZE):
193
- X_offsets = i + tl.arange(0, BLOCK_SIZE)
194
- X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
195
- tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
196
217
 
218
+ def cross_entropy_forward(
219
+ _input,
220
+ target,
221
+ ignore_index,
222
+ lse_square_scale,
223
+ label_smoothing,
224
+ reduction,
225
+ softcap,
226
+ return_z_loss,
227
+ ):
228
+ if not isinstance(return_z_loss, int):
229
+ assert (
230
+ return_z_loss in _bool_to_return_z_loss
231
+ ), f"return_z_loss must be True or False. Got: {return_z_loss}"
232
+ return_z_loss = _bool_to_return_z_loss[return_z_loss]
233
+ else:
234
+ assert (
235
+ return_z_loss in _bool_to_return_z_loss
236
+ ), f"return_z_loss must be True or False. Got: {return_z_loss}"
197
237
 
198
- def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reduction):
199
238
  BT, V = _input.shape
200
239
  n_rows = BT
201
240
 
@@ -203,6 +242,10 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti
203
242
 
204
243
  # unreduced loss
205
244
  loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
245
+ if return_z_loss == _TRUE.value:
246
+ z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
247
+ else:
248
+ z_loss_1d = loss_1d # dummy ptr when return_z_loss == False
206
249
 
207
250
  n_non_ignore = (target != ignore_index).sum().item()
208
251
 
@@ -219,20 +262,30 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti
219
262
  Y_ptr=target,
220
263
  Y_stride=target.stride(-1), # always 1
221
264
  loss_ptr=loss_1d,
265
+ z_loss_ptr=z_loss_1d,
222
266
  loss_stride=loss_1d.stride(-1), # always 1
223
267
  n_cols=V,
224
268
  n_non_ignore=n_non_ignore,
225
269
  ignore_index=ignore_index,
270
+ lse_square_scale=lse_square_scale,
226
271
  label_smoothing=label_smoothing,
227
272
  reduction=reduction,
273
+ softcap=softcap if softcap is not None else 0.0,
274
+ RETURN_Z_LOSS=return_z_loss,
228
275
  BLOCK_SIZE=BLOCK_SIZE,
276
+ HAS_SOFTCAPPING=True if softcap is not None else False,
229
277
  # TODO: 32 seems to give the best performance
230
278
  # Performance is quite sensitive to num_warps
231
- num_warps=32,
279
+ num_warps=32 if not is_hip() else 16,
232
280
  )
233
281
 
234
282
  loss = torch.sum(loss_1d)
235
- return loss, _input
283
+ if return_z_loss == _TRUE.value:
284
+ z_loss = torch.sum(z_loss_1d)
285
+ else:
286
+ z_loss = None
287
+
288
+ return loss, z_loss, _input
236
289
 
237
290
 
238
291
  def cross_entropy_backward(_input, grad_output):
@@ -253,7 +306,7 @@ def cross_entropy_backward(_input, grad_output):
253
306
  grad_output,
254
307
  V,
255
308
  BLOCK_SIZE=BLOCK_SIZE,
256
- num_warps=32,
309
+ num_warps=32 if not is_hip() else 16,
257
310
  )
258
311
 
259
312
  return _input
@@ -267,7 +320,15 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
267
320
 
268
321
  @staticmethod
269
322
  def forward(
270
- ctx, _input, target, ignore_index=-100, label_smoothing=0.0, reduction="mean"
323
+ ctx,
324
+ _input: torch.Tensor,
325
+ target: torch.Tensor,
326
+ ignore_index: int = -100,
327
+ lse_square_scale: float = 0.0,
328
+ label_smoothing: float = 0.0,
329
+ reduction: str = "mean",
330
+ softcap: Optional[float] = None,
331
+ return_z_loss: bool = False,
271
332
  ):
272
333
  """
273
334
  The forward pass of the Liger Cross Entropy loss.
@@ -277,33 +338,48 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
277
338
  _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
278
339
  target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
279
340
  ignore_index (int): The index to ignore in the target.
341
+ lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
280
342
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
281
343
  reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
344
+ softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
345
+ return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
282
346
 
283
347
  Returns:
284
- tensor: The computed loss.
348
+ tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
285
349
  """
286
- loss, _input = cross_entropy_forward(
287
- _input, target, ignore_index, label_smoothing, reduction
350
+ loss, z_loss, _input = cross_entropy_forward(
351
+ _input,
352
+ target,
353
+ ignore_index,
354
+ lse_square_scale,
355
+ label_smoothing,
356
+ reduction,
357
+ softcap,
358
+ return_z_loss,
288
359
  )
289
360
  # TODO: investigation
290
361
  # If we don't detach the _input tensor, the memory will double
291
362
  # Not sure why but seems that there will be a time both grad and value exist but in different location
292
363
  ctx.save_for_backward(_input.detach())
293
- return loss
364
+ ctx.return_z_loss = return_z_loss
365
+
366
+ return loss, z_loss
294
367
 
295
368
  @staticmethod
296
- def backward(ctx, grad_output):
369
+ def backward(ctx, grad_output, grad_ouput2):
297
370
  """
298
371
  The backward pass of the Liger Cross Entropy loss.
299
372
 
300
373
  Parameters:
301
374
  ctx : The context object with saved tensors.
302
375
  grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.
303
-
376
+ grad_output2 (tenosr): No use.
304
377
  Returns:
305
378
  tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
306
379
  """
380
+ if ctx.return_z_loss:
381
+ del grad_ouput2 # z_loss is only for logging
382
+
307
383
  (_input,) = ctx.saved_tensors
308
384
  _input = cross_entropy_backward(_input, grad_output)
309
385
  return (
@@ -312,4 +388,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
312
388
  None,
313
389
  None,
314
390
  None,
391
+ None,
392
+ None,
393
+ None,
315
394
  )