liger-kernel 0.5.4__py3-none-any.whl → 0.5.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. liger_kernel/chunked_loss/cpo_loss.py +51 -11
  2. liger_kernel/chunked_loss/dpo_loss.py +30 -4
  3. liger_kernel/chunked_loss/functional.py +2 -0
  4. liger_kernel/chunked_loss/fused_linear_distillation.py +20 -5
  5. liger_kernel/chunked_loss/fused_linear_ppo.py +331 -0
  6. liger_kernel/chunked_loss/fused_linear_preference.py +2 -2
  7. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +112 -17
  8. liger_kernel/chunked_loss/grpo_loss.py +137 -61
  9. liger_kernel/chunked_loss/jsd_loss.py +43 -13
  10. liger_kernel/chunked_loss/kto_loss.py +50 -12
  11. liger_kernel/chunked_loss/orpo_loss.py +37 -5
  12. liger_kernel/chunked_loss/simpo_loss.py +47 -11
  13. liger_kernel/ops/cross_entropy.py +7 -2
  14. liger_kernel/ops/dyt.py +225 -0
  15. liger_kernel/ops/fused_linear_jsd.py +2 -1
  16. liger_kernel/ops/jsd.py +30 -11
  17. liger_kernel/ops/kl_div.py +2 -2
  18. liger_kernel/transformers/__init__.py +4 -0
  19. liger_kernel/transformers/dyt.py +20 -0
  20. liger_kernel/transformers/functional.py +5 -0
  21. liger_kernel/transformers/model/gemma.py +8 -16
  22. liger_kernel/transformers/model/gemma2.py +7 -16
  23. liger_kernel/transformers/model/llama.py +8 -15
  24. liger_kernel/transformers/model/llava.py +369 -0
  25. liger_kernel/transformers/model/loss_utils.py +57 -0
  26. liger_kernel/transformers/model/mistral.py +9 -10
  27. liger_kernel/transformers/model/mixtral.py +8 -15
  28. liger_kernel/transformers/model/mllama.py +8 -15
  29. liger_kernel/transformers/model/olmo2.py +8 -16
  30. liger_kernel/transformers/model/paligemma.py +397 -0
  31. liger_kernel/transformers/model/phi3.py +8 -15
  32. liger_kernel/transformers/model/qwen2.py +8 -15
  33. liger_kernel/transformers/model/qwen2_5_vl.py +204 -0
  34. liger_kernel/transformers/model/qwen2_vl.py +9 -10
  35. liger_kernel/transformers/monkey_patch.py +286 -12
  36. liger_kernel/utils.py +1 -3
  37. {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info}/METADATA +11 -7
  38. liger_kernel-0.5.6.dist-info/RECORD +80 -0
  39. {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info}/WHEEL +1 -1
  40. liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -213
  41. liger_kernel-0.5.4.dist-info/RECORD +0 -74
  42. {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info/licenses}/LICENSE +0 -0
  43. {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info/licenses}/NOTICE +0 -0
  44. {liger_kernel-0.5.4.dist-info → liger_kernel-0.5.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,204 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from torch.nn import CrossEntropyLoss
9
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import _CONFIG_FOR_DOC
10
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import QWEN2_5_VL_INPUTS_DOCSTRING
11
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast
12
+ from transformers.utils import add_start_docstrings_to_model_forward
13
+ from transformers.utils import replace_return_docstrings
14
+
15
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
+
17
+
18
+ @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING)
19
+ @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
20
+ def lce_forward(
21
+ self,
22
+ input_ids: torch.LongTensor = None,
23
+ attention_mask: Optional[torch.Tensor] = None,
24
+ position_ids: Optional[torch.LongTensor] = None,
25
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
26
+ inputs_embeds: Optional[torch.FloatTensor] = None,
27
+ labels: Optional[torch.LongTensor] = None,
28
+ use_cache: Optional[bool] = None,
29
+ output_attentions: Optional[bool] = None,
30
+ output_hidden_states: Optional[bool] = None,
31
+ return_dict: Optional[bool] = None,
32
+ pixel_values: Optional[torch.Tensor] = None,
33
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
34
+ image_grid_thw: Optional[torch.LongTensor] = None,
35
+ video_grid_thw: Optional[torch.LongTensor] = None,
36
+ rope_deltas: Optional[torch.LongTensor] = None,
37
+ cache_position: Optional[torch.LongTensor] = None,
38
+ second_per_grid_ts: Optional[torch.Tensor] = None,
39
+ **loss_kwargs,
40
+ ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
41
+ r"""
42
+ Copy paste Qwen2_5_VL's forward but replace torch cross entropy with liger fused linear cross entropy
43
+ Args:
44
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
45
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
46
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
47
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
48
+
49
+ Returns:
50
+
51
+ Example:
52
+
53
+ ```python
54
+ >>> from PIL import Image
55
+ >>> import requests
56
+ >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
57
+
58
+ >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
59
+ >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
60
+
61
+ >>> messages = [
62
+ {
63
+ "role": "user",
64
+ "content": [
65
+ {"type": "image"},
66
+ {"type": "text", "text": "What is shown in this image?"},
67
+ ],
68
+ },
69
+ ]
70
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
71
+ >>> image = Image.open(requests.get(url, stream=True).raw)
72
+
73
+ >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
74
+ >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
75
+
76
+ >>> # Generate
77
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
78
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
79
+ "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
80
+ ```"""
81
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
82
+ output_hidden_states = (
83
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
84
+ )
85
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
86
+
87
+ if inputs_embeds is None:
88
+ inputs_embeds = self.model.embed_tokens(input_ids)
89
+ if pixel_values is not None:
90
+ pixel_values = pixel_values.type(self.visual.dtype)
91
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
92
+ n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
93
+ n_image_features = image_embeds.shape[0]
94
+ if n_image_tokens != n_image_features:
95
+ raise ValueError(
96
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
97
+ )
98
+
99
+ mask = input_ids == self.config.image_token_id
100
+ mask_unsqueezed = mask.unsqueeze(-1)
101
+ mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
102
+ image_mask = mask_expanded.to(inputs_embeds.device)
103
+
104
+ image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
105
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
106
+
107
+ if pixel_values_videos is not None:
108
+ pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
109
+ video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
110
+ n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
111
+ n_video_features = video_embeds.shape[0]
112
+ if n_video_tokens != n_video_features:
113
+ raise ValueError(
114
+ f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
115
+ )
116
+
117
+ mask = input_ids == self.config.video_token_id
118
+ mask_unsqueezed = mask.unsqueeze(-1)
119
+ mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
120
+ video_mask = mask_expanded.to(inputs_embeds.device)
121
+
122
+ video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
123
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
124
+
125
+ if attention_mask is not None:
126
+ attention_mask = attention_mask.to(inputs_embeds.device)
127
+
128
+ # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
129
+ if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
130
+ # calculate RoPE index once per generation in the pre-fill stage only
131
+ if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
132
+ position_ids, rope_deltas = self.get_rope_index(
133
+ input_ids,
134
+ image_grid_thw,
135
+ video_grid_thw,
136
+ second_per_grid_ts,
137
+ attention_mask,
138
+ )
139
+ self.rope_deltas = rope_deltas
140
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
141
+ else:
142
+ batch_size, seq_length, _ = inputs_embeds.shape
143
+ delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0
144
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
145
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
146
+ if cache_position is not None: # otherwise `deltas` is an int `0`
147
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
148
+ position_ids = position_ids.add(delta)
149
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
150
+
151
+ outputs = self.model(
152
+ input_ids=None,
153
+ position_ids=position_ids,
154
+ attention_mask=attention_mask,
155
+ past_key_values=past_key_values,
156
+ inputs_embeds=inputs_embeds,
157
+ use_cache=use_cache,
158
+ output_attentions=output_attentions,
159
+ output_hidden_states=output_hidden_states,
160
+ return_dict=return_dict,
161
+ cache_position=cache_position,
162
+ )
163
+
164
+ hidden_states = outputs[0]
165
+
166
+ loss = None
167
+ logits = None
168
+
169
+ if self.training and (labels is not None):
170
+ loss = LigerForCausalLMLoss(
171
+ hidden_states=hidden_states,
172
+ lm_head_weight=self.lm_head.weight,
173
+ labels=labels,
174
+ hidden_size=self.config.hidden_size,
175
+ **loss_kwargs,
176
+ )
177
+ else:
178
+ logits = self.lm_head(hidden_states)
179
+ if labels is not None:
180
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
181
+ logits = logits.float()
182
+ # Shift so that tokens < n predict n
183
+ shift_logits = logits[..., :-1, :].contiguous()
184
+ shift_labels = labels[..., 1:].contiguous()
185
+ # Flatten the tokens
186
+ loss_fct = CrossEntropyLoss()
187
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
188
+ shift_labels = shift_labels.view(-1)
189
+ # Enable model parallelism
190
+ shift_labels = shift_labels.to(shift_logits.device)
191
+ loss = loss_fct(shift_logits, shift_labels)
192
+
193
+ if not return_dict:
194
+ output = (logits,) + outputs[1:]
195
+ return (loss,) + output if loss is not None else output
196
+
197
+ return Qwen2_5_VLCausalLMOutputWithPast(
198
+ loss=loss,
199
+ logits=logits,
200
+ past_key_values=outputs.past_key_values,
201
+ hidden_states=outputs.hidden_states,
202
+ attentions=outputs.attentions,
203
+ rope_deltas=rope_deltas,
204
+ )
@@ -14,7 +14,7 @@ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutput
14
14
  from transformers.utils import add_start_docstrings_to_model_forward
15
15
  from transformers.utils import replace_return_docstrings
16
16
 
17
- from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
18
18
 
19
19
 
20
20
  @add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
@@ -37,6 +37,7 @@ def lce_forward(
37
37
  video_grid_thw: Optional[torch.LongTensor] = None,
38
38
  rope_deltas: Optional[torch.LongTensor] = None,
39
39
  cache_position: Optional[torch.LongTensor] = None,
40
+ **loss_kwargs,
40
41
  ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
41
42
  r"""
42
43
  Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy
@@ -170,15 +171,13 @@ def lce_forward(
170
171
  logits = None
171
172
 
172
173
  if self.training and (labels is not None):
173
- shift_hidden_states = hidden_states[..., :-1, :].contiguous()
174
- shift_labels = labels[..., 1:].contiguous()
175
-
176
- # Flatten tokens
177
- shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
178
- shift_labels = shift_labels.view(-1)
179
-
180
- lce = LigerFusedLinearCrossEntropyLoss()
181
- loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
174
+ loss = LigerForCausalLMLoss(
175
+ hidden_states=hidden_states,
176
+ lm_head_weight=self.lm_head.weight,
177
+ labels=labels,
178
+ hidden_size=self.config.hidden_size,
179
+ **loss_kwargs,
180
+ )
182
181
  else:
183
182
  logits = self.lm_head(hidden_states)
184
183
  if labels is not None:
@@ -19,6 +19,8 @@ from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_for
19
19
  from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
20
20
  from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
21
21
  from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
22
+ from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
23
+ from liger_kernel.transformers.model.llava import lce_forward_deprecated as llava_lce_forward_deprecated
22
24
  from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
23
25
  from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
24
26
  from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
@@ -52,13 +54,26 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
52
54
  module.in_place = in_place
53
55
  _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
54
56
  _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
57
+ module.__class__.__name__ = LigerRMSNorm.__name__
55
58
 
56
59
 
57
60
  def _patch_layer_norm_module(module, eps=1e-6):
58
61
  module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
59
- module.hidden_size = module.normalized_shape
62
+ module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
63
+
60
64
  _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
61
65
  _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
66
+ module.__class__.__name__ = LigerLayerNorm.__name__
67
+
68
+
69
+ def _patch_swiglu_module(module, liger_module):
70
+ _bind_method_to_module(module, "forward", liger_module.forward)
71
+ module.__class__.__name__ = liger_module.__name__
72
+
73
+
74
+ def _patch_geglu_module(module):
75
+ _bind_method_to_module(module, "forward", LigerGEGLUMLP.forward)
76
+ module.__class__.__name__ = LigerGEGLUMLP.__name__
62
77
 
63
78
 
64
79
  def apply_liger_kernel_to_granite(
@@ -134,7 +149,7 @@ def apply_liger_kernel_to_granite(
134
149
 
135
150
  for decoder_layer in base_model.layers:
136
151
  if swiglu:
137
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
152
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
138
153
  if rms_norm:
139
154
  _patch_rms_norm_module(decoder_layer.input_layernorm)
140
155
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -206,12 +221,91 @@ def apply_liger_kernel_to_llama(
206
221
 
207
222
  for decoder_layer in base_model.layers:
208
223
  if swiglu:
209
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
224
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
210
225
  if rms_norm:
211
226
  _patch_rms_norm_module(decoder_layer.input_layernorm)
212
227
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
213
228
 
214
229
 
230
+ def apply_liger_kernel_to_llava(
231
+ cross_entropy: bool = False,
232
+ fused_linear_cross_entropy: bool = True,
233
+ model: PreTrainedModel = None,
234
+ **kwargs,
235
+ ) -> None:
236
+ """
237
+ Apply Liger kernels to replace original implementation in HuggingFace Llava models.
238
+ Due to the characteristics of LlaVa, the model must be passed to apply Liger-Kernel's patch to other models connected to LLaVa.
239
+ However, if an LM not supported by Liger-Kernel is connected to LLaVa, unexpected side effects may occur.
240
+ NOTE: Llava is not available in transformers<4.36.0
241
+
242
+ Args:
243
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
244
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
245
+ fused_linear_cross_entropy (bool):
246
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
247
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
248
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
249
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
250
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
251
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
252
+ loaded. Default is None.
253
+ """
254
+ assert not (cross_entropy and fused_linear_cross_entropy), (
255
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
256
+ )
257
+
258
+ from transformers.models.llava import modeling_llava
259
+
260
+ if cross_entropy:
261
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
262
+ modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
263
+ if fused_linear_cross_entropy:
264
+ if transformer_version >= version.parse("4.49.0"):
265
+ modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
266
+ else: # if version < 4.49.0
267
+ logger.warning(
268
+ "Support for transformers versions < 4.49.0 will soon be discontinued due to issues with incorrect legacy processing. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/35526"
269
+ )
270
+ modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
271
+
272
+ if model is not None:
273
+ text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
274
+ text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
275
+ vision_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(vision_model_name, None)
276
+
277
+ kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs}
278
+ if text_liger_fn:
279
+ accept_params = inspect.signature(text_liger_fn).parameters
280
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
281
+ text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
282
+
283
+ if remain_params:
284
+ logger.warning(
285
+ f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
286
+ f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
287
+ )
288
+ text_kwargs["model"] = model.language_model
289
+ text_liger_fn(**text_kwargs)
290
+ elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
291
+ logger.warning(f"{text_model_name} is not supported by Liger kernel.")
292
+
293
+ if vision_liger_fn:
294
+ accept_params = inspect.signature(vision_liger_fn).parameters
295
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
296
+ vision_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
297
+
298
+ if remain_params:
299
+ logger.warning(
300
+ f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
301
+ f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
302
+ )
303
+ vision_kwargs["model"] = model.vision_tower
304
+ vision_liger_fn(**vision_kwargs)
305
+ elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
306
+ logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
307
+
308
+
215
309
  def apply_liger_kernel_to_mllama(
216
310
  rope: bool = True,
217
311
  cross_entropy: bool = False,
@@ -296,7 +390,7 @@ def apply_liger_kernel_to_mllama(
296
390
  _patch_rms_norm_module(text_model.norm)
297
391
  for decoder_layer in text_model.layers:
298
392
  if swiglu:
299
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
393
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
300
394
  if rms_norm:
301
395
  _patch_rms_norm_module(decoder_layer.input_layernorm)
302
396
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -370,7 +464,7 @@ def apply_liger_kernel_to_mistral(
370
464
 
371
465
  for decoder_layer in base_model.layers:
372
466
  if swiglu:
373
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
467
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
374
468
  if rms_norm:
375
469
  _patch_rms_norm_module(decoder_layer.input_layernorm)
376
470
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -442,7 +536,7 @@ def apply_liger_kernel_to_mixtral(
442
536
  for decoder_layer in base_model.layers:
443
537
  if swiglu:
444
538
  for expert in decoder_layer.block_sparse_moe.experts:
445
- _bind_method_to_module(expert, "forward", LigerBlockSparseTop2MLP.forward)
539
+ _patch_swiglu_module(expert, LigerBlockSparseTop2MLP)
446
540
  if rms_norm:
447
541
  _patch_rms_norm_module(decoder_layer.input_layernorm)
448
542
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -516,7 +610,7 @@ def apply_liger_kernel_to_gemma(
516
610
 
517
611
  for decoder_layer in base_model.layers:
518
612
  if geglu:
519
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
613
+ _patch_geglu_module(decoder_layer.mlp)
520
614
  if rms_norm:
521
615
  _patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
522
616
  _patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)
@@ -592,7 +686,7 @@ def apply_liger_kernel_to_gemma2(
592
686
 
593
687
  for decoder_layer in base_model.layers:
594
688
  if geglu:
595
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
689
+ _patch_geglu_module(decoder_layer.mlp)
596
690
  if rms_norm:
597
691
  _patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
598
692
  _patch_rms_norm_module_for_gemma2(decoder_layer.post_attention_layernorm)
@@ -600,6 +694,116 @@ def apply_liger_kernel_to_gemma2(
600
694
  _patch_rms_norm_module_for_gemma2(decoder_layer.post_feedforward_layernorm)
601
695
 
602
696
 
697
+ def apply_liger_kernel_to_paligemma(
698
+ rope: bool = True,
699
+ cross_entropy: bool = False,
700
+ fused_linear_cross_entropy: bool = True,
701
+ layer_norm: bool = True,
702
+ rms_norm: bool = True,
703
+ geglu: bool = True,
704
+ model: PreTrainedModel = None,
705
+ ) -> None:
706
+ """
707
+ Apply Liger kernels to replace original implementation in HuggingFace PaliGemma
708
+
709
+ Args:
710
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
711
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
712
+ fused_linear_cross_entropy (bool):
713
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
714
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
715
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
716
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
717
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
718
+ geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
719
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
720
+ loaded. Default is None.
721
+ """
722
+ assert not (cross_entropy and fused_linear_cross_entropy), (
723
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
724
+ )
725
+
726
+ # PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
727
+
728
+ from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
729
+ from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
730
+ from transformers.models.paligemma import modeling_paligemma
731
+ from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
732
+ from transformers.models.siglip import modeling_siglip
733
+ from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer
734
+ from transformers.models.siglip.modeling_siglip import SiglipVisionModel
735
+
736
+ from liger_kernel.transformers.model.paligemma import lce_forward
737
+ from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
738
+
739
+ # The vision_tower is a SiglipVisionModel
740
+ if layer_norm:
741
+ modeling_siglip.nn.LayerNorm = LigerLayerNorm
742
+
743
+ # SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
744
+ # The multi_modal_projector is Linear, nothing to do
745
+
746
+ # The language_model is GemmaForCausalLM or Gemma2ForCausalLM
747
+ apply_liger_kernel_to_gemma(
748
+ rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
749
+ )
750
+ apply_liger_kernel_to_gemma2(
751
+ rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, rms_norm=rms_norm, geglu=geglu
752
+ )
753
+ # Handle loss function
754
+ if cross_entropy:
755
+ modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
756
+ if fused_linear_cross_entropy:
757
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
758
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
759
+ else: # if version < 4.46.1
760
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
761
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
762
+
763
+ if model is not None:
764
+ # The model instance already exists, so we need to additionally patch the
765
+ # instance variables that reference already-instantiated modules
766
+
767
+ if not isinstance(model, PaliGemmaForConditionalGeneration):
768
+ raise TypeError("model have to be of type PaliGemmaForConditionalGeneration")
769
+
770
+ vision_tower: SiglipVisionModel = model.vision_tower
771
+
772
+ _patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
773
+
774
+ for layer in vision_tower.vision_model.encoder.layers:
775
+ layer: SiglipEncoderLayer
776
+ if layer_norm:
777
+ _patch_layer_norm_module(layer.layer_norm1)
778
+ _patch_layer_norm_module(layer.layer_norm2)
779
+
780
+ language_model = model.language_model
781
+
782
+ if isinstance(language_model, GemmaForCausalLM):
783
+ apply_liger_kernel_to_gemma(
784
+ rope=rope,
785
+ cross_entropy=False,
786
+ fused_linear_cross_entropy=False,
787
+ rms_norm=rms_norm,
788
+ geglu=geglu,
789
+ model=language_model,
790
+ )
791
+
792
+ elif isinstance(language_model, Gemma2ForCausalLM):
793
+ apply_liger_kernel_to_gemma2(
794
+ rope=rope,
795
+ cross_entropy=False,
796
+ fused_linear_cross_entropy=False,
797
+ rms_norm=rms_norm,
798
+ geglu=geglu,
799
+ model=language_model,
800
+ )
801
+ else:
802
+ raise TypeError(
803
+ "The language_model of a PaliGemma model must be either GemmaForCausalLM or Gemma2ForCausalLM."
804
+ )
805
+
806
+
603
807
  def apply_liger_kernel_to_qwen2(
604
808
  rope: bool = True,
605
809
  cross_entropy: bool = False,
@@ -666,7 +870,7 @@ def apply_liger_kernel_to_qwen2(
666
870
 
667
871
  for decoder_layer in base_model.layers:
668
872
  if swiglu:
669
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
873
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
670
874
  if rms_norm:
671
875
  _patch_rms_norm_module(decoder_layer.input_layernorm)
672
876
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -739,7 +943,74 @@ def apply_liger_kernel_to_qwen2_vl(
739
943
  _patch_rms_norm_module(base_model.norm)
740
944
  for decoder_layer in base_model.layers:
741
945
  if swiglu:
742
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
946
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
947
+ if rms_norm:
948
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
949
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
950
+
951
+
952
+ def apply_liger_kernel_to_qwen2_5_vl(
953
+ rope: bool = True,
954
+ cross_entropy: bool = False,
955
+ fused_linear_cross_entropy: bool = True,
956
+ rms_norm: bool = True,
957
+ swiglu: bool = True,
958
+ model: PreTrainedModel = None,
959
+ ) -> None:
960
+ """
961
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen2.5-VL models.
962
+ NOTE: Qwen2.5-VL is not available in transformers<4.48.2
963
+
964
+ Args:
965
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
966
+ fused_linear_cross_entropy (bool):
967
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
968
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
969
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
970
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
971
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
972
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
973
+ loaded. Default is None.
974
+ """
975
+ assert not (cross_entropy and fused_linear_cross_entropy), (
976
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
977
+ )
978
+
979
+ from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
980
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel
981
+
982
+ from liger_kernel.transformers.model.qwen2_5_vl import lce_forward as qwen2_5_vl_lce_forward
983
+
984
+ if rope:
985
+ modeling_qwen2_5_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
986
+ if rms_norm:
987
+ modeling_qwen2_5_vl.Qwen2RMSNorm = LigerRMSNorm
988
+ if cross_entropy:
989
+ modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
990
+ if fused_linear_cross_entropy:
991
+ modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
992
+ if swiglu:
993
+ modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
994
+
995
+ if model is not None:
996
+ # The model instance already exists, so we need to additionally patch the
997
+ # instance variables that reference already-instantiated modules
998
+
999
+ # get the base model from the model instance
1000
+ base_model: Qwen2_5_VLModel = getattr(model, model.base_model_prefix, model)
1001
+
1002
+ if hasattr(model, "visual"):
1003
+ # Patch Qwen2_5_VisionTransformerPretrainedModel
1004
+ for vision_block in model.visual.blocks:
1005
+ if rms_norm:
1006
+ _patch_rms_norm_module(vision_block.norm1)
1007
+ _patch_rms_norm_module(vision_block.norm2)
1008
+
1009
+ if rms_norm:
1010
+ _patch_rms_norm_module(base_model.norm)
1011
+ for decoder_layer in base_model.layers:
1012
+ if swiglu:
1013
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
743
1014
  if rms_norm:
744
1015
  _patch_rms_norm_module(decoder_layer.input_layernorm)
745
1016
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -808,7 +1079,7 @@ def apply_liger_kernel_to_phi3(
808
1079
 
809
1080
  for decoder_layer in base_model.layers:
810
1081
  if swiglu:
811
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward)
1082
+ _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
812
1083
  if rms_norm:
813
1084
  _patch_rms_norm_module(decoder_layer.input_layernorm)
814
1085
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -871,7 +1142,7 @@ def apply_liger_kernel_to_olmo2(
871
1142
 
872
1143
  for decoder_layer in base_model.layers:
873
1144
  if swiglu:
874
- _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
1145
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
875
1146
  if rms_norm:
876
1147
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
877
1148
  _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
@@ -882,6 +1153,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
882
1153
  "gemma": apply_liger_kernel_to_gemma,
883
1154
  "gemma2": apply_liger_kernel_to_gemma2,
884
1155
  "llama": apply_liger_kernel_to_llama,
1156
+ "llava": apply_liger_kernel_to_llava,
885
1157
  "granite": apply_liger_kernel_to_granite,
886
1158
  "mllama": apply_liger_kernel_to_mllama,
887
1159
  "mllama_text_model": apply_liger_kernel_to_mllama,
@@ -890,7 +1162,9 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
890
1162
  "olmo2": apply_liger_kernel_to_olmo2,
891
1163
  "qwen2": apply_liger_kernel_to_qwen2,
892
1164
  "qwen2_vl": apply_liger_kernel_to_qwen2_vl,
1165
+ "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
893
1166
  "phi3": apply_liger_kernel_to_phi3,
1167
+ "paligemma": apply_liger_kernel_to_paligemma,
894
1168
  }
895
1169
 
896
1170
 
liger_kernel/utils.py CHANGED
@@ -5,12 +5,10 @@ def infer_device():
5
5
  """
6
6
  Get current device name based on available devices
7
7
  """
8
- if torch.cuda.is_available():
8
+ if torch.cuda.is_available(): # Works for both Nvidia and AMD
9
9
  return "cuda"
10
10
  elif torch.xpu.is_available():
11
11
  return "xpu"
12
- elif torch.hip.is_available():
13
- return "hip"
14
12
  else:
15
13
  return "cpu"
16
14