liger-kernel-nightly 0.5.10.dev20250605210201__py3-none-any.whl → 0.5.10.dev20250605224739__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/transformers/functional.py +28 -0
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/model/gemma.py +5 -4
- liger_kernel/transformers/model/gemma2.py +7 -4
- liger_kernel/transformers/model/glm4.py +5 -4
- liger_kernel/transformers/model/llama.py +5 -4
- liger_kernel/transformers/model/mistral.py +5 -4
- liger_kernel/transformers/model/mixtral.py +5 -4
- liger_kernel/transformers/model/mllama.py +5 -4
- liger_kernel/transformers/model/olmo2.py +5 -4
- liger_kernel/transformers/model/phi3.py +5 -4
- liger_kernel/transformers/model/qwen2.py +5 -4
- liger_kernel/transformers/model/qwen2_5_vl.py +4 -3
- liger_kernel/transformers/model/qwen2_vl.py +4 -3
- liger_kernel/transformers/model/qwen3_moe.py +5 -4
- {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605224739.dist-info}/METADATA +1 -1
- {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605224739.dist-info}/RECORD +22 -20
- {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605224739.dist-info}/LICENSE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605224739.dist-info}/NOTICE +0 -0
- {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605224739.dist-info}/WHEEL +0 -0
- {liger_kernel_nightly-0.5.10.dev20250605210201.dist-info → liger_kernel_nightly-0.5.10.dev20250605224739.dist-info}/top_level.txt +0 -0
@@ -136,7 +136,7 @@ def lce_forward(
|
|
136
136
|
cache_position: Optional[torch.LongTensor] = None,
|
137
137
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
138
138
|
skip_logits: Optional[bool] = None,
|
139
|
-
**
|
139
|
+
**kwargs,
|
140
140
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
141
141
|
r"""
|
142
142
|
Args:
|
@@ -189,6 +189,7 @@ def lce_forward(
|
|
189
189
|
output_hidden_states=output_hidden_states,
|
190
190
|
return_dict=return_dict,
|
191
191
|
cache_position=cache_position,
|
192
|
+
**kwargs,
|
192
193
|
)
|
193
194
|
|
194
195
|
hidden_states = outputs[0]
|
@@ -196,7 +197,7 @@ def lce_forward(
|
|
196
197
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
197
198
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
198
199
|
|
199
|
-
shift_labels =
|
200
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
200
201
|
logits = None
|
201
202
|
loss = None
|
202
203
|
|
@@ -214,7 +215,7 @@ def lce_forward(
|
|
214
215
|
labels=labels,
|
215
216
|
shift_labels=shift_labels,
|
216
217
|
hidden_size=self.config.hidden_size,
|
217
|
-
**
|
218
|
+
**kwargs,
|
218
219
|
)
|
219
220
|
|
220
221
|
else:
|
@@ -224,7 +225,7 @@ def lce_forward(
|
|
224
225
|
logits=logits,
|
225
226
|
labels=labels,
|
226
227
|
vocab_size=self.config.vocab_size,
|
227
|
-
**
|
228
|
+
**kwargs,
|
228
229
|
)
|
229
230
|
|
230
231
|
return CausalLMOutputWithPast(
|
@@ -31,7 +31,7 @@ def lce_forward(
|
|
31
31
|
cache_position: Optional[torch.LongTensor] = None,
|
32
32
|
second_per_grid_ts: Optional[torch.Tensor] = None,
|
33
33
|
skip_logits: Optional[bool] = None,
|
34
|
-
**
|
34
|
+
**kwargs,
|
35
35
|
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
|
36
36
|
r"""
|
37
37
|
Copy paste Qwen2_5_VL's forward but replace torch cross entropy with liger fused linear cross entropy
|
@@ -154,11 +154,12 @@ def lce_forward(
|
|
154
154
|
output_hidden_states=output_hidden_states,
|
155
155
|
return_dict=return_dict,
|
156
156
|
cache_position=cache_position,
|
157
|
+
**kwargs,
|
157
158
|
)
|
158
159
|
|
159
160
|
hidden_states = outputs[0]
|
160
161
|
|
161
|
-
shift_labels =
|
162
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
162
163
|
loss = None
|
163
164
|
logits = None
|
164
165
|
|
@@ -175,7 +176,7 @@ def lce_forward(
|
|
175
176
|
labels=labels,
|
176
177
|
shift_labels=shift_labels,
|
177
178
|
hidden_size=self.config.hidden_size,
|
178
|
-
**
|
179
|
+
**kwargs,
|
179
180
|
)
|
180
181
|
else:
|
181
182
|
logits = self.lm_head(hidden_states)
|
@@ -32,7 +32,7 @@ def lce_forward(
|
|
32
32
|
rope_deltas: Optional[torch.LongTensor] = None,
|
33
33
|
cache_position: Optional[torch.LongTensor] = None,
|
34
34
|
skip_logits: Optional[bool] = None,
|
35
|
-
**
|
35
|
+
**kwargs,
|
36
36
|
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
|
37
37
|
r"""
|
38
38
|
Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy
|
@@ -158,11 +158,12 @@ def lce_forward(
|
|
158
158
|
output_hidden_states=output_hidden_states,
|
159
159
|
return_dict=return_dict,
|
160
160
|
cache_position=cache_position,
|
161
|
+
**kwargs,
|
161
162
|
)
|
162
163
|
|
163
164
|
hidden_states = outputs[0]
|
164
165
|
|
165
|
-
shift_labels =
|
166
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
166
167
|
loss = None
|
167
168
|
logits = None
|
168
169
|
|
@@ -179,7 +180,7 @@ def lce_forward(
|
|
179
180
|
labels=labels,
|
180
181
|
shift_labels=shift_labels,
|
181
182
|
hidden_size=self.config.hidden_size,
|
182
|
-
**
|
183
|
+
**kwargs,
|
183
184
|
)
|
184
185
|
else:
|
185
186
|
logits = self.lm_head(hidden_states)
|
@@ -26,7 +26,7 @@ def lce_forward(
|
|
26
26
|
cache_position: Optional[torch.LongTensor] = None,
|
27
27
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
28
28
|
skip_logits: Optional[bool] = None,
|
29
|
-
**
|
29
|
+
**kwargs,
|
30
30
|
) -> MoeCausalLMOutputWithPast:
|
31
31
|
r"""
|
32
32
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
@@ -81,6 +81,7 @@ def lce_forward(
|
|
81
81
|
output_hidden_states=output_hidden_states,
|
82
82
|
output_router_logits=output_router_logits,
|
83
83
|
cache_position=cache_position,
|
84
|
+
**kwargs,
|
84
85
|
)
|
85
86
|
|
86
87
|
hidden_states = outputs.last_hidden_state
|
@@ -88,7 +89,7 @@ def lce_forward(
|
|
88
89
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
89
90
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
90
91
|
|
91
|
-
shift_labels =
|
92
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
92
93
|
logits = None
|
93
94
|
loss = None
|
94
95
|
|
@@ -102,12 +103,12 @@ def lce_forward(
|
|
102
103
|
labels=labels,
|
103
104
|
shift_labels=shift_labels,
|
104
105
|
hidden_size=self.config.hidden_size,
|
105
|
-
**
|
106
|
+
**kwargs,
|
106
107
|
)
|
107
108
|
else: # if in inference model materialize logits
|
108
109
|
logits = self.lm_head(kept_hidden_states)
|
109
110
|
if labels is not None:
|
110
|
-
loss = self.loss_function(logits, labels, self.vocab_size, **
|
111
|
+
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
111
112
|
|
112
113
|
aux_loss = None
|
113
114
|
if output_router_logits:
|
@@ -20,6 +20,7 @@ liger_kernel/ops/cross_entropy.py,sha256=e8THGnhOcy_0SbOLABx67HEM7-B8a8pG7nDKbCR
|
|
20
20
|
liger_kernel/ops/dyt.py,sha256=Y180EIvtUc2z83mhyub0EVOCQHJmWX3JnscqkOJqswk,5467
|
21
21
|
liger_kernel/ops/fused_linear_cross_entropy.py,sha256=5fbGhN85n3zf0uIdJ7PYHWIRzTf0VTFiS0ARtOmqIP0,11020
|
22
22
|
liger_kernel/ops/fused_linear_jsd.py,sha256=CSoprxb-YcJy-YUKiTcYkxN8sb9h2kdk_iHuncvSV5c,9683
|
23
|
+
liger_kernel/ops/fused_neighborhood_attention.py,sha256=vPi5xbnh6wxyZehaqo6Tuilqo2fN5SGDiONjnNmIKqs,35556
|
23
24
|
liger_kernel/ops/geglu.py,sha256=axGvCIvlBzuluoAIrWTsp2iZM4BFKNInkPov8YVvH9E,4126
|
24
25
|
liger_kernel/ops/group_norm.py,sha256=qD4D4lSjSgVtO52EBNLC2iTseALRgPgqXE50U2woggk,10837
|
25
26
|
liger_kernel/ops/grpo_loss.py,sha256=anRnv7k1-AV3pCC6_TqP0GMg78YYUfRAJrbpx6PVhl0,9448
|
@@ -42,9 +43,10 @@ liger_kernel/transformers/auto_model.py,sha256=0qCTRZt280Bj_LcFdzo9hlaR-BWNazawX
|
|
42
43
|
liger_kernel/transformers/cross_entropy.py,sha256=z3KTWQnFxr_IZaVjtYt0ZNEWQdDdYThN35xWkHlDGH0,1683
|
43
44
|
liger_kernel/transformers/dyt.py,sha256=i-4GPaMrl-jab9TVI5qN0-H9qycn_mCbV82ozU4nbmU,723
|
44
45
|
liger_kernel/transformers/fsdp.py,sha256=CUiyjTmjkjY7pLXQv8ly9rnzgXw6529csd9pvtJNMYc,3096
|
45
|
-
liger_kernel/transformers/functional.py,sha256=
|
46
|
+
liger_kernel/transformers/functional.py,sha256=7Emw7D6VPMg8hfasC33NiolvKmQVF1gV6VayKQCEWJM,7446
|
46
47
|
liger_kernel/transformers/fused_linear_cross_entropy.py,sha256=O8Sg5BT81nTaY9fSGoOY9dOD9ekibwwiuXhdUHaxntQ,1742
|
47
48
|
liger_kernel/transformers/fused_linear_jsd.py,sha256=bZ4otCvWBuOnA5XdQL-FzZVItJlDt-ht9e_pG7PG93E,3999
|
49
|
+
liger_kernel/transformers/fused_neighborhood_attention.py,sha256=TxYDUAt9B6WSP14aJP66C_2Mbds2sSIPGnamhUSTrC8,7957
|
48
50
|
liger_kernel/transformers/geglu.py,sha256=mrgqzIUVd6lN7fkDKLkw5YaESDxDtFgbot430WwPVOQ,1107
|
49
51
|
liger_kernel/transformers/group_norm.py,sha256=6qMAWOprr4SzP0YhNVNGQIBpM5aUHplUD2VuGJrMBz0,2173
|
50
52
|
liger_kernel/transformers/grpo_loss.py,sha256=uAkUNKSnUGEOqa82L9w2e6AI1kcmG8K45-QxyaT8zhM,3897
|
@@ -63,31 +65,31 @@ liger_kernel/transformers/trainer_integration.py,sha256=W3ON51O5GkyzNJsItz0y5rKx
|
|
63
65
|
liger_kernel/transformers/tvd.py,sha256=XrRfyJIqN6HFxXk8MYyFVZM1OLz3mtSbRZvWfZ_JerQ,450
|
64
66
|
liger_kernel/transformers/experimental/embedding.py,sha256=2P0QYdlFyFrG5OqTzTa1wcRgDSyjBMv5i1a7BrDPDQw,881
|
65
67
|
liger_kernel/transformers/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
66
|
-
liger_kernel/transformers/model/gemma.py,sha256=
|
67
|
-
liger_kernel/transformers/model/gemma2.py,sha256=
|
68
|
+
liger_kernel/transformers/model/gemma.py,sha256=gvP-9zZ1e-DQD06qltWmRhiJClJDtkMQL1UrPMMZZGQ,9730
|
69
|
+
liger_kernel/transformers/model/gemma2.py,sha256=ORmzklEAMpk93nToRo4d_ZJbM4ScVE2szczsEL4hw7w,11019
|
68
70
|
liger_kernel/transformers/model/gemma3.py,sha256=JI4jj9K660HeRsofB6cpkCHBQ0OsazElArRtKUehUmw,15945
|
69
|
-
liger_kernel/transformers/model/glm4.py,sha256=
|
70
|
-
liger_kernel/transformers/model/llama.py,sha256=
|
71
|
+
liger_kernel/transformers/model/glm4.py,sha256=GlnEhdGJuDIqp2R9qC54biY3HwV1tWmfpJm6ijoAsrM,5257
|
72
|
+
liger_kernel/transformers/model/llama.py,sha256=LcIxVfF0PXXWHBVJa6Ody_5fAtIpxQcI4jC_j-o51fU,12503
|
71
73
|
liger_kernel/transformers/model/llava.py,sha256=ONdpx96AVbbL8QDQvHSm08jMJPz3tzkbeO92IRbAb1A,19270
|
72
74
|
liger_kernel/transformers/model/loss_utils.py,sha256=WWAMdiONPaXpIvxyOim_0igLrYh0yyOok5Q9_L9xvZw,1787
|
73
|
-
liger_kernel/transformers/model/mistral.py,sha256=
|
74
|
-
liger_kernel/transformers/model/mixtral.py,sha256=
|
75
|
-
liger_kernel/transformers/model/mllama.py,sha256=
|
76
|
-
liger_kernel/transformers/model/olmo2.py,sha256=
|
75
|
+
liger_kernel/transformers/model/mistral.py,sha256=okKkyashfFLfhjIT--f3JY6JHOslOtDI8U1dlpBC2Zs,5565
|
76
|
+
liger_kernel/transformers/model/mixtral.py,sha256=VY-y73IyjcCyWyI7ahxXLw0fJrhgjYfr1xwRYtsHX0o,11396
|
77
|
+
liger_kernel/transformers/model/mllama.py,sha256=my29NXk-p6ckQaP8qDIN8e318yI_9mQZHt38MV3SqLY,11280
|
78
|
+
liger_kernel/transformers/model/olmo2.py,sha256=6L_bo-ZUgO1lYppdJneOtYxNIylQKS6BiGp13g7Uq9E,5259
|
77
79
|
liger_kernel/transformers/model/paligemma.py,sha256=xuIx3oOwTgftU3jqLfWOxUxgCLBNJh0yNC21an9qDjo,18773
|
78
|
-
liger_kernel/transformers/model/phi3.py,sha256=
|
79
|
-
liger_kernel/transformers/model/qwen2.py,sha256=
|
80
|
-
liger_kernel/transformers/model/qwen2_5_vl.py,sha256=
|
81
|
-
liger_kernel/transformers/model/qwen2_vl.py,sha256=
|
80
|
+
liger_kernel/transformers/model/phi3.py,sha256=m-MD_OuTaYMGZhHOvl-RHOVEObrL8tL5cBv3VTNd4F0,10376
|
81
|
+
liger_kernel/transformers/model/qwen2.py,sha256=SdN7V-MI3eX9s2DAFRvC1g-G146uG_5n1fnNdY9QwYk,9658
|
82
|
+
liger_kernel/transformers/model/qwen2_5_vl.py,sha256=k6jt1bTCJsKsZVGhTxqIbDzmnL8-B3CpWJOjLazswbo,9203
|
83
|
+
liger_kernel/transformers/model/qwen2_vl.py,sha256=Cgs7-nPlKFifiDO9gqSI6np4vRUVCKiqoospT_vIi_M,9614
|
82
84
|
liger_kernel/transformers/model/qwen3.py,sha256=w2jBHuK9kK9EmOr5dnEIXNQXUgUSV_sJUkXSEwxLPHs,4885
|
83
|
-
liger_kernel/transformers/model/qwen3_moe.py,sha256=
|
85
|
+
liger_kernel/transformers/model/qwen3_moe.py,sha256=BkpfFH3fOH0yRfA7LF-AoHTLut2GV0Y4MOlkiIYewfU,5511
|
84
86
|
liger_kernel/transformers/trainer/__init__.py,sha256=p7yQfklV8-467qSz_ZMimkbDF7HHWHwku25A-GYL0WU,193
|
85
87
|
liger_kernel/transformers/trainer/orpo_trainer.py,sha256=tX0h63aOFe3rNqTmk6JpMf75UPo981yzEa6TghnjS0Q,5370
|
86
88
|
liger_kernel/triton/__init__.py,sha256=qCiCamzCRv6lpV8IqpAc9YMdNKC7GKurClWceQPnlis,92
|
87
89
|
liger_kernel/triton/monkey_patch.py,sha256=Rd0hUHAzDkFfHvnX7-PBaNK5EKnZhtfM_h-fgQH9HPY,1568
|
88
|
-
liger_kernel_nightly-0.5.10.
|
89
|
-
liger_kernel_nightly-0.5.10.
|
90
|
-
liger_kernel_nightly-0.5.10.
|
91
|
-
liger_kernel_nightly-0.5.10.
|
92
|
-
liger_kernel_nightly-0.5.10.
|
93
|
-
liger_kernel_nightly-0.5.10.
|
90
|
+
liger_kernel_nightly-0.5.10.dev20250605224739.dist-info/LICENSE,sha256=OhzLDHJ0to4a8sodVLELZiCFylZ1NAAYLs-HrjPy0ag,1312
|
91
|
+
liger_kernel_nightly-0.5.10.dev20250605224739.dist-info/METADATA,sha256=fJJZbkI2vH7QV5qhJouSk17zKPSUuZNWCWY2kXjDYPQ,24309
|
92
|
+
liger_kernel_nightly-0.5.10.dev20250605224739.dist-info/NOTICE,sha256=njwnoPZLh9AN8SJQzxvCGLHi-8X__AvWRze6joNXIY8,2066
|
93
|
+
liger_kernel_nightly-0.5.10.dev20250605224739.dist-info/WHEEL,sha256=iAkIy5fosb7FzIOwONchHf19Qu7_1wCWyFNR5gu9nU0,91
|
94
|
+
liger_kernel_nightly-0.5.10.dev20250605224739.dist-info/top_level.txt,sha256=2eghu4hA3LnkM7ElW92tQ8zegWKgSbeo-k-aGe1YnvY,13
|
95
|
+
liger_kernel_nightly-0.5.10.dev20250605224739.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|