liger-kernel-nightly 0.0.1.dev20240819184814__py3-none-any.whl → 0.6.4.dev20251212103629__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +366 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +307 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +63 -0
  18. liger_kernel/ops/__init__.py +141 -0
  19. liger_kernel/ops/backends/README.md +151 -0
  20. liger_kernel/ops/backends/__init__.py +13 -0
  21. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  22. liger_kernel/ops/backends/_ascend/ops/__init__.py +15 -0
  23. liger_kernel/ops/backends/registry.py +61 -0
  24. liger_kernel/ops/cross_entropy.py +383 -114
  25. liger_kernel/ops/dyt.py +160 -0
  26. liger_kernel/ops/experimental/embedding.py +141 -0
  27. liger_kernel/ops/experimental/mm_int8int2.py +349 -0
  28. liger_kernel/ops/fused_add_rms_norm.py +416 -0
  29. liger_kernel/ops/fused_linear_cross_entropy.py +346 -132
  30. liger_kernel/ops/fused_linear_jsd.py +228 -0
  31. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  32. liger_kernel/ops/geglu.py +66 -64
  33. liger_kernel/ops/group_norm.py +306 -0
  34. liger_kernel/ops/grpo_loss.py +312 -0
  35. liger_kernel/ops/jsd.py +201 -0
  36. liger_kernel/ops/kl_div.py +262 -0
  37. liger_kernel/ops/layer_norm.py +320 -0
  38. liger_kernel/ops/llama4_rope.py +225 -0
  39. liger_kernel/ops/multi_token_attention.py +207 -0
  40. liger_kernel/ops/poly_norm.py +390 -0
  41. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  42. liger_kernel/ops/rms_norm.py +484 -88
  43. liger_kernel/ops/rope.py +122 -117
  44. liger_kernel/ops/softmax.py +201 -0
  45. liger_kernel/ops/sparsemax.py +179 -0
  46. liger_kernel/ops/swiglu.py +68 -65
  47. liger_kernel/ops/tiled_mlp.py +136 -0
  48. liger_kernel/ops/tvd.py +207 -0
  49. liger_kernel/ops/utils.py +82 -3
  50. liger_kernel/transformers/__init__.py +218 -6
  51. liger_kernel/transformers/auto_model.py +38 -0
  52. liger_kernel/transformers/cross_entropy.py +52 -7
  53. liger_kernel/transformers/dyt.py +22 -0
  54. liger_kernel/transformers/experimental/__init__.py +5 -0
  55. liger_kernel/transformers/experimental/embedding.py +26 -0
  56. liger_kernel/transformers/fsdp.py +55 -0
  57. liger_kernel/transformers/functional.py +301 -0
  58. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  59. liger_kernel/transformers/fused_linear_cross_entropy.py +59 -10
  60. liger_kernel/transformers/fused_linear_jsd.py +95 -0
  61. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  62. liger_kernel/transformers/geglu.py +6 -7
  63. liger_kernel/transformers/group_norm.py +50 -0
  64. liger_kernel/transformers/grpo_loss.py +153 -0
  65. liger_kernel/transformers/jsd.py +70 -0
  66. liger_kernel/transformers/kl_div.py +12 -0
  67. liger_kernel/transformers/layer_norm.py +24 -0
  68. liger_kernel/transformers/llama4_rope.py +93 -0
  69. liger_kernel/transformers/model/falcon_h1.py +122 -0
  70. liger_kernel/transformers/model/gemma.py +261 -0
  71. liger_kernel/transformers/model/gemma2.py +283 -0
  72. liger_kernel/transformers/model/gemma3.py +332 -0
  73. liger_kernel/transformers/model/glm4.py +141 -0
  74. liger_kernel/transformers/model/glm4v.py +163 -0
  75. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  76. liger_kernel/transformers/model/gpt_oss.py +211 -0
  77. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  78. liger_kernel/transformers/model/internvl.py +157 -0
  79. liger_kernel/transformers/model/llama.py +221 -41
  80. liger_kernel/transformers/model/llama4.py +121 -0
  81. liger_kernel/transformers/model/llava.py +344 -0
  82. liger_kernel/transformers/model/loss_utils.py +95 -0
  83. liger_kernel/transformers/model/mistral.py +145 -0
  84. liger_kernel/transformers/model/mixtral.py +293 -0
  85. liger_kernel/transformers/model/mllama.py +269 -0
  86. liger_kernel/transformers/model/olmo2.py +141 -0
  87. liger_kernel/transformers/model/olmo3.py +142 -0
  88. liger_kernel/transformers/model/output_classes.py +147 -0
  89. liger_kernel/transformers/model/paligemma.py +433 -0
  90. liger_kernel/transformers/model/phi3.py +120 -0
  91. liger_kernel/transformers/model/qwen2.py +259 -0
  92. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  93. liger_kernel/transformers/model/qwen2_vl.py +159 -0
  94. liger_kernel/transformers/model/qwen3.py +136 -0
  95. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  96. liger_kernel/transformers/model/qwen3_next.py +146 -0
  97. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  98. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  99. liger_kernel/transformers/model/smollm3.py +199 -0
  100. liger_kernel/transformers/model/smolvlm.py +158 -0
  101. liger_kernel/transformers/monkey_patch.py +2816 -21
  102. liger_kernel/transformers/multi_token_attention.py +64 -0
  103. liger_kernel/transformers/poly_norm.py +42 -0
  104. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  105. liger_kernel/transformers/rms_norm.py +75 -5
  106. liger_kernel/transformers/rope.py +47 -3
  107. liger_kernel/transformers/softmax.py +12 -0
  108. liger_kernel/transformers/sparsemax.py +16 -0
  109. liger_kernel/transformers/swiglu.py +62 -6
  110. liger_kernel/transformers/tiled_mlp.py +133 -0
  111. liger_kernel/transformers/trainer/__init__.py +4 -0
  112. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  113. liger_kernel/transformers/trainer_integration.py +2 -45
  114. liger_kernel/transformers/tvd.py +13 -0
  115. liger_kernel/triton/__init__.py +1 -3
  116. liger_kernel/triton/monkey_patch.py +1 -5
  117. liger_kernel/utils.py +96 -0
  118. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/METADATA +447 -0
  119. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/NOTICE +58 -0
  120. liger_kernel_nightly-0.6.4.dev20251212103629.dist-info/RECORD +124 -0
  121. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/WHEEL +1 -1
  122. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/METADATA +0 -21
  123. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/NOTICE +0 -4
  124. liger_kernel_nightly-0.0.1.dev20240819184814.dist-info/RECORD +0 -27
  125. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/LICENSE +0 -0
  126. {liger_kernel_nightly-0.0.1.dev20240819184814.dist-info → liger_kernel_nightly-0.6.4.dev20251212103629.dist-info}/top_level.txt +0 -0
File without changes
@@ -0,0 +1,25 @@
1
+ # Liger FlexChunkLoss: Alignment and Distillation loss
2
+
3
+ Liger FlexChunkLoss offers a versatile interface, delivering up to 80% memory savings and a 10% throughput boost for post-training loss functions, including alignment (DPO, ORPO, CPO, KTO) and very soon, distillation. Its flexible design supports custom losses, ensuring efficiency gains across diverse use cases.
4
+
5
+ ### User interface
6
+
7
+ FlexChunkLoss offers two flexible usage options:
8
+
9
+ 1. **Via `Liger[Custom Loss]Trainer`**
10
+ For example, by simply replacing the HuggingFace `ORPOTrainer` with `LigerORPOTrainer` in your code, you can leverage our optimized ORPO implementation and immediately benefit from improved performance.
11
+
12
+ 2. **Using `nn.Module` Implementations of Custom Loss Functions**
13
+ Explore the [LigerORPOTrainer implementation](https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/orpo_trainer.py) to see how the modular design integrates custom loss functions seamlessly.
14
+
15
+ ### What's under the hood?
16
+
17
+ We employ chunking and fused kernel optimizations to enhance performance. By fusing the final linear layer with loss computation and calculating backward gradients during the forward pass, we significantly reduce the need for storing intermediate activations. All operations are implemented in PyTorch, leveraging `torch.compile` to streamline kernel execution without relying on extensive low-level optimizations. Additionally, we minimize `torch.compile` recompilations to reduce overhead and ensure consistent performance gains.
18
+
19
+ ### Extending to custom loss functions
20
+
21
+ We provide two base classes: `LigerFusedLinearPreferenceBase` for alignment use cases and `LigerFusedLinearDistillationBase` for distillation use cases. These base classes manage chunking, kernel fusions, and Torch compilation.
22
+
23
+ To implement a custom loss function, you need to create a subclass that defines the custom preference or distillation loss function, capable of processing a given input chunk. The base class will take care of the optimizations, handling most of the heavy lifting for you.
24
+
25
+ For a working example, refer to the [ORPO loss implementation](https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/chunked_loss/orpo_loss.py).
@@ -0,0 +1,8 @@
1
+ from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityLoss # noqa:F401
2
+ from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
3
+ from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
4
+ from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401
5
+ from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDLoss # noqa: F401
6
+ from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401
7
+ from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
8
+ from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401
@@ -0,0 +1,136 @@
1
+ from typing import Tuple
2
+ from typing import Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase
8
+
9
+
10
+ class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase):
11
+ @staticmethod
12
+ def distillation_loss_fn(student_logits, teacher_logits, beta=1.0):
13
+ """
14
+ Compute Cosine loss (Cosine Similarity Loss).
15
+ Args:
16
+ student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
17
+ teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
18
+ beta: Coefficient beta of generalized Cosine Similarity in the interval [0, 1]. Default: `1.0` (float): .
19
+ Returns:
20
+ torch.Tensor: cosine similarity loss
21
+ """
22
+ student_norm = F.normalize(student_logits, p=2, dim=-1)
23
+ teacher_norm = F.normalize(teacher_logits, p=2, dim=-1)
24
+
25
+ cosine_sim = F.cosine_similarity(student_norm, teacher_norm, dim=-1)
26
+ loss = beta * (1 - cosine_sim)
27
+ return loss.sum()
28
+
29
+ @classmethod
30
+ def forward(
31
+ cls,
32
+ ctx,
33
+ student_input: torch.Tensor,
34
+ student_weight: torch.Tensor,
35
+ teacher_input: torch.Tensor,
36
+ teacher_weight: torch.Tensor,
37
+ true_labels: torch.LongTensor,
38
+ student_bias: torch.Tensor,
39
+ teacher_bias: torch.Tensor,
40
+ weight_hard_loss: float = 0.5,
41
+ weight_soft_loss: float = 0.5,
42
+ beta: float = 0.5,
43
+ ignore_index: int = -100,
44
+ temperature: float = 1.0,
45
+ compiled: bool = True,
46
+ chunk_size: int = 1024,
47
+ return_soft_hard_loss: bool = False,
48
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
49
+ return super().forward(
50
+ cls=cls,
51
+ ctx=ctx,
52
+ student_input=student_input,
53
+ student_weight=student_weight,
54
+ teacher_input=teacher_input,
55
+ teacher_weight=teacher_weight,
56
+ target=true_labels,
57
+ student_bias=student_bias,
58
+ teacher_bias=teacher_bias,
59
+ chunk_size=chunk_size,
60
+ weight_hard_loss=weight_hard_loss,
61
+ weight_soft_loss=weight_soft_loss,
62
+ beta=beta,
63
+ ignore_index=ignore_index,
64
+ temperature=temperature,
65
+ compiled=compiled,
66
+ return_soft_hard_loss=return_soft_hard_loss,
67
+ )
68
+
69
+ @staticmethod
70
+ def backward(ctx, grad_output, *args):
71
+ grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
72
+
73
+ return (
74
+ *grads,
75
+ None, # teacher_bias
76
+ None, # weight_hard_loss
77
+ None, # weight_soft_loss
78
+ None, # beta
79
+ None, # ignore_index
80
+ None, # temperature
81
+ None, # compiled
82
+ None, # chunk_size
83
+ None, # return_soft_hard_loss
84
+ )
85
+
86
+
87
+ class LigerFusedLinearCosineSimilarityLoss(torch.nn.Module):
88
+ def __init__(
89
+ self,
90
+ weight_hard_loss: float = 0.5,
91
+ weight_soft_loss: float = 0.5,
92
+ beta: float = 0.5,
93
+ ignore_index: int = -100,
94
+ temperature: float = 1.0,
95
+ compiled: bool = True,
96
+ chunk_size: int = 1024,
97
+ return_soft_hard_loss: bool = False,
98
+ ):
99
+ super().__init__()
100
+ assert temperature != 0, "Temperature cannot be 0."
101
+ self.weight_hard_loss = weight_hard_loss
102
+ self.weight_soft_loss = weight_soft_loss
103
+ self.ignore_index = ignore_index
104
+ self.temperature = temperature
105
+ self.compiled = compiled
106
+ self.beta = beta
107
+ self.chunk_size = chunk_size
108
+ self.return_soft_hard_loss = return_soft_hard_loss
109
+
110
+ def forward(
111
+ self,
112
+ student_input: torch.Tensor,
113
+ student_weight: torch.Tensor,
114
+ teacher_input: torch.Tensor,
115
+ teacher_weight: torch.Tensor,
116
+ true_labels: torch.LongTensor,
117
+ student_bias: torch.Tensor = None,
118
+ teacher_bias: torch.Tensor = None,
119
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
120
+ return LigerFusedLinearCosineSimilarityFunction.apply(
121
+ student_input,
122
+ student_weight,
123
+ teacher_input,
124
+ teacher_weight,
125
+ true_labels,
126
+ student_bias,
127
+ teacher_bias,
128
+ self.weight_hard_loss,
129
+ self.weight_soft_loss,
130
+ self.beta,
131
+ self.ignore_index,
132
+ self.temperature,
133
+ self.compiled,
134
+ self.chunk_size,
135
+ self.return_soft_hard_loss,
136
+ )
@@ -0,0 +1,157 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
5
+
6
+
7
+ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
8
+ @staticmethod
9
+ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1, label_smoothing=0.0):
10
+ """
11
+ Paper: https://arxiv.org/pdf/2401.08417
12
+
13
+ Formula:
14
+ L(π_θ; U) = -E_(x,y_w,y_l)~D[log σ(β log π_θ(y_w|x) - β log π_θ(y_l|x))]
15
+
16
+ Where:
17
+ - π_θ(y|x): Policy (model) probability
18
+ - y_w: Chosen sequence
19
+ - y_l: Rejected sequence
20
+ - σ: Sigmoid function
21
+ - β: Temperature parameter
22
+ - E: Expected value over the dataset D
23
+ - D: Dataset of preferences
24
+
25
+ Args:
26
+ chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
27
+ rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
28
+ full_target (torch.Tensor): Non chunked full target tensor
29
+ beta (float): Weight for the CPO loss
30
+ label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
31
+ """
32
+ logits = beta * (chosen_logps - rejected_logps)
33
+ loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
34
+ full_target.shape[0] // 2
35
+ )
36
+
37
+ chosen_rewards = beta * chosen_logps
38
+ rejected_rewards = beta * rejected_logps
39
+
40
+ return loss, chosen_rewards, rejected_rewards
41
+
42
+ @classmethod
43
+ def forward(
44
+ cls,
45
+ ctx,
46
+ _input,
47
+ weight,
48
+ target,
49
+ bias=None,
50
+ ignore_index=-100,
51
+ beta=0.1,
52
+ alpha=1.0,
53
+ label_smoothing=0.0,
54
+ compute_nll_loss=True,
55
+ compiled=True,
56
+ average_log_prob=False,
57
+ chunk_size=1,
58
+ ):
59
+ """
60
+ Fused linear layer with CPO loss.
61
+ Args:
62
+ _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
63
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
64
+ target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
65
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
66
+ ignore_index (int): Index to ignore in loss computation
67
+ beta (float): Weight for the odds ratio loss
68
+ alpha (float): Weight for the alpha parameter
69
+ label_smoothing (float): Label smoothing factor
70
+ compute_nll_loss (bool): Whether to compute the NLL loss
71
+ compiled (bool): Whether to use torch compile
72
+ average_log_prob (bool): Whether to average the log probability per non-masked token
73
+ chunk_size (int): Size of chunks for processing.
74
+ Returns:
75
+ torch.Tensor: Computed loss
76
+ """
77
+ return super().forward(
78
+ cls=cls,
79
+ ctx=ctx,
80
+ _input=_input,
81
+ weight=weight,
82
+ target=target,
83
+ bias=bias,
84
+ ignore_index=ignore_index,
85
+ alpha=alpha,
86
+ beta=beta,
87
+ label_smoothing=label_smoothing,
88
+ compute_nll_loss=compute_nll_loss,
89
+ average_log_prob=average_log_prob,
90
+ compiled=compiled,
91
+ chunk_size=chunk_size,
92
+ )
93
+
94
+ @staticmethod
95
+ def backward(ctx, *grad_output):
96
+ grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
97
+ return *grads, None, None, None, None, None, None, None, None
98
+
99
+
100
+ class LigerFusedLinearCPOLoss(torch.nn.Module):
101
+ """
102
+ Fused linear layer with CPO loss.
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ ignore_index: int = -100,
108
+ beta: float = 0.1,
109
+ alpha: float = 1.0,
110
+ label_smoothing: float = 0.0,
111
+ compute_nll_loss: bool = True,
112
+ compiled: bool = True,
113
+ average_log_prob: bool = False,
114
+ chunk_size: int = 1,
115
+ ):
116
+ """
117
+ Args:
118
+ ignore_index (int): Index to ignore in the loss.
119
+ beta (float): Weight for the odds ratio loss.
120
+ alpha (float): Weight for the alpha parameter.
121
+ label_smoothing (float): Label smoothing factor.
122
+ compute_nll_loss (bool): Whether to compute the NLL loss.
123
+ compiled (bool): Whether to use the torch compiled kernel.
124
+ average_log_prob (bool): Whether to average the log probability per non-masked token.
125
+ chunk_size (int): Size of chunks for processing.
126
+ """
127
+ super().__init__()
128
+ self.ignore_index = ignore_index
129
+ self.beta = beta
130
+ self.alpha = alpha
131
+ self.label_smoothing = label_smoothing
132
+ self.compute_nll_loss = compute_nll_loss
133
+ self.compiled = compiled
134
+ self.average_log_prob = average_log_prob
135
+ self.chunk_size = chunk_size
136
+
137
+ def forward(
138
+ self,
139
+ lin_weight,
140
+ _input,
141
+ target,
142
+ bias=None,
143
+ ):
144
+ return LigerFusedLinearCPOFunction.apply(
145
+ _input,
146
+ lin_weight,
147
+ target,
148
+ bias,
149
+ self.ignore_index,
150
+ self.beta,
151
+ self.alpha,
152
+ self.label_smoothing,
153
+ self.compute_nll_loss,
154
+ self.compiled,
155
+ self.average_log_prob,
156
+ self.chunk_size,
157
+ )
@@ -0,0 +1,229 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
5
+
6
+
7
+ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
8
+ @staticmethod
9
+ def preference_loss_fn(
10
+ chosen_logps,
11
+ rejected_logps,
12
+ full_target,
13
+ ref_chosen_logps=None,
14
+ ref_rejected_logps=None,
15
+ beta=0.1,
16
+ loss_type="sigmoid",
17
+ ):
18
+ """
19
+ Paper: https://arxiv.org/pdf/2305.18290
20
+
21
+ Formula:
22
+ L_DPO = -E[ log_sigmoid( β * (log(π(y_w|x)/π_ref(y_w|x)) - log(π(y_l|x)/π_ref(y_l|x))) ) ]
23
+
24
+ Where:
25
+ - π(y|x): Policy (model) probability
26
+ - π_ref(y|x): Reference model probability
27
+ - y_w: Chosen sequence
28
+ - y_l: Rejected sequence
29
+ - β: Weight for the direct preference loss
30
+ - E: Expected value over the dataset
31
+
32
+ Args:
33
+ chosen_logps: Log probabilities of chosen tokens (batch_size,)
34
+ rejected_logps: Log probabilities of rejected tokens (batch_size,)
35
+ full_target: Non chunked full target tensor
36
+ ref_chosen_logps: Reference log probs of chosen tokens (batch_size,)
37
+ ref_rejected_logps: Reference log probs of rejected tokens (batch_size,)
38
+ beta: Weight for the direct preference loss
39
+ """
40
+
41
+ if ref_chosen_logps is None:
42
+ ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device)
43
+ if ref_rejected_logps is None:
44
+ ref_rejected_logps = torch.tensor(0.0, device=rejected_logps.device)
45
+
46
+ chosen_logratios = chosen_logps - ref_chosen_logps
47
+ rejected_logratios = rejected_logps - ref_rejected_logps
48
+
49
+ chosen_rewards = beta * chosen_logratios
50
+ rejected_rewards = beta * rejected_logratios
51
+
52
+ if loss_type == "sigmoid":
53
+ logits_diff = beta * (chosen_logratios - rejected_logratios)
54
+ loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
55
+
56
+ elif loss_type == "apo_zero":
57
+ # Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
58
+ # Use this loss when you believe the chosen outputs are better than your model's default output
59
+ losses_chosen = 1 - F.sigmoid(beta * chosen_logratios) # Increase chosen likelihood
60
+ losses_rejected = F.sigmoid(beta * rejected_logratios)
61
+ losses = losses_chosen + losses_rejected
62
+ loss = losses.sum() / (full_target.shape[0] // 2)
63
+
64
+ elif loss_type == "apo_down":
65
+ # Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266)
66
+ # Use this loss when you believe the chosen outputs are worse than your model's default output.
67
+ # Decrease chosen likelihood and decrease rejected likelihood more
68
+ losses_chosen = F.sigmoid(beta * chosen_logratios)
69
+ losses_rejected = 1 - F.sigmoid(beta * (chosen_logratios - rejected_logratios))
70
+ losses = losses_chosen + losses_rejected
71
+ loss = losses.sum() / (full_target.shape[0] // 2)
72
+
73
+ elif loss_type == "sppo_hard":
74
+ # In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach,
75
+ # estimated using the PairRM score. The probability calculation is conducted outside of the trainer class.
76
+ # The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is
77
+ # set to 1 for the winner and 0 for the loser.
78
+ a = chosen_logps - ref_chosen_logps
79
+ b = rejected_logps - ref_rejected_logps
80
+ losses = (a - 0.5 / beta) ** 2 + (b + 0.5 / beta) ** 2
81
+ loss = losses.sum() / (full_target.shape[0] // 2)
82
+
83
+ elif loss_type == "nca_pair":
84
+ losses = (
85
+ -F.logsigmoid(chosen_rewards)
86
+ - 0.5 * F.logsigmoid(-chosen_rewards)
87
+ - 0.5 * F.logsigmoid(-rejected_rewards)
88
+ )
89
+ loss = losses.sum() / (full_target.shape[0] // 2)
90
+
91
+ else:
92
+ raise ValueError(
93
+ f"Unsupported loss_type: {loss_type}. Supported types are: sigmoid, apo_zero, apo_down, sppo_hard, nca_pair"
94
+ )
95
+
96
+ return loss, chosen_rewards, rejected_rewards
97
+
98
+ @classmethod
99
+ def forward(
100
+ cls,
101
+ ctx,
102
+ _input,
103
+ weight,
104
+ target,
105
+ bias=None,
106
+ ref_input=None,
107
+ ref_weight=None,
108
+ ref_bias=None,
109
+ ignore_index=-100,
110
+ beta=0.1,
111
+ compute_nll_loss=False,
112
+ compiled=True,
113
+ use_ref_model=True,
114
+ average_log_prob=False,
115
+ chunk_size=1,
116
+ loss_type="sigmoid",
117
+ ):
118
+ """
119
+ Fused linear layer with DPO loss.
120
+ Args:
121
+ _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
122
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
123
+ target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
124
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
125
+ ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
126
+ ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
127
+ ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
128
+ ignore_index (int): Index to ignore in loss computation
129
+ beta (float): Weight for the odds ratio loss
130
+ compute_nll_loss (bool): Whether to compute the NLL loss
131
+ compiled (bool): Whether to use torch compile
132
+ use_ref_model (bool): Whether to use a reference model
133
+ average_log_prob (bool): Whether to average the log probability per non-masked token
134
+ chunk_size (int): Size of chunks for processing.
135
+ Returns:
136
+ torch.Tensor: Computed loss
137
+ """
138
+ return super().forward(
139
+ cls=cls,
140
+ ctx=ctx,
141
+ _input=_input,
142
+ weight=weight,
143
+ target=target,
144
+ bias=bias,
145
+ ignore_index=ignore_index,
146
+ beta=beta,
147
+ compute_nll_loss=compute_nll_loss,
148
+ compiled=compiled,
149
+ use_ref_model=use_ref_model,
150
+ ref_input=ref_input,
151
+ ref_weight=ref_weight,
152
+ ref_bias=ref_bias,
153
+ average_log_prob=average_log_prob,
154
+ chunk_size=chunk_size,
155
+ loss_type=loss_type,
156
+ )
157
+
158
+ @staticmethod
159
+ def backward(ctx, *grad_output):
160
+ grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
161
+ return *grads, None, None, None, None, None, None, None, None, None, None, None
162
+
163
+
164
+ class LigerFusedLinearDPOLoss(torch.nn.Module):
165
+ """
166
+ Fused linear layer with DPO loss.
167
+ """
168
+
169
+ def __init__(
170
+ self,
171
+ ignore_index: int = -100,
172
+ beta: float = 0.1,
173
+ compute_nll_loss: bool = False,
174
+ compiled: bool = True,
175
+ use_ref_model: bool = True,
176
+ average_log_prob: bool = False,
177
+ chunk_size: int = 1,
178
+ loss_type: str = "sigmoid",
179
+ ):
180
+ """
181
+ Args:
182
+ ignore_index (int): Index to ignore in the loss.
183
+ beta (float): Weight for the odds ratio loss.
184
+ compute_nll_loss (bool): Whether to compute the NLL loss.
185
+ compiled (bool): Whether to use the torch compiled kernel.
186
+ use_ref_model (bool): Whether to use a reference model for the DPO loss.
187
+ average_log_prob (bool): Whether to average the log probability per non-masked token.
188
+ chunk_size (int): Size of chunks for processing.
189
+ """
190
+ super().__init__()
191
+ self.ignore_index = ignore_index
192
+ self.beta = beta
193
+ self.compute_nll_loss = compute_nll_loss
194
+ self.compiled = compiled
195
+ self.use_ref_model = use_ref_model
196
+ self.average_log_prob = average_log_prob
197
+ self.chunk_size = chunk_size
198
+ self.loss_type = loss_type
199
+ supported_loss_types = {"sigmoid", "apo_zero", "apo_down", "sppo_hard", "nca_pair"}
200
+ if self.loss_type not in supported_loss_types:
201
+ raise ValueError(f"Unsupported loss_type: {self.loss_type}. Supported types are: {supported_loss_types}")
202
+
203
+ def forward(
204
+ self,
205
+ lin_weight,
206
+ _input,
207
+ target,
208
+ bias=None,
209
+ ref_input=None,
210
+ ref_weight=None,
211
+ ref_bias=None,
212
+ ):
213
+ return LigerFusedLinearDPOFunction.apply(
214
+ _input,
215
+ lin_weight,
216
+ target,
217
+ bias,
218
+ ref_input,
219
+ ref_weight,
220
+ ref_bias,
221
+ self.ignore_index,
222
+ self.beta,
223
+ self.compute_nll_loss,
224
+ self.compiled,
225
+ self.use_ref_model,
226
+ self.average_log_prob,
227
+ self.chunk_size,
228
+ self.loss_type,
229
+ )
@@ -0,0 +1,17 @@
1
+ from liger_kernel.chunked_loss.cosine_similarity_loss import LigerFusedLinearCosineSimilarityFunction
2
+ from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
3
+ from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
4
+ from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
5
+ from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
6
+ from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction
7
+ from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
8
+ from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
9
+
10
+ liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
11
+ liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
12
+ liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
13
+ liger_fused_linear_cosine = LigerFusedLinearCosineSimilarityFunction.apply
14
+ liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
15
+ liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
16
+ liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
17
+ liger_fused_linear_grpo = LigerFusedLinearGRPOFunction.apply