liger-kernel-nightly 0.5.6.dev20250408223717__py3-none-any.whl → 0.5.6.dev20250411201510__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/ops/kl_div.py +13 -6
- {liger_kernel_nightly-0.5.6.dev20250408223717.dist-info → liger_kernel_nightly-0.5.6.dev20250411201510.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.6.dev20250408223717.dist-info → liger_kernel_nightly-0.5.6.dev20250411201510.dist-info}/RECORD +7 -7
- {liger_kernel_nightly-0.5.6.dev20250408223717.dist-info → liger_kernel_nightly-0.5.6.dev20250411201510.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.6.dev20250408223717.dist-info → liger_kernel_nightly-0.5.6.dev20250411201510.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.6.dev20250408223717.dist-info → liger_kernel_nightly-0.5.6.dev20250411201510.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.6.dev20250408223717.dist-info → liger_kernel_nightly-0.5.6.dev20250411201510.dist-info}/top_level.txt +0 -0
liger_kernel/ops/kl_div.py
CHANGED
@@ -6,6 +6,7 @@ import triton.language as tl
|
|
6
6
|
|
7
7
|
from liger_kernel.ops.utils import ensure_contiguous
|
8
8
|
from liger_kernel.ops.utils import is_hip
|
9
|
+
from liger_kernel.utils import infer_device
|
9
10
|
|
10
11
|
|
11
12
|
def get_num_warps(BLOCK_SIZE):
|
@@ -115,9 +116,12 @@ def _kldiv_kernel_backward(
|
|
115
116
|
|
116
117
|
def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
|
117
118
|
BT, V = y_pred.shape
|
118
|
-
|
119
|
-
|
120
|
-
|
119
|
+
BLOCK_SIZE = (
|
120
|
+
min(8192, triton.next_power_of_2(V))
|
121
|
+
if infer_device() == "xpu"
|
122
|
+
else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
123
|
+
)
|
124
|
+
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
121
125
|
|
122
126
|
grid = (BT,)
|
123
127
|
reduction = _str_to_reduction_mode[reduction]
|
@@ -155,9 +159,12 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
|
|
155
159
|
|
156
160
|
def kldiv_backward_triton(target, grad_output, new_grads, log_target):
|
157
161
|
BT, V = target.shape
|
158
|
-
|
159
|
-
|
160
|
-
|
162
|
+
BLOCK_SIZE = (
|
163
|
+
min(8192, triton.next_power_of_2(V))
|
164
|
+
if infer_device() == "xpu"
|
165
|
+
else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
|
166
|
+
)
|
167
|
+
num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
|
161
168
|
|
162
169
|
grid = (BT,)
|
163
170
|
|
@@ -23,7 +23,7 @@ liger_kernel/ops/fused_linear_jsd.py,sha256=CSoprxb-YcJy-YUKiTcYkxN8sb9h2kdk_iHu
|
|
23
23
|
liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
|
24
24
|
liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
|
25
25
|
liger_kernel/ops/jsd.py,sha256=onHp5T3MbvJaVz5Vup7Ww6EQp_HTaZeayTjJk6FgQMY,7042
|
26
|
-
liger_kernel/ops/kl_div.py,sha256=
|
26
|
+
liger_kernel/ops/kl_div.py,sha256=ZjGdDLKWksHT9dZ0xF_TDgAkj5cuMTwwT5tr9E-_24o,8734
|
27
27
|
liger_kernel/ops/layer_norm.py,sha256=vWCyOm-F2GMAilB-ozJcFeUQQLCJoTE_uiXq-_0uYuI,8356
|
28
28
|
liger_kernel/ops/qwen2vl_mrope.py,sha256=3GExhYpLgB4VUtyZyjRk8XjEur3W4EWF6HQ67ML5vBU,8481
|
29
29
|
liger_kernel/ops/rms_norm.py,sha256=PP27OIBmV9By63i13jot9ylDowW0nuxY_JFIkaPLgL4,12078
|
@@ -74,9 +74,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7H
|
|
74
74
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
|
75
75
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
76
76
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
77
|
-
liger_kernel_nightly-0.5.6.
|
78
|
-
liger_kernel_nightly-0.5.6.
|
79
|
-
liger_kernel_nightly-0.5.6.
|
80
|
-
liger_kernel_nightly-0.5.6.
|
81
|
-
liger_kernel_nightly-0.5.6.
|
82
|
-
liger_kernel_nightly-0.5.6.
|
77
|
+
liger_kernel_nightly-0.5.6.dev20250411201510.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
78
|
+
liger_kernel_nightly-0.5.6.dev20250411201510.dist-info/METADATA,sha256=exdcHfLuKkUQ2NIene0sQ5hEn8mB98YKJ43XfirrGwM,23297
|
79
|
+
liger_kernel_nightly-0.5.6.dev20250411201510.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
80
|
+
liger_kernel_nightly-0.5.6.dev20250411201510.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
81
|
+
liger_kernel_nightly-0.5.6.dev20250411201510.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
82
|
+
liger_kernel_nightly-0.5.6.dev20250411201510.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|