liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/__init__.py +0 -0
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +8 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
- liger_kernel/chunked_loss/cpo_loss.py +157 -0
- liger_kernel/chunked_loss/dpo_loss.py +229 -0
- liger_kernel/chunked_loss/functional.py +17 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
- liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
- liger_kernel/chunked_loss/grpo_loss.py +304 -0
- liger_kernel/chunked_loss/jsd_loss.py +200 -0
- liger_kernel/chunked_loss/kto_loss.py +210 -0
- liger_kernel/chunked_loss/orpo_loss.py +144 -0
- liger_kernel/chunked_loss/simpo_loss.py +165 -0
- liger_kernel/env_report.py +21 -4
- liger_kernel/ops/cross_entropy.py +235 -84
- liger_kernel/ops/dyt.py +157 -0
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_add_rms_norm.py +412 -0
- liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
- liger_kernel/ops/fused_linear_jsd.py +17 -34
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +7 -18
- liger_kernel/ops/group_norm.py +305 -0
- liger_kernel/ops/grpo_loss.py +310 -0
- liger_kernel/ops/jsd.py +46 -21
- liger_kernel/ops/kl_div.py +23 -19
- liger_kernel/ops/layer_norm.py +150 -86
- liger_kernel/ops/llama4_rope.py +225 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/poly_norm.py +386 -0
- liger_kernel/ops/qwen2vl_mrope.py +222 -0
- liger_kernel/ops/rms_norm.py +314 -84
- liger_kernel/ops/rope.py +32 -34
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +5 -9
- liger_kernel/ops/tiled_mlp.py +136 -0
- liger_kernel/ops/tvd.py +207 -0
- liger_kernel/ops/utils.py +8 -4
- liger_kernel/transformers/__init__.py +199 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +33 -20
- liger_kernel/transformers/dyt.py +22 -0
- liger_kernel/transformers/experimental/__init__.py +5 -0
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +291 -13
- liger_kernel/transformers/fused_add_rms_norm.py +39 -0
- liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
- liger_kernel/transformers/fused_linear_jsd.py +1 -4
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +50 -0
- liger_kernel/transformers/grpo_loss.py +98 -0
- liger_kernel/transformers/jsd.py +2 -7
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/llama4_rope.py +93 -0
- liger_kernel/transformers/model/falcon_h1.py +122 -0
- liger_kernel/transformers/model/gemma.py +77 -77
- liger_kernel/transformers/model/gemma2.py +283 -0
- liger_kernel/transformers/model/gemma3.py +331 -0
- liger_kernel/transformers/model/glm4.py +141 -0
- liger_kernel/transformers/model/glm4v.py +163 -0
- liger_kernel/transformers/model/glm4v_moe.py +172 -0
- liger_kernel/transformers/model/internvl.py +157 -0
- liger_kernel/transformers/model/llama.py +128 -79
- liger_kernel/transformers/model/llama4.py +121 -0
- liger_kernel/transformers/model/llava.py +344 -0
- liger_kernel/transformers/model/loss_utils.py +95 -0
- liger_kernel/transformers/model/mistral.py +68 -64
- liger_kernel/transformers/model/mixtral.py +75 -91
- liger_kernel/transformers/model/mllama.py +63 -68
- liger_kernel/transformers/model/olmo2.py +141 -0
- liger_kernel/transformers/model/output_classes.py +147 -0
- liger_kernel/transformers/model/paligemma.py +432 -0
- liger_kernel/transformers/model/phi3.py +59 -213
- liger_kernel/transformers/model/qwen2.py +75 -72
- liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
- liger_kernel/transformers/model/qwen2_vl.py +78 -98
- liger_kernel/transformers/model/qwen3.py +136 -0
- liger_kernel/transformers/model/qwen3_moe.py +152 -0
- liger_kernel/transformers/model/qwen3_next.py +146 -0
- liger_kernel/transformers/model/qwen3_vl.py +150 -0
- liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
- liger_kernel/transformers/model/smollm3.py +199 -0
- liger_kernel/transformers/model/smolvlm.py +158 -0
- liger_kernel/transformers/monkey_patch.py +2106 -289
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/poly_norm.py +42 -0
- liger_kernel/transformers/qwen2vl_mrope.py +20 -0
- liger_kernel/transformers/rms_norm.py +57 -6
- liger_kernel/transformers/rope.py +45 -2
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +23 -8
- liger_kernel/transformers/tiled_mlp.py +133 -0
- liger_kernel/transformers/trainer/__init__.py +4 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
- liger_kernel/transformers/tvd.py +13 -0
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- liger_kernel/utils.py +71 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
- liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
- liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
|
|
8
|
+
@staticmethod
|
|
9
|
+
def preference_loss_fn(
|
|
10
|
+
chosen_logps,
|
|
11
|
+
rejected_logps,
|
|
12
|
+
full_target,
|
|
13
|
+
beta=0.1,
|
|
14
|
+
gamma=0.5,
|
|
15
|
+
label_smoothing=0.0,
|
|
16
|
+
):
|
|
17
|
+
"""
|
|
18
|
+
Paper: https://arxiv.org/pdf/2405.14734
|
|
19
|
+
|
|
20
|
+
Formula:
|
|
21
|
+
L_SimPO(π_θ) = -E [log σ(β/|y_w| log π_θ(y_w|x) - β/|y_l| log π_θ(y_l|x) - γ)]
|
|
22
|
+
|
|
23
|
+
Where:
|
|
24
|
+
- π_θ(y|x): Policy (model) probability
|
|
25
|
+
- y_w: Chosen sequence
|
|
26
|
+
- y_l: Rejected sequence
|
|
27
|
+
- |y_w|, |y_l|: Sequence lengths
|
|
28
|
+
- σ: Sigmoid function
|
|
29
|
+
- β: beta weight
|
|
30
|
+
- γ: gemma margin term
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
|
|
34
|
+
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
|
|
35
|
+
full_target: Non chunked full target tensor
|
|
36
|
+
beta (float): beta weight
|
|
37
|
+
gamma (float): gemma margin term
|
|
38
|
+
label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
|
|
39
|
+
"""
|
|
40
|
+
logits = beta * (chosen_logps - rejected_logps) - gamma
|
|
41
|
+
loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
|
|
42
|
+
full_target.shape[0] // 2
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
chosen_rewards = beta * chosen_logps
|
|
46
|
+
rejected_rewards = beta * rejected_logps
|
|
47
|
+
|
|
48
|
+
return loss, chosen_rewards, rejected_rewards
|
|
49
|
+
|
|
50
|
+
@classmethod
|
|
51
|
+
def forward(
|
|
52
|
+
cls,
|
|
53
|
+
ctx,
|
|
54
|
+
_input,
|
|
55
|
+
weight,
|
|
56
|
+
target,
|
|
57
|
+
bias=None,
|
|
58
|
+
ignore_index=-100,
|
|
59
|
+
beta=0.1,
|
|
60
|
+
alpha=1.0,
|
|
61
|
+
label_smoothing=0.0,
|
|
62
|
+
compute_nll_loss=False,
|
|
63
|
+
compiled=True,
|
|
64
|
+
gamma=0.5,
|
|
65
|
+
chunk_size=1,
|
|
66
|
+
):
|
|
67
|
+
"""
|
|
68
|
+
Fused linear layer with SimPO loss.
|
|
69
|
+
Args:
|
|
70
|
+
_input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
|
|
71
|
+
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
|
|
72
|
+
target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
|
|
73
|
+
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
|
|
74
|
+
ignore_index (int): Index to ignore in loss computation
|
|
75
|
+
beta (float): Weight for the odds ratio loss
|
|
76
|
+
alpha (float): Weight for the alpha parameter
|
|
77
|
+
label_smoothing (float): Label smoothing factor
|
|
78
|
+
compute_nll_loss (bool): Whether to compute the NLL loss
|
|
79
|
+
compiled (bool): Whether to use torch compile
|
|
80
|
+
gamma (float): Weight for the gamma parameter
|
|
81
|
+
chunk_size (int): Size of chunks for processing
|
|
82
|
+
Returns:
|
|
83
|
+
torch.Tensor: Computed loss
|
|
84
|
+
"""
|
|
85
|
+
return super().forward(
|
|
86
|
+
cls=cls,
|
|
87
|
+
ctx=ctx,
|
|
88
|
+
_input=_input,
|
|
89
|
+
weight=weight,
|
|
90
|
+
target=target,
|
|
91
|
+
bias=bias,
|
|
92
|
+
ignore_index=ignore_index,
|
|
93
|
+
alpha=alpha,
|
|
94
|
+
beta=beta,
|
|
95
|
+
label_smoothing=label_smoothing,
|
|
96
|
+
compute_nll_loss=compute_nll_loss,
|
|
97
|
+
compiled=compiled,
|
|
98
|
+
gamma=gamma,
|
|
99
|
+
chunk_size=chunk_size,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
@staticmethod
|
|
103
|
+
def backward(ctx, *grad_output):
|
|
104
|
+
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
105
|
+
return *grads, None, None, None, None, None, None, None, None
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class LigerFusedLinearSimPOLoss(torch.nn.Module):
|
|
109
|
+
"""
|
|
110
|
+
Fused linear layer with SimPO loss.
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
def __init__(
|
|
114
|
+
self,
|
|
115
|
+
ignore_index: int = -100,
|
|
116
|
+
beta: float = 0.1,
|
|
117
|
+
alpha: float = 1.0,
|
|
118
|
+
label_smoothing: float = 0.0,
|
|
119
|
+
compute_nll_loss: bool = True,
|
|
120
|
+
compiled: bool = True,
|
|
121
|
+
gamma: float = 0.5,
|
|
122
|
+
chunk_size: int = 1,
|
|
123
|
+
):
|
|
124
|
+
"""
|
|
125
|
+
Args:
|
|
126
|
+
ignore_index (int): Index to ignore in the loss.
|
|
127
|
+
beta (float): Weight for the odds ratio loss.
|
|
128
|
+
alpha (float): Weight for the alpha parameter.
|
|
129
|
+
label_smoothing (float): Label smoothing factor.
|
|
130
|
+
compute_nll_loss (bool): Whether to compute the NLL loss.
|
|
131
|
+
compiled (bool): Whether to use the torch compiled kernel.
|
|
132
|
+
gamma (float): Weight for the gamma parameter.
|
|
133
|
+
chunk_size (int): Size of chunks for processing.
|
|
134
|
+
"""
|
|
135
|
+
super().__init__()
|
|
136
|
+
self.ignore_index = ignore_index
|
|
137
|
+
self.beta = beta
|
|
138
|
+
self.alpha = alpha
|
|
139
|
+
self.label_smoothing = label_smoothing
|
|
140
|
+
self.compute_nll_loss = compute_nll_loss
|
|
141
|
+
self.compiled = compiled
|
|
142
|
+
self.gamma = gamma
|
|
143
|
+
self.chunk_size = chunk_size
|
|
144
|
+
|
|
145
|
+
def forward(
|
|
146
|
+
self,
|
|
147
|
+
lin_weight,
|
|
148
|
+
_input,
|
|
149
|
+
target,
|
|
150
|
+
bias=None,
|
|
151
|
+
):
|
|
152
|
+
return LigerFusedLinearSimPOFunction.apply(
|
|
153
|
+
_input,
|
|
154
|
+
lin_weight,
|
|
155
|
+
target,
|
|
156
|
+
bias,
|
|
157
|
+
self.ignore_index,
|
|
158
|
+
self.beta,
|
|
159
|
+
self.alpha,
|
|
160
|
+
self.label_smoothing,
|
|
161
|
+
self.compute_nll_loss,
|
|
162
|
+
self.compiled,
|
|
163
|
+
self.gamma,
|
|
164
|
+
self.chunk_size,
|
|
165
|
+
)
|
liger_kernel/env_report.py
CHANGED
|
@@ -1,31 +1,42 @@
|
|
|
1
1
|
import platform
|
|
2
2
|
import sys
|
|
3
3
|
|
|
4
|
+
from importlib.metadata import version
|
|
5
|
+
|
|
4
6
|
|
|
5
7
|
def print_env_report():
|
|
6
8
|
"""
|
|
7
|
-
|
|
9
|
+
|
|
10
|
+
Prints a report of the environment. Useful for debugging and reproducibility.
|
|
8
11
|
Usage:
|
|
9
12
|
```
|
|
10
13
|
python -m liger_kernel.env_report
|
|
11
14
|
```
|
|
15
|
+
|
|
12
16
|
"""
|
|
13
17
|
print("Environment Report:")
|
|
14
18
|
print("-------------------")
|
|
15
19
|
print(f"Operating System: {platform.platform()}")
|
|
16
20
|
print(f"Python version: {sys.version.split()[0]}")
|
|
17
21
|
|
|
22
|
+
try:
|
|
23
|
+
print(f"Liger Kernel version: {version('liger-kernel')}")
|
|
24
|
+
except ImportError:
|
|
25
|
+
print("Liger Kernel: Not installed")
|
|
26
|
+
|
|
18
27
|
try:
|
|
19
28
|
import torch
|
|
20
29
|
|
|
21
30
|
print(f"PyTorch version: {torch.__version__}")
|
|
22
|
-
cuda_version = (
|
|
23
|
-
torch.version.cuda if torch.cuda.is_available() else "Not available"
|
|
24
|
-
)
|
|
31
|
+
cuda_version = torch.version.cuda if torch.cuda.is_available() else "Not available"
|
|
25
32
|
print(f"CUDA version: {cuda_version}")
|
|
33
|
+
hip_version = torch.version.hip if torch.cuda.is_available() and torch.version.hip else "Not available"
|
|
34
|
+
print(f"HIP(ROCm) version: {hip_version}")
|
|
35
|
+
|
|
26
36
|
except ImportError:
|
|
27
37
|
print("PyTorch: Not installed")
|
|
28
38
|
print("CUDA version: Unable to query")
|
|
39
|
+
print("HIP(ROCm) version: Unable to query")
|
|
29
40
|
|
|
30
41
|
try:
|
|
31
42
|
import triton
|
|
@@ -41,6 +52,12 @@ def print_env_report():
|
|
|
41
52
|
except ImportError:
|
|
42
53
|
print("Transformers: Not installed")
|
|
43
54
|
|
|
55
|
+
try:
|
|
56
|
+
xpu_version = torch.version.xpu if torch.xpu.is_available() else "XPU Not Available"
|
|
57
|
+
print(f"XPU version: {xpu_version}")
|
|
58
|
+
except ImportError:
|
|
59
|
+
print("XPU version: Unable to query")
|
|
60
|
+
|
|
44
61
|
|
|
45
62
|
if __name__ == "__main__":
|
|
46
63
|
print_env_report()
|