liger-kernel-nightly 0.0.1.dev20240819184814__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/__init__.py +0 -0
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +8 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/cpo_loss.py +157 -0
- liger_kernel/chunked_loss/dpo_loss.py +229 -0
- liger_kernel/chunked_loss/functional.py +17 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
- liger_kernel/chunked_loss/fused_linear_ppo.py +366 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
- liger_kernel/chunked_loss/grpo_loss.py +307 -0
- liger_kernel/chunked_loss/jsd_loss.py +200 -0
- liger_kernel/chunked_loss/kto_loss.py +210 -0
- liger_kernel/chunked_loss/orpo_loss.py +144 -0
- liger_kernel/chunked_loss/simpo_loss.py +165 -0
- liger_kernel/env_report.py +63 -0
- liger_kernel/ops/__init__.py +141 -0
- liger_kernel/ops/backends/README.md +151 -0
- liger_kernel/ops/backends/__init__.py +13 -0
- liger_kernel/ops/backends/_ascend/__init__.py +5 -0
- liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
- liger_kernel/ops/backends/registry.py +61 -0
- liger_kernel/ops/cross_entropy.py +383 -114
- liger_kernel/ops/dyt.py +160 -0
- liger_kernel/ops/experimental/embedding.py +141 -0
- liger_kernel/ops/experimental/mm_int8int2.py +349 -0
- liger_kernel/ops/fused_add_rms_norm.py +416 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +346 -132
- liger_kernel/ops/fused_linear_jsd.py +228 -0
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +66 -64
- liger_kernel/ops/group_norm.py +306 -0
- liger_kernel/ops/grpo_loss.py +312 -0
- liger_kernel/ops/jsd.py +201 -0
- liger_kernel/ops/kl_div.py +262 -0
- liger_kernel/ops/layer_norm.py +320 -0
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +390 -0
- liger_kernel/ops/qwen2vl_mrope.py +222 -0
- liger_kernel/ops/rms_norm.py +484 -88
- liger_kernel/ops/rope.py +122 -117
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +68 -65
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +82 -3
- liger_kernel/transformers/__init__.py +218 -6
- liger_kernel/transformers/auto_model.py +38 -0
- liger_kernel/transformers/cross_entropy.py +52 -7
- liger_kernel/transformers/dyt.py +22 -0
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +26 -0
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +301 -0
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +59 -10
- liger_kernel/transformers/fused_linear_jsd.py +95 -0
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +6 -7
- liger_kernel/transformers/group_norm.py +50 -0
- liger_kernel/transformers/grpo_loss.py +153 -0
- liger_kernel/transformers/jsd.py +70 -0
- liger_kernel/transformers/kl_div.py +12 -0
- liger_kernel/transformers/layer_norm.py +24 -0
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +261 -0
- liger_kernel/transformers/model/gemma2.py +283 -0
- liger_kernel/transformers/model/gemma3.py +332 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/gpt_oss.py +211 -0
- liger_kernel/transformers/model/hunyuan_v1.py +134 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +221 -41
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +344 -0
- liger_kernel/transformers/model/loss_utils.py +95 -0
- liger_kernel/transformers/model/mistral.py +145 -0
- liger_kernel/transformers/model/mixtral.py +293 -0
- liger_kernel/transformers/model/mllama.py +269 -0
- liger_kernel/transformers/model/olmo2.py +141 -0
- liger_kernel/transformers/model/olmo3.py +142 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +433 -0
- liger_kernel/transformers/model/phi3.py +120 -0
- liger_kernel/transformers/model/qwen2.py +259 -0
- liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
- liger_kernel/transformers/model/qwen2_vl.py +159 -0
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2816 -21
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +20 -0
- liger_kernel/transformers/rms_norm.py +75 -5
- liger_kernel/transformers/rope.py +47 -3
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +62 -6
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/trainer/__init__.py +4 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
- liger_kernel/transformers/trainer_integration.py +2 -45
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -5
- liger_kernel/utils.py +96 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/METADATA +447 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/NOTICE +58 -0
- liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +1 -1
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/METADATA +0 -21
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/NOTICE +0 -4
- liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/RECORD +0 -27
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
|
|
6
|
+
from torch.nn.modules.utils import _pair
|
|
7
|
+
|
|
8
|
+
from liger_kernel.ops import LigerMultiTokenAttentionFunction
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class LigerMultiTokenAttention(nn.Module):
|
|
12
|
+
r"""
|
|
13
|
+
Multi-Token Attention:
|
|
14
|
+
out = mask_{0}(conv2d(softmax(mask_{-\inf}(scores))))
|
|
15
|
+
|
|
16
|
+
Reference: https://arxiv.org/pdf/2504.00927
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
in_channels: int,
|
|
22
|
+
out_channels: int,
|
|
23
|
+
kernel_size: int,
|
|
24
|
+
stride: int = 1,
|
|
25
|
+
padding: int = 0,
|
|
26
|
+
dilation: int = 1,
|
|
27
|
+
groups: int = 1,
|
|
28
|
+
bias: bool = True,
|
|
29
|
+
sparse: bool = False,
|
|
30
|
+
):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.in_channels = in_channels
|
|
33
|
+
self.out_channels = out_channels
|
|
34
|
+
self.kernel_size = _pair(kernel_size)
|
|
35
|
+
self.stride = _pair(stride)
|
|
36
|
+
self.padding = _pair(padding)
|
|
37
|
+
self.dilation = _pair(dilation)
|
|
38
|
+
self.groups = groups
|
|
39
|
+
self.sparse = sparse
|
|
40
|
+
|
|
41
|
+
self.weight = nn.Parameter(torch.empty(out_channels, in_channels // groups, *self.kernel_size))
|
|
42
|
+
if bias:
|
|
43
|
+
self.bias = nn.Parameter(torch.empty(out_channels))
|
|
44
|
+
else:
|
|
45
|
+
self.register_parameter("bias", None)
|
|
46
|
+
|
|
47
|
+
self.reset_parameters()
|
|
48
|
+
|
|
49
|
+
def reset_parameters(self):
|
|
50
|
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
|
51
|
+
if self.bias is not None:
|
|
52
|
+
nn.init.zeros_(self.bias)
|
|
53
|
+
|
|
54
|
+
def forward(self, scores: torch.Tensor) -> torch.Tensor:
|
|
55
|
+
return LigerMultiTokenAttentionFunction.apply(
|
|
56
|
+
scores,
|
|
57
|
+
self.weight,
|
|
58
|
+
self.bias,
|
|
59
|
+
self.stride,
|
|
60
|
+
self.padding,
|
|
61
|
+
self.dilation,
|
|
62
|
+
self.groups,
|
|
63
|
+
self.sparse,
|
|
64
|
+
)
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from liger_kernel.ops import LigerPolyNormFunction
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerPolyNorm(nn.Module):
|
|
8
|
+
"""
|
|
9
|
+
PolyNorm layer wrapper for Liger kernel.
|
|
10
|
+
|
|
11
|
+
PolyNorm formula:
|
|
12
|
+
y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
|
|
13
|
+
where norm(u) = u / sqrt(mean(u²) + ε)
|
|
14
|
+
|
|
15
|
+
Reference:
|
|
16
|
+
https://github.com/BryceZhuo/PolyCom/
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
eps: epsilon for numerical stability (default: 1e-6)
|
|
20
|
+
in_place: whether to in-place modify grad_output in backward to save memory (default: False).
|
|
21
|
+
Set to True to save memory if grad_output is not needed elsewhere.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, eps=1e-6, in_place=True):
|
|
25
|
+
super().__init__()
|
|
26
|
+
# Align with PolyCom reference: initialize weights to (1/3, 1/3, 1/3) and bias to 1.0
|
|
27
|
+
self.weight = nn.Parameter(torch.full((3,), 1.0 / 3.0))
|
|
28
|
+
self.bias = nn.Parameter(torch.tensor(1.0))
|
|
29
|
+
self.variance_epsilon = eps
|
|
30
|
+
self.in_place = in_place
|
|
31
|
+
|
|
32
|
+
def forward(self, hidden_states):
|
|
33
|
+
return LigerPolyNormFunction.apply(
|
|
34
|
+
hidden_states,
|
|
35
|
+
self.weight,
|
|
36
|
+
self.bias,
|
|
37
|
+
self.variance_epsilon,
|
|
38
|
+
self.in_place,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
def extra_repr(self):
|
|
42
|
+
return f"weight_shape={tuple(self.weight.shape)}, eps={self.variance_epsilon}, in_place={self.in_place}"
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from liger_kernel.ops import LigerQwen2VLMRopeFunction
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
|
5
|
+
"""
|
|
6
|
+
Applies Multimodal Rotary Positional Embedding (M-RoPE) operation to query and key states.
|
|
7
|
+
|
|
8
|
+
Args:
|
|
9
|
+
q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
|
|
10
|
+
k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
|
|
11
|
+
cos (torch.Tensor): The cosine tensor of shape (3, bsz, seq_len, head_dim).
|
|
12
|
+
sin (torch.Tensor): The sine tensor of shape (3, bsz, seq_len, head_dim).
|
|
13
|
+
mrope_section (List[int]): The multimodal rope section for channel dimension of temporal, height and width in rope calculation.
|
|
14
|
+
unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
Tuple[torch.Tensor, torch.Tensor]: The query and key tensors after applying the M-RoPE operation.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
|
|
@@ -1,16 +1,86 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import torch.nn as nn
|
|
3
3
|
|
|
4
|
-
from liger_kernel.ops
|
|
4
|
+
from liger_kernel.ops import LigerRMSNormFunction
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class LigerRMSNorm(nn.Module):
|
|
8
|
-
def __init__(
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
hidden_size,
|
|
11
|
+
eps=1e-6,
|
|
12
|
+
offset=0.0,
|
|
13
|
+
casting_mode="llama",
|
|
14
|
+
init_fn="ones",
|
|
15
|
+
in_place=True,
|
|
16
|
+
row_mode=None,
|
|
17
|
+
):
|
|
9
18
|
super().__init__()
|
|
10
|
-
|
|
11
|
-
|
|
19
|
+
assert init_fn in [
|
|
20
|
+
"ones",
|
|
21
|
+
"zeros",
|
|
22
|
+
], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
|
|
23
|
+
self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
|
|
24
|
+
self.variance_epsilon, self.offset, self.casting_mode, self.in_place, self.row_mode = (
|
|
25
|
+
eps,
|
|
26
|
+
offset,
|
|
27
|
+
casting_mode,
|
|
28
|
+
in_place,
|
|
29
|
+
row_mode,
|
|
30
|
+
)
|
|
12
31
|
|
|
13
32
|
def forward(self, hidden_states):
|
|
14
33
|
return LigerRMSNormFunction.apply(
|
|
15
|
-
hidden_states,
|
|
34
|
+
hidden_states,
|
|
35
|
+
self.weight,
|
|
36
|
+
self.variance_epsilon,
|
|
37
|
+
self.offset,
|
|
38
|
+
self.casting_mode,
|
|
39
|
+
self.in_place,
|
|
40
|
+
self.row_mode,
|
|
16
41
|
)
|
|
42
|
+
|
|
43
|
+
def extra_repr(self):
|
|
44
|
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}, row_mode={self.row_mode}"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class LigerRMSNormForGemma(LigerRMSNorm):
|
|
48
|
+
def __init__(
|
|
49
|
+
self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=True, row_mode=None
|
|
50
|
+
):
|
|
51
|
+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class LigerRMSNormForGemma2(LigerRMSNorm):
|
|
55
|
+
def __init__(
|
|
56
|
+
self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False, row_mode=None
|
|
57
|
+
):
|
|
58
|
+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class LigerRMSNormForGemma3(LigerRMSNorm):
|
|
62
|
+
"""Gemma3RMSNorm has a dim argument not hidden_size used in q_norm and k_norm."""
|
|
63
|
+
|
|
64
|
+
def __init__(self, dim, eps=0.000001, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False):
|
|
65
|
+
super().__init__(dim, eps, offset, casting_mode, init_fn, in_place)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class LigerRMSNormForOlmo2(LigerRMSNorm):
|
|
69
|
+
def __init__(
|
|
70
|
+
self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None
|
|
71
|
+
):
|
|
72
|
+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class LigerRMSNormForGlm4(LigerRMSNorm):
|
|
76
|
+
def __init__(
|
|
77
|
+
self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones", in_place=False, row_mode=None
|
|
78
|
+
):
|
|
79
|
+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class LigerRMSNormForQwen3Next(LigerRMSNorm):
|
|
83
|
+
def __init__(
|
|
84
|
+
self, hidden_size, eps=1e-6, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False, row_mode=None
|
|
85
|
+
):
|
|
86
|
+
super().__init__(hidden_size, eps, offset, casting_mode, init_fn, in_place, row_mode)
|
|
@@ -1,4 +1,8 @@
|
|
|
1
|
-
from
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from liger_kernel.ops import LigerRopeFunction
|
|
2
6
|
|
|
3
7
|
|
|
4
8
|
def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
@@ -8,8 +12,8 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
8
12
|
Args:
|
|
9
13
|
q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
|
|
10
14
|
k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
|
|
11
|
-
cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim).
|
|
12
|
-
sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim).
|
|
15
|
+
cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
|
|
16
|
+
sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
|
|
13
17
|
position_ids (torch.Tensor, optional): The position ids tensor. Defaults to None.
|
|
14
18
|
unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
|
|
15
19
|
|
|
@@ -18,3 +22,43 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
18
22
|
"""
|
|
19
23
|
|
|
20
24
|
return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def liger_rotary_pos_emb_vision(
|
|
28
|
+
q: torch.Tensor,
|
|
29
|
+
k: torch.Tensor,
|
|
30
|
+
cos: torch.Tensor,
|
|
31
|
+
sin: torch.Tensor,
|
|
32
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
33
|
+
"""
|
|
34
|
+
Modified version of liger_rotary_pos_emb for qwen3_vl's apply_rotary_pos_emb_vision function.
|
|
35
|
+
Manually tranposed the input and output to match the expected shape for liger_rotary_pos_emb.
|
|
36
|
+
Reference: https://https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L116
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
q (torch.Tensor): The query tensor of shape (seq_length, num_heads, head_dim),
|
|
40
|
+
with stride (num_heads * head_dim, head_dim, 1).
|
|
41
|
+
k (torch.Tensor): The query tensor of shape (seq_length, num_heads, head_dim),
|
|
42
|
+
with stride (num_heads * head_dim, head_dim, 1). Same as q.
|
|
43
|
+
cos (torch.Tensor): The cosine tensor of shape (seq_length, head_dim).
|
|
44
|
+
sin (torch.Tensor): The sine tensor of shape (seq_length, head_dim).
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Tuple[torch.Tensor, torch.Tensor]: The query and key tensors with the same shape and stride as inputs.
|
|
48
|
+
"""
|
|
49
|
+
orig_q_dtype, orig_k_dtype = q.dtype, k.dtype
|
|
50
|
+
|
|
51
|
+
# tranpose to (1, num_heads, seq_length, head_dim) and cast to float32 to match liger_rotary_pos_emb input shape
|
|
52
|
+
# also unsqueeze for batch dim
|
|
53
|
+
q32 = q.to(torch.float32).unsqueeze(0).transpose(1, 2)
|
|
54
|
+
k32 = k.to(torch.float32).unsqueeze(0).transpose(1, 2)
|
|
55
|
+
cos32 = cos.to(torch.float32)
|
|
56
|
+
sin32 = sin.to(torch.float32)
|
|
57
|
+
|
|
58
|
+
q_out, k_out = liger_rotary_pos_emb(q32, k32, cos32, sin32)
|
|
59
|
+
|
|
60
|
+
# transpose back to (seq_length, num_heads, head_dim) and cast back to original dtype
|
|
61
|
+
# also squeeze out batch dim
|
|
62
|
+
q_out = q_out.transpose(1, 2).squeeze(0).to(orig_q_dtype)
|
|
63
|
+
k_out = k_out.transpose(1, 2).squeeze(0).to(orig_k_dtype)
|
|
64
|
+
return q_out, k_out
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from liger_kernel.ops import LigerSparsemaxFunction
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerSparsemax(nn.Module):
|
|
8
|
+
def __init__(self, dim: int = -1):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.dim = dim
|
|
11
|
+
|
|
12
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
13
|
+
return LigerSparsemaxFunction.apply(x, self.dim)
|
|
14
|
+
|
|
15
|
+
def extra_repr(self) -> str:
|
|
16
|
+
return f"dim={self.dim}"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import torch.nn as nn
|
|
2
2
|
|
|
3
|
-
from liger_kernel.ops
|
|
3
|
+
from liger_kernel.ops import LigerSiLUMulFunction
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class LigerSwiGLUMLP(nn.Module):
|
|
@@ -16,10 +16,7 @@ class LigerSwiGLUMLP(nn.Module):
|
|
|
16
16
|
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
|
17
17
|
|
|
18
18
|
def forward(self, x):
|
|
19
|
-
|
|
20
|
-
return self.down_proj(
|
|
21
|
-
LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x))
|
|
22
|
-
)
|
|
19
|
+
return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
|
|
23
20
|
|
|
24
21
|
|
|
25
22
|
class LigerBlockSparseTop2MLP(nn.Module):
|
|
@@ -36,5 +33,64 @@ class LigerBlockSparseTop2MLP(nn.Module):
|
|
|
36
33
|
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
|
37
34
|
|
|
38
35
|
def forward(self, x):
|
|
39
|
-
|
|
40
36
|
return self.w2(LigerSiLUMulFunction.apply(self.w1(x), self.w3(x)))
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class LigerPhi3SwiGLUMLP(nn.Module):
|
|
40
|
+
"""
|
|
41
|
+
Patch Phi3MLP to use LigerSiLUMulFunction
|
|
42
|
+
https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/models/phi3/modeling_phi3.py#L241
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(self, config):
|
|
46
|
+
super().__init__()
|
|
47
|
+
self.config = config
|
|
48
|
+
self.hidden_size = config.hidden_size
|
|
49
|
+
self.intermediate_size = config.intermediate_size
|
|
50
|
+
self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
|
|
51
|
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
52
|
+
if config.hidden_act not in ["silu", "swish"]:
|
|
53
|
+
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
|
54
|
+
|
|
55
|
+
def forward(self, x):
|
|
56
|
+
up_states = self.gate_up_proj(x)
|
|
57
|
+
gate, up_states = up_states.chunk(2, dim=-1)
|
|
58
|
+
return self.down_proj(LigerSiLUMulFunction.apply(gate, up_states))
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class LigerQwen3MoeSwiGLUMLP(nn.Module):
|
|
62
|
+
"""
|
|
63
|
+
Patch Qwen3MoeMLP to use LigerSiLUMulFunction.
|
|
64
|
+
https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/qwen3_moe/modular_qwen3_moe.py#L57
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(self, config, intermediate_size=None):
|
|
68
|
+
super().__init__()
|
|
69
|
+
self.config = config
|
|
70
|
+
self.hidden_size = config.hidden_size
|
|
71
|
+
self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
|
|
72
|
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
73
|
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
74
|
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
75
|
+
if config.hidden_act not in ["silu", "swish"]:
|
|
76
|
+
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
|
77
|
+
|
|
78
|
+
def forward(self, x):
|
|
79
|
+
return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class LigerHunyuanV1SwiGLUMLP(nn.Module):
|
|
83
|
+
def __init__(self, config, layer_idx=None, is_shared_mlp=False):
|
|
84
|
+
super().__init__()
|
|
85
|
+
self.config = config
|
|
86
|
+
self.hidden_size = config.hidden_size
|
|
87
|
+
self.intermediate_size = config.intermediate_size
|
|
88
|
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
89
|
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
90
|
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
91
|
+
self.layer_idx = layer_idx
|
|
92
|
+
if config.hidden_act not in ["silu", "swish"]:
|
|
93
|
+
raise ValueError(f"Activation function {config.hidden_act} not supported.")
|
|
94
|
+
|
|
95
|
+
def forward(self, x):
|
|
96
|
+
return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
|
|
5
|
+
from liger_kernel.ops import LigerGELUMulFunction
|
|
6
|
+
from liger_kernel.ops import LigerSiLUMulFunction
|
|
7
|
+
from liger_kernel.ops import apply_tiled_mlp
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class LigerTiledGEGLUMLP(nn.Module):
|
|
11
|
+
"""
|
|
12
|
+
Memory-efficient GEGLU MLP using tiled computation.
|
|
13
|
+
|
|
14
|
+
This module combines GEGLU activation with tiled processing to handle
|
|
15
|
+
very long sequences efficiently. The forward pass is recomputed during
|
|
16
|
+
backward to save memory.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
config: Model configuration with hidden_size and intermediate_size attributes
|
|
20
|
+
num_shards: Number of shards to split the sequence. If None, automatically
|
|
21
|
+
calculated as ceil(seqlen / hidden_size)
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, config, num_shards: Optional[int] = None):
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.config = config
|
|
27
|
+
self.hidden_size = config.hidden_size
|
|
28
|
+
self.intermediate_size = config.intermediate_size
|
|
29
|
+
self.num_shards = num_shards
|
|
30
|
+
|
|
31
|
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
32
|
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
33
|
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
34
|
+
|
|
35
|
+
# Validate activation function
|
|
36
|
+
if hasattr(config, "hidden_act") and config.hidden_act not in [
|
|
37
|
+
"gelu",
|
|
38
|
+
"gelu_new",
|
|
39
|
+
"gelu_pytorch_tanh",
|
|
40
|
+
]:
|
|
41
|
+
raise ValueError(f"LigerTiledGEGLUMLP requires GELU activation, got {config.hidden_act}")
|
|
42
|
+
|
|
43
|
+
def _mlp_forward(self, module, x):
|
|
44
|
+
"""Internal MLP forward function for tiled computation."""
|
|
45
|
+
gate = module.gate_proj(x)
|
|
46
|
+
up = module.up_proj(x)
|
|
47
|
+
return module.down_proj(LigerGELUMulFunction.apply(gate, up))
|
|
48
|
+
|
|
49
|
+
def forward(self, x):
|
|
50
|
+
"""
|
|
51
|
+
Forward pass with tiled computation.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
x: Input tensor of shape [batch_size, seq_len, hidden_size]
|
|
55
|
+
or [seq_len, hidden_size]
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Output tensor of the same shape as input
|
|
59
|
+
"""
|
|
60
|
+
compute_params = [
|
|
61
|
+
self.gate_proj.weight,
|
|
62
|
+
self.up_proj.weight,
|
|
63
|
+
self.down_proj.weight,
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
return apply_tiled_mlp(
|
|
67
|
+
fn=self._mlp_forward,
|
|
68
|
+
mlp_module=self,
|
|
69
|
+
x=x,
|
|
70
|
+
num_shards=self.num_shards,
|
|
71
|
+
compute_params=compute_params,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class LigerTiledSwiGLUMLP(nn.Module):
|
|
76
|
+
"""
|
|
77
|
+
Memory-efficient SwiGLU MLP using tiled computation.
|
|
78
|
+
|
|
79
|
+
This module combines SwiGLU activation with tiled processing to handle
|
|
80
|
+
very long sequences efficiently. The forward pass is recomputed during
|
|
81
|
+
backward to save memory.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
config: Model configuration with hidden_size and intermediate_size attributes
|
|
85
|
+
num_shards: Number of shards to split the sequence. If None, automatically
|
|
86
|
+
calculated as ceil(seqlen / hidden_size)
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
def __init__(self, config, num_shards: Optional[int] = None):
|
|
90
|
+
super().__init__()
|
|
91
|
+
self.config = config
|
|
92
|
+
self.hidden_size = config.hidden_size
|
|
93
|
+
self.intermediate_size = config.intermediate_size
|
|
94
|
+
self.num_shards = num_shards
|
|
95
|
+
|
|
96
|
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
97
|
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
98
|
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
99
|
+
|
|
100
|
+
# Validate activation function
|
|
101
|
+
if hasattr(config, "hidden_act") and config.hidden_act not in ["silu", "swish"]:
|
|
102
|
+
raise ValueError(f"LigerTiledSwiGLUMLP requires SiLU/Swish activation, got {config.hidden_act}")
|
|
103
|
+
|
|
104
|
+
def _mlp_forward(self, module, x):
|
|
105
|
+
"""Internal MLP forward function for tiled computation."""
|
|
106
|
+
gate = module.gate_proj(x)
|
|
107
|
+
up = module.up_proj(x)
|
|
108
|
+
return module.down_proj(LigerSiLUMulFunction.apply(gate, up))
|
|
109
|
+
|
|
110
|
+
def forward(self, x):
|
|
111
|
+
"""
|
|
112
|
+
Forward pass with tiled computation.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
x: Input tensor of shape [batch_size, seq_len, hidden_size]
|
|
116
|
+
or [seq_len, hidden_size]
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
Output tensor of the same shape as input
|
|
120
|
+
"""
|
|
121
|
+
compute_params = [
|
|
122
|
+
self.gate_proj.weight,
|
|
123
|
+
self.up_proj.weight,
|
|
124
|
+
self.down_proj.weight,
|
|
125
|
+
]
|
|
126
|
+
|
|
127
|
+
return apply_tiled_mlp(
|
|
128
|
+
fn=self._mlp_forward,
|
|
129
|
+
mlp_module=self,
|
|
130
|
+
x=x,
|
|
131
|
+
num_shards=self.num_shards,
|
|
132
|
+
compute_params=compute_params,
|
|
133
|
+
)
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
from typing import Dict
|
|
2
|
+
from typing import List
|
|
3
|
+
from typing import Literal
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
from typing import Union
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
|
|
10
|
+
from torch.distributed.fsdp import FullyShardedDataParallel
|
|
11
|
+
from trl.trainer import ORPOTrainer
|
|
12
|
+
|
|
13
|
+
from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss
|
|
14
|
+
from liger_kernel.transformers.fsdp import _FSDPForwardRedirection
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class LigerORPOTrainer(ORPOTrainer):
|
|
18
|
+
def concatenated_forward(
|
|
19
|
+
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
|
|
20
|
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
|
21
|
+
"""
|
|
22
|
+
Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
|
23
|
+
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
|
24
|
+
"""
|
|
25
|
+
concatenated_batch = self.concatenated_inputs(
|
|
26
|
+
batch,
|
|
27
|
+
is_encoder_decoder=self.is_encoder_decoder,
|
|
28
|
+
label_pad_token_id=self.label_pad_token_id,
|
|
29
|
+
padding_value=self.padding_value,
|
|
30
|
+
device=self.accelerator.device,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
model_kwargs = (
|
|
34
|
+
{
|
|
35
|
+
"decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
|
|
36
|
+
}
|
|
37
|
+
if self.is_encoder_decoder
|
|
38
|
+
else {}
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
if self.aux_loss_enabled:
|
|
42
|
+
model_kwargs["output_router_logits"] = True
|
|
43
|
+
|
|
44
|
+
if self.is_encoder_decoder:
|
|
45
|
+
labels = concatenated_batch["concatenated_labels"].clone()
|
|
46
|
+
else:
|
|
47
|
+
labels = concatenated_batch["concatenated_input_ids"].clone()
|
|
48
|
+
attention_mask = concatenated_batch["concatenated_attention_mask"]
|
|
49
|
+
labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
|
|
50
|
+
|
|
51
|
+
if isinstance(model, FullyShardedDataParallel):
|
|
52
|
+
outputs = _FSDPForwardRedirection()(
|
|
53
|
+
model,
|
|
54
|
+
model._fsdp_wrapped_module.model,
|
|
55
|
+
concatenated_batch["concatenated_input_ids"],
|
|
56
|
+
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
|
57
|
+
use_cache=False,
|
|
58
|
+
**model_kwargs,
|
|
59
|
+
)
|
|
60
|
+
else:
|
|
61
|
+
if isinstance(model, torch.nn.DataParallel):
|
|
62
|
+
model = model.module
|
|
63
|
+
outputs = model.model(
|
|
64
|
+
concatenated_batch["concatenated_input_ids"],
|
|
65
|
+
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
|
66
|
+
use_cache=False,
|
|
67
|
+
**model_kwargs,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta)
|
|
71
|
+
|
|
72
|
+
def orpo_partial(lm_head, last_hidden_state, concatenated_labels, nll_target):
|
|
73
|
+
return orpo_loss_fn(
|
|
74
|
+
lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias, nll_target=nll_target
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
orpo_loss, aux_outputs = _FSDPForwardRedirection()(
|
|
78
|
+
model,
|
|
79
|
+
orpo_partial,
|
|
80
|
+
model.lm_head,
|
|
81
|
+
outputs.last_hidden_state[:, :-1] if not self.is_encoder_decoder else outputs.last_hidden_state,
|
|
82
|
+
concatenated_batch["concatenated_labels"][:, 1:]
|
|
83
|
+
if not self.is_encoder_decoder
|
|
84
|
+
else concatenated_batch["concatenated_labels"],
|
|
85
|
+
labels[:, 1:] if not self.is_encoder_decoder else labels,
|
|
86
|
+
)
|
|
87
|
+
# if aux_loss_enabled, add the aux_loss to the orpo_loss
|
|
88
|
+
if self.aux_loss_enabled:
|
|
89
|
+
orpo_loss += self.aux_loss_coef * outputs.aux_loss
|
|
90
|
+
|
|
91
|
+
return orpo_loss, aux_outputs
|
|
92
|
+
|
|
93
|
+
def get_batch_loss_metrics(
|
|
94
|
+
self,
|
|
95
|
+
model,
|
|
96
|
+
batch: Dict[str, Union[List, torch.LongTensor]],
|
|
97
|
+
train_eval: Literal["train", "eval"] = "train",
|
|
98
|
+
):
|
|
99
|
+
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
|
|
100
|
+
metrics = {}
|
|
101
|
+
loss, aux_outputs = self.concatenated_forward(model, batch)
|
|
102
|
+
(
|
|
103
|
+
policy_chosen_logps,
|
|
104
|
+
policy_rejected_logps,
|
|
105
|
+
policy_chosen_logits,
|
|
106
|
+
policy_rejected_logits,
|
|
107
|
+
policy_nll_loss,
|
|
108
|
+
) = aux_outputs[:5]
|
|
109
|
+
|
|
110
|
+
# return loss, metrics
|
|
111
|
+
chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = aux_outputs[5:]
|
|
112
|
+
|
|
113
|
+
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
|
114
|
+
|
|
115
|
+
prefix = "eval_" if train_eval == "eval" else ""
|
|
116
|
+
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean()
|
|
117
|
+
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean()
|
|
118
|
+
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean()
|
|
119
|
+
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean()
|
|
120
|
+
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean()
|
|
121
|
+
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean()
|
|
122
|
+
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean()
|
|
123
|
+
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean()
|
|
124
|
+
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean()
|
|
125
|
+
metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio
|
|
126
|
+
metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen
|
|
127
|
+
for k, v in metrics.items():
|
|
128
|
+
metrics[k] = v.item()
|
|
129
|
+
|
|
130
|
+
return loss, metrics
|