liger-kernel-nightly 0.6.2.dev20250919191028__py3-none-any.whl → 0.6.4.dev20251202054858__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (67) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +13 -4
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
  4. liger_kernel/chunked_loss/grpo_loss.py +8 -5
  5. liger_kernel/chunked_loss/jsd_loss.py +18 -5
  6. liger_kernel/ops/cross_entropy.py +120 -63
  7. liger_kernel/ops/dyt.py +5 -2
  8. liger_kernel/ops/fused_add_rms_norm.py +5 -1
  9. liger_kernel/ops/fused_linear_cross_entropy.py +43 -12
  10. liger_kernel/ops/geglu.py +2 -1
  11. liger_kernel/ops/group_norm.py +2 -1
  12. liger_kernel/ops/grpo_loss.py +3 -1
  13. liger_kernel/ops/layer_norm.py +88 -70
  14. liger_kernel/ops/poly_norm.py +390 -0
  15. liger_kernel/ops/rms_norm.py +7 -2
  16. liger_kernel/ops/tiled_mlp.py +136 -0
  17. liger_kernel/ops/utils.py +2 -0
  18. liger_kernel/transformers/__init__.py +33 -0
  19. liger_kernel/transformers/cross_entropy.py +8 -3
  20. liger_kernel/transformers/functional.py +29 -6
  21. liger_kernel/transformers/fused_linear_cross_entropy.py +8 -3
  22. liger_kernel/transformers/grpo_loss.py +56 -1
  23. liger_kernel/transformers/model/falcon_h1.py +122 -0
  24. liger_kernel/transformers/model/gemma.py +19 -7
  25. liger_kernel/transformers/model/gemma2.py +22 -7
  26. liger_kernel/transformers/model/gemma3.py +52 -14
  27. liger_kernel/transformers/model/glm4.py +18 -5
  28. liger_kernel/transformers/model/glm4v.py +18 -5
  29. liger_kernel/transformers/model/glm4v_moe.py +25 -5
  30. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  31. liger_kernel/transformers/model/internvl.py +157 -0
  32. liger_kernel/transformers/model/llama.py +16 -6
  33. liger_kernel/transformers/model/llama4.py +18 -5
  34. liger_kernel/transformers/model/llava.py +18 -6
  35. liger_kernel/transformers/model/loss_utils.py +31 -3
  36. liger_kernel/transformers/model/mistral.py +17 -7
  37. liger_kernel/transformers/model/mixtral.py +24 -9
  38. liger_kernel/transformers/model/mllama.py +14 -5
  39. liger_kernel/transformers/model/olmo2.py +18 -5
  40. liger_kernel/transformers/model/olmo3.py +142 -0
  41. liger_kernel/transformers/model/output_classes.py +147 -0
  42. liger_kernel/transformers/model/paligemma.py +41 -5
  43. liger_kernel/transformers/model/phi3.py +16 -8
  44. liger_kernel/transformers/model/qwen2.py +18 -4
  45. liger_kernel/transformers/model/qwen2_5_vl.py +21 -8
  46. liger_kernel/transformers/model/qwen2_vl.py +24 -7
  47. liger_kernel/transformers/model/qwen3.py +22 -6
  48. liger_kernel/transformers/model/qwen3_moe.py +27 -7
  49. liger_kernel/transformers/model/qwen3_next.py +146 -0
  50. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  51. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  52. liger_kernel/transformers/model/smollm3.py +17 -7
  53. liger_kernel/transformers/model/smolvlm.py +158 -0
  54. liger_kernel/transformers/monkey_patch.py +729 -4
  55. liger_kernel/transformers/poly_norm.py +42 -0
  56. liger_kernel/transformers/rms_norm.py +7 -0
  57. liger_kernel/transformers/rope.py +43 -0
  58. liger_kernel/transformers/swiglu.py +17 -0
  59. liger_kernel/transformers/tiled_mlp.py +133 -0
  60. liger_kernel/utils.py +25 -0
  61. {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/METADATA +13 -6
  62. liger_kernel_nightly-0.6.4.dev20251202054858.dist-info/RECORD +118 -0
  63. liger_kernel_nightly-0.6.2.dev20250919191028.dist-info/RECORD +0 -105
  64. {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/LICENSE +0 -0
  65. {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/NOTICE +0 -0
  66. {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/WHEEL +0 -0
  67. {liger_kernel_nightly-0.6.2.dev20250919191028.dist-info → liger_kernel_nightly-0.6.4.dev20251202054858.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,136 @@
1
+ import math
2
+
3
+ from typing import Callable
4
+ from typing import List
5
+ from typing import Optional
6
+
7
+ import torch
8
+
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+
11
+
12
+ class LigerTiledMLPFunction(torch.autograd.Function):
13
+ """
14
+ Based on DeepSpeed's TiledMLP:
15
+ https://github.com/deepspeedai/DeepSpeed/blob/v0.18.2/deepspeed/runtime/sequence_parallel/ulysses_sp.py#L838
16
+
17
+ Perform a tiled MLP computation to massively reduce memory usage needed to compute MLP
18
+ when using very long sequence lengths.
19
+
20
+ This module re-computes `forward` in the `backward`. So the `forward` occurs twice each iteration.
21
+ And if you're using activation checkpointing it then occurs thrice.
22
+
23
+ Args:
24
+ fn: the function to call on sharded inputs (e.g., mlp.forward)
25
+ mlp_module: the MLP nn.Module object
26
+ x: the input to MLP.forward (hidden_states)
27
+ shards: how many shards to use
28
+ compute_params: a list of weights engaged in the compute
29
+
30
+ Returns:
31
+ the computed hidden_states
32
+ """
33
+
34
+ @staticmethod
35
+ @ensure_contiguous
36
+ def forward(
37
+ ctx,
38
+ fn: Callable,
39
+ mlp_module: torch.nn.Module,
40
+ x: torch.Tensor,
41
+ shards: int,
42
+ compute_params: Optional[List[torch.nn.Parameter]] = None,
43
+ ) -> torch.Tensor:
44
+ ctx.fn = fn
45
+ ctx.mlp_module = mlp_module
46
+ ctx.shards = shards
47
+ ctx.save_for_backward(x)
48
+
49
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
50
+ x_shards = list(torch.chunk(x, chunks=shards, dim=-2))
51
+ with torch.no_grad():
52
+ output_shards = [fn(mlp_module, x_shard) for x_shard in x_shards]
53
+ output_unsharded = torch.cat(output_shards, dim=-2)
54
+
55
+ return output_unsharded
56
+
57
+ @staticmethod
58
+ @ensure_contiguous
59
+ def backward(ctx, *grads) -> tuple:
60
+ fn = ctx.fn
61
+ (x,) = ctx.saved_tensors
62
+ mlp_module = ctx.mlp_module
63
+ shards = ctx.shards
64
+
65
+ x_requires_grad = x.requires_grad
66
+ x = x.detach()
67
+ # detach() unsets x.requires_grad, so restore it
68
+ x.requires_grad_(x_requires_grad)
69
+
70
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
71
+ hidden_size = x.shape[-1]
72
+ x_shape_orig = x.shape
73
+
74
+ # flatten bs+seqlen to avoid having stride issues when narrowing into seqlen w/ bs>1
75
+ x = x.view(-1, hidden_size)
76
+ incoming_grad = grads[0].view(-1, hidden_size)
77
+ x_grad = torch.zeros_like(x)
78
+
79
+ x_shards = list(torch.chunk(x, chunks=shards, dim=0))
80
+
81
+ for i, x_shard in enumerate(x_shards):
82
+ x_shard.requires_grad_(x_requires_grad)
83
+
84
+ # if seqlen is not exactly divisible by shards the last step will be shorter than shard_step
85
+ shard_step = x_shards[i].shape[0]
86
+ shard_offset = i * x_shards[0].shape[0]
87
+
88
+ x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
89
+ incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
90
+
91
+ with torch.enable_grad():
92
+ output = fn(mlp_module, x_shard)
93
+ torch.autograd.backward(output, incoming_grad_shard)
94
+
95
+ # unflatten
96
+ x_grad = x_grad.view(x_shape_orig)
97
+
98
+ return (None, None, x_grad, None, None)
99
+
100
+
101
+ def apply_tiled_mlp(
102
+ fn: Callable,
103
+ mlp_module: torch.nn.Module,
104
+ x: torch.Tensor,
105
+ num_shards: Optional[int] = None,
106
+ compute_params: Optional[List[torch.nn.Parameter]] = None,
107
+ ) -> torch.Tensor:
108
+ """
109
+ Apply tiled MLP computation for memory efficiency.
110
+
111
+ Args:
112
+ fn: the function to call on sharded inputs (e.g., lambda module, x: module(x))
113
+ mlp_module: the MLP nn.Module object
114
+ x: the input tensor with shape [bs, seqlen, hidden_size] or [seqlen, hidden_size]
115
+ num_shards: number of shards to use. If None, automatically calculated as ceil(seqlen / hidden_size)
116
+ compute_params: list of parameters for DeepSpeed ZeRO optimization
117
+
118
+ Returns:
119
+ output tensor with the same shape as input
120
+ """
121
+ if num_shards is None:
122
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size]
123
+ hidden_size = x.shape[-1]
124
+ seqlen = x.shape[-2]
125
+ num_shards = math.ceil(seqlen / hidden_size)
126
+
127
+ # Ensure num_shards is at least 1
128
+ num_shards = max(1, num_shards)
129
+
130
+ return LigerTiledMLPFunction.apply(
131
+ fn,
132
+ mlp_module,
133
+ x,
134
+ num_shards,
135
+ compute_params,
136
+ )
liger_kernel/ops/utils.py CHANGED
@@ -78,6 +78,8 @@ def get_amp_custom_fwd_bwd() -> Callable:
78
78
  functools.partial(torch.amp.custom_fwd, device_type=device),
79
79
  functools.partial(torch.amp.custom_bwd, device_type=device),
80
80
  )
81
+ if hasattr(torch, "npu") and getattr(torch.npu, "amp", None) is not None:
82
+ return torch.npu.amp.custom_fwd, torch.npu.amp.custom_bwd
81
83
  return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
82
84
 
83
85
 
@@ -15,6 +15,7 @@ from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
15
15
  from liger_kernel.transformers.llama4_rope import liger_llama4_text_rotary_pos_emb # noqa: F401
16
16
  from liger_kernel.transformers.llama4_rope import liger_llama4_vision_rotary_pos_emb # noqa: F401
17
17
  from liger_kernel.transformers.multi_token_attention import LigerMultiTokenAttention # noqa: F401
18
+ from liger_kernel.transformers.poly_norm import LigerPolyNorm # noqa: F401
18
19
  from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
19
20
  from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
20
21
  from liger_kernel.transformers.softmax import LigerSoftmax # noqa: F401
@@ -23,6 +24,8 @@ from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F4
23
24
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
24
25
  from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP # noqa: F401
25
26
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
27
+ from liger_kernel.transformers.tiled_mlp import LigerTiledGEGLUMLP # noqa: F401
28
+ from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP # noqa: F401
26
29
  from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401
27
30
 
28
31
  # Static-only imports for IDEs and type checkers
@@ -30,6 +33,7 @@ if TYPE_CHECKING:
30
33
  from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
31
34
  from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
32
35
  from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
36
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_falcon_h1 # noqa: F401
33
37
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
34
38
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
35
39
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401
@@ -38,6 +42,9 @@ if TYPE_CHECKING:
38
42
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
39
43
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401
40
44
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
45
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_dense # noqa: F401
46
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_moe # noqa: F401
47
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_internvl # noqa: F401
41
48
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
42
49
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama4 # noqa: F401
43
50
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
@@ -45,6 +52,7 @@ if TYPE_CHECKING:
45
52
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
46
53
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
47
54
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo2 # noqa: F401
55
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo3 # noqa: F401
48
56
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_paligemma # noqa: F401
49
57
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
50
58
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
@@ -52,7 +60,11 @@ if TYPE_CHECKING:
52
60
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
53
61
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401
54
62
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_moe # noqa: F401
63
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_next # noqa: F401
64
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl # noqa: F401
65
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl_moe # noqa: F401
55
66
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smollm3 # noqa: F401
67
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smolvlm # noqa: F401
56
68
 
57
69
 
58
70
  # Check if 'transformers' is installed
@@ -90,6 +102,7 @@ def __getattr__(name: str):
90
102
  monkey_patch_symbols = {
91
103
  "_apply_liger_kernel",
92
104
  "_apply_liger_kernel_to_instance",
105
+ "apply_liger_kernel_to_falcon_h1",
93
106
  "apply_liger_kernel_to_gemma",
94
107
  "apply_liger_kernel_to_gemma2",
95
108
  "apply_liger_kernel_to_gemma3",
@@ -98,6 +111,7 @@ def __getattr__(name: str):
98
111
  "apply_liger_kernel_to_glm4v",
99
112
  "apply_liger_kernel_to_glm4v_moe",
100
113
  "apply_liger_kernel_to_granite",
114
+ "apply_liger_kernel_to_internvl",
101
115
  "apply_liger_kernel_to_llama",
102
116
  "apply_liger_kernel_to_llava",
103
117
  "apply_liger_kernel_to_llama4",
@@ -105,6 +119,7 @@ def __getattr__(name: str):
105
119
  "apply_liger_kernel_to_mixtral",
106
120
  "apply_liger_kernel_to_mllama",
107
121
  "apply_liger_kernel_to_olmo2",
122
+ "apply_liger_kernel_to_olmo3",
108
123
  "apply_liger_kernel_to_paligemma",
109
124
  "apply_liger_kernel_to_phi3",
110
125
  "apply_liger_kernel_to_qwen2",
@@ -112,7 +127,13 @@ def __getattr__(name: str):
112
127
  "apply_liger_kernel_to_qwen2_vl",
113
128
  "apply_liger_kernel_to_qwen3",
114
129
  "apply_liger_kernel_to_qwen3_moe",
130
+ "apply_liger_kernel_to_qwen3_next",
131
+ "apply_liger_kernel_to_qwen3_vl",
132
+ "apply_liger_kernel_to_qwen3_vl_moe",
115
133
  "apply_liger_kernel_to_smollm3",
134
+ "apply_liger_kernel_to_smolvlm",
135
+ "apply_liger_kernel_to_hunyuan_v1_dense",
136
+ "apply_liger_kernel_to_hunyuan_v1_moe",
116
137
  }
117
138
 
118
139
  if name in monkey_patch_symbols:
@@ -133,6 +154,7 @@ __all__ = [
133
154
  "LigerJSD",
134
155
  "LigerLayerNorm",
135
156
  "LigerFusedAddRMSNorm",
157
+ "LigerPolyNorm",
136
158
  "LigerRMSNorm",
137
159
  "liger_rotary_pos_emb",
138
160
  "liger_llama4_text_rotary_pos_emb",
@@ -141,6 +163,8 @@ __all__ = [
141
163
  "LigerPhi3SwiGLUMLP",
142
164
  "LigerQwen3MoeSwiGLUMLP",
143
165
  "LigerSwiGLUMLP",
166
+ "LigerTiledGEGLUMLP",
167
+ "LigerTiledSwiGLUMLP",
144
168
  "LigerTVDLoss",
145
169
  "LigerKLDIVLoss",
146
170
  "LigerMultiTokenAttention",
@@ -155,6 +179,7 @@ if _TRANSFORMERS_AVAILABLE:
155
179
  "AutoLigerKernelForCausalLM",
156
180
  "_apply_liger_kernel",
157
181
  "_apply_liger_kernel_to_instance",
182
+ "apply_liger_kernel_to_falcon_h1",
158
183
  "apply_liger_kernel_to_gemma",
159
184
  "apply_liger_kernel_to_gemma2",
160
185
  "apply_liger_kernel_to_gemma3",
@@ -163,6 +188,7 @@ if _TRANSFORMERS_AVAILABLE:
163
188
  "apply_liger_kernel_to_glm4v",
164
189
  "apply_liger_kernel_to_glm4v_moe",
165
190
  "apply_liger_kernel_to_granite",
191
+ "apply_liger_kernel_to_internvl",
166
192
  "apply_liger_kernel_to_llama",
167
193
  "apply_liger_kernel_to_llava",
168
194
  "apply_liger_kernel_to_llama4",
@@ -170,6 +196,7 @@ if _TRANSFORMERS_AVAILABLE:
170
196
  "apply_liger_kernel_to_mixtral",
171
197
  "apply_liger_kernel_to_mllama",
172
198
  "apply_liger_kernel_to_olmo2",
199
+ "apply_liger_kernel_to_olmo3",
173
200
  "apply_liger_kernel_to_paligemma",
174
201
  "apply_liger_kernel_to_phi3",
175
202
  "apply_liger_kernel_to_qwen2",
@@ -177,6 +204,12 @@ if _TRANSFORMERS_AVAILABLE:
177
204
  "apply_liger_kernel_to_qwen2_vl",
178
205
  "apply_liger_kernel_to_qwen3",
179
206
  "apply_liger_kernel_to_qwen3_moe",
207
+ "apply_liger_kernel_to_qwen3_next",
208
+ "apply_liger_kernel_to_qwen3_vl",
209
+ "apply_liger_kernel_to_qwen3_vl_moe",
180
210
  "apply_liger_kernel_to_smollm3",
211
+ "apply_liger_kernel_to_smolvlm",
212
+ "apply_liger_kernel_to_hunyuan_v1_dense",
213
+ "apply_liger_kernel_to_hunyuan_v1_moe",
181
214
  ]
182
215
  )
@@ -3,6 +3,7 @@ from typing import Optional
3
3
  import torch
4
4
 
5
5
  from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
6
+ from liger_kernel.transformers.functional import CrossEntropyOutput
6
7
 
7
8
 
8
9
  class LigerCrossEntropyLoss(torch.nn.Module):
@@ -15,6 +16,7 @@ class LigerCrossEntropyLoss(torch.nn.Module):
15
16
  reduction: str = "mean",
16
17
  softcap: Optional[float] = None,
17
18
  return_z_loss: bool = False,
19
+ return_token_accuracy: bool = False,
18
20
  ):
19
21
  super().__init__()
20
22
  assert (label_smoothing >= 0) and (label_smoothing <= 1), (
@@ -33,9 +35,10 @@ class LigerCrossEntropyLoss(torch.nn.Module):
33
35
  self.reduction = reduction
34
36
  self.softcap = softcap
35
37
  self.return_z_loss = return_z_loss
38
+ self.return_token_accuracy = return_token_accuracy
36
39
 
37
40
  def forward(self, _input: torch.Tensor, target: torch.Tensor):
38
- loss, z_loss = LigerCrossEntropyFunction.apply(
41
+ loss, z_loss, token_accuracy = LigerCrossEntropyFunction.apply(
39
42
  _input,
40
43
  target,
41
44
  self.weight,
@@ -45,7 +48,9 @@ class LigerCrossEntropyLoss(torch.nn.Module):
45
48
  self.reduction,
46
49
  self.softcap,
47
50
  self.return_z_loss,
51
+ self.return_token_accuracy,
48
52
  )
49
- if not self.return_z_loss:
53
+ if not self.return_z_loss and not self.return_token_accuracy:
50
54
  return loss
51
- return loss, z_loss
55
+
56
+ return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy)
@@ -1,5 +1,8 @@
1
+ from dataclasses import dataclass
1
2
  from typing import Optional
2
3
 
4
+ import torch
5
+
3
6
  from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
4
7
  from liger_kernel.ops.dyt import LigerDyTFunction
5
8
  from liger_kernel.ops.fused_add_rms_norm import LigerFusedAddRMSNormFunction
@@ -12,6 +15,7 @@ from liger_kernel.ops.jsd import LigerJSDFunction
12
15
  from liger_kernel.ops.kl_div import LigerKLDivLossFunction
13
16
  from liger_kernel.ops.layer_norm import LigerLayerNormFunction
14
17
  from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction
18
+ from liger_kernel.ops.poly_norm import LigerPolyNormFunction
15
19
  from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
16
20
  from liger_kernel.ops.rms_norm import LigerRMSNormFunction
17
21
  from liger_kernel.ops.rope import LigerRopeFunction
@@ -21,6 +25,13 @@ from liger_kernel.ops.swiglu import LigerSiLUMulFunction
21
25
  from liger_kernel.ops.tvd import LigerTVDLossFunction
22
26
 
23
27
 
28
+ @dataclass
29
+ class CrossEntropyOutput:
30
+ loss: torch.Tensor
31
+ z_loss: Optional[torch.Tensor] = None
32
+ token_accuracy: Optional[torch.Tensor] = None
33
+
34
+
24
35
  # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
25
36
  # `weight` and `size_average` are placeholders and not implemented yet
26
37
  def liger_cross_entropy(
@@ -35,8 +46,9 @@ def liger_cross_entropy(
35
46
  lse_square_scale: float = 0.0,
36
47
  softcap: Optional[float] = None,
37
48
  return_z_loss: bool = False,
49
+ return_token_accuracy: bool = False,
38
50
  ):
39
- loss, z_loss = LigerCrossEntropyFunction.apply(
51
+ loss, z_loss, token_accuracy = LigerCrossEntropyFunction.apply(
40
52
  input,
41
53
  target,
42
54
  weight,
@@ -46,10 +58,13 @@ def liger_cross_entropy(
46
58
  reduction,
47
59
  softcap,
48
60
  return_z_loss,
61
+ return_token_accuracy,
49
62
  )
50
- if not return_z_loss:
63
+
64
+ if not return_z_loss and not return_token_accuracy:
51
65
  return loss
52
- return loss, z_loss
66
+
67
+ return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy)
53
68
 
54
69
 
55
70
  def liger_fused_linear_cross_entropy(
@@ -66,8 +81,9 @@ def liger_fused_linear_cross_entropy(
66
81
  return_z_loss: bool = False,
67
82
  accum_dtype=None,
68
83
  use_token_scaling: bool = False,
84
+ return_token_accuracy: bool = False,
69
85
  ):
70
- loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply(
86
+ loss, z_loss, token_accuracy = LigerFusedLinearCrossEntropyFunction.apply(
71
87
  input,
72
88
  weight,
73
89
  target,
@@ -81,10 +97,13 @@ def liger_fused_linear_cross_entropy(
81
97
  return_z_loss,
82
98
  accum_dtype,
83
99
  use_token_scaling,
100
+ return_token_accuracy,
84
101
  )
85
- if not return_z_loss:
102
+
103
+ if not return_z_loss and not return_token_accuracy:
86
104
  return loss
87
- return loss, z_loss
105
+
106
+ return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy)
88
107
 
89
108
 
90
109
  def liger_fused_linear_jsd(
@@ -258,6 +277,10 @@ def liger_rms_norm(X, W, eps, offset: float = 0.0, casting_mode: str = "llama",
258
277
  return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place)
259
278
 
260
279
 
280
+ def liger_poly_norm(X, W, B, eps=1e-6, in_place=True):
281
+ return LigerPolyNormFunction.apply(X, W, B, eps, in_place)
282
+
283
+
261
284
  def liger_fused_add_rms_norm(X, R, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True):
262
285
  return LigerFusedAddRMSNormFunction.apply(X, R, W, eps, offset, casting_mode, in_place)
263
286
 
@@ -3,6 +3,7 @@ from typing import Optional
3
3
  import torch
4
4
 
5
5
  from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
6
+ from liger_kernel.transformers.functional import CrossEntropyOutput
6
7
 
7
8
 
8
9
  class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
@@ -17,6 +18,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
17
18
  return_z_loss: bool = False,
18
19
  accum_dtype: Optional[torch.dtype] = None,
19
20
  use_token_scaling: bool = False,
21
+ return_token_accuracy: bool = False,
20
22
  ):
21
23
  super().__init__()
22
24
  assert (label_smoothing >= 0) and (label_smoothing <= 1), (
@@ -37,9 +39,10 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
37
39
  self.return_z_loss = return_z_loss
38
40
  self.accum_dtype = accum_dtype
39
41
  self.use_token_scaling = use_token_scaling
42
+ self.return_token_accuracy = return_token_accuracy
40
43
 
41
44
  def forward(self, lin_weight, _input, target, bias=None):
42
- loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply(
45
+ loss, z_loss, token_accuracy = LigerFusedLinearCrossEntropyFunction.apply(
43
46
  _input,
44
47
  lin_weight,
45
48
  target,
@@ -53,7 +56,9 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
53
56
  self.return_z_loss,
54
57
  self.accum_dtype,
55
58
  self.use_token_scaling,
59
+ self.return_token_accuracy,
56
60
  )
57
- if not self.return_z_loss:
61
+ if not self.return_z_loss and not self.return_token_accuracy:
58
62
  return loss
59
- return loss, z_loss
63
+
64
+ return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy)
@@ -1,3 +1,6 @@
1
+ import torch
2
+
3
+ from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase
1
4
  from liger_kernel.ops.grpo_loss import GrpoLossFunction
2
5
 
3
6
 
@@ -13,12 +16,20 @@ def triton_grpo_loss(
13
16
  eps_low=0.2,
14
17
  eps_high=0.4,
15
18
  inplace=True,
19
+ loss_type="dapo",
20
+ max_completion_length=None,
21
+ importance_sampling_level="token",
22
+ reduce=False,
16
23
  ):
17
24
  assert logits is not None and completion_ids is not None and advantages is not None, (
18
25
  "must provide logits、completion_ids and advantages"
19
26
  )
27
+ if importance_sampling_level != "token":
28
+ raise ValueError(
29
+ f"Triton GRPO loss only supports token-level importance sampling. Got {importance_sampling_level}."
30
+ )
20
31
 
21
- return GrpoLossFunction.apply(
32
+ per_token_loss, per_token_kl, is_clipped = GrpoLossFunction.apply(
22
33
  logits,
23
34
  old_logp,
24
35
  ref_logp,
@@ -31,6 +42,50 @@ def triton_grpo_loss(
31
42
  eps_high,
32
43
  inplace,
33
44
  )
45
+ if not reduce:
46
+ return per_token_loss, per_token_kl, is_clipped
47
+
48
+ loss = _reduce_grpo_loss(
49
+ per_token_loss,
50
+ completion_mask,
51
+ loss_type=loss_type,
52
+ max_completion_length=max_completion_length,
53
+ )
54
+
55
+ metrics = []
56
+ if beta != 0.0 and per_token_kl is not None:
57
+ metrics.append(_masked_mean(per_token_kl, completion_mask))
58
+ metrics.append(_masked_mean(is_clipped.float(), completion_mask))
59
+ return loss, metrics
60
+
61
+
62
+ def _reduce_grpo_loss(per_token_loss, completion_mask, loss_type, max_completion_length):
63
+ mask = completion_mask
64
+ if mask is None:
65
+ mask = torch.ones_like(per_token_loss, dtype=per_token_loss.dtype, device=per_token_loss.device)
66
+ mask = mask.to(per_token_loss.dtype)
67
+
68
+ if loss_type == "grpo":
69
+ per_seq = (per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)
70
+ return per_seq.mean()
71
+ if loss_type == "bnpo":
72
+ return (per_token_loss * mask).sum() / mask.sum().clamp(min=1.0)
73
+ if loss_type == "dr_grpo":
74
+ if max_completion_length is None:
75
+ raise ValueError("max_completion_length must be provided when using loss_type='dr_grpo'")
76
+ batch = per_token_loss.shape[0]
77
+ return (per_token_loss * mask).sum() / (batch * max_completion_length)
78
+ if loss_type == "dapo":
79
+ normalizer = LigerFusedLinearPPOBase._compute_dapo_normalizer(mask)
80
+ return (per_token_loss * mask).sum() / normalizer
81
+ raise ValueError(f"Unsupported loss_type '{loss_type}' for Triton GRPO loss.")
82
+
83
+
84
+ def _masked_mean(values, mask):
85
+ if mask is None:
86
+ mask = torch.ones_like(values, dtype=values.dtype, device=values.device)
87
+ mask = mask.to(values.dtype)
88
+ return (values * mask).sum() / mask.sum().clamp(min=1.0)
34
89
 
35
90
 
36
91
  # This is a demo how to use grpo_loss in GRPOTrainer. The Trl version must be 0.16
@@ -0,0 +1,122 @@
1
+ from typing import TYPE_CHECKING
2
+ from typing import Optional
3
+ from typing import Union
4
+
5
+ import torch
6
+
7
+ if TYPE_CHECKING:
8
+ from transformers.models.falcon_h1.modeling_falcon_h1 import FalconHybridMambaAttentionDynamicCache
9
+
10
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11
+ from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12
+ from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
13
+
14
+
15
+ def lce_forward(
16
+ self,
17
+ input_ids: torch.LongTensor = None,
18
+ attention_mask: Optional[torch.Tensor] = None,
19
+ position_ids: Optional[torch.LongTensor] = None,
20
+ past_key_values: Optional["FalconHybridMambaAttentionDynamicCache"] = None,
21
+ inputs_embeds: Optional[torch.FloatTensor] = None,
22
+ labels: Optional[torch.LongTensor] = None,
23
+ use_cache: Optional[bool] = None,
24
+ output_attentions: Optional[bool] = None,
25
+ output_hidden_states: Optional[bool] = None,
26
+ cache_position: Optional[torch.LongTensor] = None,
27
+ logits_to_keep: Union[int, torch.Tensor] = 0,
28
+ skip_logits: Optional[bool] = None,
29
+ return_dict: Optional[bool] = None,
30
+ **kwargs,
31
+ ) -> Union[tuple, LigerCausalLMOutputWithPast]:
32
+ r"""
33
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
34
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
35
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
36
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
37
+
38
+ Example:
39
+
40
+ ```python
41
+ >>> from transformers import AutoTokenizer, FalconH1ForCausalLM
42
+
43
+ >>> model = FalconH1ForCausalLM.from_pretrained("...")
44
+ >>> tokenizer = AutoTokenizer.from_pretrained("...")
45
+
46
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
47
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
48
+
49
+ >>> # Generate
50
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
51
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
52
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
53
+ ```"""
54
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
55
+ output_hidden_states = (
56
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
57
+ )
58
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
59
+
60
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
61
+ outputs = self.model(
62
+ input_ids=input_ids,
63
+ attention_mask=attention_mask,
64
+ position_ids=position_ids,
65
+ past_key_values=past_key_values,
66
+ inputs_embeds=inputs_embeds,
67
+ use_cache=use_cache,
68
+ output_attentions=output_attentions,
69
+ output_hidden_states=output_hidden_states,
70
+ cache_position=cache_position,
71
+ **kwargs,
72
+ )
73
+
74
+ hidden_states = outputs[0]
75
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
76
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
77
+ kept_hidden_states = hidden_states[:, slice_indices, :]
78
+
79
+ shift_labels = kwargs.pop("shift_labels", None)
80
+ logits = None
81
+ loss = None
82
+ token_accuracy = None
83
+
84
+ # if in training mode, don't materialize logits
85
+ if skip_logits and labels is None:
86
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
87
+
88
+ if skip_logits is None:
89
+ # By default, if in training mode, don't materialize logits
90
+ skip_logits = self.training and labels is not None
91
+
92
+ # Compute loss
93
+ if skip_logits:
94
+ result = LigerForCausalLMLoss(
95
+ hidden_states=kept_hidden_states,
96
+ lm_head_weight=self.lm_head.weight,
97
+ labels=labels,
98
+ shift_labels=shift_labels,
99
+ hidden_size=self.config.hidden_size,
100
+ **kwargs,
101
+ )
102
+ loss, _, token_accuracy = unpack_cross_entropy_result(result)
103
+ else:
104
+ logits = self.lm_head(kept_hidden_states)
105
+ if labels is not None or shift_labels is not None:
106
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
107
+
108
+ if not return_dict:
109
+ output = (logits,) + outputs[1:]
110
+ output = ((loss,) + output) if loss is not None else output
111
+ output = output + (token_accuracy,) if token_accuracy is not None else output
112
+ return output
113
+
114
+ # Return custom output class with token_accuracy field
115
+ return LigerCausalLMOutputWithPast(
116
+ loss=loss,
117
+ logits=logits,
118
+ past_key_values=outputs.past_key_values,
119
+ hidden_states=outputs.hidden_states,
120
+ attentions=outputs.attentions,
121
+ token_accuracy=token_accuracy,
122
+ )