liger-kernel 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/__init__.py +2 -0
- liger_kernel/chunked_loss/cpo_loss.py +18 -8
- liger_kernel/chunked_loss/dpo_loss.py +20 -10
- liger_kernel/chunked_loss/functional.py +4 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +58 -44
- liger_kernel/chunked_loss/fused_linear_preference.py +108 -60
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +246 -0
- liger_kernel/chunked_loss/jsd_loss.py +154 -0
- liger_kernel/chunked_loss/kto_loss.py +172 -0
- liger_kernel/chunked_loss/orpo_loss.py +8 -9
- liger_kernel/chunked_loss/simpo_loss.py +22 -8
- liger_kernel/env_report.py +5 -12
- liger_kernel/ops/cross_entropy.py +102 -51
- liger_kernel/ops/experimental/embedding.py +1 -3
- liger_kernel/ops/experimental/mm_int8int2.py +3 -9
- liger_kernel/ops/fused_linear_cross_entropy.py +89 -55
- liger_kernel/ops/fused_linear_jsd.py +11 -29
- liger_kernel/ops/geglu.py +6 -17
- liger_kernel/ops/group_norm.py +11 -28
- liger_kernel/ops/jsd.py +2 -6
- liger_kernel/ops/kl_div.py +8 -11
- liger_kernel/ops/layer_norm.py +3 -5
- liger_kernel/ops/qwen2vl_mrope.py +21 -37
- liger_kernel/ops/rms_norm.py +14 -32
- liger_kernel/ops/rope.py +31 -33
- liger_kernel/ops/swiglu.py +4 -8
- liger_kernel/ops/utils.py +2 -0
- liger_kernel/transformers/__init__.py +16 -24
- liger_kernel/transformers/auto_model.py +6 -13
- liger_kernel/transformers/cross_entropy.py +4 -6
- liger_kernel/transformers/experimental/embedding.py +1 -3
- liger_kernel/transformers/functional.py +11 -7
- liger_kernel/transformers/fused_linear_cross_entropy.py +12 -7
- liger_kernel/transformers/geglu.py +1 -4
- liger_kernel/transformers/group_norm.py +3 -9
- liger_kernel/transformers/jsd.py +1 -3
- liger_kernel/transformers/kl_div.py +1 -3
- liger_kernel/transformers/layer_norm.py +3 -9
- liger_kernel/transformers/model/gemma.py +18 -40
- liger_kernel/transformers/model/gemma2.py +19 -41
- liger_kernel/transformers/model/llama.py +22 -48
- liger_kernel/transformers/model/mistral.py +14 -26
- liger_kernel/transformers/model/mixtral.py +24 -54
- liger_kernel/transformers/model/mllama.py +16 -36
- liger_kernel/transformers/model/phi3.py +18 -40
- liger_kernel/transformers/model/qwen2.py +18 -40
- liger_kernel/transformers/model/qwen2_vl.py +36 -32
- liger_kernel/transformers/monkey_patch.py +43 -117
- liger_kernel/transformers/qwen2vl_mrope.py +2 -2
- liger_kernel/transformers/rms_norm.py +4 -4
- liger_kernel/transformers/rope.py +2 -2
- liger_kernel/transformers/swiglu.py +2 -8
- liger_kernel/transformers/trainer/__init__.py +1 -3
- liger_kernel/transformers/trainer/orpo_trainer.py +31 -18
- liger_kernel/triton/__init__.py +1 -3
- liger_kernel/triton/monkey_patch.py +1 -3
- {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/METADATA +38 -25
- liger_kernel-0.5.3.dist-info/RECORD +69 -0
- {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/WHEEL +1 -1
- liger_kernel-0.5.1.dist-info/RECORD +0 -65
- {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/LICENSE +0 -0
- {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/NOTICE +0 -0
- {liger_kernel-0.5.1.dist-info → liger_kernel-0.5.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# Liger FlexChunkLoss: Alignment and Distillation loss
|
|
2
|
+
|
|
3
|
+
Liger FlexChunkLoss offers a versatile interface, delivering up to 80% memory savings and a 10% throughput boost for post-training loss functions, including alignment (DPO, ORPO, CPO, KTO) and very soon, distillation. Its flexible design supports custom losses, ensuring efficiency gains across diverse use cases.
|
|
4
|
+
|
|
5
|
+
### User interface
|
|
6
|
+
|
|
7
|
+
FlexChunkLoss offers two flexible usage options:
|
|
8
|
+
|
|
9
|
+
1. **Via `Liger[Custom Loss]Trainer`**
|
|
10
|
+
For example, by simply replacing the HuggingFace `ORPOTrainer` with `LigerORPOTrainer` in your code, you can leverage our optimized ORPO implementation and immediately benefit from improved performance.
|
|
11
|
+
|
|
12
|
+
2. **Using `nn.Module` Implementations of Custom Loss Functions**
|
|
13
|
+
Explore the [LigerORPOTrainer implementation](https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/orpo_trainer.py) to see how the modular design integrates custom loss functions seamlessly.
|
|
14
|
+
|
|
15
|
+
### What's under the hood?
|
|
16
|
+
|
|
17
|
+
We employ chunking and fused kernel optimizations to enhance performance. By fusing the final linear layer with loss computation and calculating backward gradients during the forward pass, we significantly reduce the need for storing intermediate activations. All operations are implemented in PyTorch, leveraging `torch.compile` to streamline kernel execution without relying on extensive low-level optimizations. Additionally, we minimize `torch.compile` recompilations to reduce overhead and ensure consistent performance gains.
|
|
18
|
+
|
|
19
|
+
### Extending to custom loss functions
|
|
20
|
+
|
|
21
|
+
We provide two base classes: `LigerFusedLinearPreferenceBase` for alignment use cases and `LigerFusedLinearDistillationBase` for distillation use cases. These base classes manage chunking, kernel fusions, and Torch compilation.
|
|
22
|
+
|
|
23
|
+
To implement a custom loss function, you need to create a subclass that defines the custom preference or distillation loss function, capable of processing a given input chunk. The base class will take care of the optimizations, handling most of the heavy lifting for you.
|
|
24
|
+
|
|
25
|
+
For a working example, refer to the [ORPO loss implementation](https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/chunked_loss/orpo_loss.py).
|
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
|
|
2
2
|
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
|
|
3
|
+
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDLoss # noqa: F401
|
|
4
|
+
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401
|
|
3
5
|
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
|
|
4
6
|
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401
|
|
@@ -1,15 +1,12 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import torch.nn.functional as F
|
|
3
3
|
|
|
4
|
-
from liger_kernel.chunked_loss.fused_linear_preference import
|
|
5
|
-
LigerFusedLinearPreferenceBase,
|
|
6
|
-
)
|
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
|
|
7
5
|
|
|
8
6
|
|
|
9
7
|
class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
|
10
|
-
|
|
11
8
|
@staticmethod
|
|
12
|
-
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
|
|
9
|
+
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1, label_smoothing=0.0):
|
|
13
10
|
"""
|
|
14
11
|
Paper: https://arxiv.org/pdf/2401.08417
|
|
15
12
|
|
|
@@ -30,10 +27,17 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
30
27
|
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
|
|
31
28
|
full_target (torch.Tensor): Non chunked full target tensor
|
|
32
29
|
beta (float): Weight for the CPO loss
|
|
30
|
+
label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
|
|
33
31
|
"""
|
|
34
32
|
logits = beta * (chosen_logps - rejected_logps)
|
|
35
|
-
loss = F.logsigmoid(logits).sum() / (
|
|
36
|
-
|
|
33
|
+
loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
|
|
34
|
+
full_target.shape[0] // 2
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
chosen_rewards = beta * chosen_logps
|
|
38
|
+
rejected_rewards = beta * rejected_logps
|
|
39
|
+
|
|
40
|
+
return loss, chosen_rewards, rejected_rewards
|
|
37
41
|
|
|
38
42
|
@staticmethod
|
|
39
43
|
def forward(
|
|
@@ -45,6 +49,7 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
45
49
|
ignore_index=-100,
|
|
46
50
|
beta=0.1,
|
|
47
51
|
alpha=1.0,
|
|
52
|
+
label_smoothing=0.0,
|
|
48
53
|
compute_nll_loss=True,
|
|
49
54
|
compiled=True,
|
|
50
55
|
):
|
|
@@ -58,14 +63,16 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
58
63
|
ignore_index=ignore_index,
|
|
59
64
|
alpha=alpha,
|
|
60
65
|
beta=beta,
|
|
66
|
+
label_smoothing=label_smoothing,
|
|
61
67
|
compute_nll_loss=compute_nll_loss,
|
|
68
|
+
average_log_prob=False,
|
|
62
69
|
compiled=compiled,
|
|
63
70
|
)
|
|
64
71
|
|
|
65
72
|
@staticmethod
|
|
66
73
|
def backward(ctx, *grad_output):
|
|
67
74
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
68
|
-
return *grads, None, None, None, None, None
|
|
75
|
+
return *grads, None, None, None, None, None, None
|
|
69
76
|
|
|
70
77
|
|
|
71
78
|
class LigerFusedLinearCPOLoss(torch.nn.Module):
|
|
@@ -78,6 +85,7 @@ class LigerFusedLinearCPOLoss(torch.nn.Module):
|
|
|
78
85
|
ignore_index: int = -100,
|
|
79
86
|
beta: float = 0.1,
|
|
80
87
|
alpha: float = 1.0,
|
|
88
|
+
label_smoothing: float = 0.0,
|
|
81
89
|
compute_nll_loss: bool = True,
|
|
82
90
|
compiled: bool = True,
|
|
83
91
|
):
|
|
@@ -90,6 +98,7 @@ class LigerFusedLinearCPOLoss(torch.nn.Module):
|
|
|
90
98
|
self.ignore_index = ignore_index
|
|
91
99
|
self.beta = beta
|
|
92
100
|
self.alpha = alpha
|
|
101
|
+
self.label_smoothing = label_smoothing
|
|
93
102
|
self.compute_nll_loss = compute_nll_loss
|
|
94
103
|
self.compiled = compiled
|
|
95
104
|
|
|
@@ -102,6 +111,7 @@ class LigerFusedLinearCPOLoss(torch.nn.Module):
|
|
|
102
111
|
self.ignore_index,
|
|
103
112
|
self.beta,
|
|
104
113
|
self.alpha,
|
|
114
|
+
self.label_smoothing,
|
|
105
115
|
self.compute_nll_loss,
|
|
106
116
|
self.compiled,
|
|
107
117
|
)
|
|
@@ -1,13 +1,10 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import torch.nn.functional as F
|
|
3
3
|
|
|
4
|
-
from liger_kernel.chunked_loss.fused_linear_preference import
|
|
5
|
-
LigerFusedLinearPreferenceBase,
|
|
6
|
-
)
|
|
4
|
+
from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
|
|
7
5
|
|
|
8
6
|
|
|
9
7
|
class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
10
|
-
|
|
11
8
|
@staticmethod
|
|
12
9
|
def preference_loss_fn(
|
|
13
10
|
chosen_logps,
|
|
@@ -48,9 +45,12 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
48
45
|
chosen_logratios = chosen_logps - ref_chosen_logps
|
|
49
46
|
rejected_logratios = rejected_logps - ref_rejected_logps
|
|
50
47
|
|
|
48
|
+
chosen_rewards = beta * chosen_logratios
|
|
49
|
+
rejected_rewards = beta * rejected_logratios
|
|
50
|
+
|
|
51
51
|
logits_diff = beta * (chosen_logratios - rejected_logratios)
|
|
52
52
|
loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
|
|
53
|
-
return loss
|
|
53
|
+
return loss, chosen_rewards, rejected_rewards
|
|
54
54
|
|
|
55
55
|
@staticmethod
|
|
56
56
|
def forward(
|
|
@@ -59,11 +59,12 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
59
59
|
weight,
|
|
60
60
|
target,
|
|
61
61
|
bias=None,
|
|
62
|
+
ref_input=None,
|
|
62
63
|
ref_weight=None,
|
|
63
64
|
ref_bias=None,
|
|
64
65
|
ignore_index=-100,
|
|
65
66
|
beta=0.1,
|
|
66
|
-
compute_nll_loss=
|
|
67
|
+
compute_nll_loss=False,
|
|
67
68
|
compiled=True,
|
|
68
69
|
use_ref_model=True,
|
|
69
70
|
):
|
|
@@ -79,6 +80,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
79
80
|
compute_nll_loss=compute_nll_loss,
|
|
80
81
|
compiled=compiled,
|
|
81
82
|
use_ref_model=use_ref_model,
|
|
83
|
+
ref_input=ref_input,
|
|
82
84
|
ref_weight=ref_weight,
|
|
83
85
|
ref_bias=ref_bias,
|
|
84
86
|
)
|
|
@@ -86,7 +88,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
|
|
|
86
88
|
@staticmethod
|
|
87
89
|
def backward(ctx, *grad_output):
|
|
88
90
|
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
|
|
89
|
-
return *grads, None, None, None, None, None, None, None
|
|
91
|
+
return *grads, None, None, None, None, None, None, None, None
|
|
90
92
|
|
|
91
93
|
|
|
92
94
|
class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
@@ -98,9 +100,9 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
98
100
|
self,
|
|
99
101
|
ignore_index: int = -100,
|
|
100
102
|
beta: float = 0.1,
|
|
101
|
-
compute_nll_loss: bool =
|
|
103
|
+
compute_nll_loss: bool = False,
|
|
102
104
|
compiled: bool = True,
|
|
103
|
-
use_ref_model: bool =
|
|
105
|
+
use_ref_model: bool = True,
|
|
104
106
|
):
|
|
105
107
|
"""
|
|
106
108
|
Args:
|
|
@@ -118,13 +120,21 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
|
|
|
118
120
|
self.use_ref_model = use_ref_model
|
|
119
121
|
|
|
120
122
|
def forward(
|
|
121
|
-
self,
|
|
123
|
+
self,
|
|
124
|
+
lin_weight,
|
|
125
|
+
_input,
|
|
126
|
+
target,
|
|
127
|
+
bias=None,
|
|
128
|
+
ref_input=None,
|
|
129
|
+
ref_weight=None,
|
|
130
|
+
ref_bias=None,
|
|
122
131
|
):
|
|
123
132
|
return LigerFusedLinearDPOFunction.apply(
|
|
124
133
|
_input,
|
|
125
134
|
lin_weight,
|
|
126
135
|
target,
|
|
127
136
|
bias,
|
|
137
|
+
ref_input,
|
|
128
138
|
ref_weight,
|
|
129
139
|
ref_bias,
|
|
130
140
|
self.ignore_index,
|
|
@@ -1,9 +1,13 @@
|
|
|
1
1
|
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
|
|
2
2
|
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
|
|
3
|
+
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
|
|
4
|
+
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction
|
|
3
5
|
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
|
|
4
6
|
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
|
|
5
7
|
|
|
6
8
|
liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
|
|
7
9
|
liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
|
|
10
|
+
liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
|
|
8
11
|
liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
|
|
9
12
|
liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
|
|
13
|
+
liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
|
|
@@ -2,18 +2,24 @@ from abc import abstractmethod
|
|
|
2
2
|
from functools import partial
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
|
+
|
|
5
6
|
from torch.nn import functional as F
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
9
|
-
|
|
10
10
|
@abstractmethod
|
|
11
|
-
def distillation_loss_fn(
|
|
11
|
+
def distillation_loss_fn(
|
|
12
|
+
student_logits,
|
|
13
|
+
teacher_logits,
|
|
14
|
+
):
|
|
12
15
|
"""
|
|
13
16
|
Compute distillation loss.
|
|
14
17
|
Args:
|
|
15
|
-
student_logits (torch.Tensor): Raw logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
|
|
16
|
-
teacher_logits (torch.Tensor): Raw logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
|
|
18
|
+
student_logits (torch.Tensor): Raw (temperature-scaled) logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
|
|
19
|
+
teacher_logits (torch.Tensor): Raw (temperature-scaled) logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
|
|
20
|
+
Returns:
|
|
21
|
+
torch.Tensor: Sum of distillation losses for the chunk. The class will handle
|
|
22
|
+
converting this to mean loss by dividing by the full batch size * sequence length in _compute_loss.
|
|
17
23
|
"""
|
|
18
24
|
raise NotImplementedError("Distillation loss function must be implemented.")
|
|
19
25
|
|
|
@@ -65,14 +71,14 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
65
71
|
distillation_loss_fn=None,
|
|
66
72
|
full_target=None,
|
|
67
73
|
ignore_index=-100,
|
|
68
|
-
temperature=1.0,
|
|
69
74
|
weight_hard_loss=0.5,
|
|
70
75
|
weight_soft_loss=0.5,
|
|
71
76
|
compute_ce_loss=True,
|
|
77
|
+
temperature=1,
|
|
72
78
|
**loss_kwargs,
|
|
73
79
|
):
|
|
74
80
|
"""
|
|
75
|
-
Compute the total loss for a chunk of input and target, while using an
|
|
81
|
+
Compute the total loss for a chunk of input and target, while using an knowledge distillation loss function.
|
|
76
82
|
Args:
|
|
77
83
|
distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
|
78
84
|
student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size).
|
|
@@ -82,32 +88,36 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
82
88
|
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,).
|
|
83
89
|
student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
|
84
90
|
teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
|
|
85
|
-
full_target (torch.Tensor): Full target tensor. Shape: (
|
|
91
|
+
full_target (torch.Tensor): Full target tensor. Shape: (batch_size * sequence_length,).
|
|
86
92
|
ignore_index (int): Index to ignore for loss computation.
|
|
87
93
|
weight_hard_loss (float): Weight for hard loss.
|
|
88
94
|
weight_soft_loss (float): Weight for soft loss.
|
|
89
95
|
compute_ce_loss (bool): Whether to compute CE loss.
|
|
96
|
+
temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
|
|
90
97
|
loss_kwargs (dict): Additional arguments for the loss function.
|
|
91
98
|
"""
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
99
|
+
(
|
|
100
|
+
student_logits_chunk,
|
|
101
|
+
teacher_logits_chunk,
|
|
102
|
+
hard_loss,
|
|
103
|
+
) = LigerFusedLinearDistillationBase.chunk_forward(
|
|
104
|
+
student_input_chunk,
|
|
105
|
+
student_weight,
|
|
106
|
+
teacher_input_chunk,
|
|
107
|
+
teacher_weight,
|
|
108
|
+
target_chunk,
|
|
109
|
+
student_bias=student_bias,
|
|
110
|
+
teacher_bias=teacher_bias,
|
|
111
|
+
ignore_index=ignore_index,
|
|
112
|
+
compute_ce_loss=compute_ce_loss,
|
|
104
113
|
)
|
|
105
114
|
|
|
115
|
+
student_logits_chunk /= temperature
|
|
116
|
+
teacher_logits_chunk /= temperature
|
|
117
|
+
|
|
106
118
|
hard_loss /= full_target.shape[0]
|
|
107
119
|
|
|
108
|
-
soft_loss = distillation_loss_fn(
|
|
109
|
-
student_logits_chunk, teacher_logits_chunk, temperature
|
|
110
|
-
)
|
|
120
|
+
soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk)
|
|
111
121
|
soft_loss /= full_target.shape[0]
|
|
112
122
|
|
|
113
123
|
loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
|
|
@@ -128,6 +138,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
128
138
|
ignore_index=-100,
|
|
129
139
|
weight_hard_loss=0.5,
|
|
130
140
|
weight_soft_loss=0.5,
|
|
141
|
+
beta=0.5,
|
|
131
142
|
compute_ce_loss=True,
|
|
132
143
|
temperature=1.0,
|
|
133
144
|
compiled=True,
|
|
@@ -147,10 +158,12 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
147
158
|
teacher_bias (torch.Tensor, optional): Teacher bias tensor. Shape: (vocab_size,).
|
|
148
159
|
loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
|
|
149
160
|
chunk_size (int): Size of a chunk.
|
|
150
|
-
compute_ce_loss (bool): Whether to compute CE loss.
|
|
151
161
|
ignore_index (int): Index to ignore for loss computation.
|
|
152
162
|
weight_hard_loss (float): Weight for hard/task loss.
|
|
153
163
|
weight_soft_loss (float): Weight for soft/distillation loss.
|
|
164
|
+
beta (float): Interpolation coefficient between 0 and 1 (default: 0.5).
|
|
165
|
+
compute_ce_loss (bool): Whether to compute CE loss.
|
|
166
|
+
temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
|
|
154
167
|
compiled (bool): Whether to use torch compile for chunk accumulation.
|
|
155
168
|
loss_kwargs (dict): Other possible arguments that a loss function might need
|
|
156
169
|
"""
|
|
@@ -167,6 +180,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
167
180
|
ignore_index=ignore_index,
|
|
168
181
|
weight_hard_loss=weight_hard_loss,
|
|
169
182
|
weight_soft_loss=weight_soft_loss,
|
|
183
|
+
beta=beta,
|
|
170
184
|
compute_ce_loss=compute_ce_loss,
|
|
171
185
|
temperature=temperature,
|
|
172
186
|
**loss_kwargs,
|
|
@@ -174,17 +188,18 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
174
188
|
|
|
175
189
|
def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk):
|
|
176
190
|
if student_bias is not None:
|
|
177
|
-
(
|
|
178
|
-
|
|
191
|
+
(
|
|
192
|
+
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
|
|
179
193
|
(
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
194
|
+
chunk_loss,
|
|
195
|
+
(
|
|
196
|
+
chunk_soft_loss,
|
|
197
|
+
chunk_hard_loss,
|
|
198
|
+
chunk_student_logits,
|
|
199
|
+
chunk_teacher_logits,
|
|
200
|
+
),
|
|
184
201
|
),
|
|
185
|
-
) = torch.func.grad_and_value(
|
|
186
|
-
loss_func_to_call, argnums=(0, 1, 5), has_aux=True
|
|
187
|
-
)(
|
|
202
|
+
) = torch.func.grad_and_value(loss_func_to_call, argnums=(0, 1, 5), has_aux=True)(
|
|
188
203
|
student_input_chunk,
|
|
189
204
|
student_weight,
|
|
190
205
|
teacher_input_chunk,
|
|
@@ -195,17 +210,18 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
195
210
|
)
|
|
196
211
|
grad_bias.add_(chunk_grad_bias)
|
|
197
212
|
else:
|
|
198
|
-
(
|
|
199
|
-
|
|
213
|
+
(
|
|
214
|
+
(chunk_grad_input, chunk_grad_weight),
|
|
200
215
|
(
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
216
|
+
chunk_loss,
|
|
217
|
+
(
|
|
218
|
+
chunk_soft_loss,
|
|
219
|
+
chunk_hard_loss,
|
|
220
|
+
chunk_student_logits,
|
|
221
|
+
chunk_teacher_logits,
|
|
222
|
+
),
|
|
205
223
|
),
|
|
206
|
-
) = torch.func.grad_and_value(
|
|
207
|
-
loss_func_to_call, argnums=(0, 1), has_aux=True
|
|
208
|
-
)(
|
|
224
|
+
) = torch.func.grad_and_value(loss_func_to_call, argnums=(0, 1), has_aux=True)(
|
|
209
225
|
student_input_chunk,
|
|
210
226
|
student_weight,
|
|
211
227
|
teacher_input_chunk,
|
|
@@ -229,9 +245,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
|
|
|
229
245
|
for student_input_chunk, teacher_input_chunk, target_chunk in zip(
|
|
230
246
|
_student_input_chunks, _teacher_input_chunks, _target_chunks
|
|
231
247
|
):
|
|
232
|
-
grad_input = accumulate_chunk(
|
|
233
|
-
student_input_chunk, teacher_input_chunk, target_chunk
|
|
234
|
-
)
|
|
248
|
+
grad_input = accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk)
|
|
235
249
|
grad_inputs.append(grad_input)
|
|
236
250
|
|
|
237
251
|
ctx.save_for_backward(
|