liger-kernel 0.5.10__tar.gz → 0.6.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. liger_kernel-0.6.0/.github/workflows/benchmark.yml +93 -0
  2. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/.github/workflows/docs.yml +3 -1
  3. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/.gitignore +4 -0
  4. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/Makefile +8 -2
  5. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/PKG-INFO +8 -2
  6. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/README.md +5 -0
  7. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/benchmark/benchmarks_visualizer.py +125 -16
  8. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/benchmark/data/all_benchmark_data.csv +640 -24
  9. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/benchmark/scripts/benchmark_cpo_loss.py +1 -1
  10. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/benchmark/scripts/benchmark_cross_entropy.py +1 -1
  11. liger_kernel-0.6.0/benchmark/scripts/benchmark_distill_cosine_loss.py +266 -0
  12. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/benchmark/scripts/benchmark_distill_jsd_loss.py +1 -1
  13. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/benchmark/scripts/benchmark_dpo_loss.py +1 -1
  14. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/benchmark/scripts/benchmark_embedding.py +1 -1
  15. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +1 -1
  16. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/benchmark/scripts/benchmark_fused_linear_jsd.py +1 -1
  17. liger_kernel-0.6.0/benchmark/scripts/benchmark_fused_neighborhood_attention.py +367 -0
  18. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/benchmark/scripts/benchmark_jsd.py +1 -1
  19. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/benchmark/scripts/benchmark_kl_div.py +1 -1
  20. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/benchmark/scripts/benchmark_kto_loss.py +1 -1
  21. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/benchmark/scripts/benchmark_layer_norm.py +1 -1
  22. liger_kernel-0.6.0/benchmark/scripts/benchmark_multi_token_attention.py +218 -0
  23. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/benchmark/scripts/benchmark_orpo_loss.py +1 -1
  24. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/benchmark/scripts/benchmark_simpo_loss.py +1 -1
  25. liger_kernel-0.6.0/benchmark/scripts/benchmark_softmax.py +140 -0
  26. liger_kernel-0.6.0/benchmark/scripts/benchmark_sparse_multi_token_attention.py +254 -0
  27. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/benchmark/scripts/benchmark_swiglu.py +1 -1
  28. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/benchmark/scripts/benchmark_tvd.py +1 -1
  29. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/benchmark/scripts/utils.py +8 -1
  30. liger_kernel-0.6.0/dev/modal/benchmarks.py +73 -0
  31. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/dev/modal/tests.py +1 -1
  32. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/dev/modal/tests_bwd.py +3 -3
  33. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/docs/Low-Level-APIs.md +15 -0
  34. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/mkdocs.yml +2 -2
  35. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/pyproject.toml +1 -1
  36. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/setup.py +2 -1
  37. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/chunked_loss/__init__.py +1 -0
  38. liger_kernel-0.6.0/src/liger_kernel/chunked_loss/cosine_similarity_loss.py +127 -0
  39. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/chunked_loss/functional.py +2 -0
  40. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/ops/dyt.py +0 -2
  41. liger_kernel-0.6.0/src/liger_kernel/ops/fused_neighborhood_attention.py +1022 -0
  42. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/ops/geglu.py +1 -1
  43. liger_kernel-0.6.0/src/liger_kernel/ops/multi_token_attention.py +207 -0
  44. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/ops/rms_norm.py +265 -54
  45. liger_kernel-0.6.0/src/liger_kernel/ops/softmax.py +201 -0
  46. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/ops/sparsemax.py +62 -50
  47. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/ops/swiglu.py +1 -1
  48. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/__init__.py +3 -0
  49. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/functional.py +62 -0
  50. liger_kernel-0.6.0/src/liger_kernel/transformers/fused_neighborhood_attention.py +234 -0
  51. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/model/gemma.py +25 -8
  52. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/model/gemma2.py +27 -8
  53. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/model/gemma3.py +62 -98
  54. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/model/glm4.py +16 -7
  55. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/model/llama.py +25 -7
  56. liger_kernel-0.6.0/src/liger_kernel/transformers/model/llama4.py +108 -0
  57. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/model/llava.py +95 -124
  58. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/model/mistral.py +13 -8
  59. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/model/mixtral.py +16 -7
  60. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/model/mllama.py +16 -7
  61. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/model/olmo2.py +16 -7
  62. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/model/paligemma.py +8 -1
  63. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/model/phi3.py +25 -8
  64. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/model/qwen2.py +24 -7
  65. liger_kernel-0.6.0/src/liger_kernel/transformers/model/qwen2_5_vl.py +150 -0
  66. liger_kernel-0.6.0/src/liger_kernel/transformers/model/qwen2_vl.py +142 -0
  67. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/model/qwen3.py +11 -3
  68. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/model/qwen3_moe.py +10 -6
  69. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/monkey_patch.py +304 -70
  70. liger_kernel-0.6.0/src/liger_kernel/transformers/multi_token_attention.py +64 -0
  71. liger_kernel-0.6.0/src/liger_kernel/transformers/rms_norm.py +79 -0
  72. liger_kernel-0.6.0/src/liger_kernel/transformers/softmax.py +12 -0
  73. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel.egg-info/PKG-INFO +8 -2
  74. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel.egg-info/SOURCES.txt +20 -2
  75. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel.egg-info/requires.txt +2 -1
  76. liger_kernel-0.6.0/test/chunked_loss/test_cosine_loss.py +320 -0
  77. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/convergence/bf16/test_mini_models.py +163 -65
  78. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/convergence/bf16/test_mini_models_multimodal.py +204 -49
  79. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/convergence/bf16/test_mini_models_with_logits.py +132 -52
  80. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/convergence/fp32/test_mini_models.py +141 -34
  81. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/convergence/fp32/test_mini_models_multimodal.py +185 -25
  82. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/convergence/fp32/test_mini_models_with_logits.py +96 -20
  83. liger_kernel-0.6.0/test/resources/fake_configs/meta-llama/Llama-4-Scout-17B-16E-Instruct/tokenizer_config.json +98 -0
  84. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/transformers/test_dyt.py +12 -8
  85. liger_kernel-0.6.0/test/transformers/test_fused_neighborhood_attention.py +572 -0
  86. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/transformers/test_monkey_patch.py +556 -39
  87. liger_kernel-0.6.0/test/transformers/test_multi_token_attention.py +327 -0
  88. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/transformers/test_qwen2vl_mrope.py +3 -3
  89. liger_kernel-0.6.0/test/transformers/test_softmax.py +103 -0
  90. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/utils.py +44 -2
  91. liger_kernel-0.5.10/.idea/workspace.xml +0 -79
  92. liger_kernel-0.5.10/src/liger_kernel/transformers/gema3_rms.py +0 -8
  93. liger_kernel-0.5.10/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -200
  94. liger_kernel-0.5.10/src/liger_kernel/transformers/model/qwen2_vl.py +0 -204
  95. liger_kernel-0.5.10/src/liger_kernel/transformers/rms_norm.py +0 -43
  96. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  97. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  98. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/.github/pull_request_template.md +0 -0
  99. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/.github/workflows/amd-ci.yml +0 -0
  100. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/.github/workflows/intel-ci.yml +0 -0
  101. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/.github/workflows/nvi-ci.yml +0 -0
  102. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/.github/workflows/publish-nightly.yml +0 -0
  103. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/.github/workflows/publish-release.yml +0 -0
  104. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/LICENSE +0 -0
  105. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/NOTICE +0 -0
  106. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/benchmark/README.md +0 -0
  107. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/benchmark/__init__.py +0 -0
  108. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/benchmark/scripts/__init__.py +0 -0
  109. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/benchmark/scripts/benchmark_dyt.py +0 -0
  110. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/benchmark/scripts/benchmark_geglu.py +0 -0
  111. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/benchmark/scripts/benchmark_group_norm.py +0 -0
  112. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  113. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  114. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/benchmark/scripts/benchmark_rope.py +0 -0
  115. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/benchmark/scripts/benchmark_sparsemax.py +0 -0
  116. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/dev/fmt-requirements.txt +0 -0
  117. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/docs/Examples.md +0 -0
  118. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/docs/Getting-Started.md +0 -0
  119. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/docs/High-Level-APIs.md +0 -0
  120. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/docs/acknowledgement.md +0 -0
  121. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/docs/contributing.md +0 -0
  122. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/docs/images/banner.GIF +0 -0
  123. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/docs/images/compose.gif +0 -0
  124. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/docs/images/e2e-memory.png +0 -0
  125. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/docs/images/e2e-tps.png +0 -0
  126. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/docs/images/logo-banner.png +0 -0
  127. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/docs/images/patch.gif +0 -0
  128. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/docs/images/post-training.png +0 -0
  129. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/docs/index.md +0 -0
  130. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/docs/license.md +0 -0
  131. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/alignment/accelerate_config.yaml +0 -0
  132. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/alignment/run_orpo.py +0 -0
  133. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/huggingface/README.md +0 -0
  134. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/huggingface/callback.py +0 -0
  135. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/huggingface/config/fsdp_config.json +0 -0
  136. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  137. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  138. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  139. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/huggingface/img/llama_tps.png +0 -0
  140. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  141. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/huggingface/img/qwen_tps.png +0 -0
  142. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/huggingface/launch_on_modal.py +0 -0
  143. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/huggingface/requirements.txt +0 -0
  144. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/huggingface/run_benchmarks.sh +0 -0
  145. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/huggingface/run_gemma.sh +0 -0
  146. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/huggingface/run_llama.sh +0 -0
  147. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/huggingface/run_qwen.sh +0 -0
  148. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/huggingface/run_qwen2_vl.sh +0 -0
  149. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/huggingface/training.py +0 -0
  150. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/huggingface/training_multimodal.py +0 -0
  151. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/lightning/README.md +0 -0
  152. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/lightning/requirements.txt +0 -0
  153. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/lightning/training.py +0 -0
  154. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/medusa/README.md +0 -0
  155. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/medusa/callback.py +0 -0
  156. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  157. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  158. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  159. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  160. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  161. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  162. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  163. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  164. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  165. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/medusa/medusa_util.py +0 -0
  166. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/medusa/requirements.txt +0 -0
  167. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  168. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/examples/medusa/train.py +0 -0
  169. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/licenses/LICENSE-Apache-2.0 +0 -0
  170. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  171. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  172. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/licenses/LICENSE-MIT-llmc +0 -0
  173. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/licenses/LICENSE-MIT-triton +0 -0
  174. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/setup.cfg +0 -0
  175. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/__init__.py +0 -0
  176. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/chunked_loss/README.md +0 -0
  177. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  178. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  179. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  180. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/chunked_loss/fused_linear_ppo.py +0 -0
  181. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  182. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  183. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
  184. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  185. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  186. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  187. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  188. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/env_report.py +0 -0
  189. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/ops/__init__.py +0 -0
  190. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/ops/cross_entropy.py +0 -0
  191. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  192. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  193. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  194. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  195. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/ops/group_norm.py +0 -0
  196. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/ops/grpo_loss.py +0 -0
  197. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/ops/jsd.py +0 -0
  198. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/ops/kl_div.py +0 -0
  199. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/ops/layer_norm.py +0 -0
  200. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  201. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/ops/rope.py +0 -0
  202. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/ops/tvd.py +0 -0
  203. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/ops/utils.py +0 -0
  204. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/auto_model.py +0 -0
  205. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  206. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/dyt.py +0 -0
  207. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  208. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/fsdp.py +0 -0
  209. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  210. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  211. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/geglu.py +0 -0
  212. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/group_norm.py +0 -0
  213. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/grpo_loss.py +0 -0
  214. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/jsd.py +0 -0
  215. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/kl_div.py +0 -0
  216. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/layer_norm.py +0 -0
  217. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/model/__init__.py +0 -0
  218. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
  219. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  220. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/rope.py +0 -0
  221. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/sparsemax.py +0 -0
  222. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/swiglu.py +0 -0
  223. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  224. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  225. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  226. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/transformers/tvd.py +0 -0
  227. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/triton/__init__.py +0 -0
  228. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/triton/monkey_patch.py +0 -0
  229. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel/utils.py +0 -0
  230. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel.egg-info/dependency_links.txt +0 -0
  231. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/src/liger_kernel.egg-info/top_level.txt +0 -0
  232. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/__init__.py +0 -0
  233. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/chunked_loss/__init__.py +0 -0
  234. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/chunked_loss/test_cpo_loss.py +0 -0
  235. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/chunked_loss/test_dpo_loss.py +0 -0
  236. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/chunked_loss/test_grpo_loss.py +0 -0
  237. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/chunked_loss/test_jsd_loss.py +0 -0
  238. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/chunked_loss/test_kto_loss.py +0 -0
  239. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/chunked_loss/test_orpo_loss.py +0 -0
  240. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/chunked_loss/test_simpo_loss.py +0 -0
  241. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/conftest.py +0 -0
  242. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/convergence/__init__.py +0 -0
  243. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/convergence/bf16/__init__.py +0 -0
  244. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/convergence/fp32/__init__.py +0 -0
  245. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json +0 -0
  246. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
  247. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
  248. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
  249. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
  250. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  251. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
  252. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  253. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  254. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/resources/tiny_shakespeare.txt +0 -0
  255. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  256. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  257. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  258. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/transformers/test_auto_model.py +0 -0
  259. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/transformers/test_cross_entropy.py +0 -0
  260. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/transformers/test_embedding.py +0 -0
  261. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/transformers/test_flex_attention.py +0 -0
  262. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  263. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/transformers/test_fused_linear_jsd.py +0 -0
  264. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/transformers/test_geglu.py +0 -0
  265. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/transformers/test_group_norm.py +0 -0
  266. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/transformers/test_grpo_loss.py +0 -0
  267. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/transformers/test_jsd.py +0 -0
  268. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/transformers/test_kl_div.py +0 -0
  269. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/transformers/test_layer_norm.py +0 -0
  270. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/transformers/test_mm_int8int2.py +0 -0
  271. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/transformers/test_rms_norm.py +0 -0
  272. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/transformers/test_rope.py +0 -0
  273. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/transformers/test_sparsemax.py +0 -0
  274. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/transformers/test_swiglu.py +0 -0
  275. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/transformers/test_trainer_integration.py +0 -0
  276. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/transformers/test_transformers.py +0 -0
  277. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/transformers/test_tvd.py +0 -0
  278. {liger_kernel-0.5.10 → liger_kernel-0.6.0}/test/triton/test_triton_monkey_patch.py +0 -0
@@ -0,0 +1,93 @@
1
+ name: Benchmarks
2
+
3
+ on:
4
+ schedule:
5
+ # Runs at 00:00 UTC every Friday
6
+ - cron: '0 0 * * 5'
7
+ workflow_dispatch: # Enables manual trigger
8
+
9
+ permissions:
10
+ contents: write
11
+
12
+ concurrency:
13
+ # This causes it to cancel previous in-progress actions on the same PR / branch,
14
+ group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
15
+ cancel-in-progress: true
16
+
17
+ jobs:
18
+ benchmarks:
19
+ runs-on: ubuntu-latest
20
+ env:
21
+ MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
22
+ MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
23
+ GITHUB_USERNAME: linkedin
24
+ REPO_NAME: Liger-Kernel
25
+ OUTPUT_DIR: benchmarks
26
+ OUTPUT_FILENAME: benchmark.csv
27
+ GENERATED_CSV: benchmark/data/all_benchmark_data.csv
28
+
29
+
30
+ steps:
31
+ - name: Checkout code
32
+ uses: actions/checkout@v3
33
+
34
+ # Get the latest commit hash from main branch
35
+ - name: Get commit hash
36
+ id: get_hash
37
+ run: echo "hash=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT
38
+
39
+ - name: Set up Python
40
+ uses: actions/setup-python@v3
41
+ with:
42
+ python-version: '3.10'
43
+
44
+ # Install dependencies
45
+ - name: Install dependencies
46
+ run: |
47
+ python -m pip install --upgrade pip
48
+ pip install modal
49
+ pip install pandas
50
+
51
+ # Delete previous benchmark results.
52
+ - name: Remove previous benchmark data
53
+ run: |
54
+ rm -f benchmark/data/all_benchmark_data.csv
55
+
56
+ - name: Run benchmarks on GPU
57
+ run: |
58
+ modal run dev.modal.benchmarks
59
+
60
+ # Step 5: Checkout gh-pages branch in a subfolderAdd commentMore actions
61
+ - name: Checkout gh-pages
62
+ uses: actions/checkout@v3
63
+ with:
64
+ ref: gh-pages
65
+ path: gh-pages
66
+
67
+ # Step 6: Copy benchmark CSV to gh-pages directory
68
+ - name: Copy generated benchmark to gh-pages
69
+ run: |
70
+ mkdir -p gh-pages/${OUTPUT_DIR}/${{ steps.get_hash.outputs.hash }}
71
+ cp ${GENERATED_CSV} gh-pages/${OUTPUT_DIR}/${{ steps.get_hash.outputs.hash }}/${OUTPUT_FILENAME}
72
+ # Step 7: Append commit hash to commits.txt if not already present
73
+ - name: Update commits.txt
74
+ run: |
75
+ cd gh-pages
76
+ echo "commits.txt file path: ${OUTPUT_DIR}/commits.txt"
77
+ # Create file if it doesn't exist
78
+ mkdir -p ${OUTPUT_DIR}
79
+ touch ${OUTPUT_DIR}/commits.txt
80
+ # Append only if not already present
81
+ if ! grep -q "${{ steps.get_hash.outputs.hash }}" ${OUTPUT_DIR}/commits.txt; then
82
+ echo "${{ steps.get_hash.outputs.hash }}" >> ${OUTPUT_DIR}/commits.txt
83
+ fi
84
+ # Step 7: Commit and push
85
+ - name: Commit and push to gh-pages
86
+ run: |
87
+ cd gh-pages
88
+ git config user.name github-actions[bot]
89
+ git config user.email 41898282+github-actions[bot]@users.noreply.github.com
90
+ git add .
91
+ git commit -m "Add benchmark for commit ${{ steps.get_hash.outputs.hash }}" || echo "No changes to commit"
92
+ git push origin gh-pages
93
+
@@ -3,10 +3,12 @@ on:
3
3
  push:
4
4
  branches:
5
5
  - main
6
+
6
7
  permissions:
7
8
  contents: write
8
9
  jobs:
9
10
  deploy:
11
+ if: False
10
12
  runs-on: ubuntu-latest
11
13
  steps:
12
14
  - uses: actions/checkout@v4
@@ -25,4 +27,4 @@ jobs:
25
27
  restore-keys: |
26
28
  mkdocs-material-
27
29
  - run: pip install mkdocs-material
28
- - run: mkdocs gh-deploy --force
30
+ - run: mkdocs gh-deploy --force
@@ -6,6 +6,7 @@ site/
6
6
  venv/
7
7
  .ipynb_checkpoints/
8
8
  .vscode/
9
+ .idea/
9
10
 
10
11
  # Misc
11
12
  .DS_Store
@@ -14,6 +15,9 @@ venv/
14
15
  build/
15
16
  dist/
16
17
 
18
+ # Doc Build
19
+ site/
20
+
17
21
  # Lockfiles
18
22
  uv.lock
19
23
 
@@ -48,13 +48,19 @@ run-benchmarks:
48
48
  # MkDocs Configuration
49
49
  MKDOCS = mkdocs
50
50
  CONFIG_FILE = mkdocs.yml
51
+ SITE_DIR = site
51
52
 
52
53
  # MkDocs targets
54
+
55
+ # Serve the documentation
53
56
  serve:
54
57
  $(MKDOCS) serve -f $(CONFIG_FILE)
55
58
 
59
+ # Build the documentation into the specified site directory
56
60
  build:
57
- $(MKDOCS) build -f $(CONFIG_FILE)
61
+ $(MKDOCS) build -f $(CONFIG_FILE) --site-dir $(SITE_DIR)
58
62
 
63
+ # Clean the output directory
59
64
  clean:
60
- rm -rf site/
65
+ rm -rf $(SITE_DIR)/
66
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: liger_kernel
3
- Version: 0.5.10
3
+ Version: 0.6.0
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -33,7 +33,7 @@ License-File: NOTICE
33
33
  Requires-Dist: torch>=2.1.2
34
34
  Requires-Dist: triton>=2.3.1
35
35
  Provides-Extra: dev
36
- Requires-Dist: transformers>=4.44.2; extra == "dev"
36
+ Requires-Dist: transformers>=4.49.0; extra == "dev"
37
37
  Requires-Dist: matplotlib>=3.7.2; extra == "dev"
38
38
  Requires-Dist: flake8>=4.0.1.1; extra == "dev"
39
39
  Requires-Dist: black>=24.4.2; extra == "dev"
@@ -45,6 +45,7 @@ Requires-Dist: datasets>=2.19.2; extra == "dev"
45
45
  Requires-Dist: seaborn; extra == "dev"
46
46
  Requires-Dist: mkdocs; extra == "dev"
47
47
  Requires-Dist: mkdocs-material; extra == "dev"
48
+ Requires-Dist: torchvision>=0.20; extra == "dev"
48
49
  Dynamic: license-file
49
50
  Dynamic: provides-extra
50
51
  Dynamic: requires-dist
@@ -114,6 +115,8 @@ Dynamic: requires-dist
114
115
 
115
116
  We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, KTO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
116
117
 
118
+ You can view the documentation site for additional installation, usage examples, and API references:https://linkedin.github.io/Liger-Kernel/
119
+
117
120
  ## Supercharge Your Model with Liger Kernel
118
121
 
119
122
  ![Banner](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/banner.GIF)
@@ -290,6 +293,7 @@ loss.backward()
290
293
 
291
294
  | **Model** | **API** | **Supported Operations** |
292
295
  |-------------|--------------------------------------------------------------|-------------------------------------------------------------------------|
296
+ | Llama4 (Text) & (Multimodal) | `liger_kernel.transformers.apply_liger_kernel_to_llama4` | RMSNorm, LayerNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
293
297
  | LLaMA 2 & 3 | `liger_kernel.transformers.apply_liger_kernel_to_llama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
294
298
  | LLaMA 3.2-Vision | `liger_kernel.transformers.apply_liger_kernel_to_mllama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
295
299
  | Mistral | `liger_kernel.transformers.apply_liger_kernel_to_mistral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
@@ -326,6 +330,8 @@ loss.backward()
326
330
  | GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
327
331
  | CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
328
332
  | Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
333
+ | Multi Token Attention | `liger_kernel.transformers.LigerMultiTokenAttention` |
334
+ | Softmax | `liger_kernel.transformers.LigerSoftmax` |
329
335
  | Sparsemax | `liger_kernel.transformers.LigerSparsemax` |
330
336
 
331
337
 
@@ -63,6 +63,8 @@
63
63
 
64
64
  We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, KTO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
65
65
 
66
+ You can view the documentation site for additional installation, usage examples, and API references:https://linkedin.github.io/Liger-Kernel/
67
+
66
68
  ## Supercharge Your Model with Liger Kernel
67
69
 
68
70
  ![Banner](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/banner.GIF)
@@ -239,6 +241,7 @@ loss.backward()
239
241
 
240
242
  | **Model** | **API** | **Supported Operations** |
241
243
  |-------------|--------------------------------------------------------------|-------------------------------------------------------------------------|
244
+ | Llama4 (Text) & (Multimodal) | `liger_kernel.transformers.apply_liger_kernel_to_llama4` | RMSNorm, LayerNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
242
245
  | LLaMA 2 & 3 | `liger_kernel.transformers.apply_liger_kernel_to_llama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
243
246
  | LLaMA 3.2-Vision | `liger_kernel.transformers.apply_liger_kernel_to_mllama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
244
247
  | Mistral | `liger_kernel.transformers.apply_liger_kernel_to_mistral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
@@ -275,6 +278,8 @@ loss.backward()
275
278
  | GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
276
279
  | CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
277
280
  | Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
281
+ | Multi Token Attention | `liger_kernel.transformers.LigerMultiTokenAttention` |
282
+ | Softmax | `liger_kernel.transformers.LigerSoftmax` |
278
283
  | Sparsemax | `liger_kernel.transformers.LigerSparsemax` |
279
284
 
280
285
 
@@ -22,6 +22,9 @@ class VisualizationsConfig:
22
22
  kernel_name (str): Kernel name to benchmark. (Will run `scripts/benchmark_{kernel_name}.py`)
23
23
  metric_name (str): Metric name to visualize (speed/memory)
24
24
  kernel_operation_mode (str): Kernel operation mode to visualize (forward/backward/full). Defaults to "full"
25
+ extra_config_filter (str, optional): A string to filter extra_benchmark_config.
26
+ Can be a substring to match or a 'key=value' pair (e.g., "'H': 4096").
27
+ Defaults to None, which means the first available config will be used if multiple exist.
25
28
  display (bool): Display the visualization. Defaults to False
26
29
  overwrite (bool): Overwrite existing visualization, if none exist this flag has no effect as ones are always created and saved. Defaults to False
27
30
 
@@ -30,6 +33,7 @@ class VisualizationsConfig:
30
33
  kernel_name: str
31
34
  metric_name: str
32
35
  kernel_operation_mode: str = "full"
36
+ extra_config_filter: str | None = None
33
37
  display: bool = False
34
38
  overwrite: bool = False
35
39
 
@@ -55,6 +59,14 @@ def parse_args() -> VisualizationsConfig:
55
59
  default=None,
56
60
  help="Kernel operation modes to visualize (forward/backward/full). If not provided, generate for all available modes.",
57
61
  )
62
+ parser.add_argument(
63
+ "--extra-config-filter",
64
+ type=str,
65
+ default=None,
66
+ help="A string to filter extra_benchmark_config. "
67
+ "Can be a substring to match or a JSON-like 'key=value' pair (e.g., \"'H': 4096\" or \"H=4096\" for simple cases). "
68
+ "Defaults to None (first available config if multiple exist).",
69
+ )
58
70
  parser.add_argument("--display", action="store_true", help="Display the visualization")
59
71
  parser.add_argument(
60
72
  "--overwrite",
@@ -81,19 +93,101 @@ def load_data(config: VisualizationsConfig) -> pd.DataFrame:
81
93
  df = pd.read_csv(DATA_PATH)
82
94
  df["extra_benchmark_config"] = df["extra_benchmark_config_str"].apply(json.loads)
83
95
 
84
- filtered_df = df[
96
+ base_filtered_df = df[
85
97
  (df["kernel_name"] == config.kernel_name)
86
98
  & (df["metric_name"] == config.metric_name)
87
99
  & (df["kernel_operation_mode"] == config.kernel_operation_mode)
88
- # Use this to filter by extra benchmark configuration property
89
- # & (data['extra_benchmark_config'].apply(lambda x: x.get('H') == 4096))
90
- # FIXME: maybe add a way to filter using some configuration, except of hardcoding it
91
100
  ]
92
101
 
93
- if filtered_df.empty:
94
- raise ValueError("No data found for the given filters")
102
+ if base_filtered_df.empty:
103
+ raise ValueError(
104
+ f"No data found for kernel_name='{config.kernel_name}', "
105
+ f"metric_name='{config.metric_name}', "
106
+ f"kernel_operation_mode='{config.kernel_operation_mode}'."
107
+ )
95
108
 
96
- return filtered_df
109
+ unique_extra_configs_str = base_filtered_df["extra_benchmark_config_str"].unique()
110
+ selected_extra_config_str = None
111
+
112
+ if len(unique_extra_configs_str) == 0:
113
+ print(
114
+ "Warning: No extra_benchmark_config found for the initial filters. "
115
+ "Proceeding with all data from initial filter."
116
+ )
117
+ return base_filtered_df
118
+
119
+ if config.extra_config_filter:
120
+ matched_configs = []
121
+ try:
122
+ if "=" in config.extra_config_filter:
123
+ key_filter, value_filter = config.extra_config_filter.split("=", 1)
124
+ for cfg_str in unique_extra_configs_str:
125
+ cfg_json = json.loads(cfg_str)
126
+ if str(cfg_json.get(key_filter.strip("'\" "))) == value_filter.strip("'\" "):
127
+ matched_configs.append(cfg_str)
128
+ if not matched_configs:
129
+ matched_configs = [
130
+ cfg_str for cfg_str in unique_extra_configs_str if config.extra_config_filter in cfg_str
131
+ ]
132
+ except Exception as e:
133
+ print(
134
+ f"Note: Could not parse extra_config_filter '{config.extra_config_filter}' as key=value ({e}), using substring match."
135
+ )
136
+ matched_configs = [cfg_str for cfg_str in unique_extra_configs_str if config.extra_config_filter in cfg_str]
137
+
138
+ if matched_configs:
139
+ if len(matched_configs) > 1:
140
+ print(
141
+ f"Warning: Multiple extra_benchmark_configs match filter '{config.extra_config_filter}': {matched_configs}. "
142
+ f"Using the first one: {matched_configs[0]}"
143
+ )
144
+ selected_extra_config_str = matched_configs[0]
145
+ else:
146
+ print(
147
+ f"Warning: No extra_benchmark_config matches filter '{config.extra_config_filter}'. "
148
+ f"Available configs for {config.kernel_name} ({config.metric_name}, {config.kernel_operation_mode}): {list(unique_extra_configs_str)}"
149
+ )
150
+ if len(unique_extra_configs_str) > 0:
151
+ selected_extra_config_str = unique_extra_configs_str[0]
152
+ print(f"Defaulting to the first available extra_benchmark_config: {selected_extra_config_str}")
153
+ else:
154
+ raise ValueError("No extra_benchmark_config available to select after failed filter attempt.")
155
+
156
+ elif len(unique_extra_configs_str) > 1:
157
+ selected_extra_config_str = unique_extra_configs_str[0]
158
+ print(
159
+ f"Warning: Multiple extra_benchmark_configs found for {config.kernel_name} ({config.metric_name}, {config.kernel_operation_mode})."
160
+ )
161
+ print(f"Defaulting to use: {selected_extra_config_str}")
162
+ print(f"Available configs: {list(unique_extra_configs_str)}")
163
+ print(
164
+ "Use the --extra-config-filter argument to select a specific one "
165
+ "(e.g., --extra-config-filter \"'H': 4096\" or a substring like \"'seq_len': 512\")."
166
+ )
167
+ elif len(unique_extra_configs_str) == 1:
168
+ selected_extra_config_str = unique_extra_configs_str[0]
169
+ print(f"Using unique extra_benchmark_config: {selected_extra_config_str}")
170
+
171
+ if selected_extra_config_str:
172
+ final_filtered_df = base_filtered_df[
173
+ base_filtered_df["extra_benchmark_config_str"] == selected_extra_config_str
174
+ ]
175
+ else:
176
+ print("Warning: Could not select an extra_benchmark_config. Using data from initial filter if any.")
177
+ final_filtered_df = base_filtered_df
178
+
179
+ if final_filtered_df.empty:
180
+ raise ValueError(
181
+ f"No data found after attempting to filter by extra_benchmark_config. "
182
+ f"Selected/Defaulted extra_config_str: {selected_extra_config_str}"
183
+ if selected_extra_config_str
184
+ else "No specific extra_config was selected."
185
+ )
186
+
187
+ print(
188
+ f"Plotting data for extra_benchmark_config: {json.loads(selected_extra_config_str if selected_extra_config_str else '{}')}"
189
+ )
190
+ return final_filtered_df
97
191
 
98
192
 
99
193
  def plot_data(df: pd.DataFrame, config: VisualizationsConfig):
@@ -103,6 +197,10 @@ def plot_data(df: pd.DataFrame, config: VisualizationsConfig):
103
197
  df (pd.DataFrame): Filtered benchmark dataframe.
104
198
  config (VisualizationsConfig): Configuration object for the visualizations script.
105
199
  """
200
+ for col in ["y_value_20", "y_value_50", "y_value_80"]:
201
+ if col in df.columns:
202
+ df[col] = pd.to_numeric(df[col], errors="coerce")
203
+
106
204
  xlabel = df["x_label"].iloc[0]
107
205
  ylabel = f"{config.metric_name} ({df['metric_unit'].iloc[0]})"
108
206
  # Sort by "kernel_provider" to ensure consistent color assignment
@@ -110,15 +208,26 @@ def plot_data(df: pd.DataFrame, config: VisualizationsConfig):
110
208
 
111
209
  plt.figure(figsize=(10, 6))
112
210
  sns.set(style="whitegrid")
113
- ax = sns.lineplot(
114
- data=df,
115
- x="x_value",
116
- y="y_value_50",
117
- hue="kernel_provider",
118
- marker="o",
119
- palette="tab10",
120
- errorbar=("ci", None),
121
- )
211
+ try:
212
+ ax = sns.lineplot(
213
+ data=df,
214
+ x="x_value",
215
+ y="y_value_50",
216
+ hue="kernel_provider",
217
+ marker="o",
218
+ palette="tab10",
219
+ errorbar=("ci", None),
220
+ )
221
+ except Exception:
222
+ ax = sns.lineplot(
223
+ data=df,
224
+ x="x_value",
225
+ y="y_value_50",
226
+ hue="kernel_provider",
227
+ marker="o",
228
+ palette="tab10",
229
+ errorbar=None,
230
+ )
122
231
 
123
232
  # Seaborn can't plot pre-computed error bars, so we need to do it manually
124
233
  lines = ax.get_lines()