liger-kernel-nightly 0.4.2.dev20241209195823__tar.gz → 0.4.2.dev20241209224333__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. {liger_kernel_nightly-0.4.2.dev20241209195823/src/liger_kernel_nightly.egg-info → liger_kernel_nightly-0.4.2.dev20241209224333}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/pyproject.toml +1 -1
  3. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/chunked_loss/fused_linear_preference.py +181 -164
  4. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333/src/liger_kernel_nightly.egg-info}/PKG-INFO +1 -1
  5. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/LICENSE +0 -0
  6. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/NOTICE +0 -0
  7. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/README.md +0 -0
  8. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/setup.cfg +0 -0
  9. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/__init__.py +0 -0
  10. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  11. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  12. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  13. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/chunked_loss/functional.py +0 -0
  14. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  15. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  16. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  17. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/env_report.py +0 -0
  18. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/ops/__init__.py +0 -0
  19. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/ops/cross_entropy.py +0 -0
  20. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  21. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  22. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  23. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  24. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/ops/geglu.py +0 -0
  25. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/ops/group_norm.py +0 -0
  26. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/ops/jsd.py +0 -0
  27. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/ops/kl_div.py +0 -0
  28. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/ops/layer_norm.py +0 -0
  29. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  30. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/ops/rms_norm.py +0 -0
  31. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/ops/rope.py +0 -0
  32. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/ops/swiglu.py +0 -0
  33. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/ops/utils.py +0 -0
  34. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/__init__.py +0 -0
  35. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/auto_model.py +0 -0
  36. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  37. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  38. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/functional.py +0 -0
  39. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  40. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  41. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/geglu.py +0 -0
  42. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/group_norm.py +0 -0
  43. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/jsd.py +0 -0
  44. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/kl_div.py +0 -0
  45. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/layer_norm.py +0 -0
  46. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/model/__init__.py +0 -0
  47. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/model/gemma.py +0 -0
  48. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  49. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/model/llama.py +0 -0
  50. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/model/mistral.py +0 -0
  51. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  52. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/model/mllama.py +0 -0
  53. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/model/phi3.py +0 -0
  54. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  55. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  56. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  57. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/orpo_trainer.py +0 -0
  58. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  59. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/rms_norm.py +0 -0
  60. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/rope.py +0 -0
  61. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/swiglu.py +0 -0
  62. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  63. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/triton/__init__.py +0 -0
  64. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/triton/monkey_patch.py +0 -0
  65. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel/utils.py +0 -0
  66. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
  67. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  68. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  69. {liger_kernel_nightly-0.4.2.dev20241209195823 → liger_kernel_nightly-0.4.2.dev20241209224333}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.2.dev20241209195823
3
+ Version: 0.4.2.dev20241209224333
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.4.2.dev20241209195823"
7
+ version = "0.4.2.dev20241209224333"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -8,159 +8,12 @@ from torch.nn import functional as F
8
8
  class LigerFusedLinearPreferenceBase(torch.autograd.Function):
9
9
 
10
10
  @abstractmethod
11
- def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
11
+ def preference_loss_fn(*args, **kwargs):
12
12
  """
13
- Compute preference loss.
14
- Args:
15
- chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
16
- rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
17
- beta (float): Weight for the odds ratio loss.
13
+ To be extended by subclasses.
18
14
  """
19
15
  raise NotImplementedError("Preference loss function must be implemented.")
20
16
 
21
- @staticmethod
22
- def chunk_forward(
23
- input_chunk,
24
- weight,
25
- target_chunk,
26
- bias=None,
27
- ignore_index=-100,
28
- compute_nll_loss=True,
29
- ):
30
- len_chosen_chunk = target_chunk.shape[0] // 2
31
- logits_chunk = input_chunk @ weight.t()
32
- if bias is not None:
33
- logits_chunk = logits_chunk + bias
34
- log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
35
-
36
- chosen_nll_loss = 0.0
37
- if compute_nll_loss:
38
- chosen_nll_loss = F.nll_loss(
39
- log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
40
- target_chunk[:len_chosen_chunk].view(-1),
41
- reduction="sum",
42
- ignore_index=ignore_index,
43
- )
44
-
45
- loss_mask = target_chunk != ignore_index
46
- label_chunk = torch.where(loss_mask, target_chunk, 0)
47
-
48
- per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
49
- -1
50
- )
51
- average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
52
-
53
- chosen_logps = average_log_prob[:len_chosen_chunk]
54
- rejected_logps = average_log_prob[len_chosen_chunk:]
55
-
56
- chosen_logits = logits_chunk[:len_chosen_chunk]
57
- rejected_logits = logits_chunk[len_chosen_chunk:]
58
-
59
- return (
60
- chosen_logps,
61
- rejected_logps,
62
- chosen_logits,
63
- rejected_logits,
64
- chosen_nll_loss,
65
- )
66
-
67
- @staticmethod
68
- def _compute_loss(
69
- input_chunk,
70
- weight,
71
- target_chunk,
72
- bias=None,
73
- preference_loss_fn=None,
74
- full_target=None,
75
- ignore_index=-100,
76
- alpha=1.0,
77
- beta=0.1,
78
- compute_nll_loss=True,
79
- use_ref_model=False,
80
- ref_weight=None,
81
- ref_bias=None,
82
- **loss_kwargs,
83
- ):
84
- """
85
- Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
86
- Args:
87
- preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
88
- input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
89
- weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
90
- target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
91
- bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
92
- full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
93
- ignore_index (int): Index to ignore for loss computation.
94
- alpha (float): Weight for the NLL loss.
95
- beta (float): Weight for the odds ratio loss.
96
- compute_nll_loss (bool): Whether to compute NLL loss.
97
- use_ref_model (bool): Whether to use a reference model for the alignment loss.
98
- ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
99
- ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
100
- loss_kwargs (dict): Additional arguments for the loss function.
101
- """
102
- (
103
- chosen_logps,
104
- rejected_logps,
105
- chosen_logits,
106
- rejected_logits,
107
- chosen_nll_loss,
108
- ) = LigerFusedLinearPreferenceBase.chunk_forward(
109
- input_chunk,
110
- weight,
111
- target_chunk,
112
- bias=bias,
113
- ignore_index=ignore_index,
114
- compute_nll_loss=compute_nll_loss,
115
- )
116
- chosen_nll_loss = (
117
- chosen_nll_loss
118
- / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
119
- )
120
- chosen_logits_mean = chosen_logits.sum() / (
121
- full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
122
- )
123
- rejected_logits_mean = rejected_logits.sum() / (
124
- full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
125
- )
126
-
127
- if use_ref_model:
128
- with torch.no_grad():
129
- (
130
- ref_chosen_logps,
131
- ref_rejected_logps,
132
- ref_chosen_logits,
133
- ref_rejected_logits,
134
- ref_chosen_nll_loss,
135
- ) = LigerFusedLinearPreferenceBase.chunk_forward(
136
- input_chunk,
137
- ref_weight,
138
- target_chunk,
139
- ref_bias,
140
- ignore_index=ignore_index,
141
- compute_nll_loss=False, # We don't need NLL loss for the reference model
142
- )
143
- loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
144
- loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
145
-
146
- preference_loss_outputs = preference_loss_fn(
147
- chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs
148
- )
149
- if isinstance(preference_loss_outputs, tuple):
150
- preference_loss, *aux_outputs = preference_loss_outputs
151
- else:
152
- preference_loss, aux_outputs = preference_loss_outputs, []
153
-
154
- loss = alpha * chosen_nll_loss - preference_loss
155
- return_vars = (
156
- chosen_logps,
157
- rejected_logps,
158
- chosen_logits_mean,
159
- rejected_logits_mean,
160
- chosen_nll_loss,
161
- )
162
- return loss, (*return_vars, *aux_outputs)
163
-
164
17
  @staticmethod
165
18
  def forward(
166
19
  ctx,
@@ -176,6 +29,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
176
29
  compute_nll_loss=True,
177
30
  compiled=True,
178
31
  use_ref_model=False,
32
+ # TODO: ref input
179
33
  ref_weight=None,
180
34
  ref_bias=None,
181
35
  **loss_kwargs,
@@ -184,6 +38,14 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
184
38
  Base class for fused linear layer with preference loss.
185
39
  Expects _input to be stacked with chosen and rejected inputs on the batch dimension.
186
40
 
41
+ The mental model is:
42
+
43
+ forward()
44
+ ├── Loop over chunks
45
+ └── compute_loss()
46
+ ├── chunk_forward() # Compute logits and log probs
47
+ └── prefer_loss() # Calculate preference loss
48
+
187
49
  Args:
188
50
  _input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size).
189
51
  weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
@@ -191,10 +53,9 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
191
53
  bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
192
54
  loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
193
55
  chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
194
- compute_nll_loss (bool): Whether to compute NLL loss.
195
56
  ignore_index (int): Index to ignore for loss computation.
196
57
  alpha (float): Weight for the NLL loss.
197
- beta (float): Weight for the odds ratio loss.
58
+ beta (float): Weight for the preference loss.
198
59
  compute_nll_loss (bool): Whether to compute NLL loss.
199
60
  compiled (bool): Whether to use torch compile for chunk accumulation.
200
61
  use_ref_model (bool): Whether to use a reference model for the alignment loss.
@@ -205,11 +66,16 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
205
66
  # TODO: Tune CHUNK_SIZE to fully utilize the GPU
206
67
  CHUNK_SIZE = chunk_size
207
68
 
69
+ # Gradients to be accumulated
208
70
  grad_weight = torch.zeros_like(weight)
209
71
  grad_chosen_inputs = []
210
72
  grad_rejected_inputs = []
211
73
  grad_bias = torch.zeros_like(bias) if bias is not None else None
74
+
75
+ # Loss to be accumulated
212
76
  loss_acc = torch.zeros((), device=_input.device)
77
+
78
+ # Metrics to be recorded
213
79
  policy_chosen_logps = []
214
80
  policy_rejected_logps = []
215
81
  policy_chosen_logits_mean = torch.zeros((), device=_input.device)
@@ -217,7 +83,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
217
83
  policy_nll_loss = torch.zeros((), device=_input.device)
218
84
  aggregated_aux_outputs = [] # aggregated aux outputs from all chunks
219
85
 
220
- loss_func_to_call = partial(
86
+ compute_loss = partial(
221
87
  LigerFusedLinearPreferenceBase._compute_loss,
222
88
  preference_loss_fn=loss_fn,
223
89
  ignore_index=ignore_index,
@@ -231,14 +97,17 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
231
97
  **loss_kwargs,
232
98
  )
233
99
 
234
- def accumulate_core(input_chunk, target_chunk):
100
+ def fused_fwd_bwd(input_chunk, target_chunk):
101
+ """
102
+ Fused forward and backward pass for a chunk of input and target.
103
+ """
235
104
  if bias is not None:
236
105
  return torch.func.grad_and_value(
237
- loss_func_to_call, argnums=(0, 1, 3), has_aux=True
106
+ compute_loss, argnums=(0, 1, 3), has_aux=True
238
107
  )(input_chunk, weight, target_chunk, bias)
239
108
  else:
240
109
  return torch.func.grad_and_value(
241
- loss_func_to_call, argnums=(0, 1), has_aux=True
110
+ compute_loss, argnums=(0, 1), has_aux=True
242
111
  )(input_chunk, weight, target_chunk)
243
112
 
244
113
  def accumulate_chunk(input_chunk, target_chunk):
@@ -253,7 +122,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
253
122
  chunk_nll_loss,
254
123
  *aux_outputs,
255
124
  ),
256
- ) = accumulate_core(input_chunk, target_chunk)
125
+ ) = fused_fwd_bwd(input_chunk, target_chunk)
257
126
  grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
258
127
  else:
259
128
  (chunk_grad_input, chunk_grad_weight), (
@@ -266,16 +135,26 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
266
135
  chunk_nll_loss,
267
136
  *aux_outputs,
268
137
  ),
269
- ) = accumulate_core(input_chunk, target_chunk)
138
+ ) = fused_fwd_bwd(input_chunk, target_chunk)
270
139
 
140
+ # Accumulate gradients
271
141
  grad_weight.add_(chunk_grad_weight)
142
+ grad_chosen_inputs.append(chunk_grad_input[: chosen_target_chunk.shape[0]])
143
+ grad_rejected_inputs.append(
144
+ chunk_grad_input[chosen_target_chunk.shape[0] :]
145
+ )
146
+
147
+ # Accumulate loss
272
148
  loss_acc.add_(chunk_loss)
149
+
150
+ # Accumulate metrics
273
151
  policy_chosen_logps.append(chunk_chosen_logps)
274
152
  policy_rejected_logps.append(chunk_rejected_logps)
275
153
  policy_chosen_logits_mean.add_(chunk_chosen_logits_mean)
276
154
  policy_rejected_logits_mean.add_(chunk_rejected_logits_mean)
277
155
  policy_nll_loss.add_(chunk_nll_loss)
278
156
 
157
+ # aux_outputs
279
158
  # Initialize storage for aux_outputs
280
159
  if len(aggregated_aux_outputs) == 0:
281
160
  for aux in aux_outputs:
@@ -293,10 +172,8 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
293
172
  else:
294
173
  aggregated_aux_outputs[i].append(aux)
295
174
 
296
- return chunk_grad_input
297
-
298
175
  if compiled:
299
- accumulate_core = torch.compile(accumulate_core)
176
+ fused_fwd_bwd = torch.compile(fused_fwd_bwd)
300
177
 
301
178
  len_chosen = target.shape[0] // 2
302
179
  chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
@@ -327,10 +204,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
327
204
  torch._dynamo.mark_dynamic(target, 1)
328
205
 
329
206
  # accumulate loss, gradients, and metrics
330
- grad_input = accumulate_chunk(input_chunk, target_chunk)
331
-
332
- grad_chosen_inputs.append(grad_input[: chosen_target_chunk.shape[0]])
333
- grad_rejected_inputs.append(grad_input[chosen_target_chunk.shape[0] :])
207
+ accumulate_chunk(input_chunk, target_chunk)
334
208
 
335
209
  # combine grad_chosen_inputs and grad_rejected_inputs
336
210
  grad_inputs = grad_chosen_inputs + grad_rejected_inputs
@@ -367,3 +241,146 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
367
241
  grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
368
242
 
369
243
  return grad_input, grad_weight, None, grad_bias, None, None, None
244
+
245
+ @staticmethod
246
+ def chunk_forward(
247
+ input_chunk,
248
+ weight,
249
+ target_chunk,
250
+ bias=None,
251
+ ignore_index=-100,
252
+ compute_nll_loss=True,
253
+ ):
254
+ len_chosen_chunk = target_chunk.shape[0] // 2
255
+ logits_chunk = input_chunk @ weight.t()
256
+ if bias is not None:
257
+ logits_chunk = logits_chunk + bias
258
+ log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
259
+
260
+ chosen_nll_loss = 0.0
261
+ if compute_nll_loss:
262
+ chosen_nll_loss = F.nll_loss(
263
+ log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
264
+ target_chunk[:len_chosen_chunk].view(-1),
265
+ reduction="sum",
266
+ ignore_index=ignore_index,
267
+ )
268
+
269
+ loss_mask = target_chunk != ignore_index
270
+ label_chunk = torch.where(loss_mask, target_chunk, 0)
271
+
272
+ per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
273
+ -1
274
+ )
275
+ average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
276
+
277
+ chosen_logps = average_log_prob[:len_chosen_chunk]
278
+ rejected_logps = average_log_prob[len_chosen_chunk:]
279
+
280
+ chosen_logits = logits_chunk[:len_chosen_chunk]
281
+ rejected_logits = logits_chunk[len_chosen_chunk:]
282
+
283
+ return (
284
+ chosen_logps,
285
+ rejected_logps,
286
+ chosen_logits,
287
+ rejected_logits,
288
+ chosen_nll_loss,
289
+ )
290
+
291
+ @staticmethod
292
+ def _compute_loss(
293
+ input_chunk,
294
+ weight,
295
+ target_chunk,
296
+ bias=None,
297
+ preference_loss_fn=None,
298
+ full_target=None,
299
+ ignore_index=-100,
300
+ alpha=1.0,
301
+ beta=0.1,
302
+ compute_nll_loss=True,
303
+ use_ref_model=False,
304
+ ref_weight=None,
305
+ ref_bias=None,
306
+ **loss_kwargs,
307
+ ):
308
+ """
309
+ Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
310
+ Args:
311
+ preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
312
+ input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
313
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
314
+ target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
315
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
316
+ full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
317
+ ignore_index (int): Index to ignore for loss computation.
318
+ alpha (float): Weight for the NLL loss.
319
+ beta (float): Weight for the preference loss.
320
+ compute_nll_loss (bool): Whether to compute NLL loss.
321
+ use_ref_model (bool): Whether to use a reference model for the alignment loss.
322
+ ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
323
+ ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
324
+ loss_kwargs (dict): Additional arguments for the loss function.
325
+ """
326
+ (
327
+ chosen_logps,
328
+ rejected_logps,
329
+ chosen_logits,
330
+ rejected_logits,
331
+ chosen_nll_loss,
332
+ ) = LigerFusedLinearPreferenceBase.chunk_forward(
333
+ input_chunk,
334
+ weight,
335
+ target_chunk,
336
+ bias=bias,
337
+ ignore_index=ignore_index,
338
+ compute_nll_loss=compute_nll_loss,
339
+ )
340
+ chosen_nll_loss = (
341
+ chosen_nll_loss
342
+ / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
343
+ )
344
+ chosen_logits_mean = chosen_logits.sum() / (
345
+ full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
346
+ )
347
+ rejected_logits_mean = rejected_logits.sum() / (
348
+ full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
349
+ )
350
+
351
+ if use_ref_model:
352
+ with torch.no_grad():
353
+ (
354
+ ref_chosen_logps,
355
+ ref_rejected_logps,
356
+ ref_chosen_logits,
357
+ ref_rejected_logits,
358
+ ref_chosen_nll_loss,
359
+ ) = LigerFusedLinearPreferenceBase.chunk_forward(
360
+ input_chunk,
361
+ ref_weight,
362
+ target_chunk,
363
+ ref_bias,
364
+ ignore_index=ignore_index,
365
+ compute_nll_loss=False, # We don't need NLL loss for the reference model
366
+ )
367
+ loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
368
+ loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
369
+
370
+ preference_loss_outputs = preference_loss_fn(
371
+ chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs
372
+ )
373
+ if isinstance(preference_loss_outputs, tuple):
374
+ preference_loss, *aux_outputs = preference_loss_outputs
375
+ else:
376
+ preference_loss, aux_outputs = preference_loss_outputs, []
377
+
378
+ loss = alpha * chosen_nll_loss - preference_loss
379
+ return_vars = (
380
+ chosen_logps,
381
+ rejected_logps,
382
+ chosen_logits_mean,
383
+ rejected_logits_mean,
384
+ chosen_nll_loss,
385
+ )
386
+ return loss, (*return_vars, *aux_outputs)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.2.dev20241209195823
3
+ Version: 0.4.2.dev20241209224333
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation