liger-kernel 0.3.1__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. liger_kernel/env_report.py +2 -0
  2. liger_kernel/ops/cross_entropy.py +144 -65
  3. liger_kernel/ops/experimental/mm_int8int2.py +355 -0
  4. liger_kernel/ops/fused_linear_cross_entropy.py +31 -11
  5. liger_kernel/ops/fused_linear_jsd.py +245 -0
  6. liger_kernel/ops/geglu.py +2 -2
  7. liger_kernel/ops/group_norm.py +322 -0
  8. liger_kernel/ops/jsd.py +176 -0
  9. liger_kernel/ops/kl_div.py +2 -2
  10. liger_kernel/ops/rms_norm.py +92 -46
  11. liger_kernel/ops/swiglu.py +2 -2
  12. liger_kernel/ops/utils.py +62 -1
  13. liger_kernel/transformers/__init__.py +3 -0
  14. liger_kernel/transformers/cross_entropy.py +44 -12
  15. liger_kernel/transformers/functional.py +38 -1
  16. liger_kernel/transformers/fused_linear_cross_entropy.py +31 -4
  17. liger_kernel/transformers/fused_linear_jsd.py +98 -0
  18. liger_kernel/transformers/group_norm.py +56 -0
  19. liger_kernel/transformers/jsd.py +75 -0
  20. liger_kernel/transformers/model/gemma.py +124 -1
  21. liger_kernel/transformers/model/gemma2.py +277 -0
  22. liger_kernel/transformers/model/llama.py +135 -4
  23. liger_kernel/transformers/model/mistral.py +3 -0
  24. liger_kernel/transformers/model/mixtral.py +153 -2
  25. liger_kernel/transformers/model/mllama.py +274 -0
  26. liger_kernel/transformers/model/phi3.py +140 -2
  27. liger_kernel/transformers/model/qwen2.py +123 -2
  28. liger_kernel/transformers/model/qwen2_vl.py +8 -1
  29. liger_kernel/transformers/monkey_patch.py +258 -68
  30. liger_kernel/transformers/rms_norm.py +11 -3
  31. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/METADATA +63 -29
  32. liger_kernel-0.4.1.dist-info/NOTICE +58 -0
  33. liger_kernel-0.4.1.dist-info/RECORD +51 -0
  34. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/WHEEL +1 -1
  35. liger_kernel-0.3.1.dist-info/NOTICE +0 -4
  36. liger_kernel-0.3.1.dist-info/RECORD +0 -42
  37. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/LICENSE +0 -0
  38. {liger_kernel-0.3.1.dist-info → liger_kernel-0.4.1.dist-info}/top_level.txt +0 -0
@@ -3,17 +3,39 @@ import logging
3
3
  from functools import partial
4
4
  from typing import Callable
5
5
 
6
+ import transformers
7
+ from packaging import version
6
8
  from transformers import PreTrainedModel
7
9
 
8
10
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
11
+ from liger_kernel.transformers.functional import liger_cross_entropy
9
12
  from liger_kernel.transformers.geglu import LigerGEGLUMLP
10
13
  from liger_kernel.transformers.layer_norm import LigerLayerNorm
11
14
  from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
15
+ from liger_kernel.transformers.model.gemma import (
16
+ lce_forward_deprecated as gemma_lce_forward_deprecated,
17
+ )
18
+ from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
19
+ from liger_kernel.transformers.model.gemma2 import (
20
+ lce_forward_deprecated as gemma2_lce_forward_deprected,
21
+ )
12
22
  from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
23
+ from liger_kernel.transformers.model.llama import (
24
+ lce_forward_deprecated as llama_lce_forward_deprecated,
25
+ )
13
26
  from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
14
27
  from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
28
+ from liger_kernel.transformers.model.mixtral import (
29
+ lce_forward_deprecated as mixtral_lce_forward_deprecated,
30
+ )
15
31
  from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
32
+ from liger_kernel.transformers.model.phi3 import (
33
+ lce_forward_deprecated as phi3_lce_forward_deprecated,
34
+ )
16
35
  from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
36
+ from liger_kernel.transformers.model.qwen2 import (
37
+ lce_forward_deprecated as qwen2_lce_forward_deprecated,
38
+ )
17
39
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
18
40
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
19
41
  from liger_kernel.transformers.swiglu import (
@@ -22,7 +44,11 @@ from liger_kernel.transformers.swiglu import (
22
44
  LigerSwiGLUMLP,
23
45
  )
24
46
 
47
+ transformer_version = version.parse(transformers.__version__)
48
+
25
49
  logger = logging.getLogger(__name__)
50
+ SUPPORTED_TRANSFORMER_VERSION = "4.46.1"
51
+ TRANSFORMER_DEPRECATION_WARNING = "Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/34191"
26
52
 
27
53
 
28
54
  def _bind_method_to_module(module, method_name: str, new_method: Callable):
@@ -78,6 +104,7 @@ def apply_liger_kernel_to_llama(
78
104
  ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
79
105
 
80
106
  from transformers.models.llama import modeling_llama
107
+ from transformers.models.llama.modeling_llama import LlamaModel
81
108
 
82
109
  if rope:
83
110
  modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
@@ -85,24 +112,29 @@ def apply_liger_kernel_to_llama(
85
112
  modeling_llama.LlamaRMSNorm = LigerRMSNorm
86
113
  if swiglu:
87
114
  modeling_llama.LlamaMLP = LigerSwiGLUMLP
115
+
88
116
  if cross_entropy:
89
- modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
117
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
118
+ from transformers.loss.loss_utils import nn
119
+
120
+ nn.functional.cross_entropy = liger_cross_entropy
121
+ else:
122
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
123
+ modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
124
+
90
125
  if fused_linear_cross_entropy:
91
- modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
126
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
127
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
128
+ else: # if version < 4.46.1
129
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
130
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
92
131
 
93
132
  if model is not None:
94
133
  # The model instance already exists, so we need to additionally patch the
95
134
  # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
96
135
 
97
- if hasattr(model, "model"):
98
- # The case for LlamaForCausalLM or LlamaForSequenceClassification, for example
99
- base_model = model.model
100
- elif hasattr(model, "transformer"):
101
- # LlamaForQuestionAnswering uses "transformer" instead of "model"
102
- base_model = model.transformer
103
- else:
104
- # Direct LlamaModel
105
- base_model = model
136
+ # get the base model from the model instance
137
+ base_model: LlamaModel = getattr(model, model.base_model_prefix, model)
106
138
 
107
139
  if rms_norm:
108
140
  _patch_rms_norm_module(base_model.norm)
@@ -117,6 +149,116 @@ def apply_liger_kernel_to_llama(
117
149
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
118
150
 
119
151
 
152
+ def apply_liger_kernel_to_mllama(
153
+ rope: bool = True,
154
+ cross_entropy: bool = False,
155
+ fused_linear_cross_entropy: bool = True,
156
+ layer_norm: bool = True,
157
+ rms_norm: bool = True,
158
+ swiglu: bool = True,
159
+ model: PreTrainedModel = None,
160
+ ) -> None:
161
+ """
162
+ Apply Liger kernels to replace original implementation in HuggingFace MLlama models.
163
+ NOTE: MLlama is not available in transformers<4.45.0
164
+
165
+ Args:
166
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
167
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
168
+ fused_linear_cross_entropy (bool):
169
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
170
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
171
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
172
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
173
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
174
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
175
+ loaded. Default is None.
176
+ """
177
+
178
+ assert not (
179
+ cross_entropy and fused_linear_cross_entropy
180
+ ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
181
+
182
+ from transformers.models.mllama import modeling_mllama
183
+ from transformers.models.mllama.modeling_mllama import (
184
+ MllamaForCausalLM,
185
+ MllamaForConditionalGeneration,
186
+ MllamaTextModel,
187
+ MllamaVisionModel,
188
+ )
189
+
190
+ from liger_kernel.transformers.model.mllama import lce_forward as mllama_lce_forward
191
+ from liger_kernel.transformers.model.mllama import (
192
+ lce_forward_deprecated as mllama_lce_forward_deprecated,
193
+ )
194
+
195
+ if rope:
196
+ modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
197
+ if layer_norm:
198
+ modeling_mllama.nn.LayerNorm = LigerLayerNorm
199
+ if rms_norm:
200
+ modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
201
+ if swiglu:
202
+ modeling_mllama.MllamaTextMLP = LigerSwiGLUMLP
203
+ if cross_entropy:
204
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
205
+ from transformers.loss.loss_utils import nn
206
+
207
+ nn.functional.cross_entropy = liger_cross_entropy
208
+ else:
209
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
210
+ modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
211
+ if fused_linear_cross_entropy:
212
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
213
+ modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
214
+ else: # if version < 4.46.1
215
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
216
+ modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
217
+
218
+ if model is not None:
219
+ # The model instance already exists, so we need to additionally patch the
220
+ # instance variables that reference already-instantiated modules
221
+
222
+ if isinstance(model, MllamaForConditionalGeneration):
223
+ language_model: MllamaForCausalLM = model.language_model
224
+ vision_model: MllamaVisionModel = model.vision_model
225
+ text_model: MllamaTextModel = language_model.model
226
+ elif isinstance(model, MllamaForCausalLM):
227
+ text_model = model.model
228
+ vision_model = None
229
+ elif isinstance(model, MllamaTextModel):
230
+ text_model = model
231
+ vision_model = None
232
+ else:
233
+ raise ValueError(f"Unsupported Mllama model type: {type(model)}")
234
+
235
+ if text_model:
236
+ if rms_norm:
237
+ _patch_rms_norm_module(text_model.norm)
238
+ for decoder_layer in text_model.layers:
239
+ if swiglu:
240
+ _bind_method_to_module(
241
+ decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
242
+ )
243
+ if rms_norm:
244
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
245
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
246
+
247
+ if vision_model:
248
+ _patch_layer_norm_module(vision_model.layernorm_pre)
249
+ _patch_layer_norm_module(vision_model.layernorm_post)
250
+
251
+ for layer in vision_model.transformer.layers:
252
+ if layer_norm:
253
+ _patch_layer_norm_module(layer.input_layernorm)
254
+ _patch_layer_norm_module(layer.post_attention_layernorm)
255
+
256
+ for layer in vision_model.global_transformer.layers:
257
+ if layer_norm:
258
+ _patch_layer_norm_module(layer.input_layernorm)
259
+ _patch_layer_norm_module(layer.post_attention_layernorm)
260
+
261
+
120
262
  def apply_liger_kernel_to_mistral(
121
263
  rope: bool = True,
122
264
  cross_entropy: bool = False,
@@ -129,7 +271,7 @@ def apply_liger_kernel_to_mistral(
129
271
  Apply Liger kernels to replace original implementation in HuggingFace Mistral models
130
272
 
131
273
  Args:
132
- rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
274
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
133
275
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
134
276
  fused_linear_cross_entropy (bool):
135
277
  Whether to apply Liger's fused linear cross entropy loss. Default is True.
@@ -146,6 +288,7 @@ def apply_liger_kernel_to_mistral(
146
288
  ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
147
289
 
148
290
  from transformers.models.mistral import modeling_mistral
291
+ from transformers.models.mistral.modeling_mistral import MistralModel
149
292
 
150
293
  if rope:
151
294
  modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
@@ -162,12 +305,8 @@ def apply_liger_kernel_to_mistral(
162
305
  # The model instance already exists, so we need to additionally patch the
163
306
  # instance variables that reference already-instantiated modules
164
307
 
165
- if hasattr(model, "model"):
166
- # The case for MistralForCausalLM, MistralForTokenClassification for example
167
- base_model = model.model
168
- else:
169
- # Direct MistralModel
170
- base_model = model
308
+ # get the base model from the model instance
309
+ base_model: MistralModel = getattr(model, model.base_model_prefix, model)
171
310
 
172
311
  if rms_norm:
173
312
  _patch_rms_norm_module(base_model.norm)
@@ -211,15 +350,27 @@ def apply_liger_kernel_to_mixtral(
211
350
  ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
212
351
 
213
352
  from transformers.models.mixtral import modeling_mixtral
353
+ from transformers.models.mixtral.modeling_mixtral import MixtralModel
214
354
 
215
355
  if rope:
216
356
  modeling_mixtral.apply_rotary_pos_emb = liger_rotary_pos_emb
217
357
  if rms_norm:
218
358
  modeling_mixtral.MixtralRMSNorm = LigerRMSNorm
219
359
  if cross_entropy:
220
- modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss
360
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
361
+ from transformers.loss.loss_utils import nn
362
+
363
+ nn.functional.cross_entropy = liger_cross_entropy
364
+ else:
365
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
366
+ modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss
367
+
221
368
  if fused_linear_cross_entropy:
222
- modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
369
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
370
+ modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
371
+ else: # if version < 4.46.1
372
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
373
+ modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
223
374
  if swiglu:
224
375
  modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
225
376
 
@@ -227,12 +378,8 @@ def apply_liger_kernel_to_mixtral(
227
378
  # The model instance already exists, so we need to additionally patch the
228
379
  # instance variables that reference already-instantiated modules
229
380
 
230
- if hasattr(model, "model"):
231
- # The case for MixtralForCausalLM, MixtralForTokenClassification for example
232
- base_model = model.model
233
- else:
234
- # Direct MixtralModel
235
- base_model = model
381
+ # get the base model from the model instance
382
+ base_model: MixtralModel = getattr(model, model.base_model_prefix, model)
236
383
 
237
384
  if rms_norm:
238
385
  _patch_rms_norm_module(base_model.norm)
@@ -277,6 +424,7 @@ def apply_liger_kernel_to_gemma(
277
424
  ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
278
425
 
279
426
  from transformers.models.gemma import modeling_gemma
427
+ from transformers.models.gemma.modeling_gemma import GemmaModel
280
428
 
281
429
  # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
282
430
  LigerRMSNormForGemma = partial(
@@ -291,22 +439,28 @@ def apply_liger_kernel_to_gemma(
291
439
  if rms_norm:
292
440
  modeling_gemma.GemmaRMSNorm = LigerRMSNormForGemma
293
441
  if cross_entropy:
294
- modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
442
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
443
+ from transformers.loss.loss_utils import nn
444
+
445
+ nn.functional.cross_entropy = liger_cross_entropy
446
+ else:
447
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
448
+ modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
295
449
  if geglu:
296
450
  modeling_gemma.GemmaMLP = LigerGEGLUMLP
297
451
  if fused_linear_cross_entropy:
298
- modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
452
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
453
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
454
+ else: # if version < 4.46.1
455
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
456
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
299
457
 
300
458
  if model is not None:
301
459
  # The model instance already exists, so we need to additionally patch the
302
460
  # instance variables that reference already-instantiated modules
303
461
 
304
- if hasattr(model, "model"):
305
- # The case for GemmaForCausalLM, GemmaForTokenClassification for example
306
- base_model = model.model
307
- else:
308
- # Direct GemmaModel
309
- base_model = model
462
+ # get the base model from the model instance
463
+ base_model: GemmaModel = getattr(model, model.base_model_prefix, model)
310
464
 
311
465
  if rms_norm:
312
466
  _patch_rms_norm_module_for_gemma(base_model.norm)
@@ -323,7 +477,8 @@ def apply_liger_kernel_to_gemma(
323
477
 
324
478
  def apply_liger_kernel_to_gemma2(
325
479
  rope: bool = True,
326
- cross_entropy: bool = True,
480
+ cross_entropy: bool = False,
481
+ fused_linear_cross_entropy: bool = True,
327
482
  rms_norm: bool = True,
328
483
  geglu: bool = True,
329
484
  model: PreTrainedModel = None,
@@ -334,16 +489,25 @@ def apply_liger_kernel_to_gemma2(
334
489
 
335
490
  Args:
336
491
  rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
337
- cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
492
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
493
+ fused_linear_cross_entropy (bool):
494
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
495
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
496
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
338
497
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
339
498
  geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
340
499
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
341
500
  loaded. Default is None.
342
501
  """
502
+ assert not (
503
+ cross_entropy and fused_linear_cross_entropy
504
+ ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
505
+
343
506
  from transformers.models.gemma2 import modeling_gemma2
507
+ from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
344
508
 
345
509
  LigerRMSNormForGemma2 = partial(
346
- LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros"
510
+ LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False
347
511
  )
348
512
  _patch_rms_norm_module_for_gemma2 = partial(
349
513
  _patch_rms_norm_module, offset=1.0, casting_mode="gemma"
@@ -355,7 +519,19 @@ def apply_liger_kernel_to_gemma2(
355
519
  # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
356
520
  modeling_gemma2.Gemma2RMSNorm = LigerRMSNormForGemma2
357
521
  if cross_entropy:
358
- modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
522
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
523
+ from transformers.loss.loss_utils import nn
524
+
525
+ nn.functional.cross_entropy = liger_cross_entropy
526
+ else:
527
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
528
+ modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
529
+ if fused_linear_cross_entropy:
530
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
531
+ modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
532
+ else:
533
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
534
+ modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
359
535
  if geglu:
360
536
  modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
361
537
 
@@ -363,12 +539,8 @@ def apply_liger_kernel_to_gemma2(
363
539
  # The model instance already exists, so we need to additionally patch the
364
540
  # instance variables that reference already-instantiated modules
365
541
 
366
- if hasattr(model, "model"):
367
- # The case for Gemma2ForCausalLM, Gemma2ForTokenClassification for example
368
- base_model = model.model
369
- else:
370
- # Direct Gemma2Model
371
- base_model = model
542
+ # get the base model from the model instance
543
+ base_model: Gemma2Model = getattr(model, model.base_model_prefix, model)
372
544
 
373
545
  if rms_norm:
374
546
  _patch_rms_norm_module_for_gemma2(base_model.norm)
@@ -419,15 +591,31 @@ def apply_liger_kernel_to_qwen2(
419
591
  ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
420
592
 
421
593
  from transformers.models.qwen2 import modeling_qwen2
594
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
422
595
 
423
596
  if rope:
424
597
  modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
425
598
  if rms_norm:
426
599
  modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
600
+
427
601
  if cross_entropy:
428
- modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
602
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
603
+ from transformers.loss.loss_utils import nn
604
+
605
+ nn.functional.cross_entropy = liger_cross_entropy
606
+ else:
607
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
608
+ modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
609
+
610
+ # import pdb; pdb.set_trace()
429
611
  if fused_linear_cross_entropy:
430
- modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
612
+
613
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
614
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
615
+ else: # if version < 4.46.1
616
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
617
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
618
+
431
619
  if swiglu:
432
620
  modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
433
621
 
@@ -435,12 +623,8 @@ def apply_liger_kernel_to_qwen2(
435
623
  # The model instance already exists, so we need to additionally patch the
436
624
  # instance variables that reference already-instantiated modules
437
625
 
438
- if hasattr(model, "model"):
439
- # The case for Qwen2ForCausalLM, Qwen2ForTokenClassification for example
440
- base_model = model.model
441
- else:
442
- # Direct Qwen2Model
443
- base_model = model
626
+ # get the base model from the model instance
627
+ base_model: Qwen2Model = getattr(model, model.base_model_prefix, model)
444
628
 
445
629
  if rms_norm:
446
630
  _patch_rms_norm_module(base_model.norm)
@@ -453,6 +637,7 @@ def apply_liger_kernel_to_qwen2(
453
637
  if rms_norm:
454
638
  _patch_rms_norm_module(decoder_layer.input_layernorm)
455
639
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
640
+ print("Applied Liger kernels to Qwen2")
456
641
 
457
642
 
458
643
  def apply_liger_kernel_to_qwen2_vl(
@@ -465,7 +650,7 @@ def apply_liger_kernel_to_qwen2_vl(
465
650
  ) -> None:
466
651
  """
467
652
  Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
468
- NOTE: Qwen2-VL is not available in transformers<=4.44.2
653
+ NOTE: Qwen2-VL is not available in transformers<4.45.0
469
654
 
470
655
  Args:
471
656
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
@@ -484,6 +669,7 @@ def apply_liger_kernel_to_qwen2_vl(
484
669
  ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
485
670
 
486
671
  from transformers.models.qwen2_vl import modeling_qwen2_vl
672
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
487
673
 
488
674
  from liger_kernel.transformers.model.qwen2_vl import (
489
675
  lce_forward as qwen2_vl_lce_forward,
@@ -507,12 +693,8 @@ def apply_liger_kernel_to_qwen2_vl(
507
693
  # The model instance already exists, so we need to additionally patch the
508
694
  # instance variables that reference already-instantiated modules
509
695
 
510
- if hasattr(model, "model"):
511
- # The case for Qwen2VLForConditionalGeneration.
512
- base_model = model.model
513
- else:
514
- # Direct Qwen2VLModel
515
- base_model = model
696
+ # get the base model from the model instance
697
+ base_model: Qwen2VLModel = getattr(model, model.base_model_prefix, model)
516
698
 
517
699
  if hasattr(model, "visual"):
518
700
  # Patch Qwen2VisionTransformerPretrainedModel
@@ -561,6 +743,7 @@ def apply_liger_kernel_to_phi3(
561
743
  ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
562
744
 
563
745
  from transformers.models.phi3 import modeling_phi3
746
+ from transformers.models.phi3.modeling_phi3 import Phi3Model
564
747
 
565
748
  if rope:
566
749
  modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
@@ -569,20 +752,26 @@ def apply_liger_kernel_to_phi3(
569
752
  if swiglu:
570
753
  modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
571
754
  if cross_entropy:
572
- modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
755
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
756
+ from transformers.loss.loss_utils import nn
757
+
758
+ nn.functional.cross_entropy = liger_cross_entropy
759
+ else:
760
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
761
+ modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
573
762
  if fused_linear_cross_entropy:
574
- modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
763
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
764
+ modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
765
+ else: # if version < 4.46.1
766
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
767
+ modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
575
768
 
576
769
  if model is not None:
577
770
  # The model instance already exists, so we need to additionally patch the
578
771
  # instance variables that reference already-instantiated modules
579
772
 
580
- if hasattr(model, "model"):
581
- # The case for Phi3ForCausalLM, Phi3ForTokenClassification for example
582
- base_model = model.model
583
- else:
584
- # Direct Phi3Model
585
- base_model = model
773
+ # get the base model from the model instance
774
+ base_model: Phi3Model = getattr(model, model.base_model_prefix, model)
586
775
 
587
776
  if rms_norm:
588
777
  _patch_rms_norm_module(base_model.norm)
@@ -602,6 +791,8 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
602
791
  "gemma": apply_liger_kernel_to_gemma,
603
792
  "gemma2": apply_liger_kernel_to_gemma2,
604
793
  "llama": apply_liger_kernel_to_llama,
794
+ "mllama": apply_liger_kernel_to_mllama,
795
+ "mllama_text_model": apply_liger_kernel_to_mllama,
605
796
  "mistral": apply_liger_kernel_to_mistral,
606
797
  "mixtral": apply_liger_kernel_to_mixtral,
607
798
  "qwen2": apply_liger_kernel_to_qwen2,
@@ -687,7 +878,6 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
687
878
  for key, value in kwargs.items()
688
879
  if key in apply_fn_signature.parameters
689
880
  }
690
-
691
881
  logger.info(
692
882
  f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}"
693
883
  )
@@ -6,7 +6,13 @@ from liger_kernel.ops.rms_norm import LigerRMSNormFunction
6
6
 
7
7
  class LigerRMSNorm(nn.Module):
8
8
  def __init__(
9
- self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones"
9
+ self,
10
+ hidden_size,
11
+ eps=1e-6,
12
+ offset=0.0,
13
+ casting_mode="llama",
14
+ init_fn="ones",
15
+ in_place=True,
10
16
  ):
11
17
  super().__init__()
12
18
  assert init_fn in [
@@ -16,10 +22,11 @@ class LigerRMSNorm(nn.Module):
16
22
  self.weight = nn.Parameter(
17
23
  torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
18
24
  )
19
- self.variance_epsilon, self.offset, self.casting_mode = (
25
+ self.variance_epsilon, self.offset, self.casting_mode, self.in_place = (
20
26
  eps,
21
27
  offset,
22
28
  casting_mode,
29
+ in_place,
23
30
  )
24
31
 
25
32
  def forward(self, hidden_states):
@@ -29,7 +36,8 @@ class LigerRMSNorm(nn.Module):
29
36
  self.variance_epsilon,
30
37
  self.offset,
31
38
  self.casting_mode,
39
+ self.in_place,
32
40
  )
33
41
 
34
42
  def extra_repr(self):
35
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}"
43
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}"