liger-kernel-nightly 0.4.0.dev20241108173943__tar.gz → 0.4.0.dev20241108174843__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (56) hide show
  1. {liger_kernel_nightly-0.4.0.dev20241108173943/src/liger_kernel_nightly.egg-info → liger_kernel_nightly-0.4.0.dev20241108174843}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/pyproject.toml +1 -1
  3. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/cross_entropy.py +46 -17
  4. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/fused_linear_cross_entropy.py +6 -1
  5. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/cross_entropy.py +27 -17
  6. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +23 -10
  7. liger_kernel_nightly-0.4.0.dev20241108174843/src/liger_kernel/transformers/model/gemma2.py +277 -0
  8. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/monkey_patch.py +21 -3
  9. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843/src/liger_kernel_nightly.egg-info}/PKG-INFO +1 -1
  10. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel_nightly.egg-info/SOURCES.txt +1 -0
  11. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/LICENSE +0 -0
  12. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/NOTICE +0 -0
  13. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/README.md +0 -0
  14. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/setup.cfg +0 -0
  15. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/env_report.py +0 -0
  16. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/__init__.py +0 -0
  17. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  18. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  19. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  20. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/geglu.py +0 -0
  21. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/group_norm.py +0 -0
  22. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/jsd.py +0 -0
  23. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/kl_div.py +0 -0
  24. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/layer_norm.py +0 -0
  25. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/rms_norm.py +0 -0
  26. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/rope.py +0 -0
  27. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/swiglu.py +0 -0
  28. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/ops/utils.py +0 -0
  29. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/__init__.py +0 -0
  30. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/auto_model.py +0 -0
  31. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  32. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/functional.py +0 -0
  33. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  34. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/geglu.py +0 -0
  35. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/group_norm.py +0 -0
  36. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/jsd.py +0 -0
  37. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/kl_div.py +0 -0
  38. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/layer_norm.py +0 -0
  39. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/model/__init__.py +0 -0
  40. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/model/gemma.py +0 -0
  41. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/model/llama.py +0 -0
  42. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/model/mistral.py +0 -0
  43. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  44. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/model/mllama.py +0 -0
  45. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/model/phi3.py +0 -0
  46. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  47. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  48. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/rms_norm.py +0 -0
  49. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/rope.py +0 -0
  50. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/swiglu.py +0 -0
  51. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  52. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/triton/__init__.py +0 -0
  53. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel/triton/monkey_patch.py +0 -0
  54. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  55. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  56. {liger_kernel_nightly-0.4.0.dev20241108173943 → liger_kernel_nightly-0.4.0.dev20241108174843}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.0.dev20241108173943
3
+ Version: 0.4.0.dev20241108174843
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.4.0.dev20241108173943"
7
+ version = "0.4.0.dev20241108174843"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -1,8 +1,21 @@
1
+ import operator
2
+ from typing import Optional
3
+
1
4
  import torch
2
5
  import triton
3
6
  import triton.language as tl
4
7
 
5
- from liger_kernel.ops.utils import element_mul_kernel, is_hip
8
+ from liger_kernel.ops.utils import compare_version, element_mul_kernel, is_hip
9
+
10
+ if compare_version("triton", operator.ge, "3.0.0"):
11
+ try:
12
+ # typical import path with dispatch available
13
+ from triton.language.extra.libdevice import tanh
14
+ except ModuleNotFoundError:
15
+ # for working with NGC containers
16
+ from triton.language.extra.cuda.libdevice import tanh
17
+ else:
18
+ from triton.language.math import tanh
6
19
 
7
20
  _TRUE = tl.constexpr(1)
8
21
  _FALSE = tl.constexpr(0)
@@ -23,8 +36,10 @@ def liger_cross_entropy_kernel(
23
36
  lse_square_scale: tl.constexpr,
24
37
  label_smoothing: tl.constexpr,
25
38
  reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
39
+ softcap,
26
40
  RETURN_Z_LOSS: tl.constexpr,
27
41
  BLOCK_SIZE: tl.constexpr,
42
+ HAS_SOFTCAPPING: tl.constexpr,
28
43
  ):
29
44
  """
30
45
  This kernel computes both cross entropy loss and the gradient of the input.
@@ -45,7 +60,9 @@ def liger_cross_entropy_kernel(
45
60
  lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
46
61
  RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
47
62
  reduction (str): The string for the reduction to apply
63
+ softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
48
64
  BLOCK_SIZE (int): The block size for Triton operations.
65
+ HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
49
66
  """
50
67
 
51
68
  # https://github.com/triton-lang/triton/issues/1058
@@ -78,6 +95,8 @@ def liger_cross_entropy_kernel(
78
95
  ori_X_y = tl.load(
79
96
  X_ptr + y
80
97
  ) # we need to store the original value of X_y for the loss calculation
98
+ if HAS_SOFTCAPPING:
99
+ ori_X_y = softcap * tanh(ori_X_y / softcap)
81
100
 
82
101
  # Label smoothing is a general case of normal cross entropy
83
102
  # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
@@ -89,6 +108,8 @@ def liger_cross_entropy_kernel(
89
108
  X_block = tl.load(
90
109
  X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
91
110
  )
111
+ if HAS_SOFTCAPPING:
112
+ X_block = softcap * tanh(X_block / softcap)
92
113
  block_max = tl.max(X_block)
93
114
  if label_smoothing > 0:
94
115
  # scale X beforehand to avoid overflow
@@ -122,15 +143,24 @@ def liger_cross_entropy_kernel(
122
143
  X_block = tl.load(
123
144
  X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
124
145
  )
146
+ if HAS_SOFTCAPPING:
147
+ intermediate = tanh(X_block / softcap)
148
+ X_block = softcap * intermediate
125
149
  # softmax(x_i)
126
150
  X_block = tl.exp(X_block - m) / d
127
151
  # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
128
152
  X_block += 2 * lse_square_scale * lse * X_block
129
153
  # smoothing term
130
154
  X_block += -eps
155
+ # special handle dx_y
156
+ X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
131
157
  # reduction scale
132
158
  if reduction == "mean":
133
159
  X_block = X_block / (n_non_ignore)
160
+ # chain rule
161
+ # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
162
+ if HAS_SOFTCAPPING:
163
+ X_block = X_block * (1 - intermediate * intermediate)
134
164
 
135
165
  tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
136
166
 
@@ -151,7 +181,7 @@ def liger_cross_entropy_kernel(
151
181
  # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
152
182
  # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
153
183
  # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
154
- # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd))
184
+ # = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd))
155
185
  # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
156
186
  # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
157
187
  # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
@@ -168,17 +198,9 @@ def liger_cross_entropy_kernel(
168
198
  z_loss = z_loss / n_non_ignore
169
199
  loss = loss / n_non_ignore
170
200
 
171
- # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
172
- X_y = tl.load(X_ptr + y)
173
- if reduction == "mean":
174
- X_y += -(1 - label_smoothing) / (n_non_ignore)
175
- else:
176
- X_y += -(1 - label_smoothing)
177
-
178
201
  tl.store(loss_ptr, loss)
179
202
  if RETURN_Z_LOSS == _TRUE:
180
203
  tl.store(z_loss_ptr, z_loss)
181
- tl.store(X_ptr + y, X_y)
182
204
 
183
205
 
184
206
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
@@ -200,6 +222,7 @@ def cross_entropy_forward(
200
222
  lse_square_scale,
201
223
  label_smoothing,
202
224
  reduction,
225
+ softcap,
203
226
  return_z_loss,
204
227
  ):
205
228
  if not isinstance(return_z_loss, int):
@@ -247,8 +270,10 @@ def cross_entropy_forward(
247
270
  lse_square_scale=lse_square_scale,
248
271
  label_smoothing=label_smoothing,
249
272
  reduction=reduction,
273
+ softcap=softcap if softcap is not None else 0.0,
250
274
  RETURN_Z_LOSS=return_z_loss,
251
275
  BLOCK_SIZE=BLOCK_SIZE,
276
+ HAS_SOFTCAPPING=True if softcap is not None else False,
252
277
  # TODO: 32 seems to give the best performance
253
278
  # Performance is quite sensitive to num_warps
254
279
  num_warps=32 if not is_hip() else 16,
@@ -296,13 +321,14 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
296
321
  @staticmethod
297
322
  def forward(
298
323
  ctx,
299
- _input,
300
- target,
301
- ignore_index=-100,
302
- lse_square_scale=0.0,
303
- label_smoothing=0.0,
304
- reduction="mean",
305
- return_z_loss=False,
324
+ _input: torch.Tensor,
325
+ target: torch.Tensor,
326
+ ignore_index: int = -100,
327
+ lse_square_scale: float = 0.0,
328
+ label_smoothing: float = 0.0,
329
+ reduction: str = "mean",
330
+ softcap: Optional[float] = None,
331
+ return_z_loss: bool = False,
306
332
  ):
307
333
  """
308
334
  The forward pass of the Liger Cross Entropy loss.
@@ -315,6 +341,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
315
341
  lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
316
342
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
317
343
  reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
344
+ softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).
318
345
  return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`
319
346
 
320
347
  Returns:
@@ -327,6 +354,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
327
354
  lse_square_scale,
328
355
  label_smoothing,
329
356
  reduction,
357
+ softcap,
330
358
  return_z_loss,
331
359
  )
332
360
  # TODO: investigation
@@ -362,4 +390,5 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
362
390
  None,
363
391
  None,
364
392
  None,
393
+ None,
365
394
  )
@@ -24,6 +24,7 @@ def fused_linear_cross_entropy_forward(
24
24
  lse_square_scale=0.0,
25
25
  label_smoothing=0.0,
26
26
  reduction="mean",
27
+ softcap=None,
27
28
  ):
28
29
  dtype = _input.dtype
29
30
  device = _input.device
@@ -95,7 +96,9 @@ def fused_linear_cross_entropy_forward(
95
96
  lse_square_scale=lse_square_scale,
96
97
  label_smoothing=label_smoothing,
97
98
  reduction=reduction,
99
+ softcap=softcap if softcap is not None else 0.0,
98
100
  RETURN_Z_LOSS=0, # False
101
+ HAS_SOFTCAPPING=True if softcap is not None else False,
99
102
  BLOCK_SIZE=BLOCK_SIZE,
100
103
  num_warps=32 if not is_hip() else 16,
101
104
  )
@@ -207,6 +210,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
207
210
  lse_square_scale=0.0,
208
211
  label_smoothing=0.0,
209
212
  reduction="mean",
213
+ softcap=None,
210
214
  ):
211
215
  """
212
216
  Fusing the last linear layer with cross-entropy loss
@@ -234,6 +238,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
234
238
  lse_square_scale,
235
239
  label_smoothing,
236
240
  reduction,
241
+ softcap,
237
242
  )
238
243
  # downcast to dtype and store for backward
239
244
  ctx.save_for_backward(
@@ -250,4 +255,4 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
250
255
  grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
251
256
  grad_output, grad_input, grad_weight, grad_bias
252
257
  )
253
- return (grad_input, grad_weight, None, grad_bias, None, None, None, None)
258
+ return (grad_input, grad_weight, None, grad_bias, None, None, None, None, None)
@@ -1,34 +1,43 @@
1
- import torch.nn as nn
1
+ from typing import Optional
2
+
3
+ import torch
2
4
 
3
5
  from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
4
6
 
5
7
 
6
- class LigerCrossEntropyLoss(nn.Module):
8
+ class LigerCrossEntropyLoss(torch.nn.Module):
7
9
  def __init__(
8
10
  self,
9
- ignore_index=-100,
10
- lse_square_scale=0.0,
11
- label_smoothing=0.0,
12
- reduction="mean",
13
- return_z_loss=False,
11
+ ignore_index: int = -100,
12
+ lse_square_scale: float = 0.0,
13
+ label_smoothing: float = 0.0,
14
+ reduction: str = "mean",
15
+ softcap: Optional[float] = None,
16
+ return_z_loss: bool = False,
14
17
  ):
15
18
  super().__init__()
19
+ assert (label_smoothing >= 0) and (
20
+ label_smoothing <= 1
21
+ ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
22
+ assert (label_smoothing >= 0) and (
23
+ label_smoothing <= 1
24
+ ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
25
+ assert reduction in {
26
+ "mean",
27
+ "sum",
28
+ "none",
29
+ }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
30
+ assert (
31
+ softcap is None or softcap > 0
32
+ ), f"softcap must greater than 0.0 or None. Got: {softcap}"
16
33
  self.ignore_index = ignore_index
17
34
  self.lse_square_scale = lse_square_scale
18
35
  self.label_smoothing = label_smoothing
19
36
  self.reduction = reduction
37
+ self.softcap = softcap
20
38
  self.return_z_loss = return_z_loss
21
39
 
22
- assert (self.label_smoothing >= 0) and (
23
- self.label_smoothing <= 1
24
- ), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}"
25
- assert self.reduction in {
26
- "mean",
27
- "sum",
28
- "none",
29
- }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {self.reduction}"
30
-
31
- def forward(self, _input, target):
40
+ def forward(self, _input: torch.Tensor, target: torch.Tensor):
32
41
  loss, z_loss = LigerCrossEntropyFunction.apply(
33
42
  _input,
34
43
  target,
@@ -36,6 +45,7 @@ class LigerCrossEntropyLoss(nn.Module):
36
45
  self.lse_square_scale,
37
46
  self.label_smoothing,
38
47
  self.reduction,
48
+ self.softcap,
39
49
  self.return_z_loss,
40
50
  )
41
51
  if not self.return_z_loss:
@@ -1,26 +1,38 @@
1
- import torch.nn as nn
1
+ from typing import Optional
2
+
3
+ import torch
2
4
 
3
5
  from liger_kernel.ops.fused_linear_cross_entropy import (
4
6
  LigerFusedLinearCrossEntropyFunction,
5
7
  )
6
8
 
7
9
 
8
- class LigerFusedLinearCrossEntropyLoss(nn.Module):
10
+ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
9
11
  def __init__(
10
12
  self,
11
- ignore_index=-100,
12
- label_smoothing=0.0,
13
- reduction="mean",
14
- lse_square_scale=0.0,
13
+ ignore_index: int = -100,
14
+ lse_square_scale: float = 0.0,
15
+ label_smoothing: float = 0.0,
16
+ reduction: str = "mean",
17
+ softcap: Optional[float] = None,
15
18
  ):
16
19
  super().__init__()
20
+ assert (label_smoothing >= 0) and (
21
+ label_smoothing <= 1
22
+ ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
23
+ assert reduction in {
24
+ "mean",
25
+ "sum",
26
+ "none",
27
+ }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
28
+ assert (
29
+ softcap is None or softcap > 0
30
+ ), f"softcap must greater than 0.0 or None. Got: {softcap}"
17
31
  self.ignore_index = ignore_index
32
+ self.lse_square_scale = lse_square_scale
18
33
  self.label_smoothing = label_smoothing
19
34
  self.reduction = reduction
20
- self.lse_square_scale = lse_square_scale
21
- assert (self.label_smoothing >= 0) and (
22
- self.label_smoothing <= 1
23
- ), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}"
35
+ self.softcap = softcap
24
36
 
25
37
  def forward(self, lin_weight, _input, target, bias=None):
26
38
  return LigerFusedLinearCrossEntropyFunction.apply(
@@ -32,4 +44,5 @@ class LigerFusedLinearCrossEntropyLoss(nn.Module):
32
44
  self.lse_square_scale,
33
45
  self.label_smoothing,
34
46
  self.reduction,
47
+ self.softcap,
35
48
  )
@@ -0,0 +1,277 @@
1
+ import logging
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch.nn import CrossEntropyLoss
6
+ from transformers.cache_utils import HybridCache
7
+ from transformers.modeling_outputs import CausalLMOutputWithPast
8
+ from transformers.models.gemma2.modeling_gemma2 import (
9
+ _CONFIG_FOR_DOC,
10
+ GEMMA2_INPUTS_DOCSTRING,
11
+ )
12
+ from transformers.utils import (
13
+ add_start_docstrings_to_model_forward,
14
+ replace_return_docstrings,
15
+ )
16
+
17
+ from liger_kernel.transformers.fused_linear_cross_entropy import (
18
+ LigerFusedLinearCrossEntropyLoss,
19
+ )
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def lce_forward_deprecated(
25
+ self,
26
+ input_ids: torch.LongTensor = None,
27
+ attention_mask: Optional[torch.Tensor] = None,
28
+ position_ids: Optional[torch.LongTensor] = None,
29
+ past_key_values: Optional[HybridCache] = None,
30
+ inputs_embeds: Optional[torch.FloatTensor] = None,
31
+ labels: Optional[torch.LongTensor] = None,
32
+ use_cache: Optional[bool] = None,
33
+ output_attentions: Optional[bool] = None,
34
+ output_hidden_states: Optional[bool] = None,
35
+ return_dict: Optional[bool] = None,
36
+ cache_position: Optional[torch.LongTensor] = None,
37
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
38
+ r"""
39
+ Args:
40
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
41
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
42
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
43
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
44
+
45
+ Returns:
46
+
47
+ Example:
48
+
49
+ ```python
50
+ >>> from transformers import AutoTokenizer, GemmaForCausalLM
51
+ >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b")
52
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
53
+ >>> prompt = "What is your favorite condiment?"
54
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
55
+ >>> # Generate
56
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
57
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
58
+ "What is your favorite condiment?"
59
+ ```"""
60
+
61
+ if self.training and self.config._attn_implementation != "eager":
62
+ logger.warning_once(
63
+ "It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
64
+ f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
65
+ )
66
+ output_attentions = (
67
+ output_attentions
68
+ if output_attentions is not None
69
+ else self.config.output_attentions
70
+ )
71
+ output_hidden_states = (
72
+ output_hidden_states
73
+ if output_hidden_states is not None
74
+ else self.config.output_hidden_states
75
+ )
76
+ return_dict = (
77
+ return_dict if return_dict is not None else self.config.use_return_dict
78
+ )
79
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
80
+ outputs = self.model(
81
+ input_ids=input_ids,
82
+ attention_mask=attention_mask,
83
+ position_ids=position_ids,
84
+ past_key_values=past_key_values,
85
+ inputs_embeds=inputs_embeds,
86
+ use_cache=use_cache,
87
+ output_attentions=output_attentions,
88
+ output_hidden_states=output_hidden_states,
89
+ return_dict=return_dict,
90
+ cache_position=cache_position,
91
+ )
92
+
93
+ hidden_states = outputs[0]
94
+
95
+ loss = None
96
+ logits = None
97
+
98
+ if self.training and (labels is not None):
99
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
100
+ shift_labels = labels[..., 1:].contiguous()
101
+
102
+ # flatten
103
+
104
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
105
+ shift_labels = shift_labels.view(-1)
106
+
107
+ lce = LigerFusedLinearCrossEntropyLoss(
108
+ softcap=self.config.final_logit_softcapping
109
+ )
110
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
111
+
112
+ else:
113
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
114
+ logits = self.lm_head(hidden_states)
115
+ if self.config.final_logit_softcapping is not None:
116
+ logits = logits / self.config.final_logit_softcapping
117
+ logits = torch.tanh(logits)
118
+ logits = logits * self.config.final_logit_softcapping
119
+
120
+ loss = None
121
+ if labels is not None:
122
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
123
+ logits = logits.float()
124
+ # Shift so that tokens < n predict n
125
+ shift_logits = logits[..., :-1, :].contiguous()
126
+ shift_labels = labels[..., 1:].contiguous()
127
+ # Flatten the tokens
128
+ loss_fct = CrossEntropyLoss()
129
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
130
+ shift_labels = shift_labels.view(-1)
131
+ # Enable model parallelism
132
+ shift_labels = shift_labels.to(shift_logits.device)
133
+ loss = loss_fct(shift_logits, shift_labels)
134
+
135
+ if not return_dict:
136
+ output = (logits,) + outputs[1:]
137
+ return (loss,) + output if loss is not None else output
138
+
139
+ return CausalLMOutputWithPast(
140
+ loss=loss,
141
+ logits=logits,
142
+ past_key_values=outputs.past_key_values,
143
+ hidden_states=outputs.hidden_states,
144
+ attentions=outputs.attentions,
145
+ )
146
+
147
+
148
+ @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
149
+ @replace_return_docstrings(
150
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
151
+ )
152
+ def lce_forward(
153
+ self,
154
+ input_ids: torch.LongTensor = None,
155
+ attention_mask: Optional[torch.Tensor] = None,
156
+ position_ids: Optional[torch.LongTensor] = None,
157
+ past_key_values: Optional[HybridCache] = None,
158
+ inputs_embeds: Optional[torch.FloatTensor] = None,
159
+ labels: Optional[torch.LongTensor] = None,
160
+ use_cache: Optional[bool] = None,
161
+ output_attentions: Optional[bool] = None,
162
+ output_hidden_states: Optional[bool] = None,
163
+ return_dict: Optional[bool] = None,
164
+ cache_position: Optional[torch.LongTensor] = None,
165
+ num_logits_to_keep: int = 0,
166
+ **loss_kwargs,
167
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
168
+ r"""
169
+ Args:
170
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
171
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
172
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
173
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
174
+
175
+ num_logits_to_keep (`int`, *optional*):
176
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
177
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
178
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
179
+
180
+ Returns:
181
+
182
+ Example:
183
+
184
+ ```python
185
+ >>> from transformers import AutoTokenizer, GemmaForCausalLM
186
+
187
+ >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b")
188
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
189
+
190
+ >>> prompt = "What is your favorite condiment?"
191
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
192
+
193
+ >>> # Generate
194
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
195
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
196
+ "What is your favorite condiment?"
197
+ ```"""
198
+
199
+ if self.training and self.config._attn_implementation != "eager":
200
+ logger.warning_once(
201
+ "It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
202
+ f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
203
+ )
204
+ output_attentions = (
205
+ output_attentions
206
+ if output_attentions is not None
207
+ else self.config.output_attentions
208
+ )
209
+ output_hidden_states = (
210
+ output_hidden_states
211
+ if output_hidden_states is not None
212
+ else self.config.output_hidden_states
213
+ )
214
+ return_dict = (
215
+ return_dict if return_dict is not None else self.config.use_return_dict
216
+ )
217
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
218
+ outputs = self.model(
219
+ input_ids=input_ids,
220
+ attention_mask=attention_mask,
221
+ position_ids=position_ids,
222
+ past_key_values=past_key_values,
223
+ inputs_embeds=inputs_embeds,
224
+ use_cache=use_cache,
225
+ output_attentions=output_attentions,
226
+ output_hidden_states=output_hidden_states,
227
+ return_dict=return_dict,
228
+ cache_position=cache_position,
229
+ )
230
+
231
+ hidden_states = outputs[0]
232
+
233
+ logits = None
234
+ loss = None
235
+ # if in training mode, don't materialize logits
236
+ if self.training and (labels is not None):
237
+ # We do the same thing as ForCausalLMLoss but using Liger FLCE
238
+
239
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
240
+ shift_labels = labels[..., 1:].contiguous()
241
+
242
+ # flatten tokens
243
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
244
+ shift_labels = shift_labels.view(-1)
245
+
246
+ reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
247
+ lce = LigerFusedLinearCrossEntropyLoss(
248
+ softcap=self.config.final_logit_softcapping,
249
+ reduction=reduction,
250
+ )
251
+
252
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
253
+ if reduction == "sum":
254
+ loss /= loss_kwargs["num_items_in_batch"]
255
+
256
+ else: # if in inference mode materialize logits
257
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
258
+ if self.config.final_logit_softcapping is not None:
259
+ logits = logits / self.config.final_logit_softcapping
260
+ logits = torch.tanh(logits)
261
+ logits = logits * self.config.final_logit_softcapping
262
+
263
+ loss = None
264
+ if labels is not None:
265
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
266
+
267
+ if not return_dict:
268
+ output = (logits,) + outputs[1:]
269
+ return (loss,) + output if loss is not None else output
270
+
271
+ return CausalLMOutputWithPast(
272
+ loss=loss,
273
+ logits=logits,
274
+ past_key_values=outputs.past_key_values,
275
+ hidden_states=outputs.hidden_states,
276
+ attentions=outputs.attentions,
277
+ )
@@ -14,6 +14,10 @@ from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forwa
14
14
  from liger_kernel.transformers.model.gemma import (
15
15
  lce_forward_deprecated as gemma_lce_forward_deprecated,
16
16
  )
17
+ from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
18
+ from liger_kernel.transformers.model.gemma2 import (
19
+ lce_forward_deprecated as gemma2_lce_forward_deprected,
20
+ )
17
21
  from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
18
22
  from liger_kernel.transformers.model.llama import (
19
23
  lce_forward_deprecated as llama_lce_forward_deprecated,
@@ -252,7 +256,7 @@ def apply_liger_kernel_to_mistral(
252
256
  Apply Liger kernels to replace original implementation in HuggingFace Mistral models
253
257
 
254
258
  Args:
255
- rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
259
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
256
260
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
257
261
  fused_linear_cross_entropy (bool):
258
262
  Whether to apply Liger's fused linear cross entropy loss. Default is True.
@@ -445,7 +449,8 @@ def apply_liger_kernel_to_gemma(
445
449
 
446
450
  def apply_liger_kernel_to_gemma2(
447
451
  rope: bool = True,
448
- cross_entropy: bool = True,
452
+ cross_entropy: bool = False,
453
+ fused_linear_cross_entropy: bool = True,
449
454
  rms_norm: bool = True,
450
455
  geglu: bool = True,
451
456
  model: PreTrainedModel = None,
@@ -456,12 +461,19 @@ def apply_liger_kernel_to_gemma2(
456
461
 
457
462
  Args:
458
463
  rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
459
- cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
464
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
465
+ fused_linear_cross_entropy (bool):
466
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
467
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
468
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
460
469
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
461
470
  geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
462
471
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
463
472
  loaded. Default is None.
464
473
  """
474
+ assert not (
475
+ cross_entropy and fused_linear_cross_entropy
476
+ ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
465
477
  from transformers.models.gemma2 import modeling_gemma2
466
478
  from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
467
479
 
@@ -479,6 +491,12 @@ def apply_liger_kernel_to_gemma2(
479
491
  modeling_gemma2.Gemma2RMSNorm = LigerRMSNormForGemma2
480
492
  if cross_entropy:
481
493
  modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
494
+ if fused_linear_cross_entropy:
495
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
496
+ modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
497
+ else:
498
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
499
+ modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
482
500
  if geglu:
483
501
  modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
484
502
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.0.dev20241108173943
3
+ Version: 0.4.0.dev20241108174843
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -37,6 +37,7 @@ src/liger_kernel/transformers/trainer_integration.py
37
37
  src/liger_kernel/transformers/experimental/embedding.py
38
38
  src/liger_kernel/transformers/model/__init__.py
39
39
  src/liger_kernel/transformers/model/gemma.py
40
+ src/liger_kernel/transformers/model/gemma2.py
40
41
  src/liger_kernel/transformers/model/llama.py
41
42
  src/liger_kernel/transformers/model/mistral.py
42
43
  src/liger_kernel/transformers/model/mixtral.py