liger-kernel 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. liger_kernel/chunked_loss/README.md +25 -0
  2. liger_kernel/chunked_loss/__init__.py +3 -0
  3. liger_kernel/chunked_loss/cpo_loss.py +18 -8
  4. liger_kernel/chunked_loss/dpo_loss.py +20 -10
  5. liger_kernel/chunked_loss/functional.py +4 -0
  6. liger_kernel/chunked_loss/fused_linear_distillation.py +58 -44
  7. liger_kernel/chunked_loss/fused_linear_preference.py +108 -60
  8. liger_kernel/chunked_loss/fused_linear_rlhf.py +213 -0
  9. liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +246 -0
  10. liger_kernel/chunked_loss/grpo_loss.py +160 -0
  11. liger_kernel/chunked_loss/jsd_loss.py +154 -0
  12. liger_kernel/chunked_loss/kto_loss.py +172 -0
  13. liger_kernel/chunked_loss/orpo_loss.py +8 -9
  14. liger_kernel/chunked_loss/simpo_loss.py +22 -8
  15. liger_kernel/env_report.py +5 -12
  16. liger_kernel/ops/cross_entropy.py +102 -51
  17. liger_kernel/ops/experimental/embedding.py +1 -3
  18. liger_kernel/ops/experimental/mm_int8int2.py +3 -9
  19. liger_kernel/ops/fused_linear_cross_entropy.py +89 -55
  20. liger_kernel/ops/fused_linear_jsd.py +14 -32
  21. liger_kernel/ops/geglu.py +6 -17
  22. liger_kernel/ops/group_norm.py +11 -28
  23. liger_kernel/ops/jsd.py +5 -9
  24. liger_kernel/ops/kl_div.py +8 -11
  25. liger_kernel/ops/layer_norm.py +23 -12
  26. liger_kernel/ops/qwen2vl_mrope.py +8 -25
  27. liger_kernel/ops/rms_norm.py +14 -32
  28. liger_kernel/ops/rope.py +31 -33
  29. liger_kernel/ops/swiglu.py +4 -8
  30. liger_kernel/ops/tvd.py +207 -0
  31. liger_kernel/ops/utils.py +3 -2
  32. liger_kernel/transformers/__init__.py +19 -24
  33. liger_kernel/transformers/auto_model.py +6 -13
  34. liger_kernel/transformers/cross_entropy.py +7 -9
  35. liger_kernel/transformers/experimental/embedding.py +1 -3
  36. liger_kernel/transformers/functional.py +28 -7
  37. liger_kernel/transformers/fused_linear_cross_entropy.py +15 -10
  38. liger_kernel/transformers/geglu.py +1 -4
  39. liger_kernel/transformers/group_norm.py +9 -15
  40. liger_kernel/transformers/jsd.py +1 -3
  41. liger_kernel/transformers/kl_div.py +1 -3
  42. liger_kernel/transformers/layer_norm.py +3 -9
  43. liger_kernel/transformers/model/gemma.py +18 -40
  44. liger_kernel/transformers/model/gemma2.py +19 -41
  45. liger_kernel/transformers/model/llama.py +22 -48
  46. liger_kernel/transformers/model/mistral.py +14 -26
  47. liger_kernel/transformers/model/mixtral.py +24 -54
  48. liger_kernel/transformers/model/mllama.py +16 -36
  49. liger_kernel/transformers/model/olmo2.py +124 -0
  50. liger_kernel/transformers/model/phi3.py +18 -40
  51. liger_kernel/transformers/model/qwen2.py +18 -40
  52. liger_kernel/transformers/model/qwen2_vl.py +36 -32
  53. liger_kernel/transformers/monkey_patch.py +214 -144
  54. liger_kernel/transformers/rms_norm.py +4 -4
  55. liger_kernel/transformers/rope.py +2 -2
  56. liger_kernel/transformers/swiglu.py +2 -8
  57. liger_kernel/transformers/trainer/__init__.py +1 -3
  58. liger_kernel/transformers/trainer/orpo_trainer.py +31 -18
  59. liger_kernel/transformers/tvd.py +13 -0
  60. liger_kernel/triton/__init__.py +1 -3
  61. liger_kernel/triton/monkey_patch.py +1 -3
  62. liger_kernel/utils.py +49 -0
  63. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/METADATA +53 -26
  64. liger_kernel-0.5.4.dist-info/RECORD +74 -0
  65. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/WHEEL +1 -1
  66. liger_kernel-0.5.2.dist-info/RECORD +0 -65
  67. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/LICENSE +0 -0
  68. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/NOTICE +0 -0
  69. {liger_kernel-0.5.2.dist-info → liger_kernel-0.5.4.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,11 @@
1
1
  import inspect
2
2
  import logging
3
+
3
4
  from functools import partial
4
5
  from typing import Callable
5
6
 
6
7
  import transformers
8
+
7
9
  from packaging import version
8
10
  from transformers import PreTrainedModel
9
11
 
@@ -12,38 +14,24 @@ from liger_kernel.transformers.functional import liger_cross_entropy
12
14
  from liger_kernel.transformers.geglu import LigerGEGLUMLP
13
15
  from liger_kernel.transformers.layer_norm import LigerLayerNorm
14
16
  from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
15
- from liger_kernel.transformers.model.gemma import (
16
- lce_forward_deprecated as gemma_lce_forward_deprecated,
17
- )
17
+ from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
18
18
  from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
19
- from liger_kernel.transformers.model.gemma2 import (
20
- lce_forward_deprecated as gemma2_lce_forward_deprected,
21
- )
19
+ from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
22
20
  from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
23
- from liger_kernel.transformers.model.llama import (
24
- lce_forward_deprecated as llama_lce_forward_deprecated,
25
- )
21
+ from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
26
22
  from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
27
23
  from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
28
- from liger_kernel.transformers.model.mixtral import (
29
- lce_forward_deprecated as mixtral_lce_forward_deprecated,
30
- )
24
+ from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
31
25
  from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
32
- from liger_kernel.transformers.model.phi3 import (
33
- lce_forward_deprecated as phi3_lce_forward_deprecated,
34
- )
26
+ from liger_kernel.transformers.model.phi3 import lce_forward_deprecated as phi3_lce_forward_deprecated
35
27
  from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
36
- from liger_kernel.transformers.model.qwen2 import (
37
- lce_forward_deprecated as qwen2_lce_forward_deprecated,
38
- )
28
+ from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
39
29
  from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
40
30
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
41
31
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
42
- from liger_kernel.transformers.swiglu import (
43
- LigerBlockSparseTop2MLP,
44
- LigerPhi3SwiGLUMLP,
45
- LigerSwiGLUMLP,
46
- )
32
+ from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
33
+ from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
34
+ from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
47
35
 
48
36
  transformer_version = version.parse(transformers.__version__)
49
37
 
@@ -57,28 +45,101 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
57
45
  module.__dict__[method_name] = new_method.__get__(module, module.__class__)
58
46
 
59
47
 
60
- def _patch_rms_norm_module(
61
- module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True
62
- ):
48
+ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True):
63
49
  module.offset = offset
64
50
  module.casting_mode = casting_mode
65
- module.variance_epsilon = (
66
- getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
67
- )
51
+ module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
68
52
  module.in_place = in_place
69
53
  _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
70
54
  _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
71
55
 
72
56
 
73
57
  def _patch_layer_norm_module(module, eps=1e-6):
74
- module.variance_epsilon = (
75
- getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
76
- )
58
+ module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
77
59
  module.hidden_size = module.normalized_shape
78
60
  _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
79
61
  _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
80
62
 
81
63
 
64
+ def apply_liger_kernel_to_granite(
65
+ rope: bool = True,
66
+ cross_entropy: bool = True,
67
+ fused_linear_cross_entropy: bool = False,
68
+ rms_norm: bool = True,
69
+ swiglu: bool = True,
70
+ model: PreTrainedModel = None,
71
+ ) -> None:
72
+ """
73
+ Apply Liger kernels to replace original implementation in HuggingFace Granite 3 models
74
+
75
+ Args:
76
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
77
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
78
+ fused_linear_cross_entropy (bool):
79
+ Whether to apply Liger's fused linear cross entropy loss. Default is False.
80
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
81
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
82
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
83
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
84
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
85
+ loaded. Default is None.
86
+
87
+
88
+
89
+ Debugging notes:
90
+ If LigerSwiGLUMLP is OK for Llama, it should be fine for Granite, but it's not.
91
+ """
92
+
93
+ assert not (cross_entropy and fused_linear_cross_entropy), (
94
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
95
+ )
96
+
97
+ from transformers.models.granite import modeling_granite
98
+ from transformers.models.granite.modeling_granite import GraniteModel
99
+
100
+ if swiglu:
101
+ modeling_granite.GraniteMLP = LigerSwiGLUMLP
102
+
103
+ if rms_norm:
104
+ modeling_granite.GraniteRMSNorm = LigerRMSNorm
105
+
106
+ if rope:
107
+ modeling_granite.apply_rotary_pos_emb = liger_rotary_pos_emb
108
+
109
+ if cross_entropy:
110
+ if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
111
+ from transformers.loss.loss_utils import nn
112
+
113
+ nn.functional.cross_entropy = liger_cross_entropy
114
+ else:
115
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
116
+ modeling_granite.CrossEntropyLoss = LigerCrossEntropyLoss
117
+
118
+ if fused_linear_cross_entropy:
119
+ raise NotImplementedError("LigerFusedLinearCrossEntropy is not available for Granite models.")
120
+ # NOTE: Granite model `GraniteForCausalLM.forward` scales logits each
121
+ # call, so we can't sidestep logit materialization. A bit more work
122
+ # would be needed to add a scaling term to the `LigerFusedLinearCrossEntropyFunction`
123
+ # for the logit output.
124
+
125
+ if model is not None:
126
+ # The model instance already exists, so we need to additionally patch the
127
+ # instance variables that reference already-instantiated modules (e.g. GraniteRMSNorm or GraniteMLP)
128
+
129
+ # get the base model from the model instance
130
+ base_model: GraniteModel = getattr(model, model.base_model_prefix, model)
131
+
132
+ if rms_norm:
133
+ _patch_rms_norm_module(base_model.norm)
134
+
135
+ for decoder_layer in base_model.layers:
136
+ if swiglu:
137
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
138
+ if rms_norm:
139
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
140
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
141
+
142
+
82
143
  def apply_liger_kernel_to_llama(
83
144
  rope: bool = True,
84
145
  cross_entropy: bool = False,
@@ -103,9 +164,9 @@ def apply_liger_kernel_to_llama(
103
164
  loaded. Default is None.
104
165
  """
105
166
 
106
- assert not (
107
- cross_entropy and fused_linear_cross_entropy
108
- ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
167
+ assert not (cross_entropy and fused_linear_cross_entropy), (
168
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
169
+ )
109
170
 
110
171
  from transformers.models.llama import modeling_llama
111
172
  from transformers.models.llama.modeling_llama import LlamaModel
@@ -145,9 +206,7 @@ def apply_liger_kernel_to_llama(
145
206
 
146
207
  for decoder_layer in base_model.layers:
147
208
  if swiglu:
148
- _bind_method_to_module(
149
- decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
150
- )
209
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
151
210
  if rms_norm:
152
211
  _patch_rms_norm_module(decoder_layer.input_layernorm)
153
212
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -179,22 +238,18 @@ def apply_liger_kernel_to_mllama(
179
238
  loaded. Default is None.
180
239
  """
181
240
 
182
- assert not (
183
- cross_entropy and fused_linear_cross_entropy
184
- ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
241
+ assert not (cross_entropy and fused_linear_cross_entropy), (
242
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
243
+ )
185
244
 
186
245
  from transformers.models.mllama import modeling_mllama
187
- from transformers.models.mllama.modeling_mllama import (
188
- MllamaForCausalLM,
189
- MllamaForConditionalGeneration,
190
- MllamaTextModel,
191
- MllamaVisionModel,
192
- )
246
+ from transformers.models.mllama.modeling_mllama import MllamaForCausalLM
247
+ from transformers.models.mllama.modeling_mllama import MllamaForConditionalGeneration
248
+ from transformers.models.mllama.modeling_mllama import MllamaTextModel
249
+ from transformers.models.mllama.modeling_mllama import MllamaVisionModel
193
250
 
194
251
  from liger_kernel.transformers.model.mllama import lce_forward as mllama_lce_forward
195
- from liger_kernel.transformers.model.mllama import (
196
- lce_forward_deprecated as mllama_lce_forward_deprecated,
197
- )
252
+ from liger_kernel.transformers.model.mllama import lce_forward_deprecated as mllama_lce_forward_deprecated
198
253
 
199
254
  if rope:
200
255
  modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
@@ -241,9 +296,7 @@ def apply_liger_kernel_to_mllama(
241
296
  _patch_rms_norm_module(text_model.norm)
242
297
  for decoder_layer in text_model.layers:
243
298
  if swiglu:
244
- _bind_method_to_module(
245
- decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
246
- )
299
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
247
300
  if rms_norm:
248
301
  _patch_rms_norm_module(decoder_layer.input_layernorm)
249
302
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -287,9 +340,9 @@ def apply_liger_kernel_to_mistral(
287
340
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
288
341
  loaded. Default is None.
289
342
  """
290
- assert not (
291
- cross_entropy and fused_linear_cross_entropy
292
- ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
343
+ assert not (cross_entropy and fused_linear_cross_entropy), (
344
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
345
+ )
293
346
 
294
347
  from transformers.models.mistral import modeling_mistral
295
348
  from transformers.models.mistral.modeling_mistral import MistralModel
@@ -317,9 +370,7 @@ def apply_liger_kernel_to_mistral(
317
370
 
318
371
  for decoder_layer in base_model.layers:
319
372
  if swiglu:
320
- _bind_method_to_module(
321
- decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
322
- )
373
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
323
374
  if rms_norm:
324
375
  _patch_rms_norm_module(decoder_layer.input_layernorm)
325
376
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -349,9 +400,9 @@ def apply_liger_kernel_to_mixtral(
349
400
  loaded. Default is None.
350
401
  """
351
402
 
352
- assert not (
353
- cross_entropy and fused_linear_cross_entropy
354
- ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
403
+ assert not (cross_entropy and fused_linear_cross_entropy), (
404
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
405
+ )
355
406
 
356
407
  from transformers.models.mixtral import modeling_mixtral
357
408
  from transformers.models.mixtral.modeling_mixtral import MixtralModel
@@ -391,9 +442,7 @@ def apply_liger_kernel_to_mixtral(
391
442
  for decoder_layer in base_model.layers:
392
443
  if swiglu:
393
444
  for expert in decoder_layer.block_sparse_moe.experts:
394
- _bind_method_to_module(
395
- expert, "forward", LigerBlockSparseTop2MLP.forward
396
- )
445
+ _bind_method_to_module(expert, "forward", LigerBlockSparseTop2MLP.forward)
397
446
  if rms_norm:
398
447
  _patch_rms_norm_module(decoder_layer.input_layernorm)
399
448
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -423,20 +472,16 @@ def apply_liger_kernel_to_gemma(
423
472
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
424
473
  loaded. Default is None.
425
474
  """
426
- assert not (
427
- cross_entropy and fused_linear_cross_entropy
428
- ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
475
+ assert not (cross_entropy and fused_linear_cross_entropy), (
476
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
477
+ )
429
478
 
430
479
  from transformers.models.gemma import modeling_gemma
431
480
  from transformers.models.gemma.modeling_gemma import GemmaModel
432
481
 
433
482
  # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
434
- LigerRMSNormForGemma = partial(
435
- LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
436
- )
437
- _patch_rms_norm_module_for_gemma = partial(
438
- _patch_rms_norm_module, casting_mode="gemma", offset=1.0
439
- )
483
+ LigerRMSNormForGemma = partial(LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma")
484
+ _patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
440
485
 
441
486
  if rope:
442
487
  modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
@@ -471,9 +516,7 @@ def apply_liger_kernel_to_gemma(
471
516
 
472
517
  for decoder_layer in base_model.layers:
473
518
  if geglu:
474
- _bind_method_to_module(
475
- decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
476
- )
519
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
477
520
  if rms_norm:
478
521
  _patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
479
522
  _patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)
@@ -503,16 +546,14 @@ def apply_liger_kernel_to_gemma2(
503
546
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
504
547
  loaded. Default is None.
505
548
  """
506
- assert not (
507
- cross_entropy and fused_linear_cross_entropy
508
- ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
549
+ assert not (cross_entropy and fused_linear_cross_entropy), (
550
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
551
+ )
509
552
 
510
553
  from transformers.models.gemma2 import modeling_gemma2
511
554
  from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
512
555
 
513
- LigerRMSNormForGemma2 = partial(
514
- LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False
515
- )
556
+ LigerRMSNormForGemma2 = partial(LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False)
516
557
  _patch_rms_norm_module_for_gemma2 = partial(
517
558
  _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
518
559
  )
@@ -551,20 +592,12 @@ def apply_liger_kernel_to_gemma2(
551
592
 
552
593
  for decoder_layer in base_model.layers:
553
594
  if geglu:
554
- _bind_method_to_module(
555
- decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
556
- )
595
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerGEGLUMLP.forward)
557
596
  if rms_norm:
558
597
  _patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
559
- _patch_rms_norm_module_for_gemma2(
560
- decoder_layer.post_attention_layernorm
561
- )
562
- _patch_rms_norm_module_for_gemma2(
563
- decoder_layer.pre_feedforward_layernorm
564
- )
565
- _patch_rms_norm_module_for_gemma2(
566
- decoder_layer.post_feedforward_layernorm
567
- )
598
+ _patch_rms_norm_module_for_gemma2(decoder_layer.post_attention_layernorm)
599
+ _patch_rms_norm_module_for_gemma2(decoder_layer.pre_feedforward_layernorm)
600
+ _patch_rms_norm_module_for_gemma2(decoder_layer.post_feedforward_layernorm)
568
601
 
569
602
 
570
603
  def apply_liger_kernel_to_qwen2(
@@ -590,9 +623,9 @@ def apply_liger_kernel_to_qwen2(
590
623
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
591
624
  loaded. Default is None.
592
625
  """
593
- assert not (
594
- cross_entropy and fused_linear_cross_entropy
595
- ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
626
+ assert not (cross_entropy and fused_linear_cross_entropy), (
627
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
628
+ )
596
629
 
597
630
  from transformers.models.qwen2 import modeling_qwen2
598
631
  from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
@@ -633,9 +666,7 @@ def apply_liger_kernel_to_qwen2(
633
666
 
634
667
  for decoder_layer in base_model.layers:
635
668
  if swiglu:
636
- _bind_method_to_module(
637
- decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
638
- )
669
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
639
670
  if rms_norm:
640
671
  _patch_rms_norm_module(decoder_layer.input_layernorm)
641
672
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -667,21 +698,17 @@ def apply_liger_kernel_to_qwen2_vl(
667
698
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
668
699
  loaded. Default is None.
669
700
  """
670
- assert not (
671
- cross_entropy and fused_linear_cross_entropy
672
- ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
701
+ assert not (cross_entropy and fused_linear_cross_entropy), (
702
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
703
+ )
673
704
 
674
705
  from transformers.models.qwen2_vl import modeling_qwen2_vl
675
706
  from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
676
707
 
677
- from liger_kernel.transformers.model.qwen2_vl import (
678
- lce_forward as qwen2_vl_lce_forward,
679
- )
708
+ from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
680
709
 
681
710
  if rope:
682
- modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = (
683
- liger_multimodal_rotary_pos_emb
684
- )
711
+ modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = liger_multimodal_rotary_pos_emb
685
712
  if rms_norm:
686
713
  # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
687
714
  modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
@@ -712,9 +739,7 @@ def apply_liger_kernel_to_qwen2_vl(
712
739
  _patch_rms_norm_module(base_model.norm)
713
740
  for decoder_layer in base_model.layers:
714
741
  if swiglu:
715
- _bind_method_to_module(
716
- decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
717
- )
742
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
718
743
  if rms_norm:
719
744
  _patch_rms_norm_module(decoder_layer.input_layernorm)
720
745
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -743,9 +768,9 @@ def apply_liger_kernel_to_phi3(
743
768
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
744
769
  loaded. Default is None.
745
770
  """
746
- assert not (
747
- cross_entropy and fused_linear_cross_entropy
748
- ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
771
+ assert not (cross_entropy and fused_linear_cross_entropy), (
772
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
773
+ )
749
774
 
750
775
  from transformers.models.phi3 import modeling_phi3
751
776
  from transformers.models.phi3.modeling_phi3 import Phi3Model
@@ -783,23 +808,86 @@ def apply_liger_kernel_to_phi3(
783
808
 
784
809
  for decoder_layer in base_model.layers:
785
810
  if swiglu:
786
- _bind_method_to_module(
787
- decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward
788
- )
811
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward)
789
812
  if rms_norm:
790
813
  _patch_rms_norm_module(decoder_layer.input_layernorm)
791
814
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
792
815
 
793
816
 
817
+ def apply_liger_kernel_to_olmo2(
818
+ rope: bool = True,
819
+ cross_entropy: bool = False,
820
+ fused_linear_cross_entropy: bool = True,
821
+ rms_norm: bool = True,
822
+ swiglu: bool = True,
823
+ model: PreTrainedModel = None,
824
+ ) -> None:
825
+ """
826
+ Apply Liger kernels to replace original implementation in HuggingFace OLMO2 models.
827
+
828
+ Args:
829
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
830
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
831
+ fused_linear_cross_entropy (bool):
832
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
833
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
834
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
835
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
836
+ swiglu (bool): Whether to apply Liger's SwiGLU Olmo2MLP. Default is True.
837
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
838
+ loaded. Default is None.
839
+ """
840
+ assert not (cross_entropy and fused_linear_cross_entropy), (
841
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
842
+ )
843
+
844
+ from transformers.models.olmo2 import modeling_olmo2
845
+ from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
846
+
847
+ from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
848
+
849
+ if rope:
850
+ modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
851
+ if rms_norm:
852
+ modeling_olmo2.Olmo2RMSNorm = partial(LigerRMSNorm, in_place=False)
853
+ if swiglu:
854
+ modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
855
+ if cross_entropy:
856
+ from transformers.loss.loss_utils import nn
857
+
858
+ nn.functional.cross_entropy = liger_cross_entropy
859
+ if fused_linear_cross_entropy:
860
+ modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
861
+
862
+ if model is not None:
863
+ # The model instance already exists, so we need to additionally patch the
864
+ # instance variables that reference already-instantiated modules
865
+
866
+ # get the base model from the model instance
867
+ base_model: Olmo2Model = getattr(model, model.base_model_prefix, model)
868
+
869
+ if rms_norm:
870
+ _patch_rms_norm_module(base_model.norm)
871
+
872
+ for decoder_layer in base_model.layers:
873
+ if swiglu:
874
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
875
+ if rms_norm:
876
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
877
+ _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
878
+
879
+
794
880
  # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
795
881
  MODEL_TYPE_TO_APPLY_LIGER_FN = {
796
882
  "gemma": apply_liger_kernel_to_gemma,
797
883
  "gemma2": apply_liger_kernel_to_gemma2,
798
884
  "llama": apply_liger_kernel_to_llama,
885
+ "granite": apply_liger_kernel_to_granite,
799
886
  "mllama": apply_liger_kernel_to_mllama,
800
887
  "mllama_text_model": apply_liger_kernel_to_mllama,
801
888
  "mistral": apply_liger_kernel_to_mistral,
802
889
  "mixtral": apply_liger_kernel_to_mixtral,
890
+ "olmo2": apply_liger_kernel_to_olmo2,
803
891
  "qwen2": apply_liger_kernel_to_qwen2,
804
892
  "qwen2_vl": apply_liger_kernel_to_qwen2_vl,
805
893
  "phi3": apply_liger_kernel_to_phi3,
@@ -826,24 +914,16 @@ def _apply_liger_kernel(model_type: str, **kwargs) -> None:
826
914
  return
827
915
 
828
916
  if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
829
- logger.info(
830
- f"There are currently no Liger kernels supported for model type: {model_type}."
831
- )
917
+ logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.")
832
918
  return
833
919
 
834
920
  apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
835
921
  apply_fn_signature = inspect.signature(apply_fn)
836
922
 
837
923
  # Filter out the keyword arguments that are not supported by the apply function
838
- applicable_kwargs = {
839
- key: value
840
- for key, value in kwargs.items()
841
- if key in apply_fn_signature.parameters
842
- }
924
+ applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters}
843
925
 
844
- logger.info(
845
- f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}"
846
- )
926
+ logger.info(f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}")
847
927
 
848
928
  # Assume this is invoked pre-model initialization, so we only need to patch transformers code
849
929
  apply_fn(**applicable_kwargs)
@@ -857,20 +937,14 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
857
937
  - model: the model instance to apply Liger kernels to
858
938
  - kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
859
939
  """
860
- model_type = getattr(model, "config", None) and getattr(
861
- model.config, "model_type", None
862
- )
940
+ model_type = getattr(model, "config", None) and getattr(model.config, "model_type", None)
863
941
 
864
942
  if not model_type:
865
- logger.info(
866
- "Model type could not be determined from model config. No Liger kernels will be applied."
867
- )
943
+ logger.info("Model type could not be determined from model config. No Liger kernels will be applied.")
868
944
  return
869
945
 
870
946
  if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
871
- logger.info(
872
- f"There are currently no Liger kernels supported for model type: {model_type}."
873
- )
947
+ logger.info(f"There are currently no Liger kernels supported for model type: {model_type}.")
874
948
  return
875
949
 
876
950
  apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
@@ -878,11 +952,7 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
878
952
  apply_fn_signature = inspect.signature(apply_fn)
879
953
 
880
954
  # Filter out the keyword arguments that are not supported by the apply function
881
- applicable_kwargs = {
882
- key: value
883
- for key, value in kwargs.items()
884
- if key in apply_fn_signature.parameters
885
- }
955
+ applicable_kwargs = {key: value for key, value in kwargs.items() if key in apply_fn_signature.parameters}
886
956
  logger.info(
887
957
  f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}"
888
958
  )
@@ -19,9 +19,7 @@ class LigerRMSNorm(nn.Module):
19
19
  "ones",
20
20
  "zeros",
21
21
  ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}"
22
- self.weight = nn.Parameter(
23
- torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
24
- )
22
+ self.weight = nn.Parameter(torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size))
25
23
  self.variance_epsilon, self.offset, self.casting_mode, self.in_place = (
26
24
  eps,
27
25
  offset,
@@ -40,4 +38,6 @@ class LigerRMSNorm(nn.Module):
40
38
  )
41
39
 
42
40
  def extra_repr(self):
43
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}"
41
+ return (
42
+ f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}, offset={self.offset}, in_place={self.in_place}"
43
+ )
@@ -8,8 +8,8 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
8
8
  Args:
9
9
  q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
10
10
  k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
11
- cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim).
12
- sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim).
11
+ cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
12
+ sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
13
13
  position_ids (torch.Tensor, optional): The position ids tensor. Defaults to None.
14
14
  unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
15
15
 
@@ -16,10 +16,7 @@ class LigerSwiGLUMLP(nn.Module):
16
16
  raise ValueError(f"Activation function {config.hidden_act} not supported.")
17
17
 
18
18
  def forward(self, x):
19
-
20
- return self.down_proj(
21
- LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x))
22
- )
19
+ return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
23
20
 
24
21
 
25
22
  class LigerBlockSparseTop2MLP(nn.Module):
@@ -36,7 +33,6 @@ class LigerBlockSparseTop2MLP(nn.Module):
36
33
  raise ValueError(f"Activation function {config.hidden_act} not supported.")
37
34
 
38
35
  def forward(self, x):
39
-
40
36
  return self.w2(LigerSiLUMulFunction.apply(self.w1(x), self.w3(x)))
41
37
 
42
38
 
@@ -51,9 +47,7 @@ class LigerPhi3SwiGLUMLP(nn.Module):
51
47
  self.config = config
52
48
  self.hidden_size = config.hidden_size
53
49
  self.intermediate_size = config.intermediate_size
54
- self.gate_up_proj = nn.Linear(
55
- self.hidden_size, 2 * self.intermediate_size, bias=False
56
- )
50
+ self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
57
51
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
58
52
  if config.hidden_act not in ["silu", "swish"]:
59
53
  raise ValueError(f"Activation function {config.hidden_act} not supported.")
@@ -1,6 +1,4 @@
1
1
  try:
2
- from liger_kernel.transformers.trainer.orpo_trainer import ( # noqa: F401
3
- LigerORPOTrainer,
4
- )
2
+ from liger_kernel.transformers.trainer.orpo_trainer import LigerORPOTrainer # noqa: F401
5
3
  except ImportError:
6
4
  raise ImportError("Please `pip install trl` to use LigerORPOTrainer")