liger-kernel-nightly 0.6.2.dev20251011154427__py3-none-any.whl → 0.6.4.dev20260107111351__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (97) hide show
  1. liger_kernel/chunked_loss/cosine_similarity_loss.py +20 -5
  2. liger_kernel/chunked_loss/fused_linear_distillation.py +23 -5
  3. liger_kernel/chunked_loss/fused_linear_ppo.py +21 -5
  4. liger_kernel/chunked_loss/grpo_loss.py +8 -5
  5. liger_kernel/chunked_loss/jsd_loss.py +39 -11
  6. liger_kernel/ops/__init__.py +141 -0
  7. liger_kernel/ops/backends/README.md +151 -0
  8. liger_kernel/ops/backends/__init__.py +13 -0
  9. liger_kernel/ops/backends/_ascend/__init__.py +5 -0
  10. liger_kernel/ops/backends/_ascend/ascend-ub-manager-design.md +485 -0
  11. liger_kernel/ops/backends/_ascend/ops/__init__.py +43 -0
  12. liger_kernel/ops/backends/_ascend/ops/geglu.py +244 -0
  13. liger_kernel/ops/backends/_ascend/ops/qwen2vl_mrope.py +285 -0
  14. liger_kernel/ops/backends/_ascend/ops/rope.py +290 -0
  15. liger_kernel/ops/backends/_ascend/ops/swiglu.py +142 -0
  16. liger_kernel/ops/backends/_ascend/ub_manager.py +349 -0
  17. liger_kernel/ops/backends/registry.py +61 -0
  18. liger_kernel/ops/cross_entropy.py +75 -12
  19. liger_kernel/ops/dyt.py +5 -2
  20. liger_kernel/ops/fused_add_rms_norm.py +5 -1
  21. liger_kernel/ops/fused_linear_cross_entropy.py +45 -14
  22. liger_kernel/ops/geglu.py +5 -3
  23. liger_kernel/ops/group_norm.py +2 -1
  24. liger_kernel/ops/grpo_loss.py +3 -1
  25. liger_kernel/ops/layer_norm.py +86 -66
  26. liger_kernel/ops/poly_norm.py +390 -0
  27. liger_kernel/ops/rms_norm.py +131 -49
  28. liger_kernel/ops/tiled_mlp.py +136 -0
  29. liger_kernel/ops/utils.py +14 -0
  30. liger_kernel/transformers/__init__.py +30 -0
  31. liger_kernel/transformers/auto_model.py +21 -0
  32. liger_kernel/transformers/cross_entropy.py +9 -4
  33. liger_kernel/transformers/dyt.py +1 -1
  34. liger_kernel/transformers/experimental/embedding.py +1 -1
  35. liger_kernel/transformers/functional.py +48 -25
  36. liger_kernel/transformers/fused_add_rms_norm.py +1 -1
  37. liger_kernel/transformers/fused_linear_cross_entropy.py +9 -4
  38. liger_kernel/transformers/fused_linear_jsd.py +1 -1
  39. liger_kernel/transformers/fused_neighborhood_attention.py +1 -1
  40. liger_kernel/transformers/geglu.py +1 -1
  41. liger_kernel/transformers/group_norm.py +1 -1
  42. liger_kernel/transformers/grpo_loss.py +57 -2
  43. liger_kernel/transformers/jsd.py +1 -1
  44. liger_kernel/transformers/kl_div.py +1 -1
  45. liger_kernel/transformers/layer_norm.py +1 -1
  46. liger_kernel/transformers/llama4_rope.py +1 -1
  47. liger_kernel/transformers/model/falcon_h1.py +19 -5
  48. liger_kernel/transformers/model/gemma.py +17 -6
  49. liger_kernel/transformers/model/gemma2.py +14 -5
  50. liger_kernel/transformers/model/gemma3.py +26 -12
  51. liger_kernel/transformers/model/glm4.py +16 -4
  52. liger_kernel/transformers/model/glm4v.py +16 -4
  53. liger_kernel/transformers/model/glm4v_moe.py +23 -4
  54. liger_kernel/transformers/model/gpt_oss.py +211 -0
  55. liger_kernel/transformers/model/hunyuan_v1.py +134 -0
  56. liger_kernel/transformers/model/internvl.py +12 -5
  57. liger_kernel/transformers/model/llama.py +14 -5
  58. liger_kernel/transformers/model/llama4.py +16 -4
  59. liger_kernel/transformers/model/llava.py +12 -4
  60. liger_kernel/transformers/model/loss_utils.py +31 -3
  61. liger_kernel/transformers/model/mistral.py +15 -6
  62. liger_kernel/transformers/model/mixtral.py +16 -7
  63. liger_kernel/transformers/model/mllama.py +12 -4
  64. liger_kernel/transformers/model/olmo2.py +16 -4
  65. liger_kernel/transformers/model/olmo3.py +142 -0
  66. liger_kernel/transformers/model/output_classes.py +147 -0
  67. liger_kernel/transformers/model/paligemma.py +23 -5
  68. liger_kernel/transformers/model/phi3.py +14 -7
  69. liger_kernel/transformers/model/qwen2.py +16 -3
  70. liger_kernel/transformers/model/qwen2_5_vl.py +14 -6
  71. liger_kernel/transformers/model/qwen2_vl.py +16 -4
  72. liger_kernel/transformers/model/qwen3.py +20 -5
  73. liger_kernel/transformers/model/qwen3_moe.py +19 -5
  74. liger_kernel/transformers/model/qwen3_next.py +146 -0
  75. liger_kernel/transformers/model/qwen3_vl.py +150 -0
  76. liger_kernel/transformers/model/qwen3_vl_moe.py +126 -0
  77. liger_kernel/transformers/model/smollm3.py +15 -6
  78. liger_kernel/transformers/model/smolvlm.py +158 -0
  79. liger_kernel/transformers/monkey_patch.py +702 -48
  80. liger_kernel/transformers/multi_token_attention.py +1 -1
  81. liger_kernel/transformers/poly_norm.py +42 -0
  82. liger_kernel/transformers/qwen2vl_mrope.py +1 -1
  83. liger_kernel/transformers/rms_norm.py +15 -3
  84. liger_kernel/transformers/rope.py +45 -1
  85. liger_kernel/transformers/softmax.py +1 -1
  86. liger_kernel/transformers/sparsemax.py +1 -1
  87. liger_kernel/transformers/swiglu.py +18 -1
  88. liger_kernel/transformers/tiled_mlp.py +133 -0
  89. liger_kernel/transformers/tvd.py +1 -1
  90. liger_kernel/utils.py +52 -0
  91. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/METADATA +12 -3
  92. liger_kernel_nightly-0.6.4.dev20260107111351.dist-info/RECORD +130 -0
  93. liger_kernel_nightly-0.6.2.dev20251011154427.dist-info/RECORD +0 -107
  94. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/LICENSE +0 -0
  95. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/NOTICE +0 -0
  96. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/WHEEL +0 -0
  97. {liger_kernel_nightly-0.6.2.dev20251011154427.dist-info → liger_kernel_nightly-0.6.4.dev20260107111351.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forwa
20
20
  from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
21
21
  from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
22
22
  from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
23
+ from liger_kernel.transformers.model.gpt_oss import lce_forward as gpt_oss_lce_forward
23
24
  from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
24
25
  from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
25
26
  from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
@@ -34,6 +35,7 @@ from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_f
34
35
  from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
35
36
  from liger_kernel.transformers.rms_norm import LigerRMSNorm
36
37
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
38
+ from liger_kernel.transformers.rope import liger_rotary_pos_emb_vision
37
39
  from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP
38
40
  from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP
39
41
  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
@@ -428,7 +430,7 @@ def apply_liger_kernel_to_llava(
428
430
  f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
429
431
  f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
430
432
  )
431
- text_kwargs["model"] = model.language_model
433
+ text_kwargs["model"] = model.model.language_model
432
434
  text_liger_fn(**text_kwargs)
433
435
  elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
434
436
  logger.warning(f"{text_model_name} is not supported by Liger kernel.")
@@ -443,7 +445,7 @@ def apply_liger_kernel_to_llava(
443
445
  f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
444
446
  f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
445
447
  )
446
- vision_kwargs["model"] = model.vision_tower
448
+ vision_kwargs["model"] = model.model.vision_tower
447
449
  vision_liger_fn(**vision_kwargs)
448
450
  elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
449
451
  logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
@@ -469,7 +471,7 @@ def apply_liger_kernel_to_llama4(
469
471
  `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
470
472
  If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
471
473
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
472
- swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
474
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
473
475
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
474
476
  loaded. Default is None.
475
477
  """
@@ -522,7 +524,10 @@ def apply_liger_kernel_to_llama4(
522
524
  _patch_rms_norm_module(text_model.norm)
523
525
  for decoder_layer in text_model.layers:
524
526
  if swiglu:
525
- _patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
527
+ if decoder_layer.is_moe_layer:
528
+ _patch_swiglu_module(decoder_layer.feed_forward.shared_expert, LigerSwiGLUMLP)
529
+ else:
530
+ _patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
526
531
  if rms_norm:
527
532
  _patch_rms_norm_module(decoder_layer.input_layernorm)
528
533
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -610,8 +615,8 @@ def apply_liger_kernel_to_mllama(
610
615
  # instance variables that reference already-instantiated modules
611
616
 
612
617
  if isinstance(model, MllamaForConditionalGeneration):
613
- language_model: MllamaForCausalLM = model.language_model
614
- vision_model: MllamaVisionModel = model.vision_model
618
+ language_model: MllamaForCausalLM = model.model.language_model
619
+ vision_model: MllamaVisionModel = model.model.vision_model
615
620
  if isinstance(language_model, MllamaForCausalLM):
616
621
  text_model: MllamaTextModel = language_model.model
617
622
  else:
@@ -1113,8 +1118,8 @@ def apply_liger_kernel_to_gemma3(
1113
1118
  # instance variables that reference already-instantiated modules
1114
1119
 
1115
1120
  if isinstance(model, Gemma3ForConditionalGeneration):
1116
- if isinstance(model.vision_tower, SiglipVisionModel):
1117
- vision_tower = model.vision_tower
1121
+ if isinstance(model.model.vision_tower, SiglipVisionModel):
1122
+ vision_tower = model.model.vision_tower
1118
1123
 
1119
1124
  _patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
1120
1125
 
@@ -1127,7 +1132,7 @@ def apply_liger_kernel_to_gemma3(
1127
1132
  raise TypeError("The vision tower must be SiglipVisionModel")
1128
1133
 
1129
1134
  if rms_norm:
1130
- _patch_rms_norm_module_for_gemma3(model.multi_modal_projector.mm_soft_emb_norm)
1135
+ _patch_rms_norm_module_for_gemma3(model.model.multi_modal_projector.mm_soft_emb_norm)
1131
1136
 
1132
1137
  apply_liger_kernel_to_gemma3_text(
1133
1138
  rope=rope,
@@ -1135,7 +1140,7 @@ def apply_liger_kernel_to_gemma3(
1135
1140
  fused_linear_cross_entropy=False,
1136
1141
  rms_norm=rms_norm,
1137
1142
  geglu=geglu,
1138
- model=model.language_model,
1143
+ model=model.model.language_model,
1139
1144
  )
1140
1145
 
1141
1146
  else:
@@ -1223,7 +1228,7 @@ def apply_liger_kernel_to_paligemma(
1223
1228
  if not isinstance(model, PaliGemmaForConditionalGeneration):
1224
1229
  raise TypeError("model have to be of type PaliGemmaForConditionalGeneration")
1225
1230
 
1226
- vision_tower: SiglipVisionModel = model.vision_tower
1231
+ vision_tower: SiglipVisionModel = model.model.vision_tower
1227
1232
 
1228
1233
  _patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
1229
1234
 
@@ -1233,7 +1238,7 @@ def apply_liger_kernel_to_paligemma(
1233
1238
  _patch_layer_norm_module(layer.layer_norm1)
1234
1239
  _patch_layer_norm_module(layer.layer_norm2)
1235
1240
 
1236
- language_model = model.language_model
1241
+ language_model = model.model.language_model
1237
1242
 
1238
1243
  if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
1239
1244
  apply_liger_kernel_to_gemma(
@@ -1454,6 +1459,79 @@ def apply_liger_kernel_to_qwen3_moe(
1454
1459
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1455
1460
 
1456
1461
 
1462
+ def apply_liger_kernel_to_gpt_oss(
1463
+ rope: bool = True,
1464
+ cross_entropy: bool = False,
1465
+ fused_linear_cross_entropy: bool = True,
1466
+ rms_norm: bool = True,
1467
+ swiglu: bool = False, # Set to False by default since GPT-OSS has custom expert implementation
1468
+ model: PreTrainedModel = None,
1469
+ ) -> None:
1470
+ """
1471
+ Apply Liger kernels to replace original implementation in HuggingFace GPT-OSS models.
1472
+ NOTE: GPT-OSS is supported in transformers >= 4.55.0
1473
+ NOTE: SwiGLU patching is disabled by default for GPT-OSS as it uses a custom expert
1474
+ implementation with clamping and MXFP4 quantization.
1475
+
1476
+ Args:
1477
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1478
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1479
+ fused_linear_cross_entropy (bool):
1480
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1481
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1482
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1483
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1484
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1485
+ Note: GPT-OSS uses a custom expert implementation, so SwiGLU patching is disabled by default.
1486
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1487
+ loaded. Default is None.
1488
+ """
1489
+ if version.parse(transformers.__version__) < version.parse("4.55.0"):
1490
+ logger.warning("GPT-OSS support requires transformers >= 4.55.0")
1491
+ return
1492
+
1493
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1494
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1495
+ )
1496
+
1497
+ from transformers.models.gpt_oss import modeling_gpt_oss
1498
+ from transformers.models.gpt_oss.modeling_gpt_oss import GptOssModel
1499
+
1500
+ if rope:
1501
+ modeling_gpt_oss.apply_rotary_pos_emb = liger_rotary_pos_emb
1502
+
1503
+ if rms_norm:
1504
+ modeling_gpt_oss.GptOssRMSNorm = LigerRMSNorm
1505
+
1506
+ if cross_entropy:
1507
+ from transformers.loss.loss_utils import nn
1508
+
1509
+ nn.functional.cross_entropy = liger_cross_entropy
1510
+
1511
+ if fused_linear_cross_entropy:
1512
+ if model is not None:
1513
+ model.forward = MethodType(gpt_oss_lce_forward, model)
1514
+ else:
1515
+ modeling_gpt_oss.GptOssForCausalLM.forward = gpt_oss_lce_forward
1516
+
1517
+ # Note: SwiGLU patching is not implemented for GPT-OSS due to custom expert implementation
1518
+ # with clamping (swiglu_limit=7.0) and MXFP4 quantization
1519
+
1520
+ if model is not None:
1521
+ # The model instance already exists, so we need to additionally patch the
1522
+ # instance variables that reference already-instantiated modules
1523
+
1524
+ # get the base model from the model instance
1525
+ base_model: GptOssModel = getattr(model, model.base_model_prefix, model)
1526
+
1527
+ if rms_norm:
1528
+ _patch_rms_norm_module(base_model.norm)
1529
+ for decoder_layer in base_model.layers:
1530
+ if rms_norm:
1531
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1532
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1533
+
1534
+
1457
1535
  def apply_liger_kernel_to_qwen2_vl(
1458
1536
  rope: bool = True,
1459
1537
  cross_entropy: bool = False,
@@ -1515,11 +1593,10 @@ def apply_liger_kernel_to_qwen2_vl(
1515
1593
  if model is not None:
1516
1594
  # The model instance already exists, so we need to additionally patch the
1517
1595
  # instance variables that reference already-instantiated modules
1518
-
1519
- if isinstance(model, (Qwen2VLForConditionalGeneration, Qwen2VLModel)):
1520
- # Note: language_model and visual properties can be accessed throught conditional class for BC.
1521
- # Not sure if it is subject to changes in the future.
1522
- # Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1698
1596
+ if isinstance(model, Qwen2VLForConditionalGeneration):
1597
+ text_model: Qwen2VLTextModel = model.model.language_model
1598
+ vision_model: Qwen2VisionTransformerPretrainedModel = model.model.visual
1599
+ elif isinstance(model, Qwen2VLModel):
1523
1600
  text_model: Qwen2VLTextModel = model.language_model
1524
1601
  vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
1525
1602
  elif isinstance(model, Qwen2VLTextModel):
@@ -1606,11 +1683,10 @@ def apply_liger_kernel_to_qwen2_5_vl(
1606
1683
  if model is not None:
1607
1684
  # The model instance already exists, so we need to additionally patch the
1608
1685
  # instance variables that reference already-instantiated modules
1609
-
1610
- if isinstance(model, (Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel)):
1611
- # Note: language_model and visual properties can be accessed throught conditional class for BC.
1612
- # Not sure if it is subject to changes in the future.
1613
- # Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1823
1686
+ if isinstance(model, Qwen2_5_VLForConditionalGeneration):
1687
+ text_model: Qwen2_5_VLTextModel = model.model.language_model
1688
+ vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.model.visual
1689
+ elif isinstance(model, Qwen2_5_VLModel):
1614
1690
  text_model: Qwen2_5_VLTextModel = model.language_model
1615
1691
  vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual
1616
1692
  elif isinstance(model, Qwen2_5_VLTextModel):
@@ -1624,7 +1700,7 @@ def apply_liger_kernel_to_qwen2_5_vl(
1624
1700
 
1625
1701
  if vision_model is not None:
1626
1702
  # Patch Qwen2_5_VisionTransformerPretrainedModel
1627
- for vision_block in model.visual.blocks:
1703
+ for vision_block in vision_model.blocks:
1628
1704
  if rms_norm:
1629
1705
  _patch_rms_norm_module(vision_block.norm1)
1630
1706
  _patch_rms_norm_module(vision_block.norm2)
@@ -1640,6 +1716,162 @@ def apply_liger_kernel_to_qwen2_5_vl(
1640
1716
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1641
1717
 
1642
1718
 
1719
+ def apply_liger_kernel_to_qwen3_vl(
1720
+ rope: bool = True,
1721
+ cross_entropy: bool = False,
1722
+ fused_linear_cross_entropy: bool = True,
1723
+ rms_norm: bool = True,
1724
+ swiglu: bool = False,
1725
+ model: PreTrainedModel = None,
1726
+ ) -> None:
1727
+ """
1728
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL models.
1729
+
1730
+ Args:
1731
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1732
+ fused_linear_cross_entropy (bool):
1733
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1734
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1735
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1736
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1737
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1738
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1739
+ loaded. Default is None.
1740
+ """
1741
+
1742
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1743
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1744
+ )
1745
+
1746
+ from transformers.models.qwen3_vl import modeling_qwen3_vl
1747
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration
1748
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel
1749
+ from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextModel
1750
+
1751
+ from liger_kernel.transformers.model.qwen3_vl import lce_forward as qwen3_vl_lce_forward
1752
+
1753
+ if rope:
1754
+ modeling_qwen3_vl.apply_rotary_pos_emb = liger_rotary_pos_emb
1755
+ modeling_qwen3_vl.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
1756
+
1757
+ if rms_norm:
1758
+ modeling_qwen3_vl.Qwen3VLTextRMSNorm = LigerRMSNorm
1759
+
1760
+ if cross_entropy:
1761
+ from transformers.loss.loss_utils import nn
1762
+
1763
+ nn.functional.cross_entropy = liger_cross_entropy
1764
+
1765
+ if fused_linear_cross_entropy:
1766
+ if model is not None:
1767
+ model.forward = MethodType(qwen3_vl_lce_forward, model)
1768
+ else:
1769
+ modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward
1770
+
1771
+ if model is not None and rms_norm:
1772
+ if isinstance(model, Qwen3VLForConditionalGeneration):
1773
+ text_model: Qwen3VLTextModel = model.model.language_model
1774
+ elif isinstance(model, Qwen3VLModel):
1775
+ text_model: Qwen3VLTextModel = model.language_model
1776
+ elif isinstance(model, Qwen3VLTextModel):
1777
+ text_model = model
1778
+ else:
1779
+ raise TypeError(
1780
+ f"Unsupported Qwen3VL model type. `model` must be `Qwen3VLForConditionalGeneration`, `Qwen3VLModel` or `Qwen3VLTextModel`. Got: {type(model)}"
1781
+ )
1782
+
1783
+ _patch_qwen3_vl_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
1784
+
1785
+ if text_model is not None:
1786
+ _patch_qwen3_vl_rms_norm(text_model.norm)
1787
+ for decoder_layer in text_model.layers:
1788
+ _patch_qwen3_vl_rms_norm(decoder_layer.input_layernorm)
1789
+ _patch_qwen3_vl_rms_norm(decoder_layer.post_attention_layernorm)
1790
+ self_attn = getattr(decoder_layer, "self_attn", None)
1791
+ if self_attn is not None:
1792
+ if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
1793
+ _patch_qwen3_vl_rms_norm(self_attn.q_norm)
1794
+ if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
1795
+ _patch_qwen3_vl_rms_norm(self_attn.k_norm)
1796
+
1797
+
1798
+ def apply_liger_kernel_to_qwen3_vl_moe(
1799
+ rope: bool = True,
1800
+ cross_entropy: bool = False,
1801
+ fused_linear_cross_entropy: bool = True,
1802
+ rms_norm: bool = True,
1803
+ swiglu: bool = False,
1804
+ model: PreTrainedModel = None,
1805
+ ) -> None:
1806
+ """
1807
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3-VL MoE models.
1808
+
1809
+ Args:
1810
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1811
+ fused_linear_cross_entropy (bool):
1812
+ Whether to apply Liger's fused linear cross entropy loss. Default is False.
1813
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1814
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1815
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1816
+ loaded. Default is None.
1817
+ """
1818
+
1819
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1820
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1821
+ )
1822
+
1823
+ from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe
1824
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
1825
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeModel
1826
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextModel
1827
+
1828
+ from liger_kernel.transformers.model.qwen3_vl_moe import lce_forward as qwen3_vl_moe_lce_forward
1829
+
1830
+ if rope:
1831
+ modeling_qwen3_vl_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
1832
+ modeling_qwen3_vl_moe.apply_rotary_pos_emb_vision = liger_rotary_pos_emb_vision
1833
+
1834
+ if rms_norm:
1835
+ modeling_qwen3_vl_moe.Qwen3VLMoeTextRMSNorm = LigerRMSNorm
1836
+
1837
+ if cross_entropy:
1838
+ from transformers.loss.loss_utils import nn
1839
+
1840
+ nn.functional.cross_entropy = liger_cross_entropy
1841
+
1842
+ if fused_linear_cross_entropy:
1843
+ if model is not None:
1844
+ model.forward = MethodType(qwen3_vl_moe_lce_forward, model)
1845
+ else:
1846
+ modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward
1847
+
1848
+ if model is not None and rms_norm:
1849
+ if isinstance(model, Qwen3VLMoeForConditionalGeneration):
1850
+ text_model: Qwen3VLMoeTextModel = model.model.language_model
1851
+ elif isinstance(model, Qwen3VLMoeModel):
1852
+ text_model: Qwen3VLMoeTextModel = model.language_model
1853
+ elif isinstance(model, Qwen3VLMoeTextModel):
1854
+ text_model = model
1855
+ else:
1856
+ raise TypeError(
1857
+ f"Unsupported Qwen3VLMoe model type. `model` must be `Qwen3VLMoeForConditionalGeneration`, `Qwen3VLMoeModel` or `Qwen3VLMoeTextModel`. Got: {type(model)}"
1858
+ )
1859
+
1860
+ _patch_qwen3_vl_moe_rms_norm = partial(_patch_rms_norm_module, offset=0.0, casting_mode="llama")
1861
+
1862
+ if text_model is not None:
1863
+ _patch_qwen3_vl_moe_rms_norm(text_model.norm)
1864
+ for decoder_layer in text_model.layers:
1865
+ _patch_qwen3_vl_moe_rms_norm(decoder_layer.input_layernorm)
1866
+ _patch_qwen3_vl_moe_rms_norm(decoder_layer.post_attention_layernorm)
1867
+ self_attn = getattr(decoder_layer, "self_attn", None)
1868
+ if self_attn is not None:
1869
+ if hasattr(self_attn, "q_norm") and self_attn.q_norm is not None:
1870
+ _patch_qwen3_vl_moe_rms_norm(self_attn.q_norm)
1871
+ if hasattr(self_attn, "k_norm") and self_attn.k_norm is not None:
1872
+ _patch_qwen3_vl_moe_rms_norm(self_attn.k_norm)
1873
+
1874
+
1643
1875
  def apply_liger_kernel_to_phi3(
1644
1876
  rope: bool = True,
1645
1877
  cross_entropy: bool = False,
@@ -1771,6 +2003,74 @@ def apply_liger_kernel_to_olmo2(
1771
2003
  _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
1772
2004
 
1773
2005
 
2006
+ def apply_liger_kernel_to_olmo3(
2007
+ rope: bool = True,
2008
+ cross_entropy: bool = False,
2009
+ fused_linear_cross_entropy: bool = True,
2010
+ rms_norm: bool = True,
2011
+ swiglu: bool = True,
2012
+ model: PreTrainedModel = None,
2013
+ ) -> None:
2014
+ """
2015
+ Apply Liger kernels to replace original implementation in HuggingFace Olmo3 models.
2016
+
2017
+ Args:
2018
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
2019
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2020
+ fused_linear_cross_entropy (bool):
2021
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2022
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2023
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2024
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2025
+ swiglu (bool): Whether to apply Liger's SwiGLU to Olmo3MLP. Default is True.
2026
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2027
+ loaded. Default is None.
2028
+ """
2029
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2030
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2031
+ )
2032
+
2033
+ from transformers.models.olmo3 import modeling_olmo3
2034
+ from transformers.models.olmo3.modeling_olmo3 import Olmo3Model
2035
+
2036
+ from liger_kernel.transformers.model.olmo3 import lce_forward as olmo3_lce_forward
2037
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
2038
+
2039
+ # Olmo3 arch is very similar to Olmo2, so we can reuse all these components in the same way.
2040
+ if rope:
2041
+ modeling_olmo3.apply_rotary_pos_emb = liger_rotary_pos_emb
2042
+ if rms_norm:
2043
+ modeling_olmo3.Olmo3RMSNorm = LigerRMSNormForOlmo2 # same as olmo2
2044
+ if swiglu:
2045
+ modeling_olmo3.Olmo3MLP = LigerSwiGLUMLP
2046
+ if cross_entropy:
2047
+ from transformers.loss.loss_utils import nn
2048
+
2049
+ nn.functional.cross_entropy = liger_cross_entropy
2050
+ if fused_linear_cross_entropy:
2051
+ if model is not None:
2052
+ model.forward = MethodType(olmo3_lce_forward, model)
2053
+ else:
2054
+ modeling_olmo3.Olmo3ForCausalLM.forward = olmo3_lce_forward
2055
+
2056
+ if model is not None:
2057
+ # The model instance already exists, so we need to additionally patch the
2058
+ # instance variables that reference already-instantiated modules
2059
+
2060
+ # get the base model from the model instance
2061
+ base_model: Olmo3Model = getattr(model, model.base_model_prefix, model)
2062
+
2063
+ if rms_norm:
2064
+ _patch_rms_norm_module(base_model.norm)
2065
+
2066
+ for decoder_layer in base_model.layers:
2067
+ if swiglu:
2068
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
2069
+ if rms_norm:
2070
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
2071
+ _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
2072
+
2073
+
1774
2074
  def apply_liger_kernel_to_glm4(
1775
2075
  rope: bool = False,
1776
2076
  cross_entropy: bool = False,
@@ -1893,10 +2193,10 @@ def apply_liger_kernel_to_glm4v(
1893
2193
  if model is not None:
1894
2194
  # The model instance already exists, so we need to additionally patch the
1895
2195
  # instance variables that reference already-instantiated modules
1896
- if isinstance(model, (Glm4vForConditionalGeneration, Glm4vModel)):
1897
- # Note: language_model and visual properties can be accessed throught conditional class for BC.
1898
- # Not sure if it is subject to changes in the future.
1899
- # Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modeling_glm4v.py#L1305
2196
+ if isinstance(model, Glm4vForConditionalGeneration):
2197
+ text_model: Glm4vTextModel = model.model.language_model
2198
+ vision_model: Glm4vVisionModel = model.model.visual
2199
+ elif isinstance(model, Glm4vModel):
1900
2200
  text_model: Glm4vTextModel = model.language_model
1901
2201
  vision_model: Glm4vVisionModel = model.visual
1902
2202
  elif isinstance(model, Glm4vTextModel):
@@ -1968,7 +2268,8 @@ def apply_liger_kernel_to_glm4v_moe(
1968
2268
  if rope:
1969
2269
  raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
1970
2270
  if rms_norm:
1971
- modeling_glm4v_moe.Glm4vRMSNorm = LigerRMSNormForGlm4
2271
+ modeling_glm4v_moe.Glm4vMoeRMSNorm = LigerRMSNormForGlm4
2272
+ modeling_glm4v_moe.Glm4vMoeTextRMSNorm = LigerRMSNormForGlm4
1972
2273
  if cross_entropy:
1973
2274
  from transformers.loss.loss_utils import nn
1974
2275
 
@@ -1982,10 +2283,11 @@ def apply_liger_kernel_to_glm4v_moe(
1982
2283
  if model is not None:
1983
2284
  # The model instance already exists, so we need to additionally patch the
1984
2285
  # instance variables that reference already-instantiated modules
1985
- if isinstance(model, (Glm4vMoeForConditionalGeneration, Glm4vMoeModel)):
1986
- # Note: language_model and visual properties can be accessed throught conditional class for BC.
1987
- # Not sure if it is subject to changes in the future.
1988
- # Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py#L337
2286
+ if isinstance(model, Glm4vMoeForConditionalGeneration):
2287
+ text_model: Glm4vMoeTextModel = model.model.language_model
2288
+ vision_model: Glm4vMoeVisionModel = model.model.visual
2289
+ Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
2290
+ elif isinstance(model, Glm4vMoeModel):
1989
2291
  text_model: Glm4vMoeTextModel = model.language_model
1990
2292
  vision_model: Glm4vMoeVisionModel = model.visual
1991
2293
  Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
@@ -2034,6 +2336,7 @@ def apply_liger_kernel_to_internvl(
2034
2336
  cross_entropy: bool = False,
2035
2337
  fused_linear_cross_entropy: bool = True,
2036
2338
  rms_norm: bool = True,
2339
+ layer_norm: bool = True,
2037
2340
  model: Optional[PreTrainedModel] = None,
2038
2341
  **kwargs,
2039
2342
  ) -> None:
@@ -2044,37 +2347,62 @@ def apply_liger_kernel_to_internvl(
2044
2347
  NOTE: InternVL is not available in transformers<4.52.1
2045
2348
 
2046
2349
  Args:
2047
- rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
2048
2350
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2049
2351
  fused_linear_cross_entropy (bool):
2050
2352
  Whether to apply Liger's fused linear cross entropy loss. Default is True.
2051
2353
  `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2052
2354
  If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2053
2355
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2054
- swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
2356
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
2055
2357
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2056
2358
  loaded. Default is None.
2057
2359
  """
2058
2360
  assert not (cross_entropy and fused_linear_cross_entropy), (
2059
2361
  "cross_entropy and fused_linear_cross_entropy cannot both be True."
2060
2362
  )
2363
+ import torch.nn as torch_nn
2061
2364
 
2062
2365
  from transformers.models.internvl import modeling_internvl
2366
+ from transformers.models.internvl.modeling_internvl import InternVLForConditionalGeneration
2367
+ from transformers.models.internvl.modeling_internvl import InternVLModel
2368
+ from transformers.models.internvl.modeling_internvl import InternVLVisionLayer
2369
+ from transformers.models.internvl.modeling_internvl import InternVLVisionModel
2370
+ from transformers.models.internvl.modeling_internvl import InternVLVisionRMSNorm
2063
2371
 
2372
+ from liger_kernel.transformers.layer_norm import LigerLayerNorm
2064
2373
  from liger_kernel.transformers.model.internvl import lce_forward as internvl_lce_forward
2374
+ from liger_kernel.transformers.rms_norm import LigerRMSNorm
2375
+
2376
+ if layer_norm and model is None:
2377
+ modeling_internvl.nn.LayerNorm = LigerLayerNorm
2065
2378
 
2066
2379
  if cross_entropy:
2067
- logger.warning(TRANSFORMER_DEPRECATION_WARNING)
2068
- modeling_internvl.nn.CrossEntropyLoss = LigerCrossEntropyLoss
2380
+ logger.info("Apply liger cross entropy")
2381
+
2382
+ from transformers.loss.loss_utils import nn
2383
+
2384
+ nn.functional.cross_entropy = liger_cross_entropy
2069
2385
  if fused_linear_cross_entropy:
2070
2386
  modeling_internvl.InternVLForConditionalGeneration.forward = internvl_lce_forward
2071
2387
  if rms_norm:
2072
2388
  modeling_internvl.InternVLVisionRMSNorm = LigerRMSNorm
2073
2389
 
2074
2390
  if model is not None:
2075
- text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
2391
+ # The model instance already exists, so we need to additionally patch the
2392
+ # instance variables that reference already-instantiated modules
2393
+ if isinstance(model, InternVLForConditionalGeneration):
2394
+ text_model = model.model.language_model
2395
+ vision_model: InternVLVisionModel = model.model.vision_tower
2396
+ elif isinstance(model, InternVLModel):
2397
+ text_model = model.language_model
2398
+ vision_model: InternVLVisionModel = model.vision_tower
2399
+ else:
2400
+ raise TypeError(
2401
+ f"Unsupported internvl model type. `model` must be `InternVLForConditionalGeneration`, `InternVLModel`. Got: {type(model)}"
2402
+ )
2403
+
2404
+ text_model_name = model.config.text_config.model_type
2076
2405
  text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
2077
- vision_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(vision_model_name, None)
2078
2406
 
2079
2407
  kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
2080
2408
  if text_liger_fn:
@@ -2087,25 +2415,133 @@ def apply_liger_kernel_to_internvl(
2087
2415
  f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
2088
2416
  f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
2089
2417
  )
2090
- text_kwargs["model"] = model.language_model
2418
+ text_kwargs["model"] = text_model
2091
2419
  text_liger_fn(**text_kwargs)
2092
2420
  elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
2093
2421
  logger.warning(f"{text_model_name} is not supported by Liger kernel.")
2094
2422
 
2095
- if vision_liger_fn:
2096
- accept_params = inspect.signature(vision_liger_fn).parameters
2423
+ # Patch vision model RMSNorm layers
2424
+ if rms_norm:
2425
+ for encoder_layer in vision_model.encoder.layer:
2426
+ encoder_layer: InternVLVisionLayer
2427
+ if isinstance(encoder_layer.attention.q_norm, InternVLVisionRMSNorm):
2428
+ _patch_rms_norm_module(encoder_layer.attention.q_norm)
2429
+ if isinstance(encoder_layer.attention.k_norm, InternVLVisionRMSNorm):
2430
+ _patch_rms_norm_module(encoder_layer.attention.k_norm)
2431
+
2432
+ # Patch vision model LayerNorm layers
2433
+ if layer_norm:
2434
+ # Patch layernorm
2435
+ if isinstance(vision_model.layernorm, torch_nn.LayerNorm):
2436
+ _patch_layer_norm_module(vision_model.layernorm)
2437
+
2438
+ # Patch encoder layers
2439
+ for encoder_layer in vision_model.encoder.layer:
2440
+ encoder_layer: InternVLVisionLayer
2441
+ if isinstance(encoder_layer.layernorm_before, torch_nn.LayerNorm):
2442
+ _patch_layer_norm_module(encoder_layer.layernorm_before)
2443
+ if isinstance(encoder_layer.layernorm_after, torch_nn.LayerNorm):
2444
+ _patch_layer_norm_module(encoder_layer.layernorm_after)
2445
+
2446
+
2447
+ def apply_liger_kernel_to_smolvlm(
2448
+ cross_entropy: bool = False,
2449
+ fused_linear_cross_entropy: bool = True,
2450
+ rms_norm: bool = True,
2451
+ layer_norm: bool = True,
2452
+ model: Optional[PreTrainedModel] = None,
2453
+ **kwargs,
2454
+ ) -> None:
2455
+ """
2456
+ Apply Liger kernels to replace original implementation in HuggingFace SmolVLM models.
2457
+ Due to the characteristics of SmolVLM, the model must be passed to apply Liger-Kernel's patch to other models connected to SmolVLM.
2458
+ However, if an LM not supported by Liger-Kernel is connected to SmolVLM, unexpected side effects may occur.
2459
+ NOTE: SmolVLM is not available in transformers<4.50.0
2460
+
2461
+ Args:
2462
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2463
+ fused_linear_cross_entropy (bool):
2464
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2465
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2466
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2467
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2468
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
2469
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2470
+ loaded. Default is None.
2471
+ """
2472
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2473
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2474
+ )
2475
+
2476
+ from transformers.models.smolvlm import modeling_smolvlm
2477
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMEncoderLayer
2478
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMForConditionalGeneration
2479
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMModel
2480
+ from transformers.models.smolvlm.modeling_smolvlm import SmolVLMVisionTransformer
2481
+
2482
+ from liger_kernel.transformers.model.smolvlm import lce_forward as smolvlm_lce_forward
2483
+
2484
+ # Patch LayerNorm for vision model if model is not provided (pre-initialization)
2485
+ if layer_norm and model is None:
2486
+ modeling_smolvlm.nn.LayerNorm = LigerLayerNorm
2487
+
2488
+ if cross_entropy:
2489
+ logger.info("Apply liger cross entropy")
2490
+
2491
+ from transformers.loss.loss_utils import nn
2492
+
2493
+ nn.functional.cross_entropy = liger_cross_entropy
2494
+ if fused_linear_cross_entropy:
2495
+ if model is not None:
2496
+ model.forward = MethodType(smolvlm_lce_forward, model)
2497
+ else:
2498
+ modeling_smolvlm.SmolVLMForConditionalGeneration.forward = smolvlm_lce_forward
2499
+ if rms_norm:
2500
+ modeling_smolvlm.SmolVLMRMSNorm = LigerRMSNorm
2501
+
2502
+ if model is not None:
2503
+ # The model instance already exists, so we need to additionally patch the
2504
+ # instance variables that reference already-instantiated modules
2505
+ if isinstance(model, SmolVLMForConditionalGeneration):
2506
+ text_model = model.model.text_model
2507
+ vision_model: SmolVLMVisionTransformer = model.model.vision_model
2508
+ elif isinstance(model, SmolVLMModel):
2509
+ text_model = model.text_model
2510
+ vision_model: SmolVLMVisionTransformer = model.vision_model
2511
+ else:
2512
+ raise TypeError(
2513
+ f"Unsupported smolvlm model type. `model` must be `SmolVLMForConditionalGeneration`, `SmolVLMModel`. Got: {type(model)}"
2514
+ )
2515
+
2516
+ text_model_name = model.config.text_config.model_type
2517
+ text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
2518
+
2519
+ kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
2520
+ if text_liger_fn:
2521
+ accept_params = inspect.signature(text_liger_fn).parameters
2097
2522
  remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
2098
- vision_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
2523
+ text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
2099
2524
 
2100
2525
  if remain_params:
2101
2526
  logger.warning(
2102
- f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
2103
- f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
2527
+ f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
2528
+ f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
2104
2529
  )
2105
- vision_kwargs["model"] = model.vision_tower
2106
- vision_liger_fn(**vision_kwargs)
2107
- elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
2108
- logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
2530
+ text_kwargs["model"] = text_model
2531
+ text_liger_fn(**text_kwargs)
2532
+ elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
2533
+ logger.warning(f"{text_model_name} is not supported by Liger kernel.")
2534
+
2535
+ # Patch vision model LayerNorm layers
2536
+ if layer_norm:
2537
+ # Patch post_layernorm
2538
+ _patch_layer_norm_module(vision_model.post_layernorm)
2539
+
2540
+ # Patch encoder layers
2541
+ for encoder_layer in vision_model.encoder.layers:
2542
+ encoder_layer: SmolVLMEncoderLayer
2543
+ _patch_layer_norm_module(encoder_layer.layer_norm1)
2544
+ _patch_layer_norm_module(encoder_layer.layer_norm2)
2109
2545
 
2110
2546
 
2111
2547
  def apply_liger_kernel_to_falcon_h1(
@@ -2177,6 +2613,214 @@ def apply_liger_kernel_to_falcon_h1(
2177
2613
  _patch_rms_norm_module(decoder_layer.pre_ff_layernorm)
2178
2614
 
2179
2615
 
2616
+ def apply_liger_kernel_to_qwen3_next(
2617
+ rope: bool = False,
2618
+ cross_entropy: bool = False,
2619
+ fused_linear_cross_entropy: bool = True,
2620
+ rms_norm: bool = True,
2621
+ swiglu: bool = True,
2622
+ model: PreTrainedModel = None,
2623
+ ) -> None:
2624
+ """
2625
+ Apply Liger kernels to replace original implementation in HuggingFace GLM4v_moe models.
2626
+
2627
+ Args:
2628
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
2629
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
2630
+ fused_linear_cross_entropy (bool):
2631
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
2632
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
2633
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
2634
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2635
+ swiglu (bool): Whether to apply Liger's SwiGLUMLP. Default is True.
2636
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
2637
+ loaded. Default is None.
2638
+ """
2639
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2640
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2641
+ )
2642
+
2643
+ from transformers.models.qwen3_next import modeling_qwen3_next
2644
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextForCausalLM
2645
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextMLP
2646
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextModel
2647
+ from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
2648
+
2649
+ from liger_kernel.transformers.model.qwen3_next import lce_forward as qwen3_next_lce_forward
2650
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next
2651
+ from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
2652
+
2653
+ if rope:
2654
+ # It might enocunter nan issue
2655
+ # modeling_qwen3_next.apply_rotary_pos_emb = liger_rotary_pos_emb
2656
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3Next models.")
2657
+ if rms_norm:
2658
+ modeling_qwen3_next.Qwen3NextRMSNorm = LigerRMSNormForQwen3Next
2659
+ if cross_entropy:
2660
+ from transformers.loss.loss_utils import nn
2661
+
2662
+ nn.functional.cross_entropy = liger_cross_entropy
2663
+ if fused_linear_cross_entropy:
2664
+ if model is not None:
2665
+ if isinstance(model, Qwen3NextForCausalLM):
2666
+ model.forward = MethodType(qwen3_next_lce_forward, model)
2667
+ else:
2668
+ raise TypeError(
2669
+ f" fused_linear_cross_entropy is only applicable on Qwen3NextForCausalLM. Got: {type(model)}"
2670
+ )
2671
+ else:
2672
+ modeling_qwen3_next.Qwen3NextForCausalLM.forward = qwen3_next_lce_forward
2673
+ if swiglu:
2674
+ # Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
2675
+ modeling_qwen3_next.Qwen3NextMLP = LigerQwen3MoeSwiGLUMLP
2676
+
2677
+ if model is not None:
2678
+ # The model instance already exists, so we need to additionally patch the
2679
+ # instance variables that reference already-instantiated modules
2680
+ if isinstance(model, (Qwen3NextForCausalLM, Qwen3NextModel)):
2681
+ base_model: Qwen3NextForCausalLM = getattr(model, model.base_model_prefix, model)
2682
+ else:
2683
+ raise TypeError(
2684
+ f"Unsupported qwen3_next model type. `model` must be `Qwen3NextForCausalLM`, `Qwen3NextModel`. Got: {type(model)}"
2685
+ )
2686
+
2687
+ if rms_norm:
2688
+ _patch_rms_norm_module(base_model.norm)
2689
+
2690
+ for decoder_layer in base_model.layers:
2691
+ if rms_norm:
2692
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2693
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2694
+
2695
+ # Qwen3MoeMLP and Qwen3NextMLP are identical, hence we reuse LigerQwen3MoeSwiGLUMLP
2696
+ if swiglu:
2697
+ if isinstance(decoder_layer.mlp, Qwen3NextMLP):
2698
+ _patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)
2699
+ if isinstance(decoder_layer.mlp, Qwen3NextSparseMoeBlock):
2700
+ _patch_swiglu_module(decoder_layer.mlp.shared_expert, LigerQwen3MoeSwiGLUMLP)
2701
+ experts = getattr(decoder_layer.mlp, "experts", None)
2702
+ if experts is not None:
2703
+ for expert in experts:
2704
+ _patch_swiglu_module(expert, LigerQwen3MoeSwiGLUMLP)
2705
+
2706
+
2707
+ def apply_liger_kernel_to_hunyuan_v1_dense(
2708
+ rope: bool = True,
2709
+ cross_entropy: bool = False,
2710
+ fused_linear_cross_entropy: bool = True,
2711
+ rms_norm: bool = True,
2712
+ swiglu: bool = True,
2713
+ model: PreTrainedModel = None,
2714
+ ) -> None:
2715
+ """
2716
+ Apply Liger kernels to replace original implementation in HuggingFace Hunyuan v1 dense models.
2717
+ """
2718
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2719
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2720
+ )
2721
+
2722
+ from transformers.models.hunyuan_v1_dense import modeling_hunyuan_v1_dense
2723
+ from transformers.models.hunyuan_v1_dense.modeling_hunyuan_v1_dense import HunYuanDenseV1Model
2724
+
2725
+ from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_lce_forward
2726
+ from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
2727
+
2728
+ if rope:
2729
+ modeling_hunyuan_v1_dense.apply_rotary_pos_emb = liger_rotary_pos_emb
2730
+
2731
+ if rms_norm:
2732
+ modeling_hunyuan_v1_dense.HunYuanDenseV1RMSNorm = LigerRMSNorm
2733
+
2734
+ if cross_entropy:
2735
+ from transformers.loss.loss_utils import nn
2736
+
2737
+ nn.functional.cross_entropy = liger_cross_entropy
2738
+
2739
+ if fused_linear_cross_entropy:
2740
+ if model is not None:
2741
+ model.forward = MethodType(hunyuan_v1_lce_forward, model)
2742
+ else:
2743
+ modeling_hunyuan_v1_dense.HunYuanDenseV1ForCausalLM.forward = hunyuan_v1_lce_forward
2744
+
2745
+ if swiglu:
2746
+ modeling_hunyuan_v1_dense.HunYuanDenseV1MLP = LigerHunyuanV1SwiGLUMLP
2747
+
2748
+ if model is not None:
2749
+ # The model instance already exists, so we need to additionally patch the
2750
+ # instance variables that reference already-instantiated modules
2751
+
2752
+ # get the base model from the model instance
2753
+ base_model: HunYuanDenseV1Model = getattr(model, model.base_model_prefix, model)
2754
+
2755
+ if rms_norm:
2756
+ _patch_rms_norm_module(base_model.norm)
2757
+ for decoder_layer in base_model.layers:
2758
+ if swiglu:
2759
+ _patch_swiglu_module(decoder_layer.mlp, LigerHunyuanV1SwiGLUMLP)
2760
+ if rms_norm:
2761
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2762
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2763
+
2764
+
2765
+ def apply_liger_kernel_to_hunyuan_v1_moe(
2766
+ rope: bool = True,
2767
+ cross_entropy: bool = False,
2768
+ fused_linear_cross_entropy: bool = True,
2769
+ rms_norm: bool = True,
2770
+ swiglu: bool = True,
2771
+ model: PreTrainedModel = None,
2772
+ ) -> None:
2773
+ """
2774
+ Apply Liger kernels to replace original implementation in HuggingFace Qwen3 models.
2775
+ """
2776
+ assert not (cross_entropy and fused_linear_cross_entropy), (
2777
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
2778
+ )
2779
+
2780
+ from transformers.models.hunyuan_v1_moe import modeling_hunyuan_v1_moe
2781
+ from transformers.models.hunyuan_v1_moe.modeling_hunyuan_v1_moe import HunYuanMoEV1Model
2782
+
2783
+ from liger_kernel.transformers.model.hunyuan_v1 import lce_forward as hunyuan_v1_moe_lce_forward
2784
+ from liger_kernel.transformers.swiglu import LigerHunyuanV1SwiGLUMLP
2785
+
2786
+ if rope:
2787
+ modeling_hunyuan_v1_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
2788
+
2789
+ if rms_norm:
2790
+ modeling_hunyuan_v1_moe.HunYuanMoEV1RMSNorm = LigerRMSNorm
2791
+
2792
+ if cross_entropy:
2793
+ from transformers.loss.loss_utils import nn
2794
+
2795
+ nn.functional.cross_entropy = liger_cross_entropy
2796
+
2797
+ if fused_linear_cross_entropy:
2798
+ if model is not None:
2799
+ model.forward = MethodType(hunyuan_v1_moe_lce_forward, model)
2800
+ else:
2801
+ modeling_hunyuan_v1_moe.HunYuanMoEV1ForCausalLM.forward = hunyuan_v1_moe_lce_forward
2802
+
2803
+ if swiglu:
2804
+ modeling_hunyuan_v1_moe.HunYuanMoEV1MLP = LigerHunyuanV1SwiGLUMLP
2805
+
2806
+ if model is not None:
2807
+ # The model instance already exists, so we need to additionally patch the
2808
+ # instance variables that reference already-instantiated modules
2809
+
2810
+ # get the base model from the model instance
2811
+ base_model: HunYuanMoEV1Model = getattr(model, model.base_model_prefix, model)
2812
+
2813
+ if rms_norm:
2814
+ _patch_rms_norm_module(base_model.norm)
2815
+ for decoder_layer in base_model.layers:
2816
+ if swiglu:
2817
+ for mlp_expert in decoder_layer.mlp.experts:
2818
+ _patch_swiglu_module(mlp_expert, LigerHunyuanV1SwiGLUMLP)
2819
+ if rms_norm:
2820
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
2821
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
2822
+
2823
+
2180
2824
  # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
2181
2825
  MODEL_TYPE_TO_APPLY_LIGER_FN = {
2182
2826
  "gemma": apply_liger_kernel_to_gemma,
@@ -2186,6 +2830,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
2186
2830
  "glm4": apply_liger_kernel_to_glm4,
2187
2831
  "glm4v": apply_liger_kernel_to_glm4v,
2188
2832
  "glm4v_moe": apply_liger_kernel_to_glm4v_moe,
2833
+ "gpt_oss": apply_liger_kernel_to_gpt_oss,
2189
2834
  "internvl": apply_liger_kernel_to_internvl,
2190
2835
  "llama": apply_liger_kernel_to_llama,
2191
2836
  "llama4_text": apply_liger_kernel_to_llama4,
@@ -2197,6 +2842,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
2197
2842
  "mistral": apply_liger_kernel_to_mistral,
2198
2843
  "mixtral": apply_liger_kernel_to_mixtral,
2199
2844
  "olmo2": apply_liger_kernel_to_olmo2,
2845
+ "olmo3": apply_liger_kernel_to_olmo3,
2200
2846
  "qwen2": apply_liger_kernel_to_qwen2,
2201
2847
  "qwen3": apply_liger_kernel_to_qwen3,
2202
2848
  "qwen3_moe": apply_liger_kernel_to_qwen3_moe,
@@ -2204,10 +2850,18 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
2204
2850
  "qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
2205
2851
  "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
2206
2852
  "qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
2853
+ "qwen3_next": apply_liger_kernel_to_qwen3_next,
2854
+ "qwen3_vl": apply_liger_kernel_to_qwen3_vl,
2855
+ "qwen3_vl_text": apply_liger_kernel_to_qwen3_vl,
2856
+ "qwen3_vl_moe": apply_liger_kernel_to_qwen3_vl_moe,
2857
+ "qwen3_vl_moe_text": apply_liger_kernel_to_qwen3_vl_moe,
2207
2858
  "smollm3": apply_liger_kernel_to_smollm3,
2208
2859
  "phi3": apply_liger_kernel_to_phi3,
2209
2860
  "paligemma": apply_liger_kernel_to_paligemma,
2210
2861
  "falcon_h1": apply_liger_kernel_to_falcon_h1,
2862
+ "smolvlm": apply_liger_kernel_to_smolvlm,
2863
+ "hunyuan_v1_dense": apply_liger_kernel_to_hunyuan_v1_dense,
2864
+ "hunyuan_v1_moe": apply_liger_kernel_to_hunyuan_v1_moe,
2211
2865
  }
2212
2866
 
2213
2867