liger-kernel-nightly 0.5.3.dev20250221162633__py3-none-any.whl → 0.5.3.dev20250221230243__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

liger_kernel/ops/tvd.py CHANGED
@@ -1,4 +1,5 @@
1
- from typing import Literal, Optional
1
+ from typing import Literal
2
+ from typing import Optional
2
3
 
3
4
  import torch
4
5
  import triton
@@ -178,15 +179,13 @@ class LigerTVDLossFunction(torch.autograd.Function):
178
179
  """
179
180
  has_label = False
180
181
  if shift_labels is not None:
181
- assert shift_labels.shape == (
182
- p.shape[0],
183
- ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
182
+ assert shift_labels.shape == (p.shape[0],), (
183
+ f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
184
+ )
184
185
  shift_labels = shift_labels.contiguous()
185
186
  has_label = True
186
187
 
187
- loss, grads = tv_distance_forward_triton(
188
- p, q, shift_labels, reduction, ignore_index, has_label
189
- )
188
+ loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label)
190
189
  ctx.save_for_backward(grads)
191
190
  return loss
192
191
 
@@ -19,7 +19,7 @@ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2
19
19
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
20
20
  from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
21
21
  from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
22
- from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401
23
22
  from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
24
23
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
25
24
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
25
+ from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401
@@ -14,6 +14,7 @@ from liger_kernel.ops.rope import LigerRopeFunction
14
14
  from liger_kernel.ops.swiglu import LigerSiLUMulFunction
15
15
  from liger_kernel.ops.tvd import LigerTVDLossFunction
16
16
 
17
+
17
18
  # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
18
19
  # `weight` and `size_average` are placeholders and not implemented yet
19
20
  def liger_cross_entropy(
@@ -156,6 +157,7 @@ def liger_kl_div(
156
157
  eps,
157
158
  )
158
159
 
160
+
159
161
  def liger_tvd(
160
162
  input,
161
163
  target,
@@ -169,7 +171,8 @@ def liger_tvd(
169
171
  shift_labels,
170
172
  reduction,
171
173
  ignore_index,
172
- )
174
+ )
175
+
173
176
 
174
177
  def liger_layer_norm(X, W, B, eps):
175
178
  return LigerLayerNormFunction.apply(X, W, B, eps)
@@ -10,6 +10,4 @@ class LigerTVDLoss(nn.Module):
10
10
  self.ignore_index = ignore_index
11
11
 
12
12
  def forward(self, p, q, shift_labels=None):
13
- return LigerTVDLossFunction.apply(
14
- p, q, shift_labels, self.reduction, self.ignore_index
15
- )
13
+ return LigerTVDLossFunction.apply(p, q, shift_labels, self.reduction, self.ignore_index)
liger_kernel/utils.py CHANGED
@@ -9,5 +9,7 @@ def infer_device():
9
9
  return "cuda"
10
10
  elif torch.xpu.is_available():
11
11
  return "xpu"
12
+ elif torch.hip.is_available():
13
+ return "hip"
12
14
  else:
13
15
  return "cpu"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.3.dev20250221162633
3
+ Version: 0.5.3.dev20250221230243
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -1,6 +1,6 @@
1
1
  liger_kernel/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  liger_kernel/env_report.py,sha256=uhdEC8OydxoZlb7B6YYcAaBF3crGFdIck-4cxaW4NJY,1728
3
- liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
3
+ liger_kernel/utils.py,sha256=Wh9TkveQY4snwiyKWAvWXUpVQKX1ARX2tL0T6qzEoIQ,305
4
4
  liger_kernel/chunked_loss/README.md,sha256=0FmkFC3hKBqyoDT5uTlIYmrvRkF-EOCR1y-EBU1LpWU,2248
5
5
  liger_kernel/chunked_loss/__init__.py,sha256=ATu-xX5Fc49Cr6yBOGBRNTo593ZrU5ZCsIuvoIbJWw4,603
6
6
  liger_kernel/chunked_loss/cpo_loss.py,sha256=OdBR8WYdHTKpLI_c9DcuwqKSWPeAAeTyREz46Vu_cAY,3682
@@ -28,14 +28,14 @@ liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML
28
28
  liger_kernel/ops/rms_norm.py,sha256=PWLJcdIKU5e-8BuYFHd9Cqlq6wmr6fUXKi9zQD4LetU,11727
29
29
  liger_kernel/ops/rope.py,sha256=ofmBOkUpZZO-Q8Z5B_LOFYYLD-YT-8WnJ4vGOrDYouI,8943
30
30
  liger_kernel/ops/swiglu.py,sha256=KmgMjaJQnbLLgZn2nEpbwHU_xpnYRweCyrLQSVvM1vA,3015
31
- liger_kernel/ops/tvd.py,sha256=9wVCijj2vBtgiLeUHhl7hy_LAiJ3liPIYOGMSU3P1ro,6407
31
+ liger_kernel/ops/tvd.py,sha256=FHJtLQI95ijqgg9UtaHpMAjSCiPxB6CduPwPMcGxelc,6405
32
32
  liger_kernel/ops/utils.py,sha256=uoFKQqo-34N2TWQNvXMFywqGiOMMXNEVBxVojzlUAa0,3836
33
33
  liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
34
34
  liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
35
- liger_kernel/transformers/__init__.py,sha256=i6GPkP5-esFBh205nF4MluNrL7KNugseGiUKdSHGW70,2172
35
+ liger_kernel/transformers/__init__.py,sha256=MGgdJkohu0tQS6owEBHuRVYhRUPXRFP9OiVc1fcjkjc,2172
36
36
  liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
37
37
  liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
38
- liger_kernel/transformers/functional.py,sha256=zahXVCjA2NxcVFpAgajILIRN0GO6mrbfLPgONUkTrY8,4940
38
+ liger_kernel/transformers/functional.py,sha256=ShLD3eb--XKNtllznCrOYTbo4f-1KVwzi0KLMICdrn4,4942
39
39
  liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=09Rt7FZzLH42VOcIbQ4dlQd0o3Rlb4vk6fqiOQ7WTD8,1778
40
40
  liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
41
41
  liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
@@ -49,7 +49,7 @@ liger_kernel/transformers/rms_norm.py,sha256=GqCEJuGt0YdqqlMcToE0Wp4A8YFquDa4UUS
49
49
  liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
50
50
  liger_kernel/transformers/swiglu.py,sha256=i9WTqcNRqReU4XJs391IPbl-I5X0wG4T72D4pqGFfJg,2422
51
51
  liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx-uy2f2cFfveZpqbUdhw,123
52
- liger_kernel/transformers/tvd.py,sha256=lt730XDR4IYEkn-HeWS7WU6AGssg90ubg8pqPAX2lbE,472
52
+ liger_kernel/transformers/tvd.py,sha256=XrRfyJIqN6HFxXk8MYyFVZM1OLz3mtSbRZvWfZ_JerQ,450
53
53
  liger_kernel/transformers/experimental/embedding.py,sha256=2P0QYdlFyFrG5OqTzTa1wcRgDSyjBMv5i1a7BrDPDQw,881
54
54
  liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
55
55
  liger_kernel/transformers/model/gemma.py,sha256=ky89b3aWPaeTGRMC-745KgixtQIRXzNAiCORAMLn9yo,9654
@@ -65,9 +65,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
65
65
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
66
66
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
67
67
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
68
- liger_kernel_nightly-0.5.3.dev20250221162633.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
69
- liger_kernel_nightly-0.5.3.dev20250221162633.dist-info/METADATA,sha256=9kT0AmFeMP_D-n-dGhoOSBX8ThMPM8ypnFT4f3O1FGc,22093
70
- liger_kernel_nightly-0.5.3.dev20250221162633.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
71
- liger_kernel_nightly-0.5.3.dev20250221162633.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
72
- liger_kernel_nightly-0.5.3.dev20250221162633.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
73
- liger_kernel_nightly-0.5.3.dev20250221162633.dist-info/RECORD,,
68
+ liger_kernel_nightly-0.5.3.dev20250221230243.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
69
+ liger_kernel_nightly-0.5.3.dev20250221230243.dist-info/METADATA,sha256=WdkWsUQstDqFp6VlaycZn_D5hm4tuHc_4NA6cAo8Gl4,22093
70
+ liger_kernel_nightly-0.5.3.dev20250221230243.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
71
+ liger_kernel_nightly-0.5.3.dev20250221230243.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
72
+ liger_kernel_nightly-0.5.3.dev20250221230243.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
73
+ liger_kernel_nightly-0.5.3.dev20250221230243.dist-info/RECORD,,