liger-kernel 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,277 @@
1
+ import logging
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch.nn import CrossEntropyLoss
6
+ from transformers.cache_utils import HybridCache
7
+ from transformers.modeling_outputs import CausalLMOutputWithPast
8
+ from transformers.models.gemma2.modeling_gemma2 import (
9
+ _CONFIG_FOR_DOC,
10
+ GEMMA2_INPUTS_DOCSTRING,
11
+ )
12
+ from transformers.utils import (
13
+ add_start_docstrings_to_model_forward,
14
+ replace_return_docstrings,
15
+ )
16
+
17
+ from liger_kernel.transformers.fused_linear_cross_entropy import (
18
+ LigerFusedLinearCrossEntropyLoss,
19
+ )
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def lce_forward_deprecated(
25
+ self,
26
+ input_ids: torch.LongTensor = None,
27
+ attention_mask: Optional[torch.Tensor] = None,
28
+ position_ids: Optional[torch.LongTensor] = None,
29
+ past_key_values: Optional[HybridCache] = None,
30
+ inputs_embeds: Optional[torch.FloatTensor] = None,
31
+ labels: Optional[torch.LongTensor] = None,
32
+ use_cache: Optional[bool] = None,
33
+ output_attentions: Optional[bool] = None,
34
+ output_hidden_states: Optional[bool] = None,
35
+ return_dict: Optional[bool] = None,
36
+ cache_position: Optional[torch.LongTensor] = None,
37
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
38
+ r"""
39
+ Args:
40
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
41
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
42
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
43
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
44
+
45
+ Returns:
46
+
47
+ Example:
48
+
49
+ ```python
50
+ >>> from transformers import AutoTokenizer, GemmaForCausalLM
51
+ >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b")
52
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
53
+ >>> prompt = "What is your favorite condiment?"
54
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
55
+ >>> # Generate
56
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
57
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
58
+ "What is your favorite condiment?"
59
+ ```"""
60
+
61
+ if self.training and self.config._attn_implementation != "eager":
62
+ logger.warning_once(
63
+ "It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
64
+ f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
65
+ )
66
+ output_attentions = (
67
+ output_attentions
68
+ if output_attentions is not None
69
+ else self.config.output_attentions
70
+ )
71
+ output_hidden_states = (
72
+ output_hidden_states
73
+ if output_hidden_states is not None
74
+ else self.config.output_hidden_states
75
+ )
76
+ return_dict = (
77
+ return_dict if return_dict is not None else self.config.use_return_dict
78
+ )
79
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
80
+ outputs = self.model(
81
+ input_ids=input_ids,
82
+ attention_mask=attention_mask,
83
+ position_ids=position_ids,
84
+ past_key_values=past_key_values,
85
+ inputs_embeds=inputs_embeds,
86
+ use_cache=use_cache,
87
+ output_attentions=output_attentions,
88
+ output_hidden_states=output_hidden_states,
89
+ return_dict=return_dict,
90
+ cache_position=cache_position,
91
+ )
92
+
93
+ hidden_states = outputs[0]
94
+
95
+ loss = None
96
+ logits = None
97
+
98
+ if self.training and (labels is not None):
99
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
100
+ shift_labels = labels[..., 1:].contiguous()
101
+
102
+ # flatten
103
+
104
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
105
+ shift_labels = shift_labels.view(-1)
106
+
107
+ lce = LigerFusedLinearCrossEntropyLoss(
108
+ softcap=self.config.final_logit_softcapping
109
+ )
110
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
111
+
112
+ else:
113
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
114
+ logits = self.lm_head(hidden_states)
115
+ if self.config.final_logit_softcapping is not None:
116
+ logits = logits / self.config.final_logit_softcapping
117
+ logits = torch.tanh(logits)
118
+ logits = logits * self.config.final_logit_softcapping
119
+
120
+ loss = None
121
+ if labels is not None:
122
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
123
+ logits = logits.float()
124
+ # Shift so that tokens < n predict n
125
+ shift_logits = logits[..., :-1, :].contiguous()
126
+ shift_labels = labels[..., 1:].contiguous()
127
+ # Flatten the tokens
128
+ loss_fct = CrossEntropyLoss()
129
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
130
+ shift_labels = shift_labels.view(-1)
131
+ # Enable model parallelism
132
+ shift_labels = shift_labels.to(shift_logits.device)
133
+ loss = loss_fct(shift_logits, shift_labels)
134
+
135
+ if not return_dict:
136
+ output = (logits,) + outputs[1:]
137
+ return (loss,) + output if loss is not None else output
138
+
139
+ return CausalLMOutputWithPast(
140
+ loss=loss,
141
+ logits=logits,
142
+ past_key_values=outputs.past_key_values,
143
+ hidden_states=outputs.hidden_states,
144
+ attentions=outputs.attentions,
145
+ )
146
+
147
+
148
+ @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
149
+ @replace_return_docstrings(
150
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
151
+ )
152
+ def lce_forward(
153
+ self,
154
+ input_ids: torch.LongTensor = None,
155
+ attention_mask: Optional[torch.Tensor] = None,
156
+ position_ids: Optional[torch.LongTensor] = None,
157
+ past_key_values: Optional[HybridCache] = None,
158
+ inputs_embeds: Optional[torch.FloatTensor] = None,
159
+ labels: Optional[torch.LongTensor] = None,
160
+ use_cache: Optional[bool] = None,
161
+ output_attentions: Optional[bool] = None,
162
+ output_hidden_states: Optional[bool] = None,
163
+ return_dict: Optional[bool] = None,
164
+ cache_position: Optional[torch.LongTensor] = None,
165
+ num_logits_to_keep: int = 0,
166
+ **loss_kwargs,
167
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
168
+ r"""
169
+ Args:
170
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
171
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
172
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
173
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
174
+
175
+ num_logits_to_keep (`int`, *optional*):
176
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
177
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
178
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
179
+
180
+ Returns:
181
+
182
+ Example:
183
+
184
+ ```python
185
+ >>> from transformers import AutoTokenizer, GemmaForCausalLM
186
+
187
+ >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b")
188
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
189
+
190
+ >>> prompt = "What is your favorite condiment?"
191
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
192
+
193
+ >>> # Generate
194
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
195
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
196
+ "What is your favorite condiment?"
197
+ ```"""
198
+
199
+ if self.training and self.config._attn_implementation != "eager":
200
+ logger.warning_once(
201
+ "It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
202
+ f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
203
+ )
204
+ output_attentions = (
205
+ output_attentions
206
+ if output_attentions is not None
207
+ else self.config.output_attentions
208
+ )
209
+ output_hidden_states = (
210
+ output_hidden_states
211
+ if output_hidden_states is not None
212
+ else self.config.output_hidden_states
213
+ )
214
+ return_dict = (
215
+ return_dict if return_dict is not None else self.config.use_return_dict
216
+ )
217
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
218
+ outputs = self.model(
219
+ input_ids=input_ids,
220
+ attention_mask=attention_mask,
221
+ position_ids=position_ids,
222
+ past_key_values=past_key_values,
223
+ inputs_embeds=inputs_embeds,
224
+ use_cache=use_cache,
225
+ output_attentions=output_attentions,
226
+ output_hidden_states=output_hidden_states,
227
+ return_dict=return_dict,
228
+ cache_position=cache_position,
229
+ )
230
+
231
+ hidden_states = outputs[0]
232
+
233
+ logits = None
234
+ loss = None
235
+ # if in training mode, don't materialize logits
236
+ if self.training and (labels is not None):
237
+ # We do the same thing as ForCausalLMLoss but using Liger FLCE
238
+
239
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
240
+ shift_labels = labels[..., 1:].contiguous()
241
+
242
+ # flatten tokens
243
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
244
+ shift_labels = shift_labels.view(-1)
245
+
246
+ reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
247
+ lce = LigerFusedLinearCrossEntropyLoss(
248
+ softcap=self.config.final_logit_softcapping,
249
+ reduction=reduction,
250
+ )
251
+
252
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
253
+ if reduction == "sum":
254
+ loss /= loss_kwargs["num_items_in_batch"]
255
+
256
+ else: # if in inference mode materialize logits
257
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
258
+ if self.config.final_logit_softcapping is not None:
259
+ logits = logits / self.config.final_logit_softcapping
260
+ logits = torch.tanh(logits)
261
+ logits = logits * self.config.final_logit_softcapping
262
+
263
+ loss = None
264
+ if labels is not None:
265
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
266
+
267
+ if not return_dict:
268
+ output = (logits,) + outputs[1:]
269
+ return (loss,) + output if loss is not None else output
270
+
271
+ return CausalLMOutputWithPast(
272
+ loss=loss,
273
+ logits=logits,
274
+ past_key_values=outputs.past_key_values,
275
+ hidden_states=outputs.hidden_states,
276
+ attentions=outputs.attentions,
277
+ )
@@ -8,12 +8,17 @@ from packaging import version
8
8
  from transformers import PreTrainedModel
9
9
 
10
10
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
11
+ from liger_kernel.transformers.functional import liger_cross_entropy
11
12
  from liger_kernel.transformers.geglu import LigerGEGLUMLP
12
13
  from liger_kernel.transformers.layer_norm import LigerLayerNorm
13
14
  from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
14
15
  from liger_kernel.transformers.model.gemma import (
15
16
  lce_forward_deprecated as gemma_lce_forward_deprecated,
16
17
  )
18
+ from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
19
+ from liger_kernel.transformers.model.gemma2 import (
20
+ lce_forward_deprecated as gemma2_lce_forward_deprected,
21
+ )
17
22
  from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
18
23
  from liger_kernel.transformers.model.llama import (
19
24
  lce_forward_deprecated as llama_lce_forward_deprecated,
@@ -99,6 +104,7 @@ def apply_liger_kernel_to_llama(
99
104
  ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
100
105
 
101
106
  from transformers.models.llama import modeling_llama
107
+ from transformers.models.llama.modeling_llama import LlamaModel
102
108
 
103
109
  if rope:
104
110
  modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
@@ -106,8 +112,16 @@ def apply_liger_kernel_to_llama(
106
112
  modeling_llama.LlamaRMSNorm = LigerRMSNorm
107
113
  if swiglu:
108
114
  modeling_llama.LlamaMLP = LigerSwiGLUMLP
115
+
109
116
  if cross_entropy:
110
- modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
117
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
118
+ from transformers.loss.loss_utils import nn
119
+
120
+ nn.functional.cross_entropy = liger_cross_entropy
121
+ else:
122
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
123
+ modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
124
+
111
125
  if fused_linear_cross_entropy:
112
126
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
113
127
  modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
@@ -119,15 +133,8 @@ def apply_liger_kernel_to_llama(
119
133
  # The model instance already exists, so we need to additionally patch the
120
134
  # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
121
135
 
122
- if hasattr(model, "model"):
123
- # The case for LlamaForCausalLM or LlamaForSequenceClassification, for example
124
- base_model = model.model
125
- elif hasattr(model, "transformer"):
126
- # LlamaForQuestionAnswering uses "transformer" instead of "model"
127
- base_model = model.transformer
128
- else:
129
- # Direct LlamaModel
130
- base_model = model
136
+ # get the base model from the model instance
137
+ base_model: LlamaModel = getattr(model, model.base_model_prefix, model)
131
138
 
132
139
  if rms_norm:
133
140
  _patch_rms_norm_module(base_model.norm)
@@ -194,7 +201,13 @@ def apply_liger_kernel_to_mllama(
194
201
  if swiglu:
195
202
  modeling_mllama.MllamaTextMLP = LigerSwiGLUMLP
196
203
  if cross_entropy:
197
- modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
204
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
205
+ from transformers.loss.loss_utils import nn
206
+
207
+ nn.functional.cross_entropy = liger_cross_entropy
208
+ else:
209
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
210
+ modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
198
211
  if fused_linear_cross_entropy:
199
212
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
200
213
  modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
@@ -258,7 +271,7 @@ def apply_liger_kernel_to_mistral(
258
271
  Apply Liger kernels to replace original implementation in HuggingFace Mistral models
259
272
 
260
273
  Args:
261
- rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
274
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
262
275
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
263
276
  fused_linear_cross_entropy (bool):
264
277
  Whether to apply Liger's fused linear cross entropy loss. Default is True.
@@ -275,6 +288,7 @@ def apply_liger_kernel_to_mistral(
275
288
  ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
276
289
 
277
290
  from transformers.models.mistral import modeling_mistral
291
+ from transformers.models.mistral.modeling_mistral import MistralModel
278
292
 
279
293
  if rope:
280
294
  modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
@@ -291,12 +305,8 @@ def apply_liger_kernel_to_mistral(
291
305
  # The model instance already exists, so we need to additionally patch the
292
306
  # instance variables that reference already-instantiated modules
293
307
 
294
- if hasattr(model, "model"):
295
- # The case for MistralForCausalLM, MistralForTokenClassification for example
296
- base_model = model.model
297
- else:
298
- # Direct MistralModel
299
- base_model = model
308
+ # get the base model from the model instance
309
+ base_model: MistralModel = getattr(model, model.base_model_prefix, model)
300
310
 
301
311
  if rms_norm:
302
312
  _patch_rms_norm_module(base_model.norm)
@@ -340,13 +350,21 @@ def apply_liger_kernel_to_mixtral(
340
350
  ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
341
351
 
342
352
  from transformers.models.mixtral import modeling_mixtral
353
+ from transformers.models.mixtral.modeling_mixtral import MixtralModel
343
354
 
344
355
  if rope:
345
356
  modeling_mixtral.apply_rotary_pos_emb = liger_rotary_pos_emb
346
357
  if rms_norm:
347
358
  modeling_mixtral.MixtralRMSNorm = LigerRMSNorm
348
359
  if cross_entropy:
349
- modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss
360
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
361
+ from transformers.loss.loss_utils import nn
362
+
363
+ nn.functional.cross_entropy = liger_cross_entropy
364
+ else:
365
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
366
+ modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss
367
+
350
368
  if fused_linear_cross_entropy:
351
369
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
352
370
  modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
@@ -360,12 +378,8 @@ def apply_liger_kernel_to_mixtral(
360
378
  # The model instance already exists, so we need to additionally patch the
361
379
  # instance variables that reference already-instantiated modules
362
380
 
363
- if hasattr(model, "model"):
364
- # The case for MixtralForCausalLM, MixtralForTokenClassification for example
365
- base_model = model.model
366
- else:
367
- # Direct MixtralModel
368
- base_model = model
381
+ # get the base model from the model instance
382
+ base_model: MixtralModel = getattr(model, model.base_model_prefix, model)
369
383
 
370
384
  if rms_norm:
371
385
  _patch_rms_norm_module(base_model.norm)
@@ -410,6 +424,7 @@ def apply_liger_kernel_to_gemma(
410
424
  ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
411
425
 
412
426
  from transformers.models.gemma import modeling_gemma
427
+ from transformers.models.gemma.modeling_gemma import GemmaModel
413
428
 
414
429
  # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
415
430
  LigerRMSNormForGemma = partial(
@@ -424,7 +439,13 @@ def apply_liger_kernel_to_gemma(
424
439
  if rms_norm:
425
440
  modeling_gemma.GemmaRMSNorm = LigerRMSNormForGemma
426
441
  if cross_entropy:
427
- modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
442
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
443
+ from transformers.loss.loss_utils import nn
444
+
445
+ nn.functional.cross_entropy = liger_cross_entropy
446
+ else:
447
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
448
+ modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
428
449
  if geglu:
429
450
  modeling_gemma.GemmaMLP = LigerGEGLUMLP
430
451
  if fused_linear_cross_entropy:
@@ -438,12 +459,8 @@ def apply_liger_kernel_to_gemma(
438
459
  # The model instance already exists, so we need to additionally patch the
439
460
  # instance variables that reference already-instantiated modules
440
461
 
441
- if hasattr(model, "model"):
442
- # The case for GemmaForCausalLM, GemmaForTokenClassification for example
443
- base_model = model.model
444
- else:
445
- # Direct GemmaModel
446
- base_model = model
462
+ # get the base model from the model instance
463
+ base_model: GemmaModel = getattr(model, model.base_model_prefix, model)
447
464
 
448
465
  if rms_norm:
449
466
  _patch_rms_norm_module_for_gemma(base_model.norm)
@@ -460,7 +477,8 @@ def apply_liger_kernel_to_gemma(
460
477
 
461
478
  def apply_liger_kernel_to_gemma2(
462
479
  rope: bool = True,
463
- cross_entropy: bool = True,
480
+ cross_entropy: bool = False,
481
+ fused_linear_cross_entropy: bool = True,
464
482
  rms_norm: bool = True,
465
483
  geglu: bool = True,
466
484
  model: PreTrainedModel = None,
@@ -471,16 +489,25 @@ def apply_liger_kernel_to_gemma2(
471
489
 
472
490
  Args:
473
491
  rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
474
- cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
492
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
493
+ fused_linear_cross_entropy (bool):
494
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
495
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
496
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
475
497
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
476
498
  geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
477
499
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
478
500
  loaded. Default is None.
479
501
  """
502
+ assert not (
503
+ cross_entropy and fused_linear_cross_entropy
504
+ ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
505
+
480
506
  from transformers.models.gemma2 import modeling_gemma2
507
+ from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
481
508
 
482
509
  LigerRMSNormForGemma2 = partial(
483
- LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros"
510
+ LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False
484
511
  )
485
512
  _patch_rms_norm_module_for_gemma2 = partial(
486
513
  _patch_rms_norm_module, offset=1.0, casting_mode="gemma"
@@ -492,7 +519,19 @@ def apply_liger_kernel_to_gemma2(
492
519
  # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
493
520
  modeling_gemma2.Gemma2RMSNorm = LigerRMSNormForGemma2
494
521
  if cross_entropy:
495
- modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
522
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
523
+ from transformers.loss.loss_utils import nn
524
+
525
+ nn.functional.cross_entropy = liger_cross_entropy
526
+ else:
527
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
528
+ modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
529
+ if fused_linear_cross_entropy:
530
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
531
+ modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
532
+ else:
533
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
534
+ modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
496
535
  if geglu:
497
536
  modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
498
537
 
@@ -500,12 +539,8 @@ def apply_liger_kernel_to_gemma2(
500
539
  # The model instance already exists, so we need to additionally patch the
501
540
  # instance variables that reference already-instantiated modules
502
541
 
503
- if hasattr(model, "model"):
504
- # The case for Gemma2ForCausalLM, Gemma2ForTokenClassification for example
505
- base_model = model.model
506
- else:
507
- # Direct Gemma2Model
508
- base_model = model
542
+ # get the base model from the model instance
543
+ base_model: Gemma2Model = getattr(model, model.base_model_prefix, model)
509
544
 
510
545
  if rms_norm:
511
546
  _patch_rms_norm_module_for_gemma2(base_model.norm)
@@ -556,13 +591,21 @@ def apply_liger_kernel_to_qwen2(
556
591
  ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
557
592
 
558
593
  from transformers.models.qwen2 import modeling_qwen2
594
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
559
595
 
560
596
  if rope:
561
597
  modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
562
598
  if rms_norm:
563
599
  modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
600
+
564
601
  if cross_entropy:
565
- modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
602
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
603
+ from transformers.loss.loss_utils import nn
604
+
605
+ nn.functional.cross_entropy = liger_cross_entropy
606
+ else:
607
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
608
+ modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
566
609
 
567
610
  # import pdb; pdb.set_trace()
568
611
  if fused_linear_cross_entropy:
@@ -580,12 +623,8 @@ def apply_liger_kernel_to_qwen2(
580
623
  # The model instance already exists, so we need to additionally patch the
581
624
  # instance variables that reference already-instantiated modules
582
625
 
583
- if hasattr(model, "model"):
584
- # The case for Qwen2ForCausalLM, Qwen2ForTokenClassification for example
585
- base_model = model.model
586
- else:
587
- # Direct Qwen2Model
588
- base_model = model
626
+ # get the base model from the model instance
627
+ base_model: Qwen2Model = getattr(model, model.base_model_prefix, model)
589
628
 
590
629
  if rms_norm:
591
630
  _patch_rms_norm_module(base_model.norm)
@@ -630,6 +669,7 @@ def apply_liger_kernel_to_qwen2_vl(
630
669
  ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
631
670
 
632
671
  from transformers.models.qwen2_vl import modeling_qwen2_vl
672
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
633
673
 
634
674
  from liger_kernel.transformers.model.qwen2_vl import (
635
675
  lce_forward as qwen2_vl_lce_forward,
@@ -653,12 +693,8 @@ def apply_liger_kernel_to_qwen2_vl(
653
693
  # The model instance already exists, so we need to additionally patch the
654
694
  # instance variables that reference already-instantiated modules
655
695
 
656
- if hasattr(model, "model"):
657
- # The case for Qwen2VLForConditionalGeneration.
658
- base_model = model.model
659
- else:
660
- # Direct Qwen2VLModel
661
- base_model = model
696
+ # get the base model from the model instance
697
+ base_model: Qwen2VLModel = getattr(model, model.base_model_prefix, model)
662
698
 
663
699
  if hasattr(model, "visual"):
664
700
  # Patch Qwen2VisionTransformerPretrainedModel
@@ -707,6 +743,7 @@ def apply_liger_kernel_to_phi3(
707
743
  ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
708
744
 
709
745
  from transformers.models.phi3 import modeling_phi3
746
+ from transformers.models.phi3.modeling_phi3 import Phi3Model
710
747
 
711
748
  if rope:
712
749
  modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
@@ -715,7 +752,13 @@ def apply_liger_kernel_to_phi3(
715
752
  if swiglu:
716
753
  modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
717
754
  if cross_entropy:
718
- modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
755
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
756
+ from transformers.loss.loss_utils import nn
757
+
758
+ nn.functional.cross_entropy = liger_cross_entropy
759
+ else:
760
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
761
+ modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
719
762
  if fused_linear_cross_entropy:
720
763
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
721
764
  modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
@@ -727,12 +770,8 @@ def apply_liger_kernel_to_phi3(
727
770
  # The model instance already exists, so we need to additionally patch the
728
771
  # instance variables that reference already-instantiated modules
729
772
 
730
- if hasattr(model, "model"):
731
- # The case for Phi3ForCausalLM, Phi3ForTokenClassification for example
732
- base_model = model.model
733
- else:
734
- # Direct Phi3Model
735
- base_model = model
773
+ # get the base model from the model instance
774
+ base_model: Phi3Model = getattr(model, model.base_model_prefix, model)
736
775
 
737
776
  if rms_norm:
738
777
  _patch_rms_norm_module(base_model.norm)
@@ -6,7 +6,13 @@ from liger_kernel.ops.rms_norm import LigerRMSNormFunction
6
6
 
7
7
  class LigerRMSNorm(nn.Module):
8
8
  def __init__(
9
- self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones"
9
+ self,
10
+ hidden_size,
11
+ eps=1e-6,
12
+ offset=0.0,
13
+ casting_mode="llama",
14
+ init_fn="ones",
15
+ in_place=True,
10
16
  ):
11
17
  super().__init__()
12
18
  assert init_fn in [
@@ -16,10 +22,11 @@ class LigerRMSNorm(nn.Module):
16
22
  self.weight = nn.Parameter(
17
23
  torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
18
24
  )
19
- self.variance_epsilon, self.offset, self.casting_mode = (
25
+ self.variance_epsilon, self.offset, self.casting_mode, self.in_place = (
20
26
  eps,
21
27
  offset,
22
28
  casting_mode,
29
+ in_place,
23
30
  )
24
31
 
25
32
  def forward(self, hidden_states):
@@ -29,7 +36,8 @@ class LigerRMSNorm(nn.Module):
29
36
  self.variance_epsilon,
30
37
  self.offset,
31
38
  self.casting_mode,
39
+ self.in_place,
32
40
  )
33
41
 
34
42
  def extra_repr(self):
35
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}"
43
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}"