liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/__init__.py +0 -0
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +8 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/cpo_loss.py +157 -0
- liger_kernel/chunked_loss/dpo_loss.py +229 -0
- liger_kernel/chunked_loss/functional.py +17 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
- liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
- liger_kernel/chunked_loss/grpo_loss.py +304 -0
- liger_kernel/chunked_loss/jsd_loss.py +200 -0
- liger_kernel/chunked_loss/kto_loss.py +210 -0
- liger_kernel/chunked_loss/orpo_loss.py +144 -0
- liger_kernel/chunked_loss/simpo_loss.py +165 -0
- liger_kernel/env_report.py +21 -4
- liger_kernel/ops/cross_entropy.py +235 -84
- liger_kernel/ops/dyt.py +157 -0
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_add_rms_norm.py +412 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
- liger_kernel/ops/fused_linear_jsd.py +17 -34
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +7 -18
- liger_kernel/ops/group_norm.py +305 -0
- liger_kernel/ops/grpo_loss.py +310 -0
- liger_kernel/ops/jsd.py +46 -21
- liger_kernel/ops/kl_div.py +23 -19
- liger_kernel/ops/layer_norm.py +150 -86
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +386 -0
- liger_kernel/ops/qwen2vl_mrope.py +222 -0
- liger_kernel/ops/rms_norm.py +314 -84
- liger_kernel/ops/rope.py +32 -34
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +5 -9
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +8 -4
- liger_kernel/transformers/__init__.py +199 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +33 -20
- liger_kernel/transformers/dyt.py +22 -0
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +291 -13
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
- liger_kernel/transformers/fused_linear_jsd.py +1 -4
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +50 -0
- liger_kernel/transformers/grpo_loss.py +98 -0
- liger_kernel/transformers/jsd.py +2 -7
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +77 -77
- liger_kernel/transformers/model/gemma2.py +283 -0
- liger_kernel/transformers/model/gemma3.py +331 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +128 -79
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +344 -0
- liger_kernel/transformers/model/loss_utils.py +95 -0
- liger_kernel/transformers/model/mistral.py +68 -64
- liger_kernel/transformers/model/mixtral.py +75 -91
- liger_kernel/transformers/model/mllama.py +63 -68
- liger_kernel/transformers/model/olmo2.py +141 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +432 -0
- liger_kernel/transformers/model/phi3.py +59 -213
- liger_kernel/transformers/model/qwen2.py +75 -72
- liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
- liger_kernel/transformers/model/qwen2_vl.py +78 -98
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2106 -289
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +20 -0
- liger_kernel/transformers/rms_norm.py +57 -6
- liger_kernel/transformers/rope.py +45 -2
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +23 -8
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/trainer/__init__.py +4 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- liger_kernel/utils.py +71 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
- liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
- liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
|
|
6
|
+
from torch.nn.modules.utils import _pair
|
|
7
|
+
|
|
8
|
+
from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class LigerMultiTokenAttention(nn.Module):
|
|
12
|
+
r"""
|
|
13
|
+
Multi-Token Attention:
|
|
14
|
+
out = mask_{0}(conv2d(softmax(mask_{-\inf}(scores))))
|
|
15
|
+
|
|
16
|
+
Reference: https://arxiv.org/pdf/2504.00927
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
in_channels: int,
|
|
22
|
+
out_channels: int,
|
|
23
|
+
kernel_size: int,
|
|
24
|
+
stride: int = 1,
|
|
25
|
+
padding: int = 0,
|
|
26
|
+
dilation: int = 1,
|
|
27
|
+
groups: int = 1,
|
|
28
|
+
bias: bool = True,
|
|
29
|
+
sparse: bool = False,
|
|
30
|
+
):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.in_channels = in_channels
|
|
33
|
+
self.out_channels = out_channels
|
|
34
|
+
self.kernel_size = _pair(kernel_size)
|
|
35
|
+
self.stride = _pair(stride)
|
|
36
|
+
self.padding = _pair(padding)
|
|
37
|
+
self.dilation = _pair(dilation)
|
|
38
|
+
self.groups = groups
|
|
39
|
+
self.sparse = sparse
|
|
40
|
+
|
|
41
|
+
self.weight = nn.Parameter(torch.empty(out_channels, in_channels // groups, *self.kernel_size))
|
|
42
|
+
if bias:
|
|
43
|
+
self.bias = nn.Parameter(torch.empty(out_channels))
|
|
44
|
+
else:
|
|
45
|
+
self.register_parameter("bias", None)
|
|
46
|
+
|
|
47
|
+
self.reset_parameters()
|
|
48
|
+
|
|
49
|
+
def reset_parameters(self):
|
|
50
|
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
|
51
|
+
if self.bias is not None:
|
|
52
|
+
nn.init.zeros_(self.bias)
|
|
53
|
+
|
|
54
|
+
def forward(self, scores: torch.Tensor) -> torch.Tensor:
|
|
55
|
+
return LigerMultiTokenAttentionFunction.apply(
|
|
56
|
+
scores,
|
|
57
|
+
self.weight,
|
|
58
|
+
self.bias,
|
|
59
|
+
self.stride,
|
|
60
|
+
self.padding,
|
|
61
|
+
self.dilation,
|
|
62
|
+
self.groups,
|
|
63
|
+
self.sparse,
|
|
64
|
+
)
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from liger_kernel.ops.poly_norm import LigerPolyNormFunction
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerPolyNorm(nn.Module):
|
|
8
|
+
"""
|
|
9
|
+
PolyNorm layer wrapper for Liger kernel.
|
|
10
|
+
|
|
11
|
+
PolyNorm formula:
|
|
12
|
+
y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
|
|
13
|
+
where norm(u) = u / sqrt(mean(u²) + ε)
|
|
14
|
+
|
|
15
|
+
Reference:
|
|
16
|
+
https://github.com/BryceZhuo/PolyCom/
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
eps: epsilon for numerical stability (default: 1e-6)
|
|
20
|
+
in_place: whether to in-place modify grad_output in backward to save memory (default: False).
|
|
21
|
+
Set to True to save memory if grad_output is not needed elsewhere.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, eps=1e-6, in_place=True):
|
|
25
|
+
super().__init__()
|
|
26
|
+
# Align with PolyCom reference: initialize weights to (1/3, 1/3, 1/3) and bias to 1.0
|
|
27
|
+
self.weight = nn.Parameter(torch.full((3,), 1.0 / 3.0))
|
|
28
|
+
self.bias = nn.Parameter(torch.tensor(1.0))
|
|
29
|
+
self.variance_epsilon = eps
|
|
30
|
+
self.in_place = in_place
|
|
31
|
+
|
|
32
|
+
def forward(self, hidden_states):
|
|
33
|
+
return LigerPolyNormFunction.apply(
|
|
34
|
+
hidden_states,
|
|
35
|
+
self.weight,
|
|
36
|
+
self.bias,
|
|
37
|
+
self.variance_epsilon,
|
|
38
|
+
self.in_place,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
def extra_repr(self):
|
|
42
|
+
return f"weight_shape={tuple(self.weight.shape)}, eps={self.variance_epsilon}, in_place={self.in_place}"
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
|
5
|
+
"""
|
|
6
|
+
Applies Multimodal Rotary Positional Embedding (M-RoPE) operation to query and key states.
|
|
7
|
+
|
|
8
|
+
Args:
|
|
9
|
+
q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
|
|
10
|
+
k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
|
|
11
|
+
cos (torch.Tensor): The cosine tensor of shape (3, bsz, seq_len, head_dim).
|
|
12
|
+
sin (torch.Tensor): The sine tensor of shape (3, bsz, seq_len, head_dim).
|
|
13
|
+
mrope_section (List[int]): The multimodal rope section for channel dimension of temporal, height and width in rope calculation.
|
|
14
|
+
unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
Tuple[torch.Tensor, torch.Tensor]: The query and key tensors after applying the M-RoPE operation.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
|
|
@@ -6,20 +6,27 @@ from liger_kernel.ops.rms_norm import LigerRMSNormFunction
|
|
|
6
6
|
|
|
7
7
|
class LigerRMSNorm(nn.Module):
|
|
8
8
|
def __init__(
|
|
9
|
-
self,
|
|
9
|
+
self,
|
|
10
|
+
hidden_size,
|
|
11
|
+
eps=1e-6,
|
|
12
|
+
offset=0.0,
|
|
13
|
+
casting_mode="llama",
|
|
14
|
+
init_fn="ones",
|
|
15
|
+
in_place=True,
|
|
16
|
+
row_mode=None,
|
|
10
17
|
):
|
|
11
18
|
super().__init__()
|
|
12
19
|
assert init_fn in [
|
|
13
20
|
"ones",
|
|
14
21
|
"zeros",
|
|
15
22
|
], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
|
|
16
|
-
self.weight = nn.Parameter(
|
|
17
|
-
|
|
18
|
-
)
|
|
19
|
-
self.variance_epsilon, self.offset, self.casting_mode = (
|
|
23
|
+
self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
|
|
24
|
+
self.variance_epsilon, self.offset, self.casting_mode, self.in_place, self.row_mode = (
|
|
20
25
|
eps,
|
|
21
26
|
offset,
|
|
22
27
|
casting_mode,
|
|
28
|
+
in_place,
|
|
29
|
+
row_mode,
|
|
23
30
|
)
|
|
24
31
|
|
|
25
32
|
def forward(self, hidden_states):
|
|
@@ -29,7 +36,51 @@ class LigerRMSNorm(nn.Module):
|
|
|
29
36
|
self.variance_epsilon,
|
|
30
37
|
self.offset,
|
|
31
38
|
self.casting_mode,
|
|
39
|
+
self.in_place,
|
|
40
|
+
self.row_mode,
|
|
32
41
|
)
|
|
33
42
|
|
|
34
43
|
def extra_repr(self):
|
|
35
|
-
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}"
|
|
44
|
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}, row_mode={self.row_mode}"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class LigerRMSNormForGemma(LigerRMSNorm):
|
|
48
|
+
def __init__(
|
|
49
|
+
self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=True, row_mode=None
|
|
50
|
+
):
|
|
51
|
+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class LigerRMSNormForGemma2(LigerRMSNorm):
|
|
55
|
+
def __init__(
|
|
56
|
+
self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False, row_mode=None
|
|
57
|
+
):
|
|
58
|
+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class LigerRMSNormForGemma3(LigerRMSNorm):
|
|
62
|
+
"""Gemma3RMSNorm has a dim argument not hidden_size used in q_norm and k_norm."""
|
|
63
|
+
|
|
64
|
+
def __init__(self, dim, eps=0.000001, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False):
|
|
65
|
+
super().__init__(dim, eps, offset, casting_mode, init_fn, in_place)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class LigerRMSNormForOlmo2(LigerRMSNorm):
|
|
69
|
+
def __init__(
|
|
70
|
+
self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None
|
|
71
|
+
):
|
|
72
|
+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class LigerRMSNormForGlm4(LigerRMSNorm):
|
|
76
|
+
def __init__(
|
|
77
|
+
self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None
|
|
78
|
+
):
|
|
79
|
+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class LigerRMSNormForQwen3Next(LigerRMSNorm):
|
|
83
|
+
def __init__(
|
|
84
|
+
self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False, row_mode=None
|
|
85
|
+
):
|
|
86
|
+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
|
|
@@ -1,3 +1,8 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
from typing import Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
1
6
|
from liger_kernel.ops.rope import LigerRopeFunction
|
|
2
7
|
|
|
3
8
|
|
|
@@ -8,8 +13,8 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
8
13
|
Args:
|
|
9
14
|
q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
|
|
10
15
|
k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
|
|
11
|
-
cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim).
|
|
12
|
-
sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim).
|
|
16
|
+
cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
|
|
17
|
+
sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
|
|
13
18
|
position_ids (torch.Tensor, optional): The position ids tensor. Defaults to None.
|
|
14
19
|
unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
|
|
15
20
|
|
|
@@ -18,3 +23,41 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
18
23
|
"""
|
|
19
24
|
|
|
20
25
|
return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def liger_rotary_pos_emb_with_cast(
|
|
29
|
+
q: torch.Tensor,
|
|
30
|
+
k: torch.Tensor,
|
|
31
|
+
cos: torch.Tensor,
|
|
32
|
+
sin: torch.Tensor,
|
|
33
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
34
|
+
unsqueeze_dim: int = 1,
|
|
35
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
36
|
+
orig_q_dtype, orig_k_dtype = q.dtype, k.dtype
|
|
37
|
+
|
|
38
|
+
q32 = q.to(torch.float32)
|
|
39
|
+
k32 = k.to(torch.float32)
|
|
40
|
+
cos32 = cos.to(torch.float32)
|
|
41
|
+
sin32 = sin.to(torch.float32)
|
|
42
|
+
|
|
43
|
+
q_out, k_out = liger_rotary_pos_emb(q32, k32, cos32, sin32, position_ids=position_ids, unsqueeze_dim=unsqueeze_dim)
|
|
44
|
+
return q_out.to(orig_q_dtype), k_out.to(orig_k_dtype)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def liger_rotary_pos_emb_with_cast_and_leading_batch(
|
|
48
|
+
q: torch.Tensor,
|
|
49
|
+
k: torch.Tensor,
|
|
50
|
+
cos: torch.Tensor,
|
|
51
|
+
sin: torch.Tensor,
|
|
52
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
53
|
+
unsqueeze_dim: int = 1,
|
|
54
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
55
|
+
orig_q_dtype, orig_k_dtype = q.dtype, k.dtype
|
|
56
|
+
|
|
57
|
+
q32 = q.to(torch.float32).unsqueeze(0)
|
|
58
|
+
k32 = k.to(torch.float32).unsqueeze(0)
|
|
59
|
+
cos32 = cos.to(torch.float32).unsqueeze(0)
|
|
60
|
+
sin32 = sin.to(torch.float32).unsqueeze(0)
|
|
61
|
+
|
|
62
|
+
q_out, k_out = liger_rotary_pos_emb(q32, k32, cos32, sin32, position_ids=position_ids, unsqueeze_dim=unsqueeze_dim)
|
|
63
|
+
return q_out.to(orig_q_dtype).squeeze(0), k_out.to(orig_k_dtype).squeeze(0)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from liger_kernel.ops.softmax import LigerSoftmaxFunction
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerSoftmax(nn.Module):
|
|
8
|
+
def __init__(self):
|
|
9
|
+
super().__init__()
|
|
10
|
+
|
|
11
|
+
def forward(self, x: torch.Tensor):
|
|
12
|
+
return LigerSoftmaxFunction.apply(x)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerSparsemax(nn.Module):
|
|
8
|
+
def __init__(self, dim: int = -1):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.dim = dim
|
|
11
|
+
|
|
12
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
13
|
+
return LigerSparsemaxFunction.apply(x, self.dim)
|
|
14
|
+
|
|
15
|
+
def extra_repr(self) -> str:
|
|
16
|
+
return f"dim={self.dim}"
|
|
@@ -16,10 +16,7 @@ class LigerSwiGLUMLP(nn.Module):
|
|
|
16
16
|
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
|
17
17
|
|
|
18
18
|
def forward(self, x):
|
|
19
|
-
|
|
20
|
-
return self.down_proj(
|
|
21
|
-
LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x))
|
|
22
|
-
)
|
|
19
|
+
return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
|
|
23
20
|
|
|
24
21
|
|
|
25
22
|
class LigerBlockSparseTop2MLP(nn.Module):
|
|
@@ -36,7 +33,6 @@ class LigerBlockSparseTop2MLP(nn.Module):
|
|
|
36
33
|
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
|
37
34
|
|
|
38
35
|
def forward(self, x):
|
|
39
|
-
|
|
40
36
|
return self.w2(LigerSiLUMulFunction.apply(self.w1(x), self.w3(x)))
|
|
41
37
|
|
|
42
38
|
|
|
@@ -51,9 +47,7 @@ class LigerPhi3SwiGLUMLP(nn.Module):
|
|
|
51
47
|
self.config = config
|
|
52
48
|
self.hidden_size = config.hidden_size
|
|
53
49
|
self.intermediate_size = config.intermediate_size
|
|
54
|
-
self.gate_up_proj = nn.Linear(
|
|
55
|
-
self.hidden_size, 2 * self.intermediate_size, bias=False
|
|
56
|
-
)
|
|
50
|
+
self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
|
|
57
51
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
58
52
|
if config.hidden_act not in ["silu", "swish"]:
|
|
59
53
|
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
|
@@ -62,3 +56,24 @@ class LigerPhi3SwiGLUMLP(nn.Module):
|
|
|
62
56
|
up_states = self.gate_up_proj(x)
|
|
63
57
|
gate, up_states = up_states.chunk(2, dim=-1)
|
|
64
58
|
return self.down_proj(LigerSiLUMulFunction.apply(gate, up_states))
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class LigerQwen3MoeSwiGLUMLP(nn.Module):
|
|
62
|
+
"""
|
|
63
|
+
Patch Qwen3MoeMLP to use LigerSiLUMulFunction.
|
|
64
|
+
https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/qwen3_moe/modular_qwen3_moe.py#L57
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(self, config, intermediate_size=None):
|
|
68
|
+
super().__init__()
|
|
69
|
+
self.config = config
|
|
70
|
+
self.hidden_size = config.hidden_size
|
|
71
|
+
self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
|
|
72
|
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
73
|
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
74
|
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
75
|
+
if config.hidden_act not in ["silu", "swish"]:
|
|
76
|
+
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
|
77
|
+
|
|
78
|
+
def forward(self, x):
|
|
79
|
+
return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
|
|
5
|
+
from liger_kernel.ops.geglu import LigerGELUMulFunction
|
|
6
|
+
from liger_kernel.ops.swiglu import LigerSiLUMulFunction
|
|
7
|
+
from liger_kernel.ops.tiled_mlp import apply_tiled_mlp
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class LigerTiledGEGLUMLP(nn.Module):
|
|
11
|
+
"""
|
|
12
|
+
Memory-efficient GEGLU MLP using tiled computation.
|
|
13
|
+
|
|
14
|
+
This module combines GEGLU activation with tiled processing to handle
|
|
15
|
+
very long sequences efficiently. The forward pass is recomputed during
|
|
16
|
+
backward to save memory.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
config: Model configuration with hidden_size and intermediate_size attributes
|
|
20
|
+
num_shards: Number of shards to split the sequence. If None, automatically
|
|
21
|
+
calculated as ceil(seqlen / hidden_size)
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, config, num_shards: Optional[int] = None):
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.config = config
|
|
27
|
+
self.hidden_size = config.hidden_size
|
|
28
|
+
self.intermediate_size = config.intermediate_size
|
|
29
|
+
self.num_shards = num_shards
|
|
30
|
+
|
|
31
|
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
32
|
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
33
|
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
34
|
+
|
|
35
|
+
# Validate activation function
|
|
36
|
+
if hasattr(config, "hidden_act") and config.hidden_act not in [
|
|
37
|
+
"gelu",
|
|
38
|
+
"gelu_new",
|
|
39
|
+
"gelu_pytorch_tanh",
|
|
40
|
+
]:
|
|
41
|
+
raise ValueError(f"LigerTiledGEGLUMLP requires GELU activation, got {config.hidden_act}")
|
|
42
|
+
|
|
43
|
+
def _mlp_forward(self, module, x):
|
|
44
|
+
"""Internal MLP forward function for tiled computation."""
|
|
45
|
+
gate = module.gate_proj(x)
|
|
46
|
+
up = module.up_proj(x)
|
|
47
|
+
return module.down_proj(LigerGELUMulFunction.apply(gate, up))
|
|
48
|
+
|
|
49
|
+
def forward(self, x):
|
|
50
|
+
"""
|
|
51
|
+
Forward pass with tiled computation.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
x: Input tensor of shape [batch_size, seq_len, hidden_size]
|
|
55
|
+
or [seq_len, hidden_size]
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Output tensor of the same shape as input
|
|
59
|
+
"""
|
|
60
|
+
compute_params = [
|
|
61
|
+
self.gate_proj.weight,
|
|
62
|
+
self.up_proj.weight,
|
|
63
|
+
self.down_proj.weight,
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
return apply_tiled_mlp(
|
|
67
|
+
fn=self._mlp_forward,
|
|
68
|
+
mlp_module=self,
|
|
69
|
+
x=x,
|
|
70
|
+
num_shards=self.num_shards,
|
|
71
|
+
compute_params=compute_params,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class LigerTiledSwiGLUMLP(nn.Module):
|
|
76
|
+
"""
|
|
77
|
+
Memory-efficient SwiGLU MLP using tiled computation.
|
|
78
|
+
|
|
79
|
+
This module combines SwiGLU activation with tiled processing to handle
|
|
80
|
+
very long sequences efficiently. The forward pass is recomputed during
|
|
81
|
+
backward to save memory.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
config: Model configuration with hidden_size and intermediate_size attributes
|
|
85
|
+
num_shards: Number of shards to split the sequence. If None, automatically
|
|
86
|
+
calculated as ceil(seqlen / hidden_size)
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
def __init__(self, config, num_shards: Optional[int] = None):
|
|
90
|
+
super().__init__()
|
|
91
|
+
self.config = config
|
|
92
|
+
self.hidden_size = config.hidden_size
|
|
93
|
+
self.intermediate_size = config.intermediate_size
|
|
94
|
+
self.num_shards = num_shards
|
|
95
|
+
|
|
96
|
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
97
|
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
98
|
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
99
|
+
|
|
100
|
+
# Validate activation function
|
|
101
|
+
if hasattr(config, "hidden_act") and config.hidden_act not in ["silu", "swish"]:
|
|
102
|
+
raise ValueError(f"LigerTiledSwiGLUMLP requires SiLU/Swish activation, got {config.hidden_act}")
|
|
103
|
+
|
|
104
|
+
def _mlp_forward(self, module, x):
|
|
105
|
+
"""Internal MLP forward function for tiled computation."""
|
|
106
|
+
gate = module.gate_proj(x)
|
|
107
|
+
up = module.up_proj(x)
|
|
108
|
+
return module.down_proj(LigerSiLUMulFunction.apply(gate, up))
|
|
109
|
+
|
|
110
|
+
def forward(self, x):
|
|
111
|
+
"""
|
|
112
|
+
Forward pass with tiled computation.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
x: Input tensor of shape [batch_size, seq_len, hidden_size]
|
|
116
|
+
or [seq_len, hidden_size]
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
Output tensor of the same shape as input
|
|
120
|
+
"""
|
|
121
|
+
compute_params = [
|
|
122
|
+
self.gate_proj.weight,
|
|
123
|
+
self.up_proj.weight,
|
|
124
|
+
self.down_proj.weight,
|
|
125
|
+
]
|
|
126
|
+
|
|
127
|
+
return apply_tiled_mlp(
|
|
128
|
+
fn=self._mlp_forward,
|
|
129
|
+
mlp_module=self,
|
|
130
|
+
x=x,
|
|
131
|
+
num_shards=self.num_shards,
|
|
132
|
+
compute_params=compute_params,
|
|
133
|
+
)
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
from typing import Dict
|
|
2
|
+
from typing import List
|
|
3
|
+
from typing import Literal
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
from typing import Union
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
|
|
10
|
+
from torch.distributed.fsdp import FullyShardedDataParallel
|
|
11
|
+
from trl.trainer import ORPOTrainer
|
|
12
|
+
|
|
13
|
+
from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss
|
|
14
|
+
from liger_kernel.transformers.fsdp import _FSDPForwardRedirection
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class LigerORPOTrainer(ORPOTrainer):
|
|
18
|
+
def concatenated_forward(
|
|
19
|
+
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
|
|
20
|
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
|
21
|
+
"""
|
|
22
|
+
Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
|
23
|
+
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
|
24
|
+
"""
|
|
25
|
+
concatenated_batch = self.concatenated_inputs(
|
|
26
|
+
batch,
|
|
27
|
+
is_encoder_decoder=self.is_encoder_decoder,
|
|
28
|
+
label_pad_token_id=self.label_pad_token_id,
|
|
29
|
+
padding_value=self.padding_value,
|
|
30
|
+
device=self.accelerator.device,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
model_kwargs = (
|
|
34
|
+
{
|
|
35
|
+
"decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
|
|
36
|
+
}
|
|
37
|
+
if self.is_encoder_decoder
|
|
38
|
+
else {}
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
if self.aux_loss_enabled:
|
|
42
|
+
model_kwargs["output_router_logits"] = True
|
|
43
|
+
|
|
44
|
+
if self.is_encoder_decoder:
|
|
45
|
+
labels = concatenated_batch["concatenated_labels"].clone()
|
|
46
|
+
else:
|
|
47
|
+
labels = concatenated_batch["concatenated_input_ids"].clone()
|
|
48
|
+
attention_mask = concatenated_batch["concatenated_attention_mask"]
|
|
49
|
+
labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
|
|
50
|
+
|
|
51
|
+
if isinstance(model, FullyShardedDataParallel):
|
|
52
|
+
outputs = _FSDPForwardRedirection()(
|
|
53
|
+
model,
|
|
54
|
+
model._fsdp_wrapped_module.model,
|
|
55
|
+
concatenated_batch["concatenated_input_ids"],
|
|
56
|
+
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
|
57
|
+
use_cache=False,
|
|
58
|
+
**model_kwargs,
|
|
59
|
+
)
|
|
60
|
+
else:
|
|
61
|
+
if isinstance(model, torch.nn.DataParallel):
|
|
62
|
+
model = model.module
|
|
63
|
+
outputs = model.model(
|
|
64
|
+
concatenated_batch["concatenated_input_ids"],
|
|
65
|
+
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
|
66
|
+
use_cache=False,
|
|
67
|
+
**model_kwargs,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta)
|
|
71
|
+
|
|
72
|
+
def orpo_partial(lm_head, last_hidden_state, concatenated_labels, nll_target):
|
|
73
|
+
return orpo_loss_fn(
|
|
74
|
+
lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias, nll_target=nll_target
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
orpo_loss, aux_outputs = _FSDPForwardRedirection()(
|
|
78
|
+
model,
|
|
79
|
+
orpo_partial,
|
|
80
|
+
model.lm_head,
|
|
81
|
+
outputs.last_hidden_state[:, :-1] if not self.is_encoder_decoder else outputs.last_hidden_state,
|
|
82
|
+
concatenated_batch["concatenated_labels"][:, 1:]
|
|
83
|
+
if not self.is_encoder_decoder
|
|
84
|
+
else concatenated_batch["concatenated_labels"],
|
|
85
|
+
labels[:, 1:] if not self.is_encoder_decoder else labels,
|
|
86
|
+
)
|
|
87
|
+
# if aux_loss_enabled, add the aux_loss to the orpo_loss
|
|
88
|
+
if self.aux_loss_enabled:
|
|
89
|
+
orpo_loss += self.aux_loss_coef * outputs.aux_loss
|
|
90
|
+
|
|
91
|
+
return orpo_loss, aux_outputs
|
|
92
|
+
|
|
93
|
+
def get_batch_loss_metrics(
|
|
94
|
+
self,
|
|
95
|
+
model,
|
|
96
|
+
batch: Dict[str, Union[List, torch.LongTensor]],
|
|
97
|
+
train_eval: Literal["train", "eval"] = "train",
|
|
98
|
+
):
|
|
99
|
+
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
|
|
100
|
+
metrics = {}
|
|
101
|
+
loss, aux_outputs = self.concatenated_forward(model, batch)
|
|
102
|
+
(
|
|
103
|
+
policy_chosen_logps,
|
|
104
|
+
policy_rejected_logps,
|
|
105
|
+
policy_chosen_logits,
|
|
106
|
+
policy_rejected_logits,
|
|
107
|
+
policy_nll_loss,
|
|
108
|
+
) = aux_outputs[:5]
|
|
109
|
+
|
|
110
|
+
# return loss, metrics
|
|
111
|
+
chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = aux_outputs[5:]
|
|
112
|
+
|
|
113
|
+
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
|
114
|
+
|
|
115
|
+
prefix = "eval_" if train_eval == "eval" else ""
|
|
116
|
+
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean()
|
|
117
|
+
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean()
|
|
118
|
+
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean()
|
|
119
|
+
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean()
|
|
120
|
+
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean()
|
|
121
|
+
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean()
|
|
122
|
+
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean()
|
|
123
|
+
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean()
|
|
124
|
+
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean()
|
|
125
|
+
metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio
|
|
126
|
+
metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen
|
|
127
|
+
for k, v in metrics.items():
|
|
128
|
+
metrics[k] = v.item()
|
|
129
|
+
|
|
130
|
+
return loss, metrics
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
|
|
3
|
+
from liger_kernel.ops.tvd import LigerTVDLossFunction
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class LigerTVDLoss(nn.Module):
|
|
7
|
+
def __init__(self, reduction="batchmean", ignore_index: int = -100):
|
|
8
|
+
super(LigerTVDLoss, self).__init__()
|
|
9
|
+
self.reduction = reduction
|
|
10
|
+
self.ignore_index = ignore_index
|
|
11
|
+
|
|
12
|
+
def forward(self, p, q, shift_labels=None):
|
|
13
|
+
return LigerTVDLossFunction.apply(p, q, shift_labels, self.reduction, self.ignore_index)
|
liger_kernel/triton/__init__.py
CHANGED
|
@@ -37,6 +37,4 @@ def apply_liger_triton_cache_manager():
|
|
|
37
37
|
Experimental feature to get around transient FileNotFoundError in triton compilation.
|
|
38
38
|
For more details please see https://github.com/triton-lang/triton/pull/4295
|
|
39
39
|
"""
|
|
40
|
-
os.environ["TRITON_CACHE_MANAGER"] =
|
|
41
|
-
"liger_kernel.triton.monkey_patch:LigerTritonFileCacheManager"
|
|
42
|
-
)
|
|
40
|
+
os.environ["TRITON_CACHE_MANAGER"] = "liger_kernel.triton.monkey_patch:LigerTritonFileCacheManager"
|