liger-kernel 0.5.9__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/__init__.py +1 -0
- liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
- liger_kernel/chunked_loss/dpo_loss.py +1 -1
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
- liger_kernel/chunked_loss/jsd_loss.py +2 -2
- liger_kernel/ops/dyt.py +111 -179
- liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
- liger_kernel/ops/geglu.py +1 -1
- liger_kernel/ops/grpo_loss.py +310 -0
- liger_kernel/ops/multi_token_attention.py +207 -0
- liger_kernel/ops/rms_norm.py +265 -54
- liger_kernel/ops/softmax.py +201 -0
- liger_kernel/ops/sparsemax.py +179 -0
- liger_kernel/ops/swiglu.py +1 -1
- liger_kernel/transformers/__init__.py +8 -0
- liger_kernel/transformers/dyt.py +5 -3
- liger_kernel/transformers/fsdp.py +55 -0
- liger_kernel/transformers/functional.py +70 -0
- liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
- liger_kernel/transformers/grpo_loss.py +98 -0
- liger_kernel/transformers/model/gemma.py +25 -16
- liger_kernel/transformers/model/gemma2.py +27 -14
- liger_kernel/transformers/model/gemma3.py +62 -106
- liger_kernel/transformers/model/glm4.py +16 -13
- liger_kernel/transformers/model/llama.py +81 -18
- liger_kernel/transformers/model/llama4.py +108 -0
- liger_kernel/transformers/model/llava.py +95 -132
- liger_kernel/transformers/model/mistral.py +13 -14
- liger_kernel/transformers/model/mixtral.py +16 -15
- liger_kernel/transformers/model/mllama.py +16 -14
- liger_kernel/transformers/model/olmo2.py +16 -13
- liger_kernel/transformers/model/paligemma.py +8 -9
- liger_kernel/transformers/model/phi3.py +25 -16
- liger_kernel/transformers/model/qwen2.py +24 -15
- liger_kernel/transformers/model/qwen2_5_vl.py +41 -97
- liger_kernel/transformers/model/qwen2_vl.py +38 -106
- liger_kernel/transformers/model/qwen3.py +11 -9
- liger_kernel/transformers/model/qwen3_moe.py +132 -0
- liger_kernel/transformers/monkey_patch.py +424 -81
- liger_kernel/transformers/multi_token_attention.py +64 -0
- liger_kernel/transformers/rms_norm.py +40 -4
- liger_kernel/transformers/softmax.py +12 -0
- liger_kernel/transformers/sparsemax.py +16 -0
- liger_kernel/transformers/swiglu.py +21 -0
- liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
- liger_kernel/utils.py +11 -0
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/METADATA +41 -21
- liger_kernel-0.6.0.dist-info/RECORD +97 -0
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/WHEEL +1 -1
- liger_kernel/transformers/gema3_rms.py +0 -8
- liger_kernel-0.5.9.dist-info/RECORD +0 -84
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/licenses/LICENSE +0 -0
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/licenses/NOTICE +0 -0
- {liger_kernel-0.5.9.dist-info → liger_kernel-0.6.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from liger_kernel.ops.grpo_loss import GrpoLossFunction
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def triton_grpo_loss(
|
|
5
|
+
logits,
|
|
6
|
+
old_logp,
|
|
7
|
+
ref_logp,
|
|
8
|
+
completion_ids,
|
|
9
|
+
advantages,
|
|
10
|
+
completion_mask=None,
|
|
11
|
+
temperature=0.9,
|
|
12
|
+
beta=0.04,
|
|
13
|
+
eps_low=0.2,
|
|
14
|
+
eps_high=0.4,
|
|
15
|
+
inplace=True,
|
|
16
|
+
):
|
|
17
|
+
assert logits is not None and completion_ids is not None and advantages is not None, (
|
|
18
|
+
"must provide logits、completion_ids and advantages"
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
return GrpoLossFunction.apply(
|
|
22
|
+
logits,
|
|
23
|
+
old_logp,
|
|
24
|
+
ref_logp,
|
|
25
|
+
completion_ids,
|
|
26
|
+
advantages,
|
|
27
|
+
completion_mask,
|
|
28
|
+
temperature,
|
|
29
|
+
beta,
|
|
30
|
+
eps_low,
|
|
31
|
+
eps_high,
|
|
32
|
+
inplace,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# This is a demo how to use grpo_loss in GRPOTrainer. The Trl version must be 0.16
|
|
37
|
+
"""
|
|
38
|
+
import torch
|
|
39
|
+
import trl
|
|
40
|
+
assert trl.__version__.startswith("0.16"), "please pip install trl==0.16"
|
|
41
|
+
from trl.extras.profiling import profiling_decorator
|
|
42
|
+
|
|
43
|
+
@profiling_decorator
|
|
44
|
+
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
|
|
45
|
+
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
|
|
46
|
+
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
|
|
47
|
+
return fused_selective_log_softmax(logits, input_ids, self.temperature, mask=attention_mask)
|
|
48
|
+
|
|
49
|
+
@profiling_decorator
|
|
50
|
+
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
|
51
|
+
if return_outputs:
|
|
52
|
+
raise ValueError("The GRPOTrainer does not support returning outputs")
|
|
53
|
+
# Compute the per-token log probabilities for the model
|
|
54
|
+
|
|
55
|
+
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
|
|
56
|
+
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
|
|
57
|
+
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
|
58
|
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
|
59
|
+
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
|
60
|
+
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
|
|
61
|
+
|
|
62
|
+
ref_per_token_logps = inputs["ref_per_token_logps"]
|
|
63
|
+
advantages = inputs["advantages"]
|
|
64
|
+
old_per_token_logps = inputs["old_per_token_logps"]
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
per_token_loss, per_token_kl, is_clipped = triton_grpo_loss(logits,
|
|
68
|
+
old_per_token_logps,
|
|
69
|
+
ref_per_token_logps,
|
|
70
|
+
completion_ids,
|
|
71
|
+
advantages,
|
|
72
|
+
completion_mask,
|
|
73
|
+
self.temperature,
|
|
74
|
+
self.beta,
|
|
75
|
+
self.epsilon_low,
|
|
76
|
+
self.epsilon_high,)
|
|
77
|
+
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
|
|
78
|
+
|
|
79
|
+
# Log the metrics
|
|
80
|
+
mode = "eval" if self.control.should_evaluate else "train"
|
|
81
|
+
|
|
82
|
+
if self.beta != 0.0:
|
|
83
|
+
mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
|
|
84
|
+
self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
|
|
85
|
+
|
|
86
|
+
clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
|
|
87
|
+
self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
|
|
88
|
+
return loss
|
|
89
|
+
|
|
90
|
+
trl.GRPOTrainer._get_per_token_logps = _get_per_token_logps
|
|
91
|
+
trl.GRPOTrainer.compute_loss = compute_loss
|
|
92
|
+
trigger = None
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
# add this line at the first line of grpo.py in open-r1
|
|
96
|
+
"""
|
|
97
|
+
from liger_kernel.transformers.grpo_loss import trigger
|
|
98
|
+
"""
|
|
@@ -8,18 +8,12 @@ import torch
|
|
|
8
8
|
from torch.nn import CrossEntropyLoss
|
|
9
9
|
from transformers.cache_utils import Cache
|
|
10
10
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
11
|
-
from transformers.models.gemma.modeling_gemma import _CONFIG_FOR_DOC
|
|
12
|
-
from transformers.models.gemma.modeling_gemma import GEMMA_INPUTS_DOCSTRING
|
|
13
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
14
|
-
from transformers.utils import replace_return_docstrings
|
|
15
11
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
16
12
|
|
|
17
13
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
18
14
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
19
15
|
|
|
20
16
|
|
|
21
|
-
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
|
22
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
23
17
|
def lce_forward_deprecated(
|
|
24
18
|
self,
|
|
25
19
|
input_ids: torch.LongTensor = None,
|
|
@@ -33,6 +27,7 @@ def lce_forward_deprecated(
|
|
|
33
27
|
output_hidden_states: Optional[bool] = None,
|
|
34
28
|
return_dict: Optional[bool] = None,
|
|
35
29
|
cache_position: Optional[torch.LongTensor] = None,
|
|
30
|
+
skip_logits: Optional[bool] = None,
|
|
36
31
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
37
32
|
r"""
|
|
38
33
|
|
|
@@ -87,7 +82,14 @@ def lce_forward_deprecated(
|
|
|
87
82
|
loss = None
|
|
88
83
|
logits = None
|
|
89
84
|
|
|
90
|
-
if
|
|
85
|
+
if skip_logits and labels is None:
|
|
86
|
+
raise ValueError("skip_logits is True, but labels is None")
|
|
87
|
+
|
|
88
|
+
if skip_logits is None:
|
|
89
|
+
# By default, if in training mode, don't materialize logits
|
|
90
|
+
skip_logits = self.training and labels is not None
|
|
91
|
+
|
|
92
|
+
if skip_logits:
|
|
91
93
|
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
|
|
92
94
|
shift_labels = labels[..., 1:].contiguous()
|
|
93
95
|
|
|
@@ -129,8 +131,6 @@ def lce_forward_deprecated(
|
|
|
129
131
|
|
|
130
132
|
|
|
131
133
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
132
|
-
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
|
133
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
134
134
|
def lce_forward(
|
|
135
135
|
self,
|
|
136
136
|
input_ids: torch.LongTensor = None,
|
|
@@ -145,7 +145,8 @@ def lce_forward(
|
|
|
145
145
|
return_dict: Optional[bool] = None,
|
|
146
146
|
cache_position: Optional[torch.LongTensor] = None,
|
|
147
147
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
148
|
-
|
|
148
|
+
skip_logits: Optional[bool] = None,
|
|
149
|
+
**kwargs,
|
|
149
150
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
150
151
|
r"""
|
|
151
152
|
Args:
|
|
@@ -197,6 +198,7 @@ def lce_forward(
|
|
|
197
198
|
output_hidden_states=output_hidden_states,
|
|
198
199
|
return_dict=return_dict,
|
|
199
200
|
cache_position=cache_position,
|
|
201
|
+
**kwargs,
|
|
200
202
|
)
|
|
201
203
|
|
|
202
204
|
hidden_states = outputs[0]
|
|
@@ -204,27 +206,34 @@ def lce_forward(
|
|
|
204
206
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
205
207
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
206
208
|
|
|
207
|
-
shift_labels =
|
|
209
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
208
210
|
logits = None
|
|
209
211
|
loss = None
|
|
210
|
-
|
|
211
|
-
if
|
|
212
|
+
|
|
213
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
214
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
215
|
+
|
|
216
|
+
if skip_logits is None:
|
|
217
|
+
# By default, if in training mode, don't materialize logits
|
|
218
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
219
|
+
|
|
220
|
+
if skip_logits:
|
|
212
221
|
loss = LigerForCausalLMLoss(
|
|
213
222
|
hidden_states=kept_hidden_states,
|
|
214
223
|
lm_head_weight=self.lm_head.weight,
|
|
215
224
|
labels=labels,
|
|
216
225
|
shift_labels=shift_labels,
|
|
217
226
|
hidden_size=self.config.hidden_size,
|
|
218
|
-
**
|
|
227
|
+
**kwargs,
|
|
219
228
|
)
|
|
220
|
-
else:
|
|
229
|
+
else:
|
|
221
230
|
logits = self.lm_head(kept_hidden_states)
|
|
222
231
|
if labels is not None:
|
|
223
232
|
loss = self.loss_function(
|
|
224
233
|
logits=logits,
|
|
225
234
|
labels=labels,
|
|
226
235
|
vocab_size=self.config.vocab_size,
|
|
227
|
-
**
|
|
236
|
+
**kwargs,
|
|
228
237
|
)
|
|
229
238
|
|
|
230
239
|
if not return_dict:
|
|
@@ -9,10 +9,6 @@ import torch
|
|
|
9
9
|
from torch.nn import CrossEntropyLoss
|
|
10
10
|
from transformers.cache_utils import HybridCache
|
|
11
11
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
12
|
-
from transformers.models.gemma2.modeling_gemma2 import _CONFIG_FOR_DOC
|
|
13
|
-
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_INPUTS_DOCSTRING
|
|
14
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
15
|
-
from transformers.utils import replace_return_docstrings
|
|
16
12
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
17
13
|
|
|
18
14
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
@@ -34,6 +30,8 @@ def lce_forward_deprecated(
|
|
|
34
30
|
output_hidden_states: Optional[bool] = None,
|
|
35
31
|
return_dict: Optional[bool] = None,
|
|
36
32
|
cache_position: Optional[torch.LongTensor] = None,
|
|
33
|
+
skip_logits: Optional[bool] = None,
|
|
34
|
+
**kwargs,
|
|
37
35
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
38
36
|
r"""
|
|
39
37
|
Args:
|
|
@@ -80,6 +78,7 @@ def lce_forward_deprecated(
|
|
|
80
78
|
output_hidden_states=output_hidden_states,
|
|
81
79
|
return_dict=return_dict,
|
|
82
80
|
cache_position=cache_position,
|
|
81
|
+
**kwargs,
|
|
83
82
|
)
|
|
84
83
|
|
|
85
84
|
hidden_states = outputs[0]
|
|
@@ -87,7 +86,14 @@ def lce_forward_deprecated(
|
|
|
87
86
|
loss = None
|
|
88
87
|
logits = None
|
|
89
88
|
|
|
90
|
-
if
|
|
89
|
+
if skip_logits and labels is None:
|
|
90
|
+
raise ValueError("skip_logits is True, but labels is None")
|
|
91
|
+
|
|
92
|
+
if skip_logits is None:
|
|
93
|
+
# By default, if in training mode, don't materialize logits
|
|
94
|
+
skip_logits = self.training and labels is not None
|
|
95
|
+
|
|
96
|
+
if skip_logits:
|
|
91
97
|
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
|
|
92
98
|
shift_labels = labels[..., 1:].contiguous()
|
|
93
99
|
|
|
@@ -136,8 +142,6 @@ def lce_forward_deprecated(
|
|
|
136
142
|
|
|
137
143
|
|
|
138
144
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
139
|
-
@add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
|
|
140
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
141
145
|
def lce_forward(
|
|
142
146
|
self,
|
|
143
147
|
input_ids: torch.LongTensor = None,
|
|
@@ -152,7 +156,8 @@ def lce_forward(
|
|
|
152
156
|
return_dict: Optional[bool] = None,
|
|
153
157
|
cache_position: Optional[torch.LongTensor] = None,
|
|
154
158
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
155
|
-
|
|
159
|
+
skip_logits: Optional[bool] = None,
|
|
160
|
+
**kwargs,
|
|
156
161
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
157
162
|
r"""
|
|
158
163
|
Args:
|
|
@@ -209,6 +214,7 @@ def lce_forward(
|
|
|
209
214
|
output_hidden_states=output_hidden_states,
|
|
210
215
|
return_dict=return_dict,
|
|
211
216
|
cache_position=cache_position,
|
|
217
|
+
**kwargs,
|
|
212
218
|
)
|
|
213
219
|
|
|
214
220
|
hidden_states = outputs[0]
|
|
@@ -216,11 +222,18 @@ def lce_forward(
|
|
|
216
222
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
217
223
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
218
224
|
|
|
219
|
-
shift_labels =
|
|
225
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
220
226
|
logits = None
|
|
221
227
|
loss = None
|
|
222
|
-
|
|
223
|
-
if
|
|
228
|
+
|
|
229
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
230
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
231
|
+
|
|
232
|
+
if skip_logits is None:
|
|
233
|
+
# By default, if in training mode, don't materialize logits
|
|
234
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
235
|
+
|
|
236
|
+
if skip_logits:
|
|
224
237
|
loss = LigerForCausalLMLoss(
|
|
225
238
|
hidden_states=kept_hidden_states,
|
|
226
239
|
lm_head_weight=self.lm_head.weight,
|
|
@@ -228,10 +241,10 @@ def lce_forward(
|
|
|
228
241
|
shift_labels=shift_labels,
|
|
229
242
|
hidden_size=self.config.hidden_size,
|
|
230
243
|
final_logit_softcapping=self.config.final_logit_softcapping,
|
|
231
|
-
**
|
|
244
|
+
**kwargs,
|
|
232
245
|
)
|
|
233
246
|
|
|
234
|
-
else:
|
|
247
|
+
else:
|
|
235
248
|
logits = self.lm_head(kept_hidden_states)
|
|
236
249
|
if self.config.final_logit_softcapping is not None:
|
|
237
250
|
logits = logits / self.config.final_logit_softcapping
|
|
@@ -240,7 +253,7 @@ def lce_forward(
|
|
|
240
253
|
|
|
241
254
|
loss = None
|
|
242
255
|
if labels is not None:
|
|
243
|
-
loss = self.loss_function(logits, labels, self.vocab_size, **
|
|
256
|
+
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
|
244
257
|
|
|
245
258
|
if not return_dict:
|
|
246
259
|
output = (logits,) + outputs[1:]
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
from typing import List
|
|
2
1
|
from typing import Optional
|
|
3
2
|
from typing import Tuple
|
|
4
3
|
from typing import Union
|
|
@@ -9,14 +8,8 @@ import torch.nn as nn
|
|
|
9
8
|
from transformers.cache_utils import Cache
|
|
10
9
|
from transformers.cache_utils import HybridCache
|
|
11
10
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
12
|
-
from transformers.models.gemma3.modeling_gemma3 import _CONFIG_FOR_DOC
|
|
13
|
-
from transformers.models.gemma3.modeling_gemma3 import GEMMA3_INPUTS_DOCSTRING
|
|
14
11
|
from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast
|
|
15
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
16
|
-
from transformers.utils import is_torchdynamo_compiling
|
|
17
12
|
from transformers.utils import logging
|
|
18
|
-
from transformers.utils import replace_return_docstrings
|
|
19
|
-
from transformers.utils.deprecation import deprecate_kwarg
|
|
20
13
|
|
|
21
14
|
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
|
22
15
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
@@ -24,9 +17,6 @@ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
|
24
17
|
logger = logging.get_logger(__name__)
|
|
25
18
|
|
|
26
19
|
|
|
27
|
-
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
28
|
-
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
|
29
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
30
20
|
def causal_forward(
|
|
31
21
|
self,
|
|
32
22
|
input_ids: torch.LongTensor = None,
|
|
@@ -41,6 +31,7 @@ def causal_forward(
|
|
|
41
31
|
return_dict: Optional[bool] = None,
|
|
42
32
|
cache_position: Optional[torch.LongTensor] = None,
|
|
43
33
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
34
|
+
skip_logits: Optional[bool] = None,
|
|
44
35
|
**loss_kwargs,
|
|
45
36
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
46
37
|
r"""
|
|
@@ -107,7 +98,11 @@ def causal_forward(
|
|
|
107
98
|
shift_labels = loss_kwargs.pop("shift_labels", None)
|
|
108
99
|
loss = None
|
|
109
100
|
logits = None
|
|
110
|
-
|
|
101
|
+
|
|
102
|
+
if skip_logits is None:
|
|
103
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
104
|
+
|
|
105
|
+
if skip_logits:
|
|
111
106
|
loss = LigerForCausalLMLoss(
|
|
112
107
|
hidden_states=kept_hidden_states,
|
|
113
108
|
lm_head_weight=self.lm_head.weight,
|
|
@@ -140,16 +135,13 @@ def causal_forward(
|
|
|
140
135
|
)
|
|
141
136
|
|
|
142
137
|
|
|
143
|
-
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
144
|
-
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
|
145
|
-
@replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
146
138
|
def multimodal_forward(
|
|
147
139
|
self,
|
|
148
140
|
input_ids: torch.LongTensor = None,
|
|
149
141
|
pixel_values: torch.FloatTensor = None,
|
|
150
142
|
attention_mask: Optional[torch.Tensor] = None,
|
|
151
143
|
position_ids: Optional[torch.LongTensor] = None,
|
|
152
|
-
past_key_values: Optional[Union[
|
|
144
|
+
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
|
|
153
145
|
token_type_ids: Optional[torch.LongTensor] = None,
|
|
154
146
|
cache_position: Optional[torch.LongTensor] = None,
|
|
155
147
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
@@ -159,22 +151,14 @@ def multimodal_forward(
|
|
|
159
151
|
output_hidden_states: Optional[bool] = None,
|
|
160
152
|
return_dict: Optional[bool] = None,
|
|
161
153
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
154
|
+
skip_logits: Optional[bool] = None,
|
|
162
155
|
**lm_kwargs,
|
|
163
|
-
) -> Union[
|
|
156
|
+
) -> Union[tuple, Gemma3CausalLMOutputWithPast]:
|
|
164
157
|
r"""
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
171
|
-
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
172
|
-
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
173
|
-
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
174
|
-
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
175
|
-
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
176
|
-
|
|
177
|
-
Returns:
|
|
158
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
159
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
160
|
+
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
161
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
|
178
162
|
|
|
179
163
|
Example:
|
|
180
164
|
|
|
@@ -183,23 +167,37 @@ def multimodal_forward(
|
|
|
183
167
|
>>> import requests
|
|
184
168
|
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
|
185
169
|
|
|
186
|
-
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/
|
|
187
|
-
>>> processor = AutoProcessor.from_pretrained("google/
|
|
188
|
-
|
|
189
|
-
>>>
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
170
|
+
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
|
|
171
|
+
>>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
|
|
172
|
+
|
|
173
|
+
>>> messages = [
|
|
174
|
+
... {
|
|
175
|
+
... "role": "system",
|
|
176
|
+
... "content": [
|
|
177
|
+
... {"type": "text", "text": "You are a helpful assistant."}
|
|
178
|
+
... ]
|
|
179
|
+
... },
|
|
180
|
+
... {
|
|
181
|
+
... "role": "user", "content": [
|
|
182
|
+
... {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
|
|
183
|
+
... {"type": "text", "text": "Where is the cat standing?"},
|
|
184
|
+
... ]
|
|
185
|
+
... },
|
|
186
|
+
... ]
|
|
187
|
+
|
|
188
|
+
>>> inputs = processor.apply_chat_template(
|
|
189
|
+
... messages,
|
|
190
|
+
... tokenize=True,
|
|
191
|
+
... return_dict=True,
|
|
192
|
+
... return_tensors="pt",
|
|
193
|
+
... add_generation_prompt=True
|
|
194
|
+
... )
|
|
195
195
|
>>> # Generate
|
|
196
|
-
>>> generate_ids = model.generate(**inputs
|
|
196
|
+
>>> generate_ids = model.generate(**inputs)
|
|
197
197
|
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
198
|
-
"
|
|
199
|
-
```
|
|
200
|
-
|
|
201
|
-
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
202
|
-
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
198
|
+
"user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
|
|
199
|
+
```
|
|
200
|
+
"""
|
|
203
201
|
|
|
204
202
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
205
203
|
output_hidden_states = (
|
|
@@ -207,81 +205,38 @@ def multimodal_forward(
|
|
|
207
205
|
)
|
|
208
206
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
209
207
|
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
llm_input_ids = input_ids.clone()
|
|
216
|
-
llm_input_ids[special_image_mask] = 0
|
|
217
|
-
else:
|
|
218
|
-
llm_input_ids = input_ids
|
|
219
|
-
|
|
220
|
-
if inputs_embeds is None:
|
|
221
|
-
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
|
222
|
-
|
|
223
|
-
if cache_position is None:
|
|
224
|
-
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
225
|
-
cache_position = torch.arange(
|
|
226
|
-
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
|
227
|
-
)
|
|
228
|
-
|
|
229
|
-
if position_ids is None:
|
|
230
|
-
position_ids = cache_position.unsqueeze(0) + 1 # Gemma3 positions are 1-indexed
|
|
231
|
-
|
|
232
|
-
# Merge text and images
|
|
233
|
-
if pixel_values is not None:
|
|
234
|
-
image_features = self.get_image_features(pixel_values)
|
|
235
|
-
|
|
236
|
-
if input_ids is None:
|
|
237
|
-
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
|
238
|
-
torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device)
|
|
239
|
-
)
|
|
240
|
-
else:
|
|
241
|
-
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
|
242
|
-
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
243
|
-
|
|
244
|
-
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
|
245
|
-
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
|
246
|
-
raise ValueError(
|
|
247
|
-
f"Number of images does not match number of special image tokens in the input text. "
|
|
248
|
-
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
|
249
|
-
"tokens from image embeddings."
|
|
250
|
-
)
|
|
251
|
-
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
252
|
-
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
|
253
|
-
|
|
254
|
-
# mask out pad-token-ids in labels for BC
|
|
255
|
-
if labels is not None and self.pad_token_id in labels:
|
|
256
|
-
logger.warning_once(
|
|
257
|
-
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
|
|
258
|
-
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
|
|
259
|
-
)
|
|
260
|
-
labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
|
|
261
|
-
|
|
262
|
-
causal_mask = self._update_causal_mask(
|
|
263
|
-
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
|
|
264
|
-
)
|
|
265
|
-
outputs = self.language_model.model(
|
|
266
|
-
attention_mask=causal_mask,
|
|
208
|
+
outputs = self.model(
|
|
209
|
+
input_ids=input_ids,
|
|
210
|
+
pixel_values=pixel_values,
|
|
211
|
+
token_type_ids=token_type_ids,
|
|
212
|
+
attention_mask=attention_mask,
|
|
267
213
|
position_ids=position_ids,
|
|
268
214
|
past_key_values=past_key_values,
|
|
269
215
|
inputs_embeds=inputs_embeds,
|
|
270
216
|
use_cache=use_cache,
|
|
217
|
+
labels=labels,
|
|
271
218
|
output_attentions=output_attentions,
|
|
272
219
|
output_hidden_states=output_hidden_states,
|
|
273
220
|
return_dict=return_dict,
|
|
274
221
|
cache_position=cache_position,
|
|
275
|
-
logits_to_keep=logits_to_keep,
|
|
276
222
|
**lm_kwargs,
|
|
277
223
|
)
|
|
278
224
|
|
|
279
225
|
hidden_states = outputs[0]
|
|
226
|
+
|
|
227
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
228
|
+
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
229
|
+
|
|
280
230
|
loss = None
|
|
281
231
|
logits = None
|
|
232
|
+
if skip_logits and labels is None:
|
|
233
|
+
raise ValueError("skip_logits is True, but labels is None")
|
|
234
|
+
|
|
235
|
+
if skip_logits is None:
|
|
236
|
+
skip_logits = self.training and (labels is not None)
|
|
282
237
|
|
|
283
|
-
if
|
|
284
|
-
shift_hidden_states =
|
|
238
|
+
if skip_logits:
|
|
239
|
+
shift_hidden_states = kept_hidden_states[..., :-1, :]
|
|
285
240
|
shift_labels = labels[..., 1:]
|
|
286
241
|
|
|
287
242
|
hidden_device = shift_hidden_states.device
|
|
@@ -302,7 +257,7 @@ def multimodal_forward(
|
|
|
302
257
|
lce = LigerFusedLinearCrossEntropyLoss()
|
|
303
258
|
loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
|
|
304
259
|
else:
|
|
305
|
-
logits = self.
|
|
260
|
+
logits = self.lm_head(kept_hidden_states)
|
|
306
261
|
if labels is not None:
|
|
307
262
|
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
308
263
|
logits = logits.float()
|
|
@@ -323,6 +278,7 @@ def multimodal_forward(
|
|
|
323
278
|
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
|
324
279
|
flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
|
325
280
|
loss = loss_fct(flat_logits, flat_labels)
|
|
281
|
+
|
|
326
282
|
if not return_dict:
|
|
327
283
|
output = (logits,) + outputs[1:]
|
|
328
284
|
return (loss,) + output if loss is not None else output
|
|
@@ -333,5 +289,5 @@ def multimodal_forward(
|
|
|
333
289
|
past_key_values=outputs.past_key_values,
|
|
334
290
|
hidden_states=outputs.hidden_states,
|
|
335
291
|
attentions=outputs.attentions,
|
|
336
|
-
image_hidden_states=
|
|
292
|
+
image_hidden_states=outputs.image_hidden_states,
|
|
337
293
|
)
|
|
@@ -6,18 +6,12 @@ from typing import Union
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
8
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
9
|
-
from transformers.models.glm4.modeling_glm4 import _CONFIG_FOR_DOC
|
|
10
|
-
from transformers.models.glm4.modeling_glm4 import GLM4_INPUTS_DOCSTRING
|
|
11
|
-
from transformers.utils import add_start_docstrings_to_model_forward
|
|
12
|
-
from transformers.utils import replace_return_docstrings
|
|
13
9
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
14
10
|
|
|
15
11
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
16
12
|
|
|
17
13
|
|
|
18
14
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
19
|
-
@add_start_docstrings_to_model_forward(GLM4_INPUTS_DOCSTRING)
|
|
20
|
-
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
21
15
|
def lce_forward(
|
|
22
16
|
self,
|
|
23
17
|
input_ids: torch.LongTensor = None,
|
|
@@ -32,7 +26,8 @@ def lce_forward(
|
|
|
32
26
|
return_dict: Optional[bool] = None,
|
|
33
27
|
cache_position: Optional[torch.LongTensor] = None,
|
|
34
28
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
35
|
-
|
|
29
|
+
skip_logits: Optional[bool] = None,
|
|
30
|
+
**kwargs,
|
|
36
31
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
37
32
|
r"""
|
|
38
33
|
Args:
|
|
@@ -85,6 +80,7 @@ def lce_forward(
|
|
|
85
80
|
output_hidden_states=output_hidden_states,
|
|
86
81
|
return_dict=return_dict,
|
|
87
82
|
cache_position=cache_position,
|
|
83
|
+
**kwargs,
|
|
88
84
|
)
|
|
89
85
|
|
|
90
86
|
hidden_states = outputs[0]
|
|
@@ -92,28 +88,35 @@ def lce_forward(
|
|
|
92
88
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
93
89
|
kept_hidden_states = hidden_states[:, slice_indices, :]
|
|
94
90
|
|
|
95
|
-
shift_labels =
|
|
91
|
+
shift_labels = kwargs.pop("shift_labels", None)
|
|
96
92
|
logits = None
|
|
97
93
|
loss = None
|
|
98
|
-
|
|
99
|
-
if
|
|
94
|
+
|
|
95
|
+
if skip_logits and labels is None and shift_labels is None:
|
|
96
|
+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
|
|
97
|
+
|
|
98
|
+
if skip_logits is None:
|
|
99
|
+
# By default, if in training mode, don't materialize logits
|
|
100
|
+
skip_logits = self.training and (labels is not None or shift_labels is not None)
|
|
101
|
+
|
|
102
|
+
if skip_logits:
|
|
100
103
|
loss = LigerForCausalLMLoss(
|
|
101
104
|
hidden_states=kept_hidden_states,
|
|
102
105
|
lm_head_weight=self.lm_head.weight,
|
|
103
106
|
labels=labels,
|
|
104
107
|
shift_labels=shift_labels,
|
|
105
108
|
hidden_size=self.config.hidden_size,
|
|
106
|
-
**
|
|
109
|
+
**kwargs,
|
|
107
110
|
)
|
|
108
111
|
|
|
109
|
-
else:
|
|
112
|
+
else:
|
|
110
113
|
logits = self.lm_head(kept_hidden_states)
|
|
111
114
|
if labels is not None:
|
|
112
115
|
loss = self.loss_function(
|
|
113
116
|
logits=logits,
|
|
114
117
|
labels=labels,
|
|
115
118
|
vocab_size=self.config.vocab_size,
|
|
116
|
-
**
|
|
119
|
+
**kwargs,
|
|
117
120
|
)
|
|
118
121
|
|
|
119
122
|
return CausalLMOutputWithPast(
|