liger-kernel-nightly 0.4.1.dev20241114155849__tar.gz → 0.4.1.dev20241115191733__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {liger_kernel_nightly-0.4.1.dev20241114155849/src/liger_kernel_nightly.egg-info → liger_kernel_nightly-0.4.1.dev20241115191733}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/pyproject.toml +1 -1
  3. liger_kernel_nightly-0.4.1.dev20241115191733/src/liger_kernel/chunked_loss/dpo_loss.py +57 -0
  4. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/ops/fused_linear_cross_entropy.py +1 -0
  5. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733/src/liger_kernel_nightly.egg-info}/PKG-INFO +1 -1
  6. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel_nightly.egg-info/SOURCES.txt +1 -0
  7. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/LICENSE +0 -0
  8. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/NOTICE +0 -0
  9. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/README.md +0 -0
  10. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/setup.cfg +0 -0
  11. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  12. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  13. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  14. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/env_report.py +0 -0
  15. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/ops/__init__.py +0 -0
  16. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/ops/cross_entropy.py +0 -0
  17. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  18. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  19. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  20. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/ops/geglu.py +0 -0
  21. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/ops/group_norm.py +0 -0
  22. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/ops/jsd.py +0 -0
  23. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/ops/kl_div.py +0 -0
  24. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/ops/layer_norm.py +0 -0
  25. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/ops/rms_norm.py +0 -0
  26. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/ops/rope.py +0 -0
  27. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/ops/swiglu.py +0 -0
  28. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/ops/utils.py +0 -0
  29. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/transformers/__init__.py +0 -0
  30. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/transformers/auto_model.py +0 -0
  31. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  32. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  33. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/transformers/functional.py +0 -0
  34. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  35. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  36. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/transformers/geglu.py +0 -0
  37. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/transformers/group_norm.py +0 -0
  38. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/transformers/jsd.py +0 -0
  39. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/transformers/kl_div.py +0 -0
  40. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/transformers/layer_norm.py +0 -0
  41. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/transformers/model/__init__.py +0 -0
  42. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/transformers/model/gemma.py +0 -0
  43. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  44. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/transformers/model/llama.py +0 -0
  45. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/transformers/model/mistral.py +0 -0
  46. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  47. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/transformers/model/mllama.py +0 -0
  48. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/transformers/model/phi3.py +0 -0
  49. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  50. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  51. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  52. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/transformers/rms_norm.py +0 -0
  53. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/transformers/rope.py +0 -0
  54. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/transformers/swiglu.py +0 -0
  55. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  56. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/triton/__init__.py +0 -0
  57. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel/triton/monkey_patch.py +0 -0
  58. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  59. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  60. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115191733}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.1.dev20241114155849
3
+ Version: 0.4.1.dev20241115191733
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.4.1.dev20241114155849"
7
+ version = "0.4.1.dev20241115191733"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -0,0 +1,57 @@
1
+ import torch.nn.functional as F
2
+
3
+ from liger_kernel.chunked_loss.fused_linear_preference import (
4
+ LigerFusedLinearPreferenceBase,
5
+ )
6
+
7
+
8
+ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
9
+
10
+ @staticmethod
11
+ def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
12
+ """
13
+ Compute DPO loss (Direct Preference Optimization).
14
+ Args:
15
+ chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
16
+ rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
17
+ beta (float): Weight for the direct preference loss.
18
+ """
19
+ logits_diff = beta * (chosen_logps - rejected_logps)
20
+ losses = -F.logsigmoid(logits_diff)
21
+ return losses.sum()
22
+
23
+ @staticmethod
24
+ def forward(
25
+ ctx,
26
+ _input,
27
+ weight,
28
+ target,
29
+ bias=None,
30
+ ignore_index=-100,
31
+ beta=0.1,
32
+ compute_nll_loss=True,
33
+ compiled=True,
34
+ ):
35
+ """
36
+ Fused linear layer with DPO (Direct Preference Optimization) loss.
37
+ Handles both the forward and backward pass of the final linear layer with DPO loss.
38
+ """
39
+ return LigerFusedLinearPreferenceBase.forward(
40
+ ctx=ctx,
41
+ _input=_input,
42
+ weight=weight,
43
+ target=target,
44
+ bias=bias,
45
+ loss_fn=LigerFusedLinearDPOFunction.preference_loss_fn,
46
+ compute_nll_loss=compute_nll_loss,
47
+ ignore_index=ignore_index,
48
+ beta=beta,
49
+ compiled=compiled,
50
+ )
51
+
52
+ @staticmethod
53
+ def backward(ctx, grad_output):
54
+ # Get gradients for _input, weight, bias, and target from the base class
55
+ grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
56
+ # Return these gradients, followed by None for the remaining inputs
57
+ return *grads, None, None, None, None
@@ -229,6 +229,7 @@ class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
229
229
  label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
230
230
  reduction: reduction to apply
231
231
  """
232
+
232
233
  loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
233
234
  _input,
234
235
  weight,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.1.dev20241114155849
3
+ Version: 0.4.1.dev20241115191733
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,6 +4,7 @@ README.md
4
4
  pyproject.toml
5
5
  src/liger_kernel/env_report.py
6
6
  src/liger_kernel/chunked_loss/__init__.py
7
+ src/liger_kernel/chunked_loss/dpo_loss.py
7
8
  src/liger_kernel/chunked_loss/fused_linear_preference.py
8
9
  src/liger_kernel/chunked_loss/orpo_loss.py
9
10
  src/liger_kernel/ops/__init__.py