liger-kernel 0.5.9__tar.gz → 0.5.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/.github/workflows/docs.yml +1 -1
  2. liger_kernel-0.5.10/.idea/workspace.xml +79 -0
  3. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/PKG-INFO +34 -20
  4. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/README.md +33 -19
  5. liger_kernel-0.5.10/benchmark/README.md +48 -0
  6. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/benchmarks_visualizer.py +35 -9
  7. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/data/all_benchmark_data.csv +72 -0
  8. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_dyt.py +37 -34
  9. liger_kernel-0.5.10/benchmark/scripts/benchmark_sparsemax.py +172 -0
  10. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/dev/modal/tests.py +2 -2
  11. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/dev/modal/tests_bwd.py +2 -2
  12. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/Low-Level-APIs.md +9 -1
  13. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/training_multimodal.py +1 -1
  14. liger_kernel-0.5.10/examples/medusa/requirements.txt +3 -0
  15. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/medusa/scripts/llama3_8b_medusa.sh +2 -5
  16. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/medusa/train.py +37 -39
  17. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/pyproject.toml +1 -1
  18. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/setup.py +23 -3
  19. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/chunked_loss/dpo_loss.py +1 -1
  20. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -1
  21. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/chunked_loss/jsd_loss.py +2 -2
  22. liger_kernel-0.5.10/src/liger_kernel/ops/dyt.py +159 -0
  23. liger_kernel-0.5.10/src/liger_kernel/ops/grpo_loss.py +310 -0
  24. liger_kernel-0.5.10/src/liger_kernel/ops/sparsemax.py +167 -0
  25. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/__init__.py +5 -0
  26. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/dyt.py +5 -3
  27. liger_kernel-0.5.10/src/liger_kernel/transformers/fsdp.py +55 -0
  28. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/functional.py +8 -0
  29. liger_kernel-0.5.10/src/liger_kernel/transformers/grpo_loss.py +98 -0
  30. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/gemma.py +0 -8
  31. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/gemma2.py +0 -6
  32. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/gemma3.py +0 -8
  33. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/glm4.py +0 -6
  34. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/llama.py +56 -11
  35. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/llava.py +0 -8
  36. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/mistral.py +0 -6
  37. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/mixtral.py +0 -8
  38. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/mllama.py +0 -7
  39. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/olmo2.py +0 -6
  40. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/paligemma.py +0 -8
  41. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/phi3.py +0 -8
  42. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/qwen2.py +0 -8
  43. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -6
  44. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -6
  45. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/qwen3.py +0 -6
  46. liger_kernel-0.5.10/src/liger_kernel/transformers/model/qwen3_moe.py +128 -0
  47. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/monkey_patch.py +122 -13
  48. liger_kernel-0.5.10/src/liger_kernel/transformers/sparsemax.py +16 -0
  49. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/swiglu.py +21 -0
  50. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/trainer/orpo_trainer.py +1 -53
  51. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/utils.py +11 -0
  52. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel.egg-info/PKG-INFO +34 -20
  53. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel.egg-info/SOURCES.txt +10 -0
  54. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/chunked_loss/test_dpo_loss.py +2 -0
  55. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/convergence/bf16/test_mini_models.py +58 -1
  56. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/convergence/bf16/test_mini_models_multimodal.py +0 -1
  57. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/convergence/bf16/test_mini_models_with_logits.py +58 -1
  58. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/convergence/fp32/test_mini_models.py +55 -1
  59. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/convergence/fp32/test_mini_models_multimodal.py +0 -1
  60. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/convergence/fp32/test_mini_models_with_logits.py +55 -1
  61. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_dyt.py +40 -20
  62. liger_kernel-0.5.10/test/transformers/test_grpo_loss.py +190 -0
  63. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_monkey_patch.py +40 -0
  64. liger_kernel-0.5.10/test/transformers/test_sparsemax.py +111 -0
  65. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/utils.py +12 -0
  66. liger_kernel-0.5.9/benchmark/README.md +0 -30
  67. liger_kernel-0.5.9/examples/medusa/requirements.txt +0 -3
  68. liger_kernel-0.5.9/src/liger_kernel/ops/dyt.py +0 -225
  69. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  70. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  71. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/.github/pull_request_template.md +0 -0
  72. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/.github/workflows/amd-ci.yml +0 -0
  73. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/.github/workflows/intel-ci.yml +0 -0
  74. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/.github/workflows/nvi-ci.yml +0 -0
  75. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/.github/workflows/publish-nightly.yml +0 -0
  76. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/.github/workflows/publish-release.yml +0 -0
  77. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/.gitignore +0 -0
  78. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/LICENSE +0 -0
  79. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/Makefile +0 -0
  80. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/NOTICE +0 -0
  81. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/__init__.py +0 -0
  82. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/__init__.py +0 -0
  83. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  84. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  85. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  86. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  87. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_embedding.py +0 -0
  88. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  89. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  90. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_geglu.py +0 -0
  91. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_group_norm.py +0 -0
  92. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_jsd.py +0 -0
  93. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_kl_div.py +0 -0
  94. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  95. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  96. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  97. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  98. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  99. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_rope.py +0 -0
  100. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  101. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_swiglu.py +0 -0
  102. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/benchmark_tvd.py +0 -0
  103. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/benchmark/scripts/utils.py +0 -0
  104. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/dev/fmt-requirements.txt +0 -0
  105. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/Examples.md +0 -0
  106. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/Getting-Started.md +0 -0
  107. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/High-Level-APIs.md +0 -0
  108. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/acknowledgement.md +0 -0
  109. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/contributing.md +0 -0
  110. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/images/banner.GIF +0 -0
  111. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/images/compose.gif +0 -0
  112. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/images/e2e-memory.png +0 -0
  113. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/images/e2e-tps.png +0 -0
  114. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/images/logo-banner.png +0 -0
  115. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/images/patch.gif +0 -0
  116. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/images/post-training.png +0 -0
  117. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/index.md +0 -0
  118. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/docs/license.md +0 -0
  119. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/alignment/accelerate_config.yaml +0 -0
  120. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/alignment/run_orpo.py +0 -0
  121. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/README.md +0 -0
  122. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/callback.py +0 -0
  123. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/config/fsdp_config.json +0 -0
  124. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  125. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  126. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  127. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/img/llama_tps.png +0 -0
  128. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  129. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/img/qwen_tps.png +0 -0
  130. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/launch_on_modal.py +0 -0
  131. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/requirements.txt +0 -0
  132. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/run_benchmarks.sh +0 -0
  133. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/run_gemma.sh +0 -0
  134. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/run_llama.sh +0 -0
  135. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/run_qwen.sh +0 -0
  136. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/run_qwen2_vl.sh +0 -0
  137. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/huggingface/training.py +0 -0
  138. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/lightning/README.md +0 -0
  139. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/lightning/requirements.txt +0 -0
  140. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/lightning/training.py +0 -0
  141. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/medusa/README.md +0 -0
  142. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/medusa/callback.py +0 -0
  143. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  144. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  145. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  146. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  147. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  148. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  149. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  150. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  151. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  152. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/examples/medusa/medusa_util.py +0 -0
  153. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/licenses/LICENSE-Apache-2.0 +0 -0
  154. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  155. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  156. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/licenses/LICENSE-MIT-llmc +0 -0
  157. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/licenses/LICENSE-MIT-triton +0 -0
  158. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/mkdocs.yml +0 -0
  159. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/setup.cfg +0 -0
  160. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/__init__.py +0 -0
  161. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/chunked_loss/README.md +0 -0
  162. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  163. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  164. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/chunked_loss/functional.py +0 -0
  165. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  166. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/chunked_loss/fused_linear_ppo.py +0 -0
  167. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  168. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
  169. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  170. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  171. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  172. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/env_report.py +0 -0
  173. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/__init__.py +0 -0
  174. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/cross_entropy.py +0 -0
  175. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  176. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  177. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  178. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  179. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/geglu.py +0 -0
  180. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/group_norm.py +0 -0
  181. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/jsd.py +0 -0
  182. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/kl_div.py +0 -0
  183. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/layer_norm.py +0 -0
  184. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  185. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/rms_norm.py +0 -0
  186. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/rope.py +0 -0
  187. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/swiglu.py +0 -0
  188. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/tvd.py +0 -0
  189. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/ops/utils.py +0 -0
  190. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/auto_model.py +0 -0
  191. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  192. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  193. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  194. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  195. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/geglu.py +0 -0
  196. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/gema3_rms.py +0 -0
  197. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/group_norm.py +0 -0
  198. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/jsd.py +0 -0
  199. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/kl_div.py +0 -0
  200. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/layer_norm.py +0 -0
  201. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/__init__.py +0 -0
  202. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
  203. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  204. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/rms_norm.py +0 -0
  205. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/rope.py +0 -0
  206. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  207. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  208. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/transformers/tvd.py +0 -0
  209. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/triton/__init__.py +0 -0
  210. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel/triton/monkey_patch.py +0 -0
  211. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel.egg-info/dependency_links.txt +0 -0
  212. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel.egg-info/requires.txt +0 -0
  213. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/src/liger_kernel.egg-info/top_level.txt +0 -0
  214. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/__init__.py +0 -0
  215. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/chunked_loss/__init__.py +0 -0
  216. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/chunked_loss/test_cpo_loss.py +0 -0
  217. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/chunked_loss/test_grpo_loss.py +0 -0
  218. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/chunked_loss/test_jsd_loss.py +0 -0
  219. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/chunked_loss/test_kto_loss.py +0 -0
  220. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/chunked_loss/test_orpo_loss.py +0 -0
  221. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/chunked_loss/test_simpo_loss.py +0 -0
  222. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/conftest.py +0 -0
  223. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/convergence/__init__.py +0 -0
  224. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/convergence/bf16/__init__.py +0 -0
  225. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/convergence/fp32/__init__.py +0 -0
  226. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json +0 -0
  227. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
  228. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
  229. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
  230. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
  231. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  232. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
  233. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  234. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  235. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/resources/tiny_shakespeare.txt +0 -0
  236. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  237. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  238. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  239. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_auto_model.py +0 -0
  240. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_cross_entropy.py +0 -0
  241. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_embedding.py +0 -0
  242. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_flex_attention.py +0 -0
  243. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  244. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_fused_linear_jsd.py +0 -0
  245. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_geglu.py +0 -0
  246. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_group_norm.py +0 -0
  247. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_jsd.py +0 -0
  248. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_kl_div.py +0 -0
  249. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_layer_norm.py +0 -0
  250. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_mm_int8int2.py +0 -0
  251. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_qwen2vl_mrope.py +0 -0
  252. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_rms_norm.py +0 -0
  253. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_rope.py +0 -0
  254. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_swiglu.py +0 -0
  255. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_trainer_integration.py +0 -0
  256. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_transformers.py +0 -0
  257. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/transformers/test_tvd.py +0 -0
  258. {liger_kernel-0.5.9 → liger_kernel-0.5.10}/test/triton/test_triton_monkey_patch.py +0 -0
@@ -2,7 +2,7 @@ name: Publish documentation
2
2
  on:
3
3
  push:
4
4
  branches:
5
- - gh-pages
5
+ - main
6
6
  permissions:
7
7
  contents: write
8
8
  jobs:
@@ -0,0 +1,79 @@
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="AutoImportSettings">
4
+ <option name="autoReloadType" value="SELECTIVE" />
5
+ </component>
6
+ <component name="ChangeListManager">
7
+ <list default="true" id="d7400753-faa8-4997-a53e-65fd3a6e6146" name="Changes" comment="Reference Unsloth in header" />
8
+ <option name="SHOW_DIALOG" value="false" />
9
+ <option name="HIGHLIGHT_CONFLICTS" value="true" />
10
+ <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
11
+ <option name="LAST_RESOLUTION" value="IGNORE" />
12
+ </component>
13
+ <component name="Git.Settings">
14
+ <option name="RECENT_BRANCH_BY_REPOSITORY">
15
+ <map>
16
+ <entry key="$PROJECT_DIR$" value="main" />
17
+ </map>
18
+ </option>
19
+ <option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
20
+ </component>
21
+ <component name="GitHubPullRequestSearchHistory"><![CDATA[{
22
+ "lastFilter": {
23
+ "state": "OPEN",
24
+ "assignee": "momochen"
25
+ }
26
+ }]]></component>
27
+ <component name="GithubPullRequestsUISettings"><![CDATA[{
28
+ "selectedUrlAndAccountId": {
29
+ "url": "https://github.com/momochen/Liger-Kernel",
30
+ "accountId": "639f3e12-86db-4b12-a409-51cc017415fb"
31
+ }
32
+ }]]></component>
33
+ <component name="ProjectColorInfo"><![CDATA[{
34
+ "associatedIndex": 5
35
+ }]]></component>
36
+ <component name="ProjectId" id="2lfyDxCjSnvFrbllYmf9VBSCcMx" />
37
+ <component name="ProjectViewState">
38
+ <option name="hideEmptyMiddlePackages" value="true" />
39
+ <option name="showLibraryContents" value="true" />
40
+ </component>
41
+ <component name="PropertiesComponent"><![CDATA[{
42
+ "keyToString": {
43
+ "RunOnceActivity.ShowReadmeOnStart": "true",
44
+ "git-widget-placeholder": "ref__unsloth",
45
+ "last_opened_file_path": "/Users/ychen/workspace/github/Liger-Kernel"
46
+ }
47
+ }]]></component>
48
+ <component name="SharedIndexes">
49
+ <attachedChunks>
50
+ <set>
51
+ <option value="bundled-python-sdk-975db3bf15a3-31b6be0877a2-com.jetbrains.pycharm.community.sharedIndexes.bundled-PC-241.18034.82" />
52
+ </set>
53
+ </attachedChunks>
54
+ </component>
55
+ <component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
56
+ <component name="TaskManager">
57
+ <task active="true" id="Default" summary="Default task">
58
+ <changelist id="d7400753-faa8-4997-a53e-65fd3a6e6146" name="Changes" comment="" />
59
+ <created>1725585310555</created>
60
+ <option name="number" value="Default" />
61
+ <option name="presentableId" value="Default" />
62
+ <updated>1725585310555</updated>
63
+ </task>
64
+ <task id="LOCAL-00001" summary="Reference Unsloth in header">
65
+ <option name="closed" value="true" />
66
+ <created>1725585434299</created>
67
+ <option name="number" value="00001" />
68
+ <option name="presentableId" value="LOCAL-00001" />
69
+ <option name="project" value="LOCAL" />
70
+ <updated>1725585434299</updated>
71
+ </task>
72
+ <option name="localTasksCounter" value="2" />
73
+ <servers />
74
+ </component>
75
+ <component name="VcsManagerConfiguration">
76
+ <MESSAGE value="Reference Unsloth in header" />
77
+ <option name="LAST_COMMIT_MESSAGE" value="Reference Unsloth in header" />
78
+ </component>
79
+ </project>
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: liger_kernel
3
- Version: 0.5.9
3
+ Version: 0.5.10
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -59,7 +59,6 @@ Dynamic: requires-dist
59
59
  <th style="padding: 10px;" colspan="2">Stable</th>
60
60
  <th style="padding: 10px;" colspan="2">Nightly</th>
61
61
  <th style="padding: 10px;">Discord</th>
62
- <th style="padding: 10px;">Build</th>
63
62
  </tr>
64
63
  <tr>
65
64
  <td style="padding: 10px;">
@@ -87,23 +86,6 @@ Dynamic: requires-dist
87
86
  <img src="https://dcbadge.vercel.app/api/server/gpumode?style=flat" alt="Join Our Discord">
88
87
  </a>
89
88
  </td>
90
- <td style="padding: 10px;">
91
- <div style="display: block;">
92
- <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml">
93
- <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml/badge.svg?event=schedule" alt="Build">
94
- </a>
95
- </div>
96
- <div style="display: block;">
97
- <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
98
- <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
99
- </a>
100
- </div>
101
- <div style="display: block;">
102
- <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
103
- <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?event=schedule" alt="Build">
104
- </a>
105
- </div>
106
- </td>
107
89
  </tr>
108
90
  </table>
109
91
 
@@ -321,6 +303,7 @@ loss.backward()
321
303
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
322
304
  | Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
323
305
  | Qwen3 | `liger_kernel.transformers.apply_liger_kernel_to_qwen3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
306
+ | Qwen3 MoE | `liger_kernel_transformers.apply_liger_kernel_to_qwen3_moe` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
324
307
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
325
308
  | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
326
309
  | OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
@@ -342,7 +325,8 @@ loss.backward()
342
325
  | SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` |
343
326
  | GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
344
327
  | CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
345
- | Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
328
+ | Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
329
+ | Sparsemax | `liger_kernel.transformers.LigerSparsemax` |
346
330
 
347
331
 
348
332
  ### Alignment Kernels
@@ -390,6 +374,36 @@ loss.backward()
390
374
  - [Axolotl](https://axolotl.ai/): Integrating Liger Kernel into Axolotl.
391
375
  - [Llama-Factory](https://github.com/hiyouga/LLaMA-Factory): Integrating Liger Kernel into Llama-Factory.
392
376
 
377
+
378
+ ## CI status
379
+
380
+ <table style="width: 100%; text-align: center; border-collapse: collapse;">
381
+ <tr>
382
+ <th style="padding: 10px;">Build</th>
383
+ </tr>
384
+ <tr>
385
+ <td style="padding: 10px;">
386
+ <div style="display: block;">
387
+ <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml">
388
+ <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml/badge.svg?event=schedule" alt="Build">
389
+ </a>
390
+ </div>
391
+ <div style="display: block;">
392
+ <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
393
+ <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
394
+ </a>
395
+ </div>
396
+ <div style="display: block;">
397
+ <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
398
+ <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?event=schedule" alt="Build">
399
+ </a>
400
+ </div>
401
+ </td>
402
+ </tr>
403
+ </table>
404
+
405
+
406
+
393
407
  ## Contact
394
408
 
395
409
  - For issues, create a Github ticket in this repository
@@ -8,7 +8,6 @@
8
8
  <th style="padding: 10px;" colspan="2">Stable</th>
9
9
  <th style="padding: 10px;" colspan="2">Nightly</th>
10
10
  <th style="padding: 10px;">Discord</th>
11
- <th style="padding: 10px;">Build</th>
12
11
  </tr>
13
12
  <tr>
14
13
  <td style="padding: 10px;">
@@ -36,23 +35,6 @@
36
35
  <img src="https://dcbadge.vercel.app/api/server/gpumode?style=flat" alt="Join Our Discord">
37
36
  </a>
38
37
  </td>
39
- <td style="padding: 10px;">
40
- <div style="display: block;">
41
- <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml">
42
- <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml/badge.svg?event=schedule" alt="Build">
43
- </a>
44
- </div>
45
- <div style="display: block;">
46
- <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
47
- <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
48
- </a>
49
- </div>
50
- <div style="display: block;">
51
- <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
52
- <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?event=schedule" alt="Build">
53
- </a>
54
- </div>
55
- </td>
56
38
  </tr>
57
39
  </table>
58
40
 
@@ -270,6 +252,7 @@ loss.backward()
270
252
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
271
253
  | Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
272
254
  | Qwen3 | `liger_kernel.transformers.apply_liger_kernel_to_qwen3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
255
+ | Qwen3 MoE | `liger_kernel_transformers.apply_liger_kernel_to_qwen3_moe` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
273
256
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
274
257
  | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
275
258
  | OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
@@ -291,7 +274,8 @@ loss.backward()
291
274
  | SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` |
292
275
  | GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
293
276
  | CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
294
- | Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
277
+ | Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
278
+ | Sparsemax | `liger_kernel.transformers.LigerSparsemax` |
295
279
 
296
280
 
297
281
  ### Alignment Kernels
@@ -339,6 +323,36 @@ loss.backward()
339
323
  - [Axolotl](https://axolotl.ai/): Integrating Liger Kernel into Axolotl.
340
324
  - [Llama-Factory](https://github.com/hiyouga/LLaMA-Factory): Integrating Liger Kernel into Llama-Factory.
341
325
 
326
+
327
+ ## CI status
328
+
329
+ <table style="width: 100%; text-align: center; border-collapse: collapse;">
330
+ <tr>
331
+ <th style="padding: 10px;">Build</th>
332
+ </tr>
333
+ <tr>
334
+ <td style="padding: 10px;">
335
+ <div style="display: block;">
336
+ <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml">
337
+ <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/nvi-ci.yml/badge.svg?event=schedule" alt="Build">
338
+ </a>
339
+ </div>
340
+ <div style="display: block;">
341
+ <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
342
+ <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
343
+ </a>
344
+ </div>
345
+ <div style="display: block;">
346
+ <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
347
+ <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?event=schedule" alt="Build">
348
+ </a>
349
+ </div>
350
+ </td>
351
+ </tr>
352
+ </table>
353
+
354
+
355
+
342
356
  ## Contact
343
357
 
344
358
  - For issues, create a Github ticket in this repository
@@ -0,0 +1,48 @@
1
+ ## Benchmarking Liger Kernels
2
+
3
+ Follow these steps to benchmark and visualize kernel performance:
4
+
5
+ 1. Create a benchmark script
6
+ - Add your script under `benchmark/scripts/`
7
+ - Name it according to the kernel (e.g., `benchmark_<kernel_name>.py`)
8
+
9
+ 2. Run the benchmark
10
+ - Results will be saved to `benchmark/data/all_benchmark_data.csv`
11
+
12
+ Example: Benchmarking KTO Loss
13
+ ```bash
14
+ cd benchmark
15
+ python scripts/benchmark_kto_loss.py
16
+ ```
17
+
18
+ 3. Visualize results
19
+ - Use the visualization script with optional modes:
20
+
21
+ * To target specific mode(s), pass `--kernel-operation-mode` one or more values.
22
+ * If you omit `--kernel-operation-mode`, the script will:
23
+ - For `speed` metrics: generate plots for all available modes (forward/backward/full).
24
+ - For `memory` metrics: generate only the `full` plot.
25
+
26
+ Examples:
27
+ 1. Specific modes (speed):
28
+ ```bash
29
+ python benchmarks_visualizer.py \
30
+ --kernel-name kto_loss \
31
+ --metric-name speed \
32
+ --kernel-operation-mode forward backward
33
+ ```
34
+ 2. All modes (speed):
35
+ ```bash
36
+ python benchmarks_visualizer.py \
37
+ --kernel-name kto_loss \
38
+ --metric-name speed
39
+ ```
40
+ 3. Memory (always full):
41
+ ```bash
42
+ python benchmarks_visualizer.py \
43
+ --kernel-name kto_loss \
44
+ --metric-name memory
45
+ ```
46
+
47
+ 4. View results
48
+ - Generated plots will be saved in `benchmark/visualizations/`
@@ -1,5 +1,6 @@
1
1
  import json
2
2
  import os
3
+ import sys
3
4
 
4
5
  from argparse import ArgumentParser
5
6
  from dataclasses import dataclass
@@ -50,8 +51,9 @@ def parse_args() -> VisualizationsConfig:
50
51
  parser.add_argument(
51
52
  "--kernel-operation-mode",
52
53
  type=str,
53
- required=True,
54
- help="Kernel operation mode to visualize (forward/backward/full)",
54
+ nargs="*",
55
+ default=None,
56
+ help="Kernel operation modes to visualize (forward/backward/full). If not provided, generate for all available modes.",
55
57
  )
56
58
  parser.add_argument("--display", action="store_true", help="Display the visualization")
57
59
  parser.add_argument(
@@ -61,8 +63,7 @@ def parse_args() -> VisualizationsConfig:
61
63
  )
62
64
 
63
65
  args = parser.parse_args()
64
-
65
- return VisualizationsConfig(**dict(args._get_kwargs()))
66
+ return args
66
67
 
67
68
 
68
69
  def load_data(config: VisualizationsConfig) -> pd.DataFrame:
@@ -123,7 +124,7 @@ def plot_data(df: pd.DataFrame, config: VisualizationsConfig):
123
124
  lines = ax.get_lines()
124
125
  colors = [line.get_color() for line in lines]
125
126
 
126
- for (_, group_data), color in zip(df.groupby("kernel_provider"), colors, strict=False):
127
+ for (_, group_data), color in zip(df.groupby("kernel_provider"), colors):
127
128
  # for i, row in group_data.iterrows():
128
129
  y_error_lower = group_data["y_value_50"] - group_data["y_value_20"]
129
130
  y_error_upper = group_data["y_value_80"] - group_data["y_value_50"]
@@ -142,7 +143,10 @@ def plot_data(df: pd.DataFrame, config: VisualizationsConfig):
142
143
  plt.ylabel(ylabel)
143
144
  plt.tight_layout()
144
145
 
145
- out_path = os.path.join(VISUALIZATIONS_PATH, f"{config.kernel_name}_{config.metric_name}.png")
146
+ out_path = os.path.join(
147
+ VISUALIZATIONS_PATH,
148
+ f"{config.kernel_name}_{config.metric_name}_{config.kernel_operation_mode}.png",
149
+ )
146
150
 
147
151
  if config.display:
148
152
  plt.show()
@@ -155,9 +159,31 @@ def plot_data(df: pd.DataFrame, config: VisualizationsConfig):
155
159
 
156
160
 
157
161
  def main():
158
- config = parse_args()
159
- df = load_data(config)
160
- plot_data(df, config)
162
+ args = parse_args()
163
+ all_df = pd.read_csv(DATA_PATH)
164
+ all_df["extra_benchmark_config"] = all_df["extra_benchmark_config_str"].apply(json.loads)
165
+
166
+ if args.metric_name == "memory":
167
+ modes = ["full"]
168
+ elif args.kernel_operation_mode:
169
+ modes = args.kernel_operation_mode
170
+ else:
171
+ filtered = all_df[(all_df["kernel_name"] == args.kernel_name) & (all_df["metric_name"] == args.metric_name)]
172
+ modes = filtered["kernel_operation_mode"].unique().tolist()
173
+ if not modes:
174
+ print(f"No data found for kernel '{args.kernel_name}' and metric '{args.metric_name}'.", file=sys.stderr)
175
+ sys.exit(1)
176
+
177
+ for mode in modes:
178
+ config = VisualizationsConfig(
179
+ kernel_name=args.kernel_name,
180
+ metric_name=args.metric_name,
181
+ kernel_operation_mode=mode,
182
+ display=args.display,
183
+ overwrite=args.overwrite,
184
+ )
185
+ df = load_data(config)
186
+ plot_data(df, config)
161
187
 
162
188
 
163
189
  if __name__ == "__main__":
@@ -805,3 +805,75 @@ kto_loss,huggingface,full,memory,MB,B,Batch Size (B),4,6099.26513671875,6099.265
805
805
  kto_loss,huggingface,full,memory,MB,B,Batch Size (B),8,9613.298828125,9613.298828125,9613.298828125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
806
806
  kto_loss,huggingface,full,memory,MB,B,Batch Size (B),16,16643.365234375,16643.365234375,16643.365234375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
807
807
  kto_loss,huggingface,full,memory,MB,B,Batch Size (B),32,30703.498046875,30703.498046875,30703.498046875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
808
+ sparsemax,liger,forward,speed,ms,V,feature size,1024,0.41471999883651733,0.4126720130443573,0.42393600940704346,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
809
+ sparsemax,liger,forward,speed,ms,V,feature size,2048,0.7608320116996765,0.7598080039024353,0.7628800272941589,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
810
+ sparsemax,liger,forward,speed,ms,V,feature size,4096,1.4561280012130737,1.4540799856185913,1.4581760168075562,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
811
+ sparsemax,liger,forward,speed,ms,V,feature size,8192,5.288959980010986,5.2848639488220215,5.29986572265625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
812
+ sparsemax,liger,forward,speed,ms,V,feature size,16384,10.734624862670898,10.729472160339355,11.096882820129395,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
813
+ sparsemax,liger,forward,speed,ms,V,feature size,32768,21.729312896728516,21.7128963470459,22.20728302001953,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
814
+ sparsemax,torch,forward,speed,ms,V,feature size,1024,0.42291200160980225,0.42188799381256104,0.42393600940704346,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
815
+ sparsemax,torch,forward,speed,ms,V,feature size,2048,0.7782400250434875,0.7772160172462463,0.779263973236084,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
816
+ sparsemax,torch,forward,speed,ms,V,feature size,4096,1.4940160512924194,1.491968035697937,1.4960639476776123,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
817
+ sparsemax,torch,forward,speed,ms,V,feature size,8192,5.359615802764893,5.356544017791748,5.366579055786133,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
818
+ sparsemax,torch,forward,speed,ms,V,feature size,16384,10.883584022521973,10.874879837036133,11.224268913269043,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
819
+ sparsemax,torch,forward,speed,ms,V,feature size,32768,22.19878387451172,22.018457412719727,22.48888397216797,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
820
+ sparsemax,liger,full,speed,ms,V,feature size,1024,0.4558719992637634,0.45558398962020874,0.45772799849510193,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
821
+ sparsemax,liger,full,speed,ms,V,feature size,2048,0.8488960266113281,0.8478720188140869,0.8509439826011658,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
822
+ sparsemax,liger,full,speed,ms,V,feature size,4096,1.6476160287857056,1.6465920209884644,1.6499264240264893,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
823
+ sparsemax,liger,full,speed,ms,V,feature size,8192,5.664768218994141,5.660672187805176,5.681356906890869,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
824
+ sparsemax,liger,full,speed,ms,V,feature size,16384,11.486207962036133,11.478015899658203,11.874713897705078,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
825
+ sparsemax,liger,full,speed,ms,V,feature size,32768,23.457279205322266,23.289682388305664,23.76642608642578,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
826
+ sparsemax,torch,full,speed,ms,V,feature size,1024,0.6021119952201843,0.6010879874229431,0.6041600108146667,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
827
+ sparsemax,torch,full,speed,ms,V,feature size,2048,1.1212799549102783,1.119264006614685,1.1223039627075195,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
828
+ sparsemax,torch,full,speed,ms,V,feature size,4096,2.1637120246887207,2.1616640090942383,2.165760040283203,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
829
+ sparsemax,torch,full,speed,ms,V,feature size,8192,6.693888187408447,6.68723201751709,6.705561637878418,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
830
+ sparsemax,torch,full,speed,ms,V,feature size,16384,13.523456573486328,13.518848419189453,13.878681182861328,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
831
+ sparsemax,torch,full,speed,ms,V,feature size,32768,27.604991912841797,27.295129776000977,27.77518081665039,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
832
+ sparsemax,liger,backward,speed,ms,V,feature size,1024,0.04403200000524521,0.043007999658584595,0.05222399905323982,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
833
+ sparsemax,liger,backward,speed,ms,V,feature size,2048,0.08806400001049042,0.08713600039482117,0.08806400001049042,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
834
+ sparsemax,liger,backward,speed,ms,V,feature size,4096,0.1884160041809082,0.1884160041809082,0.18943999707698822,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
835
+ sparsemax,liger,backward,speed,ms,V,feature size,8192,0.374783992767334,0.37376001477241516,0.37486720085144043,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
836
+ sparsemax,liger,backward,speed,ms,V,feature size,16384,0.7516160011291504,0.7505919933319092,0.7516160011291504,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
837
+ sparsemax,liger,backward,speed,ms,V,feature size,32768,1.5738879442214966,1.572864055633545,1.575935959815979,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
838
+ sparsemax,torch,backward,speed,ms,V,feature size,1024,0.1812479943037033,0.1802240014076233,0.18227200210094452,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
839
+ sparsemax,torch,backward,speed,ms,V,feature size,2048,0.34406399726867676,0.34406399726867676,0.34508800506591797,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
840
+ sparsemax,torch,backward,speed,ms,V,feature size,4096,0.6717439889907837,0.6707199811935425,0.6727679967880249,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
841
+ sparsemax,torch,backward,speed,ms,V,feature size,8192,1.3250559568405151,1.3241215944290161,1.3260799646377563,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
842
+ sparsemax,torch,backward,speed,ms,V,feature size,16384,2.629631996154785,2.628607988357544,2.6306560039520264,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
843
+ sparsemax,torch,backward,speed,ms,V,feature size,32768,5.236735820770264,5.235712051391602,5.239808082580566,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
844
+ sparsemax,liger,full,memory,MB,V,feature size,1024,82.03515625,82.03515625,82.03515625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
845
+ sparsemax,liger,full,memory,MB,V,feature size,2048,164.0390625,164.0390625,164.0390625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
846
+ sparsemax,liger,full,memory,MB,V,feature size,4096,328.046875,328.046875,328.046875,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
847
+ sparsemax,liger,full,memory,MB,V,feature size,8192,704.00048828125,704.00048828125,704.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
848
+ sparsemax,liger,full,memory,MB,V,feature size,16384,1408.00048828125,1408.00048828125,1408.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
849
+ sparsemax,liger,full,memory,MB,V,feature size,32768,2816.00048828125,2816.00048828125,2816.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
850
+ sparsemax,torch,full,memory,MB,V,feature size,1024,82.03515625,82.03515625,82.03515625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
851
+ sparsemax,torch,full,memory,MB,V,feature size,2048,164.0390625,164.0390625,164.0390625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
852
+ sparsemax,torch,full,memory,MB,V,feature size,4096,328.046875,328.046875,328.046875,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
853
+ sparsemax,torch,full,memory,MB,V,feature size,8192,704.00048828125,704.00048828125,704.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
854
+ sparsemax,torch,full,memory,MB,V,feature size,16384,1408.00048828125,1408.00048828125,1408.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
855
+ sparsemax,torch,full,memory,MB,V,feature size,32768,2816.00048828125,2816.00048828125,2816.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
856
+ sparsemax,liger,forward,memory,MB,V,feature size,1024,56.0078125,56.0078125,56.0078125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
857
+ sparsemax,liger,forward,memory,MB,V,feature size,2048,112.015625,112.015625,112.015625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
858
+ sparsemax,liger,forward,memory,MB,V,feature size,4096,224.03125,224.03125,224.03125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
859
+ sparsemax,liger,forward,memory,MB,V,feature size,8192,768.00048828125,768.00048828125,768.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
860
+ sparsemax,liger,forward,memory,MB,V,feature size,16384,1536.00048828125,1536.00048828125,1536.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
861
+ sparsemax,liger,forward,memory,MB,V,feature size,32768,3072.00048828125,3072.00048828125,3072.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
862
+ sparsemax,torch,forward,memory,MB,V,feature size,1024,82.03515625,82.03515625,82.03515625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
863
+ sparsemax,torch,forward,memory,MB,V,feature size,2048,164.0390625,164.0390625,164.0390625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
864
+ sparsemax,torch,forward,memory,MB,V,feature size,4096,328.046875,328.046875,328.046875,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
865
+ sparsemax,torch,forward,memory,MB,V,feature size,8192,704.00048828125,704.00048828125,704.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
866
+ sparsemax,torch,forward,memory,MB,V,feature size,16384,1408.00048828125,1408.00048828125,1408.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
867
+ sparsemax,torch,forward,memory,MB,V,feature size,32768,2816.00048828125,2816.00048828125,2816.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
868
+ sparsemax,liger,backward,memory,MB,V,feature size,1024,56.0078125,56.0078125,56.0078125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
869
+ sparsemax,liger,backward,memory,MB,V,feature size,2048,112.015625,112.015625,112.015625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
870
+ sparsemax,liger,backward,memory,MB,V,feature size,4096,224.03125,224.03125,224.03125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
871
+ sparsemax,liger,backward,memory,MB,V,feature size,8192,768.00048828125,768.00048828125,768.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
872
+ sparsemax,liger,backward,memory,MB,V,feature size,16384,1536.00048828125,1536.00048828125,1536.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
873
+ sparsemax,liger,backward,memory,MB,V,feature size,32768,3072.00048828125,3072.00048828125,3072.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
874
+ sparsemax,torch,backward,memory,MB,V,feature size,1024,82.03515625,82.03515625,82.03515625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
875
+ sparsemax,torch,backward,memory,MB,V,feature size,2048,164.0390625,164.0390625,164.0390625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
876
+ sparsemax,torch,backward,memory,MB,V,feature size,4096,328.046875,328.046875,328.046875,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
877
+ sparsemax,torch,backward,memory,MB,V,feature size,8192,704.00048828125,704.00048828125,704.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
878
+ sparsemax,torch,backward,memory,MB,V,feature size,16384,1408.00048828125,1408.00048828125,1408.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
879
+ sparsemax,torch,backward,memory,MB,V,feature size,32768,2816.00048828125,2816.00048828125,2816.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
@@ -22,17 +22,18 @@ def bench_speed_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
22
22
  from test.transformers.test_dyt import LigerDyT
23
23
  from test.transformers.test_dyt import TorchDyT
24
24
 
25
- BT = input.x
25
+ hidden_size = input.x
26
26
  provider = input.kernel_provider
27
27
  mode = input.kernel_operation_mode
28
28
  extra_benchmark_config = input.extra_benchmark_config
29
- hidden_size = extra_benchmark_config["hidden_size"]
29
+ BT = extra_benchmark_config["BT"]
30
+ beta = extra_benchmark_config["beta"]
30
31
  dtype = extra_benchmark_config["dtype"]
31
32
 
32
33
  x_shape = (BT, hidden_size)
33
- torch_dyt = TorchDyT(hidden_size=hidden_size).to(device)
34
- torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size).to(device))
35
- triton_dyt = LigerDyT(hidden_size=hidden_size).to(device)
34
+ torch_dyt = TorchDyT(hidden_size=hidden_size, beta=beta).to(device)
35
+ torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size, beta=beta).to(device))
36
+ triton_dyt = LigerDyT(hidden_size=hidden_size, beta=beta).to(device)
36
37
 
37
38
  x = torch.randn(x_shape, dtype=dtype, device=device)
38
39
  dy = torch.randn_like(x)
@@ -75,16 +76,17 @@ def bench_memory_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput
75
76
  from test.transformers.test_dyt import LigerDyT
76
77
  from test.transformers.test_dyt import TorchDyT
77
78
 
78
- BT = input.x
79
+ hidden_size = input.x
79
80
  provider = input.kernel_provider
80
81
  extra_benchmark_config = input.extra_benchmark_config
81
- hidden_size = extra_benchmark_config["hidden_size"]
82
+ BT = extra_benchmark_config["BT"]
83
+ beta = extra_benchmark_config["beta"]
82
84
  dtype = extra_benchmark_config["dtype"]
83
85
 
84
86
  x_shape = (BT, hidden_size)
85
- torch_dyt = TorchDyT(hidden_size=hidden_size).to(device)
86
- torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size).to(device))
87
- triton_dyt = LigerDyT(hidden_size=hidden_size).to(device)
87
+ torch_dyt = TorchDyT(hidden_size=hidden_size, beta=beta).to(device)
88
+ torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size, beta=beta).to(device))
89
+ triton_dyt = LigerDyT(hidden_size=hidden_size, beta=beta).to(device)
88
90
 
89
91
  x = torch.randn(x_shape, dtype=dtype, device=device)
90
92
  dy = torch.randn_like(x)
@@ -113,27 +115,28 @@ def bench_memory_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput
113
115
  if __name__ == "__main__":
114
116
  args = parse_benchmark_script_args()
115
117
 
116
- common_configs = {
117
- "kernel_name": "dyt",
118
- "x_name": "BT",
119
- "x_label": "batch_size * seq_len",
120
- "x_values": [2**i for i in range(10, 15)],
121
- "kernel_providers": ["liger", "torch", "torch_compile"],
122
- "extra_benchmark_configs": [{"hidden_size": 4096, "dtype": torch.float32}],
123
- "overwrite": args.overwrite,
124
- }
125
-
126
- run_benchmarks(
127
- bench_test_fn=bench_speed_dyt,
128
- kernel_operation_modes=["forward", "backward", "full"],
129
- metric_name="speed",
130
- metric_unit="ms",
131
- **common_configs,
132
- )
133
- run_benchmarks(
134
- bench_test_fn=bench_memory_dyt,
135
- kernel_operation_modes=["full"],
136
- metric_name="memory",
137
- metric_unit="MB",
138
- **common_configs,
139
- )
118
+ for beta in [False, True]:
119
+ common_configs = {
120
+ "kernel_name": f"dyt_beta={beta}",
121
+ "x_name": "hidden_size",
122
+ "x_label": "hidden_size",
123
+ "x_values": [1024 * i for i in range(1, 17)],
124
+ "kernel_providers": ["liger", "torch", "torch_compile"],
125
+ "extra_benchmark_configs": [{"BT": 4096, "dtype": torch.bfloat16, "beta": beta}],
126
+ "overwrite": args.overwrite,
127
+ }
128
+
129
+ run_benchmarks(
130
+ bench_test_fn=bench_speed_dyt,
131
+ kernel_operation_modes=["forward", "backward", "full"],
132
+ metric_name="speed",
133
+ metric_unit="ms",
134
+ **common_configs,
135
+ )
136
+ run_benchmarks(
137
+ bench_test_fn=bench_memory_dyt,
138
+ kernel_operation_modes=["full"],
139
+ metric_name="memory",
140
+ metric_unit="MB",
141
+ **common_configs,
142
+ )