liger-kernel 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. liger_kernel/env_report.py +46 -0
  2. liger_kernel/ops/cross_entropy.py +130 -63
  3. liger_kernel/ops/experimental/embedding.py +143 -0
  4. liger_kernel/ops/fused_linear_cross_entropy.py +203 -126
  5. liger_kernel/ops/geglu.py +54 -42
  6. liger_kernel/ops/kl_div.py +247 -0
  7. liger_kernel/ops/layer_norm.py +236 -0
  8. liger_kernel/ops/rms_norm.py +220 -84
  9. liger_kernel/ops/rope.py +91 -84
  10. liger_kernel/ops/swiglu.py +48 -41
  11. liger_kernel/ops/utils.py +12 -0
  12. liger_kernel/transformers/__init__.py +22 -0
  13. liger_kernel/transformers/auto_model.py +33 -0
  14. liger_kernel/transformers/cross_entropy.py +11 -1
  15. liger_kernel/transformers/experimental/embedding.py +28 -0
  16. liger_kernel/transformers/functional.py +19 -0
  17. liger_kernel/transformers/fused_linear_cross_entropy.py +8 -2
  18. liger_kernel/transformers/geglu.py +4 -2
  19. liger_kernel/transformers/kl_div.py +13 -0
  20. liger_kernel/transformers/layer_norm.py +30 -0
  21. liger_kernel/transformers/model/gemma.py +138 -0
  22. liger_kernel/transformers/model/llama.py +1 -1
  23. liger_kernel/transformers/model/mistral.py +138 -0
  24. liger_kernel/transformers/model/mixtral.py +158 -0
  25. liger_kernel/transformers/model/phi3.py +136 -0
  26. liger_kernel/transformers/model/qwen2.py +135 -0
  27. liger_kernel/transformers/model/qwen2_vl.py +172 -0
  28. liger_kernel/transformers/monkey_patch.py +605 -14
  29. liger_kernel/transformers/rms_norm.py +23 -4
  30. liger_kernel/transformers/swiglu.py +24 -0
  31. liger_kernel/transformers/trainer_integration.py +2 -45
  32. liger_kernel-0.3.0.dist-info/METADATA +388 -0
  33. liger_kernel-0.3.0.dist-info/RECORD +42 -0
  34. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.0.dist-info}/WHEEL +1 -1
  35. liger_kernel-0.1.0.dist-info/METADATA +0 -16
  36. liger_kernel-0.1.0.dist-info/RECORD +0 -27
  37. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.0.dist-info}/LICENSE +0 -0
  38. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.0.dist-info}/NOTICE +0 -0
  39. {liger_kernel-0.1.0.dist-info → liger_kernel-0.3.0.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,28 @@
1
+ import inspect
2
+ import logging
3
+ from functools import partial
4
+
5
+ from torch import nn
6
+ from transformers import PretrainedConfig, PreTrainedModel
7
+
1
8
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
2
9
  from liger_kernel.transformers.geglu import LigerGEGLUMLP
3
- from liger_kernel.transformers.model.llama import lce_forward
10
+ from liger_kernel.transformers.layer_norm import LigerLayerNorm
11
+ from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
12
+ from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
13
+ from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
14
+ from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
15
+ from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
16
+ from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
4
17
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
5
18
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
6
- from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP, LigerSwiGLUMLP
19
+ from liger_kernel.transformers.swiglu import (
20
+ LigerBlockSparseTop2MLP,
21
+ LigerPhi3SwiGLUMLP,
22
+ LigerSwiGLUMLP,
23
+ )
24
+
25
+ logger = logging.getLogger(__name__)
7
26
 
8
27
 
9
28
  def apply_liger_kernel_to_llama(
@@ -12,6 +31,7 @@ def apply_liger_kernel_to_llama(
12
31
  fused_linear_cross_entropy: bool = True,
13
32
  rms_norm: bool = True,
14
33
  swiglu: bool = True,
34
+ model: PreTrainedModel = None,
15
35
  ) -> None:
16
36
  """
17
37
  Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
@@ -20,11 +40,13 @@ def apply_liger_kernel_to_llama(
20
40
  rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
21
41
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
22
42
  fused_linear_cross_entropy (bool):
23
- Whether to apply Liger's fused lienar cross entropy loss. Default is True.
43
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
24
44
  `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
25
45
  If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
26
46
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
27
47
  swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
48
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
49
+ loaded. Default is None.
28
50
  """
29
51
 
30
52
  assert not (
@@ -42,14 +64,48 @@ def apply_liger_kernel_to_llama(
42
64
  if cross_entropy:
43
65
  modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
44
66
  if fused_linear_cross_entropy:
45
- modeling_llama.LlamaForCausalLM.forward = lce_forward
67
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
68
+
69
+ if model is not None:
70
+ # The model instance already exists, so we need to additionally patch the
71
+ # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
72
+ config: PretrainedConfig = model.config
73
+
74
+ if hasattr(model, "model"):
75
+ # The case for LlamaForCausalLM or LlamaForSequenceClassification, for example
76
+ base_model = model.model
77
+ elif hasattr(model, "transformer"):
78
+ # LlamaForQuestionAnswering uses "transformer" instead of "model"
79
+ base_model = model.transformer
80
+ else:
81
+ # Direct LlamaModel
82
+ base_model = model
83
+
84
+ torch_dtype = config.torch_dtype
85
+ if rms_norm:
86
+ base_model.norm = LigerRMSNorm(
87
+ config.hidden_size, eps=config.rms_norm_eps
88
+ ).to(torch_dtype)
89
+
90
+ for decoder_layer in base_model.layers:
91
+ if swiglu:
92
+ decoder_layer.mlp = LigerSwiGLUMLP(config).to(torch_dtype)
93
+ if rms_norm:
94
+ decoder_layer.input_layernorm = LigerRMSNorm(
95
+ config.hidden_size, eps=config.rms_norm_eps
96
+ ).to(torch_dtype)
97
+ decoder_layer.post_attention_layernorm = LigerRMSNorm(
98
+ config.hidden_size, eps=config.rms_norm_eps
99
+ ).to(torch_dtype)
46
100
 
47
101
 
48
102
  def apply_liger_kernel_to_mistral(
49
103
  rope: bool = True,
50
- cross_entropy: bool = True,
104
+ cross_entropy: bool = False,
105
+ fused_linear_cross_entropy: bool = True,
51
106
  rms_norm: bool = True,
52
107
  swiglu: bool = True,
108
+ model: PreTrainedModel = None,
53
109
  ) -> None:
54
110
  """
55
111
  Apply Liger kernels to replace original implementation in HuggingFace Mistral models
@@ -57,9 +113,19 @@ def apply_liger_kernel_to_mistral(
57
113
  Args:
58
114
  rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
59
115
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
116
+ fused_linear_cross_entropy (bool):
117
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
118
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
119
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
120
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
60
121
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
61
122
  swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
123
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
124
+ loaded. Default is None.
62
125
  """
126
+ assert not (
127
+ cross_entropy and fused_linear_cross_entropy
128
+ ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
63
129
 
64
130
  from transformers.models.mistral import modeling_mistral
65
131
 
@@ -69,62 +135,587 @@ def apply_liger_kernel_to_mistral(
69
135
  modeling_mistral.MistralRMSNorm = LigerRMSNorm
70
136
  if cross_entropy:
71
137
  modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
138
+ if fused_linear_cross_entropy:
139
+ modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
72
140
  if swiglu:
73
141
  modeling_mistral.MistralMLP = LigerSwiGLUMLP
74
142
 
143
+ if model is not None:
144
+ # The model instance already exists, so we need to additionally patch the
145
+ # instance variables that reference already-instantiated modules
146
+ config: PretrainedConfig = model.config
147
+
148
+ if hasattr(model, "model"):
149
+ # The case for MistralForCausalLM, MistralForTokenClassification for example
150
+ base_model = model.model
151
+ else:
152
+ # Direct MistralModel
153
+ base_model = model
154
+
155
+ torch_dtype = config.torch_dtype
156
+ if rms_norm:
157
+ base_model.norm = LigerRMSNorm(
158
+ config.hidden_size, eps=config.rms_norm_eps
159
+ ).to(torch_dtype)
160
+
161
+ for decoder_layer in base_model.layers:
162
+ if swiglu:
163
+ decoder_layer.mlp = LigerSwiGLUMLP(config).to(torch_dtype)
164
+ if rms_norm:
165
+ decoder_layer.input_layernorm = LigerRMSNorm(
166
+ config.hidden_size, eps=config.rms_norm_eps
167
+ ).to(torch_dtype)
168
+ decoder_layer.post_attention_layernorm = LigerRMSNorm(
169
+ config.hidden_size, eps=config.rms_norm_eps
170
+ ).to(torch_dtype)
171
+
75
172
 
76
173
  def apply_liger_kernel_to_mixtral(
77
174
  rope: bool = True,
78
- cross_entropy: bool = True,
175
+ cross_entropy: bool = False,
176
+ fused_linear_cross_entropy: bool = True,
79
177
  rms_norm: bool = True,
80
178
  swiglu: bool = True,
179
+ model: PreTrainedModel = None,
81
180
  ) -> None:
82
181
  """
83
182
  Apply Liger kernels to replace original implementation in HuggingFace Mixtral models
84
183
 
85
184
  Args:
86
185
  rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
87
- cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
186
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
187
+ fused_linear_cross_entropy (bool):
188
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
189
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
190
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
88
191
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
89
192
  swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
193
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
194
+ loaded. Default is None.
90
195
  """
91
196
 
197
+ assert not (
198
+ cross_entropy and fused_linear_cross_entropy
199
+ ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
200
+
92
201
  from transformers.models.mixtral import modeling_mixtral
93
202
 
94
203
  if rope:
95
204
  modeling_mixtral.apply_rotary_pos_emb = liger_rotary_pos_emb
96
205
  if rms_norm:
97
- modeling_mixtral.MistralRMSNorm = LigerRMSNorm
206
+ modeling_mixtral.MixtralRMSNorm = LigerRMSNorm
98
207
  if cross_entropy:
99
208
  modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss
209
+ if fused_linear_cross_entropy:
210
+ modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
100
211
  if swiglu:
101
212
  modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
102
213
 
214
+ if model is not None:
215
+ # The model instance already exists, so we need to additionally patch the
216
+ # instance variables that reference already-instantiated modules
217
+ config: PretrainedConfig = model.config
218
+
219
+ if hasattr(model, "model"):
220
+ # The case for MixtralForCausalLM, MixtralForTokenClassification for example
221
+ base_model = model.model
222
+ else:
223
+ # Direct MixtralModel
224
+ base_model = model
225
+
226
+ torch_dtype = config.torch_dtype
227
+ if rms_norm:
228
+ base_model.norm = LigerRMSNorm(
229
+ config.hidden_size, eps=config.rms_norm_eps
230
+ ).to(torch_dtype)
231
+
232
+ for decoder_layer in base_model.layers:
233
+ if swiglu:
234
+ block_sparse_moe = decoder_layer.block_sparse_moe
235
+ patched_experts = nn.ModuleList(
236
+ [
237
+ LigerBlockSparseTop2MLP(config)
238
+ for _ in range(block_sparse_moe.num_experts)
239
+ ]
240
+ )
241
+ decoder_layer.block_sparse_moe.experts = patched_experts.to(torch_dtype)
242
+ if rms_norm:
243
+ decoder_layer.input_layernorm = LigerRMSNorm(
244
+ config.hidden_size, eps=config.rms_norm_eps
245
+ ).to(torch_dtype)
246
+ decoder_layer.post_attention_layernorm = LigerRMSNorm(
247
+ config.hidden_size, eps=config.rms_norm_eps
248
+ ).to(torch_dtype)
249
+
103
250
 
104
251
  def apply_liger_kernel_to_gemma(
105
252
  rope: bool = True,
106
- cross_entropy: bool = True,
253
+ cross_entropy: bool = False,
254
+ fused_linear_cross_entropy: bool = True,
107
255
  rms_norm: bool = True,
108
256
  geglu: bool = True,
257
+ model: PreTrainedModel = None,
109
258
  ) -> None:
110
259
  """
111
- Apply Liger kernels to replace original implementation in HuggingFace Gemma2 models
112
- to make GPU go burrr.
260
+ Apply Liger kernels to replace original implementation in HuggingFace Gemma
261
+ (Gemma 1 and 1.1 supported, for Gemma2 please use `apply_liger_kernel_to_gemma2` ) to make GPU go burrr.
113
262
 
114
263
  Args:
115
264
  rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
116
- cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
265
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
266
+ fused_linear_cross_entropy (bool):
267
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
268
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
269
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
117
270
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
118
271
  geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
272
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
273
+ loaded. Default is None.
119
274
  """
120
- # TODO(yundai424): add convergence test for gemma
275
+ assert not (
276
+ cross_entropy and fused_linear_cross_entropy
277
+ ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
278
+
121
279
  from transformers.models.gemma import modeling_gemma
122
280
 
281
+ # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
282
+ LigerRMSNormForGemma = partial(
283
+ LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
284
+ )
285
+
123
286
  if rope:
124
287
  modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
125
288
  if rms_norm:
126
- modeling_gemma.GemmaRMSNorm = LigerRMSNorm
289
+ modeling_gemma.GemmaRMSNorm = LigerRMSNormForGemma
127
290
  if cross_entropy:
128
291
  modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
129
292
  if geglu:
130
293
  modeling_gemma.GemmaMLP = LigerGEGLUMLP
294
+ if fused_linear_cross_entropy:
295
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
296
+
297
+ if model is not None:
298
+ # The model instance already exists, so we need to additionally patch the
299
+ # instance variables that reference already-instantiated modules
300
+ config: PretrainedConfig = model.config
301
+
302
+ if hasattr(model, "model"):
303
+ # The case for GemmaForCausalLM, GemmaForTokenClassification for example
304
+ base_model = model.model
305
+ else:
306
+ # Direct GemmaModel
307
+ base_model = model
308
+
309
+ torch_dtype = config.torch_dtype
310
+ if rms_norm:
311
+ base_model.norm = LigerRMSNormForGemma(
312
+ config.hidden_size, eps=config.rms_norm_eps
313
+ ).to(torch_dtype)
314
+
315
+ for decoder_layer in base_model.layers:
316
+ if geglu:
317
+ decoder_layer.mlp = LigerGEGLUMLP(config).to(torch_dtype)
318
+ if rms_norm:
319
+ decoder_layer.input_layernorm = LigerRMSNormForGemma(
320
+ config.hidden_size, eps=config.rms_norm_eps
321
+ ).to(torch_dtype)
322
+ decoder_layer.post_attention_layernorm = LigerRMSNormForGemma(
323
+ config.hidden_size, eps=config.rms_norm_eps
324
+ ).to(torch_dtype)
325
+
326
+
327
+ def apply_liger_kernel_to_gemma2(
328
+ rope: bool = True,
329
+ cross_entropy: bool = True,
330
+ rms_norm: bool = True,
331
+ geglu: bool = True,
332
+ model: PreTrainedModel = None,
333
+ ) -> None:
334
+ """
335
+ Apply Liger kernels to replace original implementation in HuggingFace Gemma2
336
+ (for Gemma1 please use `apply_liger_kernel_to_gemma`) to make GPU go burrr.
337
+
338
+ Args:
339
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
340
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
341
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
342
+ geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
343
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
344
+ loaded. Default is None.
345
+ """
346
+ print("Got here!")
347
+ from transformers.models.gemma2 import modeling_gemma2
348
+
349
+ LigerRMSNormForGemma2 = partial(LigerRMSNorm, offset=1.0, init_fn="zeros")
350
+ if rope:
351
+ modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
352
+ if rms_norm:
353
+ # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
354
+ modeling_gemma2.Gemma2RMSNorm = LigerRMSNormForGemma2
355
+ if cross_entropy:
356
+ modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
357
+ if geglu:
358
+ modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
359
+
360
+ if model is not None:
361
+ # The model instance already exists, so we need to additionally patch the
362
+ # instance variables that reference already-instantiated modules
363
+ config: PretrainedConfig = model.config
364
+
365
+ if hasattr(model, "model"):
366
+ # The case for Gemma2ForCausalLM, Gemma2ForTokenClassification for example
367
+ base_model = model.model
368
+ else:
369
+ # Direct Gemma2Model
370
+ base_model = model
371
+
372
+ torch_dtype = config.torch_dtype
373
+ if rms_norm:
374
+ base_model.norm = LigerRMSNormForGemma2(
375
+ config.hidden_size, eps=config.rms_norm_eps
376
+ ).to(torch_dtype)
377
+
378
+ for decoder_layer in base_model.layers:
379
+ if geglu:
380
+ decoder_layer.mlp = LigerGEGLUMLP(config).to(torch_dtype)
381
+ if rms_norm:
382
+ decoder_layer.input_layernorm = LigerRMSNormForGemma2(
383
+ config.hidden_size, eps=config.rms_norm_eps
384
+ ).to(torch_dtype)
385
+ decoder_layer.post_attention_layernorm = LigerRMSNormForGemma2(
386
+ config.hidden_size, eps=config.rms_norm_eps
387
+ ).to(torch_dtype)
388
+ decoder_layer.pre_feedforward_layernorm = LigerRMSNormForGemma2(
389
+ config.hidden_size, eps=config.rms_norm_eps
390
+ ).to(torch_dtype)
391
+ decoder_layer.post_feedforward_layernorm = LigerRMSNormForGemma2(
392
+ config.hidden_size, eps=config.rms_norm_eps
393
+ ).to(torch_dtype)
394
+
395
+
396
+ def apply_liger_kernel_to_qwen2(
397
+ rope: bool = True,
398
+ cross_entropy: bool = False,
399
+ fused_linear_cross_entropy: bool = True,
400
+ rms_norm: bool = True,
401
+ swiglu: bool = True,
402
+ model: PreTrainedModel = None,
403
+ ) -> None:
404
+ """
405
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen2 models
406
+
407
+ Args:
408
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
409
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
410
+ fused_linear_cross_entropy (bool):
411
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
412
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
413
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
414
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
415
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
416
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
417
+ loaded. Default is None.
418
+ """
419
+ assert not (
420
+ cross_entropy and fused_linear_cross_entropy
421
+ ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
422
+
423
+ from transformers.models.qwen2 import modeling_qwen2
424
+
425
+ if rope:
426
+ modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
427
+ if rms_norm:
428
+ modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
429
+ if cross_entropy:
430
+ modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
431
+ if fused_linear_cross_entropy:
432
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
433
+ if swiglu:
434
+ modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
435
+
436
+ if model is not None:
437
+ # The model instance already exists, so we need to additionally patch the
438
+ # instance variables that reference already-instantiated modules
439
+ config: PretrainedConfig = model.config
440
+
441
+ if hasattr(model, "model"):
442
+ # The case for Qwen2ForCausalLM, Qwen2ForTokenClassification for example
443
+ base_model = model.model
444
+ else:
445
+ # Direct Qwen2Model
446
+ base_model = model
447
+
448
+ torch_dtype = config.torch_dtype
449
+ if rms_norm:
450
+ base_model.norm = LigerRMSNorm(
451
+ config.hidden_size, eps=config.rms_norm_eps
452
+ ).to(torch_dtype)
453
+
454
+ for decoder_layer in base_model.layers:
455
+ if swiglu:
456
+ decoder_layer.mlp = LigerSwiGLUMLP(config).to(torch_dtype)
457
+ if rms_norm:
458
+ decoder_layer.input_layernorm = LigerRMSNorm(
459
+ config.hidden_size, eps=config.rms_norm_eps
460
+ ).to(torch_dtype)
461
+ decoder_layer.post_attention_layernorm = LigerRMSNorm(
462
+ config.hidden_size, eps=config.rms_norm_eps
463
+ ).to(torch_dtype)
464
+
465
+
466
+ def apply_liger_kernel_to_qwen2_vl(
467
+ cross_entropy: bool = False,
468
+ fused_linear_cross_entropy: bool = True,
469
+ rms_norm: bool = True,
470
+ layer_norm: bool = True,
471
+ swiglu: bool = True,
472
+ model: PreTrainedModel = None,
473
+ ) -> None:
474
+ """
475
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
476
+ NOTE: Qwen2-VL is not available in transformers<=4.44.2
477
+
478
+ Args:
479
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
480
+ fused_linear_cross_entropy (bool):
481
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
482
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
483
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
484
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
485
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
486
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
487
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
488
+ loaded. Default is None.
489
+ """
490
+ assert not (
491
+ cross_entropy and fused_linear_cross_entropy
492
+ ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
493
+
494
+ from transformers.models.qwen2_vl import modeling_qwen2_vl
495
+
496
+ from liger_kernel.transformers.model.qwen2_vl import (
497
+ lce_forward as qwen2_vl_lce_forward,
498
+ )
499
+
500
+ # TODO: Support Qwen2-VL's multimodal RoPE implementation
501
+
502
+ LigerRMSNormForQwen2VL = partial(LigerRMSNorm, init_fn="ones", casting_mode="gemma")
503
+ if rms_norm:
504
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
505
+ modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNormForQwen2VL
506
+ if layer_norm:
507
+ modeling_qwen2_vl.LayerNorm = LigerLayerNorm
508
+ if cross_entropy:
509
+ modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
510
+ if fused_linear_cross_entropy:
511
+ modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
512
+ if swiglu:
513
+ modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
514
+
515
+ if model is not None:
516
+ # The model instance already exists, so we need to additionally patch the
517
+ # instance variables that reference already-instantiated modules
518
+ config: PretrainedConfig = model.config
519
+
520
+ torch_dtype = config.torch_dtype
521
+
522
+ if hasattr(model, "model"):
523
+ # The case for Qwen2VLForConditionalGeneration.
524
+ base_model = model.model
525
+ else:
526
+ # Direct Qwen2VLModel
527
+ base_model = model
528
+
529
+ if hasattr(model, "visual"):
530
+ # Patch Qwen2VisionTransformerPretrainedModel
531
+ for vision_block in model.visual.blocks:
532
+ if layer_norm:
533
+ vision_block.norm1 = LigerLayerNorm(config.embed_dim, eps=1e-6).to(
534
+ torch_dtype
535
+ )
536
+ vision_block.norm2 = LigerLayerNorm(config.embed_dim, eps=1e-6).to(
537
+ torch_dtype
538
+ )
539
+
540
+ if rms_norm:
541
+ base_model.norm = LigerRMSNormForQwen2VL(
542
+ config.hidden_size, eps=config.rms_norm_eps
543
+ ).to(torch_dtype)
544
+ for decoder_layer in base_model.layers:
545
+ if swiglu:
546
+ decoder_layer.mlp = LigerSwiGLUMLP(config).to(torch_dtype)
547
+ if rms_norm:
548
+ decoder_layer.input_layernorm = LigerRMSNormForQwen2VL(
549
+ config.hidden_size, eps=config.rms_norm_eps
550
+ ).to(torch_dtype)
551
+ decoder_layer.post_attention_layernorm = LigerRMSNormForQwen2VL(
552
+ config.hidden_size, eps=config.rms_norm_eps
553
+ ).to(torch_dtype)
554
+
555
+
556
+ def apply_liger_kernel_to_phi3(
557
+ rope: bool = True,
558
+ cross_entropy: bool = False,
559
+ fused_linear_cross_entropy: bool = True,
560
+ rms_norm: bool = True,
561
+ swiglu: bool = True,
562
+ model: PreTrainedModel = None,
563
+ ) -> None:
564
+ """
565
+ Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.
566
+
567
+ Args:
568
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
569
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
570
+ fused_linear_cross_entropy (bool):
571
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
572
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
573
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
574
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
575
+ swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
576
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
577
+ loaded. Default is None.
578
+ """
579
+ assert not (
580
+ cross_entropy and fused_linear_cross_entropy
581
+ ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
582
+
583
+ from transformers.models.phi3 import modeling_phi3
584
+
585
+ if rope:
586
+ modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
587
+ if rms_norm:
588
+ modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama
589
+ if swiglu:
590
+ modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
591
+ if cross_entropy:
592
+ modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
593
+ if fused_linear_cross_entropy:
594
+ modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
595
+
596
+ if model is not None:
597
+ # The model instance already exists, so we need to additionally patch the
598
+ # instance variables that reference already-instantiated modules
599
+ config: PretrainedConfig = model.config
600
+
601
+ if hasattr(model, "model"):
602
+ # The case for Phi3ForCausalLM, Phi3ForTokenClassification for example
603
+ base_model = model.model
604
+ else:
605
+ # Direct Phi3Model
606
+ base_model = model
607
+
608
+ torch_dtype = config.torch_dtype
609
+ if rms_norm:
610
+ base_model.norm = LigerRMSNorm(
611
+ config.hidden_size, eps=config.rms_norm_eps
612
+ ).to(torch_dtype)
613
+
614
+ for decoder_layer in base_model.layers:
615
+ if swiglu:
616
+ decoder_layer.mlp = LigerPhi3SwiGLUMLP(config).to(torch_dtype)
617
+ if rms_norm:
618
+ decoder_layer.input_layernorm = LigerRMSNorm(
619
+ config.hidden_size, eps=config.rms_norm_eps
620
+ ).to(torch_dtype)
621
+ decoder_layer.post_attention_layernorm = LigerRMSNorm(
622
+ config.hidden_size, eps=config.rms_norm_eps
623
+ ).to(torch_dtype)
624
+
625
+
626
+ # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
627
+ MODEL_TYPE_TO_APPLY_LIGER_FN = {
628
+ "gemma": apply_liger_kernel_to_gemma,
629
+ "gemma2": apply_liger_kernel_to_gemma2,
630
+ "llama": apply_liger_kernel_to_llama,
631
+ "mistral": apply_liger_kernel_to_mistral,
632
+ "mixtral": apply_liger_kernel_to_mixtral,
633
+ "qwen2": apply_liger_kernel_to_qwen2,
634
+ "qwen2_vl": apply_liger_kernel_to_qwen2_vl,
635
+ "phi3": apply_liger_kernel_to_phi3,
636
+ }
637
+
638
+
639
+ def _apply_liger_kernel(model_type: str, **kwargs) -> None:
640
+ """
641
+ Applies Liger kernels based on the specified model type. The custom
642
+ kernels for the specified model type will be applied with the provided
643
+ keyword arguments, otherwise the default configuration will be used.
644
+
645
+ ** Note: Calling _apply_liger_kernel() after model initialization
646
+ will not be able to fully patch models. This must be called before model initialization.
647
+ If the model has already been instantiated
648
+
649
+ Args:
650
+ - model_type: the model types as defined in transformers/models/auto/modeling_auto.py
651
+ and specified in the model's config.json
652
+ - kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
653
+ """
654
+ if not model_type:
655
+ logger.info("Model type was not provided. No Liger kernels will be applied.")
656
+ return
657
+
658
+ if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
659
+ logger.info(
660
+ f"There are currently no Liger kernels supported for model type: {model_type}."
661
+ )
662
+ return
663
+
664
+ apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
665
+ apply_fn_signature = inspect.signature(apply_fn)
666
+
667
+ # Filter out the keyword arguments that are not supported by the apply function
668
+ applicable_kwargs = {
669
+ key: value
670
+ for key, value in kwargs.items()
671
+ if key in apply_fn_signature.parameters
672
+ }
673
+
674
+ logger.info(
675
+ f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}"
676
+ )
677
+
678
+ # Assume this is invoked pre-model initialization, so we only need to patch transformers code
679
+ apply_fn(**applicable_kwargs)
680
+
681
+
682
+ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
683
+ """
684
+ Applies Liger kernels to the provided model instance.
685
+
686
+ Args:
687
+ - model: the model instance to apply Liger kernels to
688
+ - kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
689
+ """
690
+ model_type = getattr(model, "config", None) and getattr(
691
+ model.config, "model_type", None
692
+ )
693
+
694
+ if not model_type:
695
+ logger.info(
696
+ "Model type could not be determined from model config. No Liger kernels will be applied."
697
+ )
698
+ return
699
+
700
+ if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
701
+ logger.info(
702
+ f"There are currently no Liger kernels supported for model type: {model_type}."
703
+ )
704
+ return
705
+
706
+ apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
707
+
708
+ apply_fn_signature = inspect.signature(apply_fn)
709
+
710
+ # Filter out the keyword arguments that are not supported by the apply function
711
+ applicable_kwargs = {
712
+ key: value
713
+ for key, value in kwargs.items()
714
+ if key in apply_fn_signature.parameters
715
+ }
716
+
717
+ logger.info(
718
+ f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}"
719
+ )
720
+
721
+ apply_fn(model=model, **applicable_kwargs)