liger-kernel-nightly 0.6.3.dev20251106220336__py3-none-any.whl → 0.6.3.dev20251118154655__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,136 @@
1
+ import math
2
+
3
+ from typing import Callable
4
+ from typing import List
5
+ from typing import Optional
6
+
7
+ import torch
8
+
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+
11
+
12
+ class LigerTiledMLPFunction(torch.autograd.Function):
13
+ """
14
+ Based on DeepSpeed's TiledMLP:
15
+ https://github.com/deepspeedai/DeepSpeed/blob/v0.18.2/deepspeed/runtime/sequence_parallel/ulysses_sp.py#L838
16
+
17
+ Perform a tiled MLP computation to massively reduce memory usage needed to compute MLP
18
+ when using very long sequence lengths.
19
+
20
+ This module re-computes `forward` in the `backward`. So the `forward` occurs twice each iteration.
21
+ And if you're using activation checkpointing it then occurs thrice.
22
+
23
+ Args:
24
+ fn: the function to call on sharded inputs (e.g., mlp.forward)
25
+ mlp_module: the MLP nn.Module object
26
+ x: the input to MLP.forward (hidden_states)
27
+ shards: how many shards to use
28
+ compute_params: a list of weights engaged in the compute
29
+
30
+ Returns:
31
+ the computed hidden_states
32
+ """
33
+
34
+ @staticmethod
35
+ @ensure_contiguous
36
+ def forward(
37
+ ctx,
38
+ fn: Callable,
39
+ mlp_module: torch.nn.Module,
40
+ x: torch.Tensor,
41
+ shards: int,
42
+ compute_params: Optional[List[torch.nn.Parameter]] = None,
43
+ ) -> torch.Tensor:
44
+ ctx.fn = fn
45
+ ctx.mlp_module = mlp_module
46
+ ctx.shards = shards
47
+ ctx.save_for_backward(x)
48
+
49
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
50
+ x_shards = list(torch.chunk(x, chunks=shards, dim=-2))
51
+ with torch.no_grad():
52
+ output_shards = [fn(mlp_module, x_shard) for x_shard in x_shards]
53
+ output_unsharded = torch.cat(output_shards, dim=-2)
54
+
55
+ return output_unsharded
56
+
57
+ @staticmethod
58
+ @ensure_contiguous
59
+ def backward(ctx, *grads) -> tuple:
60
+ fn = ctx.fn
61
+ (x,) = ctx.saved_tensors
62
+ mlp_module = ctx.mlp_module
63
+ shards = ctx.shards
64
+
65
+ x_requires_grad = x.requires_grad
66
+ x = x.detach()
67
+ # detach() unsets x.requires_grad, so restore it
68
+ x.requires_grad_(x_requires_grad)
69
+
70
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
71
+ hidden_size = x.shape[-1]
72
+ x_shape_orig = x.shape
73
+
74
+ # flatten bs+seqlen to avoid having stride issues when narrowing into seqlen w/ bs>1
75
+ x = x.view(-1, hidden_size)
76
+ incoming_grad = grads[0].view(-1, hidden_size)
77
+ x_grad = torch.zeros_like(x)
78
+
79
+ x_shards = list(torch.chunk(x, chunks=shards, dim=0))
80
+
81
+ for i, x_shard in enumerate(x_shards):
82
+ x_shard.requires_grad_(x_requires_grad)
83
+
84
+ # if seqlen is not exactly divisible by shards the last step will be shorter than shard_step
85
+ shard_step = x_shards[i].shape[0]
86
+ shard_offset = i * x_shards[0].shape[0]
87
+
88
+ x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
89
+ incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
90
+
91
+ with torch.enable_grad():
92
+ output = fn(mlp_module, x_shard)
93
+ torch.autograd.backward(output, incoming_grad_shard)
94
+
95
+ # unflatten
96
+ x_grad = x_grad.view(x_shape_orig)
97
+
98
+ return (None, None, x_grad, None, None)
99
+
100
+
101
+ def apply_tiled_mlp(
102
+ fn: Callable,
103
+ mlp_module: torch.nn.Module,
104
+ x: torch.Tensor,
105
+ num_shards: Optional[int] = None,
106
+ compute_params: Optional[List[torch.nn.Parameter]] = None,
107
+ ) -> torch.Tensor:
108
+ """
109
+ Apply tiled MLP computation for memory efficiency.
110
+
111
+ Args:
112
+ fn: the function to call on sharded inputs (e.g., lambda module, x: module(x))
113
+ mlp_module: the MLP nn.Module object
114
+ x: the input tensor with shape [bs, seqlen, hidden_size] or [seqlen, hidden_size]
115
+ num_shards: number of shards to use. If None, automatically calculated as ceil(seqlen / hidden_size)
116
+ compute_params: list of parameters for DeepSpeed ZeRO optimization
117
+
118
+ Returns:
119
+ output tensor with the same shape as input
120
+ """
121
+ if num_shards is None:
122
+ # x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size]
123
+ hidden_size = x.shape[-1]
124
+ seqlen = x.shape[-2]
125
+ num_shards = math.ceil(seqlen / hidden_size)
126
+
127
+ # Ensure num_shards is at least 1
128
+ num_shards = max(1, num_shards)
129
+
130
+ return LigerTiledMLPFunction.apply(
131
+ fn,
132
+ mlp_module,
133
+ x,
134
+ num_shards,
135
+ compute_params,
136
+ )
@@ -24,6 +24,8 @@ from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F4
24
24
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
25
25
  from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP # noqa: F401
26
26
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
27
+ from liger_kernel.transformers.tiled_mlp import LigerTiledGEGLUMLP # noqa: F401
28
+ from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP # noqa: F401
27
29
  from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401
28
30
 
29
31
  # Static-only imports for IDEs and type checkers
@@ -155,6 +157,8 @@ __all__ = [
155
157
  "LigerPhi3SwiGLUMLP",
156
158
  "LigerQwen3MoeSwiGLUMLP",
157
159
  "LigerSwiGLUMLP",
160
+ "LigerTiledGEGLUMLP",
161
+ "LigerTiledSwiGLUMLP",
158
162
  "LigerTVDLoss",
159
163
  "LigerKLDIVLoss",
160
164
  "LigerMultiTokenAttention",
@@ -0,0 +1,133 @@
1
+ from typing import Optional
2
+
3
+ import torch.nn as nn
4
+
5
+ from liger_kernel.ops.geglu import LigerGELUMulFunction
6
+ from liger_kernel.ops.swiglu import LigerSiLUMulFunction
7
+ from liger_kernel.ops.tiled_mlp import apply_tiled_mlp
8
+
9
+
10
+ class LigerTiledGEGLUMLP(nn.Module):
11
+ """
12
+ Memory-efficient GEGLU MLP using tiled computation.
13
+
14
+ This module combines GEGLU activation with tiled processing to handle
15
+ very long sequences efficiently. The forward pass is recomputed during
16
+ backward to save memory.
17
+
18
+ Args:
19
+ config: Model configuration with hidden_size and intermediate_size attributes
20
+ num_shards: Number of shards to split the sequence. If None, automatically
21
+ calculated as ceil(seqlen / hidden_size)
22
+ """
23
+
24
+ def __init__(self, config, num_shards: Optional[int] = None):
25
+ super().__init__()
26
+ self.config = config
27
+ self.hidden_size = config.hidden_size
28
+ self.intermediate_size = config.intermediate_size
29
+ self.num_shards = num_shards
30
+
31
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
32
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
33
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
34
+
35
+ # Validate activation function
36
+ if hasattr(config, "hidden_act") and config.hidden_act not in [
37
+ "gelu",
38
+ "gelu_new",
39
+ "gelu_pytorch_tanh",
40
+ ]:
41
+ raise ValueError(f"LigerTiledGEGLUMLP requires GELU activation, got {config.hidden_act}")
42
+
43
+ def _mlp_forward(self, module, x):
44
+ """Internal MLP forward function for tiled computation."""
45
+ gate = module.gate_proj(x)
46
+ up = module.up_proj(x)
47
+ return module.down_proj(LigerGELUMulFunction.apply(gate, up))
48
+
49
+ def forward(self, x):
50
+ """
51
+ Forward pass with tiled computation.
52
+
53
+ Args:
54
+ x: Input tensor of shape [batch_size, seq_len, hidden_size]
55
+ or [seq_len, hidden_size]
56
+
57
+ Returns:
58
+ Output tensor of the same shape as input
59
+ """
60
+ compute_params = [
61
+ self.gate_proj.weight,
62
+ self.up_proj.weight,
63
+ self.down_proj.weight,
64
+ ]
65
+
66
+ return apply_tiled_mlp(
67
+ fn=self._mlp_forward,
68
+ mlp_module=self,
69
+ x=x,
70
+ num_shards=self.num_shards,
71
+ compute_params=compute_params,
72
+ )
73
+
74
+
75
+ class LigerTiledSwiGLUMLP(nn.Module):
76
+ """
77
+ Memory-efficient SwiGLU MLP using tiled computation.
78
+
79
+ This module combines SwiGLU activation with tiled processing to handle
80
+ very long sequences efficiently. The forward pass is recomputed during
81
+ backward to save memory.
82
+
83
+ Args:
84
+ config: Model configuration with hidden_size and intermediate_size attributes
85
+ num_shards: Number of shards to split the sequence. If None, automatically
86
+ calculated as ceil(seqlen / hidden_size)
87
+ """
88
+
89
+ def __init__(self, config, num_shards: Optional[int] = None):
90
+ super().__init__()
91
+ self.config = config
92
+ self.hidden_size = config.hidden_size
93
+ self.intermediate_size = config.intermediate_size
94
+ self.num_shards = num_shards
95
+
96
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
97
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
98
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
99
+
100
+ # Validate activation function
101
+ if hasattr(config, "hidden_act") and config.hidden_act not in ["silu", "swish"]:
102
+ raise ValueError(f"LigerTiledSwiGLUMLP requires SiLU/Swish activation, got {config.hidden_act}")
103
+
104
+ def _mlp_forward(self, module, x):
105
+ """Internal MLP forward function for tiled computation."""
106
+ gate = module.gate_proj(x)
107
+ up = module.up_proj(x)
108
+ return module.down_proj(LigerSiLUMulFunction.apply(gate, up))
109
+
110
+ def forward(self, x):
111
+ """
112
+ Forward pass with tiled computation.
113
+
114
+ Args:
115
+ x: Input tensor of shape [batch_size, seq_len, hidden_size]
116
+ or [seq_len, hidden_size]
117
+
118
+ Returns:
119
+ Output tensor of the same shape as input
120
+ """
121
+ compute_params = [
122
+ self.gate_proj.weight,
123
+ self.up_proj.weight,
124
+ self.down_proj.weight,
125
+ ]
126
+
127
+ return apply_tiled_mlp(
128
+ fn=self._mlp_forward,
129
+ mlp_module=self,
130
+ x=x,
131
+ num_shards=self.num_shards,
132
+ compute_params=compute_params,
133
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.6.3.dev20251106220336
3
+ Version: 0.6.3.dev20251118154655
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -38,11 +38,12 @@ liger_kernel/ops/rope.py,sha256=v-7JHRrv-5ImoROkpKfl30WwWI4qTa2tAl7zQeB4ml4,8956
38
38
  liger_kernel/ops/softmax.py,sha256=tgORx6MK1IDDtZKqGarj0IPIVjqAIEUXXYPiinhRdtI,5864
39
39
  liger_kernel/ops/sparsemax.py,sha256=AeWe1xgkHJFEKWTj2vu_0hj7LztGvjqXAps-QTpCY0U,5087
40
40
  liger_kernel/ops/swiglu.py,sha256=D7nd4u_LInwsIRNCDdY77lqnTz8-W5dJrpEAt8zEO_A,3033
41
+ liger_kernel/ops/tiled_mlp.py,sha256=eyMFsFFgHch8a_6R6IYRG24_jqKg5GF_BQUoQuAG8SY,4529
41
42
  liger_kernel/ops/tvd.py,sha256=FHJtLQI95ijqgg9UtaHpMAjSCiPxB6CduPwPMcGxelc,6405
42
43
  liger_kernel/ops/utils.py,sha256=uoFKQqo-34N2TWQNvXMFywqGiOMMXNEVBxVojzlUAa0,3836
43
44
  liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
44
45
  liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
45
- liger_kernel/transformers/__init__.py,sha256=iV1X0gH1JXwgeb7AeY8Ryv7q3r44MLQvSvn79yIVDzw,9874
46
+ liger_kernel/transformers/__init__.py,sha256=XX1ySRgZXeQe0or-6GNclAsNQG_VkABQlkwqpB1Wn8A,10090
46
47
  liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
47
48
  liger_kernel/transformers/cross_entropy.py,sha256=DMtHkKrVJDSsels7KgGQJqrXkEAd6Zopcdr-5oRmQgE,2010
48
49
  liger_kernel/transformers/dyt.py,sha256=i-4GPaMrl-jab9TVI5qN0-H9qycn_mCbV82ozU4nbmU,723
@@ -68,6 +69,7 @@ liger_kernel/transformers/rope.py,sha256=VMlDZI6zss9mLaLcN5XCE_ktmYRwAi_Eh4TIgO6
68
69
  liger_kernel/transformers/softmax.py,sha256=yadlAgE4V2JByMwrDDa2s5SUBp8Jgd57xwnVvAWoBaI,264
69
70
  liger_kernel/transformers/sparsemax.py,sha256=0lQA0UEOs4mu8CMruZ3VLhImxQVXJWhPsAKUsYA7vj8,403
70
71
  liger_kernel/transformers/swiglu.py,sha256=LZ8YeLIdv2k46JleZMjzubGk98smt6t780kSgcVLsQk,3454
72
+ liger_kernel/transformers/tiled_mlp.py,sha256=J51-kpzwikDMMhT5bX-RZCKMaXBK6zZc1bhgRYTK5F0,4651
71
73
  liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx-uy2f2cFfveZpqbUdhw,123
72
74
  liger_kernel/transformers/tvd.py,sha256=XrRfyJIqN6HFxXk8MYyFVZM1OLz3mtSbRZvWfZ_JerQ,450
73
75
  liger_kernel/transformers/experimental/__init__.py,sha256=oQqk-f32JYgWEP9DJCj6ty6bbJSGrdXsFDQFwGeX6vI,127
@@ -106,9 +108,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
106
108
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
107
109
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
108
110
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
109
- liger_kernel_nightly-0.6.3.dev20251106220336.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
110
- liger_kernel_nightly-0.6.3.dev20251106220336.dist-info/METADATA,sha256=dy_9atp4YioeU8GBh82zuDxFpz-nYGyfStlvUa4RxwY,24777
111
- liger_kernel_nightly-0.6.3.dev20251106220336.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
112
- liger_kernel_nightly-0.6.3.dev20251106220336.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
113
- liger_kernel_nightly-0.6.3.dev20251106220336.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
114
- liger_kernel_nightly-0.6.3.dev20251106220336.dist-info/RECORD,,
111
+ liger_kernel_nightly-0.6.3.dev20251118154655.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
112
+ liger_kernel_nightly-0.6.3.dev20251118154655.dist-info/METADATA,sha256=vo51Vsu8KGOEZOqaFBhwOogKYG2kQBPJhMC4-KUpyeY,24777
113
+ liger_kernel_nightly-0.6.3.dev20251118154655.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
114
+ liger_kernel_nightly-0.6.3.dev20251118154655.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
115
+ liger_kernel_nightly-0.6.3.dev20251118154655.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
116
+ liger_kernel_nightly-0.6.3.dev20251118154655.dist-info/RECORD,,