liger-kernel-nightly 0.5.2.dev20241223032630__py3-none-any.whl → 0.5.2.dev20241223042135__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. liger_kernel/chunked_loss/cpo_loss.py +5 -11
  2. liger_kernel/chunked_loss/dpo_loss.py +1 -4
  3. liger_kernel/chunked_loss/fused_linear_distillation.py +37 -37
  4. liger_kernel/chunked_loss/fused_linear_preference.py +40 -64
  5. liger_kernel/chunked_loss/orpo_loss.py +2 -6
  6. liger_kernel/chunked_loss/simpo_loss.py +4 -8
  7. liger_kernel/env_report.py +4 -11
  8. liger_kernel/ops/cross_entropy.py +7 -10
  9. liger_kernel/ops/experimental/embedding.py +1 -3
  10. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  11. liger_kernel/ops/fused_linear_cross_entropy.py +7 -15
  12. liger_kernel/ops/fused_linear_jsd.py +11 -29
  13. liger_kernel/ops/geglu.py +6 -17
  14. liger_kernel/ops/group_norm.py +11 -28
  15. liger_kernel/ops/jsd.py +2 -6
  16. liger_kernel/ops/kl_div.py +4 -7
  17. liger_kernel/ops/layer_norm.py +3 -5
  18. liger_kernel/ops/qwen2vl_mrope.py +8 -25
  19. liger_kernel/ops/rms_norm.py +11 -29
  20. liger_kernel/ops/rope.py +8 -24
  21. liger_kernel/ops/swiglu.py +4 -8
  22. liger_kernel/ops/utils.py +2 -0
  23. liger_kernel/transformers/__init__.py +16 -24
  24. liger_kernel/transformers/auto_model.py +6 -13
  25. liger_kernel/transformers/cross_entropy.py +1 -3
  26. liger_kernel/transformers/experimental/embedding.py +1 -3
  27. liger_kernel/transformers/functional.py +2 -6
  28. liger_kernel/transformers/fused_linear_cross_entropy.py +2 -6
  29. liger_kernel/transformers/geglu.py +1 -4
  30. liger_kernel/transformers/group_norm.py +3 -9
  31. liger_kernel/transformers/jsd.py +1 -3
  32. liger_kernel/transformers/kl_div.py +1 -3
  33. liger_kernel/transformers/layer_norm.py +3 -9
  34. liger_kernel/transformers/model/gemma.py +18 -40
  35. liger_kernel/transformers/model/gemma2.py +19 -41
  36. liger_kernel/transformers/model/llama.py +22 -48
  37. liger_kernel/transformers/model/mistral.py +14 -26
  38. liger_kernel/transformers/model/mixtral.py +23 -53
  39. liger_kernel/transformers/model/mllama.py +16 -36
  40. liger_kernel/transformers/model/phi3.py +18 -40
  41. liger_kernel/transformers/model/qwen2.py +18 -40
  42. liger_kernel/transformers/model/qwen2_vl.py +16 -30
  43. liger_kernel/transformers/monkey_patch.py +43 -117
  44. liger_kernel/transformers/rms_norm.py +4 -4
  45. liger_kernel/transformers/swiglu.py +2 -8
  46. liger_kernel/transformers/trainer/__init__.py +1 -3
  47. liger_kernel/transformers/trainer/orpo_trainer.py +13 -16
  48. liger_kernel/triton/__init__.py +1 -3
  49. liger_kernel/triton/monkey_patch.py +1 -3
  50. {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/METADATA +1 -1
  51. liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/RECORD +66 -0
  52. liger_kernel_nightly-0.5.2.dev20241223032630.dist-info/RECORD +0 -66
  53. {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/LICENSE +0 -0
  54. {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/NOTICE +0 -0
  55. {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/WHEEL +0 -0
  56. {liger_kernel_nightly-0.5.2.dev20241223032630.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,12 @@
1
1
  import torch
2
2
  import torch.nn.functional as F
3
3
 
4
- from liger_kernel.chunked_loss.fused_linear_preference import (
5
- LigerFusedLinearPreferenceBase,
6
- )
4
+ from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
7
5
 
8
6
 
9
7
  class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
10
-
11
8
  @staticmethod
12
- def preference_loss_fn(
13
- chosen_logps, rejected_logps, full_target, beta=0.1, label_smoothing=0.0
14
- ):
9
+ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1, label_smoothing=0.0):
15
10
  """
16
11
  Paper: https://arxiv.org/pdf/2401.08417
17
12
 
@@ -35,10 +30,9 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
35
30
  label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
36
31
  """
37
32
  logits = beta * (chosen_logps - rejected_logps)
38
- loss = (
39
- - F.logsigmoid(logits) * (1 - label_smoothing)
40
- - F.logsigmoid(-logits) * label_smoothing
41
- ).sum() / (full_target.shape[0] // 2)
33
+ loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
34
+ full_target.shape[0] // 2
35
+ )
42
36
 
43
37
  return loss
44
38
 
@@ -1,13 +1,10 @@
1
1
  import torch
2
2
  import torch.nn.functional as F
3
3
 
4
- from liger_kernel.chunked_loss.fused_linear_preference import (
5
- LigerFusedLinearPreferenceBase,
6
- )
4
+ from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
7
5
 
8
6
 
9
7
  class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
10
-
11
8
  @staticmethod
12
9
  def preference_loss_fn(
13
10
  chosen_logps,
@@ -2,11 +2,11 @@ from abc import abstractmethod
2
2
  from functools import partial
3
3
 
4
4
  import torch
5
+
5
6
  from torch.nn import functional as F
6
7
 
7
8
 
8
9
  class LigerFusedLinearDistillationBase(torch.autograd.Function):
9
-
10
10
  @abstractmethod
11
11
  def distillation_loss_fn(student_logits, teacher_logits, temperature):
12
12
  """
@@ -89,25 +89,25 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
89
89
  compute_ce_loss (bool): Whether to compute CE loss.
90
90
  loss_kwargs (dict): Additional arguments for the loss function.
91
91
  """
92
- student_logits_chunk, teacher_logits_chunk, hard_loss = (
93
- LigerFusedLinearDistillationBase.chunk_forward(
94
- student_input_chunk,
95
- student_weight,
96
- teacher_input_chunk,
97
- teacher_weight,
98
- target_chunk,
99
- student_bias=student_bias,
100
- teacher_bias=teacher_bias,
101
- ignore_index=ignore_index,
102
- compute_ce_loss=compute_ce_loss,
103
- )
92
+ (
93
+ student_logits_chunk,
94
+ teacher_logits_chunk,
95
+ hard_loss,
96
+ ) = LigerFusedLinearDistillationBase.chunk_forward(
97
+ student_input_chunk,
98
+ student_weight,
99
+ teacher_input_chunk,
100
+ teacher_weight,
101
+ target_chunk,
102
+ student_bias=student_bias,
103
+ teacher_bias=teacher_bias,
104
+ ignore_index=ignore_index,
105
+ compute_ce_loss=compute_ce_loss,
104
106
  )
105
107
 
106
108
  hard_loss /= full_target.shape[0]
107
109
 
108
- soft_loss = distillation_loss_fn(
109
- student_logits_chunk, teacher_logits_chunk, temperature
110
- )
110
+ soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature)
111
111
  soft_loss /= full_target.shape[0]
112
112
 
113
113
  loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
@@ -174,17 +174,18 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
174
174
 
175
175
  def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk):
176
176
  if student_bias is not None:
177
- (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
178
- chunk_loss,
177
+ (
178
+ (chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
179
179
  (
180
- chunk_soft_loss,
181
- chunk_hard_loss,
182
- chunk_student_logits,
183
- chunk_teacher_logits,
180
+ chunk_loss,
181
+ (
182
+ chunk_soft_loss,
183
+ chunk_hard_loss,
184
+ chunk_student_logits,
185
+ chunk_teacher_logits,
186
+ ),
184
187
  ),
185
- ) = torch.func.grad_and_value(
186
- loss_func_to_call, argnums=(0, 1, 5), has_aux=True
187
- )(
188
+ ) = torch.func.grad_and_value(loss_func_to_call, argnums=(0, 1, 5), has_aux=True)(
188
189
  student_input_chunk,
189
190
  student_weight,
190
191
  teacher_input_chunk,
@@ -195,17 +196,18 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
195
196
  )
196
197
  grad_bias.add_(chunk_grad_bias)
197
198
  else:
198
- (chunk_grad_input, chunk_grad_weight), (
199
- chunk_loss,
199
+ (
200
+ (chunk_grad_input, chunk_grad_weight),
200
201
  (
201
- chunk_soft_loss,
202
- chunk_hard_loss,
203
- chunk_student_logits,
204
- chunk_teacher_logits,
202
+ chunk_loss,
203
+ (
204
+ chunk_soft_loss,
205
+ chunk_hard_loss,
206
+ chunk_student_logits,
207
+ chunk_teacher_logits,
208
+ ),
205
209
  ),
206
- ) = torch.func.grad_and_value(
207
- loss_func_to_call, argnums=(0, 1), has_aux=True
208
- )(
210
+ ) = torch.func.grad_and_value(loss_func_to_call, argnums=(0, 1), has_aux=True)(
209
211
  student_input_chunk,
210
212
  student_weight,
211
213
  teacher_input_chunk,
@@ -229,9 +231,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
229
231
  for student_input_chunk, teacher_input_chunk, target_chunk in zip(
230
232
  _student_input_chunks, _teacher_input_chunks, _target_chunks
231
233
  ):
232
- grad_input = accumulate_chunk(
233
- student_input_chunk, teacher_input_chunk, target_chunk
234
- )
234
+ grad_input = accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk)
235
235
  grad_inputs.append(grad_input)
236
236
 
237
237
  ctx.save_for_backward(
@@ -2,11 +2,11 @@ from abc import abstractmethod
2
2
  from functools import partial
3
3
 
4
4
  import torch
5
+
5
6
  from torch.nn import functional as F
6
7
 
7
8
 
8
9
  class LigerFusedLinearPreferenceBase(torch.autograd.Function):
9
-
10
10
  @abstractmethod
11
11
  def preference_loss_fn(*args, **kwargs):
12
12
  """
@@ -102,9 +102,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
102
102
  Fused forward and backward pass for a chunk of input and target.
103
103
  """
104
104
  if bias is not None:
105
- return torch.func.grad_and_value(
106
- compute_loss, argnums=(0, 1, 3), has_aux=True
107
- )(
105
+ return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 3), has_aux=True)(
108
106
  input_chunk,
109
107
  weight,
110
108
  target_chunk,
@@ -112,43 +110,47 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
112
110
  ref_input_chunk=ref_input_chunk,
113
111
  )
114
112
  else:
115
- return torch.func.grad_and_value(
116
- compute_loss, argnums=(0, 1), has_aux=True
117
- )(input_chunk, weight, target_chunk, ref_input_chunk=ref_input_chunk)
113
+ return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
114
+ input_chunk, weight, target_chunk, ref_input_chunk=ref_input_chunk
115
+ )
118
116
 
119
117
  def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None):
120
118
  if bias is not None:
121
- (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
122
- chunk_loss,
119
+ (
120
+ (chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
123
121
  (
124
- chunk_chosen_logps,
125
- chunk_rejected_logps,
126
- chunk_chosen_logits_mean,
127
- chunk_rejected_logits_mean,
128
- chunk_nll_loss,
129
- *aux_outputs,
122
+ chunk_loss,
123
+ (
124
+ chunk_chosen_logps,
125
+ chunk_rejected_logps,
126
+ chunk_chosen_logits_mean,
127
+ chunk_rejected_logits_mean,
128
+ chunk_nll_loss,
129
+ *aux_outputs,
130
+ ),
130
131
  ),
131
132
  ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
132
133
  grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
133
134
  else:
134
- (chunk_grad_input, chunk_grad_weight), (
135
- chunk_loss,
135
+ (
136
+ (chunk_grad_input, chunk_grad_weight),
136
137
  (
137
- chunk_chosen_logps,
138
- chunk_rejected_logps,
139
- chunk_chosen_logits_mean,
140
- chunk_rejected_logits_mean,
141
- chunk_nll_loss,
142
- *aux_outputs,
138
+ chunk_loss,
139
+ (
140
+ chunk_chosen_logps,
141
+ chunk_rejected_logps,
142
+ chunk_chosen_logits_mean,
143
+ chunk_rejected_logits_mean,
144
+ chunk_nll_loss,
145
+ *aux_outputs,
146
+ ),
143
147
  ),
144
148
  ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
145
149
 
146
150
  # Accumulate gradients
147
151
  grad_weight.add_(chunk_grad_weight)
148
152
  grad_chosen_inputs.append(chunk_grad_input[: chosen_target_chunk.shape[0]])
149
- grad_rejected_inputs.append(
150
- chunk_grad_input[chosen_target_chunk.shape[0] :]
151
- )
153
+ grad_rejected_inputs.append(chunk_grad_input[chosen_target_chunk.shape[0] :])
152
154
 
153
155
  # Accumulate loss
154
156
  loss_acc.add_(chunk_loss)
@@ -165,9 +167,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
165
167
  if len(aggregated_aux_outputs) == 0:
166
168
  for aux in aux_outputs:
167
169
  if aux.ndim == 0:
168
- aggregated_aux_outputs.append(
169
- torch.zeros((), device=aux.device)
170
- )
170
+ aggregated_aux_outputs.append(torch.zeros((), device=aux.device))
171
171
  else:
172
172
  aggregated_aux_outputs.append([])
173
173
 
@@ -189,12 +189,8 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
189
189
  _rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)
190
190
 
191
191
  if use_ref_model:
192
- _ref_chosen_input_chunks = torch.chunk(
193
- ref_input[:len_chosen], chunks=chunks, dim=0
194
- )
195
- _ref_rejected_input_chunks = torch.chunk(
196
- ref_input[len_chosen:], chunks=chunks, dim=0
197
- )
192
+ _ref_chosen_input_chunks = torch.chunk(ref_input[:len_chosen], chunks=chunks, dim=0)
193
+ _ref_rejected_input_chunks = torch.chunk(ref_input[len_chosen:], chunks=chunks, dim=0)
198
194
 
199
195
  for (
200
196
  chosen_input_chunk,
@@ -208,26 +204,15 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
208
204
  _rejected_input_chunks,
209
205
  _chosen_target_chunks,
210
206
  _rejected_target_chunks,
211
- (
212
- _ref_chosen_input_chunks
213
- if use_ref_model
214
- else [None] * len(_chosen_input_chunks)
215
- ),
216
- (
217
- _ref_rejected_input_chunks
218
- if use_ref_model
219
- else [None] * len(_rejected_input_chunks)
220
- ),
207
+ (_ref_chosen_input_chunks if use_ref_model else [None] * len(_chosen_input_chunks)),
208
+ (_ref_rejected_input_chunks if use_ref_model else [None] * len(_rejected_input_chunks)),
209
+ strict=False,
221
210
  ):
222
211
  input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
223
212
  ref_input_chunk = (
224
- torch.cat([ref_chosen_input_chunk, ref_rejected_input_chunk], dim=0)
225
- if use_ref_model
226
- else None
227
- )
228
- target_chunk = torch.cat(
229
- [chosen_target_chunk, rejected_target_chunk], dim=0
213
+ torch.cat([ref_chosen_input_chunk, ref_rejected_input_chunk], dim=0) if use_ref_model else None
230
214
  )
215
+ target_chunk = torch.cat([chosen_target_chunk, rejected_target_chunk], dim=0)
231
216
 
232
217
  # mark input_chunk, target_chunk, and target dimension 1 as dynamic to prevent torch.compile recompilation
233
218
  torch._dynamo.mark_dynamic(input_chunk, 1)
@@ -265,9 +250,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
265
250
  @staticmethod
266
251
  def backward(ctx, *grad_output):
267
252
  grad_input, grad_weight, grad_bias = ctx.saved_tensors
268
- if torch.ne(
269
- grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)
270
- ):
253
+ if torch.ne(grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)):
271
254
  grad_input = grad_input * grad_output[0][0]
272
255
  grad_weight = grad_weight * grad_output[0][0]
273
256
  grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
@@ -301,9 +284,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
301
284
  loss_mask = target_chunk != ignore_index
302
285
  label_chunk = torch.where(loss_mask, target_chunk, 0)
303
286
 
304
- per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
305
- -1
306
- )
287
+ per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
307
288
  average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
308
289
 
309
290
  chosen_logps = average_log_prob[:len_chosen_chunk]
@@ -370,13 +351,8 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
370
351
  ignore_index=ignore_index,
371
352
  compute_nll_loss=compute_nll_loss,
372
353
  )
373
- chosen_nll_loss = (
374
- chosen_nll_loss
375
- / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
376
- )
377
- chosen_logits_mean = chosen_logits.sum() / (
378
- full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
379
- )
354
+ chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
355
+ chosen_logits_mean = chosen_logits.sum() / (full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0])
380
356
  rejected_logits_mean = rejected_logits.sum() / (
381
357
  full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
382
358
  )
@@ -1,13 +1,10 @@
1
1
  import torch
2
2
  import torch.nn.functional as F
3
3
 
4
- from liger_kernel.chunked_loss.fused_linear_preference import (
5
- LigerFusedLinearPreferenceBase,
6
- )
4
+ from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
7
5
 
8
6
 
9
7
  class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
10
-
11
8
  @staticmethod
12
9
  def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
13
10
  """
@@ -32,8 +29,7 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
32
29
  beta (float): Weight for the odds ratio loss.
33
30
  """
34
31
  log_odds = (chosen_logps - rejected_logps) - (
35
- torch.log1p(-torch.exp(chosen_logps))
36
- - torch.log1p(-torch.exp(rejected_logps))
32
+ torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
37
33
  )
38
34
  ratio = F.logsigmoid(log_odds)
39
35
  loss = -beta * ratio.sum() / (full_target.shape[0] // 2)
@@ -1,13 +1,10 @@
1
1
  import torch
2
2
  import torch.nn.functional as F
3
3
 
4
- from liger_kernel.chunked_loss.fused_linear_preference import (
5
- LigerFusedLinearPreferenceBase,
6
- )
4
+ from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
7
5
 
8
6
 
9
7
  class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
10
-
11
8
  @staticmethod
12
9
  def preference_loss_fn(
13
10
  chosen_logps,
@@ -41,10 +38,9 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
41
38
  label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
42
39
  """
43
40
  logits = beta * (chosen_logps - rejected_logps) - gamma
44
- loss = (
45
- - F.logsigmoid(logits) * (1 - label_smoothing)
46
- - F.logsigmoid(-logits) * label_smoothing
47
- ).sum() / (full_target.shape[0] // 2)
41
+ loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
42
+ full_target.shape[0] // 2
43
+ )
48
44
 
49
45
  return loss
50
46
 
@@ -1,5 +1,6 @@
1
1
  import platform
2
2
  import sys
3
+
3
4
  from importlib.metadata import version
4
5
 
5
6
 
@@ -27,15 +28,9 @@ def print_env_report():
27
28
  import torch
28
29
 
29
30
  print(f"PyTorch version: {torch.__version__}")
30
- cuda_version = (
31
- torch.version.cuda if torch.cuda.is_available() else "Not available"
32
- )
31
+ cuda_version = torch.version.cuda if torch.cuda.is_available() else "Not available"
33
32
  print(f"CUDA version: {cuda_version}")
34
- hip_version = (
35
- torch.version.hip
36
- if torch.cuda.is_available() and torch.version.hip
37
- else "Not available"
38
- )
33
+ hip_version = torch.version.hip if torch.cuda.is_available() and torch.version.hip else "Not available"
39
34
  print(f"HIP(ROCm) version: {hip_version}")
40
35
 
41
36
  except ImportError:
@@ -58,9 +53,7 @@ def print_env_report():
58
53
  print("Transformers: Not installed")
59
54
 
60
55
  try:
61
- xpu_version = (
62
- torch.version.xpu if torch.xpu.is_available() else "XPU Not Available"
63
- )
56
+ xpu_version = torch.version.xpu if torch.xpu.is_available() else "XPU Not Available"
64
57
  print(f"XPU version: {xpu_version}")
65
58
  except ImportError:
66
59
  print("XPU version: Unable to query")
@@ -1,11 +1,14 @@
1
1
  import operator
2
+
2
3
  from typing import Optional
3
4
 
4
5
  import torch
5
6
  import triton
6
7
  import triton.language as tl
7
8
 
8
- from liger_kernel.ops.utils import compare_version, element_mul_kernel, is_hip
9
+ from liger_kernel.ops.utils import compare_version
10
+ from liger_kernel.ops.utils import element_mul_kernel
11
+ from liger_kernel.ops.utils import is_hip
9
12
 
10
13
  if compare_version("triton", operator.ge, "3.0.0"):
11
14
  try:
@@ -92,9 +95,7 @@ def liger_cross_entropy_kernel(
92
95
  # 3. [Online softmax] first pass: find max + sum
93
96
  m = float("-inf") # m is the max value. use the notation from the paper
94
97
  d = 0.0 # d is the sum. use the notation from the paper
95
- ori_X_y = tl.load(X_ptr + y).cast(
96
- tl.float32
97
- ) # we need to store the original value of X_y for the loss calculation
98
+ ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
98
99
  if HAS_SOFTCAPPING:
99
100
  ori_X_y = softcap * tanh(ori_X_y / softcap)
100
101
 
@@ -232,14 +233,10 @@ def cross_entropy_forward(
232
233
  return_z_loss,
233
234
  ):
234
235
  if not isinstance(return_z_loss, int):
235
- assert (
236
- return_z_loss in _bool_to_return_z_loss
237
- ), f"return_z_loss must be True or False. Got: {return_z_loss}"
236
+ assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}"
238
237
  return_z_loss = _bool_to_return_z_loss[return_z_loss]
239
238
  else:
240
- assert (
241
- return_z_loss in _bool_to_return_z_loss
242
- ), f"return_z_loss must be True or False. Got: {return_z_loss}"
239
+ assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}"
243
240
 
244
241
  BT, V = _input.shape
245
242
  n_rows = BT
@@ -34,9 +34,7 @@ def embedding_forward_kernel(
34
34
  )
35
35
 
36
36
  output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :]
37
- tl.store(
38
- output_ptr + output_offsets, embeddings, mask=mask_m[:, None] & mask_n[None, :]
39
- )
37
+ tl.store(output_ptr + output_offsets, embeddings, mask=mask_m[:, None] & mask_n[None, :])
40
38
 
41
39
 
42
40
  @triton.jit
@@ -37,9 +37,7 @@ def pack_weights(intweights: torch.Tensor, bits: int = 2) -> torch.Tensor:
37
37
  else:
38
38
  packed_tensor_shape = (row_dim, *original_shape[1:])
39
39
 
40
- packed = torch.zeros(
41
- packed_tensor_shape, device=intweights.device, dtype=torch.uint8
42
- )
40
+ packed = torch.zeros(packed_tensor_shape, device=intweights.device, dtype=torch.uint8)
43
41
  unpacked = intweights.to(torch.uint8)
44
42
 
45
43
  def lshift(t: torch.Tensor, bits: int):
@@ -327,17 +325,13 @@ def matmul_kernel(
327
325
 
328
326
 
329
327
  def matmul(a, b):
330
- assert (
331
- a.shape[1] == b.shape[0] * 4
332
- ), "Incompatible dimensions, the weight matrix need to be packed"
328
+ assert a.shape[1] == b.shape[0] * 4, "Incompatible dimensions, the weight matrix need to be packed"
333
329
  assert a.is_contiguous(), "Matrix A must be contiguous"
334
330
  M, K = a.shape
335
331
  _, N = b.shape
336
332
  # c is in int32 to avoid any overflows or underflows
337
333
  c = torch.empty((M, N), device=a.device, dtype=torch.int32)
338
- grid = lambda META: (
339
- triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
340
- )
334
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),)
341
335
  matmul_kernel[grid](
342
336
  a,
343
337
  b,
@@ -2,12 +2,10 @@ import torch
2
2
  import triton
3
3
 
4
4
  from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel
5
- from liger_kernel.ops.utils import (
6
- amp_custom_bwd,
7
- amp_custom_fwd,
8
- element_mul_kernel,
9
- is_hip,
10
- )
5
+ from liger_kernel.ops.utils import amp_custom_bwd
6
+ from liger_kernel.ops.utils import amp_custom_fwd
7
+ from liger_kernel.ops.utils import element_mul_kernel
8
+ from liger_kernel.ops.utils import is_hip
11
9
 
12
10
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
13
11
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
@@ -40,14 +38,10 @@ def fused_linear_cross_entropy_forward(
40
38
  BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
41
39
 
42
40
  inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
43
- chunk_size = triton.next_power_of_2(
44
- triton.cdiv(BT, inc_factor)
45
- ) # (BT + inc_factor - 1) // inc_factor
41
+ chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
46
42
  num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
47
43
 
48
- grad_weight = (
49
- torch.zeros_like(weight, device=device) if weight.requires_grad else None
50
- )
44
+ grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
51
45
  grad_input = torch.zeros_like(_input, device=device)
52
46
  grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
53
47
  # we use fp32 for loss accumulator
@@ -137,9 +131,7 @@ def fused_linear_cross_entropy_forward(
137
131
  return loss, grad_input, grad_weight, grad_bias
138
132
 
139
133
 
140
- def fused_linear_cross_entropy_backward(
141
- grad_output, grad_input, grad_weight, grad_bias
142
- ):
134
+ def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
143
135
  # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
144
136
  if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
145
137
  # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
@@ -4,12 +4,10 @@ import torch
4
4
  import triton
5
5
 
6
6
  from liger_kernel.ops.jsd import _jsd_kernel
7
- from liger_kernel.ops.utils import (
8
- amp_custom_bwd,
9
- amp_custom_fwd,
10
- element_mul_kernel,
11
- is_hip,
12
- )
7
+ from liger_kernel.ops.utils import amp_custom_bwd
8
+ from liger_kernel.ops.utils import amp_custom_fwd
9
+ from liger_kernel.ops.utils import element_mul_kernel
10
+ from liger_kernel.ops.utils import is_hip
13
11
 
14
12
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
15
13
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
@@ -43,16 +41,10 @@ def fused_linear_jsd_forward(
43
41
  BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
44
42
 
45
43
  inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
46
- chunk_size = triton.next_power_of_2(
47
- triton.cdiv(BT, inc_factor)
48
- ) # (BT + inc_factor - 1) // inc_factor
44
+ chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
49
45
  num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
50
46
 
51
- grad_weight = (
52
- torch.zeros_like(student_weight, device=device)
53
- if student_weight.requires_grad
54
- else None
55
- )
47
+ grad_weight = torch.zeros_like(student_weight, device=device) if student_weight.requires_grad else None
56
48
  grad_input = torch.zeros_like(student_input)
57
49
  # we use fp32 for loss accumulator
58
50
  loss_1d = torch.zeros((BT, V), dtype=torch.float32, device=device)
@@ -73,12 +65,8 @@ def fused_linear_jsd_forward(
73
65
  # shape: chunk_size x V
74
66
  # For anything starting from logits to the final JSD loss, we do computation
75
67
  # in FP32 to avoid losing numerical stability.
76
- student_logits_chunk = (student_input_chunk @ student_weight.t()).to(
77
- torch.float32
78
- )
79
- teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(
80
- torch.float32
81
- )
68
+ student_logits_chunk = (student_input_chunk @ student_weight.t()).to(torch.float32)
69
+ teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(torch.float32)
82
70
  chunk_n_rows = student_logits_chunk.shape[0]
83
71
 
84
72
  # unreduced loss
@@ -104,9 +92,7 @@ def fused_linear_jsd_forward(
104
92
  dX_ptr=student_prob_chunk,
105
93
  dX_stride=student_prob_chunk.stride(-2),
106
94
  label_ptr=(
107
- shift_labels[start_idx:end_idx]
108
- if has_label
109
- else torch.empty(1, device=device)
95
+ shift_labels[start_idx:end_idx] if has_label else torch.empty(1, device=device)
110
96
  ), # dummy ptr if no label
111
97
  beta=jsd_beta,
112
98
  n_non_ignore=n_non_ignore,
@@ -121,9 +107,7 @@ def fused_linear_jsd_forward(
121
107
  student_logits_chunk = (
122
108
  student_prob_chunk
123
109
  - torch.softmax(student_logits_chunk, dim=-1)
124
- * student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(
125
- student_prob_chunk.shape
126
- )
110
+ * student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(student_prob_chunk.shape)
127
111
  ) / temperature
128
112
  # now we traverse back to grad w.r.t. input to `lm_head` and grad
129
113
  # w.r.t. `lm_head` which should be computed in original dtype
@@ -239,7 +223,5 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
239
223
  @amp_custom_bwd
240
224
  def backward(ctx, grad_output):
241
225
  (grad_input, grad_weight) = ctx.saved_tensors
242
- grad_input, grad_weight = fused_linear_jsd_backward(
243
- grad_output, grad_input, grad_weight
244
- )
226
+ grad_input, grad_weight = fused_linear_jsd_backward(grad_output, grad_input, grad_weight)
245
227
  return (grad_input, grad_weight, None, None, None, None, None, None)