liger-kernel-nightly 0.4.2.dev20241121225747__tar.gz → 0.4.2.dev20241122175637__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {liger_kernel_nightly-0.4.2.dev20241121225747/src/liger_kernel_nightly.egg-info → liger_kernel_nightly-0.4.2.dev20241122175637}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/pyproject.toml +1 -1
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/chunked_loss/dpo_loss.py +36 -4
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/chunked_loss/fused_linear_preference.py +79 -27
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/cross_entropy.py +12 -6
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -11
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637/src/liger_kernel_nightly.egg-info}/PKG-INFO +1 -1
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/LICENSE +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/NOTICE +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/README.md +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/setup.cfg +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/chunked_loss/__init__.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/chunked_loss/functional.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/env_report.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/__init__.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/geglu.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/group_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/jsd.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/kl_div.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/layer_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/rms_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/rope.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/swiglu.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/ops/utils.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/__init__.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/auto_model.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/cross_entropy.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/functional.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/geglu.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/group_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/jsd.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/kl_div.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/layer_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/model/__init__.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/model/gemma.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/model/gemma2.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/model/llama.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/model/mistral.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/model/mixtral.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/model/mllama.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/model/phi3.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/model/qwen2.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/rms_norm.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/rope.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/swiglu.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/transformers/trainer_integration.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/triton/__init__.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel/triton/monkey_patch.py +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
- {liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "liger_kernel_nightly"
|
7
|
-
version = "0.4.2.
|
7
|
+
version = "0.4.2.dev20241122175637"
|
8
8
|
description = "Efficient Triton kernels for LLM Training"
|
9
9
|
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
|
10
10
|
readme = { file = "README.md", content-type = "text/markdown" }
|
@@ -9,15 +9,31 @@ from liger_kernel.chunked_loss.fused_linear_preference import (
|
|
9
9
|
class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
10
10
|
|
11
11
|
@staticmethod
|
12
|
-
def preference_loss_fn(
|
12
|
+
def preference_loss_fn(
|
13
|
+
chosen_logps,
|
14
|
+
rejected_logps,
|
15
|
+
ref_chosen_logps=None,
|
16
|
+
ref_rejected_logps=None,
|
17
|
+
beta=0.1,
|
18
|
+
):
|
13
19
|
"""
|
14
20
|
Compute DPO loss (Direct Preference Optimization).
|
15
21
|
Args:
|
16
22
|
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
|
17
23
|
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
|
24
|
+
ref_chosen_logps (torch.Tensor, optional): Reference log probabilities of chosen tokens. Shape: (batch_size,).
|
25
|
+
ref_rejected_logps (torch.Tensor, optional): Reference log probabilities of rejected tokens. Shape: (batch_size,).
|
18
26
|
beta (float): Weight for the direct preference loss.
|
19
27
|
"""
|
20
|
-
|
28
|
+
if ref_chosen_logps is None:
|
29
|
+
ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device)
|
30
|
+
if ref_rejected_logps is None:
|
31
|
+
ref_rejected_logps = torch.tensor(0.0, device=rejected_logps.device)
|
32
|
+
|
33
|
+
chosen_logratios = chosen_logps - ref_chosen_logps
|
34
|
+
rejected_logratios = rejected_logps - ref_rejected_logps
|
35
|
+
|
36
|
+
logits_diff = beta * (chosen_logratios - rejected_logratios)
|
21
37
|
losses = -F.logsigmoid(logits_diff)
|
22
38
|
return losses.sum()
|
23
39
|
|
@@ -28,10 +44,13 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
28
44
|
weight,
|
29
45
|
target,
|
30
46
|
bias=None,
|
47
|
+
ref_weight=None,
|
48
|
+
ref_bias=None,
|
31
49
|
ignore_index=-100,
|
32
50
|
beta=0.1,
|
33
51
|
compute_nll_loss=True,
|
34
52
|
compiled=True,
|
53
|
+
use_ref_model=True,
|
35
54
|
):
|
36
55
|
"""
|
37
56
|
Fused linear layer with DPO (Direct Preference Optimization) loss.
|
@@ -48,6 +67,9 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
48
67
|
beta=beta,
|
49
68
|
compute_nll_loss=compute_nll_loss,
|
50
69
|
compiled=compiled,
|
70
|
+
use_ref_model=use_ref_model,
|
71
|
+
ref_weight=ref_weight,
|
72
|
+
ref_bias=ref_bias,
|
51
73
|
)
|
52
74
|
|
53
75
|
@staticmethod
|
@@ -55,7 +77,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
55
77
|
# Get gradients for _input, weight, bias, and target from the base class
|
56
78
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
57
79
|
# Return these gradients, followed by None for the remaining inputs
|
58
|
-
return *grads, None, None, None, None
|
80
|
+
return *grads, None, None, None, None, None, None, None
|
59
81
|
|
60
82
|
|
61
83
|
class LigerFusedLinearDPOLoss(torch.nn.Module):
|
@@ -69,26 +91,36 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
69
91
|
beta: float = 0.1,
|
70
92
|
compute_nll_loss: bool = True,
|
71
93
|
compiled: bool = True,
|
94
|
+
use_ref_model: bool = False,
|
72
95
|
):
|
73
96
|
"""
|
74
97
|
Args:
|
75
98
|
ignore_index (int): Index to ignore in the loss.
|
76
99
|
beta (float): Weight for the odds ratio loss.
|
100
|
+
compute_nll_loss (bool): Whether to compute the NLL loss.
|
101
|
+
compiled (bool): Whether to use the torch compiled kernel.
|
102
|
+
use_ref_model (bool): Whether to use a reference model for the DPO loss.
|
77
103
|
"""
|
78
104
|
super().__init__()
|
79
105
|
self.ignore_index = ignore_index
|
80
106
|
self.beta = beta
|
81
107
|
self.compute_nll_loss = compute_nll_loss
|
82
108
|
self.compiled = compiled
|
109
|
+
self.use_ref_model = use_ref_model
|
83
110
|
|
84
|
-
def forward(
|
111
|
+
def forward(
|
112
|
+
self, lin_weight, _input, target, bias=None, ref_weight=None, ref_bias=None
|
113
|
+
):
|
85
114
|
return LigerFusedLinearDPOFunction.apply(
|
86
115
|
_input,
|
87
116
|
lin_weight,
|
88
117
|
target,
|
89
118
|
bias,
|
119
|
+
ref_weight,
|
120
|
+
ref_bias,
|
90
121
|
self.ignore_index,
|
91
122
|
self.beta,
|
92
123
|
self.compute_nll_loss,
|
93
124
|
self.compiled,
|
125
|
+
self.use_ref_model,
|
94
126
|
)
|
@@ -18,6 +18,42 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
18
18
|
"""
|
19
19
|
raise NotImplementedError("Preference loss function must be implemented.")
|
20
20
|
|
21
|
+
@staticmethod
|
22
|
+
def chunk_forward(
|
23
|
+
input_chunk,
|
24
|
+
weight,
|
25
|
+
target_chunk,
|
26
|
+
bias=None,
|
27
|
+
ignore_index=-100,
|
28
|
+
compute_nll_loss=True,
|
29
|
+
):
|
30
|
+
len_chosen_chunk = target_chunk.shape[0] // 2
|
31
|
+
logits_chunk = input_chunk @ weight.t()
|
32
|
+
if bias is not None:
|
33
|
+
logits_chunk = logits_chunk + bias
|
34
|
+
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
|
35
|
+
|
36
|
+
chosen_nll_loss = 0.0
|
37
|
+
if compute_nll_loss:
|
38
|
+
chosen_nll_loss = F.nll_loss(
|
39
|
+
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
|
40
|
+
target_chunk[:len_chosen_chunk].view(-1),
|
41
|
+
reduction="sum",
|
42
|
+
ignore_index=ignore_index,
|
43
|
+
)
|
44
|
+
|
45
|
+
loss_mask = target_chunk != ignore_index
|
46
|
+
label_chunk = torch.where(loss_mask, target_chunk, 0)
|
47
|
+
|
48
|
+
per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
|
49
|
+
-1
|
50
|
+
)
|
51
|
+
average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
52
|
+
|
53
|
+
chosen_logps = average_log_prob[:len_chosen_chunk]
|
54
|
+
rejected_logps = average_log_prob[len_chosen_chunk:]
|
55
|
+
return chosen_logps, rejected_logps, chosen_nll_loss
|
56
|
+
|
21
57
|
@staticmethod
|
22
58
|
def forward(
|
23
59
|
ctx,
|
@@ -32,6 +68,9 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
32
68
|
beta=0.1,
|
33
69
|
compute_nll_loss=True,
|
34
70
|
compiled=True,
|
71
|
+
use_ref_model=False,
|
72
|
+
ref_weight=None,
|
73
|
+
ref_bias=None,
|
35
74
|
**loss_kwargs,
|
36
75
|
):
|
37
76
|
"""
|
@@ -49,7 +88,11 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
49
88
|
ignore_index (int): Index to ignore for loss computation.
|
50
89
|
alpha (float): Weight for the NLL loss.
|
51
90
|
beta (float): Weight for the odds ratio loss.
|
91
|
+
compute_nll_loss (bool): Whether to compute NLL loss.
|
52
92
|
compiled (bool): Whether to use torch compile for chunk accumulation.
|
93
|
+
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
94
|
+
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
95
|
+
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
53
96
|
loss_kwargs (dict): Other possible arguments that a loss function might need
|
54
97
|
"""
|
55
98
|
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
|
@@ -61,7 +104,6 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
61
104
|
grad_bias = torch.zeros_like(bias) if bias is not None else None
|
62
105
|
loss_acc = torch.zeros((), device=_input.device)
|
63
106
|
|
64
|
-
chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
|
65
107
|
loss_func_to_call = partial(
|
66
108
|
LigerFusedLinearPreferenceBase._compute_loss,
|
67
109
|
preference_loss_fn=loss_fn,
|
@@ -70,6 +112,9 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
70
112
|
beta=beta,
|
71
113
|
compute_nll_loss=compute_nll_loss,
|
72
114
|
full_target=target,
|
115
|
+
use_ref_model=use_ref_model,
|
116
|
+
ref_weight=ref_weight,
|
117
|
+
ref_bias=ref_bias,
|
73
118
|
**loss_kwargs,
|
74
119
|
)
|
75
120
|
|
@@ -101,6 +146,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
101
146
|
accumulate_chunk = torch.compile(accumulate_chunk)
|
102
147
|
|
103
148
|
len_chosen = target.shape[0] // 2
|
149
|
+
chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
|
104
150
|
_chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0)
|
105
151
|
_chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0)
|
106
152
|
_rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
|
@@ -159,6 +205,9 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
159
205
|
alpha=1.0,
|
160
206
|
beta=0.1,
|
161
207
|
compute_nll_loss=True,
|
208
|
+
use_ref_model=False,
|
209
|
+
ref_weight=None,
|
210
|
+
ref_bias=None,
|
162
211
|
**loss_kwargs,
|
163
212
|
):
|
164
213
|
"""
|
@@ -173,38 +222,41 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
173
222
|
ignore_index (int): Index to ignore for loss computation.
|
174
223
|
alpha (float): Weight for the NLL loss.
|
175
224
|
beta (float): Weight for the odds ratio loss.
|
225
|
+
compute_nll_loss (bool): Whether to compute NLL loss.
|
226
|
+
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
227
|
+
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
228
|
+
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
176
229
|
loss_kwargs (dict): Additional arguments for the loss function.
|
177
230
|
"""
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
chosen_nll_loss = 0.0
|
186
|
-
if compute_nll_loss:
|
187
|
-
chosen_nll_loss = F.nll_loss(
|
188
|
-
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
|
189
|
-
target_chunk[:len_chosen_chunk].view(-1),
|
190
|
-
reduction="sum",
|
231
|
+
chosen_logps, rejected_logps, chosen_nll_loss = (
|
232
|
+
LigerFusedLinearPreferenceBase.chunk_forward(
|
233
|
+
input_chunk,
|
234
|
+
weight,
|
235
|
+
target_chunk,
|
236
|
+
bias=bias,
|
191
237
|
ignore_index=ignore_index,
|
238
|
+
compute_nll_loss=compute_nll_loss,
|
192
239
|
)
|
193
|
-
chosen_nll_loss = (
|
194
|
-
chosen_nll_loss
|
195
|
-
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
|
196
|
-
)
|
197
|
-
|
198
|
-
loss_mask = target_chunk != ignore_index
|
199
|
-
label_chunk = torch.where(loss_mask, target_chunk, 0)
|
200
|
-
|
201
|
-
per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
|
202
|
-
-1
|
203
240
|
)
|
204
|
-
|
241
|
+
chosen_nll_loss = (
|
242
|
+
chosen_nll_loss
|
243
|
+
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
|
244
|
+
)
|
205
245
|
|
206
|
-
|
207
|
-
|
246
|
+
if use_ref_model:
|
247
|
+
with torch.no_grad():
|
248
|
+
ref_chosen_logps, ref_rejected_logps, _ = (
|
249
|
+
LigerFusedLinearPreferenceBase.chunk_forward(
|
250
|
+
input_chunk,
|
251
|
+
ref_weight,
|
252
|
+
target_chunk,
|
253
|
+
ref_bias,
|
254
|
+
ignore_index=ignore_index,
|
255
|
+
compute_nll_loss=False,
|
256
|
+
)
|
257
|
+
)
|
258
|
+
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
|
259
|
+
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
|
208
260
|
|
209
261
|
alignment_loss = preference_loss_fn(
|
210
262
|
chosen_logps, rejected_logps, beta=beta, **loss_kwargs
|
@@ -92,8 +92,8 @@ def liger_cross_entropy_kernel(
|
|
92
92
|
# 3. [Online softmax] first pass: find max + sum
|
93
93
|
m = float("-inf") # m is the max value. use the notation from the paper
|
94
94
|
d = 0.0 # d is the sum. use the notation from the paper
|
95
|
-
ori_X_y = tl.load(
|
96
|
-
|
95
|
+
ori_X_y = tl.load(X_ptr + y).cast(
|
96
|
+
tl.float32
|
97
97
|
) # we need to store the original value of X_y for the loss calculation
|
98
98
|
if HAS_SOFTCAPPING:
|
99
99
|
ori_X_y = softcap * tanh(ori_X_y / softcap)
|
@@ -106,8 +106,11 @@ def liger_cross_entropy_kernel(
|
|
106
106
|
for i in range(0, n_cols, BLOCK_SIZE):
|
107
107
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
108
108
|
X_block = tl.load(
|
109
|
-
X_ptr + X_offsets,
|
110
|
-
|
109
|
+
X_ptr + X_offsets,
|
110
|
+
mask=X_offsets < n_cols,
|
111
|
+
other=float("-inf"),
|
112
|
+
# Ensure float32 precision for softmax calculation
|
113
|
+
).cast(tl.float32)
|
111
114
|
if HAS_SOFTCAPPING:
|
112
115
|
X_block = softcap * tanh(X_block / softcap)
|
113
116
|
block_max = tl.max(X_block)
|
@@ -141,8 +144,11 @@ def liger_cross_entropy_kernel(
|
|
141
144
|
for i in range(0, n_cols, BLOCK_SIZE):
|
142
145
|
X_offsets = i + tl.arange(0, BLOCK_SIZE)
|
143
146
|
X_block = tl.load(
|
144
|
-
X_ptr + X_offsets,
|
145
|
-
|
147
|
+
X_ptr + X_offsets,
|
148
|
+
mask=X_offsets < n_cols,
|
149
|
+
other=float("-inf"),
|
150
|
+
# Ensure float32 precision for softmax calculation
|
151
|
+
).cast(tl.float32)
|
146
152
|
if HAS_SOFTCAPPING:
|
147
153
|
intermediate = tanh(X_block / softcap)
|
148
154
|
X_block = softcap * intermediate
|
@@ -26,7 +26,6 @@ def fused_linear_cross_entropy_forward(
|
|
26
26
|
reduction="mean",
|
27
27
|
softcap=None,
|
28
28
|
):
|
29
|
-
dtype = _input.dtype
|
30
29
|
device = _input.device
|
31
30
|
|
32
31
|
# inputs have shape: BT x H
|
@@ -74,9 +73,6 @@ def fused_linear_cross_entropy_forward(
|
|
74
73
|
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
|
75
74
|
n_non_ignore = (target_chunk != ignore_index).sum().item()
|
76
75
|
|
77
|
-
# when doing CE, use the upcasted precision
|
78
|
-
logits_chunk = logits_chunk.float()
|
79
|
-
|
80
76
|
# ensure _input and target are contiguous
|
81
77
|
logits_chunk = logits_chunk.contiguous()
|
82
78
|
target_chunk = target_chunk.contiguous()
|
@@ -103,13 +99,6 @@ def fused_linear_cross_entropy_forward(
|
|
103
99
|
num_warps=32 if not is_hip() else 16,
|
104
100
|
)
|
105
101
|
|
106
|
-
# gradient of logits_chunk is computed in-place by the above triton kernel.
|
107
|
-
# Following HuggingFace model source code, we do the forward and backward
|
108
|
-
# w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size) is huge.
|
109
|
-
# (reference: https://github.com/huggingface/transformers/blob/v4.42.4/src/transformers/models/llama/modeling_llama.py#L1194)
|
110
|
-
# Propagating to lm_head's backward, we'll switch back to the original dtype.
|
111
|
-
logits_chunk = logits_chunk.to(dtype)
|
112
|
-
|
113
102
|
# gradient of logits_chunk is computed in-place by the above triton kernel and is of shape: chunk_size x V
|
114
103
|
# thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H
|
115
104
|
# additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
|
File without changes
|
{liger_kernel_nightly-0.4.2.dev20241121225747 → liger_kernel_nightly-0.4.2.dev20241122175637}/NOTICE
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|