liger-kernel 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/__init__.py +0 -0
- liger_kernel/chunked_loss/__init__.py +4 -0
- liger_kernel/chunked_loss/cpo_loss.py +107 -0
- liger_kernel/chunked_loss/dpo_loss.py +135 -0
- liger_kernel/chunked_loss/functional.py +9 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +252 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +386 -0
- liger_kernel/chunked_loss/orpo_loss.py +113 -0
- liger_kernel/chunked_loss/simpo_loss.py +115 -0
- liger_kernel/env_report.py +22 -0
- liger_kernel/ops/cross_entropy.py +17 -10
- liger_kernel/ops/fused_linear_cross_entropy.py +1 -11
- liger_kernel/ops/fused_linear_jsd.py +1 -1
- liger_kernel/ops/jsd.py +19 -10
- liger_kernel/ops/layer_norm.py +6 -1
- liger_kernel/ops/qwen2vl_mrope.py +238 -0
- liger_kernel/ops/rms_norm.py +6 -1
- liger_kernel/ops/utils.py +5 -2
- liger_kernel/transformers/__init__.py +1 -0
- liger_kernel/transformers/functional.py +128 -11
- liger_kernel/transformers/fused_linear_jsd.py +1 -4
- liger_kernel/transformers/jsd.py +1 -4
- liger_kernel/transformers/model/qwen2_vl.py +43 -17
- liger_kernel/transformers/monkey_patch.py +11 -6
- liger_kernel/transformers/orpo_trainer.py +171 -0
- liger_kernel/transformers/qwen2vl_mrope.py +20 -0
- liger_kernel/utils.py +13 -0
- {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/METADATA +80 -123
- {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/RECORD +33 -20
- {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/WHEEL +1 -1
- {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/LICENSE +0 -0
- {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/NOTICE +0 -0
- {liger_kernel-0.4.1.dist-info → liger_kernel-0.5.0.dist-info}/top_level.txt +0 -0
liger_kernel/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,4 @@
|
|
|
1
|
+
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
|
|
2
|
+
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
|
|
3
|
+
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
|
|
4
|
+
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_preference import (
|
|
5
|
+
LigerFusedLinearPreferenceBase,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
|
10
|
+
|
|
11
|
+
@staticmethod
|
|
12
|
+
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
|
|
13
|
+
"""
|
|
14
|
+
Paper: https://arxiv.org/pdf/2401.08417
|
|
15
|
+
|
|
16
|
+
Formula:
|
|
17
|
+
L(π_θ; U) = -E_(x,y_w,y_l)~D[log σ(β log π_θ(y_w|x) - β log π_θ(y_l|x))]
|
|
18
|
+
|
|
19
|
+
Where:
|
|
20
|
+
- π_θ(y|x): Policy (model) probability
|
|
21
|
+
- y_w: Chosen sequence
|
|
22
|
+
- y_l: Rejected sequence
|
|
23
|
+
- σ: Sigmoid function
|
|
24
|
+
- β: Temperature parameter
|
|
25
|
+
- E: Expected value over the dataset D
|
|
26
|
+
- D: Dataset of preferences
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
|
|
30
|
+
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
|
|
31
|
+
full_target (torch.Tensor): Non chunked full target tensor
|
|
32
|
+
beta (float): Weight for the CPO loss
|
|
33
|
+
"""
|
|
34
|
+
logits = beta * (chosen_logps - rejected_logps)
|
|
35
|
+
loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
|
|
36
|
+
return loss
|
|
37
|
+
|
|
38
|
+
@staticmethod
|
|
39
|
+
def forward(
|
|
40
|
+
ctx,
|
|
41
|
+
_input,
|
|
42
|
+
weight,
|
|
43
|
+
target,
|
|
44
|
+
bias=None,
|
|
45
|
+
ignore_index=-100,
|
|
46
|
+
beta=0.1,
|
|
47
|
+
alpha=1.0,
|
|
48
|
+
compute_nll_loss=True,
|
|
49
|
+
compiled=True,
|
|
50
|
+
):
|
|
51
|
+
return LigerFusedLinearPreferenceBase.forward(
|
|
52
|
+
ctx,
|
|
53
|
+
_input,
|
|
54
|
+
weight,
|
|
55
|
+
target,
|
|
56
|
+
bias,
|
|
57
|
+
loss_fn=LigerFusedLinearCPOFunction.preference_loss_fn,
|
|
58
|
+
ignore_index=ignore_index,
|
|
59
|
+
alpha=alpha,
|
|
60
|
+
beta=beta,
|
|
61
|
+
compute_nll_loss=compute_nll_loss,
|
|
62
|
+
compiled=compiled,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
@staticmethod
|
|
66
|
+
def backward(ctx, *grad_output):
|
|
67
|
+
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
68
|
+
return *grads, None, None, None, None, None
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class LigerFusedLinearCPOLoss(torch.nn.Module):
|
|
72
|
+
"""
|
|
73
|
+
Fused linear layer with CPO loss.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
def __init__(
|
|
77
|
+
self,
|
|
78
|
+
ignore_index: int = -100,
|
|
79
|
+
beta: float = 0.1,
|
|
80
|
+
alpha: float = 1.0,
|
|
81
|
+
compute_nll_loss: bool = True,
|
|
82
|
+
compiled: bool = True,
|
|
83
|
+
):
|
|
84
|
+
"""
|
|
85
|
+
Args:
|
|
86
|
+
ignore_index (int): Index to ignore in the loss.
|
|
87
|
+
beta (float): Weight for the odds ratio loss.
|
|
88
|
+
"""
|
|
89
|
+
super().__init__()
|
|
90
|
+
self.ignore_index = ignore_index
|
|
91
|
+
self.beta = beta
|
|
92
|
+
self.alpha = alpha
|
|
93
|
+
self.compute_nll_loss = compute_nll_loss
|
|
94
|
+
self.compiled = compiled
|
|
95
|
+
|
|
96
|
+
def forward(self, lin_weight, _input, target, bias=None):
|
|
97
|
+
return LigerFusedLinearCPOFunction.apply(
|
|
98
|
+
_input,
|
|
99
|
+
lin_weight,
|
|
100
|
+
target,
|
|
101
|
+
bias,
|
|
102
|
+
self.ignore_index,
|
|
103
|
+
self.beta,
|
|
104
|
+
self.alpha,
|
|
105
|
+
self.compute_nll_loss,
|
|
106
|
+
self.compiled,
|
|
107
|
+
)
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_preference import (
|
|
5
|
+
LigerFusedLinearPreferenceBase,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
10
|
+
|
|
11
|
+
@staticmethod
|
|
12
|
+
def preference_loss_fn(
|
|
13
|
+
chosen_logps,
|
|
14
|
+
rejected_logps,
|
|
15
|
+
full_target,
|
|
16
|
+
ref_chosen_logps=None,
|
|
17
|
+
ref_rejected_logps=None,
|
|
18
|
+
beta=0.1,
|
|
19
|
+
):
|
|
20
|
+
"""
|
|
21
|
+
Paper: https://arxiv.org/pdf/2305.18290
|
|
22
|
+
|
|
23
|
+
Formula:
|
|
24
|
+
L_DPO = -E[ log_sigmoid( β * (log(π(y_w|x)/π_ref(y_w|x)) - log(π(y_l|x)/π_ref(y_l|x))) ) ]
|
|
25
|
+
|
|
26
|
+
Where:
|
|
27
|
+
- π(y|x): Policy (model) probability
|
|
28
|
+
- π_ref(y|x): Reference model probability
|
|
29
|
+
- y_w: Chosen sequence
|
|
30
|
+
- y_l: Rejected sequence
|
|
31
|
+
- β: Weight for the direct preference loss
|
|
32
|
+
- E: Expected value over the dataset
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
chosen_logps: Log probabilities of chosen tokens (batch_size,)
|
|
36
|
+
rejected_logps: Log probabilities of rejected tokens (batch_size,)
|
|
37
|
+
full_target: Non chunked full target tensor
|
|
38
|
+
ref_chosen_logps: Reference log probs of chosen tokens (batch_size,)
|
|
39
|
+
ref_rejected_logps: Reference log probs of rejected tokens (batch_size,)
|
|
40
|
+
beta: Weight for the direct preference loss
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
if ref_chosen_logps is None:
|
|
44
|
+
ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device)
|
|
45
|
+
if ref_rejected_logps is None:
|
|
46
|
+
ref_rejected_logps = torch.tensor(0.0, device=rejected_logps.device)
|
|
47
|
+
|
|
48
|
+
chosen_logratios = chosen_logps - ref_chosen_logps
|
|
49
|
+
rejected_logratios = rejected_logps - ref_rejected_logps
|
|
50
|
+
|
|
51
|
+
logits_diff = beta * (chosen_logratios - rejected_logratios)
|
|
52
|
+
loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
|
|
53
|
+
return loss
|
|
54
|
+
|
|
55
|
+
@staticmethod
|
|
56
|
+
def forward(
|
|
57
|
+
ctx,
|
|
58
|
+
_input,
|
|
59
|
+
weight,
|
|
60
|
+
target,
|
|
61
|
+
bias=None,
|
|
62
|
+
ref_weight=None,
|
|
63
|
+
ref_bias=None,
|
|
64
|
+
ignore_index=-100,
|
|
65
|
+
beta=0.1,
|
|
66
|
+
compute_nll_loss=True,
|
|
67
|
+
compiled=True,
|
|
68
|
+
use_ref_model=True,
|
|
69
|
+
):
|
|
70
|
+
return LigerFusedLinearPreferenceBase.forward(
|
|
71
|
+
ctx=ctx,
|
|
72
|
+
_input=_input,
|
|
73
|
+
weight=weight,
|
|
74
|
+
target=target,
|
|
75
|
+
bias=bias,
|
|
76
|
+
loss_fn=LigerFusedLinearDPOFunction.preference_loss_fn,
|
|
77
|
+
ignore_index=ignore_index,
|
|
78
|
+
beta=beta,
|
|
79
|
+
compute_nll_loss=compute_nll_loss,
|
|
80
|
+
compiled=compiled,
|
|
81
|
+
use_ref_model=use_ref_model,
|
|
82
|
+
ref_weight=ref_weight,
|
|
83
|
+
ref_bias=ref_bias,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
@staticmethod
|
|
87
|
+
def backward(ctx, *grad_output):
|
|
88
|
+
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
89
|
+
return *grads, None, None, None, None, None, None, None
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
93
|
+
"""
|
|
94
|
+
Fused linear layer with DPO loss.
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
def __init__(
|
|
98
|
+
self,
|
|
99
|
+
ignore_index: int = -100,
|
|
100
|
+
beta: float = 0.1,
|
|
101
|
+
compute_nll_loss: bool = True,
|
|
102
|
+
compiled: bool = True,
|
|
103
|
+
use_ref_model: bool = False,
|
|
104
|
+
):
|
|
105
|
+
"""
|
|
106
|
+
Args:
|
|
107
|
+
ignore_index (int): Index to ignore in the loss.
|
|
108
|
+
beta (float): Weight for the odds ratio loss.
|
|
109
|
+
compute_nll_loss (bool): Whether to compute the NLL loss.
|
|
110
|
+
compiled (bool): Whether to use the torch compiled kernel.
|
|
111
|
+
use_ref_model (bool): Whether to use a reference model for the DPO loss.
|
|
112
|
+
"""
|
|
113
|
+
super().__init__()
|
|
114
|
+
self.ignore_index = ignore_index
|
|
115
|
+
self.beta = beta
|
|
116
|
+
self.compute_nll_loss = compute_nll_loss
|
|
117
|
+
self.compiled = compiled
|
|
118
|
+
self.use_ref_model = use_ref_model
|
|
119
|
+
|
|
120
|
+
def forward(
|
|
121
|
+
self, lin_weight, _input, target, bias=None, ref_weight=None, ref_bias=None
|
|
122
|
+
):
|
|
123
|
+
return LigerFusedLinearDPOFunction.apply(
|
|
124
|
+
_input,
|
|
125
|
+
lin_weight,
|
|
126
|
+
target,
|
|
127
|
+
bias,
|
|
128
|
+
ref_weight,
|
|
129
|
+
ref_bias,
|
|
130
|
+
self.ignore_index,
|
|
131
|
+
self.beta,
|
|
132
|
+
self.compute_nll_loss,
|
|
133
|
+
self.compiled,
|
|
134
|
+
self.use_ref_model,
|
|
135
|
+
)
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
|
|
2
|
+
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
|
|
3
|
+
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
|
|
4
|
+
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
|
|
5
|
+
|
|
6
|
+
liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
|
|
7
|
+
liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
|
|
8
|
+
liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
|
|
9
|
+
liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch.nn import functional as F
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
9
|
+
|
|
10
|
+
@abstractmethod
|
|
11
|
+
def distillation_loss_fn(student_logits, teacher_logits, temperature):
|
|
12
|
+
"""
|
|
13
|
+
Compute distillation loss.
|
|
14
|
+
Args:
|
|
15
|
+
student_logits (torch.Tensor): Raw logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
|
|
16
|
+
teacher_logits (torch.Tensor): Raw logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
|
|
17
|
+
"""
|
|
18
|
+
raise NotImplementedError("Distillation loss function must be implemented.")
|
|
19
|
+
|
|
20
|
+
@staticmethod
|
|
21
|
+
def chunk_forward(
|
|
22
|
+
student_input_chunk,
|
|
23
|
+
student_weight,
|
|
24
|
+
teacher_input_chunk,
|
|
25
|
+
teacher_weight,
|
|
26
|
+
target_chunk,
|
|
27
|
+
student_bias=None,
|
|
28
|
+
teacher_bias=None,
|
|
29
|
+
ignore_index=-100,
|
|
30
|
+
compute_ce_loss=True,
|
|
31
|
+
):
|
|
32
|
+
# Student
|
|
33
|
+
student_logits_chunk = student_input_chunk @ student_weight.t()
|
|
34
|
+
if student_bias is not None:
|
|
35
|
+
student_logits_chunk += student_bias
|
|
36
|
+
student_log_probs_chunk = F.log_softmax(student_logits_chunk.float(), dim=-1)
|
|
37
|
+
|
|
38
|
+
# Teacher
|
|
39
|
+
with torch.no_grad():
|
|
40
|
+
teacher_logits_chunk = teacher_input_chunk @ teacher_weight.t()
|
|
41
|
+
if teacher_bias is not None:
|
|
42
|
+
teacher_logits_chunk += teacher_bias
|
|
43
|
+
|
|
44
|
+
# The hard/task loss
|
|
45
|
+
ce_loss = 0.0
|
|
46
|
+
if compute_ce_loss:
|
|
47
|
+
ce_loss = F.nll_loss(
|
|
48
|
+
student_log_probs_chunk.view(-1, student_log_probs_chunk.shape[-1]),
|
|
49
|
+
target_chunk.view(-1),
|
|
50
|
+
reduction="sum",
|
|
51
|
+
ignore_index=ignore_index,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
return student_logits_chunk, teacher_logits_chunk, ce_loss
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def _compute_loss(
|
|
58
|
+
student_input_chunk,
|
|
59
|
+
student_weight,
|
|
60
|
+
teacher_input_chunk,
|
|
61
|
+
teacher_weight,
|
|
62
|
+
target_chunk,
|
|
63
|
+
student_bias=None,
|
|
64
|
+
teacher_bias=None,
|
|
65
|
+
distillation_loss_fn=None,
|
|
66
|
+
full_target=None,
|
|
67
|
+
ignore_index=-100,
|
|
68
|
+
temperature=1.0,
|
|
69
|
+
weight_hard_loss=0.5,
|
|
70
|
+
weight_soft_loss=0.5,
|
|
71
|
+
compute_ce_loss=True,
|
|
72
|
+
**loss_kwargs,
|
|
73
|
+
):
|
|
74
|
+
"""
|
|
75
|
+
Compute the total loss for a chunk of input and target, while using an knowleedge distillation loss function.
|
|
76
|
+
Args:
|
|
77
|
+
distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
|
78
|
+
student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size).
|
|
79
|
+
student_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, student_hidden_size).
|
|
80
|
+
teacher_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, teacher_hidden_size).
|
|
81
|
+
teacher_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, teacher_hidden_size).
|
|
82
|
+
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,).
|
|
83
|
+
student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
|
84
|
+
teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
|
85
|
+
full_target (torch.Tensor): Full target tensor. Shape: (chunk_size,).
|
|
86
|
+
ignore_index (int): Index to ignore for loss computation.
|
|
87
|
+
weight_hard_loss (float): Weight for hard loss.
|
|
88
|
+
weight_soft_loss (float): Weight for soft loss.
|
|
89
|
+
compute_ce_loss (bool): Whether to compute CE loss.
|
|
90
|
+
loss_kwargs (dict): Additional arguments for the loss function.
|
|
91
|
+
"""
|
|
92
|
+
student_logits_chunk, teacher_logits_chunk, hard_loss = (
|
|
93
|
+
LigerFusedLinearDistillationBase.chunk_forward(
|
|
94
|
+
student_input_chunk,
|
|
95
|
+
student_weight,
|
|
96
|
+
teacher_input_chunk,
|
|
97
|
+
teacher_weight,
|
|
98
|
+
target_chunk,
|
|
99
|
+
student_bias=student_bias,
|
|
100
|
+
teacher_bias=teacher_bias,
|
|
101
|
+
ignore_index=ignore_index,
|
|
102
|
+
compute_ce_loss=compute_ce_loss,
|
|
103
|
+
)
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
hard_loss /= full_target.shape[0]
|
|
107
|
+
|
|
108
|
+
soft_loss = distillation_loss_fn(
|
|
109
|
+
student_logits_chunk, teacher_logits_chunk, temperature
|
|
110
|
+
)
|
|
111
|
+
soft_loss /= full_target.shape[0]
|
|
112
|
+
|
|
113
|
+
loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
|
|
114
|
+
return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk)
|
|
115
|
+
|
|
116
|
+
@staticmethod
|
|
117
|
+
def forward(
|
|
118
|
+
ctx,
|
|
119
|
+
student_input,
|
|
120
|
+
student_weight,
|
|
121
|
+
teacher_input,
|
|
122
|
+
teacher_weight,
|
|
123
|
+
target,
|
|
124
|
+
student_bias=None,
|
|
125
|
+
teacher_bias=None,
|
|
126
|
+
loss_fn=None,
|
|
127
|
+
chunk_size=1024,
|
|
128
|
+
ignore_index=-100,
|
|
129
|
+
weight_hard_loss=0.5,
|
|
130
|
+
weight_soft_loss=0.5,
|
|
131
|
+
compute_ce_loss=True,
|
|
132
|
+
temperature=1.0,
|
|
133
|
+
compiled=True,
|
|
134
|
+
**loss_kwargs,
|
|
135
|
+
):
|
|
136
|
+
"""
|
|
137
|
+
Base class for fused linear layer with distillation loss.
|
|
138
|
+
Only need to compute gradients for student model.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, student_hidden_size).
|
|
142
|
+
student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, student_hidden_size).
|
|
143
|
+
teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, teacher_hidden_size).
|
|
144
|
+
teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, teacher_hidden_size).
|
|
145
|
+
target (torch.Tensor): Target truth label tensor. Shape: (batch_size * seq_len).
|
|
146
|
+
student_bias (torch.Tensor, optional): Student bias tensor. Shape: (vocab_size,).
|
|
147
|
+
teacher_bias (torch.Tensor, optional): Teacher bias tensor. Shape: (vocab_size,).
|
|
148
|
+
loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
|
149
|
+
chunk_size (int): Size of a chunk.
|
|
150
|
+
compute_ce_loss (bool): Whether to compute CE loss.
|
|
151
|
+
ignore_index (int): Index to ignore for loss computation.
|
|
152
|
+
weight_hard_loss (float): Weight for hard/task loss.
|
|
153
|
+
weight_soft_loss (float): Weight for soft/distillation loss.
|
|
154
|
+
compiled (bool): Whether to use torch compile for chunk accumulation.
|
|
155
|
+
loss_kwargs (dict): Other possible arguments that a loss function might need
|
|
156
|
+
"""
|
|
157
|
+
CHUNK_SIZE = chunk_size
|
|
158
|
+
grad_weight = torch.zeros_like(student_weight)
|
|
159
|
+
grad_inputs = []
|
|
160
|
+
grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None
|
|
161
|
+
loss_acc = torch.zeros((), device=student_input.device)
|
|
162
|
+
|
|
163
|
+
loss_func_to_call = partial(
|
|
164
|
+
LigerFusedLinearDistillationBase._compute_loss,
|
|
165
|
+
distillation_loss_fn=loss_fn,
|
|
166
|
+
full_target=target,
|
|
167
|
+
ignore_index=ignore_index,
|
|
168
|
+
weight_hard_loss=weight_hard_loss,
|
|
169
|
+
weight_soft_loss=weight_soft_loss,
|
|
170
|
+
compute_ce_loss=compute_ce_loss,
|
|
171
|
+
temperature=temperature,
|
|
172
|
+
**loss_kwargs,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk):
|
|
176
|
+
if student_bias is not None:
|
|
177
|
+
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
|
|
178
|
+
chunk_loss,
|
|
179
|
+
(
|
|
180
|
+
chunk_soft_loss,
|
|
181
|
+
chunk_hard_loss,
|
|
182
|
+
chunk_student_logits,
|
|
183
|
+
chunk_teacher_logits,
|
|
184
|
+
),
|
|
185
|
+
) = torch.func.grad_and_value(
|
|
186
|
+
loss_func_to_call, argnums=(0, 1, 5), has_aux=True
|
|
187
|
+
)(
|
|
188
|
+
student_input_chunk,
|
|
189
|
+
student_weight,
|
|
190
|
+
teacher_input_chunk,
|
|
191
|
+
teacher_weight,
|
|
192
|
+
target_chunk,
|
|
193
|
+
student_bias,
|
|
194
|
+
teacher_bias,
|
|
195
|
+
)
|
|
196
|
+
grad_bias.add_(chunk_grad_bias)
|
|
197
|
+
else:
|
|
198
|
+
(chunk_grad_input, chunk_grad_weight), (
|
|
199
|
+
chunk_loss,
|
|
200
|
+
(
|
|
201
|
+
chunk_soft_loss,
|
|
202
|
+
chunk_hard_loss,
|
|
203
|
+
chunk_student_logits,
|
|
204
|
+
chunk_teacher_logits,
|
|
205
|
+
),
|
|
206
|
+
) = torch.func.grad_and_value(
|
|
207
|
+
loss_func_to_call, argnums=(0, 1), has_aux=True
|
|
208
|
+
)(
|
|
209
|
+
student_input_chunk,
|
|
210
|
+
student_weight,
|
|
211
|
+
teacher_input_chunk,
|
|
212
|
+
teacher_weight,
|
|
213
|
+
target_chunk,
|
|
214
|
+
student_bias,
|
|
215
|
+
teacher_bias,
|
|
216
|
+
)
|
|
217
|
+
grad_weight.add_(chunk_grad_weight)
|
|
218
|
+
loss_acc.add_(chunk_loss)
|
|
219
|
+
return chunk_grad_input
|
|
220
|
+
|
|
221
|
+
if compiled:
|
|
222
|
+
accumulate_chunk = torch.compile(accumulate_chunk)
|
|
223
|
+
|
|
224
|
+
num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE)
|
|
225
|
+
_student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0)
|
|
226
|
+
_teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0)
|
|
227
|
+
_target_chunks = torch.chunk(target, chunks=num_chunks, dim=0)
|
|
228
|
+
|
|
229
|
+
for student_input_chunk, teacher_input_chunk, target_chunk in zip(
|
|
230
|
+
_student_input_chunks, _teacher_input_chunks, _target_chunks
|
|
231
|
+
):
|
|
232
|
+
grad_input = accumulate_chunk(
|
|
233
|
+
student_input_chunk, teacher_input_chunk, target_chunk
|
|
234
|
+
)
|
|
235
|
+
grad_inputs.append(grad_input)
|
|
236
|
+
|
|
237
|
+
ctx.save_for_backward(
|
|
238
|
+
torch.cat(grad_inputs, dim=0),
|
|
239
|
+
grad_weight,
|
|
240
|
+
grad_bias,
|
|
241
|
+
)
|
|
242
|
+
return loss_acc
|
|
243
|
+
|
|
244
|
+
@staticmethod
|
|
245
|
+
def backward(ctx, grad_output):
|
|
246
|
+
grad_input, grad_weight, grad_bias = ctx.saved_tensors
|
|
247
|
+
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
|
|
248
|
+
grad_input = grad_input * grad_output
|
|
249
|
+
grad_weight = grad_weight * grad_output
|
|
250
|
+
grad_bias = grad_bias * grad_output if grad_bias is not None else None
|
|
251
|
+
|
|
252
|
+
return grad_input, grad_weight, None, grad_bias
|