liger-kernel-nightly 0.4.2.dev20241209234352__tar.gz → 0.4.2.dev20241210002150__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. {liger_kernel_nightly-0.4.2.dev20241209234352/src/liger_kernel_nightly.egg-info → liger_kernel_nightly-0.4.2.dev20241210002150}/PKG-INFO +26 -28
  2. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/README.md +25 -27
  3. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/pyproject.toml +1 -1
  4. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/ops/cross_entropy.py +5 -4
  5. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150/src/liger_kernel_nightly.egg-info}/PKG-INFO +26 -28
  6. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/LICENSE +0 -0
  7. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/NOTICE +0 -0
  8. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/setup.cfg +0 -0
  9. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/__init__.py +0 -0
  10. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  11. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  12. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  13. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/chunked_loss/functional.py +0 -0
  14. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  15. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  16. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  17. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  18. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/env_report.py +0 -0
  19. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/ops/__init__.py +0 -0
  20. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  21. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  22. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  23. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  24. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/ops/geglu.py +0 -0
  25. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/ops/group_norm.py +0 -0
  26. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/ops/jsd.py +0 -0
  27. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/ops/kl_div.py +0 -0
  28. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/ops/layer_norm.py +0 -0
  29. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  30. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/ops/rms_norm.py +0 -0
  31. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/ops/rope.py +0 -0
  32. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/ops/swiglu.py +0 -0
  33. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/ops/utils.py +0 -0
  34. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/__init__.py +0 -0
  35. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/auto_model.py +0 -0
  36. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  37. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  38. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/functional.py +0 -0
  39. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  40. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  41. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/geglu.py +0 -0
  42. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/group_norm.py +0 -0
  43. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/jsd.py +0 -0
  44. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/kl_div.py +0 -0
  45. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/layer_norm.py +0 -0
  46. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/model/__init__.py +0 -0
  47. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/model/gemma.py +0 -0
  48. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  49. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/model/llama.py +0 -0
  50. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/model/mistral.py +0 -0
  51. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  52. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/model/mllama.py +0 -0
  53. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/model/phi3.py +0 -0
  54. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  55. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  56. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  57. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/orpo_trainer.py +0 -0
  58. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  59. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/rms_norm.py +0 -0
  60. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/rope.py +0 -0
  61. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/swiglu.py +0 -0
  62. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  63. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/triton/__init__.py +0 -0
  64. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/triton/monkey_patch.py +0 -0
  65. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel/utils.py +0 -0
  66. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
  67. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  68. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  69. {liger_kernel_nightly-0.4.2.dev20241209234352 → liger_kernel_nightly-0.4.2.dev20241210002150}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.2.dev20241209234352
3
+ Version: 0.4.2.dev20241210002150
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -110,7 +110,7 @@ Requires-Dist: triton>=3.0.0; extra == "amd"
110
110
 
111
111
  <img src="https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/logo-banner.png">
112
112
 
113
- [Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Cite our work](#cite-this-work)
113
+ [Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [High-level APIs](#high-level-apis) | [Low-level APIs](#low-level-apis) | [Cite our work](#cite-this-work)
114
114
 
115
115
  <details>
116
116
  <summary>Latest News 🔥</summary>
@@ -266,7 +266,7 @@ loss = loss_fn(model.weight, input, target)
266
266
  loss.backward()
267
267
  ```
268
268
 
269
- ## APIs
269
+ ## High-level APIs
270
270
 
271
271
  ### AutoModel
272
272
 
@@ -290,8 +290,12 @@ loss.backward()
290
290
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
291
291
 
292
292
 
293
+ ## Low-level APIs
293
294
 
294
- ### Kernels
295
+ - `Fused Linear` kernels combine linear layers with losses, reducing memory usage by up to 80% - ideal for HBM-constrained workloads.
296
+ - Other kernels use fusion and in-place techniques for memory and performance optimization.
297
+
298
+ ### Model Kernels
295
299
 
296
300
  | **Kernel** | **API** |
297
301
  |---------------------------------|-------------------------------------------------------------|
@@ -301,39 +305,33 @@ loss.backward()
301
305
  | SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` |
302
306
  | GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
303
307
  | CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
304
- | FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
308
+ | Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
309
+
310
+
311
+ ### Alignment Kernels
312
+
313
+ | **Kernel** | **API** |
314
+ |---------------------------------|-------------------------------------------------------------|
315
+ | Fused Linear CPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearCPOLoss` |
316
+ | Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
317
+ | Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
318
+ | Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
319
+
320
+ ### Distillation Kernels
321
+
322
+ | **Kernel** | **API** |
323
+ |---------------------------------|-------------------------------------------------------------|
305
324
  | KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
306
325
  | JSD | `liger_kernel.transformers.LigerJSD` |
307
- | FusedLinearJSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
308
-
309
- - **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
310
- - **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup.
311
- - **GroupNorm**: [GroupNorm](https://arxiv.org/pdf/1803.08494), which normalizes activations across the group dimension for a given sample. Channels are grouped in K groups over which the normalization is performed, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and can achieve up to ~2X speedup as the number of channels/groups increases.
312
- - **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction.
313
- - **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
314
- $$\text{SwiGLU}(x)=\text{Swish}_{\beta}(xW+b)\otimes(xV+c)$$
315
- , is implemented by fusing the elementwise multiplication (denoted by $\otimes$) into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction.
316
- - **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
317
- $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
318
- , is implemented by fusing the elementwise multiplication into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. Note that the [tanh approximation form of GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used.
319
- - **CrossEntropy**: [Cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) is implemented by computing both the loss and gradient in the forward pass with inplace replacement of input to reduce the peak memory by avoiding simultaneous materialization of both input logits and gradient. It achieves >2X speedup and >4X memory reduction for common vocab sizes (e.g., 32K, 128K, etc.).
320
- <!-- TODO: verify vocab sizes are accurate -->
321
- - **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
322
- - **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
323
- - **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size. **NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively.
324
- - **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192. **NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively.
325
-
326
+ | Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
326
327
 
327
328
  ### Experimental Kernels
328
329
 
329
330
  | **Kernel** | **API** |
330
331
  |---------------------------------|-------------------------------------------------------------|
331
332
  | Embedding | `liger_kernel.transformers.experimental.LigerEmbedding` |
332
- | Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul`
333
+ | Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul` |
333
334
 
334
- - **Embedding**: [Embedding](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) is implemented by fusing embedding lookup and output operations. It achieves a peak speedup of ~1.5x in the forward pass and an overall speedup of ~1.1x.
335
- - **Matmul int2xint8**: is implemented by using the cache tiled matrix multiplication and by fusing the matmul with the unpacking process which achieves a considerable speed up and performs on par with @torch.compile
336
- <!-- TODO: be more specific about batch size -->
337
335
 
338
336
  ## Contributing, Acknowledgements, and License
339
337
 
@@ -55,7 +55,7 @@
55
55
 
56
56
  <img src="https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/logo-banner.png">
57
57
 
58
- [Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Cite our work](#cite-this-work)
58
+ [Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [High-level APIs](#high-level-apis) | [Low-level APIs](#low-level-apis) | [Cite our work](#cite-this-work)
59
59
 
60
60
  <details>
61
61
  <summary>Latest News 🔥</summary>
@@ -211,7 +211,7 @@ loss = loss_fn(model.weight, input, target)
211
211
  loss.backward()
212
212
  ```
213
213
 
214
- ## APIs
214
+ ## High-level APIs
215
215
 
216
216
  ### AutoModel
217
217
 
@@ -235,8 +235,12 @@ loss.backward()
235
235
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
236
236
 
237
237
 
238
+ ## Low-level APIs
238
239
 
239
- ### Kernels
240
+ - `Fused Linear` kernels combine linear layers with losses, reducing memory usage by up to 80% - ideal for HBM-constrained workloads.
241
+ - Other kernels use fusion and in-place techniques for memory and performance optimization.
242
+
243
+ ### Model Kernels
240
244
 
241
245
  | **Kernel** | **API** |
242
246
  |---------------------------------|-------------------------------------------------------------|
@@ -246,39 +250,33 @@ loss.backward()
246
250
  | SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` |
247
251
  | GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
248
252
  | CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
249
- | FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
253
+ | Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
254
+
255
+
256
+ ### Alignment Kernels
257
+
258
+ | **Kernel** | **API** |
259
+ |---------------------------------|-------------------------------------------------------------|
260
+ | Fused Linear CPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearCPOLoss` |
261
+ | Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
262
+ | Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
263
+ | Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
264
+
265
+ ### Distillation Kernels
266
+
267
+ | **Kernel** | **API** |
268
+ |---------------------------------|-------------------------------------------------------------|
250
269
  | KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
251
270
  | JSD | `liger_kernel.transformers.LigerJSD` |
252
- | FusedLinearJSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
253
-
254
- - **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
255
- - **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup.
256
- - **GroupNorm**: [GroupNorm](https://arxiv.org/pdf/1803.08494), which normalizes activations across the group dimension for a given sample. Channels are grouped in K groups over which the normalization is performed, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and can achieve up to ~2X speedup as the number of channels/groups increases.
257
- - **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction.
258
- - **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
259
- $$\text{SwiGLU}(x)=\text{Swish}_{\beta}(xW+b)\otimes(xV+c)$$
260
- , is implemented by fusing the elementwise multiplication (denoted by $\otimes$) into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction.
261
- - **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
262
- $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
263
- , is implemented by fusing the elementwise multiplication into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. Note that the [tanh approximation form of GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used.
264
- - **CrossEntropy**: [Cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) is implemented by computing both the loss and gradient in the forward pass with inplace replacement of input to reduce the peak memory by avoiding simultaneous materialization of both input logits and gradient. It achieves >2X speedup and >4X memory reduction for common vocab sizes (e.g., 32K, 128K, etc.).
265
- <!-- TODO: verify vocab sizes are accurate -->
266
- - **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
267
- - **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
268
- - **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size. **NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively.
269
- - **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192. **NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively.
270
-
271
+ | Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
271
272
 
272
273
  ### Experimental Kernels
273
274
 
274
275
  | **Kernel** | **API** |
275
276
  |---------------------------------|-------------------------------------------------------------|
276
277
  | Embedding | `liger_kernel.transformers.experimental.LigerEmbedding` |
277
- | Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul`
278
+ | Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul` |
278
279
 
279
- - **Embedding**: [Embedding](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) is implemented by fusing embedding lookup and output operations. It achieves a peak speedup of ~1.5x in the forward pass and an overall speedup of ~1.1x.
280
- - **Matmul int2xint8**: is implemented by using the cache tiled matrix multiplication and by fusing the matmul with the unpacking process which achieves a considerable speed up and performs on par with @torch.compile
281
- <!-- TODO: be more specific about batch size -->
282
280
 
283
281
  ## Contributing, Acknowledgements, and License
284
282
 
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.4.2.dev20241209234352"
7
+ version = "0.4.2.dev20241210002150"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -285,11 +285,12 @@ def cross_entropy_forward(
285
285
  num_warps=32 if not is_hip() else 16,
286
286
  )
287
287
 
288
- loss = torch.sum(loss_1d)
289
- if return_z_loss == _TRUE.value:
290
- z_loss = torch.sum(z_loss_1d)
288
+ if reduction == "none":
289
+ loss = loss_1d
290
+ z_loss = z_loss_1d if return_z_loss == _TRUE.value else None
291
291
  else:
292
- z_loss = None
292
+ loss = torch.sum(loss_1d)
293
+ z_loss = torch.sum(z_loss_1d) if return_z_loss == _TRUE.value else None
293
294
 
294
295
  return loss, z_loss, _input
295
296
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.2.dev20241209234352
3
+ Version: 0.4.2.dev20241210002150
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -110,7 +110,7 @@ Requires-Dist: triton>=3.0.0; extra == "amd"
110
110
 
111
111
  <img src="https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/logo-banner.png">
112
112
 
113
- [Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Cite our work](#cite-this-work)
113
+ [Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [High-level APIs](#high-level-apis) | [Low-level APIs](#low-level-apis) | [Cite our work](#cite-this-work)
114
114
 
115
115
  <details>
116
116
  <summary>Latest News 🔥</summary>
@@ -266,7 +266,7 @@ loss = loss_fn(model.weight, input, target)
266
266
  loss.backward()
267
267
  ```
268
268
 
269
- ## APIs
269
+ ## High-level APIs
270
270
 
271
271
  ### AutoModel
272
272
 
@@ -290,8 +290,12 @@ loss.backward()
290
290
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
291
291
 
292
292
 
293
+ ## Low-level APIs
293
294
 
294
- ### Kernels
295
+ - `Fused Linear` kernels combine linear layers with losses, reducing memory usage by up to 80% - ideal for HBM-constrained workloads.
296
+ - Other kernels use fusion and in-place techniques for memory and performance optimization.
297
+
298
+ ### Model Kernels
295
299
 
296
300
  | **Kernel** | **API** |
297
301
  |---------------------------------|-------------------------------------------------------------|
@@ -301,39 +305,33 @@ loss.backward()
301
305
  | SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` |
302
306
  | GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
303
307
  | CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
304
- | FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
308
+ | Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
309
+
310
+
311
+ ### Alignment Kernels
312
+
313
+ | **Kernel** | **API** |
314
+ |---------------------------------|-------------------------------------------------------------|
315
+ | Fused Linear CPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearCPOLoss` |
316
+ | Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
317
+ | Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
318
+ | Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
319
+
320
+ ### Distillation Kernels
321
+
322
+ | **Kernel** | **API** |
323
+ |---------------------------------|-------------------------------------------------------------|
305
324
  | KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
306
325
  | JSD | `liger_kernel.transformers.LigerJSD` |
307
- | FusedLinearJSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
308
-
309
- - **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
310
- - **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup.
311
- - **GroupNorm**: [GroupNorm](https://arxiv.org/pdf/1803.08494), which normalizes activations across the group dimension for a given sample. Channels are grouped in K groups over which the normalization is performed, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and can achieve up to ~2X speedup as the number of channels/groups increases.
312
- - **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction.
313
- - **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
314
- $$\text{SwiGLU}(x)=\text{Swish}_{\beta}(xW+b)\otimes(xV+c)$$
315
- , is implemented by fusing the elementwise multiplication (denoted by $\otimes$) into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction.
316
- - **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
317
- $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
318
- , is implemented by fusing the elementwise multiplication into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. Note that the [tanh approximation form of GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used.
319
- - **CrossEntropy**: [Cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) is implemented by computing both the loss and gradient in the forward pass with inplace replacement of input to reduce the peak memory by avoiding simultaneous materialization of both input logits and gradient. It achieves >2X speedup and >4X memory reduction for common vocab sizes (e.g., 32K, 128K, etc.).
320
- <!-- TODO: verify vocab sizes are accurate -->
321
- - **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
322
- - **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
323
- - **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size. **NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively.
324
- - **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192. **NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively.
325
-
326
+ | Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
326
327
 
327
328
  ### Experimental Kernels
328
329
 
329
330
  | **Kernel** | **API** |
330
331
  |---------------------------------|-------------------------------------------------------------|
331
332
  | Embedding | `liger_kernel.transformers.experimental.LigerEmbedding` |
332
- | Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul`
333
+ | Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul` |
333
334
 
334
- - **Embedding**: [Embedding](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) is implemented by fusing embedding lookup and output operations. It achieves a peak speedup of ~1.5x in the forward pass and an overall speedup of ~1.1x.
335
- - **Matmul int2xint8**: is implemented by using the cache tiled matrix multiplication and by fusing the matmul with the unpacking process which achieves a considerable speed up and performs on par with @torch.compile
336
- <!-- TODO: be more specific about batch size -->
337
335
 
338
336
  ## Contributing, Acknowledgements, and License
339
337