liger-kernel-nightly 0.4.0.dev20241107052928__py3-none-any.whl → 0.6.3.dev20251121010306__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (114) hide show
  1. liger_kernel/__init__.py +0 -0
  2. liger_kernel/chunked_loss/README.md +25 -0
  3. liger_kernel/chunked_loss/__init__.py +8 -0
  4. liger_kernel/chunked_loss/cosine_similarity_loss.py +136 -0
  5. liger_kernel/chunked_loss/cpo_loss.py +157 -0
  6. liger_kernel/chunked_loss/dpo_loss.py +229 -0
  7. liger_kernel/chunked_loss/functional.py +17 -0
  8. liger_kernel/chunked_loss/fused_linear_distillation.py +292 -0
  9. liger_kernel/chunked_loss/fused_linear_ppo.py +350 -0
  10. liger_kernel/chunked_loss/fused_linear_preference.py +433 -0
  11. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +341 -0
  12. liger_kernel/chunked_loss/grpo_loss.py +304 -0
  13. liger_kernel/chunked_loss/jsd_loss.py +200 -0
  14. liger_kernel/chunked_loss/kto_loss.py +210 -0
  15. liger_kernel/chunked_loss/orpo_loss.py +144 -0
  16. liger_kernel/chunked_loss/simpo_loss.py +165 -0
  17. liger_kernel/env_report.py +21 -4
  18. liger_kernel/ops/cross_entropy.py +235 -84
  19. liger_kernel/ops/dyt.py +157 -0
  20. liger_kernel/ops/experimental/embedding.py +1 -3
  21. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  22. liger_kernel/ops/fused_add_rms_norm.py +412 -0
  23. liger_kernel/ops/fused_linear_cross_entropy.py +197 -75
  24. liger_kernel/ops/fused_linear_jsd.py +17 -34
  25. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  26. liger_kernel/ops/geglu.py +7 -18
  27. liger_kernel/ops/group_norm.py +305 -0
  28. liger_kernel/ops/grpo_loss.py +310 -0
  29. liger_kernel/ops/jsd.py +46 -21
  30. liger_kernel/ops/kl_div.py +23 -19
  31. liger_kernel/ops/layer_norm.py +150 -86
  32. liger_kernel/ops/llama4_rope.py +225 -0
  33. liger_kernel/ops/multi_token_attention.py +207 -0
  34. liger_kernel/ops/poly_norm.py +386 -0
  35. liger_kernel/ops/qwen2vl_mrope.py +222 -0
  36. liger_kernel/ops/rms_norm.py +314 -84
  37. liger_kernel/ops/rope.py +32 -34
  38. liger_kernel/ops/softmax.py +201 -0
  39. liger_kernel/ops/sparsemax.py +179 -0
  40. liger_kernel/ops/swiglu.py +5 -9
  41. liger_kernel/ops/tiled_mlp.py +136 -0
  42. liger_kernel/ops/tvd.py +207 -0
  43. liger_kernel/ops/utils.py +8 -4
  44. liger_kernel/transformers/__init__.py +199 -24
  45. liger_kernel/transformers/auto_model.py +6 -13
  46. liger_kernel/transformers/cross_entropy.py +33 -20
  47. liger_kernel/transformers/dyt.py +22 -0
  48. liger_kernel/transformers/experimental/__init__.py +5 -0
  49. liger_kernel/transformers/experimental/embedding.py +1 -3
  50. liger_kernel/transformers/fsdp.py +55 -0
  51. liger_kernel/transformers/functional.py +291 -13
  52. liger_kernel/transformers/fused_add_rms_norm.py +39 -0
  53. liger_kernel/transformers/fused_linear_cross_entropy.py +43 -14
  54. liger_kernel/transformers/fused_linear_jsd.py +1 -4
  55. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  56. liger_kernel/transformers/geglu.py +1 -4
  57. liger_kernel/transformers/group_norm.py +50 -0
  58. liger_kernel/transformers/grpo_loss.py +98 -0
  59. liger_kernel/transformers/jsd.py +2 -7
  60. liger_kernel/transformers/kl_div.py +1 -3
  61. liger_kernel/transformers/layer_norm.py +3 -9
  62. liger_kernel/transformers/llama4_rope.py +93 -0
  63. liger_kernel/transformers/model/falcon_h1.py +122 -0
  64. liger_kernel/transformers/model/gemma.py +77 -77
  65. liger_kernel/transformers/model/gemma2.py +283 -0
  66. liger_kernel/transformers/model/gemma3.py +331 -0
  67. liger_kernel/transformers/model/glm4.py +141 -0
  68. liger_kernel/transformers/model/glm4v.py +163 -0
  69. liger_kernel/transformers/model/glm4v_moe.py +172 -0
  70. liger_kernel/transformers/model/internvl.py +157 -0
  71. liger_kernel/transformers/model/llama.py +128 -79
  72. liger_kernel/transformers/model/llama4.py +121 -0
  73. liger_kernel/transformers/model/llava.py +344 -0
  74. liger_kernel/transformers/model/loss_utils.py +95 -0
  75. liger_kernel/transformers/model/mistral.py +68 -64
  76. liger_kernel/transformers/model/mixtral.py +75 -91
  77. liger_kernel/transformers/model/mllama.py +63 -68
  78. liger_kernel/transformers/model/olmo2.py +141 -0
  79. liger_kernel/transformers/model/output_classes.py +147 -0
  80. liger_kernel/transformers/model/paligemma.py +432 -0
  81. liger_kernel/transformers/model/phi3.py +59 -213
  82. liger_kernel/transformers/model/qwen2.py +75 -72
  83. liger_kernel/transformers/model/qwen2_5_vl.py +163 -0
  84. liger_kernel/transformers/model/qwen2_vl.py +78 -98
  85. liger_kernel/transformers/model/qwen3.py +136 -0
  86. liger_kernel/transformers/model/qwen3_moe.py +152 -0
  87. liger_kernel/transformers/model/qwen3_next.py +146 -0
  88. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  89. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  90. liger_kernel/transformers/model/smollm3.py +199 -0
  91. liger_kernel/transformers/model/smolvlm.py +158 -0
  92. liger_kernel/transformers/monkey_patch.py +2106 -289
  93. liger_kernel/transformers/multi_token_attention.py +64 -0
  94. liger_kernel/transformers/poly_norm.py +42 -0
  95. liger_kernel/transformers/qwen2vl_mrope.py +20 -0
  96. liger_kernel/transformers/rms_norm.py +57 -6
  97. liger_kernel/transformers/rope.py +45 -2
  98. liger_kernel/transformers/softmax.py +12 -0
  99. liger_kernel/transformers/sparsemax.py +16 -0
  100. liger_kernel/transformers/swiglu.py +23 -8
  101. liger_kernel/transformers/tiled_mlp.py +133 -0
  102. liger_kernel/transformers/trainer/__init__.py +4 -0
  103. liger_kernel/transformers/trainer/orpo_trainer.py +130 -0
  104. liger_kernel/transformers/tvd.py +13 -0
  105. liger_kernel/triton/__init__.py +1 -3
  106. liger_kernel/triton/monkey_patch.py +1 -3
  107. liger_kernel/utils.py +71 -0
  108. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/METADATA +150 -137
  109. liger_kernel_nightly-0.6.3.dev20251121010306.dist-info/RECORD +116 -0
  110. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/WHEEL +1 -1
  111. liger_kernel_nightly-0.4.0.dev20241107052928.dist-info/RECORD +0 -48
  112. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/LICENSE +0 -0
  113. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/NOTICE +0 -0
  114. {liger_kernel_nightly-0.4.0.dev20241107052928.dist-info → liger_kernel_nightly-0.6.3.dev20251121010306.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,200 @@
1
+ import math
2
+
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase
10
+
11
+
12
+ class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
13
+ @staticmethod
14
+ def distillation_loss_fn(student_logits, teacher_logits, beta=0.5):
15
+ """
16
+ Compute JSD loss (Jensen-Shannon Divergence Loss).
17
+ Args:
18
+ student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
19
+ teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
20
+ beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
21
+ Returns:
22
+ torch.Tensor: Jensen-Shannon Divergence loss
23
+ """
24
+ student_log_probs = F.log_softmax(student_logits, dim=-1)
25
+ teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
26
+
27
+ if beta == 0:
28
+ jsd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="sum", log_target=True)
29
+ elif beta == 1:
30
+ jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="sum", log_target=True)
31
+ else:
32
+ # Compute probabilities (only required for mean calculation)
33
+ log_mean_probs = torch.logsumexp(
34
+ torch.stack([student_log_probs + math.log(1 - beta), teacher_log_probs + math.log(beta)], dim=0), dim=0
35
+ )
36
+
37
+ student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
38
+ teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
39
+
40
+ # JSD is the weighted average of the KL divergences
41
+ jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
42
+ return jsd_loss
43
+
44
+ @classmethod
45
+ def forward(
46
+ cls,
47
+ ctx,
48
+ student_input: torch.Tensor,
49
+ student_weight: torch.Tensor,
50
+ teacher_input: torch.Tensor,
51
+ teacher_weight: torch.Tensor,
52
+ true_labels: torch.LongTensor,
53
+ student_bias: torch.Tensor,
54
+ teacher_bias: torch.Tensor,
55
+ weight_hard_loss: float = 0.5,
56
+ weight_soft_loss: float = 0.5,
57
+ beta: float = 0.5,
58
+ ignore_index: int = -100,
59
+ temperature: float = 1.0,
60
+ compiled: bool = True,
61
+ chunk_size: int = 1024,
62
+ return_soft_hard_loss: bool = False,
63
+ ):
64
+ """
65
+ Fused linear layer with JSD distillation loss.
66
+ Args:
67
+ student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, hidden_size_student)
68
+ student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, hidden_size_student)
69
+ teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, hidden_size_teacher)
70
+ teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, hidden_size_teacher)
71
+ true_labels (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
72
+ weight_hard_loss (float): Weight for hard loss.
73
+ weight_soft_loss (float): Weight for soft loss.
74
+ beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
75
+ ignore_index (int): Index to ignore in loss computation
76
+ temperature (float): Temperature for softening/sharpening distributions
77
+ compiled (bool): Whether to use torch compile
78
+ chunk_size (int): Size of chunks for processing.
79
+ return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
80
+ Returns:
81
+ torch.Tensor: Computed loss, or tuple (loss, soft_loss, hard_loss) if return_soft_hard_loss=True
82
+ """
83
+ return super().forward(
84
+ cls=cls,
85
+ ctx=ctx,
86
+ student_input=student_input,
87
+ student_weight=student_weight,
88
+ teacher_input=teacher_input,
89
+ teacher_weight=teacher_weight,
90
+ target=true_labels,
91
+ student_bias=student_bias,
92
+ teacher_bias=teacher_bias,
93
+ chunk_size=chunk_size,
94
+ weight_hard_loss=weight_hard_loss,
95
+ weight_soft_loss=weight_soft_loss,
96
+ beta=beta,
97
+ ignore_index=ignore_index,
98
+ temperature=temperature,
99
+ compiled=compiled,
100
+ return_soft_hard_loss=return_soft_hard_loss,
101
+ )
102
+
103
+ @staticmethod
104
+ def backward(ctx, grad_output, *args):
105
+ grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output, *args)[:6]
106
+
107
+ return (
108
+ *grads,
109
+ None, # teacher_bias
110
+ None, # weight_hard_loss
111
+ None, # weight_soft_loss
112
+ None, # beta
113
+ None, # ignore_index
114
+ None, # temperature
115
+ None, # compiled
116
+ None, # chunk_size
117
+ None, # return_soft_hard_loss
118
+ )
119
+
120
+
121
+ class LigerFusedLinearJSDLoss(torch.nn.Module):
122
+ """
123
+ Fused linear layer with JSD distillation loss.
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ weight_hard_loss: float = 0.5,
129
+ weight_soft_loss: float = 0.5,
130
+ beta: float = 0.5,
131
+ ignore_index: int = -100,
132
+ temperature: float = 1.0,
133
+ compiled: bool = True,
134
+ chunk_size: int = 1024,
135
+ return_soft_hard_loss: bool = False,
136
+ ):
137
+ """
138
+ Args:
139
+ weight_hard_loss (float): Weight for hard loss.
140
+ weight_soft_loss (float): Weight for soft loss.
141
+ ignore_index (int): Index to ignore in the loss
142
+ temperature (float): Temperature for softening distributions
143
+ compiled (bool): Whether to use torch compile
144
+ beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
145
+ chunk_size (int): Size of chunks for processing.
146
+ return_soft_hard_loss (bool): Whether to return soft and hard losses separately. Default: False.
147
+ """
148
+ super().__init__()
149
+ assert temperature != 0, "Temperature cannot be 0."
150
+ self.weight_hard_loss = weight_hard_loss
151
+ self.weight_soft_loss = weight_soft_loss
152
+ self.ignore_index = ignore_index
153
+ self.temperature = temperature
154
+ self.compiled = compiled
155
+ self.beta = beta
156
+ self.chunk_size = chunk_size
157
+ self.return_soft_hard_loss = return_soft_hard_loss
158
+
159
+ def forward(
160
+ self,
161
+ student_input: torch.Tensor,
162
+ student_weight: torch.Tensor,
163
+ teacher_input: torch.Tensor,
164
+ teacher_weight: torch.Tensor,
165
+ true_labels: torch.LongTensor,
166
+ student_bias: torch.Tensor = None,
167
+ teacher_bias: torch.Tensor = None,
168
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
169
+ """
170
+ Compute the JSD distillation loss.
171
+
172
+ Args:
173
+ student_input (torch.Tensor): Student input tensor
174
+ student_weight (torch.Tensor): Student weight tensor
175
+ teacher_input (torch.Tensor): Teacher input tensor
176
+ teacher_weight (torch.Tensor): Teacher weight tensor
177
+ true_labels (torch.LongTensor): Target labels tensor
178
+
179
+ Returns:
180
+ torch.Tensor or Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
181
+ If return_soft_hard_loss is False: Computed combined loss
182
+ If return_soft_hard_loss is True: Tuple of (combined_loss, soft_loss, hard_loss)
183
+ """
184
+ return LigerFusedLinearJSDFunction.apply(
185
+ student_input,
186
+ student_weight,
187
+ teacher_input,
188
+ teacher_weight,
189
+ true_labels,
190
+ student_bias,
191
+ teacher_bias,
192
+ self.weight_hard_loss,
193
+ self.weight_soft_loss,
194
+ self.beta,
195
+ self.ignore_index,
196
+ self.temperature,
197
+ self.compiled,
198
+ self.chunk_size,
199
+ self.return_soft_hard_loss,
200
+ )
@@ -0,0 +1,210 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from liger_kernel.chunked_loss.fused_linear_unpaired_preference import LigerFusedLinearUnpairedPreferenceBase
5
+
6
+
7
+ class LigerFusedLinearKTOFunction(LigerFusedLinearUnpairedPreferenceBase):
8
+ @staticmethod
9
+ def preference_loss_fn(
10
+ log_prob_chunk,
11
+ preference_labels_chunk,
12
+ full_target,
13
+ ref_log_prob_chunk=None,
14
+ beta=0.1,
15
+ kl=None,
16
+ ):
17
+ """
18
+ Implements the Kahneman-Tversky Optimization (KTO) loss function.
19
+ Paper: "KTO: Model Alignment as Prospect Theory-Guided Optimization"
20
+ https://arxiv.org/abs/2402.01306
21
+
22
+ KTO loss is inspired by prospect theory (https://en.wikipedia.org/wiki/Prospect_theory)
23
+ from behavioral economics, which models how humans make decisions under uncertainty.
24
+ The loss function is asymmetric, treating gains and losses differently, similar to
25
+ human decision-making patterns.
26
+
27
+ Formula:
28
+ When y is chosen:
29
+ L_KTO = 1 - σ(β * (log[π(x)/π₀(x)] - KL(π||π₀)_y))
30
+ When y is rejected:
31
+ L_KTO = 1 - σ(β * (KL(π||π₀)_y - log[π(x)/π₀(x)]))
32
+
33
+ Where:
34
+ - σ: Sigmoid function
35
+ - β: Temperature parameter controlling the strength of the preference signal
36
+ - π(x): Policy (current model)
37
+ - π₀(x): Reference policy (reference model)
38
+ - KL(π||π₀)_y: KL divergence estimated using the rejected response y
39
+
40
+ The loss encourages the model to:
41
+ 1. Assign higher probability to chosen responses
42
+ 2. Assign lower probability to rejected responses
43
+ 3. Maintain reasonable distance from the reference model
44
+
45
+ Args:
46
+ log_prob_chunk: Log probabilities for the chunk (batch_size,)
47
+ preference_labels_chunk: Preference labels for the chunk (batch_size,)
48
+ full_target: Non chunked full target tensor
49
+ ref_log_prob_chunk: Reference log probs for the chunk (batch_size,)
50
+ beta: Weight for the KTO loss
51
+ kl: KL divergence between the policy model and the reference model for the chosen responses. Shape: (batch_size,)
52
+ Returns:
53
+ - loss: The KTO loss value
54
+ """
55
+ if ref_log_prob_chunk is not None:
56
+ logratios_chunk = log_prob_chunk - ref_log_prob_chunk
57
+ else:
58
+ logratios_chunk = log_prob_chunk
59
+ multiplier_chunk = torch.where(preference_labels_chunk, 1, -1)
60
+ if kl is not None:
61
+ losses = 1 - F.sigmoid(beta * (logratios_chunk - kl) * multiplier_chunk)
62
+ else:
63
+ losses = 1 - F.sigmoid(beta * logratios_chunk * multiplier_chunk)
64
+
65
+ rewards = beta * logratios_chunk
66
+ chosen_rewards_sum = (rewards * preference_labels_chunk.unsqueeze(1)).sum()
67
+ rejected_rewards_sum = (rewards * (~preference_labels_chunk).unsqueeze(1)).sum()
68
+
69
+ return losses.sum() / (full_target.shape[0]), chosen_rewards_sum, rejected_rewards_sum
70
+
71
+ @classmethod
72
+ def forward(
73
+ cls,
74
+ ctx,
75
+ _input,
76
+ weight,
77
+ target,
78
+ preference_labels,
79
+ bias=None,
80
+ ref_input=None,
81
+ ref_weight=None,
82
+ ref_bias=None,
83
+ kl=None,
84
+ ignore_index=-100,
85
+ beta=0.1,
86
+ compiled=True,
87
+ use_ref_model=True,
88
+ average_log_prob=False,
89
+ chunk_size=1,
90
+ ):
91
+ """
92
+ Fused linear layer with KTO loss.
93
+ Args:
94
+ _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
95
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
96
+ target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
97
+ preference_labels (torch.Tensor): Preference labels tensor. Shape: (batch_size,)
98
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
99
+ ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
100
+ ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
101
+ ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
102
+ kl (torch.Tensor, optional): KL divergence tensor. Shape: (batch_size,)
103
+ ignore_index (int): Index to ignore in loss computation
104
+ beta (float): Temperature parameter for the KTO loss
105
+ compiled (bool): Whether to use torch compile
106
+ use_ref_model (bool): Whether to use a reference model
107
+ average_log_prob (bool): Whether to average the log probability per non-masked token
108
+ chunk_size (int): Size of chunks for processing
109
+ Returns:
110
+ torch.Tensor: Computed loss
111
+ """
112
+ return super().forward(
113
+ cls=cls,
114
+ ctx=ctx,
115
+ _input=_input,
116
+ weight=weight,
117
+ target=target,
118
+ preference_labels=preference_labels,
119
+ bias=bias,
120
+ ignore_index=ignore_index,
121
+ beta=beta,
122
+ compiled=compiled,
123
+ use_ref_model=use_ref_model,
124
+ ref_input=ref_input,
125
+ ref_weight=ref_weight,
126
+ ref_bias=ref_bias,
127
+ average_log_prob=average_log_prob,
128
+ kl=kl,
129
+ chunk_size=chunk_size,
130
+ )
131
+
132
+ @staticmethod
133
+ def backward(ctx, *grad_output):
134
+ grads = LigerFusedLinearUnpairedPreferenceBase.backward(ctx, grad_output)[:5]
135
+ return (
136
+ *grads,
137
+ None,
138
+ None,
139
+ None,
140
+ None,
141
+ None,
142
+ None,
143
+ None,
144
+ None,
145
+ None,
146
+ None,
147
+ None,
148
+ )
149
+
150
+
151
+ class LigerFusedLinearKTOLoss(torch.nn.Module):
152
+ """
153
+ Fused linear layer with Kahneman-Tversky Optimization (KTO) loss.
154
+ """
155
+
156
+ def __init__(
157
+ self,
158
+ ignore_index: int = -100,
159
+ beta: float = 0.1,
160
+ compiled: bool = True,
161
+ use_ref_model: bool = False,
162
+ average_log_prob: bool = False,
163
+ chunk_size: int = 1,
164
+ ):
165
+ """
166
+ Args:
167
+ ignore_index (int): Index to ignore in the loss calculation
168
+ beta (float): Temperature parameter for the KTO loss
169
+ compiled (bool): Whether to use compiled operations
170
+ use_ref_model (bool): Whether to use a reference model for the DPO loss.
171
+ average_log_prob (bool): Whether to average the log probability per non-masked token
172
+ chunk_size (int): Size of chunks for processing
173
+ """
174
+ super().__init__()
175
+ self.ignore_index = ignore_index
176
+ self.beta = beta
177
+ self.compiled = compiled
178
+ self.use_ref_model = use_ref_model
179
+ self.average_log_prob = average_log_prob
180
+ self.chunk_size = chunk_size
181
+
182
+ def forward(
183
+ self,
184
+ _input,
185
+ lin_weight,
186
+ target,
187
+ bias=None,
188
+ preference_labels=None,
189
+ ref_input=None,
190
+ ref_weight=None,
191
+ ref_bias=None,
192
+ kl=None,
193
+ ):
194
+ return LigerFusedLinearKTOFunction.apply(
195
+ _input,
196
+ lin_weight,
197
+ target,
198
+ preference_labels,
199
+ bias,
200
+ ref_input,
201
+ ref_weight,
202
+ ref_bias,
203
+ kl,
204
+ self.ignore_index,
205
+ self.beta,
206
+ self.compiled,
207
+ self.use_ref_model,
208
+ self.average_log_prob,
209
+ self.chunk_size,
210
+ )
@@ -0,0 +1,144 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from liger_kernel.chunked_loss.fused_linear_preference import LigerFusedLinearPreferenceBase
5
+
6
+
7
+ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase):
8
+ @staticmethod
9
+ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
10
+ """
11
+ Paper: https://arxiv.org/pdf/2403.07691
12
+
13
+ Formula:
14
+ Compute odds-ratio loss: L_OR = -log(σ(log(odds_θ(y_w|x) / odds_θ(y_l|x))))
15
+ where odds_θ(y|x) = P_θ(y|x) / (1 - P_θ(y|x))
16
+
17
+ Where:
18
+ - P_θ(y|x): Policy (model) probability
19
+ - y_w: Chosen sequence
20
+ - y_l: Rejected sequence
21
+ - σ: Sigmoid function
22
+ - β: Weight for the odds ratio loss
23
+ - odds_θ: Odds function for the policy
24
+
25
+ Args:
26
+ chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
27
+ rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
28
+ full_target (torch.Tensor): Non chunked full target tensor
29
+ beta (float): Weight for the odds ratio loss.
30
+ """
31
+ log_odds = (chosen_logps - rejected_logps) - (
32
+ torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps))
33
+ )
34
+ ratio = F.logsigmoid(log_odds)
35
+ loss = -beta * ratio.sum() / (full_target.shape[0] // 2)
36
+
37
+ chosen_rewards = beta * chosen_logps
38
+ rejected_rewards = beta * rejected_logps
39
+
40
+ log_odds_ratio = torch.sum(ratio) / (full_target.shape[0] // 2)
41
+ log_odds_chosen = torch.sum(log_odds) / (full_target.shape[0] // 2)
42
+
43
+ return loss, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen
44
+
45
+ @classmethod
46
+ def forward(
47
+ cls,
48
+ ctx,
49
+ _input,
50
+ weight,
51
+ target,
52
+ bias=None,
53
+ ignore_index=-100,
54
+ beta=0.1,
55
+ compute_nll_loss=True,
56
+ nll_target=None,
57
+ compiled=True,
58
+ chunk_size=1,
59
+ ):
60
+ """
61
+ Fused linear layer with ORPO loss.
62
+ Args:
63
+ _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
64
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
65
+ target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
66
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
67
+ ignore_index (int): Index to ignore in loss computation
68
+ beta (float): Weight for the odds ratio loss
69
+ compute_nll_loss (bool): Whether to compute the NLL loss
70
+ nll_target (torch.LongTensor, optional): Target tensor for NLL loss. Shape: (batch_size * seq_len,)
71
+ compiled (bool): Whether to use torch compile
72
+ chunk_size (int): Size of chunks for processing
73
+ Returns:
74
+ torch.Tensor: Computed loss
75
+ """
76
+ return super().forward(
77
+ cls=cls,
78
+ ctx=ctx,
79
+ _input=_input,
80
+ weight=weight,
81
+ target=target,
82
+ bias=bias,
83
+ ignore_index=ignore_index,
84
+ beta=beta,
85
+ compute_nll_loss=compute_nll_loss,
86
+ nll_target=nll_target,
87
+ compiled=compiled,
88
+ chunk_size=chunk_size,
89
+ )
90
+
91
+ @staticmethod
92
+ def backward(ctx, *grad_output):
93
+ grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
94
+ return *grads, None, None, None, None, None, None
95
+
96
+
97
+ class LigerFusedLinearORPOLoss(torch.nn.Module):
98
+ """
99
+ Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss.
100
+ """
101
+
102
+ def __init__(
103
+ self,
104
+ ignore_index: int = -100,
105
+ beta: float = 0.1,
106
+ compute_nll_loss: bool = True,
107
+ compiled: bool = True,
108
+ chunk_size: int = 1,
109
+ ):
110
+ """
111
+ Args:
112
+ ignore_index (int): Index to ignore in the loss.
113
+ beta (float): Weight for the odds ratio loss.
114
+ compute_nll_loss (bool): Whether to compute the NLL loss.
115
+ compiled (bool): Whether to use the torch compiled kernel.
116
+ chunk_size (int): Size of chunks for processing.
117
+ """
118
+ super().__init__()
119
+ self.ignore_index = ignore_index
120
+ self.beta = beta
121
+ self.compute_nll_loss = compute_nll_loss
122
+ self.compiled = compiled
123
+ self.chunk_size = chunk_size
124
+
125
+ def forward(
126
+ self,
127
+ lin_weight,
128
+ _input,
129
+ target,
130
+ bias=None,
131
+ nll_target=None,
132
+ ):
133
+ return LigerFusedLinearORPOFunction.apply(
134
+ _input,
135
+ lin_weight,
136
+ target,
137
+ bias,
138
+ self.ignore_index,
139
+ self.beta,
140
+ self.compute_nll_loss,
141
+ nll_target,
142
+ self.compiled,
143
+ self.chunk_size,
144
+ )