liger-kernel-nightly 0.3.1.dev20241031232423__py3-none-any.whl → 0.3.1.dev20241101201851__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of liger-kernel-nightly might be problematic. Click here for more details.
- liger_kernel/ops/fused_linear_jsd.py +18 -11
- {liger_kernel_nightly-0.3.1.dev20241031232423.dist-info → liger_kernel_nightly-0.3.1.dev20241101201851.dist-info}/METADATA +2 -2
- {liger_kernel_nightly-0.3.1.dev20241031232423.dist-info → liger_kernel_nightly-0.3.1.dev20241101201851.dist-info}/RECORD +12 -12
- {liger_kernel_nightly-0.3.1.dev20241031232423.dist-info → liger_kernel_nightly-0.3.1.dev20241101201851.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.3.1.dev20241031232423.dist-info → liger_kernel_nightly-0.3.1.dev20241101201851.dist-info}/LICENSE-Apache-2.0 +0 -0
- {liger_kernel_nightly-0.3.1.dev20241031232423.dist-info → liger_kernel_nightly-0.3.1.dev20241101201851.dist-info}/LICENSE-MIT-AutoAWQ +0 -0
- {liger_kernel_nightly-0.3.1.dev20241031232423.dist-info → liger_kernel_nightly-0.3.1.dev20241101201851.dist-info}/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
- {liger_kernel_nightly-0.3.1.dev20241031232423.dist-info → liger_kernel_nightly-0.3.1.dev20241101201851.dist-info}/LICENSE-MIT-llmc +0 -0
- {liger_kernel_nightly-0.3.1.dev20241031232423.dist-info → liger_kernel_nightly-0.3.1.dev20241101201851.dist-info}/LICENSE-MIT-triton +0 -0
- {liger_kernel_nightly-0.3.1.dev20241031232423.dist-info → liger_kernel_nightly-0.3.1.dev20241101201851.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.3.1.dev20241031232423.dist-info → liger_kernel_nightly-0.3.1.dev20241101201851.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.3.1.dev20241031232423.dist-info → liger_kernel_nightly-0.3.1.dev20241101201851.dist-info}/top_level.txt +0 -0
|
@@ -4,7 +4,7 @@ import torch
|
|
|
4
4
|
import triton
|
|
5
5
|
|
|
6
6
|
from liger_kernel.ops.jsd import _jsd_kernel
|
|
7
|
-
from liger_kernel.ops.utils import element_mul_kernel
|
|
7
|
+
from liger_kernel.ops.utils import amp_custom_bwd, amp_custom_fwd, element_mul_kernel
|
|
8
8
|
|
|
9
9
|
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
|
|
10
10
|
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
|
|
@@ -24,6 +24,7 @@ def fused_linear_jsd_forward(
|
|
|
24
24
|
temperature,
|
|
25
25
|
):
|
|
26
26
|
device = student_input.device
|
|
27
|
+
dtype = student_input.dtype
|
|
27
28
|
|
|
28
29
|
# inputs have shape: BT x H
|
|
29
30
|
# materialized activations will have shape: BT x V
|
|
@@ -64,9 +65,15 @@ def fused_linear_jsd_forward(
|
|
|
64
65
|
student_input_chunk = student_input[start_idx:end_idx]
|
|
65
66
|
teacher_input_chunk = teacher_input[start_idx:end_idx]
|
|
66
67
|
|
|
67
|
-
#
|
|
68
|
-
|
|
69
|
-
|
|
68
|
+
# shape: chunk_size x V
|
|
69
|
+
# For anything starting from logits to the final JSD loss, we do computation
|
|
70
|
+
# in FP32 to avoid losing numerical stability.
|
|
71
|
+
student_logits_chunk = (student_input_chunk @ student_weight.t()).to(
|
|
72
|
+
torch.float32
|
|
73
|
+
)
|
|
74
|
+
teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(
|
|
75
|
+
torch.float32
|
|
76
|
+
)
|
|
70
77
|
chunk_n_rows = student_logits_chunk.shape[0]
|
|
71
78
|
|
|
72
79
|
# unreduced loss
|
|
@@ -113,18 +120,16 @@ def fused_linear_jsd_forward(
|
|
|
113
120
|
student_prob_chunk.shape
|
|
114
121
|
)
|
|
115
122
|
) / temperature
|
|
123
|
+
# now we traverse back to grad w.r.t. input to `lm_head` and grad
|
|
124
|
+
# w.r.t. `lm_head` which should be computed in original dtype
|
|
125
|
+
student_logits_chunk = student_logits_chunk.to(dtype)
|
|
116
126
|
grad_input[start_idx:end_idx] = student_logits_chunk @ student_weight
|
|
117
127
|
|
|
118
128
|
if grad_weight is not None:
|
|
119
|
-
|
|
120
|
-
input=grad_weight,
|
|
121
|
-
mat1=student_logits_chunk.t(), # gradients of logits_chunk
|
|
122
|
-
mat2=student_input_chunk,
|
|
123
|
-
out=grad_weight,
|
|
124
|
-
)
|
|
129
|
+
grad_weight.add_(student_logits_chunk.t() @ student_input_chunk)
|
|
125
130
|
|
|
126
131
|
loss = torch.sum(loss_1d)
|
|
127
|
-
return loss
|
|
132
|
+
return loss, grad_input, grad_weight
|
|
128
133
|
|
|
129
134
|
|
|
130
135
|
def fused_linear_jsd_backward(grad_output, grad_input, grad_weight):
|
|
@@ -172,6 +177,7 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
|
|
|
172
177
|
"""
|
|
173
178
|
|
|
174
179
|
@staticmethod
|
|
180
|
+
@amp_custom_fwd
|
|
175
181
|
def forward(
|
|
176
182
|
ctx,
|
|
177
183
|
student_input: torch.Tensor,
|
|
@@ -225,6 +231,7 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
|
|
|
225
231
|
return loss
|
|
226
232
|
|
|
227
233
|
@staticmethod
|
|
234
|
+
@amp_custom_bwd
|
|
228
235
|
def backward(ctx, grad_output):
|
|
229
236
|
(grad_input, grad_weight) = ctx.saved_tensors
|
|
230
237
|
grad_input, grad_weight = fused_linear_jsd_backward(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: liger_kernel_nightly
|
|
3
|
-
Version: 0.3.1.
|
|
3
|
+
Version: 0.3.1.dev20241101201851
|
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
|
@@ -36,7 +36,7 @@ License-File: LICENSE-MIT-llmc
|
|
|
36
36
|
License-File: LICENSE-MIT-triton
|
|
37
37
|
License-File: NOTICE
|
|
38
38
|
Requires-Dist: torch>=2.1.2
|
|
39
|
-
Requires-Dist: triton>=2.3.
|
|
39
|
+
Requires-Dist: triton>=2.3.1
|
|
40
40
|
Provides-Extra: dev
|
|
41
41
|
Requires-Dist: transformers>=4.44.2; extra == "dev"
|
|
42
42
|
Requires-Dist: matplotlib>=3.7.2; extra == "dev"
|
|
@@ -2,7 +2,7 @@ liger_kernel/env_report.py,sha256=LFUJ6UMkFFGPBYXBlqHFGy4bhsemEpSI-_1edSazlHI,11
|
|
|
2
2
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
liger_kernel/ops/cross_entropy.py,sha256=OB3nvIONLB_sj9LO6UQv1qLnf861k-pR58RtwgoiyYA,11192
|
|
4
4
|
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=qg7qBQFLDJClnkUOGhFFHPSW_x7rPvQekbm_4OOYxys,9331
|
|
5
|
-
liger_kernel/ops/fused_linear_jsd.py,sha256=
|
|
5
|
+
liger_kernel/ops/fused_linear_jsd.py,sha256=ZQUxNqm3yOokZhUId0sIfob_3e43rGbMTDxeCk9A92o,9549
|
|
6
6
|
liger_kernel/ops/geglu.py,sha256=MQL4zyzneZqZYUGPvb1QjI_EYT9_pKfSDgR25WD9jrI,4127
|
|
7
7
|
liger_kernel/ops/jsd.py,sha256=anWfdioucxZy4JQfTvbHBR-IQrZKeH-gBF1MHwwTuTQ,5781
|
|
8
8
|
liger_kernel/ops/kl_div.py,sha256=qnmtFQwuO3FR7Ovup_DDzpkD1A1LpwOaWlcO6K9ysHk,8342
|
|
@@ -40,14 +40,14 @@ liger_kernel/transformers/model/qwen2.py,sha256=3inWFXGHYT7wA10OR6bq3mDUBrr10AS5
|
|
|
40
40
|
liger_kernel/transformers/model/qwen2_vl.py,sha256=ymsm9aQpSUiSU12GY8FO608p9dSHOz4TCnNI1htX5bk,6975
|
|
41
41
|
liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
|
|
42
42
|
liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
|
|
43
|
-
liger_kernel_nightly-0.3.1.
|
|
44
|
-
liger_kernel_nightly-0.3.1.
|
|
45
|
-
liger_kernel_nightly-0.3.1.
|
|
46
|
-
liger_kernel_nightly-0.3.1.
|
|
47
|
-
liger_kernel_nightly-0.3.1.
|
|
48
|
-
liger_kernel_nightly-0.3.1.
|
|
49
|
-
liger_kernel_nightly-0.3.1.
|
|
50
|
-
liger_kernel_nightly-0.3.1.
|
|
51
|
-
liger_kernel_nightly-0.3.1.
|
|
52
|
-
liger_kernel_nightly-0.3.1.
|
|
53
|
-
liger_kernel_nightly-0.3.1.
|
|
43
|
+
liger_kernel_nightly-0.3.1.dev20241101201851.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
|
44
|
+
liger_kernel_nightly-0.3.1.dev20241101201851.dist-info/LICENSE-Apache-2.0,sha256=NRaCIsL9eblGS35gk4WKTC0usNYnR_mgRHJTKqz2_UE,11348
|
|
45
|
+
liger_kernel_nightly-0.3.1.dev20241101201851.dist-info/LICENSE-MIT-AutoAWQ,sha256=pfiOyInrAPY3xQbvV1i-gOqNZK7QEyIepT1IbqOYYYo,1067
|
|
46
|
+
liger_kernel_nightly-0.3.1.dev20241101201851.dist-info/LICENSE-MIT-Efficient-Cross-Entropy,sha256=PaC9HqyFYTy-ClS0H8Zfa2motJuTppjECXmjHwJcaOk,1063
|
|
47
|
+
liger_kernel_nightly-0.3.1.dev20241101201851.dist-info/LICENSE-MIT-llmc,sha256=kyFLt_XUcXS88CuxQt5-PjOcLjpJP2m-T4gtqZf3GLc,1071
|
|
48
|
+
liger_kernel_nightly-0.3.1.dev20241101201851.dist-info/LICENSE-MIT-triton,sha256=wL6W8IwsKiyHtzXubg8TCXhRZuo8S83EPdqXffYtqWg,1131
|
|
49
|
+
liger_kernel_nightly-0.3.1.dev20241101201851.dist-info/METADATA,sha256=qx1HwgDXyy5RdAAuvpVrmpnLA4CJ6b6A7zGVQrU-rpg,27717
|
|
50
|
+
liger_kernel_nightly-0.3.1.dev20241101201851.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
|
51
|
+
liger_kernel_nightly-0.3.1.dev20241101201851.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
|
52
|
+
liger_kernel_nightly-0.3.1.dev20241101201851.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
|
53
|
+
liger_kernel_nightly-0.3.1.dev20241101201851.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|