liger-kernel 0.1.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. liger_kernel/env_report.py +46 -0
  2. liger_kernel/ops/cross_entropy.py +130 -63
  3. liger_kernel/ops/experimental/embedding.py +143 -0
  4. liger_kernel/ops/fused_linear_cross_entropy.py +203 -126
  5. liger_kernel/ops/geglu.py +56 -44
  6. liger_kernel/ops/kl_div.py +258 -0
  7. liger_kernel/ops/layer_norm.py +236 -0
  8. liger_kernel/ops/rms_norm.py +220 -84
  9. liger_kernel/ops/rope.py +91 -84
  10. liger_kernel/ops/swiglu.py +50 -43
  11. liger_kernel/ops/utils.py +12 -0
  12. liger_kernel/transformers/__init__.py +22 -0
  13. liger_kernel/transformers/auto_model.py +45 -0
  14. liger_kernel/transformers/cross_entropy.py +11 -1
  15. liger_kernel/transformers/experimental/embedding.py +28 -0
  16. liger_kernel/transformers/functional.py +19 -0
  17. liger_kernel/transformers/fused_linear_cross_entropy.py +8 -2
  18. liger_kernel/transformers/geglu.py +4 -2
  19. liger_kernel/transformers/kl_div.py +14 -0
  20. liger_kernel/transformers/layer_norm.py +30 -0
  21. liger_kernel/transformers/model/gemma.py +138 -0
  22. liger_kernel/transformers/model/llama.py +1 -1
  23. liger_kernel/transformers/model/mistral.py +138 -0
  24. liger_kernel/transformers/model/mixtral.py +158 -0
  25. liger_kernel/transformers/model/phi3.py +136 -0
  26. liger_kernel/transformers/model/qwen2.py +135 -0
  27. liger_kernel/transformers/model/qwen2_vl.py +172 -0
  28. liger_kernel/transformers/monkey_patch.py +579 -14
  29. liger_kernel/transformers/rms_norm.py +23 -4
  30. liger_kernel/transformers/swiglu.py +24 -0
  31. liger_kernel/transformers/trainer_integration.py +2 -45
  32. liger_kernel-0.3.1.dist-info/METADATA +395 -0
  33. liger_kernel-0.3.1.dist-info/RECORD +42 -0
  34. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.1.dist-info}/WHEEL +1 -1
  35. liger_kernel-0.1.0.dist-info/METADATA +0 -16
  36. liger_kernel-0.1.0.dist-info/RECORD +0 -27
  37. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.1.dist-info}/LICENSE +0 -0
  38. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.1.dist-info}/NOTICE +0 -0
  39. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.1.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,52 @@
1
+ import inspect
2
+ import logging
3
+ from functools import partial
4
+ from typing import Callable
5
+
6
+ from transformers import PreTrainedModel
7
+
1
8
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
2
9
  from liger_kernel.transformers.geglu import LigerGEGLUMLP
3
- from liger_kernel.transformers.model.llama import lce_forward
10
+ from liger_kernel.transformers.layer_norm import LigerLayerNorm
11
+ from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
12
+ from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
13
+ from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
14
+ from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
15
+ from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
16
+ from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
4
17
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
5
18
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
6
- from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP, LigerSwiGLUMLP
19
+ from liger_kernel.transformers.swiglu import (
20
+ LigerBlockSparseTop2MLP,
21
+ LigerPhi3SwiGLUMLP,
22
+ LigerSwiGLUMLP,
23
+ )
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ def _bind_method_to_module(module, method_name: str, new_method: Callable):
29
+ # Binds a new method to a module instance so that self is passed as the first argument
30
+ module.__dict__[method_name] = new_method.__get__(module, module.__class__)
31
+
32
+
33
+ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama"):
34
+ module.offset = offset
35
+ module.casting_mode = casting_mode
36
+ module.variance_epsilon = (
37
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
38
+ )
39
+ _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
40
+ _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
41
+
42
+
43
+ def _patch_layer_norm_module(module, eps=1e-6):
44
+ module.variance_epsilon = (
45
+ getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
46
+ )
47
+ module.hidden_size = module.normalized_shape
48
+ _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
49
+ _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
7
50
 
8
51
 
9
52
  def apply_liger_kernel_to_llama(
@@ -12,6 +55,7 @@ def apply_liger_kernel_to_llama(
12
55
  fused_linear_cross_entropy: bool = True,
13
56
  rms_norm: bool = True,
14
57
  swiglu: bool = True,
58
+ model: PreTrainedModel = None,
15
59
  ) -> None:
16
60
  """
17
61
  Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
@@ -20,11 +64,13 @@ def apply_liger_kernel_to_llama(
20
64
  rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
21
65
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
22
66
  fused_linear_cross_entropy (bool):
23
- Whether to apply Liger's fused lienar cross entropy loss. Default is True.
67
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
24
68
  `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
25
69
  If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
26
70
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
27
71
  swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
72
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
73
+ loaded. Default is None.
28
74
  """
29
75
 
30
76
  assert not (
@@ -42,14 +88,42 @@ def apply_liger_kernel_to_llama(
42
88
  if cross_entropy:
43
89
  modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
44
90
  if fused_linear_cross_entropy:
45
- modeling_llama.LlamaForCausalLM.forward = lce_forward
91
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
92
+
93
+ if model is not None:
94
+ # The model instance already exists, so we need to additionally patch the
95
+ # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
96
+
97
+ if hasattr(model, "model"):
98
+ # The case for LlamaForCausalLM or LlamaForSequenceClassification, for example
99
+ base_model = model.model
100
+ elif hasattr(model, "transformer"):
101
+ # LlamaForQuestionAnswering uses "transformer" instead of "model"
102
+ base_model = model.transformer
103
+ else:
104
+ # Direct LlamaModel
105
+ base_model = model
106
+
107
+ if rms_norm:
108
+ _patch_rms_norm_module(base_model.norm)
109
+
110
+ for decoder_layer in base_model.layers:
111
+ if swiglu:
112
+ _bind_method_to_module(
113
+ decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
114
+ )
115
+ if rms_norm:
116
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
117
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
46
118
 
47
119
 
48
120
  def apply_liger_kernel_to_mistral(
49
121
  rope: bool = True,
50
- cross_entropy: bool = True,
122
+ cross_entropy: bool = False,
123
+ fused_linear_cross_entropy: bool = True,
51
124
  rms_norm: bool = True,
52
125
  swiglu: bool = True,
126
+ model: PreTrainedModel = None,
53
127
  ) -> None:
54
128
  """
55
129
  Apply Liger kernels to replace original implementation in HuggingFace Mistral models
@@ -57,9 +131,19 @@ def apply_liger_kernel_to_mistral(
57
131
  Args:
58
132
  rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
59
133
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
134
+ fused_linear_cross_entropy (bool):
135
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
136
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
137
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
138
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
60
139
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
61
140
  swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
141
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
142
+ loaded. Default is None.
62
143
  """
144
+ assert not (
145
+ cross_entropy and fused_linear_cross_entropy
146
+ ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
63
147
 
64
148
  from transformers.models.mistral import modeling_mistral
65
149
 
@@ -69,62 +153,543 @@ def apply_liger_kernel_to_mistral(
69
153
  modeling_mistral.MistralRMSNorm = LigerRMSNorm
70
154
  if cross_entropy:
71
155
  modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
156
+ if fused_linear_cross_entropy:
157
+ modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
72
158
  if swiglu:
73
159
  modeling_mistral.MistralMLP = LigerSwiGLUMLP
74
160
 
161
+ if model is not None:
162
+ # The model instance already exists, so we need to additionally patch the
163
+ # instance variables that reference already-instantiated modules
164
+
165
+ if hasattr(model, "model"):
166
+ # The case for MistralForCausalLM, MistralForTokenClassification for example
167
+ base_model = model.model
168
+ else:
169
+ # Direct MistralModel
170
+ base_model = model
171
+
172
+ if rms_norm:
173
+ _patch_rms_norm_module(base_model.norm)
174
+
175
+ for decoder_layer in base_model.layers:
176
+ if swiglu:
177
+ _bind_method_to_module(
178
+ decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
179
+ )
180
+ if rms_norm:
181
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
182
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
183
+
75
184
 
76
185
  def apply_liger_kernel_to_mixtral(
77
186
  rope: bool = True,
78
- cross_entropy: bool = True,
187
+ cross_entropy: bool = False,
188
+ fused_linear_cross_entropy: bool = True,
79
189
  rms_norm: bool = True,
80
190
  swiglu: bool = True,
191
+ model: PreTrainedModel = None,
81
192
  ) -> None:
82
193
  """
83
194
  Apply Liger kernels to replace original implementation in HuggingFace Mixtral models
84
195
 
85
196
  Args:
86
197
  rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
87
- cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
198
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
199
+ fused_linear_cross_entropy (bool):
200
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
201
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
202
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
88
203
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
89
204
  swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
205
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
206
+ loaded. Default is None.
90
207
  """
91
208
 
209
+ assert not (
210
+ cross_entropy and fused_linear_cross_entropy
211
+ ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
212
+
92
213
  from transformers.models.mixtral import modeling_mixtral
93
214
 
94
215
  if rope:
95
216
  modeling_mixtral.apply_rotary_pos_emb = liger_rotary_pos_emb
96
217
  if rms_norm:
97
- modeling_mixtral.MistralRMSNorm = LigerRMSNorm
218
+ modeling_mixtral.MixtralRMSNorm = LigerRMSNorm
98
219
  if cross_entropy:
99
220
  modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss
221
+ if fused_linear_cross_entropy:
222
+ modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
100
223
  if swiglu:
101
224
  modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
102
225
 
226
+ if model is not None:
227
+ # The model instance already exists, so we need to additionally patch the
228
+ # instance variables that reference already-instantiated modules
229
+
230
+ if hasattr(model, "model"):
231
+ # The case for MixtralForCausalLM, MixtralForTokenClassification for example
232
+ base_model = model.model
233
+ else:
234
+ # Direct MixtralModel
235
+ base_model = model
236
+
237
+ if rms_norm:
238
+ _patch_rms_norm_module(base_model.norm)
239
+
240
+ for decoder_layer in base_model.layers:
241
+ if swiglu:
242
+ for expert in decoder_layer.block_sparse_moe.experts:
243
+ _bind_method_to_module(
244
+ expert, "forward", LigerBlockSparseTop2MLP.forward
245
+ )
246
+ if rms_norm:
247
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
248
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
249
+
103
250
 
104
251
  def apply_liger_kernel_to_gemma(
105
252
  rope: bool = True,
106
- cross_entropy: bool = True,
253
+ cross_entropy: bool = False,
254
+ fused_linear_cross_entropy: bool = True,
107
255
  rms_norm: bool = True,
108
256
  geglu: bool = True,
257
+ model: PreTrainedModel = None,
109
258
  ) -> None:
110
259
  """
111
- Apply Liger kernels to replace original implementation in HuggingFace Gemma2 models
112
- to make GPU go burrr.
260
+ Apply Liger kernels to replace original implementation in HuggingFace Gemma
261
+ (Gemma 1 and 1.1 supported, for Gemma2 please use `apply_liger_kernel_to_gemma2` ) to make GPU go burrr.
113
262
 
114
263
  Args:
115
264
  rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
116
- cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
265
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
266
+ fused_linear_cross_entropy (bool):
267
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
268
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
269
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
117
270
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
118
271
  geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
272
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
273
+ loaded. Default is None.
119
274
  """
120
- # TODO(yundai424): add convergence test for gemma
275
+ assert not (
276
+ cross_entropy and fused_linear_cross_entropy
277
+ ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
278
+
121
279
  from transformers.models.gemma import modeling_gemma
122
280
 
281
+ # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
282
+ LigerRMSNormForGemma = partial(
283
+ LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
284
+ )
285
+ _patch_rms_norm_module_for_gemma = partial(
286
+ _patch_rms_norm_module, casting_mode="gemma", offset=1.0
287
+ )
288
+
123
289
  if rope:
124
290
  modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
125
291
  if rms_norm:
126
- modeling_gemma.GemmaRMSNorm = LigerRMSNorm
292
+ modeling_gemma.GemmaRMSNorm = LigerRMSNormForGemma
127
293
  if cross_entropy:
128
294
  modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
129
295
  if geglu:
130
296
  modeling_gemma.GemmaMLP = LigerGEGLUMLP
297
+ if fused_linear_cross_entropy:
298
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
299
+
300
+ if model is not None:
301
+ # The model instance already exists, so we need to additionally patch the
302
+ # instance variables that reference already-instantiated modules
303
+
304
+ if hasattr(model, "model"):
305
+ # The case for GemmaForCausalLM, GemmaForTokenClassification for example
306
+ base_model = model.model
307
+ else:
308
+ # Direct GemmaModel
309
+ base_model = model
310
+
311
+ if rms_norm:
312
+ _patch_rms_norm_module_for_gemma(base_model.norm)
313
+
314
+ for decoder_layer in base_model.layers:
315
+ if geglu:
316
+ _bind_method_to_module(
317
+ decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
318
+ )
319
+ if rms_norm:
320
+ _patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
321
+ _patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)
322
+
323
+
324
+ def apply_liger_kernel_to_gemma2(
325
+ rope: bool = True,
326
+ cross_entropy: bool = True,
327
+ rms_norm: bool = True,
328
+ geglu: bool = True,
329
+ model: PreTrainedModel = None,
330
+ ) -> None:
331
+ """
332
+ Apply Liger kernels to replace original implementation in HuggingFace Gemma2
333
+ (for Gemma1 please use `apply_liger_kernel_to_gemma`) to make GPU go burrr.
334
+
335
+ Args:
336
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
337
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
338
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
339
+ geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
340
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
341
+ loaded. Default is None.
342
+ """
343
+ from transformers.models.gemma2 import modeling_gemma2
344
+
345
+ LigerRMSNormForGemma2 = partial(
346
+ LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros"
347
+ )
348
+ _patch_rms_norm_module_for_gemma2 = partial(
349
+ _patch_rms_norm_module, offset=1.0, casting_mode="gemma"
350
+ )
351
+
352
+ if rope:
353
+ modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
354
+ if rms_norm:
355
+ # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
356
+ modeling_gemma2.Gemma2RMSNorm = LigerRMSNormForGemma2
357
+ if cross_entropy:
358
+ modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
359
+ if geglu:
360
+ modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
361
+
362
+ if model is not None:
363
+ # The model instance already exists, so we need to additionally patch the
364
+ # instance variables that reference already-instantiated modules
365
+
366
+ if hasattr(model, "model"):
367
+ # The case for Gemma2ForCausalLM, Gemma2ForTokenClassification for example
368
+ base_model = model.model
369
+ else:
370
+ # Direct Gemma2Model
371
+ base_model = model
372
+
373
+ if rms_norm:
374
+ _patch_rms_norm_module_for_gemma2(base_model.norm)
375
+
376
+ for decoder_layer in base_model.layers:
377
+ if geglu:
378
+ _bind_method_to_module(
379
+ decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
380
+ )
381
+ if rms_norm:
382
+ _patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
383
+ _patch_rms_norm_module_for_gemma2(
384
+ decoder_layer.post_attention_layernorm
385
+ )
386
+ _patch_rms_norm_module_for_gemma2(
387
+ decoder_layer.pre_feedforward_layernorm
388
+ )
389
+ _patch_rms_norm_module_for_gemma2(
390
+ decoder_layer.post_feedforward_layernorm
391
+ )
392
+
393
+
394
+ def apply_liger_kernel_to_qwen2(
395
+ rope: bool = True,
396
+ cross_entropy: bool = False,
397
+ fused_linear_cross_entropy: bool = True,
398
+ rms_norm: bool = True,
399
+ swiglu: bool = True,
400
+ model: PreTrainedModel = None,
401
+ ) -> None:
402
+ """
403
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen2 models
404
+
405
+ Args:
406
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
407
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
408
+ fused_linear_cross_entropy (bool):
409
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
410
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
411
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
412
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
413
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
414
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
415
+ loaded. Default is None.
416
+ """
417
+ assert not (
418
+ cross_entropy and fused_linear_cross_entropy
419
+ ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
420
+
421
+ from transformers.models.qwen2 import modeling_qwen2
422
+
423
+ if rope:
424
+ modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
425
+ if rms_norm:
426
+ modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
427
+ if cross_entropy:
428
+ modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
429
+ if fused_linear_cross_entropy:
430
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
431
+ if swiglu:
432
+ modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
433
+
434
+ if model is not None:
435
+ # The model instance already exists, so we need to additionally patch the
436
+ # instance variables that reference already-instantiated modules
437
+
438
+ if hasattr(model, "model"):
439
+ # The case for Qwen2ForCausalLM, Qwen2ForTokenClassification for example
440
+ base_model = model.model
441
+ else:
442
+ # Direct Qwen2Model
443
+ base_model = model
444
+
445
+ if rms_norm:
446
+ _patch_rms_norm_module(base_model.norm)
447
+
448
+ for decoder_layer in base_model.layers:
449
+ if swiglu:
450
+ _bind_method_to_module(
451
+ decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
452
+ )
453
+ if rms_norm:
454
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
455
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
456
+
457
+
458
+ def apply_liger_kernel_to_qwen2_vl(
459
+ cross_entropy: bool = False,
460
+ fused_linear_cross_entropy: bool = True,
461
+ rms_norm: bool = True,
462
+ layer_norm: bool = True,
463
+ swiglu: bool = True,
464
+ model: PreTrainedModel = None,
465
+ ) -> None:
466
+ """
467
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
468
+ NOTE: Qwen2-VL is not available in transformers<=4.44.2
469
+
470
+ Args:
471
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
472
+ fused_linear_cross_entropy (bool):
473
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
474
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
475
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
476
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
477
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
478
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
479
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
480
+ loaded. Default is None.
481
+ """
482
+ assert not (
483
+ cross_entropy and fused_linear_cross_entropy
484
+ ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
485
+
486
+ from transformers.models.qwen2_vl import modeling_qwen2_vl
487
+
488
+ from liger_kernel.transformers.model.qwen2_vl import (
489
+ lce_forward as qwen2_vl_lce_forward,
490
+ )
491
+
492
+ # TODO: Support Qwen2-VL's multimodal RoPE implementation
493
+
494
+ if rms_norm:
495
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
496
+ modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
497
+ if layer_norm:
498
+ modeling_qwen2_vl.LayerNorm = LigerLayerNorm
499
+ if cross_entropy:
500
+ modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
501
+ if fused_linear_cross_entropy:
502
+ modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
503
+ if swiglu:
504
+ modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
505
+
506
+ if model is not None:
507
+ # The model instance already exists, so we need to additionally patch the
508
+ # instance variables that reference already-instantiated modules
509
+
510
+ if hasattr(model, "model"):
511
+ # The case for Qwen2VLForConditionalGeneration.
512
+ base_model = model.model
513
+ else:
514
+ # Direct Qwen2VLModel
515
+ base_model = model
516
+
517
+ if hasattr(model, "visual"):
518
+ # Patch Qwen2VisionTransformerPretrainedModel
519
+ for vision_block in model.visual.blocks:
520
+ if layer_norm:
521
+ _patch_layer_norm_module(vision_block.norm1)
522
+ _patch_layer_norm_module(vision_block.norm2)
523
+
524
+ if rms_norm:
525
+ _patch_rms_norm_module(base_model.norm)
526
+ for decoder_layer in base_model.layers:
527
+ if swiglu:
528
+ _bind_method_to_module(
529
+ decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
530
+ )
531
+ if rms_norm:
532
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
533
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
534
+
535
+
536
+ def apply_liger_kernel_to_phi3(
537
+ rope: bool = True,
538
+ cross_entropy: bool = False,
539
+ fused_linear_cross_entropy: bool = True,
540
+ rms_norm: bool = True,
541
+ swiglu: bool = True,
542
+ model: PreTrainedModel = None,
543
+ ) -> None:
544
+ """
545
+ Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.
546
+
547
+ Args:
548
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
549
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
550
+ fused_linear_cross_entropy (bool):
551
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
552
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
553
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
554
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
555
+ swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
556
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
557
+ loaded. Default is None.
558
+ """
559
+ assert not (
560
+ cross_entropy and fused_linear_cross_entropy
561
+ ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
562
+
563
+ from transformers.models.phi3 import modeling_phi3
564
+
565
+ if rope:
566
+ modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
567
+ if rms_norm:
568
+ modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama
569
+ if swiglu:
570
+ modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
571
+ if cross_entropy:
572
+ modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
573
+ if fused_linear_cross_entropy:
574
+ modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
575
+
576
+ if model is not None:
577
+ # The model instance already exists, so we need to additionally patch the
578
+ # instance variables that reference already-instantiated modules
579
+
580
+ if hasattr(model, "model"):
581
+ # The case for Phi3ForCausalLM, Phi3ForTokenClassification for example
582
+ base_model = model.model
583
+ else:
584
+ # Direct Phi3Model
585
+ base_model = model
586
+
587
+ if rms_norm:
588
+ _patch_rms_norm_module(base_model.norm)
589
+
590
+ for decoder_layer in base_model.layers:
591
+ if swiglu:
592
+ _bind_method_to_module(
593
+ decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward
594
+ )
595
+ if rms_norm:
596
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
597
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
598
+
599
+
600
+ # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
601
+ MODEL_TYPE_TO_APPLY_LIGER_FN = {
602
+ "gemma": apply_liger_kernel_to_gemma,
603
+ "gemma2": apply_liger_kernel_to_gemma2,
604
+ "llama": apply_liger_kernel_to_llama,
605
+ "mistral": apply_liger_kernel_to_mistral,
606
+ "mixtral": apply_liger_kernel_to_mixtral,
607
+ "qwen2": apply_liger_kernel_to_qwen2,
608
+ "qwen2_vl": apply_liger_kernel_to_qwen2_vl,
609
+ "phi3": apply_liger_kernel_to_phi3,
610
+ }
611
+
612
+
613
+ def _apply_liger_kernel(model_type: str, **kwargs) -> None:
614
+ """
615
+ Applies Liger kernels based on the specified model type. The custom
616
+ kernels for the specified model type will be applied with the provided
617
+ keyword arguments, otherwise the default configuration will be used.
618
+
619
+ ** Note: Calling _apply_liger_kernel() after model initialization
620
+ will not be able to fully patch models. This must be called before model initialization.
621
+ If the model has already been instantiated
622
+
623
+ Args:
624
+ - model_type: the model types as defined in transformers/models/auto/modeling_auto.py
625
+ and specified in the model's config.json
626
+ - kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
627
+ """
628
+ if not model_type:
629
+ logger.info("Model type was not provided. No Liger kernels will be applied.")
630
+ return
631
+
632
+ if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
633
+ logger.info(
634
+ f"There are currently no Liger kernels supported for model type: {model_type}."
635
+ )
636
+ return
637
+
638
+ apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
639
+ apply_fn_signature = inspect.signature(apply_fn)
640
+
641
+ # Filter out the keyword arguments that are not supported by the apply function
642
+ applicable_kwargs = {
643
+ key: value
644
+ for key, value in kwargs.items()
645
+ if key in apply_fn_signature.parameters
646
+ }
647
+
648
+ logger.info(
649
+ f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}"
650
+ )
651
+
652
+ # Assume this is invoked pre-model initialization, so we only need to patch transformers code
653
+ apply_fn(**applicable_kwargs)
654
+
655
+
656
+ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
657
+ """
658
+ Applies Liger kernels to the provided model instance.
659
+
660
+ Args:
661
+ - model: the model instance to apply Liger kernels to
662
+ - kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
663
+ """
664
+ model_type = getattr(model, "config", None) and getattr(
665
+ model.config, "model_type", None
666
+ )
667
+
668
+ if not model_type:
669
+ logger.info(
670
+ "Model type could not be determined from model config. No Liger kernels will be applied."
671
+ )
672
+ return
673
+
674
+ if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
675
+ logger.info(
676
+ f"There are currently no Liger kernels supported for model type: {model_type}."
677
+ )
678
+ return
679
+
680
+ apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
681
+
682
+ apply_fn_signature = inspect.signature(apply_fn)
683
+
684
+ # Filter out the keyword arguments that are not supported by the apply function
685
+ applicable_kwargs = {
686
+ key: value
687
+ for key, value in kwargs.items()
688
+ if key in apply_fn_signature.parameters
689
+ }
690
+
691
+ logger.info(
692
+ f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}"
693
+ )
694
+
695
+ apply_fn(model=model, **applicable_kwargs)