liger-kernel-nightly 0.5.3.dev20250221162633__py3-none-any.whl → 0.5.3.dev20250221233257__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/ops/tvd.py +6 -7
- liger_kernel/transformers/__init__.py +1 -1
- liger_kernel/transformers/functional.py +4 -1
- liger_kernel/transformers/tvd.py +1 -3
- liger_kernel/utils.py +49 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633.dist-info → liger_kernel_nightly-0.5.3.dev20250221233257.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.3.dev20250221162633.dist-info → liger_kernel_nightly-0.5.3.dev20250221233257.dist-info}/RECORD +11 -11
- {liger_kernel_nightly-0.5.3.dev20250221162633.dist-info → liger_kernel_nightly-0.5.3.dev20250221233257.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633.dist-info → liger_kernel_nightly-0.5.3.dev20250221233257.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633.dist-info → liger_kernel_nightly-0.5.3.dev20250221233257.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221162633.dist-info → liger_kernel_nightly-0.5.3.dev20250221233257.dist-info}/top_level.txt +0 -0
liger_kernel/ops/tvd.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
from typing import Literal
|
|
1
|
+
from typing import Literal
|
|
2
|
+
from typing import Optional
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
import triton
|
|
@@ -178,15 +179,13 @@ class LigerTVDLossFunction(torch.autograd.Function):
|
|
|
178
179
|
"""
|
|
179
180
|
has_label = False
|
|
180
181
|
if shift_labels is not None:
|
|
181
|
-
assert shift_labels.shape == (
|
|
182
|
-
|
|
183
|
-
)
|
|
182
|
+
assert shift_labels.shape == (p.shape[0],), (
|
|
183
|
+
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
|
|
184
|
+
)
|
|
184
185
|
shift_labels = shift_labels.contiguous()
|
|
185
186
|
has_label = True
|
|
186
187
|
|
|
187
|
-
loss, grads = tv_distance_forward_triton(
|
|
188
|
-
p, q, shift_labels, reduction, ignore_index, has_label
|
|
189
|
-
)
|
|
188
|
+
loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label)
|
|
190
189
|
ctx.save_for_backward(grads)
|
|
191
190
|
return loss
|
|
192
191
|
|
|
@@ -19,7 +19,7 @@ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2
|
|
|
19
19
|
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
|
|
20
20
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
|
|
21
21
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
|
|
22
|
-
from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401
|
|
23
22
|
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
|
|
24
23
|
from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
|
|
25
24
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
|
|
25
|
+
from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401
|
|
@@ -14,6 +14,7 @@ from liger_kernel.ops.rope import LigerRopeFunction
|
|
|
14
14
|
from liger_kernel.ops.swiglu import LigerSiLUMulFunction
|
|
15
15
|
from liger_kernel.ops.tvd import LigerTVDLossFunction
|
|
16
16
|
|
|
17
|
+
|
|
17
18
|
# conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
|
|
18
19
|
# `weight` and `size_average` are placeholders and not implemented yet
|
|
19
20
|
def liger_cross_entropy(
|
|
@@ -156,6 +157,7 @@ def liger_kl_div(
|
|
|
156
157
|
eps,
|
|
157
158
|
)
|
|
158
159
|
|
|
160
|
+
|
|
159
161
|
def liger_tvd(
|
|
160
162
|
input,
|
|
161
163
|
target,
|
|
@@ -169,7 +171,8 @@ def liger_tvd(
|
|
|
169
171
|
shift_labels,
|
|
170
172
|
reduction,
|
|
171
173
|
ignore_index,
|
|
172
|
-
)
|
|
174
|
+
)
|
|
175
|
+
|
|
173
176
|
|
|
174
177
|
def liger_layer_norm(X, W, B, eps):
|
|
175
178
|
return LigerLayerNormFunction.apply(X, W, B, eps)
|
liger_kernel/transformers/tvd.py
CHANGED
|
@@ -10,6 +10,4 @@ class LigerTVDLoss(nn.Module):
|
|
|
10
10
|
self.ignore_index = ignore_index
|
|
11
11
|
|
|
12
12
|
def forward(self, p, q, shift_labels=None):
|
|
13
|
-
return LigerTVDLossFunction.apply(
|
|
14
|
-
p, q, shift_labels, self.reduction, self.ignore_index
|
|
15
|
-
)
|
|
13
|
+
return LigerTVDLossFunction.apply(p, q, shift_labels, self.reduction, self.ignore_index)
|
liger_kernel/utils.py
CHANGED
|
@@ -9,5 +9,54 @@ def infer_device():
|
|
|
9
9
|
return "cuda"
|
|
10
10
|
elif torch.xpu.is_available():
|
|
11
11
|
return "xpu"
|
|
12
|
+
elif torch.hip.is_available():
|
|
13
|
+
return "hip"
|
|
12
14
|
else:
|
|
13
15
|
return "cpu"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def transformers_version_dispatch(
|
|
19
|
+
required_version: str,
|
|
20
|
+
before_fn,
|
|
21
|
+
after_fn,
|
|
22
|
+
before_args: tuple = (),
|
|
23
|
+
after_args: tuple = (),
|
|
24
|
+
before_kwargs: dict = None,
|
|
25
|
+
after_kwargs: dict = None,
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Dispatches to different functions based on package version comparison.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
required_version: Version to compare against (e.g. "4.48.0")
|
|
32
|
+
before_fn: Function to call if package_version < required_version
|
|
33
|
+
after_fn: Function to call if package_version >= required_version
|
|
34
|
+
before_args: Positional arguments for before_fn
|
|
35
|
+
after_args: Positional arguments for after_fn
|
|
36
|
+
before_kwargs: Keyword arguments for before_fn
|
|
37
|
+
after_kwargs: Keyword arguments for after_fn
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
Result from either before_fn or after_fn
|
|
41
|
+
|
|
42
|
+
Example:
|
|
43
|
+
>>> rotary_emb = transformers_version_dispatch(
|
|
44
|
+
... "4.48.0",
|
|
45
|
+
... LlamaRotaryEmbedding,
|
|
46
|
+
... LlamaRotaryEmbedding,
|
|
47
|
+
... before_args=(head_dim,),
|
|
48
|
+
... after_args=(LlamaConfig(head_dim=head_dim),),
|
|
49
|
+
... before_kwargs={'device': device},
|
|
50
|
+
... after_kwargs={'device': device}
|
|
51
|
+
... )
|
|
52
|
+
"""
|
|
53
|
+
from packaging import version
|
|
54
|
+
from transformers import __version__ as transformers_version
|
|
55
|
+
|
|
56
|
+
before_kwargs = before_kwargs or {}
|
|
57
|
+
after_kwargs = after_kwargs or {}
|
|
58
|
+
|
|
59
|
+
if version.parse(transformers_version) < version.parse(required_version):
|
|
60
|
+
return before_fn(*before_args, **before_kwargs)
|
|
61
|
+
else:
|
|
62
|
+
return after_fn(*after_args, **after_kwargs)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
liger_kernel/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
liger_kernel/env_report.py,sha256=uhdEC8OydxoZlb7B6YYcAaBF3crGFdIck-4cxaW4NJY,1728
|
|
3
|
-
liger_kernel/utils.py,sha256=
|
|
3
|
+
liger_kernel/utils.py,sha256=FtVUkCGBT1UNasTl6HMNycWwiwHayK6tx-ZDdA-sNX4,1884
|
|
4
4
|
liger_kernel/chunked_loss/README.md,sha256=0FmkFC3hKBqyoDT5uTlIYmrvRkF-EOCR1y-EBU1LpWU,2248
|
|
5
5
|
liger_kernel/chunked_loss/__init__.py,sha256=ATu-xX5Fc49Cr6yBOGBRNTo593ZrU5ZCsIuvoIbJWw4,603
|
|
6
6
|
liger_kernel/chunked_loss/cpo_loss.py,sha256=OdBR8WYdHTKpLI_c9DcuwqKSWPeAAeTyREz46Vu_cAY,3682
|
|
@@ -28,14 +28,14 @@ liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML
|
|
|
28
28
|
liger_kernel/ops/rms_norm.py,sha256=PWLJcdIKU5e-8BuYFHd9Cqlq6wmr6fUXKi9zQD4LetU,11727
|
|
29
29
|
liger_kernel/ops/rope.py,sha256=ofmBOkUpZZO-Q8Z5B_LOFYYLD-YT-8WnJ4vGOrDYouI,8943
|
|
30
30
|
liger_kernel/ops/swiglu.py,sha256=KmgMjaJQnbLLgZn2nEpbwHU_xpnYRweCyrLQSVvM1vA,3015
|
|
31
|
-
liger_kernel/ops/tvd.py,sha256=
|
|
31
|
+
liger_kernel/ops/tvd.py,sha256=FHJtLQI95ijqgg9UtaHpMAjSCiPxB6CduPwPMcGxelc,6405
|
|
32
32
|
liger_kernel/ops/utils.py,sha256=uoFKQqo-34N2TWQNvXMFywqGiOMMXNEVBxVojzlUAa0,3836
|
|
33
33
|
liger_kernel/ops/experimental/embedding.py,sha256=tolj3tItkzpSb30zWqDN2_yX4ectflaQ8HMyKyFIQc8,4172
|
|
34
34
|
liger_kernel/ops/experimental/mm_int8int2.py,sha256=TrS9lpwekrik_w5qE7AhMJD1bcq-OidjtbsW80oZ6IM,13314
|
|
35
|
-
liger_kernel/transformers/__init__.py,sha256=
|
|
35
|
+
liger_kernel/transformers/__init__.py,sha256=MGgdJkohu0tQS6owEBHuRVYhRUPXRFP9OiVc1fcjkjc,2172
|
|
36
36
|
liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawXOGgoCZjgEg,1545
|
|
37
37
|
liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
|
|
38
|
-
liger_kernel/transformers/functional.py,sha256=
|
|
38
|
+
liger_kernel/transformers/functional.py,sha256=ShLD3eb--XKNtllznCrOYTbo4f-1KVwzi0KLMICdrn4,4942
|
|
39
39
|
liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=09Rt7FZzLH42VOcIbQ4dlQd0o3Rlb4vk6fqiOQ7WTD8,1778
|
|
40
40
|
liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
|
|
41
41
|
liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
|
|
@@ -49,7 +49,7 @@ liger_kernel/transformers/rms_norm.py,sha256=GqCEJuGt0YdqqlMcToE0Wp4A8YFquDa4UUS
|
|
|
49
49
|
liger_kernel/transformers/rope.py,sha256=ZTrTORSAyfcFIKjk6XEeYmk4ROH7xXED9L4g2NFntlE,999
|
|
50
50
|
liger_kernel/transformers/swiglu.py,sha256=i9WTqcNRqReU4XJs391IPbl-I5X0wG4T72D4pqGFfJg,2422
|
|
51
51
|
liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx-uy2f2cFfveZpqbUdhw,123
|
|
52
|
-
liger_kernel/transformers/tvd.py,sha256=
|
|
52
|
+
liger_kernel/transformers/tvd.py,sha256=XrRfyJIqN6HFxXk8MYyFVZM1OLz3mtSbRZvWfZ_JerQ,450
|
|
53
53
|
liger_kernel/transformers/experimental/embedding.py,sha256=2P0QYdlFyFrG5OqTzTa1wcRgDSyjBMv5i1a7BrDPDQw,881
|
|
54
54
|
liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
55
55
|
liger_kernel/transformers/model/gemma.py,sha256=ky89b3aWPaeTGRMC-745KgixtQIRXzNAiCORAMLn9yo,9654
|
|
@@ -65,9 +65,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
|
65
65
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
|
|
66
66
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
|
67
67
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
|
68
|
-
liger_kernel_nightly-0.5.3.
|
|
69
|
-
liger_kernel_nightly-0.5.3.
|
|
70
|
-
liger_kernel_nightly-0.5.3.
|
|
71
|
-
liger_kernel_nightly-0.5.3.
|
|
72
|
-
liger_kernel_nightly-0.5.3.
|
|
73
|
-
liger_kernel_nightly-0.5.3.
|
|
68
|
+
liger_kernel_nightly-0.5.3.dev20250221233257.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
|
69
|
+
liger_kernel_nightly-0.5.3.dev20250221233257.dist-info/METADATA,sha256=ZsBEdMtozVk3GNw6IRdb-wu7XjzLVefruhKVD7JJjdE,22093
|
|
70
|
+
liger_kernel_nightly-0.5.3.dev20250221233257.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
|
71
|
+
liger_kernel_nightly-0.5.3.dev20250221233257.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
|
72
|
+
liger_kernel_nightly-0.5.3.dev20250221233257.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
|
73
|
+
liger_kernel_nightly-0.5.3.dev20250221233257.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|