liger-kernel-nightly 0.5.6.dev20250403190551__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  3. liger_kernel/chunked_loss/dpo_loss.py +61 -3
  4. liger_kernel/chunked_loss/functional.py +2 -0
  5. liger_kernel/chunked_loss/fused_linear_distillation.py +13 -2
  6. liger_kernel/chunked_loss/fused_linear_ppo.py +35 -0
  7. liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  8. liger_kernel/chunked_loss/grpo_loss.py +76 -5
  9. liger_kernel/chunked_loss/jsd_loss.py +25 -9
  10. liger_kernel/ops/__init__.py +141 -0
  11. liger_kernel/ops/backends/README.md +151 -0
  12. liger_kernel/ops/backends/__init__.py +13 -0
  13. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  14. liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
  15. liger_kernel/ops/backends/registry.py +61 -0
  16. liger_kernel/ops/cross_entropy.py +124 -64
  17. liger_kernel/ops/dyt.py +115 -180
  18. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  19. liger_kernel/ops/fused_linear_cross_entropy.py +115 -22
  20. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  21. liger_kernel/ops/geglu.py +3 -2
  22. liger_kernel/ops/group_norm.py +2 -1
  23. liger_kernel/ops/grpo_loss.py +312 -0
  24. liger_kernel/ops/jsd.py +2 -1
  25. liger_kernel/ops/kl_div.py +13 -6
  26. liger_kernel/ops/layer_norm.py +146 -78
  27. liger_kernel/ops/llama4_rope.py +225 -0
  28. liger_kernel/ops/multi_token_attention.py +207 -0
  29. liger_kernel/ops/poly_norm.py +390 -0
  30. liger_kernel/ops/rms_norm.py +283 -56
  31. liger_kernel/ops/rope.py +1 -1
  32. liger_kernel/ops/softmax.py +201 -0
  33. liger_kernel/ops/sparsemax.py +179 -0
  34. liger_kernel/ops/swiglu.py +1 -1
  35. liger_kernel/ops/tiled_mlp.py +136 -0
  36. liger_kernel/ops/utils.py +2 -0
  37. liger_kernel/transformers/__init__.py +205 -19
  38. liger_kernel/transformers/cross_entropy.py +9 -4
  39. liger_kernel/transformers/dyt.py +6 -4
  40. liger_kernel/transformers/experimental/__init__.py +5 -0
  41. liger_kernel/transformers/experimental/embedding.py +1 -1
  42. liger_kernel/transformers/fsdp.py +55 -0
  43. liger_kernel/transformers/functional.py +122 -20
  44. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  45. liger_kernel/transformers/fused_linear_cross_entropy.py +16 -5
  46. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  47. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  48. liger_kernel/transformers/geglu.py +1 -1
  49. liger_kernel/transformers/group_norm.py +1 -1
  50. liger_kernel/transformers/grpo_loss.py +153 -0
  51. liger_kernel/transformers/jsd.py +1 -1
  52. liger_kernel/transformers/kl_div.py +1 -1
  53. liger_kernel/transformers/layer_norm.py +1 -1
  54. liger_kernel/transformers/llama4_rope.py +93 -0
  55. liger_kernel/transformers/model/falcon_h1.py +122 -0
  56. liger_kernel/transformers/model/gemma.py +50 -25
  57. liger_kernel/transformers/model/gemma2.py +55 -23
  58. liger_kernel/transformers/model/gemma3.py +117 -120
  59. liger_kernel/transformers/model/glm4.py +141 -0
  60. liger_kernel/transformers/model/glm4v.py +163 -0
  61. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  62. liger_kernel/transformers/model/gpt_oss.py +211 -0
  63. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  64. liger_kernel/transformers/model/internvl.py +157 -0
  65. liger_kernel/transformers/model/llama.py +102 -25
  66. liger_kernel/transformers/model/llama4.py +121 -0
  67. liger_kernel/transformers/model/llava.py +111 -136
  68. liger_kernel/transformers/model/loss_utils.py +50 -12
  69. liger_kernel/transformers/model/mistral.py +36 -23
  70. liger_kernel/transformers/model/mixtral.py +45 -25
  71. liger_kernel/transformers/model/mllama.py +39 -22
  72. liger_kernel/transformers/model/olmo2.py +40 -20
  73. liger_kernel/transformers/model/olmo3.py +142 -0
  74. liger_kernel/transformers/model/output_classes.py +147 -0
  75. liger_kernel/transformers/model/paligemma.py +50 -14
  76. liger_kernel/transformers/model/phi3.py +47 -177
  77. liger_kernel/transformers/model/qwen2.py +48 -21
  78. liger_kernel/transformers/model/qwen2_5_vl.py +62 -103
  79. liger_kernel/transformers/model/qwen2_vl.py +59 -108
  80. liger_kernel/transformers/model/qwen3.py +136 -0
  81. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  82. liger_kernel/transformers/model/qwen3_next.py +146 -0
  83. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  84. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  85. liger_kernel/transformers/model/smollm3.py +199 -0
  86. liger_kernel/transformers/model/smolvlm.py +158 -0
  87. liger_kernel/transformers/monkey_patch.py +1678 -160
  88. liger_kernel/transformers/multi_token_attention.py +64 -0
  89. liger_kernel/transformers/poly_norm.py +42 -0
  90. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  91. liger_kernel/transformers/rms_norm.py +48 -5
  92. liger_kernel/transformers/rope.py +45 -1
  93. liger_kernel/transformers/softmax.py +12 -0
  94. liger_kernel/transformers/sparsemax.py +16 -0
  95. liger_kernel/transformers/swiglu.py +39 -1
  96. liger_kernel/transformers/tiled_mlp.py +133 -0
  97. liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  98. liger_kernel/transformers/tvd.py +1 -1
  99. liger_kernel/utils.py +36 -0
  100. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/METADATA +68 -38
  101. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
  102. liger_kernel/transformers/gema3_rms.py +0 -8
  103. liger_kernel_nightly-0.5.6.dev20250403190551.dist-info/RECORD +0 -82
  104. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
  105. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/NOTICE +0 -0
  106. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +0 -0
  107. {liger_kernel_nightly-0.5.6.dev20250403190551.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,179 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import ensure_contiguous
9
+
10
+
11
+ @triton.jit
12
+ def _sparsemax_forward_kernel(
13
+ x_ptr,
14
+ x_stride_row,
15
+ sorted_x_ptr,
16
+ sorted_x_stride_row,
17
+ o_ptr,
18
+ o_stride_row,
19
+ n_cols,
20
+ BLOCK_SIZE: tl.constexpr,
21
+ num_warps: tl.constexpr,
22
+ ):
23
+ pid_row = tl.program_id(0)
24
+ ptr_x_data_row = x_ptr + pid_row * x_stride_row
25
+ ptr_sorted_x_data_row = sorted_x_ptr + pid_row * sorted_x_stride_row
26
+ ptr_output_row = o_ptr + pid_row * o_stride_row
27
+
28
+ offs = tl.arange(0, BLOCK_SIZE)
29
+ mask = offs < n_cols
30
+
31
+ z_sorted_block = tl.load(
32
+ ptr_sorted_x_data_row + offs,
33
+ mask=mask,
34
+ other=-float("inf"),
35
+ cache_modifier=".ca",
36
+ ).to(tl.float32)
37
+
38
+ z_valid = tl.where(mask, z_sorted_block, 0.0)
39
+ cssv = tl.cumsum(z_valid, 0)
40
+
41
+ r = (offs + 1).to(tl.float32)
42
+ safe_r = tl.where(mask, r, 1.0)
43
+
44
+ t_vec = (cssv - 1.0) / safe_r
45
+
46
+ support = (z_sorted_block > t_vec) & mask
47
+
48
+ k_int = tl.sum(support.to(tl.int32), 0)
49
+ k_clamped_int = tl.maximum(k_int, 1)
50
+ k = k_clamped_int.to(tl.float32)
51
+
52
+ s = tl.sum(tl.where(support, z_sorted_block, 0.0), 0)
53
+
54
+ tau = (s - 1.0) / k
55
+
56
+ x_block = tl.load(
57
+ ptr_x_data_row + offs,
58
+ mask=mask,
59
+ other=0.0,
60
+ cache_modifier=".ca",
61
+ ).to(tl.float32)
62
+
63
+ y = tl.maximum(x_block - tau, 0.0)
64
+
65
+ tl.store(
66
+ ptr_output_row + offs,
67
+ y.to(ptr_output_row.dtype.element_ty),
68
+ mask=mask,
69
+ cache_modifier=".cs",
70
+ )
71
+
72
+
73
+ @triton.jit
74
+ def _sparsemax_backward_kernel(
75
+ o_ptr, go_ptr, gi_ptr, stride, n_cols, BLOCK_SIZE: tl.constexpr, num_warps: tl.constexpr
76
+ ):
77
+ row = tl.program_id(0)
78
+ o_row = o_ptr + row * stride
79
+ go_row = go_ptr + row * stride
80
+ gi_row = gi_ptr + row * stride
81
+
82
+ offs = tl.arange(0, BLOCK_SIZE)
83
+
84
+ supp_cnt = tl.zeros((), tl.float32)
85
+ go_sum = tl.zeros((), tl.float32)
86
+
87
+ for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
88
+ offs_iter = i * BLOCK_SIZE + offs
89
+ mask_iter = offs_iter < n_cols
90
+ o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
91
+ go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
92
+ supp = o_val > 0.0
93
+ go_sum += tl.sum(tl.where(supp, go_val, 0.0))
94
+ supp_cnt += tl.sum(supp.to(tl.float32))
95
+
96
+ for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
97
+ offs_iter = i * BLOCK_SIZE + offs
98
+ mask_iter = offs_iter < n_cols
99
+ o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
100
+ go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
101
+ supp = o_val > 0.0
102
+ gi_val = tl.where(
103
+ supp,
104
+ go_val - tl.cast(go_sum / tl.maximum(supp_cnt, 1e-6), gi_row.dtype.element_ty).to(tl.float32),
105
+ 0.0,
106
+ )
107
+ tl.store(gi_row + offs_iter, gi_val.to(gi_row.dtype.element_ty), mask=mask_iter, cache_modifier=".wb")
108
+
109
+
110
+ def _sparsemax_forward(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
111
+ if dim < 0:
112
+ dim += x.dim()
113
+ x_sw = x.transpose(dim, -1).contiguous()
114
+ n_cols = x_sw.size(-1)
115
+ n_rows = x_sw.numel() // n_cols
116
+ x_flat = x_sw.view(n_rows, n_cols)
117
+ x_sorted_flat = torch.sort(x_flat.float(), dim=-1, descending=True).values
118
+
119
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
120
+ out_flat = torch.empty_like(x_flat)
121
+ grid = (n_rows,)
122
+ _sparsemax_forward_kernel[grid](
123
+ x_flat,
124
+ x_flat.stride(0),
125
+ x_sorted_flat,
126
+ x_sorted_flat.stride(0),
127
+ out_flat,
128
+ out_flat.stride(0),
129
+ n_cols,
130
+ BLOCK_SIZE=BLOCK_SIZE,
131
+ num_warps=num_warps,
132
+ )
133
+
134
+ y = out_flat.view_as(x_sw).transpose(dim, -1)
135
+ return y, out_flat
136
+
137
+
138
+ def _sparsemax_backward(
139
+ grad_out: torch.Tensor,
140
+ out_flat: torch.Tensor,
141
+ dim: int,
142
+ ) -> torch.Tensor:
143
+ grad_sw = grad_out.transpose(dim, -1).contiguous()
144
+ n_cols = grad_sw.size(-1)
145
+ n_rows = grad_sw.numel() // n_cols
146
+ go_flat = grad_sw.view(n_rows, n_cols)
147
+
148
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
149
+ dx_flat = torch.empty_like(go_flat)
150
+ grid = (n_rows,)
151
+ _sparsemax_backward_kernel[grid](
152
+ out_flat,
153
+ go_flat,
154
+ dx_flat,
155
+ out_flat.stride(0),
156
+ n_cols,
157
+ BLOCK_SIZE=BLOCK_SIZE,
158
+ num_warps=num_warps,
159
+ )
160
+
161
+ dx = dx_flat.view_as(grad_sw).transpose(dim, -1)
162
+ return dx
163
+
164
+
165
+ class LigerSparsemaxFunction(torch.autograd.Function):
166
+ @staticmethod
167
+ @ensure_contiguous
168
+ def forward(ctx, x: torch.Tensor, dim: int):
169
+ y, out_flat = _sparsemax_forward(x, dim)
170
+ ctx.save_for_backward(out_flat)
171
+ ctx.dim = dim
172
+ return y
173
+
174
+ @staticmethod
175
+ @ensure_contiguous
176
+ def backward(ctx, grad_out: torch.Tensor):
177
+ (out_flat,) = ctx.saved_tensors
178
+ dx = _sparsemax_backward(grad_out, out_flat, ctx.dim)
179
+ return dx, None
@@ -26,7 +26,7 @@ def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BL
26
26
  # sigmoid requires type float32
27
27
  a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
28
28
  b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
29
- c_row = silu(a_row) * b_row
29
+ c_row = silu(a_row).cast(b_row.dtype) * b_row
30
30
  tl.store(c_ptr + col_offsets, c_row, mask=mask)
31
31
 
32
32
 
@@ -0,0 +1,136 @@
1
+ import math
2
+
3
+ from typing import Callable
4
+ from typing import List
5
+ from typing import Optional
6
+
7
+ import torch
8
+
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+
11
+
12
+ class LigerTiledMLPFunction(torch.autograd.Function):
13
+ """
14
+ Based on DeepSpeed's TiledMLP:
15
+ https://github.com/deepspeedai/DeepSpeed/blob/v0.18.2/deepspeed/runtime/sequence_parallel/ulysses_sp.py#L838
16
+
17
+ Perform a tiled MLP computation to massively reduce memory usage needed to compute MLP
18
+ when using very long sequence lengths.
19
+
20
+ This module re-computes `forward` in the `backward`. So the `forward` occurs twice each iteration.
21
+ And if you're using activation checkpointing it then occurs thrice.
22
+
23
+ Args:
24
+ fn: the function to call on sharded inputs (e.g., mlp.forward)
25
+ mlp_module: the MLP nn.Module object
26
+ x: the input to MLP.forward (hidden_states)
27
+ shards: how many shards to use
28
+ compute_params: a list of weights engaged in the compute
29
+
30
+ Returns:
31
+ the computed hidden_states
32
+ """
33
+
34
+ @staticmethod
35
+ @ensure_contiguous
36
+ def forward(
37
+ ctx,
38
+ fn: Callable,
39
+ mlp_module: torch.nn.Module,
40
+ x: torch.Tensor,
41
+ shards: int,
42
+ compute_params: Optional[List[torch.nn.Parameter]] = None,
43
+ ) -> torch.Tensor:
44
+ ctx.fn = fn
45
+ ctx.mlp_module = mlp_module
46
+ ctx.shards = shards
47
+ ctx.save_for_backward(x)
48
+
49
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
50
+ x_shards = list(torch.chunk(x, chunks=shards, dim=-2))
51
+ with torch.no_grad():
52
+ output_shards = [fn(mlp_module, x_shard) for x_shard in x_shards]
53
+ output_unsharded = torch.cat(output_shards, dim=-2)
54
+
55
+ return output_unsharded
56
+
57
+ @staticmethod
58
+ @ensure_contiguous
59
+ def backward(ctx, *grads) -> tuple:
60
+ fn = ctx.fn
61
+ (x,) = ctx.saved_tensors
62
+ mlp_module = ctx.mlp_module
63
+ shards = ctx.shards
64
+
65
+ x_requires_grad = x.requires_grad
66
+ x = x.detach()
67
+ # detach() unsets x.requires_grad, so restore it
68
+ x.requires_grad_(x_requires_grad)
69
+
70
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
71
+ hidden_size = x.shape[-1]
72
+ x_shape_orig = x.shape
73
+
74
+ # flatten bs+seqlen to avoid having stride issues when narrowing into seqlen w/ bs>1
75
+ x = x.view(-1, hidden_size)
76
+ incoming_grad = grads[0].view(-1, hidden_size)
77
+ x_grad = torch.zeros_like(x)
78
+
79
+ x_shards = list(torch.chunk(x, chunks=shards, dim=0))
80
+
81
+ for i, x_shard in enumerate(x_shards):
82
+ x_shard.requires_grad_(x_requires_grad)
83
+
84
+ # if seqlen is not exactly divisible by shards the last step will be shorter than shard_step
85
+ shard_step = x_shards[i].shape[0]
86
+ shard_offset = i * x_shards[0].shape[0]
87
+
88
+ x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
89
+ incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
90
+
91
+ with torch.enable_grad():
92
+ output = fn(mlp_module, x_shard)
93
+ torch.autograd.backward(output, incoming_grad_shard)
94
+
95
+ # unflatten
96
+ x_grad = x_grad.view(x_shape_orig)
97
+
98
+ return (None, None, x_grad, None, None)
99
+
100
+
101
+ def apply_tiled_mlp(
102
+ fn: Callable,
103
+ mlp_module: torch.nn.Module,
104
+ x: torch.Tensor,
105
+ num_shards: Optional[int] = None,
106
+ compute_params: Optional[List[torch.nn.Parameter]] = None,
107
+ ) -> torch.Tensor:
108
+ """
109
+ Apply tiled MLP computation for memory efficiency.
110
+
111
+ Args:
112
+ fn: the function to call on sharded inputs (e.g., lambda module, x: module(x))
113
+ mlp_module: the MLP nn.Module object
114
+ x: the input tensor with shape [bs, seqlen, hidden_size] or [seqlen, hidden_size]
115
+ num_shards: number of shards to use. If None, automatically calculated as ceil(seqlen / hidden_size)
116
+ compute_params: list of parameters for DeepSpeed ZeRO optimization
117
+
118
+ Returns:
119
+ output tensor with the same shape as input
120
+ """
121
+ if num_shards is None:
122
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size]
123
+ hidden_size = x.shape[-1]
124
+ seqlen = x.shape[-2]
125
+ num_shards = math.ceil(seqlen / hidden_size)
126
+
127
+ # Ensure num_shards is at least 1
128
+ num_shards = max(1, num_shards)
129
+
130
+ return LigerTiledMLPFunction.apply(
131
+ fn,
132
+ mlp_module,
133
+ x,
134
+ num_shards,
135
+ compute_params,
136
+ )
liger_kernel/ops/utils.py CHANGED
@@ -78,6 +78,8 @@ def get_amp_custom_fwd_bwd() -> Callable:
78
78
  functools.partial(torch.amp.custom_fwd, device_type=device),
79
79
  functools.partial(torch.amp.custom_bwd, device_type=device),
80
80
  )
81
+ if hasattr(torch, "npu") and getattr(torch.npu, "amp", None) is not None:
82
+ return torch.npu.amp.custom_fwd, torch.npu.amp.custom_bwd
81
83
  return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
82
84
 
83
85
 
@@ -1,32 +1,218 @@
1
- from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
1
+ import importlib
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Always-safe imports (independent of 'transformers')
2
6
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
3
7
  from liger_kernel.transformers.dyt import LigerDyT # noqa: F401
8
+ from liger_kernel.transformers.fused_add_rms_norm import LigerFusedAddRMSNorm # noqa: F401
4
9
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401
5
10
  from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
6
11
  from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
7
12
  from liger_kernel.transformers.jsd import LigerJSD # noqa: F401
13
+ from liger_kernel.transformers.kl_div import LigerKLDIVLoss # noqa: F401
8
14
  from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
9
- from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
10
- from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
11
- from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
12
- from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
13
- from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401
14
- from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401
15
- from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
16
- from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
17
- from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
18
- from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
19
- from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
20
- from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
21
- from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo2 # noqa: F401
22
- from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_paligemma # noqa: F401
23
- from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
24
- from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
25
- from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_5_vl # noqa: F401
26
- from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
15
+ from liger_kernel.transformers.llama4_rope import liger_llama4_text_rotary_pos_emb # noqa: F401
16
+ from liger_kernel.transformers.llama4_rope import liger_llama4_vision_rotary_pos_emb # noqa: F401
17
+ from liger_kernel.transformers.multi_token_attention import LigerMultiTokenAttention # noqa: F401
18
+ from liger_kernel.transformers.poly_norm import LigerPolyNorm # noqa: F401
27
19
  from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
28
20
  from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
21
+ from liger_kernel.transformers.softmax import LigerSoftmax # noqa: F401
22
+ from liger_kernel.transformers.sparsemax import LigerSparsemax # noqa: F401
29
23
  from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
30
24
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
25
+ from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP # noqa: F401
31
26
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
27
+ from liger_kernel.transformers.tiled_mlp import LigerTiledGEGLUMLP # noqa: F401
28
+ from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP # noqa: F401
32
29
  from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401
30
+
31
+ # Static-only imports for IDEs and type checkers
32
+ if TYPE_CHECKING:
33
+ from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
34
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
35
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
36
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_falcon_h1 # noqa: F401
37
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
38
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
39
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401
40
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401
41
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
42
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
43
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401
44
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gpt_oss # noqa: F401
45
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
46
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_dense # noqa: F401
47
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_moe # noqa: F401
48
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_internvl # noqa: F401
49
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
50
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama4 # noqa: F401
51
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
52
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
53
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
54
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
55
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo2 # noqa: F401
56
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo3 # noqa: F401
57
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_paligemma # noqa: F401
58
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
59
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
60
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_5_vl # noqa: F401
61
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
62
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401
63
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_moe # noqa: F401
64
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_next # noqa: F401
65
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl # noqa: F401
66
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_vl_moe # noqa: F401
67
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smollm3 # noqa: F401
68
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smolvlm # noqa: F401
69
+
70
+
71
+ # Check if 'transformers' is installed
72
+ try:
73
+ import transformers # noqa: F401
74
+
75
+ _TRANSFORMERS_AVAILABLE = True
76
+ except ImportError:
77
+ _TRANSFORMERS_AVAILABLE = False
78
+
79
+
80
+ def is_transformers_available() -> bool:
81
+ """
82
+ Returns True if the 'transformers' package is available.
83
+ Useful for conditional logic in downstream code.
84
+ """
85
+ return _TRANSFORMERS_AVAILABLE
86
+
87
+
88
+ def __getattr__(name: str):
89
+ """
90
+ Handles lazy access to transformer-dependent attributes.
91
+ If 'transformers' is not installed, raises a user-friendly ImportError.
92
+ """
93
+ if not _TRANSFORMERS_AVAILABLE:
94
+ raise ImportError(
95
+ f"The attribute '{name}' requires the 'transformers' library, which is not installed.\n"
96
+ f"Please install it with `pip install transformers` to use this functionality."
97
+ )
98
+
99
+ if name == "AutoLigerKernelForCausalLM":
100
+ module = importlib.import_module("liger_kernel.transformers.auto_model")
101
+ return getattr(module, name)
102
+
103
+ monkey_patch_symbols = {
104
+ "_apply_liger_kernel",
105
+ "_apply_liger_kernel_to_instance",
106
+ "apply_liger_kernel_to_falcon_h1",
107
+ "apply_liger_kernel_to_gemma",
108
+ "apply_liger_kernel_to_gemma2",
109
+ "apply_liger_kernel_to_gemma3",
110
+ "apply_liger_kernel_to_gemma3_text",
111
+ "apply_liger_kernel_to_glm4",
112
+ "apply_liger_kernel_to_glm4v",
113
+ "apply_liger_kernel_to_glm4v_moe",
114
+ "apply_liger_kernel_to_gpt_oss",
115
+ "apply_liger_kernel_to_granite",
116
+ "apply_liger_kernel_to_internvl",
117
+ "apply_liger_kernel_to_llama",
118
+ "apply_liger_kernel_to_llava",
119
+ "apply_liger_kernel_to_llama4",
120
+ "apply_liger_kernel_to_mistral",
121
+ "apply_liger_kernel_to_mixtral",
122
+ "apply_liger_kernel_to_mllama",
123
+ "apply_liger_kernel_to_olmo2",
124
+ "apply_liger_kernel_to_olmo3",
125
+ "apply_liger_kernel_to_paligemma",
126
+ "apply_liger_kernel_to_phi3",
127
+ "apply_liger_kernel_to_qwen2",
128
+ "apply_liger_kernel_to_qwen2_5_vl",
129
+ "apply_liger_kernel_to_qwen2_vl",
130
+ "apply_liger_kernel_to_qwen3",
131
+ "apply_liger_kernel_to_qwen3_moe",
132
+ "apply_liger_kernel_to_qwen3_next",
133
+ "apply_liger_kernel_to_qwen3_vl",
134
+ "apply_liger_kernel_to_qwen3_vl_moe",
135
+ "apply_liger_kernel_to_smollm3",
136
+ "apply_liger_kernel_to_smolvlm",
137
+ "apply_liger_kernel_to_hunyuan_v1_dense",
138
+ "apply_liger_kernel_to_hunyuan_v1_moe",
139
+ }
140
+
141
+ if name in monkey_patch_symbols:
142
+ module = importlib.import_module("liger_kernel.transformers.monkey_patch")
143
+ return getattr(module, name)
144
+
145
+ raise AttributeError(f"module {__name__} has no attribute {name}")
146
+
147
+
148
+ # Shared symbols in all environments
149
+ __all__ = [
150
+ "is_transformers_available",
151
+ "LigerCrossEntropyLoss",
152
+ "LigerDyT",
153
+ "LigerFusedLinearCrossEntropyLoss",
154
+ "LigerFusedLinearJSD",
155
+ "LigerGEGLUMLP",
156
+ "LigerJSD",
157
+ "LigerLayerNorm",
158
+ "LigerFusedAddRMSNorm",
159
+ "LigerPolyNorm",
160
+ "LigerRMSNorm",
161
+ "liger_rotary_pos_emb",
162
+ "liger_llama4_text_rotary_pos_emb",
163
+ "liger_llama4_vision_rotary_pos_emb",
164
+ "LigerBlockSparseTop2MLP",
165
+ "LigerPhi3SwiGLUMLP",
166
+ "LigerQwen3MoeSwiGLUMLP",
167
+ "LigerSwiGLUMLP",
168
+ "LigerTiledGEGLUMLP",
169
+ "LigerTiledSwiGLUMLP",
170
+ "LigerTVDLoss",
171
+ "LigerKLDIVLoss",
172
+ "LigerMultiTokenAttention",
173
+ "LigerSoftmax",
174
+ "LigerSparsemax",
175
+ ]
176
+
177
+ # Add transformer-dependent symbols only if available
178
+ if _TRANSFORMERS_AVAILABLE:
179
+ __all__.extend(
180
+ [
181
+ "AutoLigerKernelForCausalLM",
182
+ "_apply_liger_kernel",
183
+ "_apply_liger_kernel_to_instance",
184
+ "apply_liger_kernel_to_falcon_h1",
185
+ "apply_liger_kernel_to_gemma",
186
+ "apply_liger_kernel_to_gemma2",
187
+ "apply_liger_kernel_to_gemma3",
188
+ "apply_liger_kernel_to_gemma3_text",
189
+ "apply_liger_kernel_to_glm4",
190
+ "apply_liger_kernel_to_glm4v",
191
+ "apply_liger_kernel_to_glm4v_moe",
192
+ "apply_liger_kernel_to_gpt_oss",
193
+ "apply_liger_kernel_to_granite",
194
+ "apply_liger_kernel_to_internvl",
195
+ "apply_liger_kernel_to_llama",
196
+ "apply_liger_kernel_to_llava",
197
+ "apply_liger_kernel_to_llama4",
198
+ "apply_liger_kernel_to_mistral",
199
+ "apply_liger_kernel_to_mixtral",
200
+ "apply_liger_kernel_to_mllama",
201
+ "apply_liger_kernel_to_olmo2",
202
+ "apply_liger_kernel_to_olmo3",
203
+ "apply_liger_kernel_to_paligemma",
204
+ "apply_liger_kernel_to_phi3",
205
+ "apply_liger_kernel_to_qwen2",
206
+ "apply_liger_kernel_to_qwen2_5_vl",
207
+ "apply_liger_kernel_to_qwen2_vl",
208
+ "apply_liger_kernel_to_qwen3",
209
+ "apply_liger_kernel_to_qwen3_moe",
210
+ "apply_liger_kernel_to_qwen3_next",
211
+ "apply_liger_kernel_to_qwen3_vl",
212
+ "apply_liger_kernel_to_qwen3_vl_moe",
213
+ "apply_liger_kernel_to_smollm3",
214
+ "apply_liger_kernel_to_smolvlm",
215
+ "apply_liger_kernel_to_hunyuan_v1_dense",
216
+ "apply_liger_kernel_to_hunyuan_v1_moe",
217
+ ]
218
+ )
@@ -2,7 +2,8 @@ from typing import Optional
2
2
 
3
3
  import torch
4
4
 
5
- from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
5
+ from liger_kernel.ops import LigerCrossEntropyFunction
6
+ from liger_kernel.transformers.functional import CrossEntropyOutput
6
7
 
7
8
 
8
9
  class LigerCrossEntropyLoss(torch.nn.Module):
@@ -15,6 +16,7 @@ class LigerCrossEntropyLoss(torch.nn.Module):
15
16
  reduction: str = "mean",
16
17
  softcap: Optional[float] = None,
17
18
  return_z_loss: bool = False,
19
+ return_token_accuracy: bool = False,
18
20
  ):
19
21
  super().__init__()
20
22
  assert (label_smoothing >= 0) and (label_smoothing <= 1), (
@@ -33,9 +35,10 @@ class LigerCrossEntropyLoss(torch.nn.Module):
33
35
  self.reduction = reduction
34
36
  self.softcap = softcap
35
37
  self.return_z_loss = return_z_loss
38
+ self.return_token_accuracy = return_token_accuracy
36
39
 
37
40
  def forward(self, _input: torch.Tensor, target: torch.Tensor):
38
- loss, z_loss = LigerCrossEntropyFunction.apply(
41
+ loss, z_loss, token_accuracy = LigerCrossEntropyFunction.apply(
39
42
  _input,
40
43
  target,
41
44
  self.weight,
@@ -45,7 +48,9 @@ class LigerCrossEntropyLoss(torch.nn.Module):
45
48
  self.reduction,
46
49
  self.softcap,
47
50
  self.return_z_loss,
51
+ self.return_token_accuracy,
48
52
  )
49
- if not self.return_z_loss:
53
+ if not self.return_z_loss and not self.return_token_accuracy:
50
54
  return loss
51
- return loss, z_loss
55
+
56
+ return CrossEntropyOutput(loss=loss, z_loss=z_loss, token_accuracy=token_accuracy)
@@ -1,20 +1,22 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
 
4
- from liger_kernel.ops.dyt import LigerDyTFunction
4
+ from liger_kernel.ops import LigerDyTFunction
5
5
 
6
6
 
7
7
  class LigerDyT(nn.Module):
8
- def __init__(self, hidden_size, init_alpha=0.5):
8
+ def __init__(self, hidden_size, beta=True, init_alpha=0.5):
9
9
  super().__init__()
10
10
  self.hidden_size = hidden_size
11
11
  self.init_alpha = init_alpha
12
12
  self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
13
13
  self.gamma = nn.Parameter(torch.ones(hidden_size))
14
- self.beta = nn.Parameter(torch.zeros(hidden_size))
14
+ self.beta = None
15
+ if beta:
16
+ self.beta = nn.Parameter(torch.zeros(hidden_size))
15
17
 
16
18
  def forward(self, x):
17
19
  return LigerDyTFunction.apply(x, self.alpha, self.gamma, self.beta)
18
20
 
19
21
  def extra_repr(self):
20
- return f"{self.hidden_size}, init_alpha={self.init_alpha}"
22
+ return f"{self.hidden_size}, init_alpha={self.init_alpha}, beta={self.beta}"
@@ -0,0 +1,5 @@
1
+ from liger_kernel.transformers.experimental.embedding import LigerEmbedding # noqa: F401
2
+
3
+ __all__ = [
4
+ "LigerEmbedding",
5
+ ]
@@ -3,7 +3,7 @@ from typing import Optional
3
3
  import torch
4
4
  import torch.nn as nn
5
5
 
6
- from liger_kernel.ops.experimental.embedding import LigerEmbeddingFunction
6
+ from liger_kernel.ops import LigerEmbeddingFunction
7
7
 
8
8
 
9
9
  class LigerEmbedding(nn.Module):