liger-kernel 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. liger_kernel/ops/cross_entropy.py +5 -39
  2. liger_kernel/ops/experimental/mm_int8int2.py +355 -0
  3. liger_kernel/ops/fused_linear_cross_entropy.py +13 -10
  4. liger_kernel/ops/fused_linear_jsd.py +245 -0
  5. liger_kernel/ops/geglu.py +2 -2
  6. liger_kernel/ops/jsd.py +176 -0
  7. liger_kernel/ops/kl_div.py +45 -34
  8. liger_kernel/ops/rms_norm.py +67 -42
  9. liger_kernel/ops/swiglu.py +2 -2
  10. liger_kernel/ops/utils.py +62 -1
  11. liger_kernel/transformers/__init__.py +3 -0
  12. liger_kernel/transformers/auto_model.py +18 -6
  13. liger_kernel/transformers/functional.py +4 -0
  14. liger_kernel/transformers/fused_linear_jsd.py +98 -0
  15. liger_kernel/transformers/jsd.py +75 -0
  16. liger_kernel/transformers/kl_div.py +3 -2
  17. liger_kernel/transformers/model/gemma.py +124 -1
  18. liger_kernel/transformers/model/llama.py +135 -4
  19. liger_kernel/transformers/model/mistral.py +3 -0
  20. liger_kernel/transformers/model/mixtral.py +153 -2
  21. liger_kernel/transformers/model/mllama.py +274 -0
  22. liger_kernel/transformers/model/phi3.py +140 -2
  23. liger_kernel/transformers/model/qwen2.py +123 -2
  24. liger_kernel/transformers/model/qwen2_vl.py +8 -1
  25. liger_kernel/transformers/monkey_patch.py +254 -129
  26. {liger_kernel-0.3.0.dist-info → liger_kernel-0.4.0.dist-info}/METADATA +74 -35
  27. liger_kernel-0.4.0.dist-info/NOTICE +58 -0
  28. liger_kernel-0.4.0.dist-info/RECORD +48 -0
  29. {liger_kernel-0.3.0.dist-info → liger_kernel-0.4.0.dist-info}/WHEEL +1 -1
  30. liger_kernel-0.3.0.dist-info/NOTICE +0 -4
  31. liger_kernel-0.3.0.dist-info/RECORD +0 -42
  32. {liger_kernel-0.3.0.dist-info → liger_kernel-0.4.0.dist-info}/LICENSE +0 -0
  33. {liger_kernel-0.3.0.dist-info → liger_kernel-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,19 +1,36 @@
1
1
  import inspect
2
2
  import logging
3
3
  from functools import partial
4
+ from typing import Callable
4
5
 
5
- from torch import nn
6
- from transformers import PretrainedConfig, PreTrainedModel
6
+ import transformers
7
+ from packaging import version
8
+ from transformers import PreTrainedModel
7
9
 
8
10
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
9
11
  from liger_kernel.transformers.geglu import LigerGEGLUMLP
10
12
  from liger_kernel.transformers.layer_norm import LigerLayerNorm
11
13
  from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
14
+ from liger_kernel.transformers.model.gemma import (
15
+ lce_forward_deprecated as gemma_lce_forward_deprecated,
16
+ )
12
17
  from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
18
+ from liger_kernel.transformers.model.llama import (
19
+ lce_forward_deprecated as llama_lce_forward_deprecated,
20
+ )
13
21
  from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
14
22
  from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
23
+ from liger_kernel.transformers.model.mixtral import (
24
+ lce_forward_deprecated as mixtral_lce_forward_deprecated,
25
+ )
15
26
  from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
27
+ from liger_kernel.transformers.model.phi3 import (
28
+ lce_forward_deprecated as phi3_lce_forward_deprecated,
29
+ )
16
30
  from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
31
+ from liger_kernel.transformers.model.qwen2 import (
32
+ lce_forward_deprecated as qwen2_lce_forward_deprecated,
33
+ )
17
34
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
18
35
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
19
36
  from liger_kernel.transformers.swiglu import (
@@ -22,7 +39,35 @@ from liger_kernel.transformers.swiglu import (
22
39
  LigerSwiGLUMLP,
23
40
  )
24
41
 
42
+ transformer_version = version.parse(transformers.__version__)
43
+
25
44
  logger = logging.getLogger(__name__)
45
+ SUPPORTED_TRANSFORMER_VERSION = "4.46.1"
46
+ TRANSFORMER_DEPRECATION_WARNING = "Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/34191"
47
+
48
+
49
+ def _bind_method_to_module(module, method_name: str, new_method: Callable):
50
+ # Binds a new method to a module instance so that self is passed as the first argument
51
+ module.__dict__[method_name] = new_method.__get__(module, module.__class__)
52
+
53
+
54
+ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama"):
55
+ module.offset = offset
56
+ module.casting_mode = casting_mode
57
+ module.variance_epsilon = (
58
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
59
+ )
60
+ _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
61
+ _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
62
+
63
+
64
+ def _patch_layer_norm_module(module, eps=1e-6):
65
+ module.variance_epsilon = (
66
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
67
+ )
68
+ module.hidden_size = module.normalized_shape
69
+ _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
70
+ _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
26
71
 
27
72
 
28
73
  def apply_liger_kernel_to_llama(
@@ -64,12 +109,15 @@ def apply_liger_kernel_to_llama(
64
109
  if cross_entropy:
65
110
  modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
66
111
  if fused_linear_cross_entropy:
67
- modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
112
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
113
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
114
+ else: # if version < 4.46.1
115
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
116
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
68
117
 
69
118
  if model is not None:
70
119
  # The model instance already exists, so we need to additionally patch the
71
120
  # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
72
- config: PretrainedConfig = model.config
73
121
 
74
122
  if hasattr(model, "model"):
75
123
  # The case for LlamaForCausalLM or LlamaForSequenceClassification, for example
@@ -81,22 +129,121 @@ def apply_liger_kernel_to_llama(
81
129
  # Direct LlamaModel
82
130
  base_model = model
83
131
 
84
- torch_dtype = config.torch_dtype
85
132
  if rms_norm:
86
- base_model.norm = LigerRMSNorm(
87
- config.hidden_size, eps=config.rms_norm_eps
88
- ).to(torch_dtype)
133
+ _patch_rms_norm_module(base_model.norm)
89
134
 
90
135
  for decoder_layer in base_model.layers:
91
136
  if swiglu:
92
- decoder_layer.mlp = LigerSwiGLUMLP(config).to(torch_dtype)
137
+ _bind_method_to_module(
138
+ decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
139
+ )
140
+ if rms_norm:
141
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
142
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
143
+
144
+
145
+ def apply_liger_kernel_to_mllama(
146
+ rope: bool = True,
147
+ cross_entropy: bool = False,
148
+ fused_linear_cross_entropy: bool = True,
149
+ layer_norm: bool = True,
150
+ rms_norm: bool = True,
151
+ swiglu: bool = True,
152
+ model: PreTrainedModel = None,
153
+ ) -> None:
154
+ """
155
+ Apply Liger kernels to replace original implementation in HuggingFace MLlama models.
156
+ NOTE: MLlama is not available in transformers<4.45.0
157
+
158
+ Args:
159
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
160
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
161
+ fused_linear_cross_entropy (bool):
162
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
163
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
164
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
165
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
166
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
167
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
168
+ loaded. Default is None.
169
+ """
170
+
171
+ assert not (
172
+ cross_entropy and fused_linear_cross_entropy
173
+ ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
174
+
175
+ from transformers.models.mllama import modeling_mllama
176
+ from transformers.models.mllama.modeling_mllama import (
177
+ MllamaForCausalLM,
178
+ MllamaForConditionalGeneration,
179
+ MllamaTextModel,
180
+ MllamaVisionModel,
181
+ )
182
+
183
+ from liger_kernel.transformers.model.mllama import lce_forward as mllama_lce_forward
184
+ from liger_kernel.transformers.model.mllama import (
185
+ lce_forward_deprecated as mllama_lce_forward_deprecated,
186
+ )
187
+
188
+ if rope:
189
+ modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
190
+ if layer_norm:
191
+ modeling_mllama.nn.LayerNorm = LigerLayerNorm
192
+ if rms_norm:
193
+ modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
194
+ if swiglu:
195
+ modeling_mllama.MllamaTextMLP = LigerSwiGLUMLP
196
+ if cross_entropy:
197
+ modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
198
+ if fused_linear_cross_entropy:
199
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
200
+ modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
201
+ else: # if version < 4.46.1
202
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
203
+ modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
204
+
205
+ if model is not None:
206
+ # The model instance already exists, so we need to additionally patch the
207
+ # instance variables that reference already-instantiated modules
208
+
209
+ if isinstance(model, MllamaForConditionalGeneration):
210
+ language_model: MllamaForCausalLM = model.language_model
211
+ vision_model: MllamaVisionModel = model.vision_model
212
+ text_model: MllamaTextModel = language_model.model
213
+ elif isinstance(model, MllamaForCausalLM):
214
+ text_model = model.model
215
+ vision_model = None
216
+ elif isinstance(model, MllamaTextModel):
217
+ text_model = model
218
+ vision_model = None
219
+ else:
220
+ raise ValueError(f"Unsupported Mllama model type: {type(model)}")
221
+
222
+ if text_model:
93
223
  if rms_norm:
94
- decoder_layer.input_layernorm = LigerRMSNorm(
95
- config.hidden_size, eps=config.rms_norm_eps
96
- ).to(torch_dtype)
97
- decoder_layer.post_attention_layernorm = LigerRMSNorm(
98
- config.hidden_size, eps=config.rms_norm_eps
99
- ).to(torch_dtype)
224
+ _patch_rms_norm_module(text_model.norm)
225
+ for decoder_layer in text_model.layers:
226
+ if swiglu:
227
+ _bind_method_to_module(
228
+ decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
229
+ )
230
+ if rms_norm:
231
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
232
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
233
+
234
+ if vision_model:
235
+ _patch_layer_norm_module(vision_model.layernorm_pre)
236
+ _patch_layer_norm_module(vision_model.layernorm_post)
237
+
238
+ for layer in vision_model.transformer.layers:
239
+ if layer_norm:
240
+ _patch_layer_norm_module(layer.input_layernorm)
241
+ _patch_layer_norm_module(layer.post_attention_layernorm)
242
+
243
+ for layer in vision_model.global_transformer.layers:
244
+ if layer_norm:
245
+ _patch_layer_norm_module(layer.input_layernorm)
246
+ _patch_layer_norm_module(layer.post_attention_layernorm)
100
247
 
101
248
 
102
249
  def apply_liger_kernel_to_mistral(
@@ -143,7 +290,6 @@ def apply_liger_kernel_to_mistral(
143
290
  if model is not None:
144
291
  # The model instance already exists, so we need to additionally patch the
145
292
  # instance variables that reference already-instantiated modules
146
- config: PretrainedConfig = model.config
147
293
 
148
294
  if hasattr(model, "model"):
149
295
  # The case for MistralForCausalLM, MistralForTokenClassification for example
@@ -152,22 +298,17 @@ def apply_liger_kernel_to_mistral(
152
298
  # Direct MistralModel
153
299
  base_model = model
154
300
 
155
- torch_dtype = config.torch_dtype
156
301
  if rms_norm:
157
- base_model.norm = LigerRMSNorm(
158
- config.hidden_size, eps=config.rms_norm_eps
159
- ).to(torch_dtype)
302
+ _patch_rms_norm_module(base_model.norm)
160
303
 
161
304
  for decoder_layer in base_model.layers:
162
305
  if swiglu:
163
- decoder_layer.mlp = LigerSwiGLUMLP(config).to(torch_dtype)
306
+ _bind_method_to_module(
307
+ decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
308
+ )
164
309
  if rms_norm:
165
- decoder_layer.input_layernorm = LigerRMSNorm(
166
- config.hidden_size, eps=config.rms_norm_eps
167
- ).to(torch_dtype)
168
- decoder_layer.post_attention_layernorm = LigerRMSNorm(
169
- config.hidden_size, eps=config.rms_norm_eps
170
- ).to(torch_dtype)
310
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
311
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
171
312
 
172
313
 
173
314
  def apply_liger_kernel_to_mixtral(
@@ -207,14 +348,17 @@ def apply_liger_kernel_to_mixtral(
207
348
  if cross_entropy:
208
349
  modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss
209
350
  if fused_linear_cross_entropy:
210
- modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
351
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
352
+ modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
353
+ else: # if version < 4.46.1
354
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
355
+ modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
211
356
  if swiglu:
212
357
  modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
213
358
 
214
359
  if model is not None:
215
360
  # The model instance already exists, so we need to additionally patch the
216
361
  # instance variables that reference already-instantiated modules
217
- config: PretrainedConfig = model.config
218
362
 
219
363
  if hasattr(model, "model"):
220
364
  # The case for MixtralForCausalLM, MixtralForTokenClassification for example
@@ -223,29 +367,18 @@ def apply_liger_kernel_to_mixtral(
223
367
  # Direct MixtralModel
224
368
  base_model = model
225
369
 
226
- torch_dtype = config.torch_dtype
227
370
  if rms_norm:
228
- base_model.norm = LigerRMSNorm(
229
- config.hidden_size, eps=config.rms_norm_eps
230
- ).to(torch_dtype)
371
+ _patch_rms_norm_module(base_model.norm)
231
372
 
232
373
  for decoder_layer in base_model.layers:
233
374
  if swiglu:
234
- block_sparse_moe = decoder_layer.block_sparse_moe
235
- patched_experts = nn.ModuleList(
236
- [
237
- LigerBlockSparseTop2MLP(config)
238
- for _ in range(block_sparse_moe.num_experts)
239
- ]
240
- )
241
- decoder_layer.block_sparse_moe.experts = patched_experts.to(torch_dtype)
375
+ for expert in decoder_layer.block_sparse_moe.experts:
376
+ _bind_method_to_module(
377
+ expert, "forward", LigerBlockSparseTop2MLP.forward
378
+ )
242
379
  if rms_norm:
243
- decoder_layer.input_layernorm = LigerRMSNorm(
244
- config.hidden_size, eps=config.rms_norm_eps
245
- ).to(torch_dtype)
246
- decoder_layer.post_attention_layernorm = LigerRMSNorm(
247
- config.hidden_size, eps=config.rms_norm_eps
248
- ).to(torch_dtype)
380
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
381
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
249
382
 
250
383
 
251
384
  def apply_liger_kernel_to_gemma(
@@ -282,6 +415,9 @@ def apply_liger_kernel_to_gemma(
282
415
  LigerRMSNormForGemma = partial(
283
416
  LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
284
417
  )
418
+ _patch_rms_norm_module_for_gemma = partial(
419
+ _patch_rms_norm_module, casting_mode="gemma", offset=1.0
420
+ )
285
421
 
286
422
  if rope:
287
423
  modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
@@ -292,12 +428,15 @@ def apply_liger_kernel_to_gemma(
292
428
  if geglu:
293
429
  modeling_gemma.GemmaMLP = LigerGEGLUMLP
294
430
  if fused_linear_cross_entropy:
295
- modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
431
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
432
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
433
+ else: # if version < 4.46.1
434
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
435
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
296
436
 
297
437
  if model is not None:
298
438
  # The model instance already exists, so we need to additionally patch the
299
439
  # instance variables that reference already-instantiated modules
300
- config: PretrainedConfig = model.config
301
440
 
302
441
  if hasattr(model, "model"):
303
442
  # The case for GemmaForCausalLM, GemmaForTokenClassification for example
@@ -306,22 +445,17 @@ def apply_liger_kernel_to_gemma(
306
445
  # Direct GemmaModel
307
446
  base_model = model
308
447
 
309
- torch_dtype = config.torch_dtype
310
448
  if rms_norm:
311
- base_model.norm = LigerRMSNormForGemma(
312
- config.hidden_size, eps=config.rms_norm_eps
313
- ).to(torch_dtype)
449
+ _patch_rms_norm_module_for_gemma(base_model.norm)
314
450
 
315
451
  for decoder_layer in base_model.layers:
316
452
  if geglu:
317
- decoder_layer.mlp = LigerGEGLUMLP(config).to(torch_dtype)
453
+ _bind_method_to_module(
454
+ decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
455
+ )
318
456
  if rms_norm:
319
- decoder_layer.input_layernorm = LigerRMSNormForGemma(
320
- config.hidden_size, eps=config.rms_norm_eps
321
- ).to(torch_dtype)
322
- decoder_layer.post_attention_layernorm = LigerRMSNormForGemma(
323
- config.hidden_size, eps=config.rms_norm_eps
324
- ).to(torch_dtype)
457
+ _patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
458
+ _patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)
325
459
 
326
460
 
327
461
  def apply_liger_kernel_to_gemma2(
@@ -343,10 +477,15 @@ def apply_liger_kernel_to_gemma2(
343
477
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
344
478
  loaded. Default is None.
345
479
  """
346
- print("Got here!")
347
480
  from transformers.models.gemma2 import modeling_gemma2
348
481
 
349
- LigerRMSNormForGemma2 = partial(LigerRMSNorm, offset=1.0, init_fn="zeros")
482
+ LigerRMSNormForGemma2 = partial(
483
+ LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros"
484
+ )
485
+ _patch_rms_norm_module_for_gemma2 = partial(
486
+ _patch_rms_norm_module, offset=1.0, casting_mode="gemma"
487
+ )
488
+
350
489
  if rope:
351
490
  modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
352
491
  if rms_norm:
@@ -360,7 +499,6 @@ def apply_liger_kernel_to_gemma2(
360
499
  if model is not None:
361
500
  # The model instance already exists, so we need to additionally patch the
362
501
  # instance variables that reference already-instantiated modules
363
- config: PretrainedConfig = model.config
364
502
 
365
503
  if hasattr(model, "model"):
366
504
  # The case for Gemma2ForCausalLM, Gemma2ForTokenClassification for example
@@ -369,28 +507,25 @@ def apply_liger_kernel_to_gemma2(
369
507
  # Direct Gemma2Model
370
508
  base_model = model
371
509
 
372
- torch_dtype = config.torch_dtype
373
510
  if rms_norm:
374
- base_model.norm = LigerRMSNormForGemma2(
375
- config.hidden_size, eps=config.rms_norm_eps
376
- ).to(torch_dtype)
511
+ _patch_rms_norm_module_for_gemma2(base_model.norm)
377
512
 
378
513
  for decoder_layer in base_model.layers:
379
514
  if geglu:
380
- decoder_layer.mlp = LigerGEGLUMLP(config).to(torch_dtype)
515
+ _bind_method_to_module(
516
+ decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
517
+ )
381
518
  if rms_norm:
382
- decoder_layer.input_layernorm = LigerRMSNormForGemma2(
383
- config.hidden_size, eps=config.rms_norm_eps
384
- ).to(torch_dtype)
385
- decoder_layer.post_attention_layernorm = LigerRMSNormForGemma2(
386
- config.hidden_size, eps=config.rms_norm_eps
387
- ).to(torch_dtype)
388
- decoder_layer.pre_feedforward_layernorm = LigerRMSNormForGemma2(
389
- config.hidden_size, eps=config.rms_norm_eps
390
- ).to(torch_dtype)
391
- decoder_layer.post_feedforward_layernorm = LigerRMSNormForGemma2(
392
- config.hidden_size, eps=config.rms_norm_eps
393
- ).to(torch_dtype)
519
+ _patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
520
+ _patch_rms_norm_module_for_gemma2(
521
+ decoder_layer.post_attention_layernorm
522
+ )
523
+ _patch_rms_norm_module_for_gemma2(
524
+ decoder_layer.pre_feedforward_layernorm
525
+ )
526
+ _patch_rms_norm_module_for_gemma2(
527
+ decoder_layer.post_feedforward_layernorm
528
+ )
394
529
 
395
530
 
396
531
  def apply_liger_kernel_to_qwen2(
@@ -428,15 +563,22 @@ def apply_liger_kernel_to_qwen2(
428
563
  modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
429
564
  if cross_entropy:
430
565
  modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
566
+
567
+ # import pdb; pdb.set_trace()
431
568
  if fused_linear_cross_entropy:
432
- modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
569
+
570
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
571
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
572
+ else: # if version < 4.46.1
573
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
574
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
575
+
433
576
  if swiglu:
434
577
  modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
435
578
 
436
579
  if model is not None:
437
580
  # The model instance already exists, so we need to additionally patch the
438
581
  # instance variables that reference already-instantiated modules
439
- config: PretrainedConfig = model.config
440
582
 
441
583
  if hasattr(model, "model"):
442
584
  # The case for Qwen2ForCausalLM, Qwen2ForTokenClassification for example
@@ -445,22 +587,18 @@ def apply_liger_kernel_to_qwen2(
445
587
  # Direct Qwen2Model
446
588
  base_model = model
447
589
 
448
- torch_dtype = config.torch_dtype
449
590
  if rms_norm:
450
- base_model.norm = LigerRMSNorm(
451
- config.hidden_size, eps=config.rms_norm_eps
452
- ).to(torch_dtype)
591
+ _patch_rms_norm_module(base_model.norm)
453
592
 
454
593
  for decoder_layer in base_model.layers:
455
594
  if swiglu:
456
- decoder_layer.mlp = LigerSwiGLUMLP(config).to(torch_dtype)
595
+ _bind_method_to_module(
596
+ decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
597
+ )
457
598
  if rms_norm:
458
- decoder_layer.input_layernorm = LigerRMSNorm(
459
- config.hidden_size, eps=config.rms_norm_eps
460
- ).to(torch_dtype)
461
- decoder_layer.post_attention_layernorm = LigerRMSNorm(
462
- config.hidden_size, eps=config.rms_norm_eps
463
- ).to(torch_dtype)
599
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
600
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
601
+ print("Applied Liger kernels to Qwen2")
464
602
 
465
603
 
466
604
  def apply_liger_kernel_to_qwen2_vl(
@@ -473,7 +611,7 @@ def apply_liger_kernel_to_qwen2_vl(
473
611
  ) -> None:
474
612
  """
475
613
  Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
476
- NOTE: Qwen2-VL is not available in transformers<=4.44.2
614
+ NOTE: Qwen2-VL is not available in transformers<4.45.0
477
615
 
478
616
  Args:
479
617
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
@@ -499,10 +637,9 @@ def apply_liger_kernel_to_qwen2_vl(
499
637
 
500
638
  # TODO: Support Qwen2-VL's multimodal RoPE implementation
501
639
 
502
- LigerRMSNormForQwen2VL = partial(LigerRMSNorm, init_fn="ones", casting_mode="gemma")
503
640
  if rms_norm:
504
641
  # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
505
- modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNormForQwen2VL
642
+ modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
506
643
  if layer_norm:
507
644
  modeling_qwen2_vl.LayerNorm = LigerLayerNorm
508
645
  if cross_entropy:
@@ -515,9 +652,6 @@ def apply_liger_kernel_to_qwen2_vl(
515
652
  if model is not None:
516
653
  # The model instance already exists, so we need to additionally patch the
517
654
  # instance variables that reference already-instantiated modules
518
- config: PretrainedConfig = model.config
519
-
520
- torch_dtype = config.torch_dtype
521
655
 
522
656
  if hasattr(model, "model"):
523
657
  # The case for Qwen2VLForConditionalGeneration.
@@ -530,27 +664,19 @@ def apply_liger_kernel_to_qwen2_vl(
530
664
  # Patch Qwen2VisionTransformerPretrainedModel
531
665
  for vision_block in model.visual.blocks:
532
666
  if layer_norm:
533
- vision_block.norm1 = LigerLayerNorm(config.embed_dim, eps=1e-6).to(
534
- torch_dtype
535
- )
536
- vision_block.norm2 = LigerLayerNorm(config.embed_dim, eps=1e-6).to(
537
- torch_dtype
538
- )
667
+ _patch_layer_norm_module(vision_block.norm1)
668
+ _patch_layer_norm_module(vision_block.norm2)
539
669
 
540
670
  if rms_norm:
541
- base_model.norm = LigerRMSNormForQwen2VL(
542
- config.hidden_size, eps=config.rms_norm_eps
543
- ).to(torch_dtype)
671
+ _patch_rms_norm_module(base_model.norm)
544
672
  for decoder_layer in base_model.layers:
545
673
  if swiglu:
546
- decoder_layer.mlp = LigerSwiGLUMLP(config).to(torch_dtype)
674
+ _bind_method_to_module(
675
+ decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
676
+ )
547
677
  if rms_norm:
548
- decoder_layer.input_layernorm = LigerRMSNormForQwen2VL(
549
- config.hidden_size, eps=config.rms_norm_eps
550
- ).to(torch_dtype)
551
- decoder_layer.post_attention_layernorm = LigerRMSNormForQwen2VL(
552
- config.hidden_size, eps=config.rms_norm_eps
553
- ).to(torch_dtype)
678
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
679
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
554
680
 
555
681
 
556
682
  def apply_liger_kernel_to_phi3(
@@ -591,12 +717,15 @@ def apply_liger_kernel_to_phi3(
591
717
  if cross_entropy:
592
718
  modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
593
719
  if fused_linear_cross_entropy:
594
- modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
720
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
721
+ modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
722
+ else: # if version < 4.46.1
723
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
724
+ modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
595
725
 
596
726
  if model is not None:
597
727
  # The model instance already exists, so we need to additionally patch the
598
728
  # instance variables that reference already-instantiated modules
599
- config: PretrainedConfig = model.config
600
729
 
601
730
  if hasattr(model, "model"):
602
731
  # The case for Phi3ForCausalLM, Phi3ForTokenClassification for example
@@ -605,22 +734,17 @@ def apply_liger_kernel_to_phi3(
605
734
  # Direct Phi3Model
606
735
  base_model = model
607
736
 
608
- torch_dtype = config.torch_dtype
609
737
  if rms_norm:
610
- base_model.norm = LigerRMSNorm(
611
- config.hidden_size, eps=config.rms_norm_eps
612
- ).to(torch_dtype)
738
+ _patch_rms_norm_module(base_model.norm)
613
739
 
614
740
  for decoder_layer in base_model.layers:
615
741
  if swiglu:
616
- decoder_layer.mlp = LigerPhi3SwiGLUMLP(config).to(torch_dtype)
742
+ _bind_method_to_module(
743
+ decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward
744
+ )
617
745
  if rms_norm:
618
- decoder_layer.input_layernorm = LigerRMSNorm(
619
- config.hidden_size, eps=config.rms_norm_eps
620
- ).to(torch_dtype)
621
- decoder_layer.post_attention_layernorm = LigerRMSNorm(
622
- config.hidden_size, eps=config.rms_norm_eps
623
- ).to(torch_dtype)
746
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
747
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
624
748
 
625
749
 
626
750
  # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
@@ -628,6 +752,8 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
628
752
  "gemma": apply_liger_kernel_to_gemma,
629
753
  "gemma2": apply_liger_kernel_to_gemma2,
630
754
  "llama": apply_liger_kernel_to_llama,
755
+ "mllama": apply_liger_kernel_to_mllama,
756
+ "mllama_text_model": apply_liger_kernel_to_mllama,
631
757
  "mistral": apply_liger_kernel_to_mistral,
632
758
  "mixtral": apply_liger_kernel_to_mixtral,
633
759
  "qwen2": apply_liger_kernel_to_qwen2,
@@ -713,7 +839,6 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
713
839
  for key, value in kwargs.items()
714
840
  if key in apply_fn_signature.parameters
715
841
  }
716
-
717
842
  logger.info(
718
843
  f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}"
719
844
  )