liger-kernel-nightly 0.4.2.dev20241122052539__tar.gz → 0.4.2.dev20241123040418__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. {liger_kernel_nightly-0.4.2.dev20241122052539/src/liger_kernel_nightly.egg-info → liger_kernel_nightly-0.4.2.dev20241123040418}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/pyproject.toml +1 -1
  3. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/ops/cross_entropy.py +12 -6
  4. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -11
  5. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/ops/layer_norm.py +6 -1
  6. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/ops/rms_norm.py +1 -0
  7. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/ops/utils.py +5 -2
  8. liger_kernel_nightly-0.4.2.dev20241123040418/src/liger_kernel/transformers/model/__init__.py +0 -0
  9. liger_kernel_nightly-0.4.2.dev20241123040418/src/liger_kernel/utils.py +13 -0
  10. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418/src/liger_kernel_nightly.egg-info}/PKG-INFO +1 -1
  11. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel_nightly.egg-info/SOURCES.txt +2 -0
  12. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/LICENSE +0 -0
  13. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/NOTICE +0 -0
  14. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/README.md +0 -0
  15. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/setup.cfg +0 -0
  16. {liger_kernel_nightly-0.4.2.dev20241122052539/src/liger_kernel/ops → liger_kernel_nightly-0.4.2.dev20241123040418/src/liger_kernel}/__init__.py +0 -0
  17. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  18. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  19. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  20. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/chunked_loss/functional.py +0 -0
  21. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  22. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  23. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  24. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/env_report.py +0 -0
  25. {liger_kernel_nightly-0.4.2.dev20241122052539/src/liger_kernel/transformers/model → liger_kernel_nightly-0.4.2.dev20241123040418/src/liger_kernel/ops}/__init__.py +0 -0
  26. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  27. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  28. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  29. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/ops/geglu.py +0 -0
  30. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/ops/group_norm.py +0 -0
  31. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/ops/jsd.py +0 -0
  32. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/ops/kl_div.py +0 -0
  33. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  34. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/ops/rope.py +0 -0
  35. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/ops/swiglu.py +0 -0
  36. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/transformers/__init__.py +0 -0
  37. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/transformers/auto_model.py +0 -0
  38. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  39. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  40. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/transformers/functional.py +0 -0
  41. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  42. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  43. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/transformers/geglu.py +0 -0
  44. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/transformers/group_norm.py +0 -0
  45. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/transformers/jsd.py +0 -0
  46. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/transformers/kl_div.py +0 -0
  47. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/transformers/layer_norm.py +0 -0
  48. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/transformers/model/gemma.py +0 -0
  49. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  50. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/transformers/model/llama.py +0 -0
  51. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/transformers/model/mistral.py +0 -0
  52. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  53. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/transformers/model/mllama.py +0 -0
  54. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/transformers/model/phi3.py +0 -0
  55. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  56. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  57. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  58. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  59. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/transformers/rms_norm.py +0 -0
  60. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/transformers/rope.py +0 -0
  61. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/transformers/swiglu.py +0 -0
  62. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  63. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/triton/__init__.py +0 -0
  64. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel/triton/monkey_patch.py +0 -0
  65. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  66. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  67. {liger_kernel_nightly-0.4.2.dev20241122052539 → liger_kernel_nightly-0.4.2.dev20241123040418}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.2.dev20241122052539
3
+ Version: 0.4.2.dev20241123040418
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.4.2.dev20241122052539"
7
+ version = "0.4.2.dev20241123040418"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -92,8 +92,8 @@ def liger_cross_entropy_kernel(
92
92
  # 3. [Online softmax] first pass: find max + sum
93
93
  m = float("-inf") # m is the max value. use the notation from the paper
94
94
  d = 0.0 # d is the sum. use the notation from the paper
95
- ori_X_y = tl.load(
96
- X_ptr + y
95
+ ori_X_y = tl.load(X_ptr + y).cast(
96
+ tl.float32
97
97
  ) # we need to store the original value of X_y for the loss calculation
98
98
  if HAS_SOFTCAPPING:
99
99
  ori_X_y = softcap * tanh(ori_X_y / softcap)
@@ -106,8 +106,11 @@ def liger_cross_entropy_kernel(
106
106
  for i in range(0, n_cols, BLOCK_SIZE):
107
107
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
108
108
  X_block = tl.load(
109
- X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
110
- )
109
+ X_ptr + X_offsets,
110
+ mask=X_offsets < n_cols,
111
+ other=float("-inf"),
112
+ # Ensure float32 precision for softmax calculation
113
+ ).cast(tl.float32)
111
114
  if HAS_SOFTCAPPING:
112
115
  X_block = softcap * tanh(X_block / softcap)
113
116
  block_max = tl.max(X_block)
@@ -141,8 +144,11 @@ def liger_cross_entropy_kernel(
141
144
  for i in range(0, n_cols, BLOCK_SIZE):
142
145
  X_offsets = i + tl.arange(0, BLOCK_SIZE)
143
146
  X_block = tl.load(
144
- X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
145
- )
147
+ X_ptr + X_offsets,
148
+ mask=X_offsets < n_cols,
149
+ other=float("-inf"),
150
+ # Ensure float32 precision for softmax calculation
151
+ ).cast(tl.float32)
146
152
  if HAS_SOFTCAPPING:
147
153
  intermediate = tanh(X_block / softcap)
148
154
  X_block = softcap * intermediate
@@ -26,7 +26,6 @@ def fused_linear_cross_entropy_forward(
26
26
  reduction="mean",
27
27
  softcap=None,
28
28
  ):
29
- dtype = _input.dtype
30
29
  device = _input.device
31
30
 
32
31
  # inputs have shape: BT x H
@@ -74,9 +73,6 @@ def fused_linear_cross_entropy_forward(
74
73
  loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
75
74
  n_non_ignore = (target_chunk != ignore_index).sum().item()
76
75
 
77
- # when doing CE, use the upcasted precision
78
- logits_chunk = logits_chunk.float()
79
-
80
76
  # ensure _input and target are contiguous
81
77
  logits_chunk = logits_chunk.contiguous()
82
78
  target_chunk = target_chunk.contiguous()
@@ -103,13 +99,6 @@ def fused_linear_cross_entropy_forward(
103
99
  num_warps=32 if not is_hip() else 16,
104
100
  )
105
101
 
106
- # gradient of logits_chunk is computed in-place by the above triton kernel.
107
- # Following HuggingFace model source code, we do the forward and backward
108
- # w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size) is huge.
109
- # (reference: https://github.com/huggingface/transformers/blob/v4.42.4/src/transformers/models/llama/modeling_llama.py#L1194)
110
- # Propagating to lm_head's backward, we'll switch back to the original dtype.
111
- logits_chunk = logits_chunk.to(dtype)
112
-
113
102
  # gradient of logits_chunk is computed in-place by the above triton kernel and is of shape: chunk_size x V
114
103
  # thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H
115
104
  # additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
@@ -180,8 +180,13 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
180
180
  dY = dY.view(-1, dim)
181
181
  n_rows, n_cols = dY.shape
182
182
 
183
+ sm_count = 1
184
+ if X.device.type == "cuda":
185
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
186
+ elif X.device.type == "xpu":
187
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_subslice_count
188
+
183
189
  DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
184
- sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
185
190
  _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
186
191
  _DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
187
192
 
@@ -264,6 +264,7 @@ def rms_norm_backward(
264
264
  dY = dY.view(-1, dim)
265
265
  n_rows, n_cols = dY.shape
266
266
 
267
+ sm_count = 1
267
268
  if X.device.type == "cuda":
268
269
  sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
269
270
  elif X.device.type == "xpu":
@@ -20,6 +20,8 @@ import triton
20
20
  import triton.language as tl
21
21
  from packaging.version import Version
22
22
 
23
+ from liger_kernel.utils import infer_device
24
+
23
25
 
24
26
  def is_hip() -> bool:
25
27
  return torch.version.hip is not None
@@ -69,10 +71,11 @@ def compare_version(package: str, operator: Callable, target: str):
69
71
 
70
72
 
71
73
  def get_amp_custom_fwd_bwd() -> Callable:
74
+ device = infer_device()
72
75
  if compare_version("torch", operator.ge, "2.4.0"):
73
76
  return (
74
- functools.partial(torch.amp.custom_fwd, device_type="cuda"),
75
- functools.partial(torch.amp.custom_bwd, device_type="cuda"),
77
+ functools.partial(torch.amp.custom_fwd, device_type=device),
78
+ functools.partial(torch.amp.custom_bwd, device_type=device),
76
79
  )
77
80
  return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
78
81
 
@@ -0,0 +1,13 @@
1
+ import torch
2
+
3
+
4
+ def infer_device():
5
+ """
6
+ Get current device name based on available devices
7
+ """
8
+ if torch.cuda.is_available():
9
+ return "cuda"
10
+ elif torch.xpu.is_available():
11
+ return "xpu"
12
+ else:
13
+ return "cpu"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.2.dev20241122052539
3
+ Version: 0.4.2.dev20241123040418
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -2,7 +2,9 @@ LICENSE
2
2
  NOTICE
3
3
  README.md
4
4
  pyproject.toml
5
+ src/liger_kernel/__init__.py
5
6
  src/liger_kernel/env_report.py
7
+ src/liger_kernel/utils.py
6
8
  src/liger_kernel/chunked_loss/__init__.py
7
9
  src/liger_kernel/chunked_loss/cpo_loss.py
8
10
  src/liger_kernel/chunked_loss/dpo_loss.py