liger-kernel 0.5.3__tar.gz → 0.5.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. liger_kernel-0.5.5/.github/workflows/amd-ci.yml +74 -0
  2. liger_kernel-0.5.3/.github/workflows/amd-ci.yml → liger_kernel-0.5.5/.github/workflows/intel-ci.yml +3 -3
  3. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/Makefile +11 -5
  4. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/PKG-INFO +19 -4
  5. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/README.md +18 -3
  6. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/data/all_benchmark_data.csv +66 -30
  7. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_distill_jsd_loss.py +2 -0
  8. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_kto_loss.py +10 -10
  9. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_rope.py +1 -1
  10. liger_kernel-0.5.5/benchmark/scripts/benchmark_tvd.py +133 -0
  11. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/dev/modal/tests.py +1 -1
  12. liger_kernel-0.5.5/docs/images/post-training.png +0 -0
  13. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/lightning/training.py +1 -1
  14. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/callback.py +3 -3
  15. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/pyproject.toml +1 -1
  16. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/setup.py +13 -4
  17. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/__init__.py +1 -0
  18. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/cpo_loss.py +51 -11
  19. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/dpo_loss.py +30 -4
  20. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +3 -3
  21. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/fused_linear_preference.py +2 -2
  22. liger_kernel-0.5.5/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +240 -0
  23. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +112 -17
  24. liger_kernel-0.5.5/src/liger_kernel/chunked_loss/grpo_loss.py +194 -0
  25. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/jsd_loss.py +31 -6
  26. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/kto_loss.py +53 -15
  27. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/orpo_loss.py +37 -5
  28. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/simpo_loss.py +47 -11
  29. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/cross_entropy.py +7 -3
  30. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/fused_linear_cross_entropy.py +3 -3
  31. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/fused_linear_jsd.py +3 -3
  32. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/jsd.py +3 -3
  33. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/layer_norm.py +20 -7
  34. liger_kernel-0.5.5/src/liger_kernel/ops/tvd.py +207 -0
  35. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/utils.py +1 -2
  36. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/__init__.py +4 -0
  37. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/cross_entropy.py +3 -3
  38. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/functional.py +17 -0
  39. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +3 -3
  40. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/group_norm.py +6 -6
  41. liger_kernel-0.5.5/src/liger_kernel/transformers/model/olmo2.py +124 -0
  42. liger_kernel-0.5.5/src/liger_kernel/transformers/model/qwen2_5_vl.py +205 -0
  43. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/monkey_patch.py +239 -27
  44. liger_kernel-0.5.5/src/liger_kernel/transformers/tvd.py +13 -0
  45. liger_kernel-0.5.5/src/liger_kernel/utils.py +60 -0
  46. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel.egg-info/PKG-INFO +19 -4
  47. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel.egg-info/SOURCES.txt +20 -3
  48. liger_kernel-0.5.5/test/chunked_loss/test_grpo_loss.py +275 -0
  49. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/chunked_loss/test_jsd_loss.py +49 -10
  50. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/chunked_loss/test_kto_loss.py +85 -8
  51. liger_kernel-0.5.5/test/convergence/bf16/__init__.py +0 -0
  52. {liger_kernel-0.5.3/test/convergence → liger_kernel-0.5.5/test/convergence/bf16}/test_mini_models.py +206 -36
  53. {liger_kernel-0.5.3/test/convergence → liger_kernel-0.5.5/test/convergence/bf16}/test_mini_models_multimodal.py +92 -28
  54. {liger_kernel-0.5.3/test/convergence → liger_kernel-0.5.5/test/convergence/bf16}/test_mini_models_with_logits.py +205 -35
  55. liger_kernel-0.5.5/test/convergence/fp32/__init__.py +0 -0
  56. liger_kernel-0.5.5/test/convergence/fp32/test_mini_models.py +747 -0
  57. liger_kernel-0.5.5/test/convergence/fp32/test_mini_models_multimodal.py +513 -0
  58. liger_kernel-0.5.5/test/convergence/fp32/test_mini_models_with_logits.py +746 -0
  59. liger_kernel-0.5.5/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +63 -0
  60. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_cross_entropy.py +39 -0
  61. liger_kernel-0.5.5/test/transformers/test_flex_attention.py +291 -0
  62. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_layer_norm.py +20 -5
  63. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_monkey_patch.py +122 -3
  64. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_qwen2vl_mrope.py +3 -2
  65. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_rope.py +1 -1
  66. liger_kernel-0.5.5/test/transformers/test_tvd.py +188 -0
  67. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/triton/test_triton_monkey_patch.py +3 -3
  68. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/utils.py +36 -42
  69. liger_kernel-0.5.3/docs/images/post-training.png +0 -0
  70. liger_kernel-0.5.3/src/liger_kernel/utils.py +0 -13
  71. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  72. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  73. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/.github/pull_request_template.md +0 -0
  74. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/.github/workflows/docs.yml +0 -0
  75. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/.github/workflows/nvi-ci.yml +0 -0
  76. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/.github/workflows/publish-nightly.yml +0 -0
  77. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/.github/workflows/publish-release.yml +0 -0
  78. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/.gitignore +0 -0
  79. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/LICENSE +0 -0
  80. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/NOTICE +0 -0
  81. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/README.md +0 -0
  82. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/__init__.py +0 -0
  83. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/benchmarks_visualizer.py +0 -0
  84. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/__init__.py +0 -0
  85. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  86. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  87. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  88. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_embedding.py +0 -0
  89. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  90. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  91. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_geglu.py +0 -0
  92. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_group_norm.py +0 -0
  93. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_jsd.py +0 -0
  94. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_kl_div.py +0 -0
  95. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  96. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  97. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  98. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  99. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  100. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_swiglu.py +0 -0
  101. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/benchmark/scripts/utils.py +0 -0
  102. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/dev/fmt-requirements.txt +0 -0
  103. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/dev/modal/tests_bwd.py +0 -0
  104. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/docs/Examples.md +0 -0
  105. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/docs/Getting-Started.md +0 -0
  106. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/docs/High-Level-APIs.md +0 -0
  107. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/docs/Low-Level-APIs.md +0 -0
  108. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/docs/acknowledgement.md +0 -0
  109. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/docs/contributing.md +0 -0
  110. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/docs/images/banner.GIF +0 -0
  111. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/docs/images/compose.gif +0 -0
  112. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/docs/images/e2e-memory.png +0 -0
  113. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/docs/images/e2e-tps.png +0 -0
  114. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/docs/images/logo-banner.png +0 -0
  115. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/docs/images/patch.gif +0 -0
  116. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/docs/index.md +0 -0
  117. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/docs/license.md +0 -0
  118. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/alignment/accelerate_config.yaml +0 -0
  119. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/alignment/run_orpo.py +0 -0
  120. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/README.md +0 -0
  121. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/callback.py +0 -0
  122. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/config/fsdp_config.json +0 -0
  123. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  124. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  125. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  126. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/img/llama_tps.png +0 -0
  127. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  128. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/img/qwen_tps.png +0 -0
  129. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/launch_on_modal.py +0 -0
  130. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/requirements.txt +0 -0
  131. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/run_benchmarks.sh +0 -0
  132. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/run_gemma.sh +0 -0
  133. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/run_llama.sh +0 -0
  134. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/run_qwen.sh +0 -0
  135. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/run_qwen2_vl.sh +0 -0
  136. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/training.py +0 -0
  137. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/huggingface/training_multimodal.py +0 -0
  138. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/lightning/README.md +0 -0
  139. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/lightning/requirements.txt +0 -0
  140. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/README.md +0 -0
  141. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  142. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  143. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  144. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  145. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  146. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  147. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  148. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  149. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  150. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/medusa_util.py +0 -0
  151. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/requirements.txt +0 -0
  152. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  153. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/examples/medusa/train.py +0 -0
  154. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/licenses/LICENSE-Apache-2.0 +0 -0
  155. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  156. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  157. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/licenses/LICENSE-MIT-llmc +0 -0
  158. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/licenses/LICENSE-MIT-triton +0 -0
  159. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/mkdocs.yml +0 -0
  160. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/setup.cfg +0 -0
  161. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/__init__.py +0 -0
  162. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/README.md +0 -0
  163. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/functional.py +0 -0
  164. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/env_report.py +0 -0
  165. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/__init__.py +0 -0
  166. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  167. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  168. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/geglu.py +0 -0
  169. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/group_norm.py +0 -0
  170. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/kl_div.py +0 -0
  171. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  172. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/rms_norm.py +0 -0
  173. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/rope.py +0 -0
  174. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/ops/swiglu.py +0 -0
  175. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/auto_model.py +0 -0
  176. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  177. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  178. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/geglu.py +0 -0
  179. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/jsd.py +0 -0
  180. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/kl_div.py +0 -0
  181. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/layer_norm.py +0 -0
  182. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/__init__.py +0 -0
  183. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/gemma.py +0 -0
  184. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  185. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/llama.py +0 -0
  186. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/mistral.py +0 -0
  187. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  188. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/mllama.py +0 -0
  189. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/phi3.py +0 -0
  190. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  191. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  192. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  193. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/rms_norm.py +0 -0
  194. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/rope.py +0 -0
  195. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/swiglu.py +0 -0
  196. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  197. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  198. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  199. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/triton/__init__.py +0 -0
  200. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel/triton/monkey_patch.py +0 -0
  201. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel.egg-info/dependency_links.txt +0 -0
  202. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel.egg-info/requires.txt +0 -0
  203. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/src/liger_kernel.egg-info/top_level.txt +0 -0
  204. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/__init__.py +0 -0
  205. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/chunked_loss/__init__.py +0 -0
  206. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/chunked_loss/test_cpo_loss.py +0 -0
  207. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/chunked_loss/test_dpo_loss.py +0 -0
  208. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/chunked_loss/test_orpo_loss.py +0 -0
  209. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/chunked_loss/test_simpo_loss.py +0 -0
  210. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/conftest.py +0 -0
  211. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/convergence/__init__.py +0 -0
  212. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  213. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  214. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  215. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/resources/tiny_shakespeare.txt +0 -0
  216. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  217. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  218. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  219. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_auto_model.py +0 -0
  220. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_embedding.py +0 -0
  221. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  222. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_fused_linear_jsd.py +0 -0
  223. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_geglu.py +0 -0
  224. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_group_norm.py +0 -0
  225. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_jsd.py +0 -0
  226. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_kl_div.py +0 -0
  227. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_mm_int8int2.py +0 -0
  228. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_rms_norm.py +0 -0
  229. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_swiglu.py +0 -0
  230. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_trainer_integration.py +0 -0
  231. {liger_kernel-0.5.3 → liger_kernel-0.5.5}/test/transformers/test_transformers.py +0 -0
@@ -0,0 +1,74 @@
1
+ name: AMD GPU
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ paths:
8
+ - "src/**"
9
+ - "test/**"
10
+ pull_request:
11
+ branches:
12
+ - main
13
+ paths:
14
+ - "src/**"
15
+ - "test/**"
16
+ schedule:
17
+ # Runs at 00:00 UTC daily
18
+ - cron: '0 0 * * *'
19
+ workflow_dispatch: # Enables manual trigger
20
+
21
+ concurrency:
22
+ # This causes it to cancel previous in-progress actions on the same PR / branch,
23
+ group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
24
+ cancel-in-progress: true
25
+
26
+ jobs:
27
+ checkstyle:
28
+ runs-on: ubuntu-latest
29
+
30
+ steps:
31
+ - name: Checkout code
32
+ uses: actions/checkout@v3
33
+
34
+ - name: Set up Python
35
+ uses: actions/setup-python@v3
36
+ with:
37
+ python-version: '3.10'
38
+
39
+ - name: Install dependencies
40
+ run: |
41
+ python -m pip install --upgrade pip
42
+ pip install -r dev/fmt-requirements.txt
43
+
44
+ - name: Run checkstyle
45
+ run: make checkstyle
46
+
47
+ tests:
48
+ runs-on: linux-mi300-gpu-1
49
+ needs: [checkstyle]
50
+ strategy:
51
+ matrix:
52
+ rocm_version: ['6.2', '6.3']
53
+
54
+ steps:
55
+ - name: Checkout code
56
+ uses: actions/checkout@v3
57
+
58
+ - name: Set up Python
59
+ uses: actions/setup-python@v3
60
+ with:
61
+ python-version: '3.10'
62
+
63
+ - name: Setup Dependencies
64
+ run: |
65
+ python -m pip install --upgrade pip
66
+ pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/nightly/rocm${{ matrix.rocm_version }}
67
+
68
+ - name: List Python Environments
69
+ run: python -m pip list
70
+
71
+ - name: Run Unit Tests
72
+ run: |
73
+ make test
74
+ make test-convergence
@@ -1,4 +1,4 @@
1
- name: AMD GPU
1
+ name: Intel GPU
2
2
 
3
3
  on:
4
4
  push:
@@ -45,7 +45,7 @@ jobs:
45
45
  run: make checkstyle
46
46
 
47
47
  tests:
48
- runs-on: linux-mi300-gpu-1
48
+ runs-on: linux-max1550-gpu-8
49
49
  needs: [checkstyle]
50
50
 
51
51
  steps:
@@ -60,7 +60,7 @@ jobs:
60
60
  - name: Setup Dependencies
61
61
  run: |
62
62
  python -m pip install --upgrade pip
63
- pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2
63
+ pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/test/xpu
64
64
 
65
65
  - name: List Python Environments
66
66
  run: python -m pip list
@@ -9,8 +9,10 @@ test:
9
9
 
10
10
  # Command to run ruff for linting and formatting code
11
11
  checkstyle:
12
- ruff check . --fix; ruff_check_status=$$?; \
13
- ruff format .; ruff_format_status=$$?; \
12
+ ruff check .; ruff_check_status=$$?; \
13
+ ruff format --check .; ruff_format_status=$$?; \
14
+ ruff check . --fix; \
15
+ ruff format .; \
14
16
  if [ $$ruff_check_status -ne 0 ] || [ $$ruff_format_status -ne 0 ]; then \
15
17
  exit 1; \
16
18
  fi
@@ -18,9 +20,13 @@ checkstyle:
18
20
  # Command to run pytest for convergence tests
19
21
  # We have to explicitly set HF_DATASETS_OFFLINE=1, or dataset will silently try to send metrics and timeout (80s) https://github.com/huggingface/datasets/blob/37a603679f451826cfafd8aae00738b01dcb9d58/src/datasets/load.py#L286
20
22
  test-convergence:
21
- HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models.py
22
- HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models_multimodal.py
23
- HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models_with_logits.py
23
+ HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/fp32/test_mini_models.py
24
+ HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/fp32/test_mini_models_multimodal.py
25
+ HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/fp32/test_mini_models_with_logits.py
26
+
27
+ HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models.py
28
+ HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models_multimodal.py
29
+ HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models_with_logits.py
24
30
 
25
31
  # Command to run all benchmark scripts and update benchmarking data file
26
32
  # By default this doesn't overwrite existing data for the same benchmark experiment
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: liger_kernel
3
- Version: 0.5.3
3
+ Version: 0.5.5
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -97,6 +97,11 @@ Dynamic: requires-dist
97
97
  <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
98
98
  </a>
99
99
  </div>
100
+ <div style="display: block;">
101
+ <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
102
+ <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?event=schedule" alt="Build">
103
+ </a>
104
+ </div>
100
105
  </td>
101
106
  </tr>
102
107
  </table>
@@ -123,7 +128,7 @@ Dynamic: requires-dist
123
128
 
124
129
  **Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
125
130
 
126
- We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
131
+ We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, KTO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
127
132
 
128
133
  ## Supercharge Your Model with Liger Kernel
129
134
 
@@ -149,7 +154,7 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
149
154
  We provide optimized post training kernels like DPO, ORPO, SimPO, and more which can reduce memory usage by up to 80%. You can easily use them as python modules.
150
155
 
151
156
  ```python
152
- from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss
157
+ from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss
153
158
  orpo_loss = LigerFusedLinearORPOLoss()
154
159
  y = orpo_loss(lm_head.weight, x, target)
155
160
  ```
@@ -188,6 +193,11 @@ y = orpo_loss(lm_head.weight, x, target)
188
193
  - `torch >= 2.5.0` Install according to the instruction in Pytorch official webpage.
189
194
  - `triton >= 3.0.0` Install from pypi. (e.g. `pip install triton==3.0.0`)
190
195
 
196
+ ```bash
197
+ # Need to pass the url when installing
198
+ pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2
199
+ ```
200
+
191
201
  ### Optional Dependencies
192
202
 
193
203
  - `transformers >= 4.x`: Required if you plan to use the transformers models patching APIs. The specific model you are working will dictate the minimum version of transformers.
@@ -304,7 +314,10 @@ loss.backward()
304
314
  | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
305
315
  | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
306
316
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
317
+ | Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
307
318
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
319
+ | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
320
+ | OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
308
321
 
309
322
 
310
323
  ## Low-level APIs
@@ -333,6 +346,7 @@ loss.backward()
333
346
  | Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
334
347
  | Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
335
348
  | Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
349
+ | Fused Linear KTO Loss | `liger_kernel.chunked_loss.LigerFusedLinearKTOLoss` |
336
350
 
337
351
  ### Distillation Kernels
338
352
 
@@ -341,6 +355,7 @@ loss.backward()
341
355
  | KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
342
356
  | JSD | `liger_kernel.transformers.LigerJSD` |
343
357
  | Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
358
+ | TVD | `liger_kernel.transformers.LigerTVDLoss` |
344
359
 
345
360
  ### Experimental Kernels
346
361
 
@@ -372,7 +387,7 @@ loss.backward()
372
387
 
373
388
  - For issues, create a Github ticket in this repository
374
389
  - For open discussion, join [our discord channel](https://discord.gg/gpumode)
375
- - For formal collaboration, send an email to byhsu@linkedin.com
390
+ - For formal collaboration, send an email to yannchen@linkedin.com
376
391
 
377
392
  ## Cite this work
378
393
 
@@ -47,6 +47,11 @@
47
47
  <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
48
48
  </a>
49
49
  </div>
50
+ <div style="display: block;">
51
+ <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
52
+ <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?event=schedule" alt="Build">
53
+ </a>
54
+ </div>
50
55
  </td>
51
56
  </tr>
52
57
  </table>
@@ -73,7 +78,7 @@
73
78
 
74
79
  **Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
75
80
 
76
- We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
81
+ We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, KTO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
77
82
 
78
83
  ## Supercharge Your Model with Liger Kernel
79
84
 
@@ -99,7 +104,7 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
99
104
  We provide optimized post training kernels like DPO, ORPO, SimPO, and more which can reduce memory usage by up to 80%. You can easily use them as python modules.
100
105
 
101
106
  ```python
102
- from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss
107
+ from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss
103
108
  orpo_loss = LigerFusedLinearORPOLoss()
104
109
  y = orpo_loss(lm_head.weight, x, target)
105
110
  ```
@@ -138,6 +143,11 @@ y = orpo_loss(lm_head.weight, x, target)
138
143
  - `torch >= 2.5.0` Install according to the instruction in Pytorch official webpage.
139
144
  - `triton >= 3.0.0` Install from pypi. (e.g. `pip install triton==3.0.0`)
140
145
 
146
+ ```bash
147
+ # Need to pass the url when installing
148
+ pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2
149
+ ```
150
+
141
151
  ### Optional Dependencies
142
152
 
143
153
  - `transformers >= 4.x`: Required if you plan to use the transformers models patching APIs. The specific model you are working will dictate the minimum version of transformers.
@@ -254,7 +264,10 @@ loss.backward()
254
264
  | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
255
265
  | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
256
266
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
267
+ | Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
257
268
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
269
+ | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
270
+ | OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
258
271
 
259
272
 
260
273
  ## Low-level APIs
@@ -283,6 +296,7 @@ loss.backward()
283
296
  | Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
284
297
  | Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
285
298
  | Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
299
+ | Fused Linear KTO Loss | `liger_kernel.chunked_loss.LigerFusedLinearKTOLoss` |
286
300
 
287
301
  ### Distillation Kernels
288
302
 
@@ -291,6 +305,7 @@ loss.backward()
291
305
  | KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
292
306
  | JSD | `liger_kernel.transformers.LigerJSD` |
293
307
  | Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
308
+ | TVD | `liger_kernel.transformers.LigerTVDLoss` |
294
309
 
295
310
  ### Experimental Kernels
296
311
 
@@ -322,7 +337,7 @@ loss.backward()
322
337
 
323
338
  - For issues, create a Github ticket in this repository
324
339
  - For open discussion, join [our discord channel](https://discord.gg/gpumode)
325
- - For formal collaboration, send an email to byhsu@linkedin.com
340
+ - For formal collaboration, send an email to yannchen@linkedin.com
326
341
 
327
342
  ## Cite this work
328
343
 
@@ -505,6 +505,42 @@ fused_linear_jsd,torch,full,memory,MB,BT,B x T,1024,10609.005859375,10609.005859
505
505
  fused_linear_jsd,torch,full,memory,MB,BT,B x T,2048,17146.009765625,17146.009765625,17146.009765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1
506
506
  fused_linear_jsd,torch,full,memory,MB,BT,B x T,4096,30220.017578125,30220.017578125,30220.017578125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1
507
507
  fused_linear_jsd,torch,full,memory,MB,BT,B x T,8192,56368.015625,56368.015625,56368.015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1
508
+ tvd,liger,full,memory,MB,V,vocab size,4096,1792.0009765625,1792.0009765625,1792.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
509
+ tvd,liger,full,memory,MB,V,vocab size,8192,3584.0009765625,3584.0009765625,3584.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
510
+ tvd,liger,full,memory,MB,V,vocab size,16384,7168.0009765625,7168.0009765625,7168.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
511
+ tvd,liger,full,memory,MB,V,vocab size,32768,14336.0009765625,14336.0009765625,14336.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
512
+ tvd,liger,full,memory,MB,V,vocab size,65536,28672.0,28672.0,28672.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
513
+ tvd,liger,full,memory,MB,V,vocab size,131072,57344.0,57344.0,57344.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
514
+ tvd,torch,full,memory,MB,V,vocab size,4096,2048.0009765625,2048.0009765625,2048.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
515
+ tvd,torch,full,memory,MB,V,vocab size,8192,4096.0009765625,4096.0009765625,4096.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
516
+ tvd,torch,full,memory,MB,V,vocab size,16384,8192.0009765625,8192.0009765625,8192.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
517
+ tvd,torch,full,memory,MB,V,vocab size,32768,16384.0,16384.0,16384.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
518
+ tvd,torch,full,memory,MB,V,vocab size,65536,32768.0,32768.0,32768.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
519
+ tvd,torch,full,memory,MB,V,vocab size,131072,65536.0,65536.0,65536.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
520
+ tvd,liger,forward,speed,ms,V,vocab size,4096,0.47814399003982544,0.4774720072746277,0.4790079891681671,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
521
+ tvd,liger,forward,speed,ms,V,vocab size,8192,0.906495988368988,0.905951976776123,0.9073920249938965,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
522
+ tvd,liger,forward,speed,ms,V,vocab size,16384,1.8787360191345215,1.8778239488601685,1.8797119855880737,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
523
+ tvd,liger,forward,speed,ms,V,vocab size,32768,3.5788800716400146,3.5772159099578857,3.58076810836792,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
524
+ tvd,liger,forward,speed,ms,V,vocab size,65536,7.008831977844238,7.007718086242676,7.010636806488037,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
525
+ tvd,liger,forward,speed,ms,V,vocab size,131072,13.88646411895752,13.88128662109375,13.890560150146484,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
526
+ tvd,torch,forward,speed,ms,V,vocab size,4096,1.308608055114746,1.306502342224121,1.3104127645492554,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
527
+ tvd,torch,forward,speed,ms,V,vocab size,8192,2.4735519886016846,2.472287893295288,2.4749441146850586,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
528
+ tvd,torch,forward,speed,ms,V,vocab size,16384,4.828320026397705,4.826848030090332,4.830643177032471,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
529
+ tvd,torch,forward,speed,ms,V,vocab size,32768,9.5206880569458,9.517024040222168,9.525145530700684,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
530
+ tvd,torch,forward,speed,ms,V,vocab size,65536,19.01535987854004,19.011123657226562,19.01806640625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
531
+ tvd,torch,forward,speed,ms,V,vocab size,131072,38.022865295410156,38.01945877075195,38.02627182006836,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
532
+ tvd,liger,full,speed,ms,V,vocab size,4096,2.626512050628662,2.621260643005371,2.646751880645752,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
533
+ tvd,liger,full,speed,ms,V,vocab size,8192,4.661711692810059,4.657618999481201,4.662930965423584,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
534
+ tvd,liger,full,speed,ms,V,vocab size,16384,9.088272094726562,9.080741882324219,9.092268943786621,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
535
+ tvd,liger,full,speed,ms,V,vocab size,32768,18.116064071655273,18.112728118896484,18.118234634399414,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
536
+ tvd,liger,full,speed,ms,V,vocab size,65536,35.85124969482422,35.849971771240234,35.85252380371094,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
537
+ tvd,liger,full,speed,ms,V,vocab size,131072,71.1648941040039,71.1648941040039,71.1648941040039,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
538
+ tvd,torch,full,speed,ms,V,vocab size,4096,4.361599922180176,4.360159873962402,4.3639678955078125,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
539
+ tvd,torch,full,speed,ms,V,vocab size,8192,8.11302375793457,8.11075210571289,8.114463806152344,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
540
+ tvd,torch,full,speed,ms,V,vocab size,16384,15.841055870056152,15.837087631225586,15.841856002807617,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
541
+ tvd,torch,full,speed,ms,V,vocab size,32768,31.71219253540039,31.706951141357422,31.715898513793945,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
542
+ tvd,torch,full,speed,ms,V,vocab size,65536,63.17919921875,63.17919921875,63.17919921875,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
543
+ tvd,torch,full,speed,ms,V,vocab size,131072,126.0436782836914,126.0436782836914,126.0436782836914,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
508
544
  group_norm,liger,forward,speed,ms,C,num_channels,32,0.03481600061058998,0.03379200026392937,0.03993599861860275,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1
509
545
  group_norm,liger,forward,speed,ms,C,num_channels,64,0.05222399905323982,0.05119999870657921,0.05222399905323982,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1
510
546
  group_norm,liger,forward,speed,ms,C,num_channels,128,0.08499199897050858,0.08396799862384796,0.08499199897050858,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1
@@ -715,36 +751,6 @@ fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.314
715
751
  fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
716
752
  fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
717
753
  fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
718
- kto_loss,liger,forward,speed,ms,B,Batch Size (B),2,7.841599941253662,7.801983833312988,7.849664211273193,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2
719
- kto_loss,liger,forward,speed,ms,B,Batch Size (B),4,15.568096160888672,15.555737495422363,16.054176330566406,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2
720
- kto_loss,liger,forward,speed,ms,B,Batch Size (B),8,31.145376205444336,30.750951766967773,31.5398006439209,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2
721
- kto_loss,liger,forward,speed,ms,B,Batch Size (B),16,61.49708938598633,61.49708938598633,61.49708938598633,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2
722
- kto_loss,liger,forward,speed,ms,B,Batch Size (B),32,122.01449584960938,122.01449584960938,122.01449584960938,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2
723
- kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),2,7.892335891723633,7.8687615394592285,8.03729248046875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2
724
- kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),4,14.16302490234375,13.813311576843262,15.860223770141602,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2
725
- kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),8,25.56470489501953,25.564167022705078,25.641658782958984,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2
726
- kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),16,53.0928955078125,53.0928955078125,53.0928955078125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2
727
- kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),32,108.76080322265625,108.76080322265625,108.76080322265625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2
728
- kto_loss,liger,full,speed,ms,B,Batch Size (B),2,8.662687301635742,8.488287925720215,9.611334800720215,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2
729
- kto_loss,liger,full,speed,ms,B,Batch Size (B),4,18.40096092224121,17.99224281311035,18.57883644104004,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2
730
- kto_loss,liger,full,speed,ms,B,Batch Size (B),8,32.09159851074219,31.708070755004883,32.475128173828125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2
731
- kto_loss,liger,full,speed,ms,B,Batch Size (B),16,69.30239868164062,69.30239868164062,69.30239868164062,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2
732
- kto_loss,liger,full,speed,ms,B,Batch Size (B),32,124.2437744140625,124.2437744140625,124.2437744140625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2
733
- kto_loss,huggingface,full,speed,ms,B,Batch Size (B),2,11.449472427368164,11.407564163208008,11.773555755615234,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2
734
- kto_loss,huggingface,full,speed,ms,B,Batch Size (B),4,20.871471405029297,20.862951278686523,20.879276275634766,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2
735
- kto_loss,huggingface,full,speed,ms,B,Batch Size (B),8,41.16409683227539,40.760780334472656,41.567413330078125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2
736
- kto_loss,huggingface,full,speed,ms,B,Batch Size (B),16,77.720703125,77.720703125,77.720703125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2
737
- kto_loss,huggingface,full,speed,ms,B,Batch Size (B),32,156.25794982910156,156.25794982910156,156.25794982910156,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2
738
- kto_loss,liger,full,memory,MB,B,Batch Size (B),2,2027.48583984375,2027.48583984375,2027.48583984375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2
739
- kto_loss,liger,full,memory,MB,B,Batch Size (B),4,2789.736328125,2789.736328125,2789.736328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2
740
- kto_loss,liger,full,memory,MB,B,Batch Size (B),8,2801.751953125,2801.751953125,2801.751953125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2
741
- kto_loss,liger,full,memory,MB,B,Batch Size (B),16,2825.783203125,2825.783203125,2825.783203125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2
742
- kto_loss,liger,full,memory,MB,B,Batch Size (B),32,2873.845703125,2873.845703125,2873.845703125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2
743
- kto_loss,huggingface,full,memory,MB,B,Batch Size (B),2,3786.7373046875,3786.7373046875,3786.7373046875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
744
- kto_loss,huggingface,full,memory,MB,B,Batch Size (B),4,5544.25390625,5544.25390625,5544.25390625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
745
- kto_loss,huggingface,full,memory,MB,B,Batch Size (B),8,9057.287109375,9057.287109375,9057.287109375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
746
- kto_loss,huggingface,full,memory,MB,B,Batch Size (B),16,16087.353515625,16087.353515625,16087.353515625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
747
- kto_loss,huggingface,full,memory,MB,B,Batch Size (B),32,30147.486328125,30147.486328125,30147.486328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
748
754
  distill_jsd_loss,liger,forward,speed,ms,BT,B x T,1024,7.735536098480225,7.729177474975586,7.798131465911865,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
749
755
  distill_jsd_loss,liger,forward,speed,ms,BT,B x T,2048,15.20411205291748,15.165056228637695,15.226079940795898,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
750
756
  distill_jsd_loss,liger,forward,speed,ms,BT,B x T,4096,30.159456253051758,30.126911163330078,30.165311813354492,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
@@ -769,3 +775,33 @@ distill_jsd_loss,torch,full,memory,MB,BT,B x T,1024,16174.0390625,16174.0390625,
769
775
  distill_jsd_loss,torch,full,memory,MB,BT,B x T,2048,23713.05078125,23713.05078125,23713.05078125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
770
776
  distill_jsd_loss,torch,full,memory,MB,BT,B x T,4096,38791.07421875,38791.07421875,38791.07421875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
771
777
  distill_jsd_loss,torch,full,memory,MB,BT,B x T,8192,68947.1015625,68947.1015625,68947.1015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
778
+ kto_loss,liger,forward,speed,ms,B,Batch Size (B),2,3.9951679706573486,3.991487979888916,4.002252578735352,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:22:44,0.5.4
779
+ kto_loss,liger,forward,speed,ms,B,Batch Size (B),4,7.8037919998168945,7.788575649261475,7.808595180511475,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:22:44,0.5.4
780
+ kto_loss,liger,forward,speed,ms,B,Batch Size (B),8,15.43172836303711,15.430015563964844,15.4335355758667,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:22:44,0.5.4
781
+ kto_loss,liger,forward,speed,ms,B,Batch Size (B),16,30.66864013671875,30.66431999206543,30.670501708984375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:22:44,0.5.4
782
+ kto_loss,liger,forward,speed,ms,B,Batch Size (B),32,61.1163215637207,61.1163215637207,61.1163215637207,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:22:44,0.5.4
783
+ kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),2,3.8766400814056396,3.8680384159088135,3.8897151947021484,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:01,0.5.4
784
+ kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),4,7.213727951049805,7.206470489501953,7.229574680328369,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:01,0.5.4
785
+ kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),8,13.828800201416016,13.810944557189941,13.834943771362305,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:01,0.5.4
786
+ kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),16,27.0930233001709,27.08517074584961,27.09713363647461,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:01,0.5.4
787
+ kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),32,54.13715362548828,54.13715362548828,54.13715362548828,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:01,0.5.4
788
+ kto_loss,liger,full,speed,ms,B,Batch Size (B),2,4.782928466796875,4.677459239959717,5.3430914878845215,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:18,0.5.4
789
+ kto_loss,liger,full,speed,ms,B,Batch Size (B),4,8.517248153686523,8.481344223022461,8.561504364013672,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:18,0.5.4
790
+ kto_loss,liger,full,speed,ms,B,Batch Size (B),8,16.547504425048828,16.513471603393555,16.678144454956055,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:18,0.5.4
791
+ kto_loss,liger,full,speed,ms,B,Batch Size (B),16,31.891263961791992,31.819705963134766,32.274131774902344,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:18,0.5.4
792
+ kto_loss,liger,full,speed,ms,B,Batch Size (B),32,62.953758239746094,62.953758239746094,62.953758239746094,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:18,0.5.4
793
+ kto_loss,huggingface,full,speed,ms,B,Batch Size (B),2,6.201632022857666,6.163315296173096,6.314668655395508,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:35,0.5.4
794
+ kto_loss,huggingface,full,speed,ms,B,Batch Size (B),4,11.156224250793457,11.142304420471191,11.207296371459961,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:35,0.5.4
795
+ kto_loss,huggingface,full,speed,ms,B,Batch Size (B),8,21.249855041503906,21.231891632080078,21.264543533325195,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:35,0.5.4
796
+ kto_loss,huggingface,full,speed,ms,B,Batch Size (B),16,41.55686569213867,41.536956787109375,41.57677459716797,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:35,0.5.4
797
+ kto_loss,huggingface,full,speed,ms,B,Batch Size (B),32,81.56924438476562,81.56924438476562,81.56924438476562,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:35,0.5.4
798
+ kto_loss,liger,full,memory,MB,B,Batch Size (B),2,2585.73876953125,2585.73876953125,2585.73876953125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:55,0.5.4
799
+ kto_loss,liger,full,memory,MB,B,Batch Size (B),4,3348.9892578125,3348.9892578125,3348.9892578125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:55,0.5.4
800
+ kto_loss,liger,full,memory,MB,B,Batch Size (B),8,3361.0048828125,3361.0048828125,3361.0048828125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:55,0.5.4
801
+ kto_loss,liger,full,memory,MB,B,Batch Size (B),16,3385.0361328125,3385.0361328125,3385.0361328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:55,0.5.4
802
+ kto_loss,liger,full,memory,MB,B,Batch Size (B),32,3433.0986328125,3433.0986328125,3433.0986328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:55,0.5.4
803
+ kto_loss,huggingface,full,memory,MB,B,Batch Size (B),2,4341.74951171875,4341.74951171875,4341.74951171875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
804
+ kto_loss,huggingface,full,memory,MB,B,Batch Size (B),4,6099.26513671875,6099.26513671875,6099.26513671875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
805
+ kto_loss,huggingface,full,memory,MB,B,Batch Size (B),8,9613.298828125,9613.298828125,9613.298828125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
806
+ kto_loss,huggingface,full,memory,MB,B,Batch Size (B),16,16643.365234375,16643.365234375,16643.365234375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
807
+ kto_loss,huggingface,full,memory,MB,B,Batch Size (B),32,30703.498046875,30703.498046875,30703.498046875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
@@ -81,6 +81,8 @@ class LigerJSDLoss(torch.nn.Module):
81
81
  teacher,
82
82
  self.teacher_lin.weight,
83
83
  target,
84
+ self.student_lin.bias,
85
+ self.teacher_lin.bias,
84
86
  self.weight_hard_loss,
85
87
  self.weight_soft_loss,
86
88
  )
@@ -103,8 +103,8 @@ def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
103
103
  H=H,
104
104
  V=V,
105
105
  dtype=dtype,
106
- bias=bias,
107
- ref_bias=bias,
106
+ use_bias=bias,
107
+ use_ref_bias=bias,
108
108
  ignore_index=ignore_index,
109
109
  beta=beta,
110
110
  ).to(device)
@@ -113,8 +113,8 @@ def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
113
113
  H=H,
114
114
  V=V,
115
115
  dtype=dtype,
116
- bias=bias,
117
- ref_bias=bias,
116
+ use_bias=bias,
117
+ use_ref_bias=bias,
118
118
  ignore_index=ignore_index,
119
119
  beta=beta,
120
120
  ).to(device)
@@ -149,7 +149,7 @@ def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
149
149
  y=target,
150
150
  preference_labels=preference_labels,
151
151
  kl=kl,
152
- )
152
+ )[0]
153
153
  elif provider == "huggingface":
154
154
  return torch_kto_loss(
155
155
  x=_input,
@@ -157,7 +157,7 @@ def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
157
157
  y=target,
158
158
  preference_labels=preference_labels,
159
159
  kl=kl,
160
- )
160
+ )[0]
161
161
 
162
162
  def full():
163
163
  y = fwd()
@@ -189,7 +189,7 @@ def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
189
189
  dtype=dtype,
190
190
  beta=beta,
191
191
  ignore_index=ignore_index,
192
- bias=bias,
192
+ use_bias=bias,
193
193
  ).to(device)
194
194
  liger_kto_loss = LigerLMHeadKTO(
195
195
  H=H,
@@ -197,7 +197,7 @@ def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
197
197
  dtype=dtype,
198
198
  beta=beta,
199
199
  ignore_index=ignore_index,
200
- bias=bias,
200
+ use_bias=bias,
201
201
  ).to(device)
202
202
 
203
203
  # Input shape: [B, T, H]
@@ -230,7 +230,7 @@ def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
230
230
  y=target,
231
231
  preference_labels=preference_labels,
232
232
  kl=kl,
233
- )
233
+ )[0]
234
234
  elif provider == "huggingface":
235
235
  return torch_kto_loss(
236
236
  x=_input,
@@ -238,7 +238,7 @@ def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
238
238
  y=target,
239
239
  preference_labels=preference_labels,
240
240
  kl=kl,
241
- )
241
+ )[0]
242
242
 
243
243
  if mode == "forward":
244
244
  ms_50, ms_20, ms_80 = triton.testing.do_bench(
@@ -1,7 +1,6 @@
1
1
  import torch
2
2
  import triton
3
3
 
4
- from test.utils import transformers_version_dispatch
5
4
  from transformers.models.llama.configuration_llama import LlamaConfig
6
5
  from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
7
6
  from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
@@ -14,6 +13,7 @@ from utils import run_benchmarks
14
13
 
15
14
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
16
15
  from liger_kernel.utils import infer_device
16
+ from liger_kernel.utils import transformers_version_dispatch
17
17
 
18
18
  device = infer_device()
19
19