liger-kernel-nightly 0.3.1.dev20241031232423__py3-none-any.whl → 0.3.1.dev20241101201851__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (12) hide show
  1. liger_kernel/ops/fused_linear_jsd.py +18 -11
  2. {liger_kernel_nightly-0.3.1.dev20241031232423.dist-info → liger_kernel_nightly-0.3.1.dev20241101201851.dist-info}/METADATA +2 -2
  3. {liger_kernel_nightly-0.3.1.dev20241031232423.dist-info → liger_kernel_nightly-0.3.1.dev20241101201851.dist-info}/RECORD +12 -12
  4. {liger_kernel_nightly-0.3.1.dev20241031232423.dist-info → liger_kernel_nightly-0.3.1.dev20241101201851.dist-info}/LICENSE +0 -0
  5. {liger_kernel_nightly-0.3.1.dev20241031232423.dist-info → liger_kernel_nightly-0.3.1.dev20241101201851.dist-info}/LICENSE-Apache-2.0 +0 -0
  6. {liger_kernel_nightly-0.3.1.dev20241031232423.dist-info → liger_kernel_nightly-0.3.1.dev20241101201851.dist-info}/LICENSE-MIT-AutoAWQ +0 -0
  7. {liger_kernel_nightly-0.3.1.dev20241031232423.dist-info → liger_kernel_nightly-0.3.1.dev20241101201851.dist-info}/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  8. {liger_kernel_nightly-0.3.1.dev20241031232423.dist-info → liger_kernel_nightly-0.3.1.dev20241101201851.dist-info}/LICENSE-MIT-llmc +0 -0
  9. {liger_kernel_nightly-0.3.1.dev20241031232423.dist-info → liger_kernel_nightly-0.3.1.dev20241101201851.dist-info}/LICENSE-MIT-triton +0 -0
  10. {liger_kernel_nightly-0.3.1.dev20241031232423.dist-info → liger_kernel_nightly-0.3.1.dev20241101201851.dist-info}/NOTICE +0 -0
  11. {liger_kernel_nightly-0.3.1.dev20241031232423.dist-info → liger_kernel_nightly-0.3.1.dev20241101201851.dist-info}/WHEEL +0 -0
  12. {liger_kernel_nightly-0.3.1.dev20241031232423.dist-info → liger_kernel_nightly-0.3.1.dev20241101201851.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,7 @@ import torch
4
4
  import triton
5
5
 
6
6
  from liger_kernel.ops.jsd import _jsd_kernel
7
- from liger_kernel.ops.utils import element_mul_kernel
7
+ from liger_kernel.ops.utils import amp_custom_bwd, amp_custom_fwd, element_mul_kernel
8
8
 
9
9
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
10
10
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
@@ -24,6 +24,7 @@ def fused_linear_jsd_forward(
24
24
  temperature,
25
25
  ):
26
26
  device = student_input.device
27
+ dtype = student_input.dtype
27
28
 
28
29
  # inputs have shape: BT x H
29
30
  # materialized activations will have shape: BT x V
@@ -64,9 +65,15 @@ def fused_linear_jsd_forward(
64
65
  student_input_chunk = student_input[start_idx:end_idx]
65
66
  teacher_input_chunk = teacher_input[start_idx:end_idx]
66
67
 
67
- # when doing matmul, use the original precision, shape: chunk_size x V
68
- student_logits_chunk = student_input_chunk @ student_weight.t()
69
- teacher_logits_chunk = teacher_input_chunk @ teacher_weight.t()
68
+ # shape: chunk_size x V
69
+ # For anything starting from logits to the final JSD loss, we do computation
70
+ # in FP32 to avoid losing numerical stability.
71
+ student_logits_chunk = (student_input_chunk @ student_weight.t()).to(
72
+ torch.float32
73
+ )
74
+ teacher_logits_chunk = (teacher_input_chunk @ teacher_weight.t()).to(
75
+ torch.float32
76
+ )
70
77
  chunk_n_rows = student_logits_chunk.shape[0]
71
78
 
72
79
  # unreduced loss
@@ -113,18 +120,16 @@ def fused_linear_jsd_forward(
113
120
  student_prob_chunk.shape
114
121
  )
115
122
  ) / temperature
123
+ # now we traverse back to grad w.r.t. input to `lm_head` and grad
124
+ # w.r.t. `lm_head` which should be computed in original dtype
125
+ student_logits_chunk = student_logits_chunk.to(dtype)
116
126
  grad_input[start_idx:end_idx] = student_logits_chunk @ student_weight
117
127
 
118
128
  if grad_weight is not None:
119
- torch.addmm(
120
- input=grad_weight,
121
- mat1=student_logits_chunk.t(), # gradients of logits_chunk
122
- mat2=student_input_chunk,
123
- out=grad_weight,
124
- )
129
+ grad_weight.add_(student_logits_chunk.t() @ student_input_chunk)
125
130
 
126
131
  loss = torch.sum(loss_1d)
127
- return loss.to(student_input.dtype), grad_input, grad_weight
132
+ return loss, grad_input, grad_weight
128
133
 
129
134
 
130
135
  def fused_linear_jsd_backward(grad_output, grad_input, grad_weight):
@@ -172,6 +177,7 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
172
177
  """
173
178
 
174
179
  @staticmethod
180
+ @amp_custom_fwd
175
181
  def forward(
176
182
  ctx,
177
183
  student_input: torch.Tensor,
@@ -225,6 +231,7 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
225
231
  return loss
226
232
 
227
233
  @staticmethod
234
+ @amp_custom_bwd
228
235
  def backward(ctx, grad_output):
229
236
  (grad_input, grad_weight) = ctx.saved_tensors
230
237
  grad_input, grad_weight = fused_linear_jsd_backward(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.3.1.dev20241031232423
3
+ Version: 0.3.1.dev20241101201851
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -36,7 +36,7 @@ License-File: LICENSE-MIT-llmc
36
36
  License-File: LICENSE-MIT-triton
37
37
  License-File: NOTICE
38
38
  Requires-Dist: torch>=2.1.2
39
- Requires-Dist: triton>=2.3.0
39
+ Requires-Dist: triton>=2.3.1
40
40
  Provides-Extra: dev
41
41
  Requires-Dist: transformers>=4.44.2; extra == "dev"
42
42
  Requires-Dist: matplotlib>=3.7.2; extra == "dev"
@@ -2,7 +2,7 @@ liger_kernel/env_report.py,sha256=LFUJ6UMkFFGPBYXBlqHFGy4bhsemEpSI-_1edSazlHI,11
2
2
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  liger_kernel/ops/cross_entropy.py,sha256=OB3nvIONLB_sj9LO6UQv1qLnf861k-pR58RtwgoiyYA,11192
4
4
  liger_kernel/ops/fused_linear_cross_entropy.py,sha256=qg7qBQFLDJClnkUOGhFFHPSW_x7rPvQekbm_4OOYxys,9331
5
- liger_kernel/ops/fused_linear_jsd.py,sha256=eZ8y4GPtPRE0QcNNMLX8l4gSEMPzA3ZuzknfbAbiREA,9234
5
+ liger_kernel/ops/fused_linear_jsd.py,sha256=ZQUxNqm3yOokZhUId0sIfob_3e43rGbMTDxeCk9A92o,9549
6
6
  liger_kernel/ops/geglu.py,sha256=MQL4zyzneZqZYUGPvb1QjI_EYT9_pKfSDgR25WD9jrI,4127
7
7
  liger_kernel/ops/jsd.py,sha256=anWfdioucxZy4JQfTvbHBR-IQrZKeH-gBF1MHwwTuTQ,5781
8
8
  liger_kernel/ops/kl_div.py,sha256=qnmtFQwuO3FR7Ovup_DDzpkD1A1LpwOaWlcO6K9ysHk,8342
@@ -40,14 +40,14 @@ liger_kernel/transformers/model/qwen2.py,sha256=3inWFXGHYT7wA10OR6bq3mDUBrr10AS5
40
40
  liger_kernel/transformers/model/qwen2_vl.py,sha256=ymsm9aQpSUiSU12GY8FO608p9dSHOz4TCnNI1htX5bk,6975
41
41
  liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
42
42
  liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
43
- liger_kernel_nightly-0.3.1.dev20241031232423.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
44
- liger_kernel_nightly-0.3.1.dev20241031232423.dist-info/LICENSE-Apache-2.0,sha256=NRaCIsL9eblGS35gk4WKTC0usNYnR_mgRHJTKqz2_UE,11348
45
- liger_kernel_nightly-0.3.1.dev20241031232423.dist-info/LICENSE-MIT-AutoAWQ,sha256=pfiOyInrAPY3xQbvV1i-gOqNZK7QEyIepT1IbqOYYYo,1067
46
- liger_kernel_nightly-0.3.1.dev20241031232423.dist-info/LICENSE-MIT-Efficient-Cross-Entropy,sha256=PaC9HqyFYTy-ClS0H8Zfa2motJuTppjECXmjHwJcaOk,1063
47
- liger_kernel_nightly-0.3.1.dev20241031232423.dist-info/LICENSE-MIT-llmc,sha256=kyFLt_XUcXS88CuxQt5-PjOcLjpJP2m-T4gtqZf3GLc,1071
48
- liger_kernel_nightly-0.3.1.dev20241031232423.dist-info/LICENSE-MIT-triton,sha256=wL6W8IwsKiyHtzXubg8TCXhRZuo8S83EPdqXffYtqWg,1131
49
- liger_kernel_nightly-0.3.1.dev20241031232423.dist-info/METADATA,sha256=LPXFEjmB3kvrLlkGnL0jwPSUNPXavpb5c_LLwSpXZ78,27717
50
- liger_kernel_nightly-0.3.1.dev20241031232423.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
51
- liger_kernel_nightly-0.3.1.dev20241031232423.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
52
- liger_kernel_nightly-0.3.1.dev20241031232423.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
53
- liger_kernel_nightly-0.3.1.dev20241031232423.dist-info/RECORD,,
43
+ liger_kernel_nightly-0.3.1.dev20241101201851.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
44
+ liger_kernel_nightly-0.3.1.dev20241101201851.dist-info/LICENSE-Apache-2.0,sha256=NRaCIsL9eblGS35gk4WKTC0usNYnR_mgRHJTKqz2_UE,11348
45
+ liger_kernel_nightly-0.3.1.dev20241101201851.dist-info/LICENSE-MIT-AutoAWQ,sha256=pfiOyInrAPY3xQbvV1i-gOqNZK7QEyIepT1IbqOYYYo,1067
46
+ liger_kernel_nightly-0.3.1.dev20241101201851.dist-info/LICENSE-MIT-Efficient-Cross-Entropy,sha256=PaC9HqyFYTy-ClS0H8Zfa2motJuTppjECXmjHwJcaOk,1063
47
+ liger_kernel_nightly-0.3.1.dev20241101201851.dist-info/LICENSE-MIT-llmc,sha256=kyFLt_XUcXS88CuxQt5-PjOcLjpJP2m-T4gtqZf3GLc,1071
48
+ liger_kernel_nightly-0.3.1.dev20241101201851.dist-info/LICENSE-MIT-triton,sha256=wL6W8IwsKiyHtzXubg8TCXhRZuo8S83EPdqXffYtqWg,1131
49
+ liger_kernel_nightly-0.3.1.dev20241101201851.dist-info/METADATA,sha256=qx1HwgDXyy5RdAAuvpVrmpnLA4CJ6b6A7zGVQrU-rpg,27717
50
+ liger_kernel_nightly-0.3.1.dev20241101201851.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
51
+ liger_kernel_nightly-0.3.1.dev20241101201851.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
52
+ liger_kernel_nightly-0.3.1.dev20241101201851.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
53
+ liger_kernel_nightly-0.3.1.dev20241101201851.dist-info/RECORD,,