liger-kernel-nightly 0.5.10.dev20250605210201__py3-none-any.whl → 0.5.10.dev20250605224739__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  2. liger_kernel/transformers/functional.py +28 -0
  3. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  4. liger_kernel/transformers/model/gemma.py +5 -4
  5. liger_kernel/transformers/model/gemma2.py +7 -4
  6. liger_kernel/transformers/model/glm4.py +5 -4
  7. liger_kernel/transformers/model/llama.py +5 -4
  8. liger_kernel/transformers/model/mistral.py +5 -4
  9. liger_kernel/transformers/model/mixtral.py +5 -4
  10. liger_kernel/transformers/model/mllama.py +5 -4
  11. liger_kernel/transformers/model/olmo2.py +5 -4
  12. liger_kernel/transformers/model/phi3.py +5 -4
  13. liger_kernel/transformers/model/qwen2.py +5 -4
  14. liger_kernel/transformers/model/qwen2_5_vl.py +4 -3
  15. liger_kernel/transformers/model/qwen2_vl.py +4 -3
  16. liger_kernel/transformers/model/qwen3_moe.py +5 -4
  17. {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605224739.dist-info}/METADATA +1 -1
  18. {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605224739.dist-info}/RECORD +22 -20
  19. {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605224739.dist-info}/LICENSE +0 -0
  20. {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605224739.dist-info}/NOTICE +0 -0
  21. {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605224739.dist-info}/WHEEL +0 -0
  22. {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605224739.dist-info}/top_level.txt +0 -0
@@ -136,7 +136,7 @@ def lce_forward(
136
136
  cache_position: Optional[torch.LongTensor] = None,
137
137
  logits_to_keep: Union[int, torch.Tensor] = 0,
138
138
  skip_logits: Optional[bool] = None,
139
- **loss_kwargs,
139
+ **kwargs,
140
140
  ) -> Union[Tuple, CausalLMOutputWithPast]:
141
141
  r"""
142
142
  Args:
@@ -189,6 +189,7 @@ def lce_forward(
189
189
  output_hidden_states=output_hidden_states,
190
190
  return_dict=return_dict,
191
191
  cache_position=cache_position,
192
+ **kwargs,
192
193
  )
193
194
 
194
195
  hidden_states = outputs[0]
@@ -196,7 +197,7 @@ def lce_forward(
196
197
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
197
198
  kept_hidden_states = hidden_states[:, slice_indices, :]
198
199
 
199
- shift_labels = loss_kwargs.pop("shift_labels", None)
200
+ shift_labels = kwargs.pop("shift_labels", None)
200
201
  logits = None
201
202
  loss = None
202
203
 
@@ -214,7 +215,7 @@ def lce_forward(
214
215
  labels=labels,
215
216
  shift_labels=shift_labels,
216
217
  hidden_size=self.config.hidden_size,
217
- **loss_kwargs,
218
+ **kwargs,
218
219
  )
219
220
 
220
221
  else:
@@ -224,7 +225,7 @@ def lce_forward(
224
225
  logits=logits,
225
226
  labels=labels,
226
227
  vocab_size=self.config.vocab_size,
227
- **loss_kwargs,
228
+ **kwargs,
228
229
  )
229
230
 
230
231
  return CausalLMOutputWithPast(
@@ -31,7 +31,7 @@ def lce_forward(
31
31
  cache_position: Optional[torch.LongTensor] = None,
32
32
  second_per_grid_ts: Optional[torch.Tensor] = None,
33
33
  skip_logits: Optional[bool] = None,
34
- **loss_kwargs,
34
+ **kwargs,
35
35
  ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
36
36
  r"""
37
37
  Copy paste Qwen2_5_VL's forward but replace torch cross entropy with liger fused linear cross entropy
@@ -154,11 +154,12 @@ def lce_forward(
154
154
  output_hidden_states=output_hidden_states,
155
155
  return_dict=return_dict,
156
156
  cache_position=cache_position,
157
+ **kwargs,
157
158
  )
158
159
 
159
160
  hidden_states = outputs[0]
160
161
 
161
- shift_labels = loss_kwargs.pop("shift_labels", None)
162
+ shift_labels = kwargs.pop("shift_labels", None)
162
163
  loss = None
163
164
  logits = None
164
165
 
@@ -175,7 +176,7 @@ def lce_forward(
175
176
  labels=labels,
176
177
  shift_labels=shift_labels,
177
178
  hidden_size=self.config.hidden_size,
178
- **loss_kwargs,
179
+ **kwargs,
179
180
  )
180
181
  else:
181
182
  logits = self.lm_head(hidden_states)
@@ -32,7 +32,7 @@ def lce_forward(
32
32
  rope_deltas: Optional[torch.LongTensor] = None,
33
33
  cache_position: Optional[torch.LongTensor] = None,
34
34
  skip_logits: Optional[bool] = None,
35
- **loss_kwargs,
35
+ **kwargs,
36
36
  ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
37
37
  r"""
38
38
  Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy
@@ -158,11 +158,12 @@ def lce_forward(
158
158
  output_hidden_states=output_hidden_states,
159
159
  return_dict=return_dict,
160
160
  cache_position=cache_position,
161
+ **kwargs,
161
162
  )
162
163
 
163
164
  hidden_states = outputs[0]
164
165
 
165
- shift_labels = loss_kwargs.pop("shift_labels", None)
166
+ shift_labels = kwargs.pop("shift_labels", None)
166
167
  loss = None
167
168
  logits = None
168
169
 
@@ -179,7 +180,7 @@ def lce_forward(
179
180
  labels=labels,
180
181
  shift_labels=shift_labels,
181
182
  hidden_size=self.config.hidden_size,
182
- **loss_kwargs,
183
+ **kwargs,
183
184
  )
184
185
  else:
185
186
  logits = self.lm_head(hidden_states)
@@ -26,7 +26,7 @@ def lce_forward(
26
26
  cache_position: Optional[torch.LongTensor] = None,
27
27
  logits_to_keep: Union[int, torch.Tensor] = 0,
28
28
  skip_logits: Optional[bool] = None,
29
- **loss_kwargs,
29
+ **kwargs,
30
30
  ) -> MoeCausalLMOutputWithPast:
31
31
  r"""
32
32
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -81,6 +81,7 @@ def lce_forward(
81
81
  output_hidden_states=output_hidden_states,
82
82
  output_router_logits=output_router_logits,
83
83
  cache_position=cache_position,
84
+ **kwargs,
84
85
  )
85
86
 
86
87
  hidden_states = outputs.last_hidden_state
@@ -88,7 +89,7 @@ def lce_forward(
88
89
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
89
90
  kept_hidden_states = hidden_states[:, slice_indices, :]
90
91
 
91
- shift_labels = loss_kwargs.pop("shift_labels", None)
92
+ shift_labels = kwargs.pop("shift_labels", None)
92
93
  logits = None
93
94
  loss = None
94
95
 
@@ -102,12 +103,12 @@ def lce_forward(
102
103
  labels=labels,
103
104
  shift_labels=shift_labels,
104
105
  hidden_size=self.config.hidden_size,
105
- **loss_kwargs,
106
+ **kwargs,
106
107
  )
107
108
  else: # if in inference model materialize logits
108
109
  logits = self.lm_head(kept_hidden_states)
109
110
  if labels is not None:
110
- loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
111
+ loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
111
112
 
112
113
  aux_loss = None
113
114
  if output_router_logits:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.10.dev20250605210201
3
+ Version: 0.5.10.dev20250605224739
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -20,6 +20,7 @@ liger_kernel/ops/cross_entropy.py,sha256=e8THGnhOcy_0SbOLABx67HEM7-B8a8pG7nDKbCR
20
20
  liger_kernel/ops/dyt.py,sha256=Y180EIvtUc2z83mhyub0EVOCQHJmWX3JnscqkOJqswk,5467
21
21
  liger_kernel/ops/fused_linear_cross_entropy.py,sha256=5fbGhN85n3zf0uIdJ7PYHWIRzTf0VTFiS0ARtOmqIP0,11020
22
22
  liger_kernel/ops/fused_linear_jsd.py,sha256=CSoprxb-YcJy-YUKiTcYkxN8sb9h2kdk_iHuncvSV5c,9683
23
+ liger_kernel/ops/fused_neighborhood_attention.py,sha256=vPi5xbnh6wxyZehaqo6Tuilqo2fN5SGDiONjnNmIKqs,35556
23
24
  liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
24
25
  liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
25
26
  liger_kernel/ops/grpo_loss.py,sha256=anRnv7k1-AV3pCC6_TqP0GMg78YYUfRAJrbpx6PVhl0,9448
@@ -42,9 +43,10 @@ liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawX
42
43
  liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
43
44
  liger_kernel/transformers/dyt.py,sha256=i-4GPaMrl-jab9TVI5qN0-H9qycn_mCbV82ozU4nbmU,723
44
45
  liger_kernel/transformers/fsdp.py,sha256=CUiyjTmjkjY7pLXQv8ly9rnzgXw6529csd9pvtJNMYc,3096
45
- liger_kernel/transformers/functional.py,sha256=QmnAFpRgIbp9Rzlfp8QibwiEbf5BUcANxfY68an7o8c,6444
46
+ liger_kernel/transformers/functional.py,sha256=7Emw7D6VPMg8hfasC33NiolvKmQVF1gV6VayKQCEWJM,7446
46
47
  liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=O8Sg5BT81nTaY9fSGoOY9dOD9ekibwwiuXhdUHaxntQ,1742
47
48
  liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
49
+ liger_kernel/transformers/fused_neighborhood_attention.py,sha256=TxYDUAt9B6WSP14aJP66C_2Mbds2sSIPGnamhUSTrC8,7957
48
50
  liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
49
51
  liger_kernel/transformers/group_norm.py,sha256=6qMAWOprr4SzP0YhNVNGQIBpM5aUHplUD2VuGJrMBz0,2173
50
52
  liger_kernel/transformers/grpo_loss.py,sha256=uAkUNKSnUGEOqa82L9w2e6AI1kcmG8K45-QxyaT8zhM,3897
@@ -63,31 +65,31 @@ liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx
63
65
  liger_kernel/transformers/tvd.py,sha256=XrRfyJIqN6HFxXk8MYyFVZM1OLz3mtSbRZvWfZ_JerQ,450
64
66
  liger_kernel/transformers/experimental/embedding.py,sha256=2P0QYdlFyFrG5OqTzTa1wcRgDSyjBMv5i1a7BrDPDQw,881
65
67
  liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
66
- liger_kernel/transformers/model/gemma.py,sha256=LUifPVeiVpadKwAoi0g0lplhaE5OMdx-k4pSg4g_y8A,9732
67
- liger_kernel/transformers/model/gemma2.py,sha256=JxPTXVkuFtiaZYkaBM8bZF-ObyatHmAiOG_gzRe_ElU,10989
68
+ liger_kernel/transformers/model/gemma.py,sha256=gvP-9zZ1e-DQD06qltWmRhiJClJDtkMQL1UrPMMZZGQ,9730
69
+ liger_kernel/transformers/model/gemma2.py,sha256=ORmzklEAMpk93nToRo4d_ZJbM4ScVE2szczsEL4hw7w,11019
68
70
  liger_kernel/transformers/model/gemma3.py,sha256=JI4jj9K660HeRsofB6cpkCHBQ0OsazElArRtKUehUmw,15945
69
- liger_kernel/transformers/model/glm4.py,sha256=3YJiGdZ0nNSdZidPFlXdUad8mlFwyfq44yd11OcdNns,5259
70
- liger_kernel/transformers/model/llama.py,sha256=cAWTCY0bk67lFXNtAVEXIWl9WNgn4JyU25Q7nhpKjE0,12505
71
+ liger_kernel/transformers/model/glm4.py,sha256=GlnEhdGJuDIqp2R9qC54biY3HwV1tWmfpJm6ijoAsrM,5257
72
+ liger_kernel/transformers/model/llama.py,sha256=LcIxVfF0PXXWHBVJa6Ody_5fAtIpxQcI4jC_j-o51fU,12503
71
73
  liger_kernel/transformers/model/llava.py,sha256=ONdpx96AVbbL8QDQvHSm08jMJPz3tzkbeO92IRbAb1A,19270
72
74
  liger_kernel/transformers/model/loss_utils.py,sha256=WWAMdiONPaXpIvxyOim_0igLrYh0yyOok5Q9_L9xvZw,1787
73
- liger_kernel/transformers/model/mistral.py,sha256=vFFZD5VAwpx6Bs4gXoXDRmyU9-7Dp50w3jIcj0q0sIo,5567
74
- liger_kernel/transformers/model/mixtral.py,sha256=vSmgBc91WMu9_iWkAHUJPzo0-WDkTJK5SEVYNaDRT_Y,11398
75
- liger_kernel/transformers/model/mllama.py,sha256=XhxU5r9v5TfTa4NJpg1EeYC999Q1e0CZwkVE86iaymU,11282
76
- liger_kernel/transformers/model/olmo2.py,sha256=4IwhP-TBck0dijY5gdLvoQnXO2M80gfLUV2fsK6wpiY,5261
75
+ liger_kernel/transformers/model/mistral.py,sha256=okKkyashfFLfhjIT--f3JY6JHOslOtDI8U1dlpBC2Zs,5565
76
+ liger_kernel/transformers/model/mixtral.py,sha256=VY-y73IyjcCyWyI7ahxXLw0fJrhgjYfr1xwRYtsHX0o,11396
77
+ liger_kernel/transformers/model/mllama.py,sha256=my29NXk-p6ckQaP8qDIN8e318yI_9mQZHt38MV3SqLY,11280
78
+ liger_kernel/transformers/model/olmo2.py,sha256=6L_bo-ZUgO1lYppdJneOtYxNIylQKS6BiGp13g7Uq9E,5259
77
79
  liger_kernel/transformers/model/paligemma.py,sha256=xuIx3oOwTgftU3jqLfWOxUxgCLBNJh0yNC21an9qDjo,18773
78
- liger_kernel/transformers/model/phi3.py,sha256=UslJ1gbyRhVmj5fzq_uizhDY5wYEoK_EwPamNotsUVs,10378
79
- liger_kernel/transformers/model/qwen2.py,sha256=WTKFPAp_R4aSRLQgMKygX6pmptcHeLWGCdjH42SxXVk,9660
80
- liger_kernel/transformers/model/qwen2_5_vl.py,sha256=OFPaELlVi1UdkjSXxVWGnNc32CdcQ74KV_3Dc8-uCe4,9200
81
- liger_kernel/transformers/model/qwen2_vl.py,sha256=s973gNrFT74FYAYiRvorxtK15CpZJnlbbhfk_wk-tag,9611
80
+ liger_kernel/transformers/model/phi3.py,sha256=m-MD_OuTaYMGZhHOvl-RHOVEObrL8tL5cBv3VTNd4F0,10376
81
+ liger_kernel/transformers/model/qwen2.py,sha256=SdN7V-MI3eX9s2DAFRvC1g-G146uG_5n1fnNdY9QwYk,9658
82
+ liger_kernel/transformers/model/qwen2_5_vl.py,sha256=k6jt1bTCJsKsZVGhTxqIbDzmnL8-B3CpWJOjLazswbo,9203
83
+ liger_kernel/transformers/model/qwen2_vl.py,sha256=Cgs7-nPlKFifiDO9gqSI6np4vRUVCKiqoospT_vIi_M,9614
82
84
  liger_kernel/transformers/model/qwen3.py,sha256=w2jBHuK9kK9EmOr5dnEIXNQXUgUSV_sJUkXSEwxLPHs,4885
83
- liger_kernel/transformers/model/qwen3_moe.py,sha256=CbLP4eltlmPPkcSJ2WMe61P7_n-ksKF-bzyWfEMNXFg,5513
85
+ liger_kernel/transformers/model/qwen3_moe.py,sha256=BkpfFH3fOH0yRfA7LF-AoHTLut2GV0Y4MOlkiIYewfU,5511
84
86
  liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7HHWHwku25A-GYL0WU,193
85
87
  liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
86
88
  liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
87
89
  liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
88
- liger_kernel_nightly-0.5.10.dev20250605210201.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
89
- liger_kernel_nightly-0.5.10.dev20250605210201.dist-info/METADATA,sha256=AHzrxsgoWvM8SMrpFXF1LHayfA--Wmjdh9PdVUcpCSE,24309
90
- liger_kernel_nightly-0.5.10.dev20250605210201.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
91
- liger_kernel_nightly-0.5.10.dev20250605210201.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
92
- liger_kernel_nightly-0.5.10.dev20250605210201.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
93
- liger_kernel_nightly-0.5.10.dev20250605210201.dist-info/RECORD,,
90
+ liger_kernel_nightly-0.5.10.dev20250605224739.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
91
+ liger_kernel_nightly-0.5.10.dev20250605224739.dist-info/METADATA,sha256=fJJZbkI2vH7QV5qhJouSk17zKPSUuZNWCWY2kXjDYPQ,24309
92
+ liger_kernel_nightly-0.5.10.dev20250605224739.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
93
+ liger_kernel_nightly-0.5.10.dev20250605224739.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
94
+ liger_kernel_nightly-0.5.10.dev20250605224739.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
95
+ liger_kernel_nightly-0.5.10.dev20250605224739.dist-info/RECORD,,