liger-kernel-nightly 0.5.8.dev20250422210723__py3-none-any.whl → 0.5.8.dev20250428050809__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -201,14 +201,16 @@ def lce_forward(
201
201
 
202
202
  hidden_states = outputs[0]
203
203
 
204
+ shift_labels = loss_kwargs.pop("shift_labels", None)
204
205
  logits = None
205
206
  loss = None
206
207
  # if in training mode, don't materialize logits
207
- if self.training and (labels is not None):
208
+ if self.training and (labels is not None or shift_labels is not None):
208
209
  loss = LigerForCausalLMLoss(
209
210
  hidden_states=hidden_states,
210
211
  lm_head_weight=self.lm_head.weight,
211
212
  labels=labels,
213
+ shift_labels=shift_labels,
212
214
  hidden_size=self.config.hidden_size,
213
215
  **loss_kwargs,
214
216
  )
@@ -213,14 +213,16 @@ def lce_forward(
213
213
 
214
214
  hidden_states = outputs[0]
215
215
 
216
+ shift_labels = loss_kwargs.pop("shift_labels", None)
216
217
  logits = None
217
218
  loss = None
218
219
  # if in training mode, don't materialize logits
219
- if self.training and (labels is not None):
220
+ if self.training and (labels is not None or shift_labels is not None):
220
221
  loss = LigerForCausalLMLoss(
221
222
  hidden_states=hidden_states,
222
223
  lm_head_weight=self.lm_head.weight,
223
224
  labels=labels,
225
+ shift_labels=shift_labels,
224
226
  hidden_size=self.config.hidden_size,
225
227
  final_logit_softcapping=self.config.final_logit_softcapping,
226
228
  **loss_kwargs,
@@ -104,13 +104,15 @@ def causal_forward(
104
104
  # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
105
105
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
106
106
  kept_hidden_states = hidden_states[:, slice_indices, :]
107
+ shift_labels = loss_kwargs.pop("shift_labels", None)
107
108
  loss = None
108
109
  logits = None
109
- if self.training and (labels is not None):
110
+ if self.training and (labels is not None or shift_labels is not None):
110
111
  loss = LigerForCausalLMLoss(
111
112
  hidden_states=kept_hidden_states,
112
113
  lm_head_weight=self.lm_head.weight,
113
114
  labels=labels,
115
+ shift_labels=shift_labels,
114
116
  hidden_size=self.config.hidden_size,
115
117
  final_logit_softcapping=self.config.final_logit_softcapping,
116
118
  **loss_kwargs,
@@ -213,14 +213,16 @@ def lce_forward(
213
213
  if self.config.pretraining_tp > 1:
214
214
  raise Exception("Liger Kernel does not support pretraining_tp!!")
215
215
 
216
+ shift_labels = loss_kwargs.pop("shift_labels", None)
216
217
  logits = None
217
218
  loss = None
218
219
  # if in training mode, don't materialize logits
219
- if self.training and (labels is not None):
220
+ if self.training and (labels is not None or shift_labels is not None):
220
221
  loss = LigerForCausalLMLoss(
221
222
  hidden_states=hidden_states,
222
223
  lm_head_weight=self.lm_head.weight,
223
224
  labels=labels,
225
+ shift_labels=shift_labels,
224
226
  hidden_size=self.config.hidden_size,
225
227
  **loss_kwargs,
226
228
  )
@@ -92,14 +92,16 @@ def lce_forward(
92
92
 
93
93
  hidden_states = outputs[0]
94
94
 
95
+ shift_labels = loss_kwargs.pop("shift_labels", None)
95
96
  loss = None
96
97
  logits = None
97
98
 
98
- if self.training and (labels is not None):
99
+ if self.training and (labels is not None or shift_labels is not None):
99
100
  loss = LigerForCausalLMLoss(
100
101
  hidden_states=hidden_states,
101
102
  lm_head_weight=self.lm_head.weight,
102
103
  labels=labels,
104
+ shift_labels=shift_labels,
103
105
  hidden_size=self.config.hidden_size,
104
106
  **loss_kwargs,
105
107
  )
@@ -226,14 +226,16 @@ def lce_forward(
226
226
 
227
227
  hidden_states = outputs[0]
228
228
 
229
+ shift_labels = loss_kwargs.pop("shift_labels", None)
229
230
  logits = None
230
231
  loss = None
231
232
  # if in training mode, don't materialize logits
232
- if self.training and (labels is not None):
233
+ if self.training and (labels is not None or shift_labels is not None):
233
234
  loss = LigerForCausalLMLoss(
234
235
  hidden_states=hidden_states,
235
236
  lm_head_weight=self.lm_head.weight,
236
237
  labels=labels,
238
+ shift_labels=shift_labels,
237
239
  hidden_size=self.config.hidden_size,
238
240
  **loss_kwargs,
239
241
  )
@@ -216,14 +216,16 @@ def lce_forward(
216
216
 
217
217
  hidden_states = outputs[0]
218
218
 
219
+ shift_labels = loss_kwargs.pop("shift_labels", None)
219
220
  logits = None
220
221
  loss = None
221
222
  # if in training mode, don't materialize logits
222
- if self.training and (labels is not None):
223
+ if self.training and (labels is not None or shift_labels is not None):
223
224
  loss = LigerForCausalLMLoss(
224
225
  hidden_states=hidden_states,
225
226
  lm_head_weight=self.lm_head.weight,
226
227
  labels=labels,
228
+ shift_labels=shift_labels,
227
229
  hidden_size=self.config.hidden_size,
228
230
  **loss_kwargs,
229
231
  )
@@ -89,14 +89,16 @@ def lce_forward(
89
89
 
90
90
  hidden_states = outputs[0]
91
91
 
92
+ shift_labels = loss_kwargs.pop("shift_labels", None)
92
93
  logits = None
93
94
  loss = None
94
95
  # if in training mode, don't materialize logits
95
- if self.training and (labels is not None):
96
+ if self.training and (labels is not None or shift_labels is not None):
96
97
  loss = LigerForCausalLMLoss(
97
98
  hidden_states=hidden_states,
98
99
  lm_head_weight=self.lm_head.weight,
99
100
  labels=labels,
101
+ shift_labels=shift_labels,
100
102
  hidden_size=self.config.hidden_size,
101
103
  **loss_kwargs,
102
104
  )
@@ -214,14 +214,16 @@ def lce_forward(
214
214
 
215
215
  hidden_states = outputs[0]
216
216
 
217
+ shift_labels = loss_kwargs.pop("shift_labels", None)
217
218
  logits = None
218
219
  loss = None
219
220
  # if in training mode, don't materialize logits
220
- if self.training and (labels is not None):
221
+ if self.training and (labels is not None or shift_labels is not None):
221
222
  loss = LigerForCausalLMLoss(
222
223
  hidden_states=hidden_states,
223
224
  lm_head_weight=self.lm_head.weight,
224
225
  labels=labels,
226
+ shift_labels=shift_labels,
225
227
  hidden_size=self.config.hidden_size,
226
228
  **loss_kwargs,
227
229
  )
@@ -200,14 +200,16 @@ def lce_forward(
200
200
 
201
201
  hidden_states = outputs[0]
202
202
 
203
+ shift_labels = loss_kwargs.pop("shift_labels", None)
203
204
  logits = None
204
205
  loss = None
205
206
  # if in training mode, don't materialize logits
206
- if self.training and (labels is not None):
207
+ if self.training and (labels is not None or shift_labels is not None):
207
208
  loss = LigerForCausalLMLoss(
208
209
  hidden_states=hidden_states,
209
210
  lm_head_weight=self.lm_head.weight,
210
211
  labels=labels,
212
+ shift_labels=shift_labels,
211
213
  hidden_size=self.config.hidden_size,
212
214
  **loss_kwargs,
213
215
  )
@@ -163,14 +163,16 @@ def lce_forward(
163
163
 
164
164
  hidden_states = outputs[0]
165
165
 
166
+ shift_labels = loss_kwargs.pop("shift_labels", None)
166
167
  loss = None
167
168
  logits = None
168
169
 
169
- if self.training and (labels is not None):
170
+ if self.training and (labels is not None or shift_labels is not None):
170
171
  loss = LigerForCausalLMLoss(
171
172
  hidden_states=hidden_states,
172
173
  lm_head_weight=self.lm_head.weight,
173
174
  labels=labels,
175
+ shift_labels=shift_labels,
174
176
  hidden_size=self.config.hidden_size,
175
177
  **loss_kwargs,
176
178
  )
@@ -167,14 +167,16 @@ def lce_forward(
167
167
 
168
168
  hidden_states = outputs[0]
169
169
 
170
+ shift_labels = loss_kwargs.pop("shift_labels", None)
170
171
  loss = None
171
172
  logits = None
172
173
 
173
- if self.training and (labels is not None):
174
+ if self.training and (labels is not None or shift_labels is not None):
174
175
  loss = LigerForCausalLMLoss(
175
176
  hidden_states=hidden_states,
176
177
  lm_head_weight=self.lm_head.weight,
177
178
  labels=labels,
179
+ shift_labels=shift_labels,
178
180
  hidden_size=self.config.hidden_size,
179
181
  **loss_kwargs,
180
182
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.8.dev20250422210723
3
+ Version: 0.5.8.dev20250428050809
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -55,28 +55,28 @@ liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx
55
55
  liger_kernel/transformers/tvd.py,sha256=XrRfyJIqN6HFxXk8MYyFVZM1OLz3mtSbRZvWfZ_JerQ,450
56
56
  liger_kernel/transformers/experimental/embedding.py,sha256=2P0QYdlFyFrG5OqTzTa1wcRgDSyjBMv5i1a7BrDPDQw,881
57
57
  liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
58
- liger_kernel/transformers/model/gemma.py,sha256=-JoHKWjtYPpxHQa6QbCwnzX_cctRZG2ZTsaUv-dmOt4,9816
59
- liger_kernel/transformers/model/gemma2.py,sha256=n4MZupFGDMvtnvkvkNhRrxXS3ZF341BVfyLjrOXp10g,10923
60
- liger_kernel/transformers/model/gemma3.py,sha256=ge3JYchiKvX1G1Zp00jX2zmQK2K7ymJoZAxbb2ggslw,16102
61
- liger_kernel/transformers/model/llama.py,sha256=UVXQLRW7rCU5vPab54dLNS3ER37eM446peHX00Yz6eA,10493
58
+ liger_kernel/transformers/model/gemma.py,sha256=uoZvur13XSvtUfiBIP25ZJXEGh4hB5KlB-fq_wpbavY,9940
59
+ liger_kernel/transformers/model/gemma2.py,sha256=4sPxsnFVywZiNsOoxFM4nEAKB5m5_efnJR7pCEVsQw4,11047
60
+ liger_kernel/transformers/model/gemma3.py,sha256=wGSNqaLRRgIGQ_r9esyhDezm2SkAGZflopoWoWR-nYY,16226
61
+ liger_kernel/transformers/model/llama.py,sha256=7AQROxICv2oKSrf5fGJifz_vyuPBkGRXbm0xipUwQew,10617
62
62
  liger_kernel/transformers/model/llava.py,sha256=b0pEagjUbu2-eS9xegjyfl1DwIXLwZcNpff55ibaMbA,17601
63
63
  liger_kernel/transformers/model/loss_utils.py,sha256=WWAMdiONPaXpIvxyOim_0igLrYh0yyOok5Q9_L9xvZw,1787
64
- liger_kernel/transformers/model/mistral.py,sha256=RacuKcckuDK6oSraCGD0R0bm-fE0K3q-lkYaAC56C2E,5481
65
- liger_kernel/transformers/model/mixtral.py,sha256=gLcqGabdv1XnuciS9b-TpkTDnGL8K32Hoq9j2vZMBRY,11502
66
- liger_kernel/transformers/model/mllama.py,sha256=75mxtmMsNd_q8KlKeawj2uMP6v2KjDuUi4nsUKM5jqA,11308
67
- liger_kernel/transformers/model/olmo2.py,sha256=rSzSALikEGkk0w3PLNQPrqg-ioN8TpWCXkAlg3LtCdI,5189
64
+ liger_kernel/transformers/model/mistral.py,sha256=jxZOKrazvJFUHzHQIbacUN_G9MILxv8x_JkXRaybRX4,5605
65
+ liger_kernel/transformers/model/mixtral.py,sha256=0gONJRzPDTpLhXg9x4c2woI6GkcmkMUUIuxcoayZU68,11626
66
+ liger_kernel/transformers/model/mllama.py,sha256=mXXisoETXB1x9LqV1r6GUj6kRq6RBOZ6guT94Rllqco,11432
67
+ liger_kernel/transformers/model/olmo2.py,sha256=KhSDSs3ay_zg7cWZDmS90KtA3E8WzrUFulPLCqwqD_g,5313
68
68
  liger_kernel/transformers/model/paligemma.py,sha256=GNReT6tVZt3ON6aaa9ovg8mnu1hYocSx9OhgC7b-_28,19191
69
- liger_kernel/transformers/model/phi3.py,sha256=ebITCrmwmb4z66CbSrZl1kD6BsP52IcSAR8uwUTp9nc,10455
70
- liger_kernel/transformers/model/qwen2.py,sha256=QaoTDrJv2wIuAM8QMoeWVvgNl0N5gHzIrew9QGG7kXc,9744
71
- liger_kernel/transformers/model/qwen2_5_vl.py,sha256=70BnHZjx6eQWTwi3zc5SMwxTeOOA4Tbdkfy6IYRcTaM,9289
72
- liger_kernel/transformers/model/qwen2_vl.py,sha256=zo4O9fShNHYqSLrzLGqQYWSMtJI6UHaSY7zvMCYWyD8,9685
69
+ liger_kernel/transformers/model/phi3.py,sha256=vDSCW-e72-GV_Ip0_c1bmXBvfoqQ1EXlHap8bHMMEuY,10579
70
+ liger_kernel/transformers/model/qwen2.py,sha256=RSdIDKqiTIyffevOD6aclbwqS9Vrmt0ibIIZfr1bnfY,9868
71
+ liger_kernel/transformers/model/qwen2_5_vl.py,sha256=oACIsTpg9_GdoSvekCyXLhJkuCpQEiFOTzKj7cjgi2E,9413
72
+ liger_kernel/transformers/model/qwen2_vl.py,sha256=F6DeQ65wPtcpeQJZ9a3SJZKkQ-e24SRLdYUgC-_jT-k,9809
73
73
  liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7HHWHwku25A-GYL0WU,193
74
74
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=pdekW7l6Qg_aqa5SYKYlSWUF8m3lkOFvFLcIMEHrz9s,8338
75
75
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
76
76
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
77
- liger_kernel_nightly-0.5.8.dev20250422210723.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
78
- liger_kernel_nightly-0.5.8.dev20250422210723.dist-info/METADATA,sha256=aSh18zXYcQy1fb3OW8Q-Q9_DYczeWXULpNDET3PCbfg,23297
79
- liger_kernel_nightly-0.5.8.dev20250422210723.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
80
- liger_kernel_nightly-0.5.8.dev20250422210723.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
81
- liger_kernel_nightly-0.5.8.dev20250422210723.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
82
- liger_kernel_nightly-0.5.8.dev20250422210723.dist-info/RECORD,,
77
+ liger_kernel_nightly-0.5.8.dev20250428050809.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
78
+ liger_kernel_nightly-0.5.8.dev20250428050809.dist-info/METADATA,sha256=JjME6LPHPZBmQ_lQlX3mWasYbAkKd0r6ZgEfsyeIGx8,23297
79
+ liger_kernel_nightly-0.5.8.dev20250428050809.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
80
+ liger_kernel_nightly-0.5.8.dev20250428050809.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
81
+ liger_kernel_nightly-0.5.8.dev20250428050809.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
82
+ liger_kernel_nightly-0.5.8.dev20250428050809.dist-info/RECORD,,