liger-kernel 0.5.4__py3-none-any.whl → 0.5.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liger_kernel/chunked_loss/cpo_loss.py +51 -11
- liger_kernel/chunked_loss/dpo_loss.py +30 -4
- liger_kernel/chunked_loss/functional.py +2 -0
- liger_kernel/chunked_loss/fused_linear_distillation.py +20 -5
- liger_kernel/chunked_loss/fused_linear_ppo.py +331 -0
- liger_kernel/chunked_loss/fused_linear_preference.py +2 -2
- liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +112 -17
- liger_kernel/chunked_loss/grpo_loss.py +137 -61
- liger_kernel/chunked_loss/jsd_loss.py +43 -13
- liger_kernel/chunked_loss/kto_loss.py +50 -12
- liger_kernel/chunked_loss/orpo_loss.py +37 -5
- liger_kernel/chunked_loss/simpo_loss.py +47 -11
- liger_kernel/ops/cross_entropy.py +7 -2
- liger_kernel/ops/dyt.py +225 -0
- liger_kernel/ops/fused_linear_jsd.py +2 -1
- liger_kernel/ops/jsd.py +30 -11
- liger_kernel/ops/kl_div.py +2 -2
- liger_kernel/transformers/__init__.py +4 -0
- liger_kernel/transformers/dyt.py +20 -0
- liger_kernel/transformers/functional.py +5 -0
- liger_kernel/transformers/model/gemma.py +8 -16
- liger_kernel/transformers/model/gemma2.py +7 -16
- liger_kernel/transformers/model/llama.py +8 -15
- liger_kernel/transformers/model/llava.py +369 -0
- liger_kernel/transformers/model/loss_utils.py +57 -0
- liger_kernel/transformers/model/mistral.py +9 -10
- liger_kernel/transformers/model/mixtral.py +8 -15
- liger_kernel/transformers/model/mllama.py +8 -15
- liger_kernel/transformers/model/olmo2.py +8 -16
- liger_kernel/transformers/model/paligemma.py +397 -0
- liger_kernel/transformers/model/phi3.py +8 -15
- liger_kernel/transformers/model/qwen2.py +8 -15
- liger_kernel/transformers/model/qwen2_5_vl.py +204 -0
- liger_kernel/transformers/model/qwen2_vl.py +9 -10
- liger_kernel/transformers/monkey_patch.py +286 -12
- liger_kernel/utils.py +1 -3
- {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info}/METADATA +11 -7
- liger_kernel-0.5.6.dist-info/RECORD +80 -0
- {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info}/WHEEL +1 -1
- liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -213
- liger_kernel-0.5.4.dist-info/RECORD +0 -74
- {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info/licenses}/LICENSE +0 -0
- {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info/licenses}/NOTICE +0 -0
- {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from torch.nn import CrossEntropyLoss
|
|
9
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import _CONFIG_FOR_DOC
|
|
10
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import QWEN2_5_VL_INPUTS_DOCSTRING
|
|
11
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast
|
|
12
|
+
from transformers.utils import add_start_docstrings_to_model_forward
|
|
13
|
+
from transformers.utils import replace_return_docstrings
|
|
14
|
+
|
|
15
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING)
|
|
19
|
+
@replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
20
|
+
def lce_forward(
|
|
21
|
+
self,
|
|
22
|
+
input_ids: torch.LongTensor = None,
|
|
23
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
24
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
25
|
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
26
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
27
|
+
labels: Optional[torch.LongTensor] = None,
|
|
28
|
+
use_cache: Optional[bool] = None,
|
|
29
|
+
output_attentions: Optional[bool] = None,
|
|
30
|
+
output_hidden_states: Optional[bool] = None,
|
|
31
|
+
return_dict: Optional[bool] = None,
|
|
32
|
+
pixel_values: Optional[torch.Tensor] = None,
|
|
33
|
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
|
34
|
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
35
|
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
36
|
+
rope_deltas: Optional[torch.LongTensor] = None,
|
|
37
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
38
|
+
second_per_grid_ts: Optional[torch.Tensor] = None,
|
|
39
|
+
**loss_kwargs,
|
|
40
|
+
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
|
|
41
|
+
r"""
|
|
42
|
+
Copy paste Qwen2_5_VL's forward but replace torch cross entropy with liger fused linear cross entropy
|
|
43
|
+
Args:
|
|
44
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
45
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
46
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
47
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
|
|
51
|
+
Example:
|
|
52
|
+
|
|
53
|
+
```python
|
|
54
|
+
>>> from PIL import Image
|
|
55
|
+
>>> import requests
|
|
56
|
+
>>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
|
57
|
+
|
|
58
|
+
>>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
|
59
|
+
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
|
60
|
+
|
|
61
|
+
>>> messages = [
|
|
62
|
+
{
|
|
63
|
+
"role": "user",
|
|
64
|
+
"content": [
|
|
65
|
+
{"type": "image"},
|
|
66
|
+
{"type": "text", "text": "What is shown in this image?"},
|
|
67
|
+
],
|
|
68
|
+
},
|
|
69
|
+
]
|
|
70
|
+
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
|
71
|
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
72
|
+
|
|
73
|
+
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
74
|
+
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
|
75
|
+
|
|
76
|
+
>>> # Generate
|
|
77
|
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
78
|
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
79
|
+
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
|
80
|
+
```"""
|
|
81
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
82
|
+
output_hidden_states = (
|
|
83
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
84
|
+
)
|
|
85
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
86
|
+
|
|
87
|
+
if inputs_embeds is None:
|
|
88
|
+
inputs_embeds = self.model.embed_tokens(input_ids)
|
|
89
|
+
if pixel_values is not None:
|
|
90
|
+
pixel_values = pixel_values.type(self.visual.dtype)
|
|
91
|
+
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
|
92
|
+
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
|
93
|
+
n_image_features = image_embeds.shape[0]
|
|
94
|
+
if n_image_tokens != n_image_features:
|
|
95
|
+
raise ValueError(
|
|
96
|
+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
mask = input_ids == self.config.image_token_id
|
|
100
|
+
mask_unsqueezed = mask.unsqueeze(-1)
|
|
101
|
+
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
|
102
|
+
image_mask = mask_expanded.to(inputs_embeds.device)
|
|
103
|
+
|
|
104
|
+
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
105
|
+
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
|
106
|
+
|
|
107
|
+
if pixel_values_videos is not None:
|
|
108
|
+
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
|
109
|
+
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
|
110
|
+
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
|
111
|
+
n_video_features = video_embeds.shape[0]
|
|
112
|
+
if n_video_tokens != n_video_features:
|
|
113
|
+
raise ValueError(
|
|
114
|
+
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
mask = input_ids == self.config.video_token_id
|
|
118
|
+
mask_unsqueezed = mask.unsqueeze(-1)
|
|
119
|
+
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
|
120
|
+
video_mask = mask_expanded.to(inputs_embeds.device)
|
|
121
|
+
|
|
122
|
+
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
123
|
+
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
|
124
|
+
|
|
125
|
+
if attention_mask is not None:
|
|
126
|
+
attention_mask = attention_mask.to(inputs_embeds.device)
|
|
127
|
+
|
|
128
|
+
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
|
129
|
+
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
|
130
|
+
# calculate RoPE index once per generation in the pre-fill stage only
|
|
131
|
+
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
|
|
132
|
+
position_ids, rope_deltas = self.get_rope_index(
|
|
133
|
+
input_ids,
|
|
134
|
+
image_grid_thw,
|
|
135
|
+
video_grid_thw,
|
|
136
|
+
second_per_grid_ts,
|
|
137
|
+
attention_mask,
|
|
138
|
+
)
|
|
139
|
+
self.rope_deltas = rope_deltas
|
|
140
|
+
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
|
141
|
+
else:
|
|
142
|
+
batch_size, seq_length, _ = inputs_embeds.shape
|
|
143
|
+
delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0
|
|
144
|
+
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
|
145
|
+
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
|
146
|
+
if cache_position is not None: # otherwise `deltas` is an int `0`
|
|
147
|
+
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
|
148
|
+
position_ids = position_ids.add(delta)
|
|
149
|
+
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
|
150
|
+
|
|
151
|
+
outputs = self.model(
|
|
152
|
+
input_ids=None,
|
|
153
|
+
position_ids=position_ids,
|
|
154
|
+
attention_mask=attention_mask,
|
|
155
|
+
past_key_values=past_key_values,
|
|
156
|
+
inputs_embeds=inputs_embeds,
|
|
157
|
+
use_cache=use_cache,
|
|
158
|
+
output_attentions=output_attentions,
|
|
159
|
+
output_hidden_states=output_hidden_states,
|
|
160
|
+
return_dict=return_dict,
|
|
161
|
+
cache_position=cache_position,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
hidden_states = outputs[0]
|
|
165
|
+
|
|
166
|
+
loss = None
|
|
167
|
+
logits = None
|
|
168
|
+
|
|
169
|
+
if self.training and (labels is not None):
|
|
170
|
+
loss = LigerForCausalLMLoss(
|
|
171
|
+
hidden_states=hidden_states,
|
|
172
|
+
lm_head_weight=self.lm_head.weight,
|
|
173
|
+
labels=labels,
|
|
174
|
+
hidden_size=self.config.hidden_size,
|
|
175
|
+
**loss_kwargs,
|
|
176
|
+
)
|
|
177
|
+
else:
|
|
178
|
+
logits = self.lm_head(hidden_states)
|
|
179
|
+
if labels is not None:
|
|
180
|
+
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
181
|
+
logits = logits.float()
|
|
182
|
+
# Shift so that tokens < n predict n
|
|
183
|
+
shift_logits = logits[..., :-1, :].contiguous()
|
|
184
|
+
shift_labels = labels[..., 1:].contiguous()
|
|
185
|
+
# Flatten the tokens
|
|
186
|
+
loss_fct = CrossEntropyLoss()
|
|
187
|
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
188
|
+
shift_labels = shift_labels.view(-1)
|
|
189
|
+
# Enable model parallelism
|
|
190
|
+
shift_labels = shift_labels.to(shift_logits.device)
|
|
191
|
+
loss = loss_fct(shift_logits, shift_labels)
|
|
192
|
+
|
|
193
|
+
if not return_dict:
|
|
194
|
+
output = (logits,) + outputs[1:]
|
|
195
|
+
return (loss,) + output if loss is not None else output
|
|
196
|
+
|
|
197
|
+
return Qwen2_5_VLCausalLMOutputWithPast(
|
|
198
|
+
loss=loss,
|
|
199
|
+
logits=logits,
|
|
200
|
+
past_key_values=outputs.past_key_values,
|
|
201
|
+
hidden_states=outputs.hidden_states,
|
|
202
|
+
attentions=outputs.attentions,
|
|
203
|
+
rope_deltas=rope_deltas,
|
|
204
|
+
)
|
|
@@ -14,7 +14,7 @@ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutput
|
|
|
14
14
|
from transformers.utils import add_start_docstrings_to_model_forward
|
|
15
15
|
from transformers.utils import replace_return_docstrings
|
|
16
16
|
|
|
17
|
-
from liger_kernel.transformers.
|
|
17
|
+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
|
|
@@ -37,6 +37,7 @@ def lce_forward(
|
|
|
37
37
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
38
38
|
rope_deltas: Optional[torch.LongTensor] = None,
|
|
39
39
|
cache_position: Optional[torch.LongTensor] = None,
|
|
40
|
+
**loss_kwargs,
|
|
40
41
|
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
|
|
41
42
|
r"""
|
|
42
43
|
Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy
|
|
@@ -170,15 +171,13 @@ def lce_forward(
|
|
|
170
171
|
logits = None
|
|
171
172
|
|
|
172
173
|
if self.training and (labels is not None):
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
lce = LigerFusedLinearCrossEntropyLoss()
|
|
181
|
-
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
|
|
174
|
+
loss = LigerForCausalLMLoss(
|
|
175
|
+
hidden_states=hidden_states,
|
|
176
|
+
lm_head_weight=self.lm_head.weight,
|
|
177
|
+
labels=labels,
|
|
178
|
+
hidden_size=self.config.hidden_size,
|
|
179
|
+
**loss_kwargs,
|
|
180
|
+
)
|
|
182
181
|
else:
|
|
183
182
|
logits = self.lm_head(hidden_states)
|
|
184
183
|
if labels is not None:
|
|
@@ -19,6 +19,8 @@ from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_for
|
|
|
19
19
|
from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
|
|
20
20
|
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
|
|
21
21
|
from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
|
|
22
|
+
from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
|
|
23
|
+
from liger_kernel.transformers.model.llava import lce_forward_deprecated as llava_lce_forward_deprecated
|
|
22
24
|
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
|
|
23
25
|
from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
|
|
24
26
|
from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
|
|
@@ -52,13 +54,26 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
|
|
|
52
54
|
module.in_place = in_place
|
|
53
55
|
_bind_method_to_module(module, "forward", LigerRMSNorm.forward)
|
|
54
56
|
_bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
|
|
57
|
+
module.__class__.__name__ = LigerRMSNorm.__name__
|
|
55
58
|
|
|
56
59
|
|
|
57
60
|
def _patch_layer_norm_module(module, eps=1e-6):
|
|
58
61
|
module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
|
|
59
|
-
module.hidden_size = module
|
|
62
|
+
module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
|
|
63
|
+
|
|
60
64
|
_bind_method_to_module(module, "forward", LigerLayerNorm.forward)
|
|
61
65
|
_bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
|
|
66
|
+
module.__class__.__name__ = LigerLayerNorm.__name__
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _patch_swiglu_module(module, liger_module):
|
|
70
|
+
_bind_method_to_module(module, "forward", liger_module.forward)
|
|
71
|
+
module.__class__.__name__ = liger_module.__name__
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _patch_geglu_module(module):
|
|
75
|
+
_bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
|
|
76
|
+
module.__class__.__name__ = LigerGEGLUMLP.__name__
|
|
62
77
|
|
|
63
78
|
|
|
64
79
|
def apply_liger_kernel_to_granite(
|
|
@@ -134,7 +149,7 @@ def apply_liger_kernel_to_granite(
|
|
|
134
149
|
|
|
135
150
|
for decoder_layer in base_model.layers:
|
|
136
151
|
if swiglu:
|
|
137
|
-
|
|
152
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
138
153
|
if rms_norm:
|
|
139
154
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
140
155
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -206,12 +221,91 @@ def apply_liger_kernel_to_llama(
|
|
|
206
221
|
|
|
207
222
|
for decoder_layer in base_model.layers:
|
|
208
223
|
if swiglu:
|
|
209
|
-
|
|
224
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
210
225
|
if rms_norm:
|
|
211
226
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
212
227
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
213
228
|
|
|
214
229
|
|
|
230
|
+
def apply_liger_kernel_to_llava(
|
|
231
|
+
cross_entropy: bool = False,
|
|
232
|
+
fused_linear_cross_entropy: bool = True,
|
|
233
|
+
model: PreTrainedModel = None,
|
|
234
|
+
**kwargs,
|
|
235
|
+
) -> None:
|
|
236
|
+
"""
|
|
237
|
+
Apply Liger kernels to replace original implementation in HuggingFace Llava models.
|
|
238
|
+
Due to the characteristics of LlaVa, the model must be passed to apply Liger-Kernel's patch to other models connected to LLaVa.
|
|
239
|
+
However, if an LM not supported by Liger-Kernel is connected to LLaVa, unexpected side effects may occur.
|
|
240
|
+
NOTE: Llava is not available in transformers<4.36.0
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
244
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
245
|
+
fused_linear_cross_entropy (bool):
|
|
246
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
247
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
248
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
249
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
250
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
251
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
252
|
+
loaded. Default is None.
|
|
253
|
+
"""
|
|
254
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
255
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
from transformers.models.llava import modeling_llava
|
|
259
|
+
|
|
260
|
+
if cross_entropy:
|
|
261
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
262
|
+
modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
263
|
+
if fused_linear_cross_entropy:
|
|
264
|
+
if transformer_version >= version.parse("4.49.0"):
|
|
265
|
+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
|
|
266
|
+
else: # if version < 4.49.0
|
|
267
|
+
logger.warning(
|
|
268
|
+
"Support for transformers versions < 4.49.0 will soon be discontinued due to issues with incorrect legacy processing. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/35526"
|
|
269
|
+
)
|
|
270
|
+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
|
|
271
|
+
|
|
272
|
+
if model is not None:
|
|
273
|
+
text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
|
|
274
|
+
text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
|
|
275
|
+
vision_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(vision_model_name, None)
|
|
276
|
+
|
|
277
|
+
kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs}
|
|
278
|
+
if text_liger_fn:
|
|
279
|
+
accept_params = inspect.signature(text_liger_fn).parameters
|
|
280
|
+
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
|
|
281
|
+
text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
|
|
282
|
+
|
|
283
|
+
if remain_params:
|
|
284
|
+
logger.warning(
|
|
285
|
+
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
|
|
286
|
+
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
|
|
287
|
+
)
|
|
288
|
+
text_kwargs["model"] = model.language_model
|
|
289
|
+
text_liger_fn(**text_kwargs)
|
|
290
|
+
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
291
|
+
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
|
|
292
|
+
|
|
293
|
+
if vision_liger_fn:
|
|
294
|
+
accept_params = inspect.signature(vision_liger_fn).parameters
|
|
295
|
+
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
|
|
296
|
+
vision_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
|
|
297
|
+
|
|
298
|
+
if remain_params:
|
|
299
|
+
logger.warning(
|
|
300
|
+
f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
|
|
301
|
+
f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
|
|
302
|
+
)
|
|
303
|
+
vision_kwargs["model"] = model.vision_tower
|
|
304
|
+
vision_liger_fn(**vision_kwargs)
|
|
305
|
+
elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
306
|
+
logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
|
|
307
|
+
|
|
308
|
+
|
|
215
309
|
def apply_liger_kernel_to_mllama(
|
|
216
310
|
rope: bool = True,
|
|
217
311
|
cross_entropy: bool = False,
|
|
@@ -296,7 +390,7 @@ def apply_liger_kernel_to_mllama(
|
|
|
296
390
|
_patch_rms_norm_module(text_model.norm)
|
|
297
391
|
for decoder_layer in text_model.layers:
|
|
298
392
|
if swiglu:
|
|
299
|
-
|
|
393
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
300
394
|
if rms_norm:
|
|
301
395
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
302
396
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -370,7 +464,7 @@ def apply_liger_kernel_to_mistral(
|
|
|
370
464
|
|
|
371
465
|
for decoder_layer in base_model.layers:
|
|
372
466
|
if swiglu:
|
|
373
|
-
|
|
467
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
374
468
|
if rms_norm:
|
|
375
469
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
376
470
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -442,7 +536,7 @@ def apply_liger_kernel_to_mixtral(
|
|
|
442
536
|
for decoder_layer in base_model.layers:
|
|
443
537
|
if swiglu:
|
|
444
538
|
for expert in decoder_layer.block_sparse_moe.experts:
|
|
445
|
-
|
|
539
|
+
_patch_swiglu_module(expert, LigerBlockSparseTop2MLP)
|
|
446
540
|
if rms_norm:
|
|
447
541
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
448
542
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -516,7 +610,7 @@ def apply_liger_kernel_to_gemma(
|
|
|
516
610
|
|
|
517
611
|
for decoder_layer in base_model.layers:
|
|
518
612
|
if geglu:
|
|
519
|
-
|
|
613
|
+
_patch_geglu_module(decoder_layer.mlp)
|
|
520
614
|
if rms_norm:
|
|
521
615
|
_patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
|
|
522
616
|
_patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)
|
|
@@ -592,7 +686,7 @@ def apply_liger_kernel_to_gemma2(
|
|
|
592
686
|
|
|
593
687
|
for decoder_layer in base_model.layers:
|
|
594
688
|
if geglu:
|
|
595
|
-
|
|
689
|
+
_patch_geglu_module(decoder_layer.mlp)
|
|
596
690
|
if rms_norm:
|
|
597
691
|
_patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
|
|
598
692
|
_patch_rms_norm_module_for_gemma2(decoder_layer.post_attention_layernorm)
|
|
@@ -600,6 +694,116 @@ def apply_liger_kernel_to_gemma2(
|
|
|
600
694
|
_patch_rms_norm_module_for_gemma2(decoder_layer.post_feedforward_layernorm)
|
|
601
695
|
|
|
602
696
|
|
|
697
|
+
def apply_liger_kernel_to_paligemma(
|
|
698
|
+
rope: bool = True,
|
|
699
|
+
cross_entropy: bool = False,
|
|
700
|
+
fused_linear_cross_entropy: bool = True,
|
|
701
|
+
layer_norm: bool = True,
|
|
702
|
+
rms_norm: bool = True,
|
|
703
|
+
geglu: bool = True,
|
|
704
|
+
model: PreTrainedModel = None,
|
|
705
|
+
) -> None:
|
|
706
|
+
"""
|
|
707
|
+
Apply Liger kernels to replace original implementation in HuggingFace PaliGemma
|
|
708
|
+
|
|
709
|
+
Args:
|
|
710
|
+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
|
|
711
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
712
|
+
fused_linear_cross_entropy (bool):
|
|
713
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
714
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
715
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
716
|
+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
|
|
717
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
718
|
+
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
|
|
719
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
720
|
+
loaded. Default is None.
|
|
721
|
+
"""
|
|
722
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
723
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
# PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
|
|
727
|
+
|
|
728
|
+
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
|
|
729
|
+
from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
|
|
730
|
+
from transformers.models.paligemma import modeling_paligemma
|
|
731
|
+
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
|
732
|
+
from transformers.models.siglip import modeling_siglip
|
|
733
|
+
from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer
|
|
734
|
+
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
|
|
735
|
+
|
|
736
|
+
from liger_kernel.transformers.model.paligemma import lce_forward
|
|
737
|
+
from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
|
|
738
|
+
|
|
739
|
+
# The vision_tower is a SiglipVisionModel
|
|
740
|
+
if layer_norm:
|
|
741
|
+
modeling_siglip.nn.LayerNorm = LigerLayerNorm
|
|
742
|
+
|
|
743
|
+
# SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
|
|
744
|
+
# The multi_modal_projector is Linear, nothing to do
|
|
745
|
+
|
|
746
|
+
# The language_model is GemmaForCausalLM or Gemma2ForCausalLM
|
|
747
|
+
apply_liger_kernel_to_gemma(
|
|
748
|
+
rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
|
|
749
|
+
)
|
|
750
|
+
apply_liger_kernel_to_gemma2(
|
|
751
|
+
rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
|
|
752
|
+
)
|
|
753
|
+
# Handle loss function
|
|
754
|
+
if cross_entropy:
|
|
755
|
+
modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
756
|
+
if fused_linear_cross_entropy:
|
|
757
|
+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
|
|
758
|
+
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
|
|
759
|
+
else: # if version < 4.46.1
|
|
760
|
+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
|
|
761
|
+
modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
|
|
762
|
+
|
|
763
|
+
if model is not None:
|
|
764
|
+
# The model instance already exists, so we need to additionally patch the
|
|
765
|
+
# instance variables that reference already-instantiated modules
|
|
766
|
+
|
|
767
|
+
if not isinstance(model, PaliGemmaForConditionalGeneration):
|
|
768
|
+
raise TypeError("model have to be of type PaliGemmaForConditionalGeneration")
|
|
769
|
+
|
|
770
|
+
vision_tower: SiglipVisionModel = model.vision_tower
|
|
771
|
+
|
|
772
|
+
_patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
|
|
773
|
+
|
|
774
|
+
for layer in vision_tower.vision_model.encoder.layers:
|
|
775
|
+
layer: SiglipEncoderLayer
|
|
776
|
+
if layer_norm:
|
|
777
|
+
_patch_layer_norm_module(layer.layer_norm1)
|
|
778
|
+
_patch_layer_norm_module(layer.layer_norm2)
|
|
779
|
+
|
|
780
|
+
language_model = model.language_model
|
|
781
|
+
|
|
782
|
+
if isinstance(language_model, GemmaForCausalLM):
|
|
783
|
+
apply_liger_kernel_to_gemma(
|
|
784
|
+
rope=rope,
|
|
785
|
+
cross_entropy=False,
|
|
786
|
+
fused_linear_cross_entropy=False,
|
|
787
|
+
rms_norm=rms_norm,
|
|
788
|
+
geglu=geglu,
|
|
789
|
+
model=language_model,
|
|
790
|
+
)
|
|
791
|
+
|
|
792
|
+
elif isinstance(language_model, Gemma2ForCausalLM):
|
|
793
|
+
apply_liger_kernel_to_gemma2(
|
|
794
|
+
rope=rope,
|
|
795
|
+
cross_entropy=False,
|
|
796
|
+
fused_linear_cross_entropy=False,
|
|
797
|
+
rms_norm=rms_norm,
|
|
798
|
+
geglu=geglu,
|
|
799
|
+
model=language_model,
|
|
800
|
+
)
|
|
801
|
+
else:
|
|
802
|
+
raise TypeError(
|
|
803
|
+
"The language_model of a PaliGemma model must be either GemmaForCausalLM or Gemma2ForCausalLM."
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
|
|
603
807
|
def apply_liger_kernel_to_qwen2(
|
|
604
808
|
rope: bool = True,
|
|
605
809
|
cross_entropy: bool = False,
|
|
@@ -666,7 +870,7 @@ def apply_liger_kernel_to_qwen2(
|
|
|
666
870
|
|
|
667
871
|
for decoder_layer in base_model.layers:
|
|
668
872
|
if swiglu:
|
|
669
|
-
|
|
873
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
670
874
|
if rms_norm:
|
|
671
875
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
672
876
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -739,7 +943,74 @@ def apply_liger_kernel_to_qwen2_vl(
|
|
|
739
943
|
_patch_rms_norm_module(base_model.norm)
|
|
740
944
|
for decoder_layer in base_model.layers:
|
|
741
945
|
if swiglu:
|
|
742
|
-
|
|
946
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
947
|
+
if rms_norm:
|
|
948
|
+
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
949
|
+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
950
|
+
|
|
951
|
+
|
|
952
|
+
def apply_liger_kernel_to_qwen2_5_vl(
|
|
953
|
+
rope: bool = True,
|
|
954
|
+
cross_entropy: bool = False,
|
|
955
|
+
fused_linear_cross_entropy: bool = True,
|
|
956
|
+
rms_norm: bool = True,
|
|
957
|
+
swiglu: bool = True,
|
|
958
|
+
model: PreTrainedModel = None,
|
|
959
|
+
) -> None:
|
|
960
|
+
"""
|
|
961
|
+
Apply Liger kernels to replace original implementation in HuggingFace Qwen2.5-VL models.
|
|
962
|
+
NOTE: Qwen2.5-VL is not available in transformers<4.48.2
|
|
963
|
+
|
|
964
|
+
Args:
|
|
965
|
+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
|
966
|
+
fused_linear_cross_entropy (bool):
|
|
967
|
+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
|
|
968
|
+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
|
969
|
+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
|
970
|
+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
|
|
971
|
+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
|
|
972
|
+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
|
|
973
|
+
loaded. Default is None.
|
|
974
|
+
"""
|
|
975
|
+
assert not (cross_entropy and fused_linear_cross_entropy), (
|
|
976
|
+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
|
977
|
+
)
|
|
978
|
+
|
|
979
|
+
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
|
|
980
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel
|
|
981
|
+
|
|
982
|
+
from liger_kernel.transformers.model.qwen2_5_vl import lce_forward as qwen2_5_vl_lce_forward
|
|
983
|
+
|
|
984
|
+
if rope:
|
|
985
|
+
modeling_qwen2_5_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
|
|
986
|
+
if rms_norm:
|
|
987
|
+
modeling_qwen2_5_vl.Qwen2RMSNorm = LigerRMSNorm
|
|
988
|
+
if cross_entropy:
|
|
989
|
+
modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
990
|
+
if fused_linear_cross_entropy:
|
|
991
|
+
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
|
|
992
|
+
if swiglu:
|
|
993
|
+
modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
|
|
994
|
+
|
|
995
|
+
if model is not None:
|
|
996
|
+
# The model instance already exists, so we need to additionally patch the
|
|
997
|
+
# instance variables that reference already-instantiated modules
|
|
998
|
+
|
|
999
|
+
# get the base model from the model instance
|
|
1000
|
+
base_model: Qwen2_5_VLModel = getattr(model, model.base_model_prefix, model)
|
|
1001
|
+
|
|
1002
|
+
if hasattr(model, "visual"):
|
|
1003
|
+
# Patch Qwen2_5_VisionTransformerPretrainedModel
|
|
1004
|
+
for vision_block in model.visual.blocks:
|
|
1005
|
+
if rms_norm:
|
|
1006
|
+
_patch_rms_norm_module(vision_block.norm1)
|
|
1007
|
+
_patch_rms_norm_module(vision_block.norm2)
|
|
1008
|
+
|
|
1009
|
+
if rms_norm:
|
|
1010
|
+
_patch_rms_norm_module(base_model.norm)
|
|
1011
|
+
for decoder_layer in base_model.layers:
|
|
1012
|
+
if swiglu:
|
|
1013
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
743
1014
|
if rms_norm:
|
|
744
1015
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
745
1016
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -808,7 +1079,7 @@ def apply_liger_kernel_to_phi3(
|
|
|
808
1079
|
|
|
809
1080
|
for decoder_layer in base_model.layers:
|
|
810
1081
|
if swiglu:
|
|
811
|
-
|
|
1082
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
|
|
812
1083
|
if rms_norm:
|
|
813
1084
|
_patch_rms_norm_module(decoder_layer.input_layernorm)
|
|
814
1085
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
|
|
@@ -871,7 +1142,7 @@ def apply_liger_kernel_to_olmo2(
|
|
|
871
1142
|
|
|
872
1143
|
for decoder_layer in base_model.layers:
|
|
873
1144
|
if swiglu:
|
|
874
|
-
|
|
1145
|
+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
|
|
875
1146
|
if rms_norm:
|
|
876
1147
|
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
|
|
877
1148
|
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
|
|
@@ -882,6 +1153,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
882
1153
|
"gemma": apply_liger_kernel_to_gemma,
|
|
883
1154
|
"gemma2": apply_liger_kernel_to_gemma2,
|
|
884
1155
|
"llama": apply_liger_kernel_to_llama,
|
|
1156
|
+
"llava": apply_liger_kernel_to_llava,
|
|
885
1157
|
"granite": apply_liger_kernel_to_granite,
|
|
886
1158
|
"mllama": apply_liger_kernel_to_mllama,
|
|
887
1159
|
"mllama_text_model": apply_liger_kernel_to_mllama,
|
|
@@ -890,7 +1162,9 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
|
|
|
890
1162
|
"olmo2": apply_liger_kernel_to_olmo2,
|
|
891
1163
|
"qwen2": apply_liger_kernel_to_qwen2,
|
|
892
1164
|
"qwen2_vl": apply_liger_kernel_to_qwen2_vl,
|
|
1165
|
+
"qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
|
|
893
1166
|
"phi3": apply_liger_kernel_to_phi3,
|
|
1167
|
+
"paligemma": apply_liger_kernel_to_paligemma,
|
|
894
1168
|
}
|
|
895
1169
|
|
|
896
1170
|
|
liger_kernel/utils.py
CHANGED
|
@@ -5,12 +5,10 @@ def infer_device():
|
|
|
5
5
|
"""
|
|
6
6
|
Get current device name based on available devices
|
|
7
7
|
"""
|
|
8
|
-
if torch.cuda.is_available():
|
|
8
|
+
if torch.cuda.is_available(): # Works for both Nvidia and AMD
|
|
9
9
|
return "cuda"
|
|
10
10
|
elif torch.xpu.is_available():
|
|
11
11
|
return "xpu"
|
|
12
|
-
elif torch.hip.is_available():
|
|
13
|
-
return "hip"
|
|
14
12
|
else:
|
|
15
13
|
return "cpu"
|
|
16
14
|
|