liger-kernel-nightly 0.5.2.dev20241211213024__py3-none-any.whl → 0.5.2.dev20241212000548__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/README.md +25 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +2 -6
- {liger_kernel_nightly-0.5.2.dev20241211213024.dist-info → liger_kernel_nightly-0.5.2.dev20241212000548.dist-info}/METADATA +11 -15
- {liger_kernel_nightly-0.5.2.dev20241211213024.dist-info → liger_kernel_nightly-0.5.2.dev20241212000548.dist-info}/RECORD +8 -7
- {liger_kernel_nightly-0.5.2.dev20241211213024.dist-info → liger_kernel_nightly-0.5.2.dev20241212000548.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241211213024.dist-info → liger_kernel_nightly-0.5.2.dev20241212000548.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.2.dev20241211213024.dist-info → liger_kernel_nightly-0.5.2.dev20241212000548.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.2.dev20241211213024.dist-info → liger_kernel_nightly-0.5.2.dev20241212000548.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,25 @@
|
|
1
|
+
# Liger FlexChunkLoss: Alignment and Distillation loss
|
2
|
+
|
3
|
+
Liger FlexChunkLoss offers a versatile interface, delivering up to 80% memory savings and a 10% throughput boost for post-training loss functions, including alignment (DPO, ORPO, CPO) and very soon, distillation. Its flexible design supports custom losses, ensuring efficiency gains across diverse use cases.
|
4
|
+
|
5
|
+
### User interface
|
6
|
+
|
7
|
+
FlexChunkLoss offers two flexible usage options:
|
8
|
+
|
9
|
+
1. **Via `Liger[Custom Loss]Trainer`**
|
10
|
+
For example, by simply replacing the HuggingFace `ORPOTrainer` with `LigerORPOTrainer` in your code, you can leverage our optimized ORPO implementation and immediately benefit from improved performance.
|
11
|
+
|
12
|
+
2. **Using `nn.Module` Implementations of Custom Loss Functions**
|
13
|
+
Explore the [LigerORPOTrainer implementation](https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/orpo_trainer.py) to see how the modular design integrates custom loss functions seamlessly.
|
14
|
+
|
15
|
+
### What's under the hood?
|
16
|
+
|
17
|
+
We employ chunking and fused kernel optimizations to enhance performance. By fusing the final linear layer with loss computation and calculating backward gradients during the forward pass, we significantly reduce the need for storing intermediate activations. All operations are implemented in PyTorch, leveraging `torch.compile` to streamline kernel execution without relying on extensive low-level optimizations. Additionally, we minimize `torch.compile` recompilations to reduce overhead and ensure consistent performance gains.
|
18
|
+
|
19
|
+
### Extending to custom loss functions
|
20
|
+
|
21
|
+
We provide two base classes: `LigerFusedLinearPreferenceBase` for alignment use cases and `LigerFusedLinearDistillationBase` for distillation use cases. These base classes manage chunking, kernel fusions, and Torch compilation.
|
22
|
+
|
23
|
+
To implement a custom loss function, you need to create a subclass that defines the custom preference or distillation loss function, capable of processing a given input chunk. The base class will take care of the optimizations, handling most of the heavy lifting for you.
|
24
|
+
|
25
|
+
For a working example, refer to the [ORPO loss implementation](https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/chunked_loss/orpo_loss.py).
|
@@ -29,7 +29,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
29
29
|
compute_nll_loss=True,
|
30
30
|
compiled=True,
|
31
31
|
use_ref_model=False,
|
32
|
-
|
32
|
+
# TODO: ref input
|
33
33
|
ref_weight=None,
|
34
34
|
ref_bias=None,
|
35
35
|
**loss_kwargs,
|
@@ -59,7 +59,6 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
59
59
|
compute_nll_loss (bool): Whether to compute NLL loss.
|
60
60
|
compiled (bool): Whether to use torch compile for chunk accumulation.
|
61
61
|
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
62
|
-
ref_input (torch.Tensor): Reference input tensor. Shape: (batch_size, seq_len, hidden_size).
|
63
62
|
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
64
63
|
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
65
64
|
loss_kwargs (dict): Other possible arguments that a loss function might need
|
@@ -93,7 +92,6 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
93
92
|
compute_nll_loss=compute_nll_loss,
|
94
93
|
full_target=target,
|
95
94
|
use_ref_model=use_ref_model,
|
96
|
-
ref_input=ref_input,
|
97
95
|
ref_weight=ref_weight,
|
98
96
|
ref_bias=ref_bias,
|
99
97
|
**loss_kwargs,
|
@@ -303,7 +301,6 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
303
301
|
beta=0.1,
|
304
302
|
compute_nll_loss=True,
|
305
303
|
use_ref_model=False,
|
306
|
-
ref_input=None,
|
307
304
|
ref_weight=None,
|
308
305
|
ref_bias=None,
|
309
306
|
**loss_kwargs,
|
@@ -322,7 +319,6 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
322
319
|
beta (float): Weight for the preference loss.
|
323
320
|
compute_nll_loss (bool): Whether to compute NLL loss.
|
324
321
|
use_ref_model (bool): Whether to use a reference model for the alignment loss.
|
325
|
-
ref_input (torch.Tensor): Reference input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
|
326
322
|
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
|
327
323
|
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
|
328
324
|
loss_kwargs (dict): Additional arguments for the loss function.
|
@@ -361,7 +357,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
|
|
361
357
|
ref_rejected_logits,
|
362
358
|
ref_chosen_nll_loss,
|
363
359
|
) = LigerFusedLinearPreferenceBase.chunk_forward(
|
364
|
-
|
360
|
+
input_chunk,
|
365
361
|
ref_weight,
|
366
362
|
target_chunk,
|
367
363
|
ref_bias,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: liger_kernel_nightly
|
3
|
-
Version: 0.5.2.
|
3
|
+
Version: 0.5.2.dev20241212000548
|
4
4
|
Summary: Efficient Triton kernels for LLM Training
|
5
5
|
License: BSD 2-CLAUSE LICENSE
|
6
6
|
Copyright 2024 LinkedIn Corporation
|
@@ -32,11 +32,6 @@ License-File: LICENSE
|
|
32
32
|
License-File: NOTICE
|
33
33
|
Requires-Dist: torch>=2.1.2
|
34
34
|
Requires-Dist: triton>=2.3.1
|
35
|
-
Provides-Extra: amd
|
36
|
-
Requires-Dist: torch>=2.6.0.dev; extra == "amd"
|
37
|
-
Requires-Dist: setuptools-scm>=8; extra == "amd"
|
38
|
-
Requires-Dist: torchvision>=0.20.0.dev; extra == "amd"
|
39
|
-
Requires-Dist: triton>=3.0.0; extra == "amd"
|
40
35
|
Provides-Extra: dev
|
41
36
|
Requires-Dist: transformers>=4.44.2; extra == "dev"
|
42
37
|
Requires-Dist: matplotlib>=3.7.2; extra == "dev"
|
@@ -47,12 +42,11 @@ Requires-Dist: pytest>=7.1.2; extra == "dev"
|
|
47
42
|
Requires-Dist: pytest-xdist; extra == "dev"
|
48
43
|
Requires-Dist: pytest-rerunfailures; extra == "dev"
|
49
44
|
Requires-Dist: datasets>=2.19.2; extra == "dev"
|
50
|
-
Requires-Dist: torchvision>=0.16.2; extra == "dev"
|
51
45
|
Requires-Dist: seaborn; extra == "dev"
|
52
|
-
Provides-Extra:
|
53
|
-
Requires-Dist:
|
54
|
-
|
55
|
-
Requires-Dist:
|
46
|
+
Provides-Extra: fmt
|
47
|
+
Requires-Dist: flake8; extra == "fmt"
|
48
|
+
Requires-Dist: isort; extra == "fmt"
|
49
|
+
Requires-Dist: black; extra == "fmt"
|
56
50
|
|
57
51
|
<a name="readme-top"></a>
|
58
52
|
|
@@ -202,11 +196,13 @@ To install from source:
|
|
202
196
|
```bash
|
203
197
|
git clone https://github.com/linkedin/Liger-Kernel.git
|
204
198
|
cd Liger-Kernel
|
199
|
+
|
200
|
+
# Install Default Dependencies
|
201
|
+
# Setup.py will detect whether you are using AMD or NVIDIA
|
205
202
|
pip install -e .
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
pip install -e .[transformers]
|
203
|
+
|
204
|
+
# Setup Development Dependencies
|
205
|
+
pip install -e ".[dev]"
|
210
206
|
```
|
211
207
|
|
212
208
|
|
@@ -1,12 +1,13 @@
|
|
1
1
|
liger_kernel/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
liger_kernel/env_report.py,sha256=FViyPju795lB6z4k2TZldvBSmQdcS0A2hcnDxepJrDo,1822
|
3
3
|
liger_kernel/utils.py,sha256=HJa-xVKOohDn6pLVIx-Fv0V9h0QAL3qZGQNRICI-OpI,249
|
4
|
+
liger_kernel/chunked_loss/README.md,sha256=K6rucm6nqHpWCmxUOhBYcE3apwQxAy0TfRUippR7Icw,2243
|
4
5
|
liger_kernel/chunked_loss/__init__.py,sha256=R2wCcz4Y0kTAve926DH3k182XKezpXeACMHj05g9Mm8,346
|
5
6
|
liger_kernel/chunked_loss/cpo_loss.py,sha256=Qu1Ul2A12sp6CqIT-atPbHWFb_LLtINEA9mOpIRx_0g,3097
|
6
7
|
liger_kernel/chunked_loss/dpo_loss.py,sha256=H9_RRhclckHYM2sd75tgbnf8IxC_PU2JCALbgtPQvwc,4222
|
7
8
|
liger_kernel/chunked_loss/functional.py,sha256=9Gr-YXIuEzEJkBUhDx3G2fuQayckLor7cC7svhmPML4,549
|
8
9
|
liger_kernel/chunked_loss/fused_linear_distillation.py,sha256=2BH6DCPjsR2zS6zcwFPcIIZRhLF8SohjGdKsAJ_301o,10222
|
9
|
-
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=
|
10
|
+
liger_kernel/chunked_loss/fused_linear_preference.py,sha256=vlWfaaIECWvCQhY9PM7zRI0vKThIrydMf6P44bXn1EE,15114
|
10
11
|
liger_kernel/chunked_loss/orpo_loss.py,sha256=ZuKGjbkIYzV4UzvupNdq6vyxCp7-BztQkUt8ZnFvKos,3531
|
11
12
|
liger_kernel/chunked_loss/simpo_loss.py,sha256=Wa4LOlDG9PbJkOOkKg8hbKvnKgg7OTBz6-qIkwPK1yw,3275
|
12
13
|
liger_kernel/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -57,9 +58,9 @@ liger_kernel/transformers/trainer/__init__.py,sha256=c4OQVJmhNOloj0JYSEc0j_cQuBb
|
|
57
58
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=jko6oq_XQdBSmXubp05E-_YXOyhtB5Bj75dg5YNwOsE,7517
|
58
59
|
liger_kernel/triton/__init__.py,sha256=yfRe0zMb47QnqjecZWG7LnanfCTzeku7SgWRAwNVmzU,101
|
59
60
|
liger_kernel/triton/monkey_patch.py,sha256=5BcGKTtdqeYchypBIBopGIWPx1-cFALz7sOKoEsqXJ0,1584
|
60
|
-
liger_kernel_nightly-0.5.2.
|
61
|
-
liger_kernel_nightly-0.5.2.
|
62
|
-
liger_kernel_nightly-0.5.2.
|
63
|
-
liger_kernel_nightly-0.5.2.
|
64
|
-
liger_kernel_nightly-0.5.2.
|
65
|
-
liger_kernel_nightly-0.5.2.
|
61
|
+
liger_kernel_nightly-0.5.2.dev20241212000548.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
62
|
+
liger_kernel_nightly-0.5.2.dev20241212000548.dist-info/METADATA,sha256=NfFECBU1FHBc34_9Ybi5h4iFRUTmKUeNCcdqvPzhbR4,20392
|
63
|
+
liger_kernel_nightly-0.5.2.dev20241212000548.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
64
|
+
liger_kernel_nightly-0.5.2.dev20241212000548.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
65
|
+
liger_kernel_nightly-0.5.2.dev20241212000548.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
66
|
+
liger_kernel_nightly-0.5.2.dev20241212000548.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|