liger-kernel 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. liger_kernel/chunked_loss/README.md +25 -0
  2. liger_kernel/chunked_loss/__init__.py +3 -0
  3. liger_kernel/chunked_loss/cpo_loss.py +18 -8
  4. liger_kernel/chunked_loss/dpo_loss.py +20 -10
  5. liger_kernel/chunked_loss/functional.py +4 -0
  6. liger_kernel/chunked_loss/fused_linear_distillation.py +58 -44
  7. liger_kernel/chunked_loss/fused_linear_preference.py +108 -60
  8. liger_kernel/chunked_loss/fused_linear_rlhf.py +213 -0
  9. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +246 -0
  10. liger_kernel/chunked_loss/grpo_loss.py +160 -0
  11. liger_kernel/chunked_loss/jsd_loss.py +154 -0
  12. liger_kernel/chunked_loss/kto_loss.py +172 -0
  13. liger_kernel/chunked_loss/orpo_loss.py +8 -9
  14. liger_kernel/chunked_loss/simpo_loss.py +22 -8
  15. liger_kernel/env_report.py +5 -12
  16. liger_kernel/ops/cross_entropy.py +102 -51
  17. liger_kernel/ops/experimental/embedding.py +1 -3
  18. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  19. liger_kernel/ops/fused_linear_cross_entropy.py +89 -55
  20. liger_kernel/ops/fused_linear_jsd.py +14 -32
  21. liger_kernel/ops/geglu.py +6 -17
  22. liger_kernel/ops/group_norm.py +11 -28
  23. liger_kernel/ops/jsd.py +5 -9
  24. liger_kernel/ops/kl_div.py +8 -11
  25. liger_kernel/ops/layer_norm.py +23 -12
  26. liger_kernel/ops/qwen2vl_mrope.py +8 -25
  27. liger_kernel/ops/rms_norm.py +14 -32
  28. liger_kernel/ops/rope.py +31 -33
  29. liger_kernel/ops/swiglu.py +4 -8
  30. liger_kernel/ops/tvd.py +207 -0
  31. liger_kernel/ops/utils.py +3 -2
  32. liger_kernel/transformers/__init__.py +19 -24
  33. liger_kernel/transformers/auto_model.py +6 -13
  34. liger_kernel/transformers/cross_entropy.py +7 -9
  35. liger_kernel/transformers/experimental/embedding.py +1 -3
  36. liger_kernel/transformers/functional.py +28 -7
  37. liger_kernel/transformers/fused_linear_cross_entropy.py +15 -10
  38. liger_kernel/transformers/geglu.py +1 -4
  39. liger_kernel/transformers/group_norm.py +9 -15
  40. liger_kernel/transformers/jsd.py +1 -3
  41. liger_kernel/transformers/kl_div.py +1 -3
  42. liger_kernel/transformers/layer_norm.py +3 -9
  43. liger_kernel/transformers/model/gemma.py +18 -40
  44. liger_kernel/transformers/model/gemma2.py +19 -41
  45. liger_kernel/transformers/model/llama.py +22 -48
  46. liger_kernel/transformers/model/mistral.py +14 -26
  47. liger_kernel/transformers/model/mixtral.py +24 -54
  48. liger_kernel/transformers/model/mllama.py +16 -36
  49. liger_kernel/transformers/model/olmo2.py +124 -0
  50. liger_kernel/transformers/model/phi3.py +18 -40
  51. liger_kernel/transformers/model/qwen2.py +18 -40
  52. liger_kernel/transformers/model/qwen2_vl.py +36 -32
  53. liger_kernel/transformers/monkey_patch.py +214 -144
  54. liger_kernel/transformers/rms_norm.py +4 -4
  55. liger_kernel/transformers/rope.py +2 -2
  56. liger_kernel/transformers/swiglu.py +2 -8
  57. liger_kernel/transformers/trainer/__init__.py +1 -3
  58. liger_kernel/transformers/trainer/orpo_trainer.py +31 -18
  59. liger_kernel/transformers/tvd.py +13 -0
  60. liger_kernel/triton/__init__.py +1 -3
  61. liger_kernel/triton/monkey_patch.py +1 -3
  62. liger_kernel/utils.py +49 -0
  63. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/METADATA +53 -26
  64. liger_kernel-0.5.4.dist-info/RECORD +74 -0
  65. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/WHEEL +1 -1
  66. liger_kernel-0.5.2.dist-info/RECORD +0 -65
  67. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/LICENSE +0 -0
  68. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/NOTICE +0 -0
  69. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,25 @@
1
+ # Liger FlexChunkLoss: Alignment and Distillation loss
2
+
3
+ Liger FlexChunkLoss offers a versatile interface, delivering up to 80% memory savings and a 10% throughput boost for post-training loss functions, including alignment (DPO, ORPO, CPO, KTO) and very soon, distillation. Its flexible design supports custom losses, ensuring efficiency gains across diverse use cases.
4
+
5
+ ### User interface
6
+
7
+ FlexChunkLoss offers two flexible usage options:
8
+
9
+ 1. **Via `Liger[Custom Loss]Trainer`**
10
+ For example, by simply replacing the HuggingFace `ORPOTrainer` with `LigerORPOTrainer` in your code, you can leverage our optimized ORPO implementation and immediately benefit from improved performance.
11
+
12
+ 2. **Using `nn.Module` Implementations of Custom Loss Functions**
13
+ Explore the [LigerORPOTrainer implementation](https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/orpo_trainer.py) to see how the modular design integrates custom loss functions seamlessly.
14
+
15
+ ### What's under the hood?
16
+
17
+ We employ chunking and fused kernel optimizations to enhance performance. By fusing the final linear layer with loss computation and calculating backward gradients during the forward pass, we significantly reduce the need for storing intermediate activations. All operations are implemented in PyTorch, leveraging `torch.compile` to streamline kernel execution without relying on extensive low-level optimizations. Additionally, we minimize `torch.compile` recompilations to reduce overhead and ensure consistent performance gains.
18
+
19
+ ### Extending to custom loss functions
20
+
21
+ We provide two base classes: `LigerFusedLinearPreferenceBase` for alignment use cases and `LigerFusedLinearDistillationBase` for distillation use cases. These base classes manage chunking, kernel fusions, and Torch compilation.
22
+
23
+ To implement a custom loss function, you need to create a subclass that defines the custom preference or distillation loss function, capable of processing a given input chunk. The base class will take care of the optimizations, handling most of the heavy lifting for you.
24
+
25
+ For a working example, refer to the [ORPO loss implementation](https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/chunked_loss/orpo_loss.py).
@@ -1,4 +1,7 @@
1
1
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
2
2
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
3
+ from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401
4
+ from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDLoss # noqa: F401
5
+ from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401
3
6
  from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
4
7
  from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401
@@ -1,15 +1,12 @@
1
1
  import torch
2
2
  import torch.nn.functional as F
3
3
 
4
- from liger_kernel.chunked_loss.fused_linear_preference import (
5
- LigerFusedLinearPreferenceBase,
6
- )
4
+ from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
7
5
 
8
6
 
9
7
  class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
10
-
11
8
  @staticmethod
12
- def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
9
+ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1, label_smoothing=0.0):
13
10
  """
14
11
  Paper: https://arxiv.org/pdf/2401.08417
15
12
 
@@ -30,10 +27,17 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
30
27
  rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
31
28
  full_target (torch.Tensor): Non chunked full target tensor
32
29
  beta (float): Weight for the CPO loss
30
+ label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
33
31
  """
34
32
  logits = beta * (chosen_logps - rejected_logps)
35
- loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
36
- return loss
33
+ loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
34
+ full_target.shape[0] // 2
35
+ )
36
+
37
+ chosen_rewards = beta * chosen_logps
38
+ rejected_rewards = beta * rejected_logps
39
+
40
+ return loss, chosen_rewards, rejected_rewards
37
41
 
38
42
  @staticmethod
39
43
  def forward(
@@ -45,6 +49,7 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
45
49
  ignore_index=-100,
46
50
  beta=0.1,
47
51
  alpha=1.0,
52
+ label_smoothing=0.0,
48
53
  compute_nll_loss=True,
49
54
  compiled=True,
50
55
  ):
@@ -58,14 +63,16 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
58
63
  ignore_index=ignore_index,
59
64
  alpha=alpha,
60
65
  beta=beta,
66
+ label_smoothing=label_smoothing,
61
67
  compute_nll_loss=compute_nll_loss,
68
+ average_log_prob=False,
62
69
  compiled=compiled,
63
70
  )
64
71
 
65
72
  @staticmethod
66
73
  def backward(ctx, *grad_output):
67
74
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
68
- return *grads, None, None, None, None, None
75
+ return *grads, None, None, None, None, None, None
69
76
 
70
77
 
71
78
  class LigerFusedLinearCPOLoss(torch.nn.Module):
@@ -78,6 +85,7 @@ class LigerFusedLinearCPOLoss(torch.nn.Module):
78
85
  ignore_index: int = -100,
79
86
  beta: float = 0.1,
80
87
  alpha: float = 1.0,
88
+ label_smoothing: float = 0.0,
81
89
  compute_nll_loss: bool = True,
82
90
  compiled: bool = True,
83
91
  ):
@@ -90,6 +98,7 @@ class LigerFusedLinearCPOLoss(torch.nn.Module):
90
98
  self.ignore_index = ignore_index
91
99
  self.beta = beta
92
100
  self.alpha = alpha
101
+ self.label_smoothing = label_smoothing
93
102
  self.compute_nll_loss = compute_nll_loss
94
103
  self.compiled = compiled
95
104
 
@@ -102,6 +111,7 @@ class LigerFusedLinearCPOLoss(torch.nn.Module):
102
111
  self.ignore_index,
103
112
  self.beta,
104
113
  self.alpha,
114
+ self.label_smoothing,
105
115
  self.compute_nll_loss,
106
116
  self.compiled,
107
117
  )
@@ -1,13 +1,10 @@
1
1
  import torch
2
2
  import torch.nn.functional as F
3
3
 
4
- from liger_kernel.chunked_loss.fused_linear_preference import (
5
- LigerFusedLinearPreferenceBase,
6
- )
4
+ from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
7
5
 
8
6
 
9
7
  class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
10
-
11
8
  @staticmethod
12
9
  def preference_loss_fn(
13
10
  chosen_logps,
@@ -48,9 +45,12 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
48
45
  chosen_logratios = chosen_logps - ref_chosen_logps
49
46
  rejected_logratios = rejected_logps - ref_rejected_logps
50
47
 
48
+ chosen_rewards = beta * chosen_logratios
49
+ rejected_rewards = beta * rejected_logratios
50
+
51
51
  logits_diff = beta * (chosen_logratios - rejected_logratios)
52
52
  loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
53
- return loss
53
+ return loss, chosen_rewards, rejected_rewards
54
54
 
55
55
  @staticmethod
56
56
  def forward(
@@ -59,11 +59,12 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
59
59
  weight,
60
60
  target,
61
61
  bias=None,
62
+ ref_input=None,
62
63
  ref_weight=None,
63
64
  ref_bias=None,
64
65
  ignore_index=-100,
65
66
  beta=0.1,
66
- compute_nll_loss=True,
67
+ compute_nll_loss=False,
67
68
  compiled=True,
68
69
  use_ref_model=True,
69
70
  ):
@@ -79,6 +80,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
79
80
  compute_nll_loss=compute_nll_loss,
80
81
  compiled=compiled,
81
82
  use_ref_model=use_ref_model,
83
+ ref_input=ref_input,
82
84
  ref_weight=ref_weight,
83
85
  ref_bias=ref_bias,
84
86
  )
@@ -86,7 +88,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
86
88
  @staticmethod
87
89
  def backward(ctx, *grad_output):
88
90
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
89
- return *grads, None, None, None, None, None, None, None
91
+ return *grads, None, None, None, None, None, None, None, None
90
92
 
91
93
 
92
94
  class LigerFusedLinearDPOLoss(torch.nn.Module):
@@ -98,9 +100,9 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
98
100
  self,
99
101
  ignore_index: int = -100,
100
102
  beta: float = 0.1,
101
- compute_nll_loss: bool = True,
103
+ compute_nll_loss: bool = False,
102
104
  compiled: bool = True,
103
- use_ref_model: bool = False,
105
+ use_ref_model: bool = True,
104
106
  ):
105
107
  """
106
108
  Args:
@@ -118,13 +120,21 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
118
120
  self.use_ref_model = use_ref_model
119
121
 
120
122
  def forward(
121
- self, lin_weight, _input, target, bias=None, ref_weight=None, ref_bias=None
123
+ self,
124
+ lin_weight,
125
+ _input,
126
+ target,
127
+ bias=None,
128
+ ref_input=None,
129
+ ref_weight=None,
130
+ ref_bias=None,
122
131
  ):
123
132
  return LigerFusedLinearDPOFunction.apply(
124
133
  _input,
125
134
  lin_weight,
126
135
  target,
127
136
  bias,
137
+ ref_input,
128
138
  ref_weight,
129
139
  ref_bias,
130
140
  self.ignore_index,
@@ -1,9 +1,13 @@
1
1
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
2
2
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
3
+ from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
4
+ from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction
3
5
  from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
4
6
  from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
5
7
 
6
8
  liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
7
9
  liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
10
+ liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
8
11
  liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
9
12
  liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
13
+ liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
@@ -2,18 +2,24 @@ from abc import abstractmethod
2
2
  from functools import partial
3
3
 
4
4
  import torch
5
+
5
6
  from torch.nn import functional as F
6
7
 
7
8
 
8
9
  class LigerFusedLinearDistillationBase(torch.autograd.Function):
9
-
10
10
  @abstractmethod
11
- def distillation_loss_fn(student_logits, teacher_logits, temperature):
11
+ def distillation_loss_fn(
12
+ student_logits,
13
+ teacher_logits,
14
+ ):
12
15
  """
13
16
  Compute distillation loss.
14
17
  Args:
15
- student_logits (torch.Tensor): Raw logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
16
- teacher_logits (torch.Tensor): Raw logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
18
+ student_logits (torch.Tensor): Raw (temperature-scaled) logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
19
+ teacher_logits (torch.Tensor): Raw (temperature-scaled) logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
20
+ Returns:
21
+ torch.Tensor: Sum of distillation losses for the chunk. The class will handle
22
+ converting this to mean loss by dividing by the full batch size * sequence length in _compute_loss.
17
23
  """
18
24
  raise NotImplementedError("Distillation loss function must be implemented.")
19
25
 
@@ -65,14 +71,14 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
65
71
  distillation_loss_fn=None,
66
72
  full_target=None,
67
73
  ignore_index=-100,
68
- temperature=1.0,
69
74
  weight_hard_loss=0.5,
70
75
  weight_soft_loss=0.5,
71
76
  compute_ce_loss=True,
77
+ temperature=1,
72
78
  **loss_kwargs,
73
79
  ):
74
80
  """
75
- Compute the total loss for a chunk of input and target, while using an knowleedge distillation loss function.
81
+ Compute the total loss for a chunk of input and target, while using an knowledge distillation loss function.
76
82
  Args:
77
83
  distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
78
84
  student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size).
@@ -82,32 +88,36 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
82
88
  target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,).
83
89
  student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
84
90
  teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
85
- full_target (torch.Tensor): Full target tensor. Shape: (chunk_size,).
91
+ full_target (torch.Tensor): Full target tensor. Shape: (batch_size * sequence_length,).
86
92
  ignore_index (int): Index to ignore for loss computation.
87
93
  weight_hard_loss (float): Weight for hard loss.
88
94
  weight_soft_loss (float): Weight for soft loss.
89
95
  compute_ce_loss (bool): Whether to compute CE loss.
96
+ temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
90
97
  loss_kwargs (dict): Additional arguments for the loss function.
91
98
  """
92
- student_logits_chunk, teacher_logits_chunk, hard_loss = (
93
- LigerFusedLinearDistillationBase.chunk_forward(
94
- student_input_chunk,
95
- student_weight,
96
- teacher_input_chunk,
97
- teacher_weight,
98
- target_chunk,
99
- student_bias=student_bias,
100
- teacher_bias=teacher_bias,
101
- ignore_index=ignore_index,
102
- compute_ce_loss=compute_ce_loss,
103
- )
99
+ (
100
+ student_logits_chunk,
101
+ teacher_logits_chunk,
102
+ hard_loss,
103
+ ) = LigerFusedLinearDistillationBase.chunk_forward(
104
+ student_input_chunk,
105
+ student_weight,
106
+ teacher_input_chunk,
107
+ teacher_weight,
108
+ target_chunk,
109
+ student_bias=student_bias,
110
+ teacher_bias=teacher_bias,
111
+ ignore_index=ignore_index,
112
+ compute_ce_loss=compute_ce_loss,
104
113
  )
105
114
 
115
+ student_logits_chunk /= temperature
116
+ teacher_logits_chunk /= temperature
117
+
106
118
  hard_loss /= full_target.shape[0]
107
119
 
108
- soft_loss = distillation_loss_fn(
109
- student_logits_chunk, teacher_logits_chunk, temperature
110
- )
120
+ soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk)
111
121
  soft_loss /= full_target.shape[0]
112
122
 
113
123
  loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
@@ -128,6 +138,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
128
138
  ignore_index=-100,
129
139
  weight_hard_loss=0.5,
130
140
  weight_soft_loss=0.5,
141
+ beta=0.5,
131
142
  compute_ce_loss=True,
132
143
  temperature=1.0,
133
144
  compiled=True,
@@ -147,10 +158,12 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
147
158
  teacher_bias (torch.Tensor, optional): Teacher bias tensor. Shape: (vocab_size,).
148
159
  loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
149
160
  chunk_size (int): Size of a chunk.
150
- compute_ce_loss (bool): Whether to compute CE loss.
151
161
  ignore_index (int): Index to ignore for loss computation.
152
162
  weight_hard_loss (float): Weight for hard/task loss.
153
163
  weight_soft_loss (float): Weight for soft/distillation loss.
164
+ beta (float): Interpolation coefficient between 0 and 1 (default: 0.5).
165
+ compute_ce_loss (bool): Whether to compute CE loss.
166
+ temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
154
167
  compiled (bool): Whether to use torch compile for chunk accumulation.
155
168
  loss_kwargs (dict): Other possible arguments that a loss function might need
156
169
  """
@@ -167,6 +180,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
167
180
  ignore_index=ignore_index,
168
181
  weight_hard_loss=weight_hard_loss,
169
182
  weight_soft_loss=weight_soft_loss,
183
+ beta=beta,
170
184
  compute_ce_loss=compute_ce_loss,
171
185
  temperature=temperature,
172
186
  **loss_kwargs,
@@ -174,17 +188,18 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
174
188
 
175
189
  def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk):
176
190
  if student_bias is not None:
177
- (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
178
- chunk_loss,
191
+ (
192
+ (chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
179
193
  (
180
- chunk_soft_loss,
181
- chunk_hard_loss,
182
- chunk_student_logits,
183
- chunk_teacher_logits,
194
+ chunk_loss,
195
+ (
196
+ chunk_soft_loss,
197
+ chunk_hard_loss,
198
+ chunk_student_logits,
199
+ chunk_teacher_logits,
200
+ ),
184
201
  ),
185
- ) = torch.func.grad_and_value(
186
- loss_func_to_call, argnums=(0, 1, 5), has_aux=True
187
- )(
202
+ ) = torch.func.grad_and_value(loss_func_to_call, argnums=(0, 1, 5), has_aux=True)(
188
203
  student_input_chunk,
189
204
  student_weight,
190
205
  teacher_input_chunk,
@@ -195,17 +210,18 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
195
210
  )
196
211
  grad_bias.add_(chunk_grad_bias)
197
212
  else:
198
- (chunk_grad_input, chunk_grad_weight), (
199
- chunk_loss,
213
+ (
214
+ (chunk_grad_input, chunk_grad_weight),
200
215
  (
201
- chunk_soft_loss,
202
- chunk_hard_loss,
203
- chunk_student_logits,
204
- chunk_teacher_logits,
216
+ chunk_loss,
217
+ (
218
+ chunk_soft_loss,
219
+ chunk_hard_loss,
220
+ chunk_student_logits,
221
+ chunk_teacher_logits,
222
+ ),
205
223
  ),
206
- ) = torch.func.grad_and_value(
207
- loss_func_to_call, argnums=(0, 1), has_aux=True
208
- )(
224
+ ) = torch.func.grad_and_value(loss_func_to_call, argnums=(0, 1), has_aux=True)(
209
225
  student_input_chunk,
210
226
  student_weight,
211
227
  teacher_input_chunk,
@@ -229,9 +245,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
229
245
  for student_input_chunk, teacher_input_chunk, target_chunk in zip(
230
246
  _student_input_chunks, _teacher_input_chunks, _target_chunks
231
247
  ):
232
- grad_input = accumulate_chunk(
233
- student_input_chunk, teacher_input_chunk, target_chunk
234
- )
248
+ grad_input = accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk)
235
249
  grad_inputs.append(grad_input)
236
250
 
237
251
  ctx.save_for_backward(