liger-kernel 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. liger_kernel/chunked_loss/README.md +25 -0
  2. liger_kernel/chunked_loss/__init__.py +3 -0
  3. liger_kernel/chunked_loss/cpo_loss.py +18 -8
  4. liger_kernel/chunked_loss/dpo_loss.py +20 -10
  5. liger_kernel/chunked_loss/functional.py +4 -0
  6. liger_kernel/chunked_loss/fused_linear_distillation.py +58 -44
  7. liger_kernel/chunked_loss/fused_linear_preference.py +108 -60
  8. liger_kernel/chunked_loss/fused_linear_rlhf.py +213 -0
  9. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +246 -0
  10. liger_kernel/chunked_loss/grpo_loss.py +160 -0
  11. liger_kernel/chunked_loss/jsd_loss.py +154 -0
  12. liger_kernel/chunked_loss/kto_loss.py +172 -0
  13. liger_kernel/chunked_loss/orpo_loss.py +8 -9
  14. liger_kernel/chunked_loss/simpo_loss.py +22 -8
  15. liger_kernel/env_report.py +5 -12
  16. liger_kernel/ops/cross_entropy.py +102 -51
  17. liger_kernel/ops/experimental/embedding.py +1 -3
  18. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  19. liger_kernel/ops/fused_linear_cross_entropy.py +89 -55
  20. liger_kernel/ops/fused_linear_jsd.py +14 -32
  21. liger_kernel/ops/geglu.py +6 -17
  22. liger_kernel/ops/group_norm.py +11 -28
  23. liger_kernel/ops/jsd.py +5 -9
  24. liger_kernel/ops/kl_div.py +8 -11
  25. liger_kernel/ops/layer_norm.py +23 -12
  26. liger_kernel/ops/qwen2vl_mrope.py +8 -25
  27. liger_kernel/ops/rms_norm.py +14 -32
  28. liger_kernel/ops/rope.py +31 -33
  29. liger_kernel/ops/swiglu.py +4 -8
  30. liger_kernel/ops/tvd.py +207 -0
  31. liger_kernel/ops/utils.py +3 -2
  32. liger_kernel/transformers/__init__.py +19 -24
  33. liger_kernel/transformers/auto_model.py +6 -13
  34. liger_kernel/transformers/cross_entropy.py +7 -9
  35. liger_kernel/transformers/experimental/embedding.py +1 -3
  36. liger_kernel/transformers/functional.py +28 -7
  37. liger_kernel/transformers/fused_linear_cross_entropy.py +15 -10
  38. liger_kernel/transformers/geglu.py +1 -4
  39. liger_kernel/transformers/group_norm.py +9 -15
  40. liger_kernel/transformers/jsd.py +1 -3
  41. liger_kernel/transformers/kl_div.py +1 -3
  42. liger_kernel/transformers/layer_norm.py +3 -9
  43. liger_kernel/transformers/model/gemma.py +18 -40
  44. liger_kernel/transformers/model/gemma2.py +19 -41
  45. liger_kernel/transformers/model/llama.py +22 -48
  46. liger_kernel/transformers/model/mistral.py +14 -26
  47. liger_kernel/transformers/model/mixtral.py +24 -54
  48. liger_kernel/transformers/model/mllama.py +16 -36
  49. liger_kernel/transformers/model/olmo2.py +124 -0
  50. liger_kernel/transformers/model/phi3.py +18 -40
  51. liger_kernel/transformers/model/qwen2.py +18 -40
  52. liger_kernel/transformers/model/qwen2_vl.py +36 -32
  53. liger_kernel/transformers/monkey_patch.py +214 -144
  54. liger_kernel/transformers/rms_norm.py +4 -4
  55. liger_kernel/transformers/rope.py +2 -2
  56. liger_kernel/transformers/swiglu.py +2 -8
  57. liger_kernel/transformers/trainer/__init__.py +1 -3
  58. liger_kernel/transformers/trainer/orpo_trainer.py +31 -18
  59. liger_kernel/transformers/tvd.py +13 -0
  60. liger_kernel/triton/__init__.py +1 -3
  61. liger_kernel/triton/monkey_patch.py +1 -3
  62. liger_kernel/utils.py +49 -0
  63. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/METADATA +53 -26
  64. liger_kernel-0.5.4.dist-info/RECORD +74 -0
  65. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/WHEEL +1 -1
  66. liger_kernel-0.5.2.dist-info/RECORD +0 -65
  67. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/LICENSE +0 -0
  68. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/NOTICE +0 -0
  69. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,246 @@
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+
4
+ import torch
5
+
6
+ from torch.nn import functional as F
7
+
8
+
9
+ class LigerFusedLinearUnpairedPreferenceBase(torch.autograd.Function):
10
+ @abstractmethod
11
+ def preference_loss_fn(*args, **kwargs):
12
+ """
13
+ To be extended by subclasses.
14
+ """
15
+ raise NotImplementedError("Preference loss function must be implemented.")
16
+
17
+ @staticmethod
18
+ def forward(
19
+ ctx,
20
+ _input,
21
+ weight,
22
+ target,
23
+ preference_labels,
24
+ bias=None,
25
+ loss_fn=None,
26
+ chunk_size=1,
27
+ ignore_index=-100,
28
+ compiled=True,
29
+ use_ref_model=False,
30
+ ref_input=None,
31
+ ref_weight=None,
32
+ ref_bias=None,
33
+ **loss_kwargs,
34
+ ):
35
+ """
36
+ Base class for fused linear layer with unpaired preference loss like KTO
37
+ Expects _input to be stacked with chosen and rejected inputs on the batch dimension.
38
+
39
+ The mental model is:
40
+
41
+ forward()
42
+ ├── Loop over chunks
43
+ └── compute_loss()
44
+ ├── chunk_forward() # Compute logits and log probs
45
+ └── prefer_loss() # Calculate preference loss
46
+
47
+ Args:
48
+ _input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size).
49
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
50
+ target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len).
51
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
52
+ loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
53
+ chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
54
+ ignore_index (int): Index to ignore for loss computation.
55
+ beta (float): Weight for the preference loss.
56
+ compiled (bool): Whether to use torch compile for chunk accumulation.
57
+ use_ref_model (bool): Whether to use a reference model for the alignment loss.
58
+ preference_labels (torch.Tensor): Boolean tensor indicating chosen (True) vs rejected (False) examples.
59
+ Shape: (batch_size,).
60
+ ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
61
+ ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
62
+ loss_kwargs (dict): Other possible arguments that a loss function might need
63
+ """
64
+ # TODO: Tune CHUNK_SIZE to fully utilize the GPU
65
+ CHUNK_SIZE = chunk_size
66
+
67
+ # Gradients to be accumulated
68
+ grad_inputs = []
69
+ grad_weight = torch.zeros_like(weight)
70
+ grad_bias = torch.zeros_like(bias) if bias is not None else None
71
+
72
+ # Loss to be accumulated
73
+ loss_acc = torch.zeros((), device=_input.device)
74
+
75
+ compute_loss = partial(
76
+ LigerFusedLinearUnpairedPreferenceBase._compute_loss,
77
+ preference_loss_fn=loss_fn,
78
+ full_target=target,
79
+ ignore_index=ignore_index,
80
+ use_ref_model=use_ref_model,
81
+ ref_weight=ref_weight,
82
+ ref_bias=ref_bias,
83
+ **loss_kwargs,
84
+ )
85
+
86
+ def fused_fwd_bwd(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk):
87
+ """
88
+ Fused forward and backward pass for a chunk of input and target.
89
+ """
90
+ argnums = (0, 1, 4) if bias is not None else (0, 1)
91
+ return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=False)(
92
+ input_chunk,
93
+ weight,
94
+ target_chunk,
95
+ preference_labels_chunk,
96
+ bias,
97
+ ref_input_chunk=ref_input_chunk,
98
+ )
99
+
100
+ def accumulate_chunk(
101
+ input_chunk,
102
+ target_chunk,
103
+ preference_labels_chunk=None,
104
+ ref_input_chunk=None,
105
+ ):
106
+ (chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss) = fused_fwd_bwd(
107
+ input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk
108
+ )
109
+ if bias is not None:
110
+ grad_bias.add_(chunk_grad_bias[0]) # accumulate bias gradient
111
+
112
+ # Accumulate gradients
113
+ grad_weight.add_(chunk_grad_weight)
114
+ grad_inputs.append(chunk_grad_input)
115
+
116
+ # Accumulate loss
117
+ loss_acc.add_(chunk_loss)
118
+
119
+ if compiled:
120
+ fused_fwd_bwd = torch.compile(fused_fwd_bwd)
121
+
122
+ # When not paired, use labels to separate chosen and rejected
123
+ assert preference_labels is not None, "preference_labels must be provided for unpaired preference loss"
124
+
125
+ chunks = max(1, _input.shape[0] // CHUNK_SIZE)
126
+ _input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
127
+ _target_chunks = torch.chunk(target, chunks=chunks, dim=0)
128
+ _preference_labels_chunks = torch.chunk(preference_labels, chunks=chunks, dim=0)
129
+
130
+ if use_ref_model:
131
+ _ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0)
132
+
133
+ for (
134
+ input_chunk,
135
+ target_chunk,
136
+ ref_input_chunk,
137
+ preference_labels_chunk,
138
+ ) in zip(
139
+ _input_chunks,
140
+ _target_chunks,
141
+ (_ref_input_chunks if use_ref_model else [None] * len(_input_chunks)),
142
+ _preference_labels_chunks,
143
+ ):
144
+ # mark input_chunk, target_chunk, and target dimension 1 (sequence length) as dynamic to prevent torch.compile recompilation
145
+ torch._dynamo.mark_dynamic(input_chunk, 1)
146
+ torch._dynamo.mark_dynamic(target_chunk, 1)
147
+ torch._dynamo.mark_dynamic(target, 1)
148
+ torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None
149
+ torch._dynamo.mark_dynamic(preference_labels_chunk, 1)
150
+
151
+ # accumulate loss, gradients, and metrics
152
+ accumulate_chunk(input_chunk, target_chunk, preference_labels_chunk, ref_input_chunk)
153
+
154
+ ctx.save_for_backward(
155
+ torch.cat(grad_inputs, dim=0),
156
+ grad_weight,
157
+ grad_bias,
158
+ )
159
+ return loss_acc
160
+
161
+ @staticmethod
162
+ def backward(ctx, *grad_output):
163
+ grad_input, grad_weight, grad_bias = ctx.saved_tensors
164
+ if torch.ne(grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)):
165
+ grad_input = grad_input * grad_output[0][0]
166
+ grad_weight = grad_weight * grad_output[0][0]
167
+ grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
168
+
169
+ return grad_input, grad_weight, None, None, grad_bias
170
+
171
+ @staticmethod
172
+ def chunk_forward(
173
+ input_chunk,
174
+ weight,
175
+ target_chunk,
176
+ bias=None,
177
+ ignore_index=-100,
178
+ ):
179
+ logits_chunk = input_chunk @ weight.t()
180
+ if bias is not None:
181
+ logits_chunk = logits_chunk + bias
182
+ log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
183
+
184
+ loss_mask_chunk = target_chunk != ignore_index
185
+ label_chunk = torch.where(loss_mask_chunk, target_chunk, 0)
186
+
187
+ per_token_logps_chunk = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
188
+ average_log_prob_chunk = (per_token_logps_chunk * loss_mask_chunk).sum(-1) / loss_mask_chunk.sum(-1)
189
+
190
+ return average_log_prob_chunk
191
+
192
+ @staticmethod
193
+ def _compute_loss(
194
+ input_chunk,
195
+ weight,
196
+ target_chunk,
197
+ preference_labels_chunk,
198
+ bias=None,
199
+ preference_loss_fn=None,
200
+ full_target=None,
201
+ ignore_index=-100,
202
+ use_ref_model=False,
203
+ ref_input_chunk=None,
204
+ ref_weight=None,
205
+ ref_bias=None,
206
+ **loss_kwargs,
207
+ ):
208
+ """
209
+ Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
210
+ Args:
211
+ preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
212
+ input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
213
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
214
+ target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
215
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
216
+ full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
217
+ ignore_index (int): Index to ignore for loss computation.
218
+ use_ref_model (bool): Whether to use a reference model for the alignment loss.
219
+ ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
220
+ ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
221
+ loss_kwargs (dict): Additional arguments for the loss function.
222
+ """
223
+ average_log_prob_chunk = LigerFusedLinearUnpairedPreferenceBase.chunk_forward(
224
+ input_chunk,
225
+ weight,
226
+ target_chunk,
227
+ bias=bias,
228
+ ignore_index=ignore_index,
229
+ )
230
+
231
+ if use_ref_model:
232
+ with torch.no_grad():
233
+ ref_average_log_prob_chunk = LigerFusedLinearUnpairedPreferenceBase.chunk_forward(
234
+ ref_input_chunk,
235
+ ref_weight,
236
+ target_chunk,
237
+ ref_bias,
238
+ ignore_index=ignore_index,
239
+ )
240
+ loss_kwargs["ref_average_log_prob_chunk"] = ref_average_log_prob_chunk
241
+
242
+ preference_loss_chunk = preference_loss_fn(
243
+ average_log_prob_chunk, preference_labels_chunk, full_target, **loss_kwargs
244
+ )
245
+
246
+ return preference_loss_chunk
@@ -0,0 +1,160 @@
1
+ import torch
2
+
3
+ from liger_kernel.chunked_loss.fused_linear_rlhf import LigerFusedLinearRLHFBase
4
+
5
+
6
+ class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
7
+ @staticmethod
8
+ def rlhf_loss_fn(
9
+ log_probs,
10
+ attention_mask,
11
+ rewards,
12
+ ref_log_probs=None,
13
+ beta=0.1,
14
+ **kwargs,
15
+ ):
16
+ """GRPO Loss Function matching GRPOTrainer implementation."""
17
+ # Get chosen token probabilities
18
+ chosen_tokens = log_probs.argmax(dim=-1) # (batch_size, seq_len)
19
+ chosen_token_logprobs = log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(
20
+ -1
21
+ ) # (batch_size, seq_len)
22
+
23
+ # Get reference model probabilities
24
+ if ref_log_probs is not None:
25
+ with torch.no_grad():
26
+ ref_token_logprobs = ref_log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(-1)
27
+ else:
28
+ ref_token_logprobs = chosen_token_logprobs.detach()
29
+
30
+ # Compute advantages per batch entry in a grouped fashion
31
+ mean_grouped_rewards = rewards.mean() # [batch_size,]
32
+ std_grouped_rewards = rewards.std() # [batch_size,]
33
+
34
+ # Calculate advantages using the same epsilon as in GRPOTrainer
35
+ eps = 1e-4
36
+ advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + eps)
37
+
38
+ # Compute policy gradient loss with importance sampling ratio
39
+ ratio = torch.exp(chosen_token_logprobs - chosen_token_logprobs.detach())
40
+ policy_loss = -ratio * advantages.unsqueeze(1)
41
+
42
+ # Compute KL penalty
43
+ kl_div = (
44
+ torch.exp(ref_token_logprobs - chosen_token_logprobs) - (ref_token_logprobs - chosen_token_logprobs) - 1.0
45
+ )
46
+
47
+ # Combine losses
48
+ per_token_loss = policy_loss + beta * kl_div
49
+
50
+ # Apply masking and normalize
51
+ masked_loss = per_token_loss * attention_mask
52
+ seq_lengths = attention_mask.sum()
53
+ seq_lengths = torch.clamp(seq_lengths, min=1.0)
54
+ loss = masked_loss.sum() / seq_lengths
55
+
56
+ # Calculate metrics
57
+ metrics = (
58
+ chosen_token_logprobs.mean(), # mean log prob
59
+ chosen_token_logprobs.std(), # std log prob
60
+ log_probs.mean(), # mean all log probs
61
+ ((kl_div * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)).mean(), # mean KL div
62
+ )
63
+
64
+ return loss, metrics
65
+
66
+ @staticmethod
67
+ def forward(
68
+ ctx,
69
+ _input,
70
+ weight,
71
+ attention_mask,
72
+ rewards,
73
+ bias=None,
74
+ ref_input=None,
75
+ ref_weight=None,
76
+ ref_bias=None,
77
+ beta=0.1,
78
+ compiled=True,
79
+ use_ref_model=True,
80
+ num_generations=1,
81
+ ):
82
+ return LigerFusedLinearRLHFBase.forward(
83
+ ctx=ctx,
84
+ _input=_input,
85
+ weight=weight,
86
+ attention_mask=attention_mask,
87
+ loss_fn=LigerFusedLinearGRPOFunction.rlhf_loss_fn,
88
+ rewards=rewards,
89
+ bias=bias,
90
+ ref_input=ref_input,
91
+ ref_weight=ref_weight,
92
+ ref_bias=ref_bias,
93
+ beta=beta,
94
+ compiled=compiled,
95
+ use_ref_model=use_ref_model,
96
+ num_generations=num_generations,
97
+ )
98
+
99
+ @staticmethod
100
+ def backward(ctx, grad_output, *grad_metrics):
101
+ """Backward pass for GRPO loss.
102
+
103
+ Args:
104
+ grad_output: Gradient of the loss (scalar)
105
+ grad_metrics: Gradients of the metrics (not used in backward computation)
106
+ """
107
+ grads = LigerFusedLinearRLHFBase.backward(ctx, grad_output)
108
+ return (
109
+ *grads[:5], # grad_input, grad_weight, grad_attention_mask, grad_rewards, grad_bias
110
+ None, # grad_ref_input
111
+ None, # grad_ref_weight
112
+ None, # grad_ref_bias
113
+ None, # grad_beta
114
+ None, # grad_compiled
115
+ None, # grad_use_ref_model
116
+ None, # grad_num_generations
117
+ )
118
+
119
+
120
+ class LigerFusedLinearGRPOLoss(torch.nn.Module):
121
+ """Fused linear layer with GRPO loss."""
122
+
123
+ def __init__(
124
+ self,
125
+ beta: float = 0.1,
126
+ compiled: bool = True,
127
+ use_ref_model: bool = True,
128
+ num_generations: int = 1,
129
+ ):
130
+ super().__init__()
131
+ self.beta = beta
132
+ self.compiled = compiled
133
+ self.use_ref_model = use_ref_model
134
+ self.num_generations = num_generations
135
+
136
+ def forward(
137
+ self,
138
+ _input,
139
+ lin_weight,
140
+ attention_mask,
141
+ rewards,
142
+ bias=None,
143
+ ref_input=None,
144
+ ref_weight=None,
145
+ ref_bias=None,
146
+ ):
147
+ return LigerFusedLinearGRPOFunction.apply(
148
+ _input,
149
+ lin_weight,
150
+ attention_mask,
151
+ rewards,
152
+ bias,
153
+ ref_input,
154
+ ref_weight,
155
+ ref_bias,
156
+ self.beta,
157
+ self.compiled,
158
+ self.use_ref_model,
159
+ self.num_generations,
160
+ )
@@ -0,0 +1,154 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase
5
+
6
+
7
+ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
8
+ @staticmethod
9
+ def distillation_loss_fn(student_logits, teacher_logits, beta=0.5):
10
+ """
11
+ Compute JSD loss (Jensen-Shannon Divergence Loss).
12
+ Args:
13
+ student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
14
+ teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
15
+ beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
16
+ Returns:
17
+ torch.Tensor: Jensen-Shannon Divergence loss
18
+ """
19
+ student_log_probs = F.log_softmax(student_logits, dim=-1)
20
+ teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
21
+
22
+ # Compute probabilities (only required for mean calculation)
23
+ mean_probs = beta * student_log_probs.exp() + (1 - beta) * teacher_log_probs.exp()
24
+ log_mean_probs = mean_probs.log()
25
+
26
+ student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
27
+ teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
28
+
29
+ # JSD is the weighted average of the KL divergences
30
+ jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
31
+ return jsd_loss
32
+
33
+ @staticmethod
34
+ def forward(
35
+ ctx,
36
+ student_input: torch.Tensor,
37
+ student_weight: torch.Tensor,
38
+ teacher_input: torch.Tensor,
39
+ teacher_weight: torch.Tensor,
40
+ true_labels: torch.LongTensor,
41
+ weight_hard_loss: float = 0.5,
42
+ weight_soft_loss: float = 0.5,
43
+ beta: float = 0.5,
44
+ ignore_index: int = -100,
45
+ temperature: float = 1.0,
46
+ compiled: bool = True,
47
+ ):
48
+ """
49
+ Fused linear layer with JSD distillation loss.
50
+ Args:
51
+ student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, hidden_size_student)
52
+ student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, hidden_size_student)
53
+ teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, hidden_size_teacher)
54
+ teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, hidden_size_teacher)
55
+ true_labels (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
56
+ weight_hard_loss (float): Weight for hard loss.
57
+ weight_soft_loss (float): Weight for soft loss.
58
+ beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
59
+ ignore_index (int): Index to ignore in loss computation
60
+ temperature (float): Temperature for softening/sharpening distributions
61
+ compiled (bool): Whether to use torch compile
62
+ Returns:
63
+ torch.Tensor: Computed loss
64
+ """
65
+ return LigerFusedLinearDistillationBase.forward(
66
+ ctx=ctx,
67
+ student_input=student_input,
68
+ student_weight=student_weight,
69
+ teacher_input=teacher_input,
70
+ teacher_weight=teacher_weight,
71
+ target=true_labels,
72
+ loss_fn=LigerFusedLinearJSDFunction.distillation_loss_fn,
73
+ chunk_size=1,
74
+ weight_hard_loss=weight_hard_loss,
75
+ weight_soft_loss=weight_soft_loss,
76
+ beta=beta,
77
+ ignore_index=ignore_index,
78
+ temperature=temperature,
79
+ compiled=compiled,
80
+ )
81
+
82
+ @staticmethod
83
+ def backward(ctx, grad_output):
84
+ grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:4]
85
+
86
+ return (*grads, None, None, None, None, None, None, None)
87
+
88
+
89
+ class LigerFusedLinearJSDLoss(torch.nn.Module):
90
+ """
91
+ Fused linear layer with JSD distillation loss.
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ weight_hard_loss: float = 0.5,
97
+ weight_soft_loss: float = 0.5,
98
+ beta: float = 0.5,
99
+ ignore_index: int = -100,
100
+ temperature: float = 1.0,
101
+ compiled: bool = True,
102
+ ):
103
+ """
104
+ Args:
105
+ weight_hard_loss (float): Weight for hard loss.
106
+ weight_soft_loss (float): Weight for soft loss.
107
+ ignore_index (int): Index to ignore in the loss
108
+ temperature (float): Temperature for softening distributions
109
+ compiled (bool): Whether to use torch compile
110
+ beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
111
+ """
112
+ super().__init__()
113
+ assert temperature != 0, "Temperature cannot be 0."
114
+ self.weight_hard_loss = weight_hard_loss
115
+ self.weight_soft_loss = weight_soft_loss
116
+ self.ignore_index = ignore_index
117
+ self.temperature = temperature
118
+ self.compiled = compiled
119
+ self.beta = beta
120
+
121
+ def forward(
122
+ self,
123
+ student_input: torch.Tensor,
124
+ student_weight: torch.Tensor,
125
+ teacher_input: torch.Tensor,
126
+ teacher_weight: torch.Tensor,
127
+ true_labels: torch.LongTensor,
128
+ ) -> torch.Tensor:
129
+ """
130
+ Compute the JSD distillation loss.
131
+
132
+ Args:
133
+ student_input (torch.Tensor): Student input tensor
134
+ student_weight (torch.Tensor): Student weight tensor
135
+ teacher_input (torch.Tensor): Teacher input tensor
136
+ teacher_weight (torch.Tensor): Teacher weight tensor
137
+ true_labels (torch.LongTensor): Target labels tensor
138
+
139
+ Returns:
140
+ torch.Tensor: Computed loss
141
+ """
142
+ return LigerFusedLinearJSDFunction.apply(
143
+ student_input,
144
+ student_weight,
145
+ teacher_input,
146
+ teacher_weight,
147
+ true_labels,
148
+ self.weight_hard_loss,
149
+ self.weight_soft_loss,
150
+ self.beta,
151
+ self.ignore_index,
152
+ self.temperature,
153
+ self.compiled,
154
+ )