liger-kernel 0.4.2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/__init__.py +4 -0
  3. liger_kernel/chunked_loss/cpo_loss.py +107 -0
  4. liger_kernel/chunked_loss/dpo_loss.py +95 -17
  5. liger_kernel/chunked_loss/functional.py +9 -0
  6. liger_kernel/chunked_loss/fused_linear_distillation.py +252 -0
  7. liger_kernel/chunked_loss/fused_linear_preference.py +245 -65
  8. liger_kernel/chunked_loss/orpo_loss.py +63 -13
  9. liger_kernel/chunked_loss/simpo_loss.py +115 -0
  10. liger_kernel/env_report.py +22 -0
  11. liger_kernel/ops/cross_entropy.py +17 -10
  12. liger_kernel/ops/fused_linear_cross_entropy.py +0 -11
  13. liger_kernel/ops/fused_linear_jsd.py +1 -1
  14. liger_kernel/ops/jsd.py +19 -10
  15. liger_kernel/ops/layer_norm.py +6 -1
  16. liger_kernel/ops/qwen2vl_mrope.py +238 -0
  17. liger_kernel/ops/rms_norm.py +6 -1
  18. liger_kernel/ops/utils.py +5 -2
  19. liger_kernel/transformers/functional.py +128 -11
  20. liger_kernel/transformers/fused_linear_jsd.py +1 -4
  21. liger_kernel/transformers/jsd.py +1 -4
  22. liger_kernel/transformers/monkey_patch.py +6 -4
  23. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  24. liger_kernel/transformers/trainer/__init__.py +6 -0
  25. liger_kernel/transformers/trainer/orpo_trainer.py +169 -0
  26. liger_kernel/utils.py +13 -0
  27. {liger_kernel-0.4.2.dist-info → liger_kernel-0.5.1.dist-info}/METADATA +71 -47
  28. {liger_kernel-0.4.2.dist-info → liger_kernel-0.5.1.dist-info}/RECORD +32 -22
  29. {liger_kernel-0.4.2.dist-info → liger_kernel-0.5.1.dist-info}/WHEEL +1 -1
  30. {liger_kernel-0.4.2.dist-info → liger_kernel-0.5.1.dist-info}/LICENSE +0 -0
  31. {liger_kernel-0.4.2.dist-info → liger_kernel-0.5.1.dist-info}/NOTICE +0 -0
  32. {liger_kernel-0.4.2.dist-info → liger_kernel-0.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,20 @@
1
+ from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
2
+
3
+
4
+ def liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
5
+ """
6
+ Applies Multimodal Rotary Positional Embedding (M-RoPE) operation to query and key states.
7
+
8
+ Args:
9
+ q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
10
+ k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
11
+ cos (torch.Tensor): The cosine tensor of shape (3, 1, seq_len, head_dim).
12
+ sin (torch.Tensor): The sine tensor of shape (3, 1, seq_len, head_dim).
13
+ mrope_section (List[int]): The multimodal rope section for channel dimension of temporal, height and width in rope calculation.
14
+ unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
15
+
16
+ Returns:
17
+ Tuple[torch.Tensor, torch.Tensor]: The query and key tensors after applying the M-RoPE operation.
18
+ """
19
+
20
+ return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
@@ -0,0 +1,6 @@
1
+ try:
2
+ from liger_kernel.transformers.trainer.orpo_trainer import ( # noqa: F401
3
+ LigerORPOTrainer,
4
+ )
5
+ except ImportError:
6
+ raise ImportError("Please `pip install trl` to use LigerORPOTrainer")
@@ -0,0 +1,169 @@
1
+ from typing import Any, Callable, Dict, List, Literal, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.distributed.fsdp import FullyShardedDataParallel
6
+ from trl.trainer import ORPOTrainer
7
+
8
+ from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss
9
+
10
+
11
+ class _FSDPForwardRedirection:
12
+ """
13
+ Modified based on
14
+ https://github.com/Lightning-AI/pytorch-lightning/blob/d3f9c83d6efa4f1def36aa6c199600946cdb9117/src/lightning/pytorch/strategies/strategy.py#L601-L648
15
+ Redirect a method call through FullyShardedDataParallel.forward so that the FSDP module's root pre-forward and
16
+ post-forward can be properly executed around the method call.
17
+ This is needed in cases where we call a submodule of a FSDP module. For instance, when we want to call only
18
+ the `LlamaModel` part out of a FSDP-wrapped `LlamaForCausalLM` to get the hidden states without involving
19
+ GPU-memory-heavy `lm_head` and cross entropy computation, doing this directly (i.e. `model.model.forward()`)
20
+ will not work because the first `nn.Emebedding` layer is not independently wrapped as a FSDP module (because of
21
+ the transformer-based wrapping policy), and not calling it through FSDP root module forward will not all-gather
22
+ its parameter, thus resulting in "RuntimeError: 'weight' must be 2-D" error. Similarly, if we want to call just
23
+ the `lm_head` part of a model, we need this trick too to properly get its params all-gathered.
24
+ """
25
+
26
+ def __call__(
27
+ self,
28
+ wrapper_module: FullyShardedDataParallel,
29
+ method: Callable,
30
+ *args: Any,
31
+ **kwargs: Any,
32
+ ):
33
+ """Reroutes a method call through the `wrapper_module`'s `forward` method.
34
+ Args:
35
+ wrapper_module: The module that has `original_module` wrapped.
36
+ original_module: The module that was wrapped inside `wrapper_module`.
37
+ method_name: The name of the method that should be called on the `original_module` after inputs get
38
+ redirected through the `wrapper_module`'s `forward` method.
39
+ *args: The positional arguments to the method `method_name`. They will get passed to a patched
40
+ `forward` method instead.
41
+ **kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched
42
+ `forward` method instead.
43
+ """
44
+ assert isinstance(wrapper_module, FullyShardedDataParallel)
45
+ original_module = wrapper_module._fsdp_wrapped_module
46
+ original_forward = original_module.forward
47
+
48
+ def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any:
49
+ # Unpatch ourselves immediately before calling the method `method_name`
50
+ # because itself may want to call the real `forward`
51
+ original_module.forward = original_forward # type: ignore[method-assign]
52
+ # Call the actual method e.g. `.training_step(...)`
53
+ out = method(*_args, **_kwargs)
54
+ return out
55
+
56
+ # Patch the original_module's forward so we can redirect the arguments back to the real method
57
+ original_module.forward = wrapped_forward # type: ignore[method-assign]
58
+ wrapper_output = wrapper_module(*args, **kwargs)
59
+ return wrapper_output
60
+
61
+
62
+ class LigerORPOTrainer(ORPOTrainer):
63
+ def concatenated_forward(
64
+ self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
65
+ ) -> Tuple[
66
+ torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor
67
+ ]:
68
+ """
69
+ Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
70
+ We do this to avoid doing two forward passes, because it's faster for FSDP.
71
+ """
72
+ concatenated_batch = self.concatenated_inputs(
73
+ batch,
74
+ is_encoder_decoder=self.is_encoder_decoder,
75
+ label_pad_token_id=self.label_pad_token_id,
76
+ padding_value=self.padding_value,
77
+ device=self.accelerator.device,
78
+ )
79
+
80
+ model_kwargs = (
81
+ {
82
+ "decoder_input_ids": self._shift_right(
83
+ concatenated_batch["concatenated_labels"]
84
+ ),
85
+ }
86
+ if self.is_encoder_decoder
87
+ else {}
88
+ )
89
+
90
+ if self.aux_loss_enabled:
91
+ model_kwargs["output_router_logits"] = True
92
+
93
+ if isinstance(model, FullyShardedDataParallel):
94
+ outputs = _FSDPForwardRedirection()(
95
+ model,
96
+ model._fsdp_wrapped_module.model,
97
+ concatenated_batch["concatenated_input_ids"],
98
+ attention_mask=concatenated_batch["concatenated_attention_mask"],
99
+ use_cache=False,
100
+ **model_kwargs,
101
+ )
102
+ else:
103
+ if isinstance(model, torch.nn.DataParallel):
104
+ model = model.module
105
+ outputs = model.model(
106
+ concatenated_batch["concatenated_input_ids"],
107
+ attention_mask=concatenated_batch["concatenated_attention_mask"],
108
+ use_cache=False,
109
+ **model_kwargs,
110
+ )
111
+
112
+ orpo_loss_fn = LigerFusedLinearORPOLoss(
113
+ ignore_index=self.label_pad_token_id, beta=self.beta
114
+ )
115
+
116
+ def orpo_partial(lm_head, last_hidden_state, concatenated_labels):
117
+ return orpo_loss_fn(
118
+ lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias
119
+ )
120
+
121
+ orpo_loss, aux_outputs = _FSDPForwardRedirection()(
122
+ model,
123
+ orpo_partial,
124
+ model.lm_head,
125
+ outputs.last_hidden_state,
126
+ concatenated_batch["concatenated_labels"],
127
+ )
128
+ return orpo_loss, aux_outputs
129
+
130
+ def get_batch_loss_metrics(
131
+ self,
132
+ model,
133
+ batch: Dict[str, Union[List, torch.LongTensor]],
134
+ train_eval: Literal["train", "eval"] = "train",
135
+ ):
136
+ """Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
137
+ metrics = {}
138
+ loss, aux_outputs = self.concatenated_forward(model, batch)
139
+ (
140
+ policy_chosen_logps,
141
+ policy_rejected_logps,
142
+ policy_chosen_logits,
143
+ policy_rejected_logits,
144
+ policy_nll_loss,
145
+ ) = aux_outputs[:5]
146
+
147
+ # return loss, metrics
148
+ chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = aux_outputs[
149
+ 5:
150
+ ]
151
+
152
+ reward_accuracies = (chosen_rewards > rejected_rewards).float()
153
+
154
+ prefix = "eval_" if train_eval == "eval" else ""
155
+ metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean()
156
+ metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean()
157
+ metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean()
158
+ metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean()
159
+ metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean()
160
+ metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean()
161
+ metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean()
162
+ metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean()
163
+ metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean()
164
+ metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio
165
+ metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen
166
+ for k, v in metrics.items():
167
+ metrics[k] = v.item()
168
+
169
+ return loss, metrics
liger_kernel/utils.py ADDED
@@ -0,0 +1,13 @@
1
+ import torch
2
+
3
+
4
+ def infer_device():
5
+ """
6
+ Get current device name based on available devices
7
+ """
8
+ if torch.cuda.is_available():
9
+ return "cuda"
10
+ elif torch.xpu.is_available():
11
+ return "xpu"
12
+ else:
13
+ return "cpu"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel
3
- Version: 0.4.2
3
+ Version: 0.5.1
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -32,6 +32,10 @@ License-File: LICENSE
32
32
  License-File: NOTICE
33
33
  Requires-Dist: torch>=2.1.2
34
34
  Requires-Dist: triton>=2.3.1
35
+ Provides-Extra: transformers
36
+ Requires-Dist: transformers~=4.0; extra == "transformers"
37
+ Provides-Extra: trl
38
+ Requires-Dist: trl>=0.11.0; extra == "trl"
35
39
  Provides-Extra: dev
36
40
  Requires-Dist: transformers>=4.44.2; extra == "dev"
37
41
  Requires-Dist: matplotlib>=3.7.2; extra == "dev"
@@ -39,11 +43,16 @@ Requires-Dist: flake8>=4.0.1.1; extra == "dev"
39
43
  Requires-Dist: black>=24.4.2; extra == "dev"
40
44
  Requires-Dist: isort>=5.13.2; extra == "dev"
41
45
  Requires-Dist: pytest>=7.1.2; extra == "dev"
46
+ Requires-Dist: pytest-xdist; extra == "dev"
47
+ Requires-Dist: pytest-rerunfailures; extra == "dev"
42
48
  Requires-Dist: datasets>=2.19.2; extra == "dev"
43
49
  Requires-Dist: torchvision>=0.16.2; extra == "dev"
44
50
  Requires-Dist: seaborn; extra == "dev"
45
- Provides-Extra: transformers
46
- Requires-Dist: transformers~=4.0; extra == "transformers"
51
+ Provides-Extra: amd
52
+ Requires-Dist: torch>=2.6.0.dev; extra == "amd"
53
+ Requires-Dist: setuptools-scm>=8; extra == "amd"
54
+ Requires-Dist: torchvision>=0.20.0.dev; extra == "amd"
55
+ Requires-Dist: triton>=3.0.0; extra == "amd"
47
56
 
48
57
  <a name="readme-top"></a>
49
58
 
@@ -55,7 +64,7 @@ Requires-Dist: transformers~=4.0; extra == "transformers"
55
64
  <th style="padding: 10px;" colspan="2">Stable</th>
56
65
  <th style="padding: 10px;" colspan="2">Nightly</th>
57
66
  <th style="padding: 10px;">Discord</th>
58
- <th style="padding: 10px;">Gurubase (experimental)</th>
67
+ <th style="padding: 10px;">Build</th>
59
68
  </tr>
60
69
  <tr>
61
70
  <td style="padding: 10px;">
@@ -84,9 +93,16 @@ Requires-Dist: transformers~=4.0; extra == "transformers"
84
93
  </a>
85
94
  </td>
86
95
  <td style="padding: 10px;">
87
- <a href="https://gurubase.io/g/liger-kernel">
88
- <img src="https://img.shields.io/badge/Gurubase-Ask%20Liger%20Kernel%20Guru-006BFF" alt="Ask Liger Kernel Guru">
89
- </a>
96
+ <div style="display: block;">
97
+ <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml">
98
+ <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml/badge.svg?event=schedule" alt="Build">
99
+ </a>
100
+ </div>
101
+ <div style="display: block;">
102
+ <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
103
+ <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
104
+ </a>
105
+ </div>
90
106
  </td>
91
107
  </tr>
92
108
  </table>
@@ -95,13 +111,14 @@ Requires-Dist: transformers~=4.0; extra == "transformers"
95
111
 
96
112
  <img src="https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/logo-banner.png">
97
113
 
98
- [Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Cite our work](#cite-this-work)
114
+ [Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [High-level APIs](#high-level-apis) | [Low-level APIs](#low-level-apis) | [Cite our work](#cite-this-work)
99
115
 
100
116
  <details>
101
117
  <summary>Latest News 🔥</summary>
102
118
 
119
+ - [2024/12/15] We release LinkedIn Engineering Blog - [Liger-Kernel: Empowering an open source ecosystem of Triton Kernels for Efficient LLM Training](https://www.linkedin.com/blog/engineering/open-source/liger-kernel-open-source-ecosystem-for-efficient-llm-training)
103
120
  - [2024/11/6] We release [v0.4.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.4.0): Full AMD support, Tech Report, Modal CI, Llama-3.2-Vision!
104
- - [2024/10/21] We have released the tech report of Liger Kernel on Arxiv: https://arxiv.org/pdf/2410.10989
121
+ - [2024/10/21] We have released the tech report of Liger Kernel on Arxiv: https://arxiv.org/pdf/2410.10989
105
122
  - [2024/9/6] We release v0.2.1 ([X post](https://x.com/liger_kernel/status/1832168197002510649)). 2500+ Stars, 10+ New Contributors, 50+ PRs, 50k Downloads in two weeks!
106
123
  - [2024/8/31] CUDA MODE talk, [Liger-Kernel: Real-world Triton kernel for LLM Training](https://youtu.be/gWble4FreV4?si=dxPeIchhkJ36Mbns), [Slides](https://github.com/cuda-mode/lectures?tab=readme-ov-file#lecture-28-liger-kernel)
107
124
  - [2024/8/23] Official release: check out our [X post](https://x.com/hsu_byron/status/1827072737673982056)
@@ -111,6 +128,8 @@ Requires-Dist: transformers~=4.0; extra == "transformers"
111
128
 
112
129
  **Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
113
130
 
131
+ We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more.
132
+
114
133
  ## Supercharge Your Model with Liger Kernel
115
134
 
116
135
  ![Banner](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/banner.GIF)
@@ -128,12 +147,13 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
128
147
 
129
148
  ## Examples
130
149
 
131
-
132
150
  | **Use Case** | **Description** |
133
151
  |------------------------------------------------|---------------------------------------------------------------------------------------------------|
134
152
  | [**Hugging Face Trainer**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/huggingface) | Train LLaMA 3-8B ~20% faster with over 40% memory reduction on Alpaca dataset using 4 A100s with FSDP |
135
153
  | [**Lightning Trainer**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/lightning) | Increase 15% throughput and reduce memory usage by 40% with LLaMA3-8B on MMLU dataset using 8 A100s with DeepSpeed ZeRO3 |
136
- | [**Medusa Multi-head LLM (Retraining Phase)**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) | Reduce memory usage by 80% with 5 LM heads and improve throughput by 40% using 8 A100s with FSDP | |
154
+ | [**Medusa Multi-head LLM (Retraining Phase)**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) | Reduce memory usage by 80% with 5 LM heads and improve throughput by 40% using 8 A100s with FSDP |
155
+ | [**Vision-Language Model SFT**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/huggingface/run_qwen2_vl.sh) | Finetune Qwen2-VL on image-text data using 4 A100s with FSDP |
156
+ | [**Liger ORPO Trainer**](https://github.com/linkedin/Liger-Kernel/blob/main/examples/alignment/run_orpo.py) | Align Llama 3.2 using Liger ORPO Trainer with FSDP with 50% memory reduction |
137
157
 
138
158
  ## Key Features
139
159
 
@@ -146,7 +166,7 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
146
166
 
147
167
  ## Installation
148
168
 
149
- ### Dependencies
169
+ ### Dependencies
150
170
 
151
171
  #### CUDA
152
172
 
@@ -183,6 +203,8 @@ To install from source:
183
203
  git clone https://github.com/linkedin/Liger-Kernel.git
184
204
  cd Liger-Kernel
185
205
  pip install -e .
206
+ # or if installing on amd platform
207
+ pip install -e .[amd] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2 # rocm6.2
186
208
  # or if using transformers
187
209
  pip install -e .[transformers]
188
210
  ```
@@ -249,7 +271,7 @@ loss = loss_fn(model.weight, input, target)
249
271
  loss.backward()
250
272
  ```
251
273
 
252
- ## APIs
274
+ ## High-level APIs
253
275
 
254
276
  ### AutoModel
255
277
 
@@ -268,13 +290,17 @@ loss.backward()
268
290
  | Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
269
291
  | Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
270
292
  | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
271
- | Qwen2 & Qwen2.5 | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
293
+ | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
272
294
  | Qwen2-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
273
295
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
274
296
 
275
297
 
298
+ ## Low-level APIs
299
+
300
+ - `Fused Linear` kernels combine linear layers with losses, reducing memory usage by up to 80% - ideal for HBM-constrained workloads.
301
+ - Other kernels use fusion and in-place techniques for memory and performance optimization.
276
302
 
277
- ### Kernels
303
+ ### Model Kernels
278
304
 
279
305
  | **Kernel** | **API** |
280
306
  |---------------------------------|-------------------------------------------------------------|
@@ -284,39 +310,33 @@ loss.backward()
284
310
  | SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` |
285
311
  | GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
286
312
  | CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
287
- | FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
313
+ | Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
314
+
315
+
316
+ ### Alignment Kernels
317
+
318
+ | **Kernel** | **API** |
319
+ |---------------------------------|-------------------------------------------------------------|
320
+ | Fused Linear CPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearCPOLoss` |
321
+ | Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
322
+ | Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
323
+ | Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
324
+
325
+ ### Distillation Kernels
326
+
327
+ | **Kernel** | **API** |
328
+ |---------------------------------|-------------------------------------------------------------|
288
329
  | KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
289
330
  | JSD | `liger_kernel.transformers.LigerJSD` |
290
- | FusedLinearJSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
291
-
292
- - **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
293
- - **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup.
294
- - **GroupNorm**: [GroupNorm](https://arxiv.org/pdf/1803.08494), which normalizes activations across the group dimension for a given sample. Channels are grouped in K groups over which the normalization is performed, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and can achieve up to ~2X speedup as the number of channels/groups increases.
295
- - **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction.
296
- - **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
297
- $$\text{SwiGLU}(x)=\text{Swish}_{\beta}(xW+b)\otimes(xV+c)$$
298
- , is implemented by fusing the elementwise multiplication (denoted by $\otimes$) into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction.
299
- - **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
300
- $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
301
- , is implemented by fusing the elementwise multiplication into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. Note that the [tanh approximation form of GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used.
302
- - **CrossEntropy**: [Cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) is implemented by computing both the loss and gradient in the forward pass with inplace replacement of input to reduce the peak memory by avoiding simultaneous materialization of both input logits and gradient. It achieves >2X speedup and >4X memory reduction for common vocab sizes (e.g., 32K, 128K, etc.).
303
- <!-- TODO: verify vocab sizes are accurate -->
304
- - **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
305
- - **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
306
- - **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size.
307
- - **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192.
308
-
331
+ | Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
309
332
 
310
333
  ### Experimental Kernels
311
334
 
312
335
  | **Kernel** | **API** |
313
336
  |---------------------------------|-------------------------------------------------------------|
314
337
  | Embedding | `liger_kernel.transformers.experimental.LigerEmbedding` |
315
- | Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul`
338
+ | Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul` |
316
339
 
317
- - **Embedding**: [Embedding](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) is implemented by fusing embedding lookup and output operations. It achieves a peak speedup of ~1.5x in the forward pass and an overall speedup of ~1.1x.
318
- - **Matmul int2xint8**: is implemented by using the cache tiled matrix multiplication and by fusing the matmul with the unpacking process which achieves a considerable speed up and performs on par with @torch.compile
319
- <!-- TODO: be more specific about batch size -->
320
340
 
321
341
  ## Contributing, Acknowledgements, and License
322
342
 
@@ -324,6 +344,17 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
324
344
  - [Acknowledgements](https://github.com/linkedin/Liger-Kernel/blob/main/docs/Acknowledgement.md)
325
345
  - [License Information](https://github.com/linkedin/Liger-Kernel/blob/main/docs/License.md)
326
346
 
347
+ ## Sponsorship and Collaboration
348
+
349
+ - [AMD](https://www.amd.com/en.html): Providing AMD GPUs for our AMD CI.
350
+ - [Intel](https://www.intel.com/): Providing Intel GPUs for our Intel CI.
351
+ - [Modal](https://modal.com/): Free 3000 credits from GPU MODE IRL for our NVIDIA CI.
352
+ - [EmbeddedLLM](https://embeddedllm.com/): Making Liger Kernel run fast and stable on AMD.
353
+ - [HuggingFace](https://huggingface.co/): Integrating Liger Kernel into Hugging Face Transformers and TRL.
354
+ - [Lightning AI](https://lightning.ai/): Integrating Liger Kernel into Lightning Thunder.
355
+ - [Axolotl](https://axolotl.ai/): Integrating Liger Kernel into Axolotl.
356
+ - [Llama-Factory](https://github.com/hiyouga/LLaMA-Factory): Integrating Liger Kernel into Llama-Factory.
357
+
327
358
  ## Contact
328
359
 
329
360
  - For issues, create a Github ticket in this repository
@@ -335,7 +366,7 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
335
366
  Biblatex entry:
336
367
  ```bib
337
368
  @article{hsu2024ligerkernelefficienttriton,
338
- title={Liger Kernel: Efficient Triton Kernels for LLM Training},
369
+ title={Liger Kernel: Efficient Triton Kernels for LLM Training},
339
370
  author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen},
340
371
  year={2024},
341
372
  eprint={2410.10989},
@@ -349,15 +380,8 @@ Biblatex entry:
349
380
  ## Star History
350
381
  [![Star History Chart](https://api.star-history.com/svg?repos=linkedin/Liger-Kernel&type=Date)](https://star-history.com/#linkedin/Liger-Kernel&Date)
351
382
 
352
- ## Contributors
353
-
354
- <a href="https://github.com/linkedin/Liger-Kernel/graphs/contributors">
355
- <img alt="contributors" src="https://contrib.rocks/image?repo=linkedin/Liger-Kernel"/>
356
- </a>
357
-
358
383
  <p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
359
384
  <a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
360
385
  ↑ Back to Top ↑
361
386
  </a>
362
387
  </p>
363
-
@@ -1,35 +1,43 @@
1
- liger_kernel/env_report.py,sha256=jye8RvUkmhqaIshdeIpoUABoAu7FPKJUib4FnAfvkpw,1132
2
- liger_kernel/chunked_loss/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- liger_kernel/chunked_loss/dpo_loss.py,sha256=_sftycUsxypLiQaCIoqMEwtc425Kxiq97YI6DvFvscc,1943
4
- liger_kernel/chunked_loss/fused_linear_preference.py,sha256=ayx-dmAx1TW9sThHJ_wUU1MqpZeJ4-SooGh0ZgVFlOA,8420
5
- liger_kernel/chunked_loss/orpo_loss.py,sha256=DNifPpzGV_t3dfOPlPy2XKDM6M1Qne0kCbIPztvFY9U,2179
1
+ liger_kernel/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ liger_kernel/env_report.py,sha256=FViyPju795lB6z4k2TZldvBSmQdcS0A2hcnDxepJrDo,1822
3
+ liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
4
+ liger_kernel/chunked_loss/__init__.py,sha256=R2wCcz4Y0kTAve926DH3k182XKezpXeACMHj05g9Mm8,346
5
+ liger_kernel/chunked_loss/cpo_loss.py,sha256=Qu1Ul2A12sp6CqIT-atPbHWFb_LLtINEA9mOpIRx_0g,3097
6
+ liger_kernel/chunked_loss/dpo_loss.py,sha256=H9_RRhclckHYM2sd75tgbnf8IxC_PU2JCALbgtPQvwc,4222
7
+ liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
8
+ liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=2BH6DCPjsR2zS6zcwFPcIIZRhLF8SohjGdKsAJ_301o,10222
9
+ liger_kernel/chunked_loss/fused_linear_preference.py,sha256=vlWfaaIECWvCQhY9PM7zRI0vKThIrydMf6P44bXn1EE,15114
10
+ liger_kernel/chunked_loss/orpo_loss.py,sha256=ZuKGjbkIYzV4UzvupNdq6vyxCp7-BztQkUt8ZnFvKos,3531
11
+ liger_kernel/chunked_loss/simpo_loss.py,sha256=Wa4LOlDG9PbJkOOkKg8hbKvnKgg7OTBz6-qIkwPK1yw,3275
6
12
  liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
- liger_kernel/ops/cross_entropy.py,sha256=sfUb7-jIZp0EKXjg1DYy2Wdzw_Mg-mHmGoR5bpdm4tw,15526
8
- liger_kernel/ops/fused_linear_cross_entropy.py,sha256=ib7M3AjJE164yMfuS9R39k-5qnDgYOXptIT146lqYbg,9964
9
- liger_kernel/ops/fused_linear_jsd.py,sha256=5D_obamh08lGGTMyh85kBJD_aNjPhOYf4-TmCZ6m4s4,9626
13
+ liger_kernel/ops/cross_entropy.py,sha256=oG5hfrlmnlF5lOoZRhHRglObxgH4B0KadjWMJj9EWPM,15860
14
+ liger_kernel/ops/fused_linear_cross_entropy.py,sha256=Tnw4gyAYVVdnCOqhOuLEzbUQ3goOTnoAfk3pqSIM5ac,9301
15
+ liger_kernel/ops/fused_linear_jsd.py,sha256=nOv4zwfxHqqepKEmMsQuz-B3H-gRjyo8uClpmqSGLYA,9693
10
16
  liger_kernel/ops/geglu.py,sha256=MQL4zyzneZqZYUGPvb1QjI_EYT9_pKfSDgR25WD9jrI,4127
11
17
  liger_kernel/ops/group_norm.py,sha256=VaRErVJGR4JqgXXvuIjNGTn3E2egjLtU1y3ymwIf4d8,10961
12
- liger_kernel/ops/jsd.py,sha256=anWfdioucxZy4JQfTvbHBR-IQrZKeH-gBF1MHwwTuTQ,5781
18
+ liger_kernel/ops/jsd.py,sha256=Ap2b0_geCl6fqBXLI1IS6Yn6GlO-8LgPmnOW3y47dus,6151
13
19
  liger_kernel/ops/kl_div.py,sha256=03FNXfvCb6M-56hhFepAFV9p6brArPR6KOKkdGD34mw,8374
14
- liger_kernel/ops/layer_norm.py,sha256=unGMYMOPqtkM9aTrokhcqgPmsV2AUN7Yzv86isVB9OI,7422
15
- liger_kernel/ops/rms_norm.py,sha256=LAxCiFjpBbb7TDh9pOzsVmDGAR7eEbTDnEhjSd6TX_M,11583
20
+ liger_kernel/ops/layer_norm.py,sha256=_CZggw3GNEIUx5weDzadFit5I-Lzosoo8prgeJzcViY,7589
21
+ liger_kernel/ops/qwen2vl_mrope.py,sha256=xZvQnhkSTjU-k6KiiRn9e0SYO1ESs1jmuZFMICduLpc,8552
22
+ liger_kernel/ops/rms_norm.py,sha256=g7OXwuYI8-LXudDwvXuiupVjjOsbu8c4wwv83VaHa54,11750
16
23
  liger_kernel/ops/rope.py,sha256=jrzaA9-6Orn44y_IIam9_YNPQxOFK2FrIRNfFea4EtU,8513
17
24
  liger_kernel/ops/swiglu.py,sha256=Fwxtd76rhHKT9ShQAGca9RsnASplAVxtYKHmiT73_yA,2994
18
- liger_kernel/ops/utils.py,sha256=3JSF--O7KT5Wa5BuO70M4h0XetxoZ_e9IoW9GRlxlBg,3777
25
+ liger_kernel/ops/utils.py,sha256=_VQvd1PX5JXm5xaiBrk2gANp3qr4kM7qYG3ypkBwkMs,3850
19
26
  liger_kernel/ops/experimental/embedding.py,sha256=LYR66dB-jhvhtUjeV4PnNro-n77J1mdlmpSLSxB3Y6U,4186
20
27
  liger_kernel/ops/experimental/mm_int8int2.py,sha256=JpGVZCgRC6T8XMUJ_QbZRS2XU1bh0urIZphs5DTc1mY,13358
21
28
  liger_kernel/transformers/__init__.py,sha256=gia-eBxr7TLxU0GdDf8AfCY4WgDlFLqIGSt7EoQGsBA,1336
22
29
  liger_kernel/transformers/auto_model.py,sha256=RMIwQHSiXoksXFTIqFZ4PLBgoqkxJJAT3q1Qh47bGN8,1552
23
30
  liger_kernel/transformers/cross_entropy.py,sha256=yEm_YQ7oa3_BzT3hdW6KrAslduhSqWcJQVNZZDcWCg4,1758
24
- liger_kernel/transformers/functional.py,sha256=Hd4WvxNqOJHM9HmRfAQueRnmOy5WU9nFsFygB5Iv8Xs,2000
31
+ liger_kernel/transformers/functional.py,sha256=sUBoU8Vb4pLpr9G6IdkRsToYgh-rCXL4OLYat7Tv_GU,4450
25
32
  liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=_i0PXSp5iZ9pKXdEeZ4lvHCENJYjV4y74yz3ZRG5XQg,1484
26
- liger_kernel/transformers/fused_linear_jsd.py,sha256=MJ-KjmLZnakuoVpnbDGkd95DQgvESniyrRWYzollVZM,4066
33
+ liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
27
34
  liger_kernel/transformers/geglu.py,sha256=QcrME_8ooIn0xa59LaC0aoOdRrBIFd11Y0bAyF0NfCw,1130
28
35
  liger_kernel/transformers/group_norm.py,sha256=FJ9R7mS9G1wO-GRIQ6QKSmIhnZ6nQ6GIkE4NnX_hnn0,2241
29
- liger_kernel/transformers/jsd.py,sha256=W-5CypO2mx4-bUWOxq1KScfCdoXlLoYbtt5xBnRzMs4,3056
36
+ liger_kernel/transformers/jsd.py,sha256=sbr8DnKSYZJH9pv2rpmboNijYGpZKbhb2-WSGp5_v6g,3001
30
37
  liger_kernel/transformers/kl_div.py,sha256=qVhjBg6tjRyue5iZ3NFxo8uySY4JuIFJyv0IM_50F24,431
31
38
  liger_kernel/transformers/layer_norm.py,sha256=fd6o4kSHJWolQMWxh-l1qObfgL08ruNbUoBiANKX1ow,972
32
- liger_kernel/transformers/monkey_patch.py,sha256=L1IuGmFMWYgf-u3OXCg43BUxbZKTpd7ATjjDjYoFkEM,38268
39
+ liger_kernel/transformers/monkey_patch.py,sha256=Fk2v4GZQDJzfh3Cpc6BHNJbs_tungDyWmqS9nuG9Lc4,38406
40
+ liger_kernel/transformers/qwen2vl_mrope.py,sha256=SfSQVwOe7ArrVfpmIdfZrdzCxmcj7V-YQp9zDu17-ao,1043
33
41
  liger_kernel/transformers/rms_norm.py,sha256=AHstklNIO1PLHjjCBU-TPuUD-Fl_pycJUTLlJNojbV8,1189
34
42
  liger_kernel/transformers/rope.py,sha256=m-ah8vZBYW8tfplTXCiAPMHJWlB1tdp_JPXJeWE-Boo,943
35
43
  liger_kernel/transformers/swiglu.py,sha256=0-tVJ8xEYfhxnduc16PflXFj8sZPxdx9sHUn3hfwCI4,2468
@@ -45,11 +53,13 @@ liger_kernel/transformers/model/mllama.py,sha256=mesNCgj0Ea1O-fqRD4LVxDJ1CR2abY_
45
53
  liger_kernel/transformers/model/phi3.py,sha256=xUZPlaPKwknLjHc3uUW3EPodm1h0vD3G7Qnhh51v-Io,10332
46
54
  liger_kernel/transformers/model/qwen2.py,sha256=EyhSSzQOskGjSnCsKMZpd1s5IAIlHd5PBO3q0MoCs00,9619
47
55
  liger_kernel/transformers/model/qwen2_vl.py,sha256=bIQe2bWiY--G84FhCD29Gdi64_qHP6vbcGsK6vKysQE,8547
56
+ liger_kernel/transformers/trainer/__init__.py,sha256=c4OQVJmhNOloj0JYSEc0j_cQuBbzGWILfaowUR1hmRw,210
57
+ liger_kernel/transformers/trainer/orpo_trainer.py,sha256=jko6oq_XQdBSmXubp05E-_YXOyhtB5Bj75dg5YNwOsE,7517
48
58
  liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
49
59
  liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
50
- liger_kernel-0.4.2.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
51
- liger_kernel-0.4.2.dist-info/METADATA,sha256=P9d8zHay6rXNyu58aJbTbgxtfRDXX61A15Dp-paNKpg,21530
52
- liger_kernel-0.4.2.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
53
- liger_kernel-0.4.2.dist-info/WHEEL,sha256=R06PA3UVYHThwHvxuRWMqaGcr-PuniXahwjmQRFMEkY,91
54
- liger_kernel-0.4.2.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
55
- liger_kernel-0.4.2.dist-info/RECORD,,
60
+ liger_kernel-0.5.1.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
61
+ liger_kernel-0.5.1.dist-info/METADATA,sha256=eVl4SHm0MlnvS0v9sbJxuoJMe5yVCE3DP-Z_KPzKbio,20695
62
+ liger_kernel-0.5.1.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
63
+ liger_kernel-0.5.1.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
64
+ liger_kernel-0.5.1.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
65
+ liger_kernel-0.5.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.5.0)
2
+ Generator: setuptools (75.6.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5