liger-kernel-nightly 0.4.1.dev20241114155849__tar.gz → 0.4.1.dev20241115012952__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {liger_kernel_nightly-0.4.1.dev20241114155849/src/liger_kernel_nightly.egg-info → liger_kernel_nightly-0.4.1.dev20241115012952}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/pyproject.toml +1 -1
  3. liger_kernel_nightly-0.4.1.dev20241115012952/src/liger_kernel/chunked_loss/dpo_loss.py +57 -0
  4. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952/src/liger_kernel_nightly.egg-info}/PKG-INFO +1 -1
  5. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel_nightly.egg-info/SOURCES.txt +1 -0
  6. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/LICENSE +0 -0
  7. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/NOTICE +0 -0
  8. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/README.md +0 -0
  9. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/setup.cfg +0 -0
  10. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  11. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  12. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  13. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/env_report.py +0 -0
  14. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/ops/__init__.py +0 -0
  15. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/ops/cross_entropy.py +0 -0
  16. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  17. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  18. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  19. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  20. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/ops/geglu.py +0 -0
  21. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/ops/group_norm.py +0 -0
  22. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/ops/jsd.py +0 -0
  23. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/ops/kl_div.py +0 -0
  24. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/ops/layer_norm.py +0 -0
  25. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/ops/rms_norm.py +0 -0
  26. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/ops/rope.py +0 -0
  27. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/ops/swiglu.py +0 -0
  28. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/ops/utils.py +0 -0
  29. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/transformers/__init__.py +0 -0
  30. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/transformers/auto_model.py +0 -0
  31. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  32. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  33. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/transformers/functional.py +0 -0
  34. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  35. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  36. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/transformers/geglu.py +0 -0
  37. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/transformers/group_norm.py +0 -0
  38. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/transformers/jsd.py +0 -0
  39. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/transformers/kl_div.py +0 -0
  40. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/transformers/layer_norm.py +0 -0
  41. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/transformers/model/__init__.py +0 -0
  42. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/transformers/model/gemma.py +0 -0
  43. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  44. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/transformers/model/llama.py +0 -0
  45. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/transformers/model/mistral.py +0 -0
  46. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  47. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/transformers/model/mllama.py +0 -0
  48. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/transformers/model/phi3.py +0 -0
  49. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  50. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  51. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  52. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/transformers/rms_norm.py +0 -0
  53. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/transformers/rope.py +0 -0
  54. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/transformers/swiglu.py +0 -0
  55. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  56. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/triton/__init__.py +0 -0
  57. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel/triton/monkey_patch.py +0 -0
  58. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  59. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  60. {liger_kernel_nightly-0.4.1.dev20241114155849 → liger_kernel_nightly-0.4.1.dev20241115012952}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.1.dev20241114155849
3
+ Version: 0.4.1.dev20241115012952
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.4.1.dev20241114155849"
7
+ version = "0.4.1.dev20241115012952"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -0,0 +1,57 @@
1
+ import torch.nn.functional as F
2
+
3
+ from liger_kernel.chunked_loss.fused_linear_preference import (
4
+ LigerFusedLinearPreferenceBase,
5
+ )
6
+
7
+
8
+ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
9
+
10
+ @staticmethod
11
+ def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
12
+ """
13
+ Compute DPO loss (Direct Preference Optimization).
14
+ Args:
15
+ chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
16
+ rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
17
+ beta (float): Weight for the direct preference loss.
18
+ """
19
+ logits_diff = beta * (chosen_logps - rejected_logps)
20
+ losses = -F.logsigmoid(logits_diff)
21
+ return losses.sum()
22
+
23
+ @staticmethod
24
+ def forward(
25
+ ctx,
26
+ _input,
27
+ weight,
28
+ target,
29
+ bias=None,
30
+ ignore_index=-100,
31
+ beta=0.1,
32
+ compute_nll_loss=True,
33
+ compiled=True,
34
+ ):
35
+ """
36
+ Fused linear layer with DPO (Direct Preference Optimization) loss.
37
+ Handles both the forward and backward pass of the final linear layer with DPO loss.
38
+ """
39
+ return LigerFusedLinearPreferenceBase.forward(
40
+ ctx=ctx,
41
+ _input=_input,
42
+ weight=weight,
43
+ target=target,
44
+ bias=bias,
45
+ loss_fn=LigerFusedLinearDPOFunction.preference_loss_fn,
46
+ compute_nll_loss=compute_nll_loss,
47
+ ignore_index=ignore_index,
48
+ beta=beta,
49
+ compiled=compiled,
50
+ )
51
+
52
+ @staticmethod
53
+ def backward(ctx, grad_output):
54
+ # Get gradients for _input, weight, bias, and target from the base class
55
+ grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
56
+ # Return these gradients, followed by None for the remaining inputs
57
+ return *grads, None, None, None, None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.4.1.dev20241114155849
3
+ Version: 0.4.1.dev20241115012952
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,6 +4,7 @@ README.md
4
4
  pyproject.toml
5
5
  src/liger_kernel/env_report.py
6
6
  src/liger_kernel/chunked_loss/__init__.py
7
+ src/liger_kernel/chunked_loss/dpo_loss.py
7
8
  src/liger_kernel/chunked_loss/fused_linear_preference.py
8
9
  src/liger_kernel/chunked_loss/orpo_loss.py
9
10
  src/liger_kernel/ops/__init__.py