liger-kernel-nightly 0.5.5.dev20250402184001__py3-none-any.whl → 0.5.5.dev20250402185606__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/ops/cross_entropy.py +3 -2
- {liger_kernel_nightly-0.5.5.dev20250402184001.dist-info → liger_kernel_nightly-0.5.5.dev20250402185606.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.5.dev20250402184001.dist-info → liger_kernel_nightly-0.5.5.dev20250402185606.dist-info}/RECORD +7 -7
- {liger_kernel_nightly-0.5.5.dev20250402184001.dist-info → liger_kernel_nightly-0.5.5.dev20250402185606.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402184001.dist-info → liger_kernel_nightly-0.5.5.dev20250402185606.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402184001.dist-info → liger_kernel_nightly-0.5.5.dev20250402185606.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.5.dev20250402184001.dist-info → liger_kernel_nightly-0.5.5.dev20250402185606.dist-info}/top_level.txt +0 -0
@@ -9,6 +9,7 @@ import triton.language as tl
|
|
9
9
|
from liger_kernel.ops.utils import compare_version
|
10
10
|
from liger_kernel.ops.utils import element_mul_kernel
|
11
11
|
from liger_kernel.ops.utils import is_hip
|
12
|
+
from liger_kernel.utils import infer_device
|
12
13
|
|
13
14
|
if compare_version("triton", operator.ge, "3.0.0"):
|
14
15
|
try:
|
@@ -59,7 +60,7 @@ def liger_cross_entropy_kernel(
|
|
59
60
|
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
|
60
61
|
loss_stride (int): The stride of the loss tensor.
|
61
62
|
n_cols (int): The number of columns in the input tensor.
|
62
|
-
n_non_ignore (
|
63
|
+
n_non_ignore (float): The number of non-ignored elements in the batch.
|
63
64
|
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
|
64
65
|
weight_sum (float): The sum of weight tensor.
|
65
66
|
ignore_index (int): The index to ignore in the target.
|
@@ -258,7 +259,7 @@ def liger_cross_entropy_kernel(
|
|
258
259
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
259
260
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
260
261
|
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
|
261
|
-
MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
|
262
|
+
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning
|
262
263
|
|
263
264
|
|
264
265
|
def cross_entropy_forward(
|
@@ -16,7 +16,7 @@ liger_kernel/chunked_loss/kto_loss.py,sha256=llVCe6DkcpCo57seGWoMikaQVFApx764jsm
|
|
16
16
|
liger_kernel/chunked_loss/orpo_loss.py,sha256=nu9UYG16dcMw93lvHi4_hYs3Q0FK1KnlmMRj7OpYU8s,4872
|
17
17
|
liger_kernel/chunked_loss/simpo_loss.py,sha256=fy2w8KbhMrBv7b1jdIeH3bBFxY52bPQPZb3KwBvmurM,5385
|
18
18
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
liger_kernel/ops/cross_entropy.py,sha256=
|
19
|
+
liger_kernel/ops/cross_entropy.py,sha256=T5oSsqOS1y-Iea5o9v_BSU-_mIEXqWAT1oX_m59NcA4,18941
|
20
20
|
liger_kernel/ops/dyt.py,sha256=YD1-buHz9VmIX838VKzLc-lm5CeUQ4LAskGDWBUMQHA,6187
|
21
21
|
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=1Y3Uk_TCSjqKgoG2eot1ptnWXJXXQESqGvOmqAW1gsM,10912
|
22
22
|
liger_kernel/ops/fused_linear_jsd.py,sha256=Seshez2qaM6HiTQ8_HEqSwhaeVruNT1SvIM4ZrAPBEU,9602
|
@@ -72,9 +72,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
72
72
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
|
73
73
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
74
74
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
75
|
-
liger_kernel_nightly-0.5.5.
|
76
|
-
liger_kernel_nightly-0.5.5.
|
77
|
-
liger_kernel_nightly-0.5.5.
|
78
|
-
liger_kernel_nightly-0.5.5.
|
79
|
-
liger_kernel_nightly-0.5.5.
|
80
|
-
liger_kernel_nightly-0.5.5.
|
75
|
+
liger_kernel_nightly-0.5.5.dev20250402185606.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
76
|
+
liger_kernel_nightly-0.5.5.dev20250402185606.dist-info/METADATA,sha256=XQaGc9bnsEFdwtLh1Mv5_fX-TIejLbcHk1SP-FEY5ew,22959
|
77
|
+
liger_kernel_nightly-0.5.5.dev20250402185606.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
78
|
+
liger_kernel_nightly-0.5.5.dev20250402185606.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
79
|
+
liger_kernel_nightly-0.5.5.dev20250402185606.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
80
|
+
liger_kernel_nightly-0.5.5.dev20250402185606.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|