liger-kernel-nightly 0.4.2.dev20241117192137__tar.gz → 0.4.2.dev20241119054334__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. {liger_kernel_nightly-0.4.2.dev20241117192137/src/liger_kernel_nightly.egg-info → liger_kernel_nightly-0.4.2.dev20241119054334}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/pyproject.toml +1 -1
  3. liger_kernel_nightly-0.4.2.dev20241119054334/src/liger_kernel/chunked_loss/cpo_loss.py +61 -0
  4. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/chunked_loss/fused_linear_preference.py +6 -1
  5. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334/src/liger_kernel_nightly.egg-info}/PKG-INFO +1 -1
  6. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel_nightly.egg-info/SOURCES.txt +1 -0
  7. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/LICENSE +0 -0
  8. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/NOTICE +0 -0
  9. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/README.md +0 -0
  10. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/setup.cfg +0 -0
  11. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  12. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  13. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  14. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/env_report.py +0 -0
  15. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/ops/__init__.py +0 -0
  16. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/ops/cross_entropy.py +0 -0
  17. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  18. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  19. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  20. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  21. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/ops/geglu.py +0 -0
  22. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/ops/group_norm.py +0 -0
  23. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/ops/jsd.py +0 -0
  24. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/ops/kl_div.py +0 -0
  25. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/ops/layer_norm.py +0 -0
  26. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/ops/rms_norm.py +0 -0
  27. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/ops/rope.py +0 -0
  28. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/ops/swiglu.py +0 -0
  29. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/ops/utils.py +0 -0
  30. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/transformers/__init__.py +0 -0
  31. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/transformers/auto_model.py +0 -0
  32. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  33. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  34. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/transformers/functional.py +0 -0
  35. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  36. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  37. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/transformers/geglu.py +0 -0
  38. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/transformers/group_norm.py +0 -0
  39. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/transformers/jsd.py +0 -0
  40. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/transformers/kl_div.py +0 -0
  41. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/transformers/layer_norm.py +0 -0
  42. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/transformers/model/__init__.py +0 -0
  43. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/transformers/model/gemma.py +0 -0
  44. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  45. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/transformers/model/llama.py +0 -0
  46. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/transformers/model/mistral.py +0 -0
  47. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  48. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/transformers/model/mllama.py +0 -0
  49. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/transformers/model/phi3.py +0 -0
  50. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  51. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  52. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  53. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/transformers/rms_norm.py +0 -0
  54. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/transformers/rope.py +0 -0
  55. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/transformers/swiglu.py +0 -0
  56. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  57. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/triton/__init__.py +0 -0
  58. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel/triton/monkey_patch.py +0 -0
  59. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  60. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  61. {liger_kernel_nightly-0.4.2.dev20241117192137 → liger_kernel_nightly-0.4.2.dev20241119054334}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.2.dev20241117192137
3
+ Version: 0.4.2.dev20241119054334
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.4.2.dev20241117192137"
7
+ version = "0.4.2.dev20241119054334"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -0,0 +1,61 @@
1
+ import torch.nn.functional as F
2
+
3
+ from liger_kernel.chunked_loss.fused_linear_preference import (
4
+ LigerFusedLinearPreferenceBase,
5
+ )
6
+
7
+
8
+ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
9
+
10
+ @staticmethod
11
+ def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
12
+ """
13
+ Compute odds-ratio loss.
14
+ Args:
15
+ chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
16
+ rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
17
+ beta (float): Weight for the odds ratio loss.
18
+ """
19
+ logits = beta * (chosen_logps - rejected_logps)
20
+ loss = F.logsigmoid(logits).mean()
21
+ return loss
22
+
23
+ @staticmethod
24
+ def forward(
25
+ ctx,
26
+ _input,
27
+ weight,
28
+ target,
29
+ bias=None,
30
+ ignore_index=-100,
31
+ beta=0.1,
32
+ alpha=1.0,
33
+ compute_nll_loss=True,
34
+ compiled=True,
35
+ ):
36
+ """
37
+ Fused linear layer with CPO (Odds-Ratio Preference Optimization) loss.
38
+ Handles both the forward and backward pass of the final linear layer with CPO loss.
39
+ Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
40
+ """
41
+
42
+ return LigerFusedLinearPreferenceBase.forward(
43
+ ctx,
44
+ _input,
45
+ weight,
46
+ target,
47
+ bias,
48
+ loss_fn=LigerFusedLinearCPOFunction.preference_loss_fn,
49
+ compute_nll_loss=compute_nll_loss,
50
+ ignore_index=ignore_index,
51
+ alpha=alpha,
52
+ beta=beta,
53
+ compiled=compiled,
54
+ )
55
+
56
+ @staticmethod
57
+ def backward(ctx, grad_output):
58
+ # Get gradients for _input, weight, bias, and target from the base class
59
+ grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
60
+ # Return these gradients, followed by None for the remaining inputs
61
+ return *grads, None, None, None, None, None
@@ -29,6 +29,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
29
29
  chunk_size=1,
30
30
  compute_nll_loss=True,
31
31
  ignore_index=-100,
32
+ alpha=1.0,
32
33
  beta=0.1,
33
34
  compiled=True,
34
35
  ):
@@ -45,6 +46,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
45
46
  chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
46
47
  compute_nll_loss (bool): Whether to compute NLL loss.
47
48
  ignore_index (int): Index to ignore for loss computation.
49
+ alpha (float): Weight for the NLL loss.
48
50
  beta (float): Weight for the odds ratio loss.
49
51
  compiled (bool): Whether to use torch compile for chunk accumulation.
50
52
  """
@@ -62,6 +64,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
62
64
  LigerFusedLinearPreferenceBase._compute_loss,
63
65
  preference_loss_fn=loss_fn,
64
66
  ignore_index=ignore_index,
67
+ alpha=alpha,
65
68
  beta=beta,
66
69
  compute_nll_loss=compute_nll_loss,
67
70
  full_target=target,
@@ -149,6 +152,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
149
152
  preference_loss_fn=None,
150
153
  full_target=None,
151
154
  ignore_index=-100,
155
+ alpha=1.0,
152
156
  beta=0.1,
153
157
  compute_nll_loss=True,
154
158
  **loss_kwargs,
@@ -163,6 +167,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
163
167
  bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
164
168
  full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
165
169
  ignore_index (int): Index to ignore for loss computation.
170
+ alpha (float): Weight for the NLL loss.
166
171
  beta (float): Weight for the odds ratio loss.
167
172
  loss_kwargs (dict): Additional arguments for the loss function.
168
173
  """
@@ -202,5 +207,5 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
202
207
  )
203
208
  alignment_loss = alignment_loss / (full_target.shape[0] // 2)
204
209
 
205
- loss = chosen_nll_loss - alignment_loss
210
+ loss = alpha * chosen_nll_loss - alignment_loss
206
211
  return loss, (alignment_loss, chosen_logps, rejected_logps)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.2.dev20241117192137
3
+ Version: 0.4.2.dev20241119054334
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,6 +4,7 @@ README.md
4
4
  pyproject.toml
5
5
  src/liger_kernel/env_report.py
6
6
  src/liger_kernel/chunked_loss/__init__.py
7
+ src/liger_kernel/chunked_loss/cpo_loss.py
7
8
  src/liger_kernel/chunked_loss/dpo_loss.py
8
9
  src/liger_kernel/chunked_loss/fused_linear_preference.py
9
10
  src/liger_kernel/chunked_loss/orpo_loss.py