liger-kernel 0.5.6__tar.gz → 0.5.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/PKG-INFO +3 -1
  2. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/README.md +2 -0
  3. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/pyproject.toml +1 -1
  4. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/chunked_loss/fused_linear_ppo.py +15 -0
  5. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/chunked_loss/grpo_loss.py +33 -1
  6. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/ops/jsd.py +2 -1
  7. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/ops/kl_div.py +13 -6
  8. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/ops/layer_norm.py +14 -1
  9. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/ops/rms_norm.py +12 -1
  10. liger_kernel-0.5.7/src/liger_kernel/transformers/__init__.py +145 -0
  11. liger_kernel-0.5.7/src/liger_kernel/transformers/gema3_rms.py +8 -0
  12. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/model/gemma.py +9 -4
  13. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/model/gemma2.py +10 -5
  14. liger_kernel-0.5.7/src/liger_kernel/transformers/model/gemma3.py +335 -0
  15. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/model/llama.py +9 -4
  16. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/model/loss_utils.py +17 -10
  17. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/model/mistral.py +19 -15
  18. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/model/mixtral.py +12 -11
  19. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/model/mllama.py +9 -4
  20. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/model/olmo2.py +9 -4
  21. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/model/phi3.py +9 -4
  22. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/model/qwen2.py +9 -4
  23. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/monkey_patch.py +173 -0
  24. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel.egg-info/PKG-INFO +3 -1
  25. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel.egg-info/SOURCES.txt +3 -0
  26. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/chunked_loss/test_grpo_loss.py +49 -24
  27. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/chunked_loss/test_kto_loss.py +62 -58
  28. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/convergence/bf16/test_mini_models.py +59 -0
  29. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/convergence/bf16/test_mini_models_multimodal.py +100 -0
  30. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/convergence/bf16/test_mini_models_with_logits.py +59 -0
  31. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/convergence/fp32/test_mini_models.py +55 -0
  32. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/convergence/fp32/test_mini_models_multimodal.py +96 -2
  33. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/convergence/fp32/test_mini_models_with_logits.py +55 -0
  34. liger_kernel-0.5.7/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json +90 -0
  35. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/transformers/test_monkey_patch.py +150 -0
  36. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/utils.py +29 -1
  37. liger_kernel-0.5.6/src/liger_kernel/transformers/__init__.py +0 -30
  38. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  39. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  40. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/.github/pull_request_template.md +0 -0
  41. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/.github/workflows/amd-ci.yml +0 -0
  42. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/.github/workflows/docs.yml +0 -0
  43. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/.github/workflows/intel-ci.yml +0 -0
  44. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/.github/workflows/nvi-ci.yml +0 -0
  45. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/.github/workflows/publish-nightly.yml +0 -0
  46. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/.github/workflows/publish-release.yml +0 -0
  47. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/.gitignore +0 -0
  48. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/LICENSE +0 -0
  49. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/Makefile +0 -0
  50. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/NOTICE +0 -0
  51. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/benchmark/README.md +0 -0
  52. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/benchmark/__init__.py +0 -0
  53. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/benchmark/benchmarks_visualizer.py +0 -0
  54. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/benchmark/data/all_benchmark_data.csv +0 -0
  55. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/benchmark/scripts/__init__.py +0 -0
  56. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  57. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  58. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  59. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  60. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_dyt.py +0 -0
  61. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_embedding.py +0 -0
  62. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  63. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  64. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_geglu.py +0 -0
  65. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_group_norm.py +0 -0
  66. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_jsd.py +0 -0
  67. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_kl_div.py +0 -0
  68. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  69. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  70. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  71. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  72. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  73. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_rope.py +0 -0
  74. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  75. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_swiglu.py +0 -0
  76. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_tvd.py +0 -0
  77. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/benchmark/scripts/utils.py +0 -0
  78. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/dev/fmt-requirements.txt +0 -0
  79. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/dev/modal/tests.py +0 -0
  80. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/dev/modal/tests_bwd.py +0 -0
  81. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/docs/Examples.md +0 -0
  82. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/docs/Getting-Started.md +0 -0
  83. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/docs/High-Level-APIs.md +0 -0
  84. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/docs/Low-Level-APIs.md +0 -0
  85. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/docs/acknowledgement.md +0 -0
  86. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/docs/contributing.md +0 -0
  87. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/docs/images/banner.GIF +0 -0
  88. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/docs/images/compose.gif +0 -0
  89. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/docs/images/e2e-memory.png +0 -0
  90. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/docs/images/e2e-tps.png +0 -0
  91. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/docs/images/logo-banner.png +0 -0
  92. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/docs/images/patch.gif +0 -0
  93. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/docs/images/post-training.png +0 -0
  94. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/docs/index.md +0 -0
  95. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/docs/license.md +0 -0
  96. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/alignment/accelerate_config.yaml +0 -0
  97. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/alignment/run_orpo.py +0 -0
  98. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/huggingface/README.md +0 -0
  99. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/huggingface/callback.py +0 -0
  100. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/huggingface/config/fsdp_config.json +0 -0
  101. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  102. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  103. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  104. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/huggingface/img/llama_tps.png +0 -0
  105. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  106. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/huggingface/img/qwen_tps.png +0 -0
  107. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/huggingface/launch_on_modal.py +0 -0
  108. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/huggingface/requirements.txt +0 -0
  109. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/huggingface/run_benchmarks.sh +0 -0
  110. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/huggingface/run_gemma.sh +0 -0
  111. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/huggingface/run_llama.sh +0 -0
  112. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/huggingface/run_qwen.sh +0 -0
  113. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/huggingface/run_qwen2_vl.sh +0 -0
  114. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/huggingface/training.py +0 -0
  115. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/huggingface/training_multimodal.py +0 -0
  116. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/lightning/README.md +0 -0
  117. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/lightning/requirements.txt +0 -0
  118. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/lightning/training.py +0 -0
  119. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/medusa/README.md +0 -0
  120. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/medusa/callback.py +0 -0
  121. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  122. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  123. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  124. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  125. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  126. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  127. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  128. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  129. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  130. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/medusa/medusa_util.py +0 -0
  131. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/medusa/requirements.txt +0 -0
  132. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  133. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/examples/medusa/train.py +0 -0
  134. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/licenses/LICENSE-Apache-2.0 +0 -0
  135. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  136. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  137. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/licenses/LICENSE-MIT-llmc +0 -0
  138. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/licenses/LICENSE-MIT-triton +0 -0
  139. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/mkdocs.yml +0 -0
  140. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/setup.cfg +0 -0
  141. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/setup.py +0 -0
  142. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/__init__.py +0 -0
  143. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/chunked_loss/README.md +0 -0
  144. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  145. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  146. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  147. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/chunked_loss/functional.py +0 -0
  148. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  149. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  150. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  151. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  152. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  153. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  154. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  155. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/env_report.py +0 -0
  156. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/ops/__init__.py +0 -0
  157. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/ops/cross_entropy.py +0 -0
  158. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/ops/dyt.py +0 -0
  159. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  160. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  161. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  162. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  163. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/ops/geglu.py +0 -0
  164. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/ops/group_norm.py +0 -0
  165. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  166. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/ops/rope.py +0 -0
  167. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/ops/swiglu.py +0 -0
  168. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/ops/tvd.py +0 -0
  169. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/ops/utils.py +0 -0
  170. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/auto_model.py +0 -0
  171. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  172. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/dyt.py +0 -0
  173. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  174. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/functional.py +0 -0
  175. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  176. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  177. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/geglu.py +0 -0
  178. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/group_norm.py +0 -0
  179. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/jsd.py +0 -0
  180. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/kl_div.py +0 -0
  181. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/layer_norm.py +0 -0
  182. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/model/__init__.py +0 -0
  183. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/model/llava.py +0 -0
  184. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/model/paligemma.py +0 -0
  185. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
  186. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  187. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  188. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/rms_norm.py +0 -0
  189. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/rope.py +0 -0
  190. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/swiglu.py +0 -0
  191. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  192. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  193. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  194. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/transformers/tvd.py +0 -0
  195. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/triton/__init__.py +0 -0
  196. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/triton/monkey_patch.py +0 -0
  197. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel/utils.py +0 -0
  198. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel.egg-info/dependency_links.txt +0 -0
  199. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel.egg-info/requires.txt +0 -0
  200. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/src/liger_kernel.egg-info/top_level.txt +0 -0
  201. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/__init__.py +0 -0
  202. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/chunked_loss/__init__.py +0 -0
  203. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/chunked_loss/test_cpo_loss.py +0 -0
  204. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/chunked_loss/test_dpo_loss.py +0 -0
  205. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/chunked_loss/test_jsd_loss.py +0 -0
  206. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/chunked_loss/test_orpo_loss.py +0 -0
  207. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/chunked_loss/test_simpo_loss.py +0 -0
  208. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/conftest.py +0 -0
  209. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/convergence/__init__.py +0 -0
  210. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/convergence/bf16/__init__.py +0 -0
  211. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/convergence/fp32/__init__.py +0 -0
  212. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
  213. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
  214. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
  215. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
  216. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  217. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
  218. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  219. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  220. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/resources/tiny_shakespeare.txt +0 -0
  221. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  222. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  223. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  224. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/transformers/test_auto_model.py +0 -0
  225. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/transformers/test_cross_entropy.py +0 -0
  226. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/transformers/test_dyt.py +0 -0
  227. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/transformers/test_embedding.py +0 -0
  228. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/transformers/test_flex_attention.py +0 -0
  229. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  230. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/transformers/test_fused_linear_jsd.py +0 -0
  231. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/transformers/test_geglu.py +0 -0
  232. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/transformers/test_group_norm.py +0 -0
  233. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/transformers/test_jsd.py +0 -0
  234. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/transformers/test_kl_div.py +0 -0
  235. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/transformers/test_layer_norm.py +0 -0
  236. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/transformers/test_mm_int8int2.py +0 -0
  237. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/transformers/test_qwen2vl_mrope.py +0 -0
  238. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/transformers/test_rms_norm.py +0 -0
  239. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/transformers/test_rope.py +0 -0
  240. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/transformers/test_swiglu.py +0 -0
  241. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/transformers/test_trainer_integration.py +0 -0
  242. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/transformers/test_transformers.py +0 -0
  243. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/transformers/test_tvd.py +0 -0
  244. {liger_kernel-0.5.6 → liger_kernel-0.5.7}/test/triton/test_triton_monkey_patch.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: liger_kernel
3
- Version: 0.5.6
3
+ Version: 0.5.7
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -314,6 +314,8 @@ loss.backward()
314
314
  | Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
315
315
  | Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
316
316
  | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
317
+ | Gemma3 (Text) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3_text` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
318
+ | Gemma3 (Multimodal) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
317
319
  | Paligemma, Paligemma2, & Paligemma2 Mix | `liger_kernel.transformers.apply_liger_kernel_to_paligemma` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
318
320
  | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
319
321
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
@@ -263,6 +263,8 @@ loss.backward()
263
263
  | Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
264
264
  | Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
265
265
  | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
266
+ | Gemma3 (Text) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3_text` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
267
+ | Gemma3 (Multimodal) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
266
268
  | Paligemma, Paligemma2, & Paligemma2 Mix | `liger_kernel.transformers.apply_liger_kernel_to_paligemma` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
267
269
  | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
268
270
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel"
7
- version = "0.5.6"
7
+ version = "0.5.7"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -32,6 +32,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
32
32
  epsilon_low=0.2,
33
33
  epsilon_high=0.2,
34
34
  beta=0.04,
35
+ loss_type="bnpo",
36
+ max_completion_length=None,
35
37
  temperature=1.0,
36
38
  compiled=True,
37
39
  use_ref_model=False,
@@ -57,6 +59,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
57
59
  epsilon_low: Lower bound for clipping the importance sampling ratio
58
60
  epsilon_high: Upper bound for clipping the importance sampling ratio
59
61
  beta: Weight for the KL penalty
62
+ loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo")
63
+ max_completion_length: Maximum completion length required for "dr_grpo"
60
64
  temperature: Temperature for the logits
61
65
  compiled: Whether to use torch compile
62
66
  use_ref_model: Whether to use a reference model
@@ -68,6 +72,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
68
72
  )
69
73
  if ref_per_token_logps is not None and ref_input is not None:
70
74
  raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.")
75
+ if loss_type == "dr_grpo":
76
+ assert max_completion_length is not None, "max_completion_length must be provided for loss_type 'dr_grpo'"
71
77
  # Initialize accumulators
72
78
  loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32)
73
79
  grad_weight = torch.zeros_like(weight) # [V, H]
@@ -84,6 +90,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
84
90
  epsilon_low=epsilon_low,
85
91
  epsilon_high=epsilon_high,
86
92
  beta=beta,
93
+ loss_type=loss_type,
94
+ max_completion_length=max_completion_length,
87
95
  temperature=temperature,
88
96
  use_ref_model=use_ref_model,
89
97
  ppo_loss_fn=cls.ppo_loss_fn,
@@ -251,6 +259,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
251
259
  epsilon_low=0.2,
252
260
  epsilon_high=0.2,
253
261
  beta=0.04,
262
+ loss_type="bnpo",
263
+ max_completion_length=None,
254
264
  temperature=1.0,
255
265
  use_ref_model=False,
256
266
  ppo_loss_fn=None,
@@ -280,6 +290,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
280
290
  epsilon_low=epsilon_low,
281
291
  epsilon_high=epsilon_high,
282
292
  beta=beta,
293
+ loss_type=loss_type,
294
+ max_completion_length=max_completion_length,
283
295
  )
284
296
 
285
297
  return chunk_loss, chunk_metrics
@@ -303,6 +315,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
303
315
  def backward(ctx, grad_output, *grad_metrics):
304
316
  """Backward pass for PPO loss."""
305
317
  grad_input, grad_weight, grad_bias = ctx.saved_tensors
318
+
306
319
  if grad_output != 1.0:
307
320
  grad_input = grad_input * grad_output
308
321
  grad_weight = grad_weight * grad_output
@@ -328,4 +341,6 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
328
341
  None, # grad_compiled
329
342
  None, # grad_use_ref_model
330
343
  None, # grad_chunk_size
344
+ None, # grad_loss_type
345
+ None, # grad_max_completion_length
331
346
  )
@@ -27,6 +27,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
27
27
  epsilon_low=0.2,
28
28
  epsilon_high=0.2,
29
29
  beta=0.04,
30
+ loss_type="bnpo", # ["grpo", "bnpo", "dr_grpo"]
31
+ max_completion_length=None, # Required for dr_grpo
30
32
  **kwargs,
31
33
  ):
32
34
  """GRPO Loss Function matching GRPOTrainer implementation."""
@@ -61,7 +63,21 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
61
63
  # which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1)
62
64
  # and TRL GRPO implementation
63
65
  # (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966)
64
- loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
66
+ if loss_type == "grpo":
67
+ # Average per-sequence loss
68
+ loss = (
69
+ (per_token_loss * attention_mask).sum(-1) / torch.clamp(attention_mask.sum(-1), min=1.0)
70
+ ).sum() / full_attention_mask.shape[0]
71
+ elif loss_type == "bnpo":
72
+ # Batch Normalized Per-token loss (original implementation)
73
+ loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
74
+ elif loss_type == "dr_grpo":
75
+ # Dimension-Reduced GRPO (normalize by batch_size * max_completion_length)
76
+ if max_completion_length is None:
77
+ raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
78
+ loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
79
+ else:
80
+ raise ValueError(f"Unknown loss type: {loss_type}")
65
81
 
66
82
  # Calculate metrics
67
83
  metrics = []
@@ -91,6 +107,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
91
107
  beta=0.04,
92
108
  epsilon_low=0.2,
93
109
  epsilon_high=0.2,
110
+ loss_type="bnpo",
111
+ max_completion_length=None,
94
112
  temperature=1.0,
95
113
  compiled=True,
96
114
  use_ref_model=True,
@@ -110,6 +128,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
110
128
  ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
111
129
  ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
112
130
  beta (float): Weight for the KL penalty
131
+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
132
+ max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
113
133
  temperature (float): Temperature for the logits
114
134
  compiled (bool): Whether to use torch compile
115
135
  use_ref_model (bool): Whether to use a reference model
@@ -134,6 +154,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
134
154
  beta=beta,
135
155
  epsilon_low=epsilon_low,
136
156
  epsilon_high=epsilon_high,
157
+ loss_type=loss_type,
158
+ max_completion_length=max_completion_length,
137
159
  temperature=temperature,
138
160
  compiled=compiled,
139
161
  use_ref_model=use_ref_model,
@@ -161,6 +183,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
161
183
  None, # grad_beta
162
184
  None, # grad_epsilon_low
163
185
  None, # grad_epsilon_high
186
+ None, # grad_loss_type (string, not differentiable)
187
+ None, # grad_max_completion_length (int, not differentiable)
164
188
  None, # grad_temperature
165
189
  None, # grad_compiled
166
190
  None, # grad_use_ref_model
@@ -179,6 +203,8 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
179
203
  chunk_size: int = 1,
180
204
  epsilon_low: float = 0.2,
181
205
  epsilon_high: float = 0.2,
206
+ loss_type: str = "bnpo",
207
+ max_completion_length: int | None = None,
182
208
  temperature: float = 1.0,
183
209
  ):
184
210
  """
@@ -189,6 +215,8 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
189
215
  chunk_size (int): Size of chunks for processing.
190
216
  epsilon_low (float): Lower bound for the importance sampling ratio.
191
217
  epsilon_high (float): Upper bound for the importance sampling ratio.
218
+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
219
+ max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
192
220
  temperature (float): Temperature for the logits.
193
221
  """
194
222
  super().__init__()
@@ -198,6 +226,8 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
198
226
  self.chunk_size = chunk_size
199
227
  self.epsilon_low = epsilon_low
200
228
  self.epsilon_high = epsilon_high
229
+ self.loss_type = loss_type
230
+ self.max_completion_length = max_completion_length
201
231
  self.temperature = temperature
202
232
 
203
233
  def forward(
@@ -229,6 +259,8 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
229
259
  self.beta,
230
260
  self.epsilon_low,
231
261
  self.epsilon_high,
262
+ self.loss_type,
263
+ self.max_completion_length,
232
264
  self.temperature,
233
265
  self.compiled,
234
266
  self.use_ref_model,
@@ -5,6 +5,7 @@ import triton
5
5
  import triton.language as tl
6
6
 
7
7
  from liger_kernel.ops.utils import ensure_contiguous
8
+ from liger_kernel.utils import infer_device
8
9
 
9
10
 
10
11
  @triton.jit
@@ -92,7 +93,7 @@ def _jsd_kernel(
92
93
  tl.store(dX_ptr + offsets, dX, mask=mask)
93
94
 
94
95
 
95
- MAX_FUSED_SIZE = 65536
96
+ MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536
96
97
 
97
98
 
98
99
  def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
@@ -6,6 +6,7 @@ import triton.language as tl
6
6
 
7
7
  from liger_kernel.ops.utils import ensure_contiguous
8
8
  from liger_kernel.ops.utils import is_hip
9
+ from liger_kernel.utils import infer_device
9
10
 
10
11
 
11
12
  def get_num_warps(BLOCK_SIZE):
@@ -115,9 +116,12 @@ def _kldiv_kernel_backward(
115
116
 
116
117
  def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
117
118
  BT, V = y_pred.shape
118
-
119
- BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
120
- num_warps = get_num_warps(BLOCK_SIZE)
119
+ BLOCK_SIZE = (
120
+ min(8192, triton.next_power_of_2(V))
121
+ if infer_device() == "xpu"
122
+ else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
123
+ )
124
+ num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
121
125
 
122
126
  grid = (BT,)
123
127
  reduction = _str_to_reduction_mode[reduction]
@@ -155,9 +159,12 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
155
159
 
156
160
  def kldiv_backward_triton(target, grad_output, new_grads, log_target):
157
161
  BT, V = target.shape
158
-
159
- BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
160
- num_warps = get_num_warps(BLOCK_SIZE)
162
+ BLOCK_SIZE = (
163
+ min(8192, triton.next_power_of_2(V))
164
+ if infer_device() == "xpu"
165
+ else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
166
+ )
167
+ num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
161
168
 
162
169
  grid = (BT,)
163
170
 
@@ -154,6 +154,11 @@ def layer_norm_forward(X, W, B, eps):
154
154
  f"must match weight size (W.shape[0]={W.shape[0]})"
155
155
  )
156
156
 
157
+ # XPU-specific optimization
158
+ kernel_args = {}
159
+ if X.device.type == "xpu":
160
+ kernel_args["grf_mode"] = "large"
161
+
157
162
  _layer_norm_forward_kernel[(n_rows,)](
158
163
  Y,
159
164
  Y.stride(0),
@@ -171,6 +176,7 @@ def layer_norm_forward(X, W, B, eps):
171
176
  eps,
172
177
  BLOCK_SIZE=BLOCK_SIZE,
173
178
  num_warps=num_warps,
179
+ **kernel_args, # XPU-specific optimization
174
180
  )
175
181
  return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps
176
182
 
@@ -185,7 +191,7 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
185
191
  if X.device.type == "cuda":
186
192
  sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
187
193
  elif X.device.type == "xpu":
188
- sm_count = torch.xpu.get_device_properties(X.device).gpu_subslice_count
194
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
189
195
 
190
196
  DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
191
197
  _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)
@@ -208,6 +214,12 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
208
214
  if X.dtype == torch.float16
209
215
  else tl.float32 # fallback to float32 for other types
210
216
  )
217
+
218
+ # XPU-specific optimization
219
+ kernel_args = {}
220
+ if X.device.type == "xpu":
221
+ kernel_args.update({"grf_mode": "large", "num_warps": 32, "num_stages": 4})
222
+
211
223
  _layer_norm_backward_kernel[grid](
212
224
  X,
213
225
  W,
@@ -227,6 +239,7 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
227
239
  rows_per_program,
228
240
  BLOCK_SIZE=BLOCK_SIZE,
229
241
  dtype=triton_dtype,
242
+ **kernel_args, # XPU-specific optimization
230
243
  )
231
244
 
232
245
  DW = _DW.sum(dim=0).to(W.dtype)
@@ -223,6 +223,10 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
223
223
  # Check constraints.
224
224
  assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
225
225
 
226
+ # XPU-specific optimization
227
+ kernel_args = {}
228
+ if X.device.type == "xpu":
229
+ kernel_args["grf_mode"] = "large"
226
230
  _rms_norm_forward_kernel[(n_rows,)](
227
231
  Y,
228
232
  Y.stride(0),
@@ -238,6 +242,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode):
238
242
  casting_mode,
239
243
  BLOCK_SIZE=BLOCK_SIZE,
240
244
  num_warps=num_warps,
245
+ **kernel_args, # XPU-specific optimization
241
246
  )
242
247
  return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
243
248
 
@@ -252,7 +257,7 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
252
257
  if X.device.type == "cuda":
253
258
  sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
254
259
  elif X.device.type == "xpu":
255
- sm_count = torch.xpu.get_device_properties(X.device).gpu_subslice_count
260
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
256
261
 
257
262
  # fp32 for numerical stability especially.
258
263
  _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
@@ -267,6 +272,11 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
267
272
  else:
268
273
  dX = torch.zeros_like(dY)
269
274
 
275
+ # XPU-specific optimization
276
+ kernel_args = {}
277
+ if X.device.type == "xpu":
278
+ kernel_args["grf_mode"] = "large"
279
+
270
280
  _rms_norm_backward_kernel[grid](
271
281
  dY,
272
282
  dY.stride(0),
@@ -288,6 +298,7 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
288
298
  casting_mode,
289
299
  BLOCK_SIZE=BLOCK_SIZE,
290
300
  num_warps=num_warps,
301
+ **kernel_args, # XPU-specific optimization
291
302
  )
292
303
  dX = dX.view(*shape)
293
304
  dW = _dW.sum(dim=0).to(W.dtype)
@@ -0,0 +1,145 @@
1
+ import importlib
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ # Always-safe imports (independent of 'transformers')
6
+ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
7
+ from liger_kernel.transformers.dyt import LigerDyT # noqa: F401
8
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401
9
+ from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
10
+ from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
11
+ from liger_kernel.transformers.jsd import LigerJSD # noqa: F401
12
+ from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
13
+ from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
14
+ from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
15
+ from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP # noqa: F401
16
+ from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP # noqa: F401
17
+ from liger_kernel.transformers.swiglu import LigerSwiGLUMLP # noqa: F401
18
+ from liger_kernel.transformers.tvd import LigerTVDLoss # noqa: F401
19
+
20
+ # Static-only imports for IDEs and type checkers
21
+ if TYPE_CHECKING:
22
+ from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
23
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401
24
+ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
25
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
26
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
27
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401
28
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401
29
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
30
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
31
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
32
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
33
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
34
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
35
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo2 # noqa: F401
36
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_paligemma # noqa: F401
37
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
38
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
39
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_5_vl # noqa: F401
40
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
41
+
42
+
43
+ # Check if 'transformers' is installed
44
+ try:
45
+ import transformers # noqa: F401
46
+
47
+ _TRANSFORMERS_AVAILABLE = True
48
+ except ImportError:
49
+ _TRANSFORMERS_AVAILABLE = False
50
+
51
+
52
+ def is_transformers_available() -> bool:
53
+ """
54
+ Returns True if the 'transformers' package is available.
55
+ Useful for conditional logic in downstream code.
56
+ """
57
+ return _TRANSFORMERS_AVAILABLE
58
+
59
+
60
+ def __getattr__(name: str):
61
+ """
62
+ Handles lazy access to transformer-dependent attributes.
63
+ If 'transformers' is not installed, raises a user-friendly ImportError.
64
+ """
65
+ if not _TRANSFORMERS_AVAILABLE:
66
+ raise ImportError(
67
+ f"The attribute '{name}' requires the 'transformers' library, which is not installed.\n"
68
+ f"Please install it with `pip install transformers` to use this functionality."
69
+ )
70
+
71
+ if name == "AutoLigerKernelForCausalLM":
72
+ module = importlib.import_module("liger_kernel.transformers.auto_model")
73
+ return getattr(module, name)
74
+
75
+ monkey_patch_symbols = {
76
+ "_apply_liger_kernel",
77
+ "_apply_liger_kernel_to_instance",
78
+ "apply_liger_kernel_to_gemma",
79
+ "apply_liger_kernel_to_gemma2",
80
+ "apply_liger_kernel_to_gemma3",
81
+ "apply_liger_kernel_to_gemma3_text",
82
+ "apply_liger_kernel_to_granite",
83
+ "apply_liger_kernel_to_llama",
84
+ "apply_liger_kernel_to_llava",
85
+ "apply_liger_kernel_to_mistral",
86
+ "apply_liger_kernel_to_mixtral",
87
+ "apply_liger_kernel_to_mllama",
88
+ "apply_liger_kernel_to_olmo2",
89
+ "apply_liger_kernel_to_paligemma",
90
+ "apply_liger_kernel_to_phi3",
91
+ "apply_liger_kernel_to_qwen2",
92
+ "apply_liger_kernel_to_qwen2_5_vl",
93
+ "apply_liger_kernel_to_qwen2_vl",
94
+ }
95
+
96
+ if name in monkey_patch_symbols:
97
+ module = importlib.import_module("liger_kernel.transformers.monkey_patch")
98
+ return getattr(module, name)
99
+
100
+ raise AttributeError(f"module {__name__} has no attribute {name}")
101
+
102
+
103
+ # Shared symbols in all environments
104
+ __all__ = [
105
+ "is_transformers_available",
106
+ "LigerCrossEntropyLoss",
107
+ "LigerDyT",
108
+ "LigerFusedLinearCrossEntropyLoss",
109
+ "LigerFusedLinearJSD",
110
+ "LigerGEGLUMLP",
111
+ "LigerJSD",
112
+ "LigerLayerNorm",
113
+ "LigerRMSNorm",
114
+ "liger_rotary_pos_emb",
115
+ "LigerBlockSparseTop2MLP",
116
+ "LigerPhi3SwiGLUMLP",
117
+ "LigerSwiGLUMLP",
118
+ "LigerTVDLoss",
119
+ ]
120
+
121
+ # Add transformer-dependent symbols only if available
122
+ if _TRANSFORMERS_AVAILABLE:
123
+ __all__.extend(
124
+ [
125
+ "AutoLigerKernelForCausalLM",
126
+ "_apply_liger_kernel",
127
+ "_apply_liger_kernel_to_instance",
128
+ "apply_liger_kernel_to_gemma",
129
+ "apply_liger_kernel_to_gemma2",
130
+ "apply_liger_kernel_to_gemma3",
131
+ "apply_liger_kernel_to_gemma3_text",
132
+ "apply_liger_kernel_to_granite",
133
+ "apply_liger_kernel_to_llama",
134
+ "apply_liger_kernel_to_llava",
135
+ "apply_liger_kernel_to_mistral",
136
+ "apply_liger_kernel_to_mixtral",
137
+ "apply_liger_kernel_to_mllama",
138
+ "apply_liger_kernel_to_olmo2",
139
+ "apply_liger_kernel_to_paligemma",
140
+ "apply_liger_kernel_to_phi3",
141
+ "apply_liger_kernel_to_qwen2",
142
+ "apply_liger_kernel_to_qwen2_5_vl",
143
+ "apply_liger_kernel_to_qwen2_vl",
144
+ ]
145
+ )
@@ -0,0 +1,8 @@
1
+ from .rms_norm import LigerRMSNorm
2
+
3
+
4
+ class LigerRMSNormForGemma3(LigerRMSNorm):
5
+ """Gemma3RMSNorm has a dim argument not hidden_size used in q_norm and k_norm."""
6
+
7
+ def __init__(self, dim, eps=0.000001, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False):
8
+ super().__init__(dim, eps, offset, casting_mode, init_fn, in_place)
@@ -12,6 +12,7 @@ from transformers.models.gemma.modeling_gemma import _CONFIG_FOR_DOC
12
12
  from transformers.models.gemma.modeling_gemma import GEMMA_INPUTS_DOCSTRING
13
13
  from transformers.utils import add_start_docstrings_to_model_forward
14
14
  from transformers.utils import replace_return_docstrings
15
+ from transformers.utils.deprecation import deprecate_kwarg
15
16
 
16
17
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
18
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
@@ -127,6 +128,7 @@ def lce_forward_deprecated(
127
128
  )
128
129
 
129
130
 
131
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
130
132
  @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
131
133
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
132
134
  def lce_forward(
@@ -142,7 +144,7 @@ def lce_forward(
142
144
  output_hidden_states: Optional[bool] = None,
143
145
  return_dict: Optional[bool] = None,
144
146
  cache_position: Optional[torch.LongTensor] = None,
145
- num_logits_to_keep: int = 0,
147
+ logits_to_keep: Union[int, torch.Tensor] = 0,
146
148
  **loss_kwargs,
147
149
  ) -> Union[Tuple, CausalLMOutputWithPast]:
148
150
  r"""
@@ -152,10 +154,12 @@ def lce_forward(
152
154
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
153
155
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
154
156
 
155
- num_logits_to_keep (`int`, *optional*):
156
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
157
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
158
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
157
159
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
158
160
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
161
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
162
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
159
163
 
160
164
  Returns:
161
165
 
@@ -209,7 +213,8 @@ def lce_forward(
209
213
  **loss_kwargs,
210
214
  )
211
215
  else: # if in inference mode materialize logits
212
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
216
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
217
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
213
218
  if labels is not None:
214
219
  loss = self.loss_function(
215
220
  logits=logits,
@@ -13,6 +13,7 @@ from transformers.models.gemma2.modeling_gemma2 import _CONFIG_FOR_DOC
13
13
  from transformers.models.gemma2.modeling_gemma2 import GEMMA2_INPUTS_DOCSTRING
14
14
  from transformers.utils import add_start_docstrings_to_model_forward
15
15
  from transformers.utils import replace_return_docstrings
16
+ from transformers.utils.deprecation import deprecate_kwarg
16
17
 
17
18
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
18
19
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
@@ -134,6 +135,7 @@ def lce_forward_deprecated(
134
135
  )
135
136
 
136
137
 
138
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
137
139
  @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
138
140
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
139
141
  def lce_forward(
@@ -149,7 +151,7 @@ def lce_forward(
149
151
  output_hidden_states: Optional[bool] = None,
150
152
  return_dict: Optional[bool] = None,
151
153
  cache_position: Optional[torch.LongTensor] = None,
152
- num_logits_to_keep: int = 0,
154
+ logits_to_keep: Union[int, torch.Tensor] = 0,
153
155
  **loss_kwargs,
154
156
  ) -> Union[Tuple, CausalLMOutputWithPast]:
155
157
  r"""
@@ -159,10 +161,12 @@ def lce_forward(
159
161
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
160
162
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
161
163
 
162
- num_logits_to_keep (`int`, *optional*):
163
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
164
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
165
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
164
166
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
165
167
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
168
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
169
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
166
170
 
167
171
  Returns:
168
172
 
@@ -218,12 +222,13 @@ def lce_forward(
218
222
  lm_head_weight=self.lm_head.weight,
219
223
  labels=labels,
220
224
  hidden_size=self.config.hidden_size,
221
- softcap=self.config.final_logit_softcapping,
225
+ final_logit_softcapping=self.config.final_logit_softcapping,
222
226
  **loss_kwargs,
223
227
  )
224
228
 
225
229
  else: # if in inference mode materialize logits
226
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
230
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
231
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
227
232
  if self.config.final_logit_softcapping is not None:
228
233
  logits = logits / self.config.final_logit_softcapping
229
234
  logits = torch.tanh(logits)