liger-kernel-nightly 0.3.1.dev20241025165359__py3-none-any.whl → 0.3.1.dev20241030045958__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (15) hide show
  1. liger_kernel/ops/fused_linear_cross_entropy.py +4 -4
  2. liger_kernel/ops/fused_linear_jsd.py +3 -1
  3. liger_kernel/ops/jsd.py +1 -1
  4. liger_kernel/ops/utils.py +13 -0
  5. {liger_kernel_nightly-0.3.1.dev20241025165359.dist-info → liger_kernel_nightly-0.3.1.dev20241030045958.dist-info}/METADATA +1 -1
  6. {liger_kernel_nightly-0.3.1.dev20241025165359.dist-info → liger_kernel_nightly-0.3.1.dev20241030045958.dist-info}/RECORD +15 -15
  7. {liger_kernel_nightly-0.3.1.dev20241025165359.dist-info → liger_kernel_nightly-0.3.1.dev20241030045958.dist-info}/WHEEL +1 -1
  8. {liger_kernel_nightly-0.3.1.dev20241025165359.dist-info → liger_kernel_nightly-0.3.1.dev20241030045958.dist-info}/LICENSE +0 -0
  9. {liger_kernel_nightly-0.3.1.dev20241025165359.dist-info → liger_kernel_nightly-0.3.1.dev20241030045958.dist-info}/LICENSE-Apache-2.0 +0 -0
  10. {liger_kernel_nightly-0.3.1.dev20241025165359.dist-info → liger_kernel_nightly-0.3.1.dev20241030045958.dist-info}/LICENSE-MIT-AutoAWQ +0 -0
  11. {liger_kernel_nightly-0.3.1.dev20241025165359.dist-info → liger_kernel_nightly-0.3.1.dev20241030045958.dist-info}/LICENSE-MIT-Efficient Cross Entropy +0 -0
  12. {liger_kernel_nightly-0.3.1.dev20241025165359.dist-info → liger_kernel_nightly-0.3.1.dev20241030045958.dist-info}/LICENSE-MIT-llmc +0 -0
  13. {liger_kernel_nightly-0.3.1.dev20241025165359.dist-info → liger_kernel_nightly-0.3.1.dev20241030045958.dist-info}/LICENSE-MIT-triton +0 -0
  14. {liger_kernel_nightly-0.3.1.dev20241025165359.dist-info → liger_kernel_nightly-0.3.1.dev20241030045958.dist-info}/NOTICE +0 -0
  15. {liger_kernel_nightly-0.3.1.dev20241025165359.dist-info → liger_kernel_nightly-0.3.1.dev20241030045958.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ import torch
2
2
  import triton
3
3
 
4
4
  from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel
5
- from liger_kernel.ops.utils import element_mul_kernel
5
+ from liger_kernel.ops.utils import amp_custom_bwd, amp_custom_fwd, element_mul_kernel
6
6
 
7
7
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
8
8
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
@@ -19,9 +19,7 @@ def fused_linear_cross_entropy_forward(
19
19
  label_smoothing=0.0,
20
20
  reduction="mean",
21
21
  ):
22
- dtype = (
23
- torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else _input.dtype
24
- )
22
+ dtype = _input.dtype
25
23
  device = _input.device
26
24
 
27
25
  # inputs have shape: BT x H
@@ -189,6 +187,7 @@ def fused_linear_cross_entropy_backward(
189
187
 
190
188
  class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
191
189
  @staticmethod
190
+ @amp_custom_fwd
192
191
  def forward(
193
192
  ctx,
194
193
  _input,
@@ -228,6 +227,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
228
227
  return loss
229
228
 
230
229
  @staticmethod
230
+ @amp_custom_bwd
231
231
  def backward(ctx, grad_output):
232
232
  (grad_input, grad_weight, grad_bias) = ctx.saved_tensors
233
233
  grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
@@ -92,7 +92,9 @@ def fused_linear_jsd_forward(
92
92
  dX_ptr=student_prob_chunk,
93
93
  dX_stride=student_prob_chunk.stride(-2),
94
94
  label_ptr=(
95
- shift_labels if has_label else torch.empty(1, device=device)
95
+ shift_labels[start_idx:end_idx]
96
+ if has_label
97
+ else torch.empty(1, device=device)
96
98
  ), # dummy ptr if no label
97
99
  beta=jsd_beta,
98
100
  n_non_ignore=n_non_ignore,
liger_kernel/ops/jsd.py CHANGED
@@ -19,7 +19,7 @@ def _jsd_kernel(
19
19
  dX_stride,
20
20
  label_ptr,
21
21
  beta,
22
- n_non_ignore,
22
+ n_non_ignore: int,
23
23
  ignore_index: tl.constexpr,
24
24
  n_cols,
25
25
  BLOCK_SIZE: tl.constexpr,
liger_kernel/ops/utils.py CHANGED
@@ -12,6 +12,7 @@ Modifications made by Yanning Chen, 2024.
12
12
 
13
13
  import functools
14
14
  import importlib
15
+ import operator
15
16
  from typing import Callable
16
17
 
17
18
  import torch
@@ -63,6 +64,18 @@ def compare_version(package: str, operator: Callable, target: str):
63
64
  return operator(pkg_version, Version(target))
64
65
 
65
66
 
67
+ def get_amp_custom_fwd_bwd() -> Callable:
68
+ if compare_version("torch", operator.ge, "2.4.0"):
69
+ return (
70
+ functools.partial(torch.amp.custom_fwd, device_type="cuda"),
71
+ functools.partial(torch.amp.custom_bwd, device_type="cuda"),
72
+ )
73
+ return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
74
+
75
+
76
+ amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd()
77
+
78
+
66
79
  torch_to_triton_dtype = {
67
80
  torch.float32: tl.float32,
68
81
  torch.float16: tl.float16,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.3.1.dev20241025165359
3
+ Version: 0.3.1.dev20241030045958
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -1,16 +1,16 @@
1
1
  liger_kernel/env_report.py,sha256=LFUJ6UMkFFGPBYXBlqHFGy4bhsemEpSI-_1edSazlHI,1130
2
2
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  liger_kernel/ops/cross_entropy.py,sha256=OB3nvIONLB_sj9LO6UQv1qLnf861k-pR58RtwgoiyYA,11192
4
- liger_kernel/ops/fused_linear_cross_entropy.py,sha256=NgkdAghxAX2F8EIi0HxTs2ODCF875g9-WXK3tfyqx84,9342
5
- liger_kernel/ops/fused_linear_jsd.py,sha256=4Lt-ffWWChqXdsVg5Q8RDZkBXhOjk3DLYeaSeUrTRrg,9183
4
+ liger_kernel/ops/fused_linear_cross_entropy.py,sha256=qg7qBQFLDJClnkUOGhFFHPSW_x7rPvQekbm_4OOYxys,9331
5
+ liger_kernel/ops/fused_linear_jsd.py,sha256=eZ8y4GPtPRE0QcNNMLX8l4gSEMPzA3ZuzknfbAbiREA,9234
6
6
  liger_kernel/ops/geglu.py,sha256=MQL4zyzneZqZYUGPvb1QjI_EYT9_pKfSDgR25WD9jrI,4127
7
- liger_kernel/ops/jsd.py,sha256=iash2cEG5zIdm73_XLCtbNWjLrIIxXlBNYbQOAAANzk,5776
7
+ liger_kernel/ops/jsd.py,sha256=anWfdioucxZy4JQfTvbHBR-IQrZKeH-gBF1MHwwTuTQ,5781
8
8
  liger_kernel/ops/kl_div.py,sha256=qnmtFQwuO3FR7Ovup_DDzpkD1A1LpwOaWlcO6K9ysHk,8342
9
9
  liger_kernel/ops/layer_norm.py,sha256=unGMYMOPqtkM9aTrokhcqgPmsV2AUN7Yzv86isVB9OI,7422
10
10
  liger_kernel/ops/rms_norm.py,sha256=9S9wyZLmzNyJlBxV4vbv4p5es7bGP-m_5wK9JC6JIdA,10911
11
11
  liger_kernel/ops/rope.py,sha256=jrzaA9-6Orn44y_IIam9_YNPQxOFK2FrIRNfFea4EtU,8513
12
12
  liger_kernel/ops/swiglu.py,sha256=Fwxtd76rhHKT9ShQAGca9RsnASplAVxtYKHmiT73_yA,2994
13
- liger_kernel/ops/utils.py,sha256=6JQH4ULilZiuig7kqiY5LaQQhZHKcf_NAxU0FBo_gjE,3276
13
+ liger_kernel/ops/utils.py,sha256=w0QT3ynUK2vYUAgsVvfoENTUu5L-2TuB3IYt8JaXlNA,3688
14
14
  liger_kernel/ops/experimental/embedding.py,sha256=LYR66dB-jhvhtUjeV4PnNro-n77J1mdlmpSLSxB3Y6U,4186
15
15
  liger_kernel/ops/experimental/mm_int8int2.py,sha256=JpGVZCgRC6T8XMUJ_QbZRS2XU1bh0urIZphs5DTc1mY,13358
16
16
  liger_kernel/transformers/__init__.py,sha256=gia-eBxr7TLxU0GdDf8AfCY4WgDlFLqIGSt7EoQGsBA,1336
@@ -40,14 +40,14 @@ liger_kernel/transformers/model/qwen2.py,sha256=3inWFXGHYT7wA10OR6bq3mDUBrr10AS5
40
40
  liger_kernel/transformers/model/qwen2_vl.py,sha256=ymsm9aQpSUiSU12GY8FO608p9dSHOz4TCnNI1htX5bk,6975
41
41
  liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
42
42
  liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
43
- liger_kernel_nightly-0.3.1.dev20241025165359.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
44
- liger_kernel_nightly-0.3.1.dev20241025165359.dist-info/LICENSE-Apache-2.0,sha256=NRaCIsL9eblGS35gk4WKTC0usNYnR_mgRHJTKqz2_UE,11348
45
- liger_kernel_nightly-0.3.1.dev20241025165359.dist-info/LICENSE-MIT-AutoAWQ,sha256=pfiOyInrAPY3xQbvV1i-gOqNZK7QEyIepT1IbqOYYYo,1067
46
- liger_kernel_nightly-0.3.1.dev20241025165359.dist-info/LICENSE-MIT-Efficient Cross Entropy,sha256=PaC9HqyFYTy-ClS0H8Zfa2motJuTppjECXmjHwJcaOk,1063
47
- liger_kernel_nightly-0.3.1.dev20241025165359.dist-info/LICENSE-MIT-llmc,sha256=kyFLt_XUcXS88CuxQt5-PjOcLjpJP2m-T4gtqZf3GLc,1071
48
- liger_kernel_nightly-0.3.1.dev20241025165359.dist-info/LICENSE-MIT-triton,sha256=wL6W8IwsKiyHtzXubg8TCXhRZuo8S83EPdqXffYtqWg,1131
49
- liger_kernel_nightly-0.3.1.dev20241025165359.dist-info/METADATA,sha256=TptnGUbkLj4f4i38XU0IsPlYbV-2KAIsIWZlC-h3PWE,27717
50
- liger_kernel_nightly-0.3.1.dev20241025165359.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
51
- liger_kernel_nightly-0.3.1.dev20241025165359.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
52
- liger_kernel_nightly-0.3.1.dev20241025165359.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
53
- liger_kernel_nightly-0.3.1.dev20241025165359.dist-info/RECORD,,
43
+ liger_kernel_nightly-0.3.1.dev20241030045958.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
44
+ liger_kernel_nightly-0.3.1.dev20241030045958.dist-info/LICENSE-Apache-2.0,sha256=NRaCIsL9eblGS35gk4WKTC0usNYnR_mgRHJTKqz2_UE,11348
45
+ liger_kernel_nightly-0.3.1.dev20241030045958.dist-info/LICENSE-MIT-AutoAWQ,sha256=pfiOyInrAPY3xQbvV1i-gOqNZK7QEyIepT1IbqOYYYo,1067
46
+ liger_kernel_nightly-0.3.1.dev20241030045958.dist-info/LICENSE-MIT-Efficient Cross Entropy,sha256=PaC9HqyFYTy-ClS0H8Zfa2motJuTppjECXmjHwJcaOk,1063
47
+ liger_kernel_nightly-0.3.1.dev20241030045958.dist-info/LICENSE-MIT-llmc,sha256=kyFLt_XUcXS88CuxQt5-PjOcLjpJP2m-T4gtqZf3GLc,1071
48
+ liger_kernel_nightly-0.3.1.dev20241030045958.dist-info/LICENSE-MIT-triton,sha256=wL6W8IwsKiyHtzXubg8TCXhRZuo8S83EPdqXffYtqWg,1131
49
+ liger_kernel_nightly-0.3.1.dev20241030045958.dist-info/METADATA,sha256=aSchnUTDzM__RwWSMgJNr8BqjwjNzTd-ja3hz2NTBKc,27717
50
+ liger_kernel_nightly-0.3.1.dev20241030045958.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
51
+ liger_kernel_nightly-0.3.1.dev20241030045958.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
52
+ liger_kernel_nightly-0.3.1.dev20241030045958.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
53
+ liger_kernel_nightly-0.3.1.dev20241030045958.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.2.0)
2
+ Generator: setuptools (75.3.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5