liger-kernel 0.6.1__tar.gz → 0.6.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (285) hide show
  1. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/.github/workflows/benchmark.yml +10 -9
  2. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/PKG-INFO +2 -2
  3. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/README.md +1 -1
  4. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/data/all_benchmark_data.csv +65 -1
  5. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +11 -3
  6. liger_kernel-0.6.2/benchmark/scripts/benchmark_llama4_rope.py +249 -0
  7. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/pyproject.toml +1 -1
  8. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/chunked_loss/dpo_loss.py +54 -3
  9. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/ops/fused_linear_cross_entropy.py +21 -13
  10. liger_kernel-0.6.2/src/liger_kernel/ops/llama4_rope.py +225 -0
  11. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/__init__.py +15 -0
  12. liger_kernel-0.6.2/src/liger_kernel/transformers/experimental/__init__.py +5 -0
  13. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/functional.py +2 -0
  14. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +3 -0
  15. liger_kernel-0.6.2/src/liger_kernel/transformers/llama4_rope.py +93 -0
  16. liger_kernel-0.6.2/src/liger_kernel/transformers/model/glm4v.py +150 -0
  17. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/model/loss_utils.py +2 -0
  18. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/model/mllama.py +4 -2
  19. liger_kernel-0.6.2/src/liger_kernel/transformers/model/phi3.py +112 -0
  20. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/monkey_patch.py +100 -20
  21. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel.egg-info/PKG-INFO +2 -2
  22. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel.egg-info/SOURCES.txt +5 -0
  23. liger_kernel-0.6.2/test/chunked_loss/test_dpo_loss.py +938 -0
  24. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/convergence/bf16/test_mini_models.py +99 -2
  25. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/convergence/bf16/test_mini_models_multimodal.py +13 -3
  26. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/convergence/bf16/test_mini_models_with_logits.py +97 -0
  27. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/convergence/fp32/test_mini_models.py +94 -1
  28. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/convergence/fp32/test_mini_models_multimodal.py +2 -1
  29. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/convergence/fp32/test_mini_models_with_logits.py +94 -0
  30. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/transformers/test_fused_linear_cross_entropy.py +12 -5
  31. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/transformers/test_monkey_patch.py +96 -5
  32. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/utils.py +12 -0
  33. liger_kernel-0.6.1/src/liger_kernel/transformers/model/phi3.py +0 -263
  34. liger_kernel-0.6.1/test/chunked_loss/test_dpo_loss.py +0 -358
  35. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  36. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  37. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/.github/pull_request_template.md +0 -0
  38. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/.github/workflows/amd-ci.yml +0 -0
  39. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/.github/workflows/docs.yml +0 -0
  40. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/.github/workflows/intel-ci.yml +0 -0
  41. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/.github/workflows/nvi-ci.yml +0 -0
  42. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/.github/workflows/publish-nightly.yml +0 -0
  43. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/.github/workflows/publish-release.yml +0 -0
  44. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/.gitignore +0 -0
  45. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/LICENSE +0 -0
  46. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/Makefile +0 -0
  47. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/NOTICE +0 -0
  48. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/README.md +0 -0
  49. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/__init__.py +0 -0
  50. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/benchmarks_visualizer.py +0 -0
  51. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/__init__.py +0 -0
  52. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  53. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  54. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/benchmark_distill_cosine_loss.py +0 -0
  55. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  56. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  57. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/benchmark_dyt.py +0 -0
  58. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/benchmark_embedding.py +0 -0
  59. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/benchmark_fused_add_rms_norm.py +0 -0
  60. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  61. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/benchmark_fused_neighborhood_attention.py +0 -0
  62. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/benchmark_geglu.py +0 -0
  63. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/benchmark_group_norm.py +0 -0
  64. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/benchmark_jsd.py +0 -0
  65. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/benchmark_kl_div.py +0 -0
  66. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  67. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  68. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/benchmark_multi_token_attention.py +0 -0
  69. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  70. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  71. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  72. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/benchmark_rope.py +0 -0
  73. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  74. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/benchmark_softmax.py +0 -0
  75. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/benchmark_sparse_multi_token_attention.py +0 -0
  76. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/benchmark_sparsemax.py +0 -0
  77. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/benchmark_swiglu.py +0 -0
  78. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/benchmark_tvd.py +0 -0
  79. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/benchmark/scripts/utils.py +0 -0
  80. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/dev/fmt-requirements.txt +0 -0
  81. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/dev/modal/benchmarks.py +0 -0
  82. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/dev/modal/tests.py +0 -0
  83. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/dev/modal/tests_bwd.py +0 -0
  84. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/docs/Examples.md +0 -0
  85. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/docs/Getting-Started.md +0 -0
  86. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/docs/High-Level-APIs.md +0 -0
  87. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/docs/Low-Level-APIs.md +0 -0
  88. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/docs/acknowledgement.md +0 -0
  89. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/docs/contributing.md +0 -0
  90. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/docs/images/banner.GIF +0 -0
  91. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/docs/images/compose.gif +0 -0
  92. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/docs/images/e2e-memory.png +0 -0
  93. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/docs/images/e2e-tps.png +0 -0
  94. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/docs/images/logo-banner.png +0 -0
  95. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/docs/images/patch.gif +0 -0
  96. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/docs/images/post-training.png +0 -0
  97. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/docs/index.md +0 -0
  98. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/docs/license.md +0 -0
  99. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/alignment/accelerate_config.yaml +0 -0
  100. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/alignment/run_orpo.py +0 -0
  101. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/huggingface/README.md +0 -0
  102. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/huggingface/callback.py +0 -0
  103. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/huggingface/config/fsdp_config.json +0 -0
  104. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  105. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  106. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  107. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/huggingface/img/llama_tps.png +0 -0
  108. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  109. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/huggingface/img/qwen_tps.png +0 -0
  110. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/huggingface/launch_on_modal.py +0 -0
  111. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/huggingface/requirements.txt +0 -0
  112. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/huggingface/run_benchmarks.sh +0 -0
  113. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/huggingface/run_gemma.sh +0 -0
  114. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/huggingface/run_llama.sh +0 -0
  115. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/huggingface/run_qwen.sh +0 -0
  116. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/huggingface/run_qwen2_vl.sh +0 -0
  117. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/huggingface/training.py +0 -0
  118. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/huggingface/training_multimodal.py +0 -0
  119. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/lightning/README.md +0 -0
  120. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/lightning/requirements.txt +0 -0
  121. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/lightning/training.py +0 -0
  122. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/medusa/README.md +0 -0
  123. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/medusa/callback.py +0 -0
  124. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  125. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  126. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  127. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  128. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  129. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  130. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  131. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  132. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  133. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/medusa/medusa_util.py +0 -0
  134. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/medusa/requirements.txt +0 -0
  135. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  136. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/examples/medusa/train.py +0 -0
  137. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/licenses/LICENSE-Apache-2.0 +0 -0
  138. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  139. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  140. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/licenses/LICENSE-MIT-llmc +0 -0
  141. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/licenses/LICENSE-MIT-triton +0 -0
  142. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/mkdocs.yml +0 -0
  143. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/setup.cfg +0 -0
  144. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/setup.py +0 -0
  145. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/__init__.py +0 -0
  146. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/chunked_loss/README.md +0 -0
  147. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  148. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/chunked_loss/cosine_similarity_loss.py +0 -0
  149. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  150. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/chunked_loss/functional.py +0 -0
  151. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  152. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/chunked_loss/fused_linear_ppo.py +0 -0
  153. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  154. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  155. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
  156. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  157. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  158. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  159. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  160. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/env_report.py +0 -0
  161. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/ops/__init__.py +0 -0
  162. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/ops/cross_entropy.py +0 -0
  163. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/ops/dyt.py +0 -0
  164. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  165. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  166. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/ops/fused_add_rms_norm.py +0 -0
  167. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  168. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/ops/fused_neighborhood_attention.py +0 -0
  169. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/ops/geglu.py +0 -0
  170. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/ops/group_norm.py +0 -0
  171. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/ops/grpo_loss.py +0 -0
  172. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/ops/jsd.py +0 -0
  173. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/ops/kl_div.py +0 -0
  174. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/ops/layer_norm.py +0 -0
  175. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/ops/multi_token_attention.py +0 -0
  176. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  177. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/ops/rms_norm.py +0 -0
  178. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/ops/rope.py +0 -0
  179. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/ops/softmax.py +0 -0
  180. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/ops/sparsemax.py +0 -0
  181. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/ops/swiglu.py +0 -0
  182. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/ops/tvd.py +0 -0
  183. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/ops/utils.py +0 -0
  184. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/auto_model.py +0 -0
  185. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  186. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/dyt.py +0 -0
  187. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  188. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/fsdp.py +0 -0
  189. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/fused_add_rms_norm.py +0 -0
  190. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  191. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/fused_neighborhood_attention.py +0 -0
  192. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/geglu.py +0 -0
  193. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/group_norm.py +0 -0
  194. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/grpo_loss.py +0 -0
  195. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/jsd.py +0 -0
  196. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/kl_div.py +0 -0
  197. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/layer_norm.py +0 -0
  198. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/model/__init__.py +0 -0
  199. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/model/gemma.py +0 -0
  200. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  201. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/model/gemma3.py +0 -0
  202. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/model/glm4.py +0 -0
  203. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/model/llama.py +0 -0
  204. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/model/llama4.py +0 -0
  205. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/model/llava.py +0 -0
  206. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/model/mistral.py +0 -0
  207. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  208. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/model/olmo2.py +0 -0
  209. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/model/paligemma.py +0 -0
  210. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  211. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
  212. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  213. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/model/qwen3.py +0 -0
  214. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/model/qwen3_moe.py +0 -0
  215. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/model/smollm3.py +0 -0
  216. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/multi_token_attention.py +0 -0
  217. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  218. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/rms_norm.py +0 -0
  219. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/rope.py +0 -0
  220. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/softmax.py +0 -0
  221. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/sparsemax.py +0 -0
  222. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/swiglu.py +0 -0
  223. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  224. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  225. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  226. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/transformers/tvd.py +0 -0
  227. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/triton/__init__.py +0 -0
  228. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/triton/monkey_patch.py +0 -0
  229. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel/utils.py +0 -0
  230. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel.egg-info/dependency_links.txt +0 -0
  231. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel.egg-info/requires.txt +0 -0
  232. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/src/liger_kernel.egg-info/top_level.txt +0 -0
  233. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/__init__.py +0 -0
  234. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/chunked_loss/__init__.py +0 -0
  235. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/chunked_loss/test_cosine_loss.py +0 -0
  236. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/chunked_loss/test_cpo_loss.py +0 -0
  237. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/chunked_loss/test_grpo_loss.py +0 -0
  238. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/chunked_loss/test_jsd_loss.py +0 -0
  239. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/chunked_loss/test_kto_loss.py +0 -0
  240. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/chunked_loss/test_orpo_loss.py +0 -0
  241. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/chunked_loss/test_simpo_loss.py +0 -0
  242. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/conftest.py +0 -0
  243. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/convergence/__init__.py +0 -0
  244. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/convergence/bf16/__init__.py +0 -0
  245. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/convergence/fp32/__init__.py +0 -0
  246. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json +0 -0
  247. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
  248. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
  249. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
  250. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
  251. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  252. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
  253. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  254. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/resources/fake_configs/meta-llama/Llama-4-Scout-17B-16E-Instruct/tokenizer_config.json +0 -0
  255. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  256. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/resources/tiny_shakespeare.txt +0 -0
  257. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  258. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  259. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  260. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/transformers/test_auto_model.py +0 -0
  261. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/transformers/test_cross_entropy.py +0 -0
  262. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/transformers/test_dyt.py +0 -0
  263. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/transformers/test_embedding.py +0 -0
  264. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/transformers/test_flex_attention.py +0 -0
  265. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/transformers/test_fused_add_rms_norm.py +0 -0
  266. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/transformers/test_fused_linear_jsd.py +0 -0
  267. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/transformers/test_fused_neighborhood_attention.py +0 -0
  268. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/transformers/test_geglu.py +0 -0
  269. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/transformers/test_group_norm.py +0 -0
  270. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/transformers/test_grpo_loss.py +0 -0
  271. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/transformers/test_jsd.py +0 -0
  272. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/transformers/test_kl_div.py +0 -0
  273. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/transformers/test_layer_norm.py +0 -0
  274. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/transformers/test_mm_int8int2.py +0 -0
  275. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/transformers/test_multi_token_attention.py +0 -0
  276. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/transformers/test_qwen2vl_mrope.py +0 -0
  277. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/transformers/test_rms_norm.py +0 -0
  278. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/transformers/test_rope.py +0 -0
  279. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/transformers/test_softmax.py +0 -0
  280. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/transformers/test_sparsemax.py +0 -0
  281. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/transformers/test_swiglu.py +0 -0
  282. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/transformers/test_trainer_integration.py +0 -0
  283. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/transformers/test_transformers.py +0 -0
  284. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/transformers/test_tvd.py +0 -0
  285. {liger_kernel-0.6.1 → liger_kernel-0.6.2}/test/triton/test_triton_monkey_patch.py +0 -0
@@ -39,24 +39,24 @@ jobs:
39
39
 
40
40
  steps:
41
41
  # Step: Decide the commit hash to use
42
+ # Step: Checkout full history so we can check out any commit
43
+ - name: Checkout full repo history
44
+ uses: actions/checkout@v3
45
+ with:
46
+ fetch-depth: 0 # Important: so we can checkout arbitrary commit
47
+
42
48
  - name: Determine commit hash to checkout
43
49
  id: choose_commit
44
50
  run: |
45
- if [ "${{ github.event.inputs.commit_hash }}" != "" ]; then
51
+ if [ "${{ github.event_name}}" == "workflow_dispatch" ] && [ "${{ github.event.inputs.commit_hash }}" != "main" ]; then
46
52
  echo "Using manual input commit: ${{ github.event.inputs.commit_hash }}"
47
53
  echo "hash=${{ github.event.inputs.commit_hash }}" >> $GITHUB_OUTPUT
48
54
  else
49
55
  echo "Using latest commit from main"
50
- git fetch origin main
51
- echo "hash=$(git rev-parse origin/main)" >> $GITHUB_OUTPUT
56
+ echo "hash=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT
52
57
  fi
53
58
 
54
- # Step: Checkout full history so we can check out any commit
55
- - name: Checkout full repo history
56
- uses: actions/checkout@v3
57
- with:
58
- fetch-depth: 0 # Important: so we can checkout arbitrary commit
59
- # Step: Conditionally replace benchmark folder from main
59
+ # Step: Conditionally replace benchmark folder from main
60
60
  - name: Replace benchmark folder from main (manual only, commit ≠ main)
61
61
  if: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.commit_hash != 'main' }}
62
62
  run: |
@@ -133,6 +133,7 @@ jobs:
133
133
  echo "Not a release event"
134
134
  path=${{steps.choose_commit.outputs.hash}}
135
135
  fi
136
+ echo "path=$path" >> $GITHUB_OUTPUT
136
137
  COMMIT_DIR="gh-pages/${OUTPUT_DIR}/${path}"
137
138
 
138
139
  mkdir -p "$COMMIT_DIR"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: liger_kernel
3
- Version: 0.6.1
3
+ Version: 0.6.2
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -400,7 +400,7 @@ loss.backward()
400
400
  </a>
401
401
  </div>
402
402
  <div style="display: block;">
403
- <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
403
+ <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml">
404
404
  <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?event=schedule" alt="Build">
405
405
  </a>
406
406
  </div>
@@ -348,7 +348,7 @@ loss.backward()
348
348
  </a>
349
349
  </div>
350
350
  <div style="display: block;">
351
- <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
351
+ <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml">
352
352
  <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?event=schedule" alt="Build">
353
353
  </a>
354
354
  </div>
@@ -1574,4 +1574,68 @@ fused_add_rms_norm,liger_rms_norm,full,memory,MB,H,hidden size,2048,208.06298828
1574
1574
  fused_add_rms_norm,liger_rms_norm,full,memory,MB,H,hidden size,4096,416.11767578125,416.11767578125,416.11767578125,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0
1575
1575
  fused_add_rms_norm,liger_rms_norm,full,memory,MB,H,hidden size,8192,832.22705078125,832.22705078125,832.22705078125,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0
1576
1576
  fused_add_rms_norm,liger_rms_norm,full,memory,MB,H,hidden size,16384,1544.44580078125,1544.44580078125,1544.44580078125,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0
1577
- fused_add_rms_norm,liger_rms_norm,full,memory,MB,H,hidden size,32768,2960.8837890625,2960.8837890625,2960.8837890625,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0
1577
+ fused_add_rms_norm,liger_rms_norm,full,memory,MB,H,hidden size,32768,2960.8837890625,2960.8837890625,2960.8837890625,"{""M"": 2048, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 NVL,2025-07-16 07:04:45,0.6.0
1578
+ llama4_rope,liger,forward,speed,ms,H,hidden size,512,0.08249600231647491,0.08102399855852127,0.08432000130414963,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:01,0.6.1
1579
+ llama4_rope,liger,forward,speed,ms,H,hidden size,2048,0.08169600367546082,0.08037760108709335,0.08329600095748901,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:01,0.6.1
1580
+ llama4_rope,liger,forward,speed,ms,H,hidden size,8192,0.08128000050783157,0.07980799674987793,0.08329600095748901,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:01,0.6.1
1581
+ llama4_rope,huggingface,forward,speed,ms,H,hidden size,512,0.03759999945759773,0.03612799942493439,0.03907199949026108,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:03,0.6.1
1582
+ llama4_rope,huggingface,forward,speed,ms,H,hidden size,2048,0.06185600161552429,0.061267200857400894,0.06252799928188324,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:03,0.6.1
1583
+ llama4_rope,huggingface,forward,speed,ms,H,hidden size,8192,0.206496000289917,0.20582400262355804,0.20716799795627594,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:03,0.6.1
1584
+ llama4_rope,liger,backward,speed,ms,H,hidden size,512,0.15404799580574036,0.15241600573062897,0.15615999698638916,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:04,0.6.1
1585
+ llama4_rope,liger,backward,speed,ms,H,hidden size,2048,0.1536320000886917,0.15190400183200836,0.1558080017566681,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:04,0.6.1
1586
+ llama4_rope,liger,backward,speed,ms,H,hidden size,8192,0.15263999998569489,0.15094399452209473,0.15491199493408203,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:04,0.6.1
1587
+ llama4_rope,huggingface,backward,speed,ms,H,hidden size,512,0.13760000467300415,0.13574400544166565,0.14009599387645721,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:05,0.6.1
1588
+ llama4_rope,huggingface,backward,speed,ms,H,hidden size,2048,0.13600000739097595,0.13449600338935852,0.1382720023393631,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:05,0.6.1
1589
+ llama4_rope,huggingface,backward,speed,ms,H,hidden size,8192,0.21011200547218323,0.20924800634384155,0.21110400557518005,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:05,0.6.1
1590
+ llama4_rope,liger,full,speed,ms,H,hidden size,512,0.3652159869670868,0.3619840145111084,0.3699840009212494,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:07,0.6.1
1591
+ llama4_rope,liger,full,speed,ms,H,hidden size,2048,0.3599040061235428,0.2881920039653778,0.36559998989105225,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:07,0.6.1
1592
+ llama4_rope,liger,full,speed,ms,H,hidden size,8192,0.2874239981174469,0.2852480113506317,0.29029120206832887,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:07,0.6.1
1593
+ llama4_rope,huggingface,full,speed,ms,H,hidden size,512,0.24691200256347656,0.24489599466323853,0.24961919784545897,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:08,0.6.1
1594
+ llama4_rope,huggingface,full,speed,ms,H,hidden size,2048,0.24774399399757385,0.24582399427890778,0.2505407989025116,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:08,0.6.1
1595
+ llama4_rope,huggingface,full,speed,ms,H,hidden size,8192,0.41414400935173035,0.41337600350379944,0.41491198539733887,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:08,0.6.1
1596
+ llama4_rope,liger,full,memory,MB,H,hidden size,512,37.23486328125,37.23486328125,37.23486328125,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:08,0.6.1
1597
+ llama4_rope,liger,full,memory,MB,H,hidden size,2048,52.89111328125,52.89111328125,52.89111328125,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:08,0.6.1
1598
+ llama4_rope,liger,full,memory,MB,H,hidden size,8192,115.51611328125,115.51611328125,115.51611328125,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:08,0.6.1
1599
+ llama4_rope,huggingface,full,memory,MB,H,hidden size,512,49.64111328125,49.64111328125,49.64111328125,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:08,0.6.1
1600
+ llama4_rope,huggingface,full,memory,MB,H,hidden size,2048,102.51611328125,102.51611328125,102.51611328125,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:08,0.6.1
1601
+ llama4_rope,huggingface,full,memory,MB,H,hidden size,8192,314.01611328125,314.01611328125,314.01611328125,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:08,0.6.1
1602
+ llama4_rope,liger,forward,speed,ms,T,sequence length,1024,0.07417599856853485,0.07248000055551529,0.07596799731254578,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:10,0.6.1
1603
+ llama4_rope,liger,forward,speed,ms,T,sequence length,2048,0.08182399719953537,0.08006399869918823,0.08380799740552902,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:10,0.6.1
1604
+ llama4_rope,liger,forward,speed,ms,T,sequence length,4096,0.11708799749612808,0.1167680025100708,0.11744000017642975,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:10,0.6.1
1605
+ llama4_rope,liger,forward,speed,ms,T,sequence length,8192,0.2165440022945404,0.21596799790859222,0.21715199947357178,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:10,0.6.1
1606
+ llama4_rope,liger,forward,speed,ms,T,sequence length,16384,0.41756799817085266,0.41705599427223206,0.41811200976371765,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:10,0.6.1
1607
+ llama4_rope,huggingface,forward,speed,ms,T,sequence length,1024,0.11644800007343292,0.11590400338172913,0.11708799749612808,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:12,0.6.1
1608
+ llama4_rope,huggingface,forward,speed,ms,T,sequence length,2048,0.20659199357032776,0.20608000457286835,0.2072640061378479,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:12,0.6.1
1609
+ llama4_rope,huggingface,forward,speed,ms,T,sequence length,4096,0.38553598523139954,0.3846847891807556,0.38624000549316406,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:12,0.6.1
1610
+ llama4_rope,huggingface,forward,speed,ms,T,sequence length,8192,0.7411519885063171,0.7403839826583862,0.7420480251312256,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:12,0.6.1
1611
+ llama4_rope,huggingface,forward,speed,ms,T,sequence length,16384,1.4553920030593872,1.4543871641159059,1.4562879800796509,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:12,0.6.1
1612
+ llama4_rope,liger,backward,speed,ms,T,sequence length,1024,0.11840000003576279,0.11711999773979187,0.12031999975442886,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:15,0.6.1
1613
+ llama4_rope,liger,backward,speed,ms,T,sequence length,2048,0.12336000055074692,0.12198399752378464,0.12489599734544754,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:15,0.6.1
1614
+ llama4_rope,liger,backward,speed,ms,T,sequence length,4096,0.12380799651145935,0.12240000069141388,0.12559999525547028,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:15,0.6.1
1615
+ llama4_rope,liger,backward,speed,ms,T,sequence length,8192,0.2170879989862442,0.2165759950876236,0.21753600239753723,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:15,0.6.1
1616
+ llama4_rope,liger,backward,speed,ms,T,sequence length,16384,0.4175359904766083,0.41705599427223206,0.4181375920772552,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:15,0.6.1
1617
+ llama4_rope,huggingface,backward,speed,ms,T,sequence length,1024,0.1189119964838028,0.11769600212574005,0.12003199756145477,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:17,0.6.1
1618
+ llama4_rope,huggingface,backward,speed,ms,T,sequence length,2048,0.21011200547218323,0.20927999913692474,0.21119999885559082,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:17,0.6.1
1619
+ llama4_rope,huggingface,backward,speed,ms,T,sequence length,4096,0.39740800857543945,0.3963199853897095,0.39824000000953674,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:17,0.6.1
1620
+ llama4_rope,huggingface,backward,speed,ms,T,sequence length,8192,0.7540159821510315,0.7528960108757019,0.7550719976425171,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:17,0.6.1
1621
+ llama4_rope,huggingface,backward,speed,ms,T,sequence length,16384,1.4822720289230347,1.4810559749603271,1.4833600521087646,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:17,0.6.1
1622
+ llama4_rope,liger,full,speed,ms,T,sequence length,1024,0.2874400019645691,0.2853440046310425,0.29052799940109253,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:19,0.6.1
1623
+ llama4_rope,liger,full,speed,ms,T,sequence length,2048,0.28646400570869446,0.2845759987831116,0.28963199257850647,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:19,0.6.1
1624
+ llama4_rope,liger,full,speed,ms,T,sequence length,4096,0.29897600412368774,0.29660800099372864,0.302131199836731,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:19,0.6.1
1625
+ llama4_rope,liger,full,speed,ms,T,sequence length,8192,0.4315840005874634,0.4304639995098114,0.43270400166511536,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:19,0.6.1
1626
+ llama4_rope,liger,full,speed,ms,T,sequence length,16384,0.833184003829956,0.8322240114212036,0.8345024228096007,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:19,0.6.1
1627
+ llama4_rope,huggingface,full,speed,ms,T,sequence length,1024,0.24592000246047974,0.24396799504756927,0.24876800179481506,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1
1628
+ llama4_rope,huggingface,full,speed,ms,T,sequence length,2048,0.4138239920139313,0.41308799386024475,0.4145599901676178,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1
1629
+ llama4_rope,huggingface,full,speed,ms,T,sequence length,4096,0.7800959944725037,0.7790719866752625,0.7810239791870117,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1
1630
+ llama4_rope,huggingface,full,speed,ms,T,sequence length,8192,1.4911680221557617,1.4902976036071778,1.4922879934310913,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1
1631
+ llama4_rope,huggingface,full,speed,ms,T,sequence length,16384,2.9344160556793213,2.9333438873291016,2.9353599548339844,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1
1632
+ llama4_rope,liger,full,memory,MB,T,sequence length,1024,73.75830078125,73.75830078125,73.75830078125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1
1633
+ llama4_rope,liger,full,memory,MB,T,sequence length,2048,115.51611328125,115.51611328125,115.51611328125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1
1634
+ llama4_rope,liger,full,memory,MB,T,sequence length,4096,199.03173828125,199.03173828125,199.03173828125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1
1635
+ llama4_rope,liger,full,memory,MB,T,sequence length,8192,366.06298828125,366.06298828125,366.06298828125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1
1636
+ llama4_rope,liger,full,memory,MB,T,sequence length,16384,700.12548828125,700.12548828125,700.12548828125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1
1637
+ llama4_rope,huggingface,full,memory,MB,T,sequence length,1024,173.00830078125,173.00830078125,173.00830078125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1
1638
+ llama4_rope,huggingface,full,memory,MB,T,sequence length,2048,314.01611328125,314.01611328125,314.01611328125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1
1639
+ llama4_rope,huggingface,full,memory,MB,T,sequence length,4096,596.03173828125,596.03173828125,596.03173828125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1
1640
+ llama4_rope,huggingface,full,memory,MB,T,sequence length,8192,1160.06298828125,1160.06298828125,1160.06298828125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1
1641
+ llama4_rope,huggingface,full,memory,MB,T,sequence length,16384,2288.12548828125,2288.12548828125,2288.12548828125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1
@@ -34,10 +34,12 @@ class TorchLMHeadCE(torch.nn.Module):
34
34
 
35
35
 
36
36
  class LigerLMHeadCE(torch.nn.Module):
37
- def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
37
+ def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100, accum_dtype=None):
38
38
  super().__init__()
39
39
  self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype)
40
- self.ce_loss = LigerFusedLinearCrossEntropyLoss(ignore_index=ignore_index, reduction="mean")
40
+ self.ce_loss = LigerFusedLinearCrossEntropyLoss(
41
+ ignore_index=ignore_index, reduction="mean", accum_dtype=accum_dtype
42
+ )
41
43
 
42
44
  def forward(self, x, y):
43
45
  return self.ce_loss(self.lin.weight, x, y)
@@ -59,6 +61,7 @@ def bench_memory_fused_linear_cross_entropy(
59
61
 
60
62
  torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)
61
63
  liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)
64
+ liger_lm_head_ce_fp32_accum = LigerLMHeadCE(H=H, V=V, dtype=dtype, accum_dtype=torch.float32).to(device)
62
65
 
63
66
  _input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device)
64
67
  target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1)
@@ -66,6 +69,8 @@ def bench_memory_fused_linear_cross_entropy(
66
69
  def fwd():
67
70
  if provider == "liger":
68
71
  return liger_lm_head_ce(_input, target)
72
+ elif provider == "liger-fp32-accum":
73
+ return liger_lm_head_ce_fp32_accum(_input, target)
69
74
  elif provider == "huggingface":
70
75
  return torch_lm_head_ce(_input, target)
71
76
 
@@ -98,6 +103,7 @@ def bench_speed_fused_linear_cross_entropy(
98
103
 
99
104
  torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)
100
105
  liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)
106
+ liger_lm_head_ce_fp32_accum = LigerLMHeadCE(H=H, V=V, dtype=dtype, accum_dtype=torch.float32).to(device)
101
107
 
102
108
  _input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device)
103
109
  target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1)
@@ -105,6 +111,8 @@ def bench_speed_fused_linear_cross_entropy(
105
111
  def fwd():
106
112
  if provider == "liger":
107
113
  return liger_lm_head_ce(_input, target)
114
+ elif provider == "liger-fp32-accum":
115
+ return liger_lm_head_ce_fp32_accum(_input, target)
108
116
  elif provider == "huggingface":
109
117
  return torch_lm_head_ce(_input, target)
110
118
 
@@ -149,7 +157,7 @@ if __name__ == "__main__":
149
157
  "x_name": "BT",
150
158
  "x_label": "B x T",
151
159
  "x_values": [2**i for i in range(12, 16)],
152
- "kernel_providers": ["liger", "huggingface"],
160
+ "kernel_providers": ["liger", "liger-fp32-accum", "huggingface"],
153
161
  "extra_benchmark_configs": [{"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16}],
154
162
  "overwrite": args.overwrite,
155
163
  }
@@ -0,0 +1,249 @@
1
+ import torch
2
+ import triton
3
+
4
+ from transformers.models.llama4.configuration_llama4 import Llama4TextConfig
5
+ from transformers.models.llama4.modeling_llama4 import Llama4TextRotaryEmbedding
6
+ from transformers.models.llama4.modeling_llama4 import apply_rotary_emb
7
+ from utils import QUANTILES
8
+ from utils import SingleBenchmarkRunInput
9
+ from utils import SingleBenchmarkRunOutput
10
+ from utils import _test_memory
11
+ from utils import parse_benchmark_script_args
12
+ from utils import run_benchmarks
13
+
14
+ from liger_kernel.transformers.llama4_rope import liger_llama4_text_rotary_pos_emb
15
+ from liger_kernel.utils import infer_device
16
+ from liger_kernel.utils import transformers_version_dispatch
17
+
18
+ device = infer_device()
19
+
20
+
21
+ def bench_speed_llama4_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
22
+ provider = input.kernel_provider
23
+ mode = input.kernel_operation_mode
24
+
25
+ extra_benchmark_config = input.extra_benchmark_config
26
+ num_q_heads = extra_benchmark_config["num_q_heads"]
27
+ num_kv_heads = extra_benchmark_config["num_kv_heads"]
28
+ dtype = extra_benchmark_config["dtype"]
29
+
30
+ # x can be either hidden_size or seq_len
31
+ hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x
32
+ seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x
33
+
34
+ head_dim = hidden_size // num_q_heads
35
+
36
+ # Create Llama4TextConfig for the rotary embedding
37
+ config = Llama4TextConfig(
38
+ hidden_size=hidden_size,
39
+ num_attention_heads=num_q_heads,
40
+ num_key_value_heads=num_kv_heads,
41
+ head_dim=head_dim,
42
+ max_position_embeddings=seq_len,
43
+ rope_theta=10000.0,
44
+ rope_scaling=None, # Use default rope type
45
+ )
46
+
47
+ rotary_emb = transformers_version_dispatch(
48
+ "4.48.0",
49
+ Llama4TextRotaryEmbedding,
50
+ Llama4TextRotaryEmbedding,
51
+ before_kwargs={"config": config, "device": device},
52
+ after_kwargs={"config": config, "device": device},
53
+ )
54
+
55
+ q = torch.randn(
56
+ (1, seq_len, num_q_heads, head_dim),
57
+ device=device,
58
+ requires_grad=True,
59
+ dtype=dtype,
60
+ )
61
+ k = torch.randn(
62
+ (1, seq_len, num_kv_heads, head_dim),
63
+ device=device,
64
+ requires_grad=True,
65
+ dtype=dtype,
66
+ )
67
+ dq, dk = (
68
+ torch.randn_like(q, device=device, dtype=dtype),
69
+ torch.randn_like(k, device=device),
70
+ )
71
+ pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
72
+ freqs_cis = rotary_emb(q, pos_ids)
73
+
74
+ def fwd():
75
+ if provider == "liger":
76
+ return liger_llama4_text_rotary_pos_emb(q, k, freqs_cis)
77
+ elif provider == "huggingface":
78
+ return apply_rotary_emb(q, k, freqs_cis)
79
+ else:
80
+ raise ValueError(f"Invalid provider: {provider} for Llama4 RoPE embedding")
81
+
82
+ if mode == "forward":
83
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(
84
+ fwd,
85
+ grad_to_none=[q, k],
86
+ rep=400,
87
+ quantiles=QUANTILES,
88
+ )
89
+ elif mode == "backward":
90
+ q_out, k_out = fwd()
91
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(
92
+ lambda: torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True),
93
+ grad_to_none=[q, k],
94
+ rep=400,
95
+ quantiles=QUANTILES,
96
+ )
97
+ elif mode == "full":
98
+
99
+ def full():
100
+ q_out, k_out = fwd()
101
+ torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True)
102
+
103
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(
104
+ full,
105
+ grad_to_none=[q, k],
106
+ rep=400,
107
+ quantiles=QUANTILES,
108
+ )
109
+ return SingleBenchmarkRunOutput(
110
+ y_20=ms_20,
111
+ y_50=ms_50,
112
+ y_80=ms_80,
113
+ )
114
+
115
+
116
+ def bench_memory_llama4_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
117
+ provider = input.kernel_provider
118
+
119
+ extra_benchmark_config = input.extra_benchmark_config
120
+ num_q_heads = extra_benchmark_config["num_q_heads"]
121
+ num_kv_heads = extra_benchmark_config["num_kv_heads"]
122
+ dtype = extra_benchmark_config["dtype"]
123
+
124
+ # x can be either hidden_size or seq_len
125
+ hidden_size = extra_benchmark_config["hidden_size"] if "hidden_size" in extra_benchmark_config else input.x
126
+ seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x
127
+
128
+ head_dim = hidden_size // num_q_heads
129
+
130
+ # Create Llama4TextConfig for the rotary embedding
131
+ config = Llama4TextConfig(
132
+ hidden_size=hidden_size,
133
+ num_attention_heads=num_q_heads,
134
+ num_key_value_heads=num_kv_heads,
135
+ head_dim=head_dim,
136
+ max_position_embeddings=seq_len,
137
+ rope_theta=10000.0,
138
+ rope_scaling=None, # Use default rope type
139
+ )
140
+
141
+ rotary_emb = transformers_version_dispatch(
142
+ "4.48.0",
143
+ Llama4TextRotaryEmbedding,
144
+ Llama4TextRotaryEmbedding,
145
+ before_kwargs={"config": config, "device": device},
146
+ after_kwargs={"config": config, "device": device},
147
+ )
148
+
149
+ q = torch.randn(
150
+ (1, seq_len, num_q_heads, head_dim),
151
+ device=device,
152
+ requires_grad=True,
153
+ dtype=dtype,
154
+ )
155
+ k = torch.randn(
156
+ (1, seq_len, num_kv_heads, head_dim),
157
+ device=device,
158
+ requires_grad=True,
159
+ dtype=dtype,
160
+ )
161
+ dq, dk = (
162
+ torch.randn_like(q, device=device, dtype=dtype),
163
+ torch.randn_like(k, device=device),
164
+ )
165
+ pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
166
+ freqs_cis = rotary_emb(q, pos_ids)
167
+
168
+ def full():
169
+ if provider == "liger":
170
+ q_out, k_out = liger_llama4_text_rotary_pos_emb(q, k, freqs_cis)
171
+ else:
172
+ q_out, k_out = apply_rotary_emb(q, k, freqs_cis)
173
+ torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True)
174
+
175
+ mem_50, mem_20, mem_80 = _test_memory(
176
+ full,
177
+ quantiles=QUANTILES,
178
+ )
179
+ return SingleBenchmarkRunOutput(
180
+ y_20=mem_20,
181
+ y_50=mem_50,
182
+ y_80=mem_80,
183
+ )
184
+
185
+
186
+ if __name__ == "__main__":
187
+ args = parse_benchmark_script_args()
188
+
189
+ common_configs_varying_hidden_size = {
190
+ "kernel_name": "llama4_rope",
191
+ "x_name": "H",
192
+ "x_label": "hidden size",
193
+ "x_values": [32 * (2**i) for i in range(4, 10, 2)],
194
+ "kernel_providers": ["liger", "huggingface"],
195
+ "extra_benchmark_configs": [
196
+ {
197
+ "dtype": torch.bfloat16,
198
+ "seq_len": 2048,
199
+ "num_q_heads": 32,
200
+ "num_kv_heads": 8,
201
+ }
202
+ ],
203
+ "overwrite": args.overwrite,
204
+ }
205
+ run_benchmarks(
206
+ bench_test_fn=bench_speed_llama4_rope,
207
+ kernel_operation_modes=["forward", "backward", "full"],
208
+ metric_name="speed",
209
+ metric_unit="ms",
210
+ **common_configs_varying_hidden_size,
211
+ )
212
+ run_benchmarks(
213
+ bench_test_fn=bench_memory_llama4_rope,
214
+ kernel_operation_modes=["full"],
215
+ metric_name="memory",
216
+ metric_unit="MB",
217
+ **common_configs_varying_hidden_size,
218
+ )
219
+
220
+ common_configs_varying_seq_len = {
221
+ "kernel_name": "llama4_rope",
222
+ "x_name": "T",
223
+ "x_label": "sequence length",
224
+ "x_values": [2**i for i in range(10, 15)],
225
+ "kernel_providers": ["liger", "huggingface"],
226
+ "extra_benchmark_configs": [
227
+ {
228
+ "dtype": torch.bfloat16,
229
+ "hidden_size": 8192,
230
+ "num_q_heads": 32,
231
+ "num_kv_heads": 8,
232
+ }
233
+ ],
234
+ "overwrite": args.overwrite,
235
+ }
236
+ run_benchmarks(
237
+ bench_test_fn=bench_speed_llama4_rope,
238
+ kernel_operation_modes=["forward", "backward", "full"],
239
+ metric_name="speed",
240
+ metric_unit="ms",
241
+ **common_configs_varying_seq_len,
242
+ )
243
+ run_benchmarks(
244
+ bench_test_fn=bench_memory_llama4_rope,
245
+ kernel_operation_modes=["full"],
246
+ metric_name="memory",
247
+ metric_unit="MB",
248
+ **common_configs_varying_seq_len,
249
+ )
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel"
7
- version = "0.6.1"
7
+ version = "0.6.2"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -13,6 +13,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
13
13
  ref_chosen_logps=None,
14
14
  ref_rejected_logps=None,
15
15
  beta=0.1,
16
+ loss_type="sigmoid",
16
17
  ):
17
18
  """
18
19
  Paper: https://arxiv.org/pdf/2305.18290
@@ -48,8 +49,50 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
48
49
  chosen_rewards = beta * chosen_logratios
49
50
  rejected_rewards = beta * rejected_logratios
50
51
 
51
- logits_diff = beta * (chosen_logratios - rejected_logratios)
52
- loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
52
+ if loss_type == "sigmoid":
53
+ logits_diff = beta * (chosen_logratios - rejected_logratios)
54
+ loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
55
+
56
+ elif loss_type == "apo_zero":
57
+ # Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
58
+ # Use this loss when you believe the chosen outputs are better than your model's default output
59
+ losses_chosen = 1 - F.sigmoid(beta * chosen_logratios) # Increase chosen likelihood
60
+ losses_rejected = F.sigmoid(beta * rejected_logratios)
61
+ losses = losses_chosen + losses_rejected
62
+ loss = losses.sum() / (full_target.shape[0] // 2)
63
+
64
+ elif loss_type == "apo_down":
65
+ # Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266)
66
+ # Use this loss when you believe the chosen outputs are worse than your model's default output.
67
+ # Decrease chosen likelihood and decrease rejected likelihood more
68
+ losses_chosen = F.sigmoid(beta * chosen_logratios)
69
+ losses_rejected = 1 - F.sigmoid(beta * (chosen_logratios - rejected_logratios))
70
+ losses = losses_chosen + losses_rejected
71
+ loss = losses.sum() / (full_target.shape[0] // 2)
72
+
73
+ elif loss_type == "sppo_hard":
74
+ # In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach,
75
+ # estimated using the PairRM score. The probability calculation is conducted outside of the trainer class.
76
+ # The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is
77
+ # set to 1 for the winner and 0 for the loser.
78
+ a = chosen_logps - ref_chosen_logps
79
+ b = rejected_logps - ref_rejected_logps
80
+ losses = (a - 0.5 / beta) ** 2 + (b + 0.5 / beta) ** 2
81
+ loss = losses.sum() / (full_target.shape[0] // 2)
82
+
83
+ elif loss_type == "nca_pair":
84
+ losses = (
85
+ -F.logsigmoid(chosen_rewards)
86
+ - 0.5 * F.logsigmoid(-chosen_rewards)
87
+ - 0.5 * F.logsigmoid(-rejected_rewards)
88
+ )
89
+ loss = losses.sum() / (full_target.shape[0] // 2)
90
+
91
+ else:
92
+ raise ValueError(
93
+ f"Unsupported loss_type: {loss_type}. Supported types are: sigmoid, apo_zero, apo_down, sppo_hard, nca_pair"
94
+ )
95
+
53
96
  return loss, chosen_rewards, rejected_rewards
54
97
 
55
98
  @classmethod
@@ -70,6 +113,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
70
113
  use_ref_model=True,
71
114
  average_log_prob=False,
72
115
  chunk_size=1,
116
+ loss_type="sigmoid",
73
117
  ):
74
118
  """
75
119
  Fused linear layer with DPO loss.
@@ -108,12 +152,13 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
108
152
  ref_bias=ref_bias,
109
153
  average_log_prob=average_log_prob,
110
154
  chunk_size=chunk_size,
155
+ loss_type=loss_type,
111
156
  )
112
157
 
113
158
  @staticmethod
114
159
  def backward(ctx, *grad_output):
115
160
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
116
- return *grads, None, None, None, None, None, None, None, None, None, None
161
+ return *grads, None, None, None, None, None, None, None, None, None, None, None
117
162
 
118
163
 
119
164
  class LigerFusedLinearDPOLoss(torch.nn.Module):
@@ -130,6 +175,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
130
175
  use_ref_model: bool = True,
131
176
  average_log_prob: bool = False,
132
177
  chunk_size: int = 1,
178
+ loss_type: str = "sigmoid",
133
179
  ):
134
180
  """
135
181
  Args:
@@ -149,6 +195,10 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
149
195
  self.use_ref_model = use_ref_model
150
196
  self.average_log_prob = average_log_prob
151
197
  self.chunk_size = chunk_size
198
+ self.loss_type = loss_type
199
+ supported_loss_types = {"sigmoid", "apo_zero", "apo_down", "sppo_hard", "nca_pair"}
200
+ if self.loss_type not in supported_loss_types:
201
+ raise ValueError(f"Unsupported loss_type: {self.loss_type}. Supported types are: {supported_loss_types}")
152
202
 
153
203
  def forward(
154
204
  self,
@@ -175,4 +225,5 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
175
225
  self.use_ref_model,
176
226
  self.average_log_prob,
177
227
  self.chunk_size,
228
+ self.loss_type,
178
229
  )