liger-kernel 0.5.9__py3-none-any.whl → 0.5.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. liger_kernel/chunked_loss/dpo_loss.py +1 -1
  2. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  3. liger_kernel/chunked_loss/jsd_loss.py +2 -2
  4. liger_kernel/ops/dyt.py +113 -179
  5. liger_kernel/ops/grpo_loss.py +310 -0
  6. liger_kernel/ops/sparsemax.py +167 -0
  7. liger_kernel/transformers/__init__.py +5 -0
  8. liger_kernel/transformers/dyt.py +5 -3
  9. liger_kernel/transformers/fsdp.py +55 -0
  10. liger_kernel/transformers/functional.py +8 -0
  11. liger_kernel/transformers/grpo_loss.py +98 -0
  12. liger_kernel/transformers/model/gemma.py +0 -8
  13. liger_kernel/transformers/model/gemma2.py +0 -6
  14. liger_kernel/transformers/model/gemma3.py +0 -8
  15. liger_kernel/transformers/model/glm4.py +0 -6
  16. liger_kernel/transformers/model/llama.py +56 -11
  17. liger_kernel/transformers/model/llava.py +0 -8
  18. liger_kernel/transformers/model/mistral.py +0 -6
  19. liger_kernel/transformers/model/mixtral.py +0 -8
  20. liger_kernel/transformers/model/mllama.py +0 -7
  21. liger_kernel/transformers/model/olmo2.py +0 -6
  22. liger_kernel/transformers/model/paligemma.py +0 -8
  23. liger_kernel/transformers/model/phi3.py +0 -8
  24. liger_kernel/transformers/model/qwen2.py +0 -8
  25. liger_kernel/transformers/model/qwen2_5_vl.py +0 -6
  26. liger_kernel/transformers/model/qwen2_vl.py +0 -6
  27. liger_kernel/transformers/model/qwen3.py +0 -6
  28. liger_kernel/transformers/model/qwen3_moe.py +128 -0
  29. liger_kernel/transformers/monkey_patch.py +122 -13
  30. liger_kernel/transformers/sparsemax.py +16 -0
  31. liger_kernel/transformers/swiglu.py +21 -0
  32. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  33. liger_kernel/utils.py +11 -0
  34. {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/METADATA +34 -20
  35. {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/RECORD +39 -33
  36. {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/WHEEL +1 -1
  37. {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/licenses/LICENSE +0 -0
  38. {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/licenses/NOTICE +0 -0
  39. {liger_kernel-0.5.9.dist-info → liger_kernel-0.5.10.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,167 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from liger_kernel.ops.utils import calculate_settings
6
+ from liger_kernel.ops.utils import ensure_contiguous
7
+
8
+
9
+ @triton.jit
10
+ def _sparsemax_forward_kernel(
11
+ x_ptr,
12
+ x_stride_row,
13
+ sorted_x_ptr,
14
+ sorted_x_stride_row,
15
+ o_ptr,
16
+ o_stride_row,
17
+ n_cols,
18
+ BLOCK_SIZE: tl.constexpr,
19
+ num_warps: tl.constexpr,
20
+ ):
21
+ pid_row = tl.program_id(0)
22
+ ptr_x_data_row = x_ptr + pid_row * x_stride_row
23
+ ptr_sorted_x_data_row = sorted_x_ptr + pid_row * sorted_x_stride_row
24
+ ptr_output_row = o_ptr + pid_row * o_stride_row
25
+
26
+ offs = tl.arange(0, BLOCK_SIZE)
27
+ mask = offs < n_cols
28
+
29
+ z_sorted_block = tl.load(
30
+ ptr_sorted_x_data_row + offs,
31
+ mask=mask,
32
+ other=-float("inf"),
33
+ cache_modifier=".ca",
34
+ ).to(tl.float32)
35
+
36
+ z_valid = tl.where(mask, z_sorted_block, 0.0)
37
+ cssv = tl.cumsum(z_valid, 0)
38
+
39
+ r = (offs + 1).to(tl.float32)
40
+ safe_r = tl.where(mask, r, 1.0)
41
+
42
+ t_vec = (cssv - 1.0) / safe_r
43
+
44
+ support = (z_sorted_block > t_vec) & mask
45
+
46
+ k_int = tl.sum(support.to(tl.int32), 0)
47
+ k_clamped_int = tl.maximum(k_int, 1)
48
+ k = k_clamped_int.to(tl.float32)
49
+
50
+ s = tl.sum(tl.where(support, z_sorted_block, 0.0), 0)
51
+
52
+ tau = (s - 1.0) / k
53
+
54
+ x_block = tl.load(
55
+ ptr_x_data_row + offs,
56
+ mask=mask,
57
+ other=0.0,
58
+ cache_modifier=".ca",
59
+ ).to(tl.float32)
60
+
61
+ y = tl.maximum(x_block - tau, 0.0)
62
+
63
+ tl.store(
64
+ ptr_output_row + offs,
65
+ y.to(ptr_output_row.dtype.element_ty),
66
+ mask=mask,
67
+ cache_modifier=".cs",
68
+ )
69
+
70
+
71
+ @triton.jit
72
+ def _sparsemax_backward_kernel(
73
+ o_ptr, go_ptr, gi_ptr, stride, n_cols, BLOCK_SIZE: tl.constexpr, num_warps: tl.constexpr
74
+ ):
75
+ row = tl.program_id(0)
76
+ o_row = o_ptr + row * stride
77
+ go_row = go_ptr + row * stride
78
+ gi_row = gi_ptr + row * stride
79
+
80
+ offs = tl.arange(0, BLOCK_SIZE)
81
+
82
+ supp_cnt = tl.zeros((), tl.float32)
83
+ go_sum = tl.zeros((), tl.float32)
84
+
85
+ for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
86
+ offs_iter = i * BLOCK_SIZE + offs
87
+ mask_iter = offs_iter < n_cols
88
+ o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
89
+ go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
90
+ supp = o_val > 0.0
91
+ go_sum += tl.sum(tl.where(supp, go_val, 0.0))
92
+ supp_cnt += tl.sum(supp.to(tl.float32))
93
+
94
+ for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
95
+ offs_iter = i * BLOCK_SIZE + offs
96
+ mask_iter = offs_iter < n_cols
97
+ o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
98
+ go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
99
+ supp = o_val > 0.0
100
+ gi_val = tl.where(
101
+ supp,
102
+ go_val - tl.cast(go_sum / tl.maximum(supp_cnt, 1e-6), gi_row.dtype.element_ty).to(tl.float32),
103
+ 0.0,
104
+ )
105
+ tl.store(gi_row + offs_iter, gi_val.to(gi_row.dtype.element_ty), mask=mask_iter, cache_modifier=".wb")
106
+
107
+
108
+ class LigerSparsemaxFunction(torch.autograd.Function):
109
+ @staticmethod
110
+ @ensure_contiguous
111
+ def forward(ctx, x: torch.Tensor, dim: int):
112
+ if dim < 0:
113
+ dim += x.dim()
114
+ ctx.dim = dim
115
+
116
+ x_sw = x.transpose(dim, -1).contiguous()
117
+ n_cols = x_sw.size(-1)
118
+ n_rows = x_sw.numel() // n_cols
119
+ x_flat = x_sw.view(n_rows, n_cols)
120
+
121
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
122
+ out_flat = torch.empty_like(x_flat)
123
+ grid = (n_rows,)
124
+
125
+ x_sorted_flat = torch.sort(x_flat.float(), dim=-1, descending=True).values
126
+
127
+ _sparsemax_forward_kernel[grid](
128
+ x_flat,
129
+ x_flat.stride(0),
130
+ x_sorted_flat,
131
+ x_sorted_flat.stride(0),
132
+ out_flat,
133
+ out_flat.stride(0),
134
+ n_cols,
135
+ BLOCK_SIZE=BLOCK_SIZE,
136
+ num_warps=num_warps,
137
+ )
138
+
139
+ ctx.save_for_backward(out_flat)
140
+ return out_flat.view_as(x_sw).transpose(dim, -1)
141
+
142
+ @staticmethod
143
+ @ensure_contiguous
144
+ def backward(ctx, grad_out: torch.Tensor):
145
+ (out_flat,) = ctx.saved_tensors
146
+ dim = ctx.dim
147
+
148
+ go_sw = grad_out.transpose(dim, -1).contiguous()
149
+ n_cols = go_sw.size(-1)
150
+ n_rows = go_sw.numel() // n_cols
151
+ go_flat = go_sw.view(n_rows, n_cols)
152
+
153
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
154
+ gi_flat = torch.empty_like(go_flat)
155
+ grid = (n_rows,)
156
+
157
+ _sparsemax_backward_kernel[grid](
158
+ out_flat,
159
+ go_flat,
160
+ gi_flat,
161
+ out_flat.stride(0),
162
+ n_cols,
163
+ BLOCK_SIZE=BLOCK_SIZE,
164
+ num_warps=num_warps,
165
+ )
166
+
167
+ return gi_flat.view_as(go_sw).transpose(dim, -1), None
@@ -14,6 +14,7 @@ from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
14
14
  from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
15
15
  from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
16
16
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
17
+ from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP # noqa: F401
17
18
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
18
19
  from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401
19
20
 
@@ -40,6 +41,7 @@ if TYPE_CHECKING:
40
41
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_5_vl # noqa: F401
41
42
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
42
43
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401
44
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_moe # noqa: F401
43
45
 
44
46
 
45
47
  # Check if 'transformers' is installed
@@ -95,6 +97,7 @@ def __getattr__(name: str):
95
97
  "apply_liger_kernel_to_qwen2_5_vl",
96
98
  "apply_liger_kernel_to_qwen2_vl",
97
99
  "apply_liger_kernel_to_qwen3",
100
+ "apply_liger_kernel_to_qwen3_moe",
98
101
  }
99
102
 
100
103
  if name in monkey_patch_symbols:
@@ -118,6 +121,7 @@ __all__ = [
118
121
  "liger_rotary_pos_emb",
119
122
  "LigerBlockSparseTop2MLP",
120
123
  "LigerPhi3SwiGLUMLP",
124
+ "LigerQwen3MoeSwiGLUMLP",
121
125
  "LigerSwiGLUMLP",
122
126
  "LigerTVDLoss",
123
127
  ]
@@ -147,5 +151,6 @@ if _TRANSFORMERS_AVAILABLE:
147
151
  "apply_liger_kernel_to_qwen2_5_vl",
148
152
  "apply_liger_kernel_to_qwen2_vl",
149
153
  "apply_liger_kernel_to_qwen3",
154
+ "apply_liger_kernel_to_qwen3_moe",
150
155
  ]
151
156
  )
@@ -5,16 +5,18 @@ from liger_kernel.ops.dyt import LigerDyTFunction
5
5
 
6
6
 
7
7
  class LigerDyT(nn.Module):
8
- def __init__(self, hidden_size, init_alpha=0.5):
8
+ def __init__(self, hidden_size, beta=True, init_alpha=0.5):
9
9
  super().__init__()
10
10
  self.hidden_size = hidden_size
11
11
  self.init_alpha = init_alpha
12
12
  self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
13
13
  self.gamma = nn.Parameter(torch.ones(hidden_size))
14
- self.beta = nn.Parameter(torch.zeros(hidden_size))
14
+ self.beta = None
15
+ if beta:
16
+ self.beta = nn.Parameter(torch.zeros(hidden_size))
15
17
 
16
18
  def forward(self, x):
17
19
  return LigerDyTFunction.apply(x, self.alpha, self.gamma, self.beta)
18
20
 
19
21
  def extra_repr(self):
20
- return f"{self.hidden_size}, init_alpha={self.init_alpha}"
22
+ return f"{self.hidden_size}, init_alpha={self.init_alpha}, beta={self.beta}"
@@ -0,0 +1,55 @@
1
+ from typing import Any
2
+ from typing import Callable
3
+
4
+ from torch.distributed.fsdp import FullyShardedDataParallel
5
+
6
+
7
+ class _FSDPForwardRedirection:
8
+ """
9
+ Modified based on
10
+ https://github.com/Lightning-AI/pytorch-lightning/blob/d3f9c83d6efa4f1def36aa6c199600946cdb9117/src/lightning/pytorch/strategies/strategy.py#L601-L648
11
+ Redirect a method call through FullyShardedDataParallel.forward so that the FSDP module's root pre-forward and
12
+ post-forward can be properly executed around the method call.
13
+ This is needed in cases where we call a submodule of a FSDP module. For instance, when we want to call only
14
+ the `LlamaModel` part out of a FSDP-wrapped `LlamaForCausalLM` to get the hidden states without involving
15
+ GPU-memory-heavy `lm_head` and cross entropy computation, doing this directly (i.e. `model.model.forward()`)
16
+ will not work because the first `nn.Embedding` layer is not independently wrapped as a FSDP module (because of
17
+ the transformer-based wrapping policy), and not calling it through FSDP root module forward will not all-gather
18
+ its parameter, thus resulting in "RuntimeError: 'weight' must be 2-D" error. Similarly, if we want to call just
19
+ the `lm_head` part of a model, we need this trick too to properly get its params all-gathered.
20
+ """
21
+
22
+ def __call__(
23
+ self,
24
+ wrapper_module: FullyShardedDataParallel,
25
+ method: Callable,
26
+ *args: Any,
27
+ **kwargs: Any,
28
+ ):
29
+ """Reroutes a method call through the `wrapper_module`'s `forward` method.
30
+ Args:
31
+ wrapper_module: The module that has `original_module` wrapped.
32
+ original_module: The module that was wrapped inside `wrapper_module`.
33
+ method_name: The name of the method that should be called on the `original_module` after inputs get
34
+ redirected through the `wrapper_module`'s `forward` method.
35
+ *args: The positional arguments to the method `method_name`. They will get passed to a patched
36
+ `forward` method instead.
37
+ **kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched
38
+ `forward` method instead.
39
+ """
40
+ assert isinstance(wrapper_module, FullyShardedDataParallel)
41
+ original_module = wrapper_module._fsdp_wrapped_module
42
+ original_forward = original_module.forward
43
+
44
+ def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any:
45
+ # Unpatch ourselves immediately before calling the method `method_name`
46
+ # because itself may want to call the real `forward`
47
+ original_module.forward = original_forward # type: ignore[method-assign]
48
+ # Call the actual method e.g. `.training_step(...)`
49
+ out = method(*_args, **_kwargs)
50
+ return out
51
+
52
+ # Patch the original_module's forward so we can redirect the arguments back to the real method
53
+ original_module.forward = wrapped_forward # type: ignore[method-assign]
54
+ wrapper_output = wrapper_module(*args, **kwargs)
55
+ return wrapper_output
@@ -12,6 +12,7 @@ from liger_kernel.ops.layer_norm import LigerLayerNormFunction
12
12
  from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
13
13
  from liger_kernel.ops.rms_norm import LigerRMSNormFunction
14
14
  from liger_kernel.ops.rope import LigerRopeFunction
15
+ from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
15
16
  from liger_kernel.ops.swiglu import LigerSiLUMulFunction
16
17
  from liger_kernel.ops.tvd import LigerTVDLossFunction
17
18
 
@@ -159,6 +160,13 @@ def liger_kl_div(
159
160
  )
160
161
 
161
162
 
163
+ def liger_sparsemax(
164
+ input,
165
+ dim: int = -1,
166
+ ):
167
+ return LigerSparsemaxFunction.apply(input, dim)
168
+
169
+
162
170
  def liger_tvd(
163
171
  input,
164
172
  target,
@@ -0,0 +1,98 @@
1
+ from liger_kernel.ops.grpo_loss import GrpoLossFunction
2
+
3
+
4
+ def triton_grpo_loss(
5
+ logits,
6
+ old_logp,
7
+ ref_logp,
8
+ completion_ids,
9
+ advantages,
10
+ completion_mask=None,
11
+ temperature=0.9,
12
+ beta=0.04,
13
+ eps_low=0.2,
14
+ eps_high=0.4,
15
+ inplace=True,
16
+ ):
17
+ assert logits is not None and completion_ids is not None and advantages is not None, (
18
+ "must provide logits、completion_ids and advantages"
19
+ )
20
+
21
+ return GrpoLossFunction.apply(
22
+ logits,
23
+ old_logp,
24
+ ref_logp,
25
+ completion_ids,
26
+ advantages,
27
+ completion_mask,
28
+ temperature,
29
+ beta,
30
+ eps_low,
31
+ eps_high,
32
+ inplace,
33
+ )
34
+
35
+
36
+ # This is a demo how to use grpo_loss in GRPOTrainer. The Trl version must be 0.16
37
+ """
38
+ import torch
39
+ import trl
40
+ assert trl.__version__.startswith("0.16"), "please pip install trl==0.16"
41
+ from trl.extras.profiling import profiling_decorator
42
+
43
+ @profiling_decorator
44
+ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
45
+ # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
46
+ logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
47
+ return fused_selective_log_softmax(logits, input_ids, self.temperature, mask=attention_mask)
48
+
49
+ @profiling_decorator
50
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
51
+ if return_outputs:
52
+ raise ValueError("The GRPOTrainer does not support returning outputs")
53
+ # Compute the per-token log probabilities for the model
54
+
55
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
56
+ completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
57
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
58
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
59
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
60
+ logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
61
+
62
+ ref_per_token_logps = inputs["ref_per_token_logps"]
63
+ advantages = inputs["advantages"]
64
+ old_per_token_logps = inputs["old_per_token_logps"]
65
+
66
+
67
+ per_token_loss, per_token_kl, is_clipped = triton_grpo_loss(logits,
68
+ old_per_token_logps,
69
+ ref_per_token_logps,
70
+ completion_ids,
71
+ advantages,
72
+ completion_mask,
73
+ self.temperature,
74
+ self.beta,
75
+ self.epsilon_low,
76
+ self.epsilon_high,)
77
+ loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
78
+
79
+ # Log the metrics
80
+ mode = "eval" if self.control.should_evaluate else "train"
81
+
82
+ if self.beta != 0.0:
83
+ mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
84
+ self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
85
+
86
+ clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
87
+ self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
88
+ return loss
89
+
90
+ trl.GRPOTrainer._get_per_token_logps = _get_per_token_logps
91
+ trl.GRPOTrainer.compute_loss = compute_loss
92
+ trigger = None
93
+ """
94
+
95
+ # add this line at the first line of grpo.py in open-r1
96
+ """
97
+ from liger_kernel.transformers.grpo_loss import trigger
98
+ """
@@ -8,18 +8,12 @@ import torch
8
8
  from torch.nn import CrossEntropyLoss
9
9
  from transformers.cache_utils import Cache
10
10
  from transformers.modeling_outputs import CausalLMOutputWithPast
11
- from transformers.models.gemma.modeling_gemma import _CONFIG_FOR_DOC
12
- from transformers.models.gemma.modeling_gemma import GEMMA_INPUTS_DOCSTRING
13
- from transformers.utils import add_start_docstrings_to_model_forward
14
- from transformers.utils import replace_return_docstrings
15
11
  from transformers.utils.deprecation import deprecate_kwarg
16
12
 
17
13
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
18
14
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
19
15
 
20
16
 
21
- @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
22
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
23
17
  def lce_forward_deprecated(
24
18
  self,
25
19
  input_ids: torch.LongTensor = None,
@@ -129,8 +123,6 @@ def lce_forward_deprecated(
129
123
 
130
124
 
131
125
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
132
- @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
133
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
134
126
  def lce_forward(
135
127
  self,
136
128
  input_ids: torch.LongTensor = None,
@@ -9,10 +9,6 @@ import torch
9
9
  from torch.nn import CrossEntropyLoss
10
10
  from transformers.cache_utils import HybridCache
11
11
  from transformers.modeling_outputs import CausalLMOutputWithPast
12
- from transformers.models.gemma2.modeling_gemma2 import _CONFIG_FOR_DOC
13
- from transformers.models.gemma2.modeling_gemma2 import GEMMA2_INPUTS_DOCSTRING
14
- from transformers.utils import add_start_docstrings_to_model_forward
15
- from transformers.utils import replace_return_docstrings
16
12
  from transformers.utils.deprecation import deprecate_kwarg
17
13
 
18
14
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
@@ -136,8 +132,6 @@ def lce_forward_deprecated(
136
132
 
137
133
 
138
134
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
139
- @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
140
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
141
135
  def lce_forward(
142
136
  self,
143
137
  input_ids: torch.LongTensor = None,
@@ -9,13 +9,9 @@ import torch.nn as nn
9
9
  from transformers.cache_utils import Cache
10
10
  from transformers.cache_utils import HybridCache
11
11
  from transformers.modeling_outputs import CausalLMOutputWithPast
12
- from transformers.models.gemma3.modeling_gemma3 import _CONFIG_FOR_DOC
13
- from transformers.models.gemma3.modeling_gemma3 import GEMMA3_INPUTS_DOCSTRING
14
12
  from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast
15
- from transformers.utils import add_start_docstrings_to_model_forward
16
13
  from transformers.utils import is_torchdynamo_compiling
17
14
  from transformers.utils import logging
18
- from transformers.utils import replace_return_docstrings
19
15
  from transformers.utils.deprecation import deprecate_kwarg
20
16
 
21
17
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
@@ -25,8 +21,6 @@ logger = logging.get_logger(__name__)
25
21
 
26
22
 
27
23
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
28
- @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
29
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
30
24
  def causal_forward(
31
25
  self,
32
26
  input_ids: torch.LongTensor = None,
@@ -141,8 +135,6 @@ def causal_forward(
141
135
 
142
136
 
143
137
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
144
- @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
145
- @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
146
138
  def multimodal_forward(
147
139
  self,
148
140
  input_ids: torch.LongTensor = None,
@@ -6,18 +6,12 @@ from typing import Union
6
6
  import torch
7
7
 
8
8
  from transformers.modeling_outputs import CausalLMOutputWithPast
9
- from transformers.models.glm4.modeling_glm4 import _CONFIG_FOR_DOC
10
- from transformers.models.glm4.modeling_glm4 import GLM4_INPUTS_DOCSTRING
11
- from transformers.utils import add_start_docstrings_to_model_forward
12
- from transformers.utils import replace_return_docstrings
13
9
  from transformers.utils.deprecation import deprecate_kwarg
14
10
 
15
11
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
12
 
17
13
 
18
14
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
19
- @add_start_docstrings_to_model_forward(GLM4_INPUTS_DOCSTRING)
20
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
21
15
  def lce_forward(
22
16
  self,
23
17
  input_ids: torch.LongTensor = None,
@@ -7,23 +7,23 @@ from typing import Union
7
7
  import torch
8
8
  import torch.nn.functional as F
9
9
 
10
+ from torch.distributed.fsdp import FullyShardedDataParallel
10
11
  from torch.nn import CrossEntropyLoss
11
12
  from transformers.modeling_outputs import CausalLMOutputWithPast
12
- from transformers.models.llama.modeling_llama import _CONFIG_FOR_DOC
13
- from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING
14
- from transformers.utils import add_start_docstrings_to_model_forward
15
- from transformers.utils import replace_return_docstrings
16
13
  from transformers.utils.deprecation import deprecate_kwarg
17
14
 
15
+ from liger_kernel.transformers.fsdp import _FSDPForwardRedirection
18
16
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
19
17
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
+ from liger_kernel.utils import PEFT_AVAILABLE
20
19
 
21
20
  if TYPE_CHECKING:
22
21
  from transformers.cache_utils import Cache
23
22
 
23
+ if PEFT_AVAILABLE:
24
+ from peft.utils.other import ModulesToSaveWrapper
25
+
24
26
 
25
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
26
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
27
27
  def lce_forward_deprecated(
28
28
  self,
29
29
  input_ids: torch.LongTensor = None,
@@ -137,8 +137,6 @@ def lce_forward_deprecated(
137
137
 
138
138
 
139
139
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
140
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
141
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
142
140
  def lce_forward(
143
141
  self,
144
142
  input_ids: torch.LongTensor = None,
@@ -221,12 +219,12 @@ def lce_forward(
221
219
  loss = None
222
220
  # if in training mode, don't materialize logits
223
221
  if self.training and (labels is not None or shift_labels is not None):
224
- loss = LigerForCausalLMLoss(
222
+ loss = lce_maybe_trainable_lm_head(
223
+ self,
225
224
  hidden_states=kept_hidden_states,
226
- lm_head_weight=self.lm_head.weight,
225
+ hidden_size=self.config.hidden_size,
227
226
  labels=labels,
228
227
  shift_labels=shift_labels,
229
- hidden_size=self.config.hidden_size,
230
228
  **loss_kwargs,
231
229
  )
232
230
 
@@ -251,3 +249,50 @@ def lce_forward(
251
249
  hidden_states=outputs.hidden_states,
252
250
  attentions=outputs.attentions,
253
251
  )
252
+
253
+
254
+ def lce_maybe_trainable_lm_head(self, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs):
255
+ lm_head = self.lm_head
256
+
257
+ # Unwrap the module if lm_head has been added as trainable module in PEFT LoRA configuration,
258
+ # i.e. listed in the modules_to_save field of LoraConfig, so the lm_head weights are read
259
+ # from the unwrapped module.
260
+ # See https://huggingface.co/docs/peft/package_reference/lora for reference.
261
+ if PEFT_AVAILABLE and isinstance(lm_head, ModulesToSaveWrapper):
262
+ lm_head = lm_head.modules_to_save.default
263
+
264
+ # If FSDP is used and lm_head is trainable, e.g., during full fine-tuning or with LoRA,
265
+ # reading the lm_head module weights and calling the kernel must be done within FSDP forward pass
266
+ # so the module entire parameters are summoned and kept in memory during the kernel execution.
267
+ if isinstance(lm_head, FullyShardedDataParallel):
268
+ return _FSDPForwardRedirection()(
269
+ lm_head,
270
+ _liger_for_causal_lm_loss,
271
+ lm_head.module,
272
+ hidden_states,
273
+ hidden_size,
274
+ labels,
275
+ shift_labels,
276
+ **loss_kwargs,
277
+ )
278
+
279
+ # FSDP is not used so we can read the lm_head weights and call the kernel directly
280
+ return _liger_for_causal_lm_loss(
281
+ lm_head=self.lm_head,
282
+ hidden_states=hidden_states,
283
+ hidden_size=hidden_size,
284
+ labels=labels,
285
+ shift_labels=shift_labels,
286
+ **loss_kwargs,
287
+ )
288
+
289
+
290
+ def _liger_for_causal_lm_loss(lm_head, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs):
291
+ return LigerForCausalLMLoss(
292
+ hidden_states=hidden_states,
293
+ lm_head_weight=lm_head.weight,
294
+ labels=labels,
295
+ hidden_size=hidden_size,
296
+ shift_labels=shift_labels,
297
+ **loss_kwargs,
298
+ )
@@ -5,19 +5,13 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from transformers.models.llava.modeling_llava import _CONFIG_FOR_DOC
9
- from transformers.models.llava.modeling_llava import LLAVA_INPUTS_DOCSTRING
10
8
  from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
11
- from transformers.utils import add_start_docstrings_to_model_forward
12
9
  from transformers.utils import is_torchdynamo_compiling
13
- from transformers.utils import replace_return_docstrings
14
10
  from transformers.utils.deprecation import deprecate_kwarg
15
11
 
16
12
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
13
 
18
14
 
19
- @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
20
- @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
21
15
  def lce_forward_deprecated(
22
16
  self,
23
17
  input_ids: torch.LongTensor = None,
@@ -210,9 +204,7 @@ def lce_forward_deprecated(
210
204
  )
211
205
 
212
206
 
213
- @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
214
207
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
215
- @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
216
208
  def lce_forward(
217
209
  self,
218
210
  input_ids: torch.LongTensor = None,
@@ -7,18 +7,12 @@ import torch
7
7
 
8
8
  from transformers.cache_utils import Cache
9
9
  from transformers.modeling_outputs import CausalLMOutputWithPast
10
- from transformers.models.mistral.modeling_mistral import _CONFIG_FOR_DOC
11
- from transformers.models.mistral.modeling_mistral import MISTRAL_INPUTS_DOCSTRING
12
- from transformers.utils import add_start_docstrings_to_model_forward
13
- from transformers.utils import replace_return_docstrings
14
10
  from transformers.utils.deprecation import deprecate_kwarg
15
11
 
16
12
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
17
13
 
18
14
 
19
15
  @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
20
- @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
21
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
22
16
  def lce_forward(
23
17
  self,
24
18
  input_ids: torch.LongTensor = None,