liger-kernel 0.5.5__py3-none-any.whl → 0.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. liger_kernel/chunked_loss/functional.py +2 -0
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +17 -2
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +346 -0
  4. liger_kernel/chunked_loss/grpo_loss.py +134 -60
  5. liger_kernel/chunked_loss/jsd_loss.py +12 -7
  6. liger_kernel/ops/cross_entropy.py +3 -2
  7. liger_kernel/ops/dyt.py +225 -0
  8. liger_kernel/ops/fused_linear_jsd.py +2 -1
  9. liger_kernel/ops/jsd.py +32 -12
  10. liger_kernel/ops/kl_div.py +15 -8
  11. liger_kernel/ops/layer_norm.py +14 -1
  12. liger_kernel/ops/rms_norm.py +12 -1
  13. liger_kernel/transformers/__init__.py +133 -15
  14. liger_kernel/transformers/dyt.py +20 -0
  15. liger_kernel/transformers/functional.py +5 -0
  16. liger_kernel/transformers/gema3_rms.py +8 -0
  17. liger_kernel/transformers/model/gemma.py +17 -20
  18. liger_kernel/transformers/model/gemma2.py +17 -21
  19. liger_kernel/transformers/model/gemma3.py +335 -0
  20. liger_kernel/transformers/model/llama.py +17 -19
  21. liger_kernel/transformers/model/llava.py +369 -0
  22. liger_kernel/transformers/model/loss_utils.py +64 -0
  23. liger_kernel/transformers/model/mistral.py +28 -25
  24. liger_kernel/transformers/model/mixtral.py +20 -26
  25. liger_kernel/transformers/model/mllama.py +17 -19
  26. liger_kernel/transformers/model/olmo2.py +17 -20
  27. liger_kernel/transformers/model/paligemma.py +397 -0
  28. liger_kernel/transformers/model/phi3.py +17 -19
  29. liger_kernel/transformers/model/qwen2.py +17 -19
  30. liger_kernel/transformers/model/qwen2_5_vl.py +9 -10
  31. liger_kernel/transformers/model/qwen2_vl.py +9 -10
  32. liger_kernel/transformers/monkey_patch.py +392 -13
  33. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/METADATA +11 -6
  34. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/RECORD +38 -31
  35. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/WHEEL +1 -1
  36. liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -240
  37. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info/licenses}/LICENSE +0 -0
  38. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info/licenses}/NOTICE +0 -0
  39. {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,7 @@ import triton.language as tl
6
6
 
7
7
  from liger_kernel.ops.utils import ensure_contiguous
8
8
  from liger_kernel.ops.utils import is_hip
9
+ from liger_kernel.utils import infer_device
9
10
 
10
11
 
11
12
  def get_num_warps(BLOCK_SIZE):
@@ -115,9 +116,12 @@ def _kldiv_kernel_backward(
115
116
 
116
117
  def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
117
118
  BT, V = y_pred.shape
118
-
119
- BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
120
- num_warps = get_num_warps(BLOCK_SIZE)
119
+ BLOCK_SIZE = (
120
+ min(8192, triton.next_power_of_2(V))
121
+ if infer_device() == "xpu"
122
+ else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
123
+ )
124
+ num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
121
125
 
122
126
  grid = (BT,)
123
127
  reduction = _str_to_reduction_mode[reduction]
@@ -155,9 +159,12 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
155
159
 
156
160
  def kldiv_backward_triton(target, grad_output, new_grads, log_target):
157
161
  BT, V = target.shape
158
-
159
- BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
160
- num_warps = get_num_warps(BLOCK_SIZE)
162
+ BLOCK_SIZE = (
163
+ min(8192, triton.next_power_of_2(V))
164
+ if infer_device() == "xpu"
165
+ else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
166
+ )
167
+ num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
161
168
 
162
169
  grid = (BT,)
163
170
 
@@ -185,9 +192,9 @@ class LigerKLDivLossFunction(torch.autograd.Function):
185
192
  Class implementing the forward and backward pass for the KL Divergence Loss using Triton, as defined by the following formula:
186
193
  ```python
187
194
  if log_target:
188
- loss = target * (target.log() - input)
189
- else:
190
195
  loss = target.exp() * (target - input)
196
+ else:
197
+ loss = target * (target.log() - input)
191
198
  ```,
192
199
  then the loss is reduced according to the `reduction` parameter.
193
200
  as defined in the PyTorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
@@ -154,6 +154,11 @@ def layer_norm_forward(X, W, B, eps):
154
154
  f"must match weight size (W.shape[0]={W.shape[0]})"
155
155
  )
156
156
 
157
+ # XPU-specific optimization
158
+ kernel_args = {}
159
+ if X.device.type == "xpu":
160
+ kernel_args["grf_mode"] = "large"
161
+
157
162
  _layer_norm_forward_kernel[(n_rows,)](
158
163
  Y,
159
164
  Y.stride(0),
@@ -171,6 +176,7 @@ def layer_norm_forward(X, W, B, eps):
171
176
  eps,
172
177
  BLOCK_SIZE=BLOCK_SIZE,
173
178
  num_warps=num_warps,
179
+ **kernel_args, # XPU-specific optimization
174
180
  )
175
181
  return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
176
182
 
@@ -185,7 +191,7 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
185
191
  if X.device.type == "cuda":
186
192
  sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
187
193
  elif X.device.type == "xpu":
188
- sm_count = torch.xpu.get_device_properties(X.device).gpu_subslice_count
194
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
189
195
 
190
196
  DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
191
197
  _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
@@ -208,6 +214,12 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
208
214
  if X.dtype == torch.float16
209
215
  else tl.float32 # fallback to float32 for other types
210
216
  )
217
+
218
+ # XPU-specific optimization
219
+ kernel_args = {}
220
+ if X.device.type == "xpu":
221
+ kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
222
+
211
223
  _layer_norm_backward_kernel[grid](
212
224
  X,
213
225
  W,
@@ -227,6 +239,7 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
227
239
  rows_per_program,
228
240
  BLOCK_SIZE=BLOCK_SIZE,
229
241
  dtype=triton_dtype,
242
+ **kernel_args, # XPU-specific optimization
230
243
  )
231
244
 
232
245
  DW = _DW.sum(dim=0).to(W.dtype)
@@ -223,6 +223,10 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
223
223
  # Check constraints.
224
224
  assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
225
225
 
226
+ # XPU-specific optimization
227
+ kernel_args = {}
228
+ if X.device.type == "xpu":
229
+ kernel_args["grf_mode"] = "large"
226
230
  _rms_norm_forward_kernel[(n_rows,)](
227
231
  Y,
228
232
  Y.stride(0),
@@ -238,6 +242,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
238
242
  casting_mode,
239
243
  BLOCK_SIZE=BLOCK_SIZE,
240
244
  num_warps=num_warps,
245
+ **kernel_args, # XPU-specific optimization
241
246
  )
242
247
  return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
243
248
 
@@ -252,7 +257,7 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
252
257
  if X.device.type == "cuda":
253
258
  sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
254
259
  elif X.device.type == "xpu":
255
- sm_count = torch.xpu.get_device_properties(X.device).gpu_subslice_count
260
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
256
261
 
257
262
  # fp32 for numerical stability especially.
258
263
  _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
@@ -267,6 +272,11 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
267
272
  else:
268
273
  dX = torch.zeros_like(dY)
269
274
 
275
+ # XPU-specific optimization
276
+ kernel_args = {}
277
+ if X.device.type == "xpu":
278
+ kernel_args["grf_mode"] = "large"
279
+
270
280
  _rms_norm_backward_kernel[grid](
271
281
  dY,
272
282
  dY.stride(0),
@@ -288,6 +298,7 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
288
298
  casting_mode,
289
299
  BLOCK_SIZE=BLOCK_SIZE,
290
300
  num_warps=num_warps,
301
+ **kernel_args, # XPU-specific optimization
291
302
  )
292
303
  dX = dX.view(*shape)
293
304
  dW = _dW.sum(dim=0).to(W.dtype)
@@ -1,27 +1,145 @@
1
- from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
1
+ import importlib
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Always-safe imports (independent of 'transformers')
2
6
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
7
+ from liger_kernel.transformers.dyt import LigerDyT # noqa: F401
3
8
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401
4
9
  from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
5
10
  from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
6
11
  from liger_kernel.transformers.jsd import LigerJSD # noqa: F401
7
12
  from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
8
- from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
9
- from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
10
- from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
11
- from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
12
- from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
13
- from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
14
- from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
15
- from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
16
- from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
17
- from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo2 # noqa: F401
18
- from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
19
- from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
20
- from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_5_vl # noqa: F401
21
- from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
22
13
  from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
23
14
  from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
24
15
  from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
25
16
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
26
17
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
27
18
  from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401
19
+
20
+ # Static-only imports for IDEs and type checkers
21
+ if TYPE_CHECKING:
22
+ from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
23
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
24
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
25
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
26
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
27
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401
28
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401
29
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
30
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
31
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
32
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
33
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
34
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
35
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo2 # noqa: F401
36
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_paligemma # noqa: F401
37
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
38
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
39
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_5_vl # noqa: F401
40
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
41
+
42
+
43
+ # Check if 'transformers' is installed
44
+ try:
45
+ import transformers # noqa: F401
46
+
47
+ _TRANSFORMERS_AVAILABLE = True
48
+ except ImportError:
49
+ _TRANSFORMERS_AVAILABLE = False
50
+
51
+
52
+ def is_transformers_available() -> bool:
53
+ """
54
+ Returns True if the 'transformers' package is available.
55
+ Useful for conditional logic in downstream code.
56
+ """
57
+ return _TRANSFORMERS_AVAILABLE
58
+
59
+
60
+ def __getattr__(name: str):
61
+ """
62
+ Handles lazy access to transformer-dependent attributes.
63
+ If 'transformers' is not installed, raises a user-friendly ImportError.
64
+ """
65
+ if not _TRANSFORMERS_AVAILABLE:
66
+ raise ImportError(
67
+ f"The attribute '{name}' requires the 'transformers' library, which is not installed.\n"
68
+ f"Please install it with `pip install transformers` to use this functionality."
69
+ )
70
+
71
+ if name == "AutoLigerKernelForCausalLM":
72
+ module = importlib.import_module("liger_kernel.transformers.auto_model")
73
+ return getattr(module, name)
74
+
75
+ monkey_patch_symbols = {
76
+ "_apply_liger_kernel",
77
+ "_apply_liger_kernel_to_instance",
78
+ "apply_liger_kernel_to_gemma",
79
+ "apply_liger_kernel_to_gemma2",
80
+ "apply_liger_kernel_to_gemma3",
81
+ "apply_liger_kernel_to_gemma3_text",
82
+ "apply_liger_kernel_to_granite",
83
+ "apply_liger_kernel_to_llama",
84
+ "apply_liger_kernel_to_llava",
85
+ "apply_liger_kernel_to_mistral",
86
+ "apply_liger_kernel_to_mixtral",
87
+ "apply_liger_kernel_to_mllama",
88
+ "apply_liger_kernel_to_olmo2",
89
+ "apply_liger_kernel_to_paligemma",
90
+ "apply_liger_kernel_to_phi3",
91
+ "apply_liger_kernel_to_qwen2",
92
+ "apply_liger_kernel_to_qwen2_5_vl",
93
+ "apply_liger_kernel_to_qwen2_vl",
94
+ }
95
+
96
+ if name in monkey_patch_symbols:
97
+ module = importlib.import_module("liger_kernel.transformers.monkey_patch")
98
+ return getattr(module, name)
99
+
100
+ raise AttributeError(f"module {__name__} has no attribute {name}")
101
+
102
+
103
+ # Shared symbols in all environments
104
+ __all__ = [
105
+ "is_transformers_available",
106
+ "LigerCrossEntropyLoss",
107
+ "LigerDyT",
108
+ "LigerFusedLinearCrossEntropyLoss",
109
+ "LigerFusedLinearJSD",
110
+ "LigerGEGLUMLP",
111
+ "LigerJSD",
112
+ "LigerLayerNorm",
113
+ "LigerRMSNorm",
114
+ "liger_rotary_pos_emb",
115
+ "LigerBlockSparseTop2MLP",
116
+ "LigerPhi3SwiGLUMLP",
117
+ "LigerSwiGLUMLP",
118
+ "LigerTVDLoss",
119
+ ]
120
+
121
+ # Add transformer-dependent symbols only if available
122
+ if _TRANSFORMERS_AVAILABLE:
123
+ __all__.extend(
124
+ [
125
+ "AutoLigerKernelForCausalLM",
126
+ "_apply_liger_kernel",
127
+ "_apply_liger_kernel_to_instance",
128
+ "apply_liger_kernel_to_gemma",
129
+ "apply_liger_kernel_to_gemma2",
130
+ "apply_liger_kernel_to_gemma3",
131
+ "apply_liger_kernel_to_gemma3_text",
132
+ "apply_liger_kernel_to_granite",
133
+ "apply_liger_kernel_to_llama",
134
+ "apply_liger_kernel_to_llava",
135
+ "apply_liger_kernel_to_mistral",
136
+ "apply_liger_kernel_to_mixtral",
137
+ "apply_liger_kernel_to_mllama",
138
+ "apply_liger_kernel_to_olmo2",
139
+ "apply_liger_kernel_to_paligemma",
140
+ "apply_liger_kernel_to_phi3",
141
+ "apply_liger_kernel_to_qwen2",
142
+ "apply_liger_kernel_to_qwen2_5_vl",
143
+ "apply_liger_kernel_to_qwen2_vl",
144
+ ]
145
+ )
@@ -0,0 +1,20 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from liger_kernel.ops.dyt import LigerDyTFunction
5
+
6
+
7
+ class LigerDyT(nn.Module):
8
+ def __init__(self, hidden_size, init_alpha=0.5):
9
+ super().__init__()
10
+ self.hidden_size = hidden_size
11
+ self.init_alpha = init_alpha
12
+ self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
13
+ self.gamma = nn.Parameter(torch.ones(hidden_size))
14
+ self.beta = nn.Parameter(torch.zeros(hidden_size))
15
+
16
+ def forward(self, x):
17
+ return LigerDyTFunction.apply(x, self.alpha, self.gamma, self.beta)
18
+
19
+ def extra_repr(self):
20
+ return f"{self.hidden_size}, init_alpha={self.init_alpha}"
@@ -1,6 +1,7 @@
1
1
  from typing import Optional
2
2
 
3
3
  from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
4
+ from liger_kernel.ops.dyt import LigerDyTFunction
4
5
  from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
5
6
  from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
6
7
  from liger_kernel.ops.geglu import LigerGELUMulFunction
@@ -192,3 +193,7 @@ def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
192
193
 
193
194
  def liger_swiglu(a, b):
194
195
  return LigerSiLUMulFunction.apply(a, b)
196
+
197
+
198
+ def liger_dyt(x, alpha, gamma, beta):
199
+ return LigerDyTFunction.apply(x, alpha, gamma, beta)
@@ -0,0 +1,8 @@
1
+ from .rms_norm import LigerRMSNorm
2
+
3
+
4
+ class LigerRMSNormForGemma3(LigerRMSNorm):
5
+ """Gemma3RMSNorm has a dim argument not hidden_size used in q_norm and k_norm."""
6
+
7
+ def __init__(self, dim, eps=0.000001, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False):
8
+ super().__init__(dim, eps, offset, casting_mode, init_fn, in_place)
@@ -12,8 +12,10 @@ from transformers.models.gemma.modeling_gemma import _CONFIG_FOR_DOC
12
12
  from transformers.models.gemma.modeling_gemma import GEMMA_INPUTS_DOCSTRING
13
13
  from transformers.utils import add_start_docstrings_to_model_forward
14
14
  from transformers.utils import replace_return_docstrings
15
+ from transformers.utils.deprecation import deprecate_kwarg
15
16
 
16
17
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
18
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
17
19
 
18
20
 
19
21
  @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
@@ -126,6 +128,7 @@ def lce_forward_deprecated(
126
128
  )
127
129
 
128
130
 
131
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
129
132
  @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
130
133
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
131
134
  def lce_forward(
@@ -141,7 +144,7 @@ def lce_forward(
141
144
  output_hidden_states: Optional[bool] = None,
142
145
  return_dict: Optional[bool] = None,
143
146
  cache_position: Optional[torch.LongTensor] = None,
144
- num_logits_to_keep: int = 0,
147
+ logits_to_keep: Union[int, torch.Tensor] = 0,
145
148
  **loss_kwargs,
146
149
  ) -> Union[Tuple, CausalLMOutputWithPast]:
147
150
  r"""
@@ -151,10 +154,12 @@ def lce_forward(
151
154
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
152
155
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
153
156
 
154
- num_logits_to_keep (`int`, *optional*):
155
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
157
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
158
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
156
159
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
157
160
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
161
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
162
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
158
163
 
159
164
  Returns:
160
165
 
@@ -200,24 +205,16 @@ def lce_forward(
200
205
  loss = None
201
206
  # if in training mode, don't materialize logits
202
207
  if self.training and (labels is not None):
203
- # We do the same thing as ForCausalLMLoss but using Liger FLCE
204
-
205
- shift_hidden_states = hidden_states[..., :-1, :].contiguous()
206
- shift_labels = labels[..., 1:].contiguous()
207
-
208
- # flatten tokens
209
- shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
210
- shift_labels = shift_labels.view(-1)
211
-
212
- reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
213
- lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
214
-
215
- loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
216
- if reduction == "sum":
217
- loss /= loss_kwargs["num_items_in_batch"]
218
-
208
+ loss = LigerForCausalLMLoss(
209
+ hidden_states=hidden_states,
210
+ lm_head_weight=self.lm_head.weight,
211
+ labels=labels,
212
+ hidden_size=self.config.hidden_size,
213
+ **loss_kwargs,
214
+ )
219
215
  else: # if in inference mode materialize logits
220
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
216
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
217
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
221
218
  if labels is not None:
222
219
  loss = self.loss_function(
223
220
  logits=logits,
@@ -13,8 +13,10 @@ from transformers.models.gemma2.modeling_gemma2 import _CONFIG_FOR_DOC
13
13
  from transformers.models.gemma2.modeling_gemma2 import GEMMA2_INPUTS_DOCSTRING
14
14
  from transformers.utils import add_start_docstrings_to_model_forward
15
15
  from transformers.utils import replace_return_docstrings
16
+ from transformers.utils.deprecation import deprecate_kwarg
16
17
 
17
18
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
19
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
20
 
19
21
  logger = logging.getLogger(__name__)
20
22
 
@@ -133,6 +135,7 @@ def lce_forward_deprecated(
133
135
  )
134
136
 
135
137
 
138
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
136
139
  @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
137
140
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
138
141
  def lce_forward(
@@ -148,7 +151,7 @@ def lce_forward(
148
151
  output_hidden_states: Optional[bool] = None,
149
152
  return_dict: Optional[bool] = None,
150
153
  cache_position: Optional[torch.LongTensor] = None,
151
- num_logits_to_keep: int = 0,
154
+ logits_to_keep: Union[int, torch.Tensor] = 0,
152
155
  **loss_kwargs,
153
156
  ) -> Union[Tuple, CausalLMOutputWithPast]:
154
157
  r"""
@@ -158,10 +161,12 @@ def lce_forward(
158
161
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
159
162
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
160
163
 
161
- num_logits_to_keep (`int`, *optional*):
162
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
164
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
165
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
163
166
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
164
167
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
168
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
169
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
165
170
 
166
171
  Returns:
167
172
 
@@ -212,27 +217,18 @@ def lce_forward(
212
217
  loss = None
213
218
  # if in training mode, don't materialize logits
214
219
  if self.training and (labels is not None):
215
- # We do the same thing as ForCausalLMLoss but using Liger FLCE
216
-
217
- shift_hidden_states = hidden_states[..., :-1, :].contiguous()
218
- shift_labels = labels[..., 1:].contiguous()
219
-
220
- # flatten tokens
221
- shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
222
- shift_labels = shift_labels.view(-1)
223
-
224
- reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
225
- lce = LigerFusedLinearCrossEntropyLoss(
226
- softcap=self.config.final_logit_softcapping,
227
- reduction=reduction,
220
+ loss = LigerForCausalLMLoss(
221
+ hidden_states=hidden_states,
222
+ lm_head_weight=self.lm_head.weight,
223
+ labels=labels,
224
+ hidden_size=self.config.hidden_size,
225
+ final_logit_softcapping=self.config.final_logit_softcapping,
226
+ **loss_kwargs,
228
227
  )
229
228
 
230
- loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
231
- if reduction == "sum":
232
- loss /= loss_kwargs["num_items_in_batch"]
233
-
234
229
  else: # if in inference mode materialize logits
235
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
230
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
231
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
236
232
  if self.config.final_logit_softcapping is not None:
237
233
  logits = logits / self.config.final_logit_softcapping
238
234
  logits = torch.tanh(logits)