liger-kernel-nightly 0.4.2.dev20241209224333__tar.gz → 0.4.2.dev20241210001927__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. {liger_kernel_nightly-0.4.2.dev20241209224333/src/liger_kernel_nightly.egg-info → liger_kernel_nightly-0.4.2.dev20241210001927}/PKG-INFO +26 -28
  2. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/README.md +25 -27
  3. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/pyproject.toml +1 -1
  4. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/chunked_loss/cpo_loss.py +16 -10
  5. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/chunked_loss/dpo_loss.py +20 -12
  6. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/chunked_loss/orpo_loss.py +15 -9
  7. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/chunked_loss/simpo_loss.py +17 -11
  8. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927/src/liger_kernel_nightly.egg-info}/PKG-INFO +26 -28
  9. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/LICENSE +0 -0
  10. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/NOTICE +0 -0
  11. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/setup.cfg +0 -0
  12. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/__init__.py +0 -0
  13. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  14. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/chunked_loss/functional.py +0 -0
  15. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  16. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  17. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/env_report.py +0 -0
  18. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/ops/__init__.py +0 -0
  19. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/ops/cross_entropy.py +0 -0
  20. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  21. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  22. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  23. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  24. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/ops/geglu.py +0 -0
  25. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/ops/group_norm.py +0 -0
  26. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/ops/jsd.py +0 -0
  27. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/ops/kl_div.py +0 -0
  28. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/ops/layer_norm.py +0 -0
  29. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  30. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/ops/rms_norm.py +0 -0
  31. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/ops/rope.py +0 -0
  32. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/ops/swiglu.py +0 -0
  33. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/ops/utils.py +0 -0
  34. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/__init__.py +0 -0
  35. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/auto_model.py +0 -0
  36. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  37. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  38. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/functional.py +0 -0
  39. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  40. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  41. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/geglu.py +0 -0
  42. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/group_norm.py +0 -0
  43. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/jsd.py +0 -0
  44. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/kl_div.py +0 -0
  45. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/layer_norm.py +0 -0
  46. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/model/__init__.py +0 -0
  47. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/model/gemma.py +0 -0
  48. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  49. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/model/llama.py +0 -0
  50. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/model/mistral.py +0 -0
  51. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  52. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/model/mllama.py +0 -0
  53. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/model/phi3.py +0 -0
  54. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  55. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  56. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  57. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/orpo_trainer.py +0 -0
  58. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  59. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/rms_norm.py +0 -0
  60. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/rope.py +0 -0
  61. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/swiglu.py +0 -0
  62. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  63. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/triton/__init__.py +0 -0
  64. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/triton/monkey_patch.py +0 -0
  65. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel/utils.py +0 -0
  66. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
  67. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  68. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  69. {liger_kernel_nightly-0.4.2.dev20241209224333 → liger_kernel_nightly-0.4.2.dev20241210001927}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.2.dev20241209224333
3
+ Version: 0.4.2.dev20241210001927
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -110,7 +110,7 @@ Requires-Dist: triton>=3.0.0; extra == "amd"
110
110
 
111
111
  <img src="https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/logo-banner.png">
112
112
 
113
- [Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Cite our work](#cite-this-work)
113
+ [Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [High-level APIs](#high-level-apis) | [Low-level APIs](#low-level-apis) | [Cite our work](#cite-this-work)
114
114
 
115
115
  <details>
116
116
  <summary>Latest News 🔥</summary>
@@ -266,7 +266,7 @@ loss = loss_fn(model.weight, input, target)
266
266
  loss.backward()
267
267
  ```
268
268
 
269
- ## APIs
269
+ ## High-level APIs
270
270
 
271
271
  ### AutoModel
272
272
 
@@ -290,8 +290,12 @@ loss.backward()
290
290
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
291
291
 
292
292
 
293
+ ## Low-level APIs
293
294
 
294
- ### Kernels
295
+ - `Fused Linear` kernels combine linear layers with losses, reducing memory usage by up to 80% - ideal for HBM-constrained workloads.
296
+ - Other kernels use fusion and in-place techniques for memory and performance optimization.
297
+
298
+ ### Model Kernels
295
299
 
296
300
  | **Kernel** | **API** |
297
301
  |---------------------------------|-------------------------------------------------------------|
@@ -301,39 +305,33 @@ loss.backward()
301
305
  | SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` |
302
306
  | GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
303
307
  | CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
304
- | FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
308
+ | Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
309
+
310
+
311
+ ### Alignment Kernels
312
+
313
+ | **Kernel** | **API** |
314
+ |---------------------------------|-------------------------------------------------------------|
315
+ | Fused Linear CPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearCPOLoss` |
316
+ | Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
317
+ | Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
318
+ | Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
319
+
320
+ ### Distillation Kernels
321
+
322
+ | **Kernel** | **API** |
323
+ |---------------------------------|-------------------------------------------------------------|
305
324
  | KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
306
325
  | JSD | `liger_kernel.transformers.LigerJSD` |
307
- | FusedLinearJSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
308
-
309
- - **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
310
- - **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup.
311
- - **GroupNorm**: [GroupNorm](https://arxiv.org/pdf/1803.08494), which normalizes activations across the group dimension for a given sample. Channels are grouped in K groups over which the normalization is performed, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and can achieve up to ~2X speedup as the number of channels/groups increases.
312
- - **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction.
313
- - **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
314
- $$\text{SwiGLU}(x)=\text{Swish}_{\beta}(xW+b)\otimes(xV+c)$$
315
- , is implemented by fusing the elementwise multiplication (denoted by $\otimes$) into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction.
316
- - **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
317
- $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
318
- , is implemented by fusing the elementwise multiplication into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. Note that the [tanh approximation form of GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used.
319
- - **CrossEntropy**: [Cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) is implemented by computing both the loss and gradient in the forward pass with inplace replacement of input to reduce the peak memory by avoiding simultaneous materialization of both input logits and gradient. It achieves >2X speedup and >4X memory reduction for common vocab sizes (e.g., 32K, 128K, etc.).
320
- <!-- TODO: verify vocab sizes are accurate -->
321
- - **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
322
- - **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
323
- - **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size. **NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively.
324
- - **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192. **NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively.
325
-
326
+ | Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
326
327
 
327
328
  ### Experimental Kernels
328
329
 
329
330
  | **Kernel** | **API** |
330
331
  |---------------------------------|-------------------------------------------------------------|
331
332
  | Embedding | `liger_kernel.transformers.experimental.LigerEmbedding` |
332
- | Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul`
333
+ | Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul` |
333
334
 
334
- - **Embedding**: [Embedding](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) is implemented by fusing embedding lookup and output operations. It achieves a peak speedup of ~1.5x in the forward pass and an overall speedup of ~1.1x.
335
- - **Matmul int2xint8**: is implemented by using the cache tiled matrix multiplication and by fusing the matmul with the unpacking process which achieves a considerable speed up and performs on par with @torch.compile
336
- <!-- TODO: be more specific about batch size -->
337
335
 
338
336
  ## Contributing, Acknowledgements, and License
339
337
 
@@ -55,7 +55,7 @@
55
55
 
56
56
  <img src="https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/logo-banner.png">
57
57
 
58
- [Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Cite our work](#cite-this-work)
58
+ [Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [High-level APIs](#high-level-apis) | [Low-level APIs](#low-level-apis) | [Cite our work](#cite-this-work)
59
59
 
60
60
  <details>
61
61
  <summary>Latest News 🔥</summary>
@@ -211,7 +211,7 @@ loss = loss_fn(model.weight, input, target)
211
211
  loss.backward()
212
212
  ```
213
213
 
214
- ## APIs
214
+ ## High-level APIs
215
215
 
216
216
  ### AutoModel
217
217
 
@@ -235,8 +235,12 @@ loss.backward()
235
235
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
236
236
 
237
237
 
238
+ ## Low-level APIs
238
239
 
239
- ### Kernels
240
+ - `Fused Linear` kernels combine linear layers with losses, reducing memory usage by up to 80% - ideal for HBM-constrained workloads.
241
+ - Other kernels use fusion and in-place techniques for memory and performance optimization.
242
+
243
+ ### Model Kernels
240
244
 
241
245
  | **Kernel** | **API** |
242
246
  |---------------------------------|-------------------------------------------------------------|
@@ -246,39 +250,33 @@ loss.backward()
246
250
  | SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` |
247
251
  | GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
248
252
  | CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
249
- | FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
253
+ | Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
254
+
255
+
256
+ ### Alignment Kernels
257
+
258
+ | **Kernel** | **API** |
259
+ |---------------------------------|-------------------------------------------------------------|
260
+ | Fused Linear CPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearCPOLoss` |
261
+ | Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
262
+ | Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
263
+ | Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
264
+
265
+ ### Distillation Kernels
266
+
267
+ | **Kernel** | **API** |
268
+ |---------------------------------|-------------------------------------------------------------|
250
269
  | KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
251
270
  | JSD | `liger_kernel.transformers.LigerJSD` |
252
- | FusedLinearJSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
253
-
254
- - **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
255
- - **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup.
256
- - **GroupNorm**: [GroupNorm](https://arxiv.org/pdf/1803.08494), which normalizes activations across the group dimension for a given sample. Channels are grouped in K groups over which the normalization is performed, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and can achieve up to ~2X speedup as the number of channels/groups increases.
257
- - **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction.
258
- - **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
259
- $$\text{SwiGLU}(x)=\text{Swish}_{\beta}(xW+b)\otimes(xV+c)$$
260
- , is implemented by fusing the elementwise multiplication (denoted by $\otimes$) into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction.
261
- - **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
262
- $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
263
- , is implemented by fusing the elementwise multiplication into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. Note that the [tanh approximation form of GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used.
264
- - **CrossEntropy**: [Cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) is implemented by computing both the loss and gradient in the forward pass with inplace replacement of input to reduce the peak memory by avoiding simultaneous materialization of both input logits and gradient. It achieves >2X speedup and >4X memory reduction for common vocab sizes (e.g., 32K, 128K, etc.).
265
- <!-- TODO: verify vocab sizes are accurate -->
266
- - **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
267
- - **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
268
- - **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size. **NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively.
269
- - **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192. **NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively.
270
-
271
+ | Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
271
272
 
272
273
  ### Experimental Kernels
273
274
 
274
275
  | **Kernel** | **API** |
275
276
  |---------------------------------|-------------------------------------------------------------|
276
277
  | Embedding | `liger_kernel.transformers.experimental.LigerEmbedding` |
277
- | Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul`
278
+ | Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul` |
278
279
 
279
- - **Embedding**: [Embedding](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) is implemented by fusing embedding lookup and output operations. It achieves a peak speedup of ~1.5x in the forward pass and an overall speedup of ~1.1x.
280
- - **Matmul int2xint8**: is implemented by using the cache tiled matrix multiplication and by fusing the matmul with the unpacking process which achieves a considerable speed up and performs on par with @torch.compile
281
- <!-- TODO: be more specific about batch size -->
282
280
 
283
281
  ## Contributing, Acknowledgements, and License
284
282
 
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.4.2.dev20241209224333"
7
+ version = "0.4.2.dev20241210001927"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -11,11 +11,25 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
11
11
  @staticmethod
12
12
  def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
13
13
  """
14
- Compute odds-ratio loss.
14
+ Paper: https://arxiv.org/pdf/2401.08417
15
+
16
+ Formula:
17
+ L(π_θ; U) = -E_(x,y_w,y_l)~D[log σ(β log π_θ(y_w|x) - β log π_θ(y_l|x))]
18
+
19
+ Where:
20
+ - π_θ(y|x): Policy (model) probability
21
+ - y_w: Chosen sequence
22
+ - y_l: Rejected sequence
23
+ - σ: Sigmoid function
24
+ - β: Temperature parameter
25
+ - E: Expected value over the dataset D
26
+ - D: Dataset of preferences
27
+
15
28
  Args:
16
29
  chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
17
30
  rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
18
- beta (float): Weight for the odds ratio loss.
31
+ full_target (torch.Tensor): Non chunked full target tensor
32
+ beta (float): Weight for the CPO loss
19
33
  """
20
34
  logits = beta * (chosen_logps - rejected_logps)
21
35
  loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
@@ -34,12 +48,6 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
34
48
  compute_nll_loss=True,
35
49
  compiled=True,
36
50
  ):
37
- """
38
- Fused linear layer with CPO (Odds-Ratio Preference Optimization) loss.
39
- Handles both the forward and backward pass of the final linear layer with CPO loss.
40
- Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
41
- """
42
-
43
51
  return LigerFusedLinearPreferenceBase.forward(
44
52
  ctx,
45
53
  _input,
@@ -56,9 +64,7 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
56
64
 
57
65
  @staticmethod
58
66
  def backward(ctx, *grad_output):
59
- # Get gradients for _input, weight, bias, and target from the base class
60
67
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
61
- # Return these gradients, followed by None for the remaining inputs
62
68
  return *grads, None, None, None, None, None
63
69
 
64
70
 
@@ -18,14 +18,28 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
18
18
  beta=0.1,
19
19
  ):
20
20
  """
21
- Compute DPO loss (Direct Preference Optimization).
21
+ Paper: https://arxiv.org/pdf/2305.18290
22
+
23
+ Formula:
24
+ L_DPO = -E[ log_sigmoid( β * (log(π(y_w|x)/π_ref(y_w|x)) - log(π(y_l|x)/π_ref(y_l|x))) ) ]
25
+
26
+ Where:
27
+ - π(y|x): Policy (model) probability
28
+ - π_ref(y|x): Reference model probability
29
+ - y_w: Chosen sequence
30
+ - y_l: Rejected sequence
31
+ - β: Weight for the direct preference loss
32
+ - E: Expected value over the dataset
33
+
22
34
  Args:
23
- chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
24
- rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
25
- ref_chosen_logps (torch.Tensor, optional): Reference log probabilities of chosen tokens. Shape: (batch_size,).
26
- ref_rejected_logps (torch.Tensor, optional): Reference log probabilities of rejected tokens. Shape: (batch_size,).
27
- beta (float): Weight for the direct preference loss.
35
+ chosen_logps: Log probabilities of chosen tokens (batch_size,)
36
+ rejected_logps: Log probabilities of rejected tokens (batch_size,)
37
+ full_target: Non chunked full target tensor
38
+ ref_chosen_logps: Reference log probs of chosen tokens (batch_size,)
39
+ ref_rejected_logps: Reference log probs of rejected tokens (batch_size,)
40
+ beta: Weight for the direct preference loss
28
41
  """
42
+
29
43
  if ref_chosen_logps is None:
30
44
  ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device)
31
45
  if ref_rejected_logps is None:
@@ -53,10 +67,6 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
53
67
  compiled=True,
54
68
  use_ref_model=True,
55
69
  ):
56
- """
57
- Fused linear layer with DPO (Direct Preference Optimization) loss.
58
- Handles both the forward and backward pass of the final linear layer with DPO loss.
59
- """
60
70
  return LigerFusedLinearPreferenceBase.forward(
61
71
  ctx=ctx,
62
72
  _input=_input,
@@ -75,9 +85,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
75
85
 
76
86
  @staticmethod
77
87
  def backward(ctx, *grad_output):
78
- # Get gradients for _input, weight, bias, and target from the base class
79
88
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
80
- # Return these gradients, followed by None for the remaining inputs
81
89
  return *grads, None, None, None, None, None, None, None
82
90
 
83
91
 
@@ -11,10 +11,24 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
11
11
  @staticmethod
12
12
  def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
13
13
  """
14
- Compute odds-ratio loss.
14
+ Paper: https://arxiv.org/pdf/2403.07691
15
+
16
+ Formula:
17
+ Compute odds-ratio loss: L_OR = -log(σ(log(odds_θ(y_w|x) / odds_θ(y_l|x))))
18
+ where odds_θ(y|x) = P_θ(y|x) / (1 - P_θ(y|x))
19
+
20
+ Where:
21
+ - P_θ(y|x): Policy (model) probability
22
+ - y_w: Chosen sequence
23
+ - y_l: Rejected sequence
24
+ - σ: Sigmoid function
25
+ - β: Weight for the odds ratio loss
26
+ - odds_θ: Odds function for the policy
27
+
15
28
  Args:
16
29
  chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
17
30
  rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
31
+ full_target (torch.Tensor): Non chunked full target tensor
18
32
  beta (float): Weight for the odds ratio loss.
19
33
  """
20
34
  log_odds = (chosen_logps - rejected_logps) - (
@@ -44,12 +58,6 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
44
58
  compute_nll_loss=True,
45
59
  compiled=True,
46
60
  ):
47
- """
48
- Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss.
49
- Handles both the forward and backward pass of the final linear layer with ORPO loss.
50
- Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
51
- """
52
-
53
61
  return LigerFusedLinearPreferenceBase.forward(
54
62
  ctx=ctx,
55
63
  _input=_input,
@@ -65,9 +73,7 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
65
73
 
66
74
  @staticmethod
67
75
  def backward(ctx, *grad_output):
68
- # Get gradients for _input, weight, bias, and target from the base class
69
76
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
70
- # Return these gradients, followed by None for the remaining inputs
71
77
  return *grads, None, None, None, None
72
78
 
73
79
 
@@ -13,12 +13,26 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
13
13
  chosen_logps, rejected_logps, full_target, beta=0.1, gamma=0.5
14
14
  ):
15
15
  """
16
- Compute odds-ratio loss.
16
+ Paper: https://arxiv.org/pdf/2405.14734
17
+
18
+ Formula:
19
+ L_SimPO(π_θ) = -E [log σ(β/|y_w| log π_θ(y_w|x) - β/|y_l| log π_θ(y_l|x) - γ)]
20
+
21
+ Where:
22
+ - π_θ(y|x): Policy (model) probability
23
+ - y_w: Chosen sequence
24
+ - y_l: Rejected sequence
25
+ - |y_w|, |y_l|: Sequence lengths
26
+ - σ: Sigmoid function
27
+ - β: beta weight
28
+ - γ: gemma margin term
29
+
17
30
  Args:
18
31
  chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
19
32
  rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
20
- beta (float): Weight for the odds ratio loss.
21
- gamma (float): The simpo gamma, margin term.
33
+ full_target: Non chunked full target tensor
34
+ beta (float): beta weight
35
+ gamma (float): gemma margin term
22
36
  """
23
37
  logits = beta * (chosen_logps - rejected_logps) - gamma
24
38
  loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
@@ -38,12 +52,6 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
38
52
  compiled=True,
39
53
  gamma=0.5,
40
54
  ):
41
- """
42
- Fused linear layer with SimPO (Simple Preference Optimization) loss. https://arxiv.org/pdf/2405.14734
43
- Handles both the forward and backward pass of the final linear layer with SimPO loss.
44
- Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
45
- """
46
-
47
55
  return LigerFusedLinearPreferenceBase.forward(
48
56
  ctx,
49
57
  _input,
@@ -61,9 +69,7 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
61
69
 
62
70
  @staticmethod
63
71
  def backward(ctx, *grad_output):
64
- # Get gradients for _input, weight, bias, and target from the base class
65
72
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
66
- # Return these gradients, followed by None for the remaining inputs
67
73
  return *grads, None, None, None, None, None, None
68
74
 
69
75
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.2.dev20241209224333
3
+ Version: 0.4.2.dev20241210001927
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -110,7 +110,7 @@ Requires-Dist: triton>=3.0.0; extra == "amd"
110
110
 
111
111
  <img src="https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/logo-banner.png">
112
112
 
113
- [Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Cite our work](#cite-this-work)
113
+ [Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [High-level APIs](#high-level-apis) | [Low-level APIs](#low-level-apis) | [Cite our work](#cite-this-work)
114
114
 
115
115
  <details>
116
116
  <summary>Latest News 🔥</summary>
@@ -266,7 +266,7 @@ loss = loss_fn(model.weight, input, target)
266
266
  loss.backward()
267
267
  ```
268
268
 
269
- ## APIs
269
+ ## High-level APIs
270
270
 
271
271
  ### AutoModel
272
272
 
@@ -290,8 +290,12 @@ loss.backward()
290
290
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
291
291
 
292
292
 
293
+ ## Low-level APIs
293
294
 
294
- ### Kernels
295
+ - `Fused Linear` kernels combine linear layers with losses, reducing memory usage by up to 80% - ideal for HBM-constrained workloads.
296
+ - Other kernels use fusion and in-place techniques for memory and performance optimization.
297
+
298
+ ### Model Kernels
295
299
 
296
300
  | **Kernel** | **API** |
297
301
  |---------------------------------|-------------------------------------------------------------|
@@ -301,39 +305,33 @@ loss.backward()
301
305
  | SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` |
302
306
  | GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
303
307
  | CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
304
- | FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
308
+ | Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
309
+
310
+
311
+ ### Alignment Kernels
312
+
313
+ | **Kernel** | **API** |
314
+ |---------------------------------|-------------------------------------------------------------|
315
+ | Fused Linear CPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearCPOLoss` |
316
+ | Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
317
+ | Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
318
+ | Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
319
+
320
+ ### Distillation Kernels
321
+
322
+ | **Kernel** | **API** |
323
+ |---------------------------------|-------------------------------------------------------------|
305
324
  | KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
306
325
  | JSD | `liger_kernel.transformers.LigerJSD` |
307
- | FusedLinearJSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
308
-
309
- - **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
310
- - **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup.
311
- - **GroupNorm**: [GroupNorm](https://arxiv.org/pdf/1803.08494), which normalizes activations across the group dimension for a given sample. Channels are grouped in K groups over which the normalization is performed, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and can achieve up to ~2X speedup as the number of channels/groups increases.
312
- - **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction.
313
- - **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
314
- $$\text{SwiGLU}(x)=\text{Swish}_{\beta}(xW+b)\otimes(xV+c)$$
315
- , is implemented by fusing the elementwise multiplication (denoted by $\otimes$) into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction.
316
- - **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
317
- $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
318
- , is implemented by fusing the elementwise multiplication into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. Note that the [tanh approximation form of GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used.
319
- - **CrossEntropy**: [Cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) is implemented by computing both the loss and gradient in the forward pass with inplace replacement of input to reduce the peak memory by avoiding simultaneous materialization of both input logits and gradient. It achieves >2X speedup and >4X memory reduction for common vocab sizes (e.g., 32K, 128K, etc.).
320
- <!-- TODO: verify vocab sizes are accurate -->
321
- - **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
322
- - **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
323
- - **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size. **NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively.
324
- - **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192. **NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively.
325
-
326
+ | Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
326
327
 
327
328
  ### Experimental Kernels
328
329
 
329
330
  | **Kernel** | **API** |
330
331
  |---------------------------------|-------------------------------------------------------------|
331
332
  | Embedding | `liger_kernel.transformers.experimental.LigerEmbedding` |
332
- | Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul`
333
+ | Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul` |
333
334
 
334
- - **Embedding**: [Embedding](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) is implemented by fusing embedding lookup and output operations. It achieves a peak speedup of ~1.5x in the forward pass and an overall speedup of ~1.1x.
335
- - **Matmul int2xint8**: is implemented by using the cache tiled matrix multiplication and by fusing the matmul with the unpacking process which achieves a considerable speed up and performs on par with @torch.compile
336
- <!-- TODO: be more specific about batch size -->
337
335
 
338
336
  ## Contributing, Acknowledgements, and License
339
337