liger-kernel 0.5.5__py3-none-any.whl → 0.5.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +17 -2
- liger_kernel/chunked_loss/fused_linear_ppo.py +346 -0
- liger_kernel/chunked_loss/grpo_loss.py +134 -60
- liger_kernel/chunked_loss/jsd_loss.py +12 -7
- liger_kernel/ops/cross_entropy.py +3 -2
- liger_kernel/ops/dyt.py +225 -0
- liger_kernel/ops/fused_linear_jsd.py +2 -1
- liger_kernel/ops/jsd.py +32 -12
- liger_kernel/ops/kl_div.py +15 -8
- liger_kernel/ops/layer_norm.py +14 -1
- liger_kernel/ops/rms_norm.py +12 -1
- liger_kernel/transformers/__init__.py +133 -15
- liger_kernel/transformers/dyt.py +20 -0
- liger_kernel/transformers/functional.py +5 -0
- liger_kernel/transformers/gema3_rms.py +8 -0
- liger_kernel/transformers/model/gemma.py +17 -20
- liger_kernel/transformers/model/gemma2.py +17 -21
- liger_kernel/transformers/model/gemma3.py +335 -0
- liger_kernel/transformers/model/llama.py +17 -19
- liger_kernel/transformers/model/llava.py +369 -0
- liger_kernel/transformers/model/loss_utils.py +64 -0
- liger_kernel/transformers/model/mistral.py +28 -25
- liger_kernel/transformers/model/mixtral.py +20 -26
- liger_kernel/transformers/model/mllama.py +17 -19
- liger_kernel/transformers/model/olmo2.py +17 -20
- liger_kernel/transformers/model/paligemma.py +397 -0
- liger_kernel/transformers/model/phi3.py +17 -19
- liger_kernel/transformers/model/qwen2.py +17 -19
- liger_kernel/transformers/model/qwen2_5_vl.py +9 -10
- liger_kernel/transformers/model/qwen2_vl.py +9 -10
- liger_kernel/transformers/monkey_patch.py +392 -13
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/METADATA +11 -6
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/RECORD +38 -31
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/WHEEL +1 -1
- liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -240
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info/licenses}/LICENSE +0 -0
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info/licenses}/NOTICE +0 -0
- {liger_kernel-0.5.5.dist-info → liger_kernel-0.5.7.dist-info}/top_level.txt +0 -0
|
@@ -11,8 +11,10 @@ from transformers.models.qwen2.modeling_qwen2 import _CONFIG_FOR_DOC
|
|
|
11
11
|
from transformers.models.qwen2.modeling_qwen2 import QWEN2_INPUTS_DOCSTRING
|
|
12
12
|
from transformers.utils import add_start_docstrings_to_model_forward
|
|
13
13
|
from transformers.utils import replace_return_docstrings
|
|
14
|
+
from transformers.utils.deprecation import deprecate_kwarg
|
|
14
15
|
|
|
15
16
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
17
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
16
18
|
|
|
17
19
|
|
|
18
20
|
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
|
@@ -124,6 +126,7 @@ def lce_forward_deprecated(
|
|
|
124
126
|
)
|
|
125
127
|
|
|
126
128
|
|
|
129
|
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
127
130
|
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
|
128
131
|
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
129
132
|
def lce_forward(
|
|
@@ -139,7 +142,7 @@ def lce_forward(
|
|
|
139
142
|
output_hidden_states: Optional[bool] = None,
|
|
140
143
|
return_dict: Optional[bool] = None,
|
|
141
144
|
cache_position: Optional[torch.LongTensor] = None,
|
|
142
|
-
|
|
145
|
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
143
146
|
**loss_kwargs,
|
|
144
147
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
145
148
|
r"""
|
|
@@ -149,10 +152,12 @@ def lce_forward(
|
|
|
149
152
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
150
153
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
151
154
|
|
|
152
|
-
|
|
153
|
-
|
|
155
|
+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
156
|
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
154
157
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
155
158
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
159
|
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
160
|
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
156
161
|
|
|
157
162
|
Returns:
|
|
158
163
|
|
|
@@ -199,24 +204,17 @@ def lce_forward(
|
|
|
199
204
|
loss = None
|
|
200
205
|
# if in training mode, don't materialize logits
|
|
201
206
|
if self.training and (labels is not None):
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
shift_labels = shift_labels.view(-1)
|
|
210
|
-
|
|
211
|
-
reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
|
|
212
|
-
lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
|
|
213
|
-
|
|
214
|
-
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
215
|
-
if reduction == "sum":
|
|
216
|
-
loss /= loss_kwargs["num_items_in_batch"]
|
|
207
|
+
loss = LigerForCausalLMLoss(
|
|
208
|
+
hidden_states=hidden_states,
|
|
209
|
+
lm_head_weight=self.lm_head.weight,
|
|
210
|
+
labels=labels,
|
|
211
|
+
hidden_size=self.config.hidden_size,
|
|
212
|
+
**loss_kwargs,
|
|
213
|
+
)
|
|
217
214
|
|
|
218
215
|
else: # if in inference mode materialize logits
|
|
219
|
-
|
|
216
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
217
|
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
220
218
|
if labels is not None:
|
|
221
219
|
loss = self.loss_function(
|
|
222
220
|
logits=logits,
|
|
@@ -12,7 +12,7 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalL
|
|
|
12
12
|
from transformers.utils import add_start_docstrings_to_model_forward
|
|
13
13
|
from transformers.utils import replace_return_docstrings
|
|
14
14
|
|
|
15
|
-
from liger_kernel.transformers.
|
|
15
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
@add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING)
|
|
@@ -36,6 +36,7 @@ def lce_forward(
|
|
|
36
36
|
rope_deltas: Optional[torch.LongTensor] = None,
|
|
37
37
|
cache_position: Optional[torch.LongTensor] = None,
|
|
38
38
|
second_per_grid_ts: Optional[torch.Tensor] = None,
|
|
39
|
+
**loss_kwargs,
|
|
39
40
|
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
|
|
40
41
|
r"""
|
|
41
42
|
Copy paste Qwen2_5_VL's forward but replace torch cross entropy with liger fused linear cross entropy
|
|
@@ -166,15 +167,13 @@ def lce_forward(
|
|
|
166
167
|
logits = None
|
|
167
168
|
|
|
168
169
|
if self.training and (labels is not None):
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
lce = LigerFusedLinearCrossEntropyLoss()
|
|
177
|
-
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
170
|
+
loss = LigerForCausalLMLoss(
|
|
171
|
+
hidden_states=hidden_states,
|
|
172
|
+
lm_head_weight=self.lm_head.weight,
|
|
173
|
+
labels=labels,
|
|
174
|
+
hidden_size=self.config.hidden_size,
|
|
175
|
+
**loss_kwargs,
|
|
176
|
+
)
|
|
178
177
|
else:
|
|
179
178
|
logits = self.lm_head(hidden_states)
|
|
180
179
|
if labels is not None:
|
|
@@ -14,7 +14,7 @@ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutput
|
|
|
14
14
|
from transformers.utils import add_start_docstrings_to_model_forward
|
|
15
15
|
from transformers.utils import replace_return_docstrings
|
|
16
16
|
|
|
17
|
-
from liger_kernel.transformers.
|
|
17
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
|
|
@@ -37,6 +37,7 @@ def lce_forward(
|
|
|
37
37
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
38
38
|
rope_deltas: Optional[torch.LongTensor] = None,
|
|
39
39
|
cache_position: Optional[torch.LongTensor] = None,
|
|
40
|
+
**loss_kwargs,
|
|
40
41
|
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
|
|
41
42
|
r"""
|
|
42
43
|
Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy
|
|
@@ -170,15 +171,13 @@ def lce_forward(
|
|
|
170
171
|
logits = None
|
|
171
172
|
|
|
172
173
|
if self.training and (labels is not None):
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
lce = LigerFusedLinearCrossEntropyLoss()
|
|
181
|
-
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
174
|
+
loss = LigerForCausalLMLoss(
|
|
175
|
+
hidden_states=hidden_states,
|
|
176
|
+
lm_head_weight=self.lm_head.weight,
|
|
177
|
+
labels=labels,
|
|
178
|
+
hidden_size=self.config.hidden_size,
|
|
179
|
+
**loss_kwargs,
|
|
180
|
+
)
|
|
182
181
|
else:
|
|
183
182
|
logits = self.lm_head(hidden_states)
|
|
184
183
|
if labels is not None:
|