liger-kernel-nightly 0.3.1.dev20241025165359__tar.gz → 0.3.1.dev20241030045958__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (58) hide show
  1. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/pyproject.toml +1 -1
  3. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/ops/fused_linear_cross_entropy.py +4 -4
  4. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/ops/fused_linear_jsd.py +3 -1
  5. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/ops/jsd.py +1 -1
  6. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/ops/utils.py +13 -0
  7. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
  8. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/LICENSE +0 -0
  9. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/LICENSE-Apache-2.0 +0 -0
  10. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/LICENSE-MIT-AutoAWQ +0 -0
  11. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/LICENSE-MIT-Efficient Cross Entropy +0 -0
  12. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/LICENSE-MIT-llmc +0 -0
  13. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/LICENSE-MIT-triton +0 -0
  14. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/NOTICE +0 -0
  15. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/README.md +0 -0
  16. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/setup.cfg +0 -0
  17. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/env_report.py +0 -0
  18. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/ops/__init__.py +0 -0
  19. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/ops/cross_entropy.py +0 -0
  20. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  21. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  22. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/ops/geglu.py +0 -0
  23. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/ops/kl_div.py +0 -0
  24. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/ops/layer_norm.py +0 -0
  25. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/ops/rms_norm.py +0 -0
  26. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/ops/rope.py +0 -0
  27. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/ops/swiglu.py +0 -0
  28. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/transformers/__init__.py +0 -0
  29. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/transformers/auto_model.py +0 -0
  30. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  31. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  32. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/transformers/functional.py +0 -0
  33. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  34. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  35. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/transformers/geglu.py +0 -0
  36. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/transformers/jsd.py +0 -0
  37. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/transformers/kl_div.py +0 -0
  38. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/transformers/layer_norm.py +0 -0
  39. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/transformers/model/__init__.py +0 -0
  40. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/transformers/model/gemma.py +0 -0
  41. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/transformers/model/llama.py +0 -0
  42. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/transformers/model/mistral.py +0 -0
  43. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  44. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/transformers/model/mllama.py +0 -0
  45. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/transformers/model/phi3.py +0 -0
  46. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  47. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  48. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  49. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/transformers/rms_norm.py +0 -0
  50. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/transformers/rope.py +0 -0
  51. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/transformers/swiglu.py +0 -0
  52. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  53. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/triton/__init__.py +0 -0
  54. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel/triton/monkey_patch.py +0 -0
  55. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
  56. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  57. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  58. {liger_kernel_nightly-0.3.1.dev20241025165359 → liger_kernel_nightly-0.3.1.dev20241030045958}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.3.1.dev20241025165359
3
+ Version: 0.3.1.dev20241030045958
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.3.1.dev20241025165359"
7
+ version = "0.3.1.dev20241030045958"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -2,7 +2,7 @@ import torch
2
2
  import triton
3
3
 
4
4
  from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel
5
- from liger_kernel.ops.utils import element_mul_kernel
5
+ from liger_kernel.ops.utils import amp_custom_bwd, amp_custom_fwd, element_mul_kernel
6
6
 
7
7
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
8
8
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
@@ -19,9 +19,7 @@ def fused_linear_cross_entropy_forward(
19
19
  label_smoothing=0.0,
20
20
  reduction="mean",
21
21
  ):
22
- dtype = (
23
- torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else _input.dtype
24
- )
22
+ dtype = _input.dtype
25
23
  device = _input.device
26
24
 
27
25
  # inputs have shape: BT x H
@@ -189,6 +187,7 @@ def fused_linear_cross_entropy_backward(
189
187
 
190
188
  class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
191
189
  @staticmethod
190
+ @amp_custom_fwd
192
191
  def forward(
193
192
  ctx,
194
193
  _input,
@@ -228,6 +227,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
228
227
  return loss
229
228
 
230
229
  @staticmethod
230
+ @amp_custom_bwd
231
231
  def backward(ctx, grad_output):
232
232
  (grad_input, grad_weight, grad_bias) = ctx.saved_tensors
233
233
  grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
@@ -92,7 +92,9 @@ def fused_linear_jsd_forward(
92
92
  dX_ptr=student_prob_chunk,
93
93
  dX_stride=student_prob_chunk.stride(-2),
94
94
  label_ptr=(
95
- shift_labels if has_label else torch.empty(1, device=device)
95
+ shift_labels[start_idx:end_idx]
96
+ if has_label
97
+ else torch.empty(1, device=device)
96
98
  ), # dummy ptr if no label
97
99
  beta=jsd_beta,
98
100
  n_non_ignore=n_non_ignore,
@@ -19,7 +19,7 @@ def _jsd_kernel(
19
19
  dX_stride,
20
20
  label_ptr,
21
21
  beta,
22
- n_non_ignore,
22
+ n_non_ignore: int,
23
23
  ignore_index: tl.constexpr,
24
24
  n_cols,
25
25
  BLOCK_SIZE: tl.constexpr,
@@ -12,6 +12,7 @@ Modifications made by Yanning Chen, 2024.
12
12
 
13
13
  import functools
14
14
  import importlib
15
+ import operator
15
16
  from typing import Callable
16
17
 
17
18
  import torch
@@ -63,6 +64,18 @@ def compare_version(package: str, operator: Callable, target: str):
63
64
  return operator(pkg_version, Version(target))
64
65
 
65
66
 
67
+ def get_amp_custom_fwd_bwd() -> Callable:
68
+ if compare_version("torch", operator.ge, "2.4.0"):
69
+ return (
70
+ functools.partial(torch.amp.custom_fwd, device_type="cuda"),
71
+ functools.partial(torch.amp.custom_bwd, device_type="cuda"),
72
+ )
73
+ return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
74
+
75
+
76
+ amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd()
77
+
78
+
66
79
  torch_to_triton_dtype = {
67
80
  torch.float32: tl.float32,
68
81
  torch.float16: tl.float16,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.3.1.dev20241025165359
3
+ Version: 0.3.1.dev20241030045958
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation