liger-kernel-nightly 0.5.3.dev20250221230243__py3-none-any.whl → 0.5.3.dev20250221233257__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/utils.py +47 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243.dist-info → liger_kernel_nightly-0.5.3.dev20250221233257.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.3.dev20250221230243.dist-info → liger_kernel_nightly-0.5.3.dev20250221233257.dist-info}/RECORD +7 -7
- {liger_kernel_nightly-0.5.3.dev20250221230243.dist-info → liger_kernel_nightly-0.5.3.dev20250221233257.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243.dist-info → liger_kernel_nightly-0.5.3.dev20250221233257.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243.dist-info → liger_kernel_nightly-0.5.3.dev20250221233257.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.3.dev20250221230243.dist-info → liger_kernel_nightly-0.5.3.dev20250221233257.dist-info}/top_level.txt +0 -0
liger_kernel/utils.py
CHANGED
|
@@ -13,3 +13,50 @@ def infer_device():
|
|
|
13
13
|
return "hip"
|
|
14
14
|
else:
|
|
15
15
|
return "cpu"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def transformers_version_dispatch(
|
|
19
|
+
required_version: str,
|
|
20
|
+
before_fn,
|
|
21
|
+
after_fn,
|
|
22
|
+
before_args: tuple = (),
|
|
23
|
+
after_args: tuple = (),
|
|
24
|
+
before_kwargs: dict = None,
|
|
25
|
+
after_kwargs: dict = None,
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Dispatches to different functions based on package version comparison.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
required_version: Version to compare against (e.g. "4.48.0")
|
|
32
|
+
before_fn: Function to call if package_version < required_version
|
|
33
|
+
after_fn: Function to call if package_version >= required_version
|
|
34
|
+
before_args: Positional arguments for before_fn
|
|
35
|
+
after_args: Positional arguments for after_fn
|
|
36
|
+
before_kwargs: Keyword arguments for before_fn
|
|
37
|
+
after_kwargs: Keyword arguments for after_fn
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
Result from either before_fn or after_fn
|
|
41
|
+
|
|
42
|
+
Example:
|
|
43
|
+
>>> rotary_emb = transformers_version_dispatch(
|
|
44
|
+
... "4.48.0",
|
|
45
|
+
... LlamaRotaryEmbedding,
|
|
46
|
+
... LlamaRotaryEmbedding,
|
|
47
|
+
... before_args=(head_dim,),
|
|
48
|
+
... after_args=(LlamaConfig(head_dim=head_dim),),
|
|
49
|
+
... before_kwargs={'device': device},
|
|
50
|
+
... after_kwargs={'device': device}
|
|
51
|
+
... )
|
|
52
|
+
"""
|
|
53
|
+
from packaging import version
|
|
54
|
+
from transformers import __version__ as transformers_version
|
|
55
|
+
|
|
56
|
+
before_kwargs = before_kwargs or {}
|
|
57
|
+
after_kwargs = after_kwargs or {}
|
|
58
|
+
|
|
59
|
+
if version.parse(transformers_version) < version.parse(required_version):
|
|
60
|
+
return before_fn(*before_args, **before_kwargs)
|
|
61
|
+
else:
|
|
62
|
+
return after_fn(*after_args, **after_kwargs)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
liger_kernel/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
liger_kernel/env_report.py,sha256=uhdEC8OydxoZlb7B6YYcAaBF3crGFdIck-4cxaW4NJY,1728
|
|
3
|
-
liger_kernel/utils.py,sha256=
|
|
3
|
+
liger_kernel/utils.py,sha256=FtVUkCGBT1UNasTl6HMNycWwiwHayK6tx-ZDdA-sNX4,1884
|
|
4
4
|
liger_kernel/chunked_loss/README.md,sha256=0FmkFC3hKBqyoDT5uTlIYmrvRkF-EOCR1y-EBU1LpWU,2248
|
|
5
5
|
liger_kernel/chunked_loss/__init__.py,sha256=ATu-xX5Fc49Cr6yBOGBRNTo593ZrU5ZCsIuvoIbJWw4,603
|
|
6
6
|
liger_kernel/chunked_loss/cpo_loss.py,sha256=OdBR8WYdHTKpLI_c9DcuwqKSWPeAAeTyREz46Vu_cAY,3682
|
|
@@ -65,9 +65,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
|
65
65
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
|
|
66
66
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
|
67
67
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
|
68
|
-
liger_kernel_nightly-0.5.3.
|
|
69
|
-
liger_kernel_nightly-0.5.3.
|
|
70
|
-
liger_kernel_nightly-0.5.3.
|
|
71
|
-
liger_kernel_nightly-0.5.3.
|
|
72
|
-
liger_kernel_nightly-0.5.3.
|
|
73
|
-
liger_kernel_nightly-0.5.3.
|
|
68
|
+
liger_kernel_nightly-0.5.3.dev20250221233257.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
|
69
|
+
liger_kernel_nightly-0.5.3.dev20250221233257.dist-info/METADATA,sha256=ZsBEdMtozVk3GNw6IRdb-wu7XjzLVefruhKVD7JJjdE,22093
|
|
70
|
+
liger_kernel_nightly-0.5.3.dev20250221233257.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
|
71
|
+
liger_kernel_nightly-0.5.3.dev20250221233257.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
|
72
|
+
liger_kernel_nightly-0.5.3.dev20250221233257.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
|
73
|
+
liger_kernel_nightly-0.5.3.dev20250221233257.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|