liger-kernel 0.5.10__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. liger_kernel/chunked_loss/__init__.py +1 -0
  2. liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
  3. liger_kernel/chunked_loss/functional.py +2 -0
  4. liger_kernel/ops/dyt.py +0 -2
  5. liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  6. liger_kernel/ops/geglu.py +1 -1
  7. liger_kernel/ops/multi_token_attention.py +207 -0
  8. liger_kernel/ops/rms_norm.py +265 -54
  9. liger_kernel/ops/softmax.py +201 -0
  10. liger_kernel/ops/sparsemax.py +62 -50
  11. liger_kernel/ops/swiglu.py +1 -1
  12. liger_kernel/transformers/__init__.py +3 -0
  13. liger_kernel/transformers/functional.py +62 -0
  14. liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  15. liger_kernel/transformers/model/gemma.py +25 -8
  16. liger_kernel/transformers/model/gemma2.py +27 -8
  17. liger_kernel/transformers/model/gemma3.py +62 -98
  18. liger_kernel/transformers/model/glm4.py +16 -7
  19. liger_kernel/transformers/model/llama.py +25 -7
  20. liger_kernel/transformers/model/llama4.py +108 -0
  21. liger_kernel/transformers/model/llava.py +95 -124
  22. liger_kernel/transformers/model/mistral.py +13 -8
  23. liger_kernel/transformers/model/mixtral.py +16 -7
  24. liger_kernel/transformers/model/mllama.py +16 -7
  25. liger_kernel/transformers/model/olmo2.py +16 -7
  26. liger_kernel/transformers/model/paligemma.py +8 -1
  27. liger_kernel/transformers/model/phi3.py +25 -8
  28. liger_kernel/transformers/model/qwen2.py +24 -7
  29. liger_kernel/transformers/model/qwen2_5_vl.py +41 -91
  30. liger_kernel/transformers/model/qwen2_vl.py +38 -100
  31. liger_kernel/transformers/model/qwen3.py +11 -3
  32. liger_kernel/transformers/model/qwen3_moe.py +10 -6
  33. liger_kernel/transformers/monkey_patch.py +304 -70
  34. liger_kernel/transformers/multi_token_attention.py +64 -0
  35. liger_kernel/transformers/rms_norm.py +40 -4
  36. liger_kernel/transformers/softmax.py +12 -0
  37. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/METADATA +8 -2
  38. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/RECORD +42 -35
  39. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/WHEEL +1 -1
  40. liger_kernel/transformers/gema3_rms.py +0 -8
  41. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/licenses/LICENSE +0 -0
  42. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/licenses/NOTICE +0 -0
  43. {liger_kernel-0.5.10.dist-info → liger_kernel-0.6.0.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@ import inspect
2
2
  import logging
3
3
 
4
4
  from functools import partial
5
+ from types import MethodType
5
6
  from typing import Callable
6
7
 
7
8
  import transformers
@@ -54,7 +55,7 @@ def _bind_method_to_module(module, method_name: str, new_method: Callable):
54
55
  module.__dict__[method_name] = new_method.__get__(module, module.__class__)
55
56
 
56
57
 
57
- def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True):
58
+ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True, row_mode=None):
58
59
  # Check if the module is a PEFT ModulesToSaveWrapper
59
60
  # If it is, we need to patch the modules_to_save.default and original_modules
60
61
  if PEFT_AVAILABLE and isinstance(module, peft.utils.other.ModulesToSaveWrapper):
@@ -64,12 +65,14 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
64
65
  getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
65
66
  )
66
67
  module.modules_to_save.default.in_place = in_place
68
+ module.modules_to_save.default.row_mode = row_mode
67
69
  module.original_module.offset = offset
68
70
  module.original_module.casting_mode = casting_mode
69
71
  module.original_module.variance_epsilon = (
70
72
  getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
71
73
  )
72
74
  module.original_module.in_place = in_place
75
+ module.original_module.row_mode = row_mode
73
76
  _bind_method_to_module(module.modules_to_save.default, "forward", LigerRMSNorm.forward)
74
77
  _bind_method_to_module(module.modules_to_save.default, "extra_repr", LigerRMSNorm.extra_repr)
75
78
  _bind_method_to_module(module.original_module, "forward", LigerRMSNorm.forward)
@@ -81,6 +84,7 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
81
84
  module.casting_mode = casting_mode
82
85
  module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
83
86
  module.in_place = in_place
87
+ module.row_mode = row_mode
84
88
  _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
85
89
  _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
86
90
  module.__class__.__name__ = LigerRMSNorm.__name__
@@ -257,10 +261,16 @@ def apply_liger_kernel_to_llama(
257
261
 
258
262
  if fused_linear_cross_entropy:
259
263
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
260
- modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
264
+ if model is not None:
265
+ model.forward = MethodType(llama_lce_forward, model)
266
+ else:
267
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
261
268
  else: # if version < 4.46.1
262
269
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
263
- modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
270
+ if model is not None:
271
+ model.forward = MethodType(llama_lce_forward_deprecated, model)
272
+ else:
273
+ modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
264
274
 
265
275
  if model is not None:
266
276
  # The model instance already exists, so we need to additionally patch the
@@ -314,13 +324,20 @@ def apply_liger_kernel_to_llava(
314
324
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
315
325
  modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
316
326
  if fused_linear_cross_entropy:
317
- if transformer_version >= version.parse("4.49.0"):
318
- modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
327
+ if transformer_version >= version.parse("4.52.0"):
328
+ if model is not None:
329
+ model.forward = MethodType(llava_lce_forward, model)
330
+ else:
331
+ modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
332
+ elif transformer_version >= version.parse("4.49.0") and transformer_version < version.parse("4.52.0"):
333
+ if model is not None:
334
+ model.forward = MethodType(llava_lce_forward_deprecated, model)
335
+ else:
336
+ modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
319
337
  else: # if version < 4.49.0
320
338
  logger.warning(
321
- "Support for transformers versions < 4.49.0 will soon be discontinued due to issues with incorrect legacy processing. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/35526"
339
+ "The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
322
340
  )
323
- modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
324
341
 
325
342
  if model is not None:
326
343
  text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
@@ -359,6 +376,92 @@ def apply_liger_kernel_to_llava(
359
376
  logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
360
377
 
361
378
 
379
+ def apply_liger_kernel_to_llama4(
380
+ rope: bool = False,
381
+ cross_entropy: bool = False,
382
+ fused_linear_cross_entropy: bool = True,
383
+ rms_norm: bool = True,
384
+ swiglu: bool = True,
385
+ model: PreTrainedModel = None,
386
+ layer_norm: bool = True,
387
+ ) -> None:
388
+ """
389
+ Apply Liger kernels to replace original implementation in HuggingFace Llama4 models.
390
+
391
+ Args:
392
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
393
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
394
+ fused_linear_cross_entropy (bool):
395
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
396
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
397
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
398
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
399
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
400
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
401
+ loaded. Default is None.
402
+ """
403
+ assert not (cross_entropy and fused_linear_cross_entropy), (
404
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
405
+ )
406
+
407
+ from transformers.models.llama4 import modeling_llama4
408
+ from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM
409
+ from transformers.models.llama4.modeling_llama4 import Llama4ForConditionalGeneration
410
+ from transformers.models.llama4.modeling_llama4 import Llama4TextModel
411
+ from transformers.models.llama4.modeling_llama4 import Llama4VisionModel
412
+
413
+ from liger_kernel.transformers.model.llama4 import lce_forward as llama4_lce_forward
414
+
415
+ if rope:
416
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Llama4 models.")
417
+ if rms_norm:
418
+ modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
419
+ if swiglu:
420
+ modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
421
+
422
+ if cross_entropy:
423
+ modeling_llama4.CrossEntropyLoss = LigerCrossEntropyLoss
424
+
425
+ if fused_linear_cross_entropy:
426
+ modeling_llama4.Llama4ForCausalLM.forward = llama4_lce_forward
427
+
428
+ if model is not None:
429
+ # The model instance already exists, so we need to additionally patch the
430
+ # instance variables that reference already-instantiated modules
431
+ if isinstance(model, Llama4ForConditionalGeneration):
432
+ language_model: Llama4ForCausalLM = model.language_model
433
+ vision_model: Llama4VisionModel = model.vision_model
434
+ text_model: Llama4TextModel = language_model.model
435
+ elif isinstance(model, Llama4ForCausalLM):
436
+ text_model = model.model
437
+ vision_model = None
438
+ elif isinstance(model, Llama4TextModel):
439
+ text_model = model
440
+ vision_model = None
441
+
442
+ else:
443
+ raise ValueError(f"Unsupported Llama4 model type: {type(model)}")
444
+
445
+ if text_model:
446
+ if rms_norm:
447
+ _patch_rms_norm_module(text_model.norm)
448
+ for decoder_layer in text_model.layers:
449
+ if swiglu:
450
+ _patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
451
+ if rms_norm:
452
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
453
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
454
+
455
+ if vision_model:
456
+ _patch_layer_norm_module(vision_model.layernorm_pre)
457
+ _patch_layer_norm_module(vision_model.layernorm_post)
458
+
459
+ for layer in vision_model.model.layers:
460
+ if layer_norm:
461
+ _patch_layer_norm_module(layer.input_layernorm)
462
+ _patch_layer_norm_module(layer.post_attention_layernorm)
463
+
464
+
362
465
  def apply_liger_kernel_to_mllama(
363
466
  rope: bool = True,
364
467
  cross_entropy: bool = False,
@@ -400,7 +503,7 @@ def apply_liger_kernel_to_mllama(
400
503
 
401
504
  if rope:
402
505
  modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
403
- if layer_norm:
506
+ if layer_norm and model is None:
404
507
  modeling_mllama.nn.LayerNorm = LigerLayerNorm
405
508
  if rms_norm:
406
509
  modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
@@ -416,10 +519,16 @@ def apply_liger_kernel_to_mllama(
416
519
  modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
417
520
  if fused_linear_cross_entropy:
418
521
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
419
- modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
522
+ if model is not None:
523
+ model.forward = MethodType(mllama_lce_forward, model)
524
+ else:
525
+ modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
420
526
  else: # if version < 4.46.1
421
527
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
422
- modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
528
+ if model is not None:
529
+ model.forward = MethodType(mllama_lce_forward_deprecated, model)
530
+ else:
531
+ modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
423
532
 
424
533
  if model is not None:
425
534
  # The model instance already exists, so we need to additionally patch the
@@ -428,13 +537,17 @@ def apply_liger_kernel_to_mllama(
428
537
  if isinstance(model, MllamaForConditionalGeneration):
429
538
  language_model: MllamaForCausalLM = model.language_model
430
539
  vision_model: MllamaVisionModel = model.vision_model
431
- text_model: MllamaTextModel = language_model.model
540
+ if isinstance(language_model, MllamaForCausalLM):
541
+ text_model: MllamaTextModel = language_model.model
542
+ else:
543
+ text_model = language_model
432
544
  elif isinstance(model, MllamaForCausalLM):
433
545
  text_model = model.model
434
546
  vision_model = None
435
547
  elif isinstance(model, MllamaTextModel):
436
548
  text_model = model
437
549
  vision_model = None
550
+
438
551
  else:
439
552
  raise ValueError(f"Unsupported Mllama model type: {type(model)}")
440
553
 
@@ -501,7 +614,17 @@ def apply_liger_kernel_to_mistral(
501
614
  if cross_entropy:
502
615
  modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
503
616
  if fused_linear_cross_entropy:
504
- modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
617
+ if transformer_version >= version.parse("4.49.0"):
618
+ if model is not None:
619
+ model.forward = MethodType(mistral_lce_forward, model)
620
+ else:
621
+ modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
622
+ else:
623
+ logger.warning(
624
+ "The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
625
+ )
626
+ logger.warning("LigerFusedLinearCrossEntropy patch is not applied.")
627
+
505
628
  if swiglu:
506
629
  modeling_mistral.MistralMLP = LigerSwiGLUMLP
507
630
 
@@ -569,10 +692,16 @@ def apply_liger_kernel_to_mixtral(
569
692
 
570
693
  if fused_linear_cross_entropy:
571
694
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
572
- modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
695
+ if model is not None:
696
+ model.forward = MethodType(mixtral_lce_forward, model)
697
+ else:
698
+ modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
573
699
  else: # if version < 4.46.1
574
700
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
575
- modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
701
+ if model is not None:
702
+ model.forward = MethodType(mixtral_lce_forward_deprecated, model)
703
+ else:
704
+ modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
576
705
  if swiglu:
577
706
  modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
578
707
 
@@ -626,8 +755,8 @@ def apply_liger_kernel_to_gemma(
626
755
  from transformers.models.gemma import modeling_gemma
627
756
  from transformers.models.gemma.modeling_gemma import GemmaModel
628
757
 
629
- # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
630
- LigerRMSNormForGemma = partial(LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma")
758
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma
759
+
631
760
  _patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
632
761
 
633
762
  if rope:
@@ -646,10 +775,16 @@ def apply_liger_kernel_to_gemma(
646
775
  modeling_gemma.GemmaMLP = LigerGEGLUMLP
647
776
  if fused_linear_cross_entropy:
648
777
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
649
- modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
778
+ if model is not None:
779
+ model.forward = MethodType(gemma_lce_forward, model)
780
+ else:
781
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
650
782
  else: # if version < 4.46.1
651
783
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
652
- modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
784
+ if model is not None:
785
+ model.forward = MethodType(gemma_lce_forward_deprecated, model)
786
+ else:
787
+ modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
653
788
 
654
789
  if model is not None:
655
790
  # The model instance already exists, so we need to additionally patch the
@@ -700,7 +835,8 @@ def apply_liger_kernel_to_gemma2(
700
835
  from transformers.models.gemma2 import modeling_gemma2
701
836
  from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
702
837
 
703
- LigerRMSNormForGemma2 = partial(LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False)
838
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma2
839
+
704
840
  _patch_rms_norm_module_for_gemma2 = partial(
705
841
  _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
706
842
  )
@@ -720,10 +856,16 @@ def apply_liger_kernel_to_gemma2(
720
856
  modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
721
857
  if fused_linear_cross_entropy:
722
858
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
723
- modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
859
+ if model is not None:
860
+ model.forward = MethodType(gemma2_lce_forward, model)
861
+ else:
862
+ modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
724
863
  else:
725
864
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
726
- modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
865
+ if model is not None:
866
+ model.forward = MethodType(gemma2_lce_forward_deprected, model)
867
+ else:
868
+ modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
727
869
  if geglu:
728
870
  modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
729
871
 
@@ -777,9 +919,10 @@ def apply_liger_kernel_to_gemma3_text(
777
919
  from transformers.models.gemma3 import modeling_gemma3
778
920
  from transformers.models.gemma3.modeling_gemma3 import Gemma3DecoderLayer
779
921
  from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
922
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3TextModel
780
923
 
781
- from liger_kernel.transformers.gema3_rms import LigerRMSNormForGemma3
782
924
  from liger_kernel.transformers.model.gemma3 import causal_forward
925
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGemma3
783
926
 
784
927
  _patch_rms_norm_module_for_gemma3 = partial(
785
928
  _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
@@ -801,15 +944,18 @@ def apply_liger_kernel_to_gemma3_text(
801
944
  nn.functional.cross_entropy = liger_cross_entropy
802
945
 
803
946
  if fused_linear_cross_entropy:
804
- modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
947
+ if model is not None:
948
+ model.forward = MethodType(causal_forward, model)
949
+ else:
950
+ modeling_gemma3.Gemma3ForCausalLM.forward = causal_forward
805
951
 
806
952
  if model is not None:
807
953
  # The model instance already exists, so we need to additionally patch the
808
954
  # instance variables that reference already-instantiated modules
809
955
 
810
- if isinstance(model, Gemma3ForCausalLM):
956
+ if isinstance(model, Gemma3ForCausalLM) or isinstance(model, Gemma3TextModel):
811
957
  # get the base model from the model instance
812
- base_model = model.model
958
+ base_model = model.model if isinstance(model, Gemma3ForCausalLM) else model
813
959
 
814
960
  if rms_norm:
815
961
  _patch_rms_norm_module_for_gemma3(base_model.norm)
@@ -871,7 +1017,7 @@ def apply_liger_kernel_to_gemma3(
871
1017
  _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
872
1018
  )
873
1019
 
874
- if layer_norm:
1020
+ if layer_norm and model is None:
875
1021
  modeling_siglip.nn.LayerNorm = LigerLayerNorm
876
1022
 
877
1023
  apply_liger_kernel_to_gemma3_text(
@@ -882,7 +1028,10 @@ def apply_liger_kernel_to_gemma3(
882
1028
  modeling_gemma3.nn.CrossEntropyLoss = LigerCrossEntropyLoss
883
1029
 
884
1030
  if fused_linear_cross_entropy:
885
- modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
1031
+ if model is not None:
1032
+ model.forward = MethodType(multimodal_forward, model)
1033
+ else:
1034
+ modeling_gemma3.Gemma3ForConditionalGeneration.forward = multimodal_forward
886
1035
 
887
1036
  if model is not None:
888
1037
  # The model instance already exists, so we need to additionally patch the
@@ -950,7 +1099,9 @@ def apply_liger_kernel_to_paligemma(
950
1099
  # PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
951
1100
 
952
1101
  from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
1102
+ from transformers.models.gemma.modeling_gemma import GemmaModel
953
1103
  from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
1104
+ from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
954
1105
  from transformers.models.paligemma import modeling_paligemma
955
1106
  from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
956
1107
  from transformers.models.siglip import modeling_siglip
@@ -961,7 +1112,7 @@ def apply_liger_kernel_to_paligemma(
961
1112
  from liger_kernel.transformers.model.paligemma import lce_forward_deprecated
962
1113
 
963
1114
  # The vision_tower is a SiglipVisionModel
964
- if layer_norm:
1115
+ if layer_norm and model is None:
965
1116
  modeling_siglip.nn.LayerNorm = LigerLayerNorm
966
1117
 
967
1118
  # SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
@@ -979,10 +1130,16 @@ def apply_liger_kernel_to_paligemma(
979
1130
  modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
980
1131
  if fused_linear_cross_entropy:
981
1132
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
982
- modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
1133
+ if model is not None:
1134
+ model.forward = MethodType(lce_forward, model)
1135
+ else:
1136
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
983
1137
  else: # if version < 4.46.1
984
1138
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
985
- modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
1139
+ if model is not None:
1140
+ model.forward = MethodType(lce_forward_deprecated, model)
1141
+ else:
1142
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward_deprecated
986
1143
 
987
1144
  if model is not None:
988
1145
  # The model instance already exists, so we need to additionally patch the
@@ -1003,7 +1160,7 @@ def apply_liger_kernel_to_paligemma(
1003
1160
 
1004
1161
  language_model = model.language_model
1005
1162
 
1006
- if isinstance(language_model, GemmaForCausalLM):
1163
+ if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
1007
1164
  apply_liger_kernel_to_gemma(
1008
1165
  rope=rope,
1009
1166
  cross_entropy=False,
@@ -1013,7 +1170,7 @@ def apply_liger_kernel_to_paligemma(
1013
1170
  model=language_model,
1014
1171
  )
1015
1172
 
1016
- elif isinstance(language_model, Gemma2ForCausalLM):
1173
+ elif isinstance(language_model, (Gemma2ForCausalLM, Gemma2Model)):
1017
1174
  apply_liger_kernel_to_gemma2(
1018
1175
  rope=rope,
1019
1176
  cross_entropy=False,
@@ -1074,10 +1231,16 @@ def apply_liger_kernel_to_qwen2(
1074
1231
 
1075
1232
  if fused_linear_cross_entropy:
1076
1233
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1077
- modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
1234
+ if model is not None:
1235
+ model.forward = MethodType(qwen2_lce_forward, model)
1236
+ else:
1237
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
1078
1238
  else: # if version < 4.46.1
1079
1239
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1080
- modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
1240
+ if model is not None:
1241
+ model.forward = MethodType(qwen2_lce_forward_deprecated, model)
1242
+ else:
1243
+ modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
1081
1244
 
1082
1245
  if swiglu:
1083
1246
  modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
@@ -1133,7 +1296,10 @@ def apply_liger_kernel_to_qwen3(
1133
1296
  nn.functional.cross_entropy = liger_cross_entropy
1134
1297
 
1135
1298
  if fused_linear_cross_entropy:
1136
- modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
1299
+ if model is not None:
1300
+ model.forward = MethodType(qwen3_lce_forward, model)
1301
+ else:
1302
+ modeling_qwen3.Qwen3ForCausalLM.forward = qwen3_lce_forward
1137
1303
 
1138
1304
  if swiglu:
1139
1305
  modeling_qwen3.Qwen3MLP = LigerSwiGLUMLP
@@ -1188,7 +1354,10 @@ def apply_liger_kernel_to_qwen3_moe(
1188
1354
  nn.functional.cross_entropy = liger_cross_entropy
1189
1355
 
1190
1356
  if fused_linear_cross_entropy:
1191
- modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
1357
+ if model is not None:
1358
+ model.forward = MethodType(qwen3_lce_forward, model)
1359
+ else:
1360
+ modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = qwen3_lce_forward
1192
1361
 
1193
1362
  if swiglu:
1194
1363
  modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
@@ -1204,7 +1373,8 @@ def apply_liger_kernel_to_qwen3_moe(
1204
1373
  _patch_rms_norm_module(base_model.norm)
1205
1374
  for decoder_layer in base_model.layers:
1206
1375
  if swiglu:
1207
- _patch_swiglu_module(decoder_layer.mlp, LigerQwen3MoeSwiGLUMLP)
1376
+ for mlp_expert in decoder_layer.mlp.experts:
1377
+ _patch_swiglu_module(mlp_expert, LigerQwen3MoeSwiGLUMLP)
1208
1378
  if rms_norm:
1209
1379
  _patch_rms_norm_module(decoder_layer.input_layernorm)
1210
1380
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -1221,7 +1391,7 @@ def apply_liger_kernel_to_qwen2_vl(
1221
1391
  ) -> None:
1222
1392
  """
1223
1393
  Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
1224
- NOTE: Qwen2-VL is not available in transformers<4.45.0
1394
+ NOTE: Qwen2-VL is not supported in transformers<4.52.4
1225
1395
 
1226
1396
  Args:
1227
1397
  cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
@@ -1235,12 +1405,19 @@ def apply_liger_kernel_to_qwen2_vl(
1235
1405
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1236
1406
  loaded. Default is None.
1237
1407
  """
1408
+ if transformer_version < version.parse("4.52.4"):
1409
+ logger.warning("Qwen2-VL support is only compatible with transformers >= 4.52.4")
1410
+ return
1411
+
1238
1412
  assert not (cross_entropy and fused_linear_cross_entropy), (
1239
1413
  "cross_entropy and fused_linear_cross_entropy cannot both be True."
1240
1414
  )
1241
1415
 
1242
1416
  from transformers.models.qwen2_vl import modeling_qwen2_vl
1417
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
1418
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
1243
1419
  from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
1420
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel
1244
1421
 
1245
1422
  from liger_kernel.transformers.model.qwen2_vl import lce_forward as qwen2_vl_lce_forward
1246
1423
 
@@ -1249,12 +1426,15 @@ def apply_liger_kernel_to_qwen2_vl(
1249
1426
  if rms_norm:
1250
1427
  # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
1251
1428
  modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
1252
- if layer_norm:
1429
+ if layer_norm and model is None:
1253
1430
  modeling_qwen2_vl.LayerNorm = LigerLayerNorm
1254
1431
  if cross_entropy:
1255
1432
  modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
1256
1433
  if fused_linear_cross_entropy:
1257
- modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
1434
+ if model is not None:
1435
+ model.forward = MethodType(qwen2_vl_lce_forward, model)
1436
+ else:
1437
+ modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
1258
1438
  if swiglu:
1259
1439
  modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
1260
1440
 
@@ -1262,24 +1442,38 @@ def apply_liger_kernel_to_qwen2_vl(
1262
1442
  # The model instance already exists, so we need to additionally patch the
1263
1443
  # instance variables that reference already-instantiated modules
1264
1444
 
1265
- # get the base model from the model instance
1266
- base_model: Qwen2VLModel = getattr(model, model.base_model_prefix, model)
1445
+ if isinstance(model, (Qwen2VLForConditionalGeneration, Qwen2VLModel)):
1446
+ # Note: language_model and visual properties can be accessed throught conditional class for BC.
1447
+ # Not sure if it is subject to changes in the future.
1448
+ # Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1698
1449
+ text_model: Qwen2VLTextModel = model.language_model
1450
+ vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
1451
+ elif isinstance(model, Qwen2VLTextModel):
1452
+ text_model: Qwen2VLTextModel = model
1453
+ vision_model = None
1454
+ else:
1455
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
1456
+ raise TypeError(
1457
+ f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
1458
+ )
1267
1459
 
1268
- if hasattr(model, "visual"):
1269
- # Patch Qwen2VisionTransformerPretrainedModel
1270
- for vision_block in model.visual.blocks:
1460
+ # Patch Qwen2VisionTransformerPretrainedModel
1461
+ if vision_model is not None:
1462
+ for vision_block in vision_model.blocks:
1271
1463
  if layer_norm:
1272
1464
  _patch_layer_norm_module(vision_block.norm1)
1273
1465
  _patch_layer_norm_module(vision_block.norm2)
1274
1466
 
1275
- if rms_norm:
1276
- _patch_rms_norm_module(base_model.norm)
1277
- for decoder_layer in base_model.layers:
1278
- if swiglu:
1279
- _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1467
+ # Patch Qwen2VisionTextModel
1468
+ if text_model is not None:
1280
1469
  if rms_norm:
1281
- _patch_rms_norm_module(decoder_layer.input_layernorm)
1282
- _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1470
+ _patch_rms_norm_module(text_model.norm)
1471
+ for decoder_layer in text_model.layers:
1472
+ if swiglu:
1473
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1474
+ if rms_norm:
1475
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1476
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1283
1477
 
1284
1478
 
1285
1479
  def apply_liger_kernel_to_qwen2_5_vl(
@@ -1305,12 +1499,19 @@ def apply_liger_kernel_to_qwen2_5_vl(
1305
1499
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1306
1500
  loaded. Default is None.
1307
1501
  """
1502
+ if transformer_version < version.parse("4.52.4"):
1503
+ logger.warning("Qwen2.5-VL support is only compatible with transformers >= 4.52.4")
1504
+ return
1505
+
1308
1506
  assert not (cross_entropy and fused_linear_cross_entropy), (
1309
1507
  "cross_entropy and fused_linear_cross_entropy cannot both be True."
1310
1508
  )
1311
1509
 
1312
1510
  from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
1511
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
1512
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
1313
1513
  from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel
1514
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel
1314
1515
 
1315
1516
  from liger_kernel.transformers.model.qwen2_5_vl import lce_forward as qwen2_5_vl_lce_forward
1316
1517
 
@@ -1321,7 +1522,10 @@ def apply_liger_kernel_to_qwen2_5_vl(
1321
1522
  if cross_entropy:
1322
1523
  modeling_qwen2_5_vl.CrossEntropyLoss = LigerCrossEntropyLoss
1323
1524
  if fused_linear_cross_entropy:
1324
- modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
1525
+ if model is not None:
1526
+ model.forward = MethodType(qwen2_5_vl_lce_forward, model)
1527
+ else:
1528
+ modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_vl_lce_forward
1325
1529
  if swiglu:
1326
1530
  modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP
1327
1531
 
@@ -1329,24 +1533,37 @@ def apply_liger_kernel_to_qwen2_5_vl(
1329
1533
  # The model instance already exists, so we need to additionally patch the
1330
1534
  # instance variables that reference already-instantiated modules
1331
1535
 
1332
- # get the base model from the model instance
1333
- base_model: Qwen2_5_VLModel = getattr(model, model.base_model_prefix, model)
1536
+ if isinstance(model, (Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel)):
1537
+ # Note: language_model and visual properties can be accessed throught conditional class for BC.
1538
+ # Not sure if it is subject to changes in the future.
1539
+ # Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1823
1540
+ text_model: Qwen2_5_VLTextModel = model.language_model
1541
+ vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual
1542
+ elif isinstance(model, Qwen2_5_VLTextModel):
1543
+ text_model: Qwen2_5_VLTextModel = model
1544
+ vision_model = None
1545
+ else:
1546
+ # Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
1547
+ raise TypeError(
1548
+ f"Unsupported Qwen2VL model type. `model` must be `Qwen2VLForConditionalGeneration`, `Qwen2VLModel` or `Qwen2VLTextModel`. Got: {type(model)}"
1549
+ )
1334
1550
 
1335
- if hasattr(model, "visual"):
1551
+ if vision_model is not None:
1336
1552
  # Patch Qwen2_5_VisionTransformerPretrainedModel
1337
1553
  for vision_block in model.visual.blocks:
1338
1554
  if rms_norm:
1339
1555
  _patch_rms_norm_module(vision_block.norm1)
1340
1556
  _patch_rms_norm_module(vision_block.norm2)
1341
1557
 
1342
- if rms_norm:
1343
- _patch_rms_norm_module(base_model.norm)
1344
- for decoder_layer in base_model.layers:
1345
- if swiglu:
1346
- _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1558
+ if text_model is not None:
1347
1559
  if rms_norm:
1348
- _patch_rms_norm_module(decoder_layer.input_layernorm)
1349
- _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1560
+ _patch_rms_norm_module(text_model.norm)
1561
+ for decoder_layer in text_model.layers:
1562
+ if swiglu:
1563
+ _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1564
+ if rms_norm:
1565
+ _patch_rms_norm_module(decoder_layer.input_layernorm)
1566
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1350
1567
 
1351
1568
 
1352
1569
  def apply_liger_kernel_to_phi3(
@@ -1395,10 +1612,16 @@ def apply_liger_kernel_to_phi3(
1395
1612
  modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
1396
1613
  if fused_linear_cross_entropy:
1397
1614
  if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1398
- modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
1615
+ if model is not None:
1616
+ model.forward = MethodType(phi3_lce_forward, model)
1617
+ else:
1618
+ modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
1399
1619
  else: # if version < 4.46.1
1400
1620
  logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1401
- modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
1621
+ if model is not None:
1622
+ model.forward = MethodType(phi3_lce_forward_deprecated, model)
1623
+ else:
1624
+ modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
1402
1625
 
1403
1626
  if model is not None:
1404
1627
  # The model instance already exists, so we need to additionally patch the
@@ -1449,11 +1672,12 @@ def apply_liger_kernel_to_olmo2(
1449
1672
  from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
1450
1673
 
1451
1674
  from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
1675
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
1452
1676
 
1453
1677
  if rope:
1454
1678
  modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
1455
1679
  if rms_norm:
1456
- modeling_olmo2.Olmo2RMSNorm = partial(LigerRMSNorm, in_place=False)
1680
+ modeling_olmo2.Olmo2RMSNorm = LigerRMSNormForOlmo2
1457
1681
  if swiglu:
1458
1682
  modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
1459
1683
  if cross_entropy:
@@ -1461,7 +1685,10 @@ def apply_liger_kernel_to_olmo2(
1461
1685
 
1462
1686
  nn.functional.cross_entropy = liger_cross_entropy
1463
1687
  if fused_linear_cross_entropy:
1464
- modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
1688
+ if model is not None:
1689
+ model.forward = MethodType(olmo2_lce_forward, model)
1690
+ else:
1691
+ modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
1465
1692
 
1466
1693
  if model is not None:
1467
1694
  # The model instance already exists, so we need to additionally patch the
@@ -1512,11 +1739,12 @@ def apply_liger_kernel_to_glm4(
1512
1739
  from transformers.models.glm4.modeling_glm4 import Glm4Model
1513
1740
 
1514
1741
  from liger_kernel.transformers.model.glm4 import lce_forward as glm4_lce_forward
1742
+ from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
1515
1743
 
1516
1744
  if rope:
1517
1745
  raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
1518
1746
  if rms_norm:
1519
- modeling_glm4.Glm4RMSNorm = partial(LigerRMSNorm, in_place=False)
1747
+ modeling_glm4.Glm4RMSNorm = LigerRMSNormForGlm4
1520
1748
  if swiglu:
1521
1749
  modeling_glm4.Glm4MLP = LigerPhi3SwiGLUMLP
1522
1750
  if cross_entropy:
@@ -1524,7 +1752,10 @@ def apply_liger_kernel_to_glm4(
1524
1752
 
1525
1753
  nn.functional.cross_entropy = liger_cross_entropy
1526
1754
  if fused_linear_cross_entropy:
1527
- modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
1755
+ if model is not None:
1756
+ model.forward = MethodType(glm4_lce_forward, model)
1757
+ else:
1758
+ modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
1528
1759
 
1529
1760
  if model is not None:
1530
1761
  # The model instance already exists, so we need to additionally patch the
@@ -1554,6 +1785,8 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
1554
1785
  "gemma3": apply_liger_kernel_to_gemma3,
1555
1786
  "glm4": apply_liger_kernel_to_glm4,
1556
1787
  "llama": apply_liger_kernel_to_llama,
1788
+ "llama4_text": apply_liger_kernel_to_llama4,
1789
+ "llama4": apply_liger_kernel_to_llama4,
1557
1790
  "llava": apply_liger_kernel_to_llava,
1558
1791
  "granite": apply_liger_kernel_to_granite,
1559
1792
  "mllama": apply_liger_kernel_to_mllama,
@@ -1565,7 +1798,9 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
1565
1798
  "qwen3": apply_liger_kernel_to_qwen3,
1566
1799
  "qwen3_moe": apply_liger_kernel_to_qwen3_moe,
1567
1800
  "qwen2_vl": apply_liger_kernel_to_qwen2_vl,
1801
+ "qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
1568
1802
  "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
1803
+ "qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
1569
1804
  "phi3": apply_liger_kernel_to_phi3,
1570
1805
  "paligemma": apply_liger_kernel_to_paligemma,
1571
1806
  }
@@ -1625,7 +1860,6 @@ def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
1625
1860
  return
1626
1861
 
1627
1862
  apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
1628
-
1629
1863
  apply_fn_signature = inspect.signature(apply_fn)
1630
1864
 
1631
1865
  # Filter out the keyword arguments that are not supported by the apply function