liger-kernel-nightly 0.4.2.dev20241117192137__tar.gz → 0.4.2.dev20241119054456__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {liger_kernel_nightly-0.4.2.dev20241117192137/src/liger_kernel_nightly.egg-info → liger_kernel_nightly-0.4.2.dev20241119054456}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/pyproject.toml +1 -1
- liger_kernel_nightly-0.4.2.dev20241119054456/src/liger_kernel/chunked_loss/cpo_loss.py +61 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/chunked_loss/fused_linear_preference.py +6 -1
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/transformers/monkey_patch.py +0 -2
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456/src/liger_kernel_nightly.egg-info}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel_nightly.egg-info/SOURCES.txt +1 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/LICENSE +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/NOTICE +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/README.md +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/setup.cfg +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/ops/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/transformers/__init__.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/transformers/functional.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/transformers/model/gemma2.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "liger_kernel_nightly"
|
7
|
-
version = "0.4.2.
|
7
|
+
version = "0.4.2.dev20241119054456"
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
@@ -0,0 +1,61 @@
|
|
1
|
+
import torch.nn.functional as F
|
2
|
+
|
3
|
+
from liger_kernel.chunked_loss.fused_linear_preference import (
|
4
|
+
LigerFusedLinearPreferenceBase,
|
5
|
+
)
|
6
|
+
|
7
|
+
|
8
|
+
class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
9
|
+
|
10
|
+
@staticmethod
|
11
|
+
def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
|
12
|
+
"""
|
13
|
+
Compute odds-ratio loss.
|
14
|
+
Args:
|
15
|
+
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
|
16
|
+
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
|
17
|
+
beta (float): Weight for the odds ratio loss.
|
18
|
+
"""
|
19
|
+
logits = beta * (chosen_logps - rejected_logps)
|
20
|
+
loss = F.logsigmoid(logits).mean()
|
21
|
+
return loss
|
22
|
+
|
23
|
+
@staticmethod
|
24
|
+
def forward(
|
25
|
+
ctx,
|
26
|
+
_input,
|
27
|
+
weight,
|
28
|
+
target,
|
29
|
+
bias=None,
|
30
|
+
ignore_index=-100,
|
31
|
+
beta=0.1,
|
32
|
+
alpha=1.0,
|
33
|
+
compute_nll_loss=True,
|
34
|
+
compiled=True,
|
35
|
+
):
|
36
|
+
"""
|
37
|
+
Fused linear layer with CPO (Odds-Ratio Preference Optimization) loss.
|
38
|
+
Handles both the forward and backward pass of the final linear layer with CPO loss.
|
39
|
+
Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
|
40
|
+
"""
|
41
|
+
|
42
|
+
return LigerFusedLinearPreferenceBase.forward(
|
43
|
+
ctx,
|
44
|
+
_input,
|
45
|
+
weight,
|
46
|
+
target,
|
47
|
+
bias,
|
48
|
+
loss_fn=LigerFusedLinearCPOFunction.preference_loss_fn,
|
49
|
+
compute_nll_loss=compute_nll_loss,
|
50
|
+
ignore_index=ignore_index,
|
51
|
+
alpha=alpha,
|
52
|
+
beta=beta,
|
53
|
+
compiled=compiled,
|
54
|
+
)
|
55
|
+
|
56
|
+
@staticmethod
|
57
|
+
def backward(ctx, grad_output):
|
58
|
+
# Get gradients for _input, weight, bias, and target from the base class
|
59
|
+
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
60
|
+
# Return these gradients, followed by None for the remaining inputs
|
61
|
+
return *grads, None, None, None, None, None
|
@@ -29,6 +29,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
29
29
|
chunk_size=1,
|
30
30
|
compute_nll_loss=True,
|
31
31
|
ignore_index=-100,
|
32
|
+
alpha=1.0,
|
32
33
|
beta=0.1,
|
33
34
|
compiled=True,
|
34
35
|
):
|
@@ -45,6 +46,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
45
46
|
chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
|
46
47
|
compute_nll_loss (bool): Whether to compute NLL loss.
|
47
48
|
ignore_index (int): Index to ignore for loss computation.
|
49
|
+
alpha (float): Weight for the NLL loss.
|
48
50
|
beta (float): Weight for the odds ratio loss.
|
49
51
|
compiled (bool): Whether to use torch compile for chunk accumulation.
|
50
52
|
"""
|
@@ -62,6 +64,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
62
64
|
LigerFusedLinearPreferenceBase._compute_loss,
|
63
65
|
preference_loss_fn=loss_fn,
|
64
66
|
ignore_index=ignore_index,
|
67
|
+
alpha=alpha,
|
65
68
|
beta=beta,
|
66
69
|
compute_nll_loss=compute_nll_loss,
|
67
70
|
full_target=target,
|
@@ -149,6 +152,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
149
152
|
preference_loss_fn=None,
|
150
153
|
full_target=None,
|
151
154
|
ignore_index=-100,
|
155
|
+
alpha=1.0,
|
152
156
|
beta=0.1,
|
153
157
|
compute_nll_loss=True,
|
154
158
|
**loss_kwargs,
|
@@ -163,6 +167,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
163
167
|
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
164
168
|
full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
|
165
169
|
ignore_index (int): Index to ignore for loss computation.
|
170
|
+
alpha (float): Weight for the NLL loss.
|
166
171
|
beta (float): Weight for the odds ratio loss.
|
167
172
|
loss_kwargs (dict): Additional arguments for the loss function.
|
168
173
|
"""
|
@@ -202,5 +207,5 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
202
207
|
)
|
203
208
|
alignment_loss = alignment_loss / (full_target.shape[0] // 2)
|
204
209
|
|
205
|
-
loss = chosen_nll_loss - alignment_loss
|
210
|
+
loss = alpha * chosen_nll_loss - alignment_loss
|
206
211
|
return loss, (alignment_loss, chosen_logps, rejected_logps)
|
@@ -610,9 +610,7 @@ def apply_liger_kernel_to_qwen2(
|
|
610
610
|
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
611
611
|
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
|
612
612
|
|
613
|
-
# import pdb; pdb.set_trace()
|
614
613
|
if fused_linear_cross_entropy:
|
615
|
-
|
616
614
|
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
617
615
|
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
|
618
616
|
else: # if version < 4.46.1
|
@@ -4,6 +4,7 @@ README.md
|
|
4
4
|
pyproject.toml
|
5
5
|
src/liger_kernel/env_report.py
|
6
6
|
src/liger_kernel/chunked_loss/__init__.py
|
7
|
+
src/liger_kernel/chunked_loss/cpo_loss.py
|
7
8
|
src/liger_kernel/chunked_loss/dpo_loss.py
|
8
9
|
src/liger_kernel/chunked_loss/fused_linear_preference.py
|
9
10
|
src/liger_kernel/chunked_loss/orpo_loss.py
|
File without changes
|
{liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054456}/NOTICE
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|