liger-kernel-nightly 0.5.2.dev20241223032015__py3-none-any.whl → 0.5.2.dev20241223042135__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (57) hide show
  1. liger_kernel/chunked_loss/cpo_loss.py +5 -11
  2. liger_kernel/chunked_loss/dpo_loss.py +1 -4
  3. liger_kernel/chunked_loss/fused_linear_distillation.py +37 -37
  4. liger_kernel/chunked_loss/fused_linear_preference.py +40 -64
  5. liger_kernel/chunked_loss/orpo_loss.py +2 -6
  6. liger_kernel/chunked_loss/simpo_loss.py +4 -8
  7. liger_kernel/env_report.py +4 -11
  8. liger_kernel/ops/cross_entropy.py +7 -10
  9. liger_kernel/ops/experimental/embedding.py +1 -3
  10. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  11. liger_kernel/ops/fused_linear_cross_entropy.py +7 -15
  12. liger_kernel/ops/fused_linear_jsd.py +11 -29
  13. liger_kernel/ops/geglu.py +6 -17
  14. liger_kernel/ops/group_norm.py +11 -28
  15. liger_kernel/ops/jsd.py +2 -6
  16. liger_kernel/ops/kl_div.py +4 -7
  17. liger_kernel/ops/layer_norm.py +3 -5
  18. liger_kernel/ops/qwen2vl_mrope.py +8 -25
  19. liger_kernel/ops/rms_norm.py +11 -29
  20. liger_kernel/ops/rope.py +31 -33
  21. liger_kernel/ops/swiglu.py +4 -8
  22. liger_kernel/ops/utils.py +2 -0
  23. liger_kernel/transformers/__init__.py +16 -24
  24. liger_kernel/transformers/auto_model.py +6 -13
  25. liger_kernel/transformers/cross_entropy.py +1 -3
  26. liger_kernel/transformers/experimental/embedding.py +1 -3
  27. liger_kernel/transformers/functional.py +2 -6
  28. liger_kernel/transformers/fused_linear_cross_entropy.py +2 -6
  29. liger_kernel/transformers/geglu.py +1 -4
  30. liger_kernel/transformers/group_norm.py +3 -9
  31. liger_kernel/transformers/jsd.py +1 -3
  32. liger_kernel/transformers/kl_div.py +1 -3
  33. liger_kernel/transformers/layer_norm.py +3 -9
  34. liger_kernel/transformers/model/gemma.py +18 -40
  35. liger_kernel/transformers/model/gemma2.py +19 -41
  36. liger_kernel/transformers/model/llama.py +22 -48
  37. liger_kernel/transformers/model/mistral.py +14 -26
  38. liger_kernel/transformers/model/mixtral.py +23 -53
  39. liger_kernel/transformers/model/mllama.py +16 -36
  40. liger_kernel/transformers/model/phi3.py +18 -40
  41. liger_kernel/transformers/model/qwen2.py +18 -40
  42. liger_kernel/transformers/model/qwen2_vl.py +16 -30
  43. liger_kernel/transformers/monkey_patch.py +43 -117
  44. liger_kernel/transformers/rms_norm.py +4 -4
  45. liger_kernel/transformers/rope.py +2 -2
  46. liger_kernel/transformers/swiglu.py +2 -8
  47. liger_kernel/transformers/trainer/__init__.py +1 -3
  48. liger_kernel/transformers/trainer/orpo_trainer.py +13 -16
  49. liger_kernel/triton/__init__.py +1 -3
  50. liger_kernel/triton/monkey_patch.py +1 -3
  51. {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/METADATA +1 -1
  52. liger_kernel_nightly-0.5.2.dev20241223042135.dist-info/RECORD +66 -0
  53. liger_kernel_nightly-0.5.2.dev20241223032015.dist-info/RECORD +0 -66
  54. {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/LICENSE +0 -0
  55. {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/NOTICE +0 -0
  56. {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/WHEEL +0 -0
  57. {liger_kernel_nightly-0.5.2.dev20241223032015.dist-info → liger_kernel_nightly-0.5.2.dev20241223042135.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,12 @@
1
1
  import torch
2
2
  import torch.nn.functional as F
3
3
 
4
- from liger_kernel.chunked_loss.fused_linear_preference import (
5
- LigerFusedLinearPreferenceBase,
6
- )
4
+ from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
7
5
 
8
6
 
9
7
  class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
10
-
11
8
  @staticmethod
12
- def preference_loss_fn(
13
- chosen_logps, rejected_logps, full_target, beta=0.1, label_smoothing=0.0
14
- ):
9
+ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1, label_smoothing=0.0):
15
10
  """
16
11
  Paper: https://arxiv.org/pdf/2401.08417
17
12
 
@@ -35,10 +30,9 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
35
30
  label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
36
31
  """
37
32
  logits = beta * (chosen_logps - rejected_logps)
38
- loss = (
39
- - F.logsigmoid(logits) * (1 - label_smoothing)
40
- - F.logsigmoid(-logits) * label_smoothing
41
- ).sum() / (full_target.shape[0] // 2)
33
+ loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
34
+ full_target.shape[0] // 2
35
+ )
42
36
 
43
37
  return loss
44
38
 
@@ -1,13 +1,10 @@
1
1
  import torch
2
2
  import torch.nn.functional as F
3
3
 
4
- from liger_kernel.chunked_loss.fused_linear_preference import (
5
- LigerFusedLinearPreferenceBase,
6
- )
4
+ from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
7
5
 
8
6
 
9
7
  class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
10
-
11
8
  @staticmethod
12
9
  def preference_loss_fn(
13
10
  chosen_logps,
@@ -2,11 +2,11 @@ from abc import abstractmethod
2
2
  from functools import partial
3
3
 
4
4
  import torch
5
+
5
6
  from torch.nn import functional as F
6
7
 
7
8
 
8
9
  class LigerFusedLinearDistillationBase(torch.autograd.Function):
9
-
10
10
  @abstractmethod
11
11
  def distillation_loss_fn(student_logits, teacher_logits, temperature):
12
12
  """
@@ -89,25 +89,25 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
89
89
  compute_ce_loss (bool): Whether to compute CE loss.
90
90
  loss_kwargs (dict): Additional arguments for the loss function.
91
91
  """
92
- student_logits_chunk, teacher_logits_chunk, hard_loss = (
93
- LigerFusedLinearDistillationBase.chunk_forward(
94
- student_input_chunk,
95
- student_weight,
96
- teacher_input_chunk,
97
- teacher_weight,
98
- target_chunk,
99
- student_bias=student_bias,
100
- teacher_bias=teacher_bias,
101
- ignore_index=ignore_index,
102
- compute_ce_loss=compute_ce_loss,
103
- )
92
+ (
93
+ student_logits_chunk,
94
+ teacher_logits_chunk,
95
+ hard_loss,
96
+ ) = LigerFusedLinearDistillationBase.chunk_forward(
97
+ student_input_chunk,
98
+ student_weight,
99
+ teacher_input_chunk,
100
+ teacher_weight,
101
+ target_chunk,
102
+ student_bias=student_bias,
103
+ teacher_bias=teacher_bias,
104
+ ignore_index=ignore_index,
105
+ compute_ce_loss=compute_ce_loss,
104
106
  )
105
107
 
106
108
  hard_loss /= full_target.shape[0]
107
109
 
108
- soft_loss = distillation_loss_fn(
109
- student_logits_chunk, teacher_logits_chunk, temperature
110
- )
110
+ soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature)
111
111
  soft_loss /= full_target.shape[0]
112
112
 
113
113
  loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
@@ -174,17 +174,18 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
174
174
 
175
175
  def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk):
176
176
  if student_bias is not None:
177
- (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
178
- chunk_loss,
177
+ (
178
+ (chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
179
179
  (
180
- chunk_soft_loss,
181
- chunk_hard_loss,
182
- chunk_student_logits,
183
- chunk_teacher_logits,
180
+ chunk_loss,
181
+ (
182
+ chunk_soft_loss,
183
+ chunk_hard_loss,
184
+ chunk_student_logits,
185
+ chunk_teacher_logits,
186
+ ),
184
187
  ),
185
- ) = torch.func.grad_and_value(
186
- loss_func_to_call, argnums=(0, 1, 5), has_aux=True
187
- )(
188
+ ) = torch.func.grad_and_value(loss_func_to_call, argnums=(0, 1, 5), has_aux=True)(
188
189
  student_input_chunk,
189
190
  student_weight,
190
191
  teacher_input_chunk,
@@ -195,17 +196,18 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
195
196
  )
196
197
  grad_bias.add_(chunk_grad_bias)
197
198
  else:
198
- (chunk_grad_input, chunk_grad_weight), (
199
- chunk_loss,
199
+ (
200
+ (chunk_grad_input, chunk_grad_weight),
200
201
  (
201
- chunk_soft_loss,
202
- chunk_hard_loss,
203
- chunk_student_logits,
204
- chunk_teacher_logits,
202
+ chunk_loss,
203
+ (
204
+ chunk_soft_loss,
205
+ chunk_hard_loss,
206
+ chunk_student_logits,
207
+ chunk_teacher_logits,
208
+ ),
205
209
  ),
206
- ) = torch.func.grad_and_value(
207
- loss_func_to_call, argnums=(0, 1), has_aux=True
208
- )(
210
+ ) = torch.func.grad_and_value(loss_func_to_call, argnums=(0, 1), has_aux=True)(
209
211
  student_input_chunk,
210
212
  student_weight,
211
213
  teacher_input_chunk,
@@ -229,9 +231,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
229
231
  for student_input_chunk, teacher_input_chunk, target_chunk in zip(
230
232
  _student_input_chunks, _teacher_input_chunks, _target_chunks
231
233
  ):
232
- grad_input = accumulate_chunk(
233
- student_input_chunk, teacher_input_chunk, target_chunk
234
- )
234
+ grad_input = accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk)
235
235
  grad_inputs.append(grad_input)
236
236
 
237
237
  ctx.save_for_backward(
@@ -2,11 +2,11 @@ from abc import abstractmethod
2
2
  from functools import partial
3
3
 
4
4
  import torch
5
+
5
6
  from torch.nn import functional as F
6
7
 
7
8
 
8
9
  class LigerFusedLinearPreferenceBase(torch.autograd.Function):
9
-
10
10
  @abstractmethod
11
11
  def preference_loss_fn(*args, **kwargs):
12
12
  """
@@ -102,9 +102,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
102
102
  Fused forward and backward pass for a chunk of input and target.
103
103
  """
104
104
  if bias is not None:
105
- return torch.func.grad_and_value(
106
- compute_loss, argnums=(0, 1, 3), has_aux=True
107
- )(
105
+ return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 3), has_aux=True)(
108
106
  input_chunk,
109
107
  weight,
110
108
  target_chunk,
@@ -112,43 +110,47 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
112
110
  ref_input_chunk=ref_input_chunk,
113
111
  )
114
112
  else:
115
- return torch.func.grad_and_value(
116
- compute_loss, argnums=(0, 1), has_aux=True
117
- )(input_chunk, weight, target_chunk, ref_input_chunk=ref_input_chunk)
113
+ return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
114
+ input_chunk, weight, target_chunk, ref_input_chunk=ref_input_chunk
115
+ )
118
116
 
119
117
  def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None):
120
118
  if bias is not None:
121
- (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
122
- chunk_loss,
119
+ (
120
+ (chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
123
121
  (
124
- chunk_chosen_logps,
125
- chunk_rejected_logps,
126
- chunk_chosen_logits_mean,
127
- chunk_rejected_logits_mean,
128
- chunk_nll_loss,
129
- *aux_outputs,
122
+ chunk_loss,
123
+ (
124
+ chunk_chosen_logps,
125
+ chunk_rejected_logps,
126
+ chunk_chosen_logits_mean,
127
+ chunk_rejected_logits_mean,
128
+ chunk_nll_loss,
129
+ *aux_outputs,
130
+ ),
130
131
  ),
131
132
  ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
132
133
  grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
133
134
  else:
134
- (chunk_grad_input, chunk_grad_weight), (
135
- chunk_loss,
135
+ (
136
+ (chunk_grad_input, chunk_grad_weight),
136
137
  (
137
- chunk_chosen_logps,
138
- chunk_rejected_logps,
139
- chunk_chosen_logits_mean,
140
- chunk_rejected_logits_mean,
141
- chunk_nll_loss,
142
- *aux_outputs,
138
+ chunk_loss,
139
+ (
140
+ chunk_chosen_logps,
141
+ chunk_rejected_logps,
142
+ chunk_chosen_logits_mean,
143
+ chunk_rejected_logits_mean,
144
+ chunk_nll_loss,
145
+ *aux_outputs,
146
+ ),
143
147
  ),
144
148
  ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
145
149
 
146
150
  # Accumulate gradients
147
151
  grad_weight.add_(chunk_grad_weight)
148
152
  grad_chosen_inputs.append(chunk_grad_input[: chosen_target_chunk.shape[0]])
149
- grad_rejected_inputs.append(
150
- chunk_grad_input[chosen_target_chunk.shape[0] :]
151
- )
153
+ grad_rejected_inputs.append(chunk_grad_input[chosen_target_chunk.shape[0] :])
152
154
 
153
155
  # Accumulate loss
154
156
  loss_acc.add_(chunk_loss)
@@ -165,9 +167,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
165
167
  if len(aggregated_aux_outputs) == 0:
166
168
  for aux in aux_outputs:
167
169
  if aux.ndim == 0:
168
- aggregated_aux_outputs.append(
169
- torch.zeros((), device=aux.device)
170
- )
170
+ aggregated_aux_outputs.append(torch.zeros((), device=aux.device))
171
171
  else:
172
172
  aggregated_aux_outputs.append([])
173
173
 
@@ -189,12 +189,8 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
189
189
  _rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)
190
190
 
191
191
  if use_ref_model:
192
- _ref_chosen_input_chunks = torch.chunk(
193
- ref_input[:len_chosen], chunks=chunks, dim=0
194
- )
195
- _ref_rejected_input_chunks = torch.chunk(
196
- ref_input[len_chosen:], chunks=chunks, dim=0
197
- )
192
+ _ref_chosen_input_chunks = torch.chunk(ref_input[:len_chosen], chunks=chunks, dim=0)
193
+ _ref_rejected_input_chunks = torch.chunk(ref_input[len_chosen:], chunks=chunks, dim=0)
198
194
 
199
195
  for (
200
196
  chosen_input_chunk,
@@ -208,26 +204,15 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
208
204
  _rejected_input_chunks,
209
205
  _chosen_target_chunks,
210
206
  _rejected_target_chunks,
211
- (
212
- _ref_chosen_input_chunks
213
- if use_ref_model
214
- else [None] * len(_chosen_input_chunks)
215
- ),
216
- (
217
- _ref_rejected_input_chunks
218
- if use_ref_model
219
- else [None] * len(_rejected_input_chunks)
220
- ),
207
+ (_ref_chosen_input_chunks if use_ref_model else [None] * len(_chosen_input_chunks)),
208
+ (_ref_rejected_input_chunks if use_ref_model else [None] * len(_rejected_input_chunks)),
209
+ strict=False,
221
210
  ):
222
211
  input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
223
212
  ref_input_chunk = (
224
- torch.cat([ref_chosen_input_chunk, ref_rejected_input_chunk], dim=0)
225
- if use_ref_model
226
- else None
227
- )
228
- target_chunk = torch.cat(
229
- [chosen_target_chunk, rejected_target_chunk], dim=0
213
+ torch.cat([ref_chosen_input_chunk, ref_rejected_input_chunk], dim=0) if use_ref_model else None
230
214
  )
215
+ target_chunk = torch.cat([chosen_target_chunk, rejected_target_chunk], dim=0)
231
216
 
232
217
  # mark input_chunk, target_chunk, and target dimension 1 as dynamic to prevent torch.compile recompilation
233
218
  torch._dynamo.mark_dynamic(input_chunk, 1)
@@ -265,9 +250,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
265
250
  @staticmethod
266
251
  def backward(ctx, *grad_output):
267
252
  grad_input, grad_weight, grad_bias = ctx.saved_tensors
268
- if torch.ne(
269
- grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)
270
- ):
253
+ if torch.ne(grad_output[0][0], torch.tensor(1.0, device=grad_output[0][0].device)):
271
254
  grad_input = grad_input * grad_output[0][0]
272
255
  grad_weight = grad_weight * grad_output[0][0]
273
256
  grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None
@@ -301,9 +284,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
301
284
  loss_mask = target_chunk != ignore_index
302
285
  label_chunk = torch.where(loss_mask, target_chunk, 0)
303
286
 
304
- per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
305
- -1
306
- )
287
+ per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
307
288
  average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
308
289
 
309
290
  chosen_logps = average_log_prob[:len_chosen_chunk]
@@ -370,13 +351,8 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
370
351
  ignore_index=ignore_index,
371
352
  compute_nll_loss=compute_nll_loss,
372
353
  )
373
- chosen_nll_loss = (
374
- chosen_nll_loss
375
- / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
376
- )
377
- chosen_logits_mean = chosen_logits.sum() / (
378
- full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
379
- )
354
+ chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
355
+ chosen_logits_mean = chosen_logits.sum() / (full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0])
380
356
  rejected_logits_mean = rejected_logits.sum() / (
381
357
  full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
382
358
  )
@@ -1,13 +1,10 @@
1
1
  import torch
2
2
  import torch.nn.functional as F
3
3
 
4
- from liger_kernel.chunked_loss.fused_linear_preference import (
5
- LigerFusedLinearPreferenceBase,
6
- )
4
+ from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
7
5
 
8
6
 
9
7
  class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
10
-
11
8
  @staticmethod
12
9
  def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
13
10
  """
@@ -32,8 +29,7 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
32
29
  beta (float): Weight for the odds ratio loss.
33
30
  """
34
31
  log_odds = (chosen_logps - rejected_logps) - (
35
- torch.log1p(-torch.exp(chosen_logps))
36
- - torch.log1p(-torch.exp(rejected_logps))
32
+ torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
37
33
  )
38
34
  ratio = F.logsigmoid(log_odds)
39
35
  loss = -beta * ratio.sum() / (full_target.shape[0] // 2)
@@ -1,13 +1,10 @@
1
1
  import torch
2
2
  import torch.nn.functional as F
3
3
 
4
- from liger_kernel.chunked_loss.fused_linear_preference import (
5
- LigerFusedLinearPreferenceBase,
6
- )
4
+ from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
7
5
 
8
6
 
9
7
  class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
10
-
11
8
  @staticmethod
12
9
  def preference_loss_fn(
13
10
  chosen_logps,
@@ -41,10 +38,9 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
41
38
  label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
42
39
  """
43
40
  logits = beta * (chosen_logps - rejected_logps) - gamma
44
- loss = (
45
- - F.logsigmoid(logits) * (1 - label_smoothing)
46
- - F.logsigmoid(-logits) * label_smoothing
47
- ).sum() / (full_target.shape[0] // 2)
41
+ loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
42
+ full_target.shape[0] // 2
43
+ )
48
44
 
49
45
  return loss
50
46
 
@@ -1,5 +1,6 @@
1
1
  import platform
2
2
  import sys
3
+
3
4
  from importlib.metadata import version
4
5
 
5
6
 
@@ -27,15 +28,9 @@ def print_env_report():
27
28
  import torch
28
29
 
29
30
  print(f"PyTorch version: {torch.__version__}")
30
- cuda_version = (
31
- torch.version.cuda if torch.cuda.is_available() else "Not available"
32
- )
31
+ cuda_version = torch.version.cuda if torch.cuda.is_available() else "Not available"
33
32
  print(f"CUDA version: {cuda_version}")
34
- hip_version = (
35
- torch.version.hip
36
- if torch.cuda.is_available() and torch.version.hip
37
- else "Not available"
38
- )
33
+ hip_version = torch.version.hip if torch.cuda.is_available() and torch.version.hip else "Not available"
39
34
  print(f"HIP(ROCm) version: {hip_version}")
40
35
 
41
36
  except ImportError:
@@ -58,9 +53,7 @@ def print_env_report():
58
53
  print("Transformers: Not installed")
59
54
 
60
55
  try:
61
- xpu_version = (
62
- torch.version.xpu if torch.xpu.is_available() else "XPU Not Available"
63
- )
56
+ xpu_version = torch.version.xpu if torch.xpu.is_available() else "XPU Not Available"
64
57
  print(f"XPU version: {xpu_version}")
65
58
  except ImportError:
66
59
  print("XPU version: Unable to query")
@@ -1,11 +1,14 @@
1
1
  import operator
2
+
2
3
  from typing import Optional
3
4
 
4
5
  import torch
5
6
  import triton
6
7
  import triton.language as tl
7
8
 
8
- from liger_kernel.ops.utils import compare_version, element_mul_kernel, is_hip
9
+ from liger_kernel.ops.utils import compare_version
10
+ from liger_kernel.ops.utils import element_mul_kernel
11
+ from liger_kernel.ops.utils import is_hip
9
12
 
10
13
  if compare_version("triton", operator.ge, "3.0.0"):
11
14
  try:
@@ -92,9 +95,7 @@ def liger_cross_entropy_kernel(
92
95
  # 3. [Online softmax] first pass: find max + sum
93
96
  m = float("-inf") # m is the max value. use the notation from the paper
94
97
  d = 0.0 # d is the sum. use the notation from the paper
95
- ori_X_y = tl.load(X_ptr + y).cast(
96
- tl.float32
97
- ) # we need to store the original value of X_y for the loss calculation
98
+ ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation
98
99
  if HAS_SOFTCAPPING:
99
100
  ori_X_y = softcap * tanh(ori_X_y / softcap)
100
101
 
@@ -232,14 +233,10 @@ def cross_entropy_forward(
232
233
  return_z_loss,
233
234
  ):
234
235
  if not isinstance(return_z_loss, int):
235
- assert (
236
- return_z_loss in _bool_to_return_z_loss
237
- ), f"return_z_loss must be True or False. Got: {return_z_loss}"
236
+ assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}"
238
237
  return_z_loss = _bool_to_return_z_loss[return_z_loss]
239
238
  else:
240
- assert (
241
- return_z_loss in _bool_to_return_z_loss
242
- ), f"return_z_loss must be True or False. Got: {return_z_loss}"
239
+ assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}"
243
240
 
244
241
  BT, V = _input.shape
245
242
  n_rows = BT
@@ -34,9 +34,7 @@ def embedding_forward_kernel(
34
34
  )
35
35
 
36
36
  output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :]
37
- tl.store(
38
- output_ptr + output_offsets, embeddings, mask=mask_m[:, None] & mask_n[None, :]
39
- )
37
+ tl.store(output_ptr + output_offsets, embeddings, mask=mask_m[:, None] & mask_n[None, :])
40
38
 
41
39
 
42
40
  @triton.jit
@@ -37,9 +37,7 @@ def pack_weights(intweights: torch.Tensor, bits: int = 2) -> torch.Tensor:
37
37
  else:
38
38
  packed_tensor_shape = (row_dim, *original_shape[1:])
39
39
 
40
- packed = torch.zeros(
41
- packed_tensor_shape, device=intweights.device, dtype=torch.uint8
42
- )
40
+ packed = torch.zeros(packed_tensor_shape, device=intweights.device, dtype=torch.uint8)
43
41
  unpacked = intweights.to(torch.uint8)
44
42
 
45
43
  def lshift(t: torch.Tensor, bits: int):
@@ -327,17 +325,13 @@ def matmul_kernel(
327
325
 
328
326
 
329
327
  def matmul(a, b):
330
- assert (
331
- a.shape[1] == b.shape[0] * 4
332
- ), "Incompatible dimensions, the weight matrix need to be packed"
328
+ assert a.shape[1] == b.shape[0] * 4, "Incompatible dimensions, the weight matrix need to be packed"
333
329
  assert a.is_contiguous(), "Matrix A must be contiguous"
334
330
  M, K = a.shape
335
331
  _, N = b.shape
336
332
  # c is in int32 to avoid any overflows or underflows
337
333
  c = torch.empty((M, N), device=a.device, dtype=torch.int32)
338
- grid = lambda META: (
339
- triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
340
- )
334
+ grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),)
341
335
  matmul_kernel[grid](
342
336
  a,
343
337
  b,
@@ -2,12 +2,10 @@ import torch
2
2
  import triton
3
3
 
4
4
  from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel
5
- from liger_kernel.ops.utils import (
6
- amp_custom_bwd,
7
- amp_custom_fwd,
8
- element_mul_kernel,
9
- is_hip,
10
- )
5
+ from liger_kernel.ops.utils import amp_custom_bwd
6
+ from liger_kernel.ops.utils import amp_custom_fwd
7
+ from liger_kernel.ops.utils import element_mul_kernel
8
+ from liger_kernel.ops.utils import is_hip
11
9
 
12
10
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
13
11
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
@@ -40,14 +38,10 @@ def fused_linear_cross_entropy_forward(
40
38
  BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
41
39
 
42
40
  inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
43
- chunk_size = triton.next_power_of_2(
44
- triton.cdiv(BT, inc_factor)
45
- ) # (BT + inc_factor - 1) // inc_factor
41
+ chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
46
42
  num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
47
43
 
48
- grad_weight = (
49
- torch.zeros_like(weight, device=device) if weight.requires_grad else None
50
- )
44
+ grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None
51
45
  grad_input = torch.zeros_like(_input, device=device)
52
46
  grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
53
47
  # we use fp32 for loss accumulator
@@ -137,9 +131,7 @@ def fused_linear_cross_entropy_forward(
137
131
  return loss, grad_input, grad_weight, grad_bias
138
132
 
139
133
 
140
- def fused_linear_cross_entropy_backward(
141
- grad_output, grad_input, grad_weight, grad_bias
142
- ):
134
+ def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
143
135
  # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
144
136
  if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
145
137
  # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
@@ -4,12 +4,10 @@ import torch
4
4
  import triton
5
5
 
6
6
  from liger_kernel.ops.jsd import _jsd_kernel
7
- from liger_kernel.ops.utils import (
8
- amp_custom_bwd,
9
- amp_custom_fwd,
10
- element_mul_kernel,
11
- is_hip,
12
- )
7
+ from liger_kernel.ops.utils import amp_custom_bwd
8
+ from liger_kernel.ops.utils import amp_custom_fwd
9
+ from liger_kernel.ops.utils import element_mul_kernel
10
+ from liger_kernel.ops.utils import is_hip
13
11
 
14
12
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
15
13
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
@@ -43,16 +41,10 @@ def fused_linear_jsd_forward(
43
41
  BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
44
42
 
45
43
  inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
46
- chunk_size = triton.next_power_of_2(
47
- triton.cdiv(BT, inc_factor)
48
- ) # (BT + inc_factor - 1) // inc_factor
44
+ chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor
49
45
  num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
50
46
 
51
- grad_weight = (
52
- torch.zeros_like(student_weight, device=device)
53
- if student_weight.requires_grad
54
- else None
55
- )
47
+ grad_weight = torch.zeros_like(student_weight, device=device) if student_weight.requires_grad else None
56
48
  grad_input = torch.zeros_like(student_input)
57
49
  # we use fp32 for loss accumulator
58
50
  loss_1d = torch.zeros((BT, V), dtype=torch.float32, device=device)
@@ -73,12 +65,8 @@ def fused_linear_jsd_forward(
73
65
  # shape: chunk_size x V
74
66
  # For anything starting from logits to the final JSD loss, we do computation
75
67
  # in FP32 to avoid losing numerical stability.
76
- student_logits_chunk = (student_input_chunk @ student_weight.t()).to(
77
- torch.float32
78
- )
79
- teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(
80
- torch.float32
81
- )
68
+ student_logits_chunk = (student_input_chunk @ student_weight.t()).to(torch.float32)
69
+ teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(torch.float32)
82
70
  chunk_n_rows = student_logits_chunk.shape[0]
83
71
 
84
72
  # unreduced loss
@@ -104,9 +92,7 @@ def fused_linear_jsd_forward(
104
92
  dX_ptr=student_prob_chunk,
105
93
  dX_stride=student_prob_chunk.stride(-2),
106
94
  label_ptr=(
107
- shift_labels[start_idx:end_idx]
108
- if has_label
109
- else torch.empty(1, device=device)
95
+ shift_labels[start_idx:end_idx] if has_label else torch.empty(1, device=device)
110
96
  ), # dummy ptr if no label
111
97
  beta=jsd_beta,
112
98
  n_non_ignore=n_non_ignore,
@@ -121,9 +107,7 @@ def fused_linear_jsd_forward(
121
107
  student_logits_chunk = (
122
108
  student_prob_chunk
123
109
  - torch.softmax(student_logits_chunk, dim=-1)
124
- * student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(
125
- student_prob_chunk.shape
126
- )
110
+ * student_prob_chunk.sum(dim=-1, keepdim=True).broadcast_to(student_prob_chunk.shape)
127
111
  ) / temperature
128
112
  # now we traverse back to grad w.r.t. input to `lm_head` and grad
129
113
  # w.r.t. `lm_head` which should be computed in original dtype
@@ -239,7 +223,5 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
239
223
  @amp_custom_bwd
240
224
  def backward(ctx, grad_output):
241
225
  (grad_input, grad_weight) = ctx.saved_tensors
242
- grad_input, grad_weight = fused_linear_jsd_backward(
243
- grad_output, grad_input, grad_weight
244
- )
226
+ grad_input, grad_weight = fused_linear_jsd_backward(grad_output, grad_input, grad_weight)
245
227
  return (grad_input, grad_weight, None, None, None, None, None, None)