liger-kernel 0.5.3__tar.gz → 0.5.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (229) hide show
  1. liger_kernel-0.5.4/.github/workflows/intel-ci.yml +71 -0
  2. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/Makefile +11 -5
  3. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/PKG-INFO +17 -3
  4. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/README.md +16 -2
  5. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/data/all_benchmark_data.csv +37 -0
  6. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_kto_loss.py +6 -6
  7. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_rope.py +1 -1
  8. liger_kernel-0.5.4/benchmark/scripts/benchmark_tvd.py +133 -0
  9. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/dev/modal/tests.py +1 -1
  10. liger_kernel-0.5.4/docs/images/post-training.png +0 -0
  11. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/lightning/training.py +1 -1
  12. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/callback.py +3 -3
  13. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/pyproject.toml +1 -1
  14. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/setup.py +13 -4
  15. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/chunked_loss/__init__.py +1 -0
  16. liger_kernel-0.5.4/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +213 -0
  17. liger_kernel-0.5.4/src/liger_kernel/chunked_loss/grpo_loss.py +160 -0
  18. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/chunked_loss/kto_loss.py +9 -9
  19. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/cross_entropy.py +3 -3
  20. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/fused_linear_cross_entropy.py +3 -3
  21. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/fused_linear_jsd.py +3 -3
  22. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/jsd.py +3 -3
  23. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/layer_norm.py +20 -7
  24. liger_kernel-0.5.4/src/liger_kernel/ops/tvd.py +207 -0
  25. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/utils.py +1 -2
  26. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/__init__.py +3 -0
  27. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/cross_entropy.py +3 -3
  28. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/functional.py +17 -0
  29. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +3 -3
  30. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/group_norm.py +6 -6
  31. liger_kernel-0.5.4/src/liger_kernel/transformers/model/olmo2.py +124 -0
  32. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/monkey_patch.py +171 -27
  33. liger_kernel-0.5.4/src/liger_kernel/transformers/tvd.py +13 -0
  34. liger_kernel-0.5.4/src/liger_kernel/utils.py +62 -0
  35. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel.egg-info/PKG-INFO +17 -3
  36. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel.egg-info/SOURCES.txt +18 -3
  37. liger_kernel-0.5.4/test/chunked_loss/test_grpo_loss.py +275 -0
  38. liger_kernel-0.5.4/test/convergence/bf16/__init__.py +0 -0
  39. {liger_kernel-0.5.3/test/convergence → liger_kernel-0.5.4/test/convergence/bf16}/test_mini_models.py +121 -37
  40. {liger_kernel-0.5.3/test/convergence → liger_kernel-0.5.4/test/convergence/bf16}/test_mini_models_multimodal.py +1 -36
  41. {liger_kernel-0.5.3/test/convergence → liger_kernel-0.5.4/test/convergence/bf16}/test_mini_models_with_logits.py +120 -36
  42. liger_kernel-0.5.4/test/convergence/fp32/__init__.py +0 -0
  43. liger_kernel-0.5.4/test/convergence/fp32/test_mini_models.py +664 -0
  44. liger_kernel-0.5.4/test/convergence/fp32/test_mini_models_multimodal.py +415 -0
  45. liger_kernel-0.5.4/test/convergence/fp32/test_mini_models_with_logits.py +663 -0
  46. liger_kernel-0.5.4/test/transformers/test_flex_attention.py +291 -0
  47. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_layer_norm.py +20 -5
  48. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_monkey_patch.py +54 -3
  49. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_qwen2vl_mrope.py +3 -2
  50. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_rope.py +1 -1
  51. liger_kernel-0.5.4/test/transformers/test_tvd.py +188 -0
  52. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/triton/test_triton_monkey_patch.py +3 -3
  53. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/utils.py +18 -41
  54. liger_kernel-0.5.3/docs/images/post-training.png +0 -0
  55. liger_kernel-0.5.3/src/liger_kernel/utils.py +0 -13
  56. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  57. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  58. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/.github/pull_request_template.md +0 -0
  59. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/.github/workflows/amd-ci.yml +0 -0
  60. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/.github/workflows/docs.yml +0 -0
  61. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/.github/workflows/nvi-ci.yml +0 -0
  62. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/.github/workflows/publish-nightly.yml +0 -0
  63. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/.github/workflows/publish-release.yml +0 -0
  64. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/.gitignore +0 -0
  65. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/LICENSE +0 -0
  66. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/NOTICE +0 -0
  67. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/README.md +0 -0
  68. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/__init__.py +0 -0
  69. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/benchmarks_visualizer.py +0 -0
  70. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/__init__.py +0 -0
  71. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  72. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  73. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  74. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  75. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_embedding.py +0 -0
  76. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  77. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  78. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_geglu.py +0 -0
  79. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_group_norm.py +0 -0
  80. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_jsd.py +0 -0
  81. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_kl_div.py +0 -0
  82. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  83. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  84. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  85. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  86. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  87. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/benchmark_swiglu.py +0 -0
  88. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/benchmark/scripts/utils.py +0 -0
  89. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/dev/fmt-requirements.txt +0 -0
  90. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/dev/modal/tests_bwd.py +0 -0
  91. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/docs/Examples.md +0 -0
  92. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/docs/Getting-Started.md +0 -0
  93. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/docs/High-Level-APIs.md +0 -0
  94. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/docs/Low-Level-APIs.md +0 -0
  95. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/docs/acknowledgement.md +0 -0
  96. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/docs/contributing.md +0 -0
  97. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/docs/images/banner.GIF +0 -0
  98. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/docs/images/compose.gif +0 -0
  99. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/docs/images/e2e-memory.png +0 -0
  100. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/docs/images/e2e-tps.png +0 -0
  101. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/docs/images/logo-banner.png +0 -0
  102. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/docs/images/patch.gif +0 -0
  103. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/docs/index.md +0 -0
  104. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/docs/license.md +0 -0
  105. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/alignment/accelerate_config.yaml +0 -0
  106. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/alignment/run_orpo.py +0 -0
  107. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/README.md +0 -0
  108. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/callback.py +0 -0
  109. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/config/fsdp_config.json +0 -0
  110. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  111. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  112. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  113. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/img/llama_tps.png +0 -0
  114. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  115. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/img/qwen_tps.png +0 -0
  116. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/launch_on_modal.py +0 -0
  117. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/requirements.txt +0 -0
  118. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/run_benchmarks.sh +0 -0
  119. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/run_gemma.sh +0 -0
  120. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/run_llama.sh +0 -0
  121. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/run_qwen.sh +0 -0
  122. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/run_qwen2_vl.sh +0 -0
  123. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/training.py +0 -0
  124. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/huggingface/training_multimodal.py +0 -0
  125. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/lightning/README.md +0 -0
  126. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/lightning/requirements.txt +0 -0
  127. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/README.md +0 -0
  128. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  129. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  130. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  131. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  132. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  133. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  134. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  135. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  136. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  137. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/medusa_util.py +0 -0
  138. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/requirements.txt +0 -0
  139. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  140. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/examples/medusa/train.py +0 -0
  141. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/licenses/LICENSE-Apache-2.0 +0 -0
  142. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  143. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  144. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/licenses/LICENSE-MIT-llmc +0 -0
  145. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/licenses/LICENSE-MIT-triton +0 -0
  146. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/mkdocs.yml +0 -0
  147. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/setup.cfg +0 -0
  148. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/__init__.py +0 -0
  149. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/chunked_loss/README.md +0 -0
  150. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  151. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  152. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/chunked_loss/functional.py +0 -0
  153. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  154. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  155. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  156. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  157. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  158. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  159. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/env_report.py +0 -0
  160. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/__init__.py +0 -0
  161. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  162. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  163. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/geglu.py +0 -0
  164. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/group_norm.py +0 -0
  165. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/kl_div.py +0 -0
  166. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  167. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/rms_norm.py +0 -0
  168. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/rope.py +0 -0
  169. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/ops/swiglu.py +0 -0
  170. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/auto_model.py +0 -0
  171. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  172. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  173. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/geglu.py +0 -0
  174. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/jsd.py +0 -0
  175. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/kl_div.py +0 -0
  176. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/layer_norm.py +0 -0
  177. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/model/__init__.py +0 -0
  178. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/model/gemma.py +0 -0
  179. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  180. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/model/llama.py +0 -0
  181. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/model/mistral.py +0 -0
  182. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  183. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/model/mllama.py +0 -0
  184. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/model/phi3.py +0 -0
  185. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  186. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  187. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  188. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/rms_norm.py +0 -0
  189. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/rope.py +0 -0
  190. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/swiglu.py +0 -0
  191. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  192. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  193. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  194. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/triton/__init__.py +0 -0
  195. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel/triton/monkey_patch.py +0 -0
  196. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel.egg-info/dependency_links.txt +0 -0
  197. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel.egg-info/requires.txt +0 -0
  198. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/src/liger_kernel.egg-info/top_level.txt +0 -0
  199. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/__init__.py +0 -0
  200. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/chunked_loss/__init__.py +0 -0
  201. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/chunked_loss/test_cpo_loss.py +0 -0
  202. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/chunked_loss/test_dpo_loss.py +0 -0
  203. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/chunked_loss/test_jsd_loss.py +0 -0
  204. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/chunked_loss/test_kto_loss.py +0 -0
  205. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/chunked_loss/test_orpo_loss.py +0 -0
  206. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/chunked_loss/test_simpo_loss.py +0 -0
  207. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/conftest.py +0 -0
  208. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/convergence/__init__.py +0 -0
  209. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  210. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  211. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  212. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/resources/tiny_shakespeare.txt +0 -0
  213. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  214. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  215. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  216. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_auto_model.py +0 -0
  217. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_cross_entropy.py +0 -0
  218. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_embedding.py +0 -0
  219. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  220. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_fused_linear_jsd.py +0 -0
  221. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_geglu.py +0 -0
  222. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_group_norm.py +0 -0
  223. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_jsd.py +0 -0
  224. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_kl_div.py +0 -0
  225. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_mm_int8int2.py +0 -0
  226. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_rms_norm.py +0 -0
  227. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_swiglu.py +0 -0
  228. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_trainer_integration.py +0 -0
  229. {liger_kernel-0.5.3 → liger_kernel-0.5.4}/test/transformers/test_transformers.py +0 -0
@@ -0,0 +1,71 @@
1
+ name: Intel GPU
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ paths:
8
+ - "src/**"
9
+ - "test/**"
10
+ pull_request:
11
+ branches:
12
+ - main
13
+ paths:
14
+ - "src/**"
15
+ - "test/**"
16
+ schedule:
17
+ # Runs at 00:00 UTC daily
18
+ - cron: '0 0 * * *'
19
+ workflow_dispatch: # Enables manual trigger
20
+
21
+ concurrency:
22
+ # This causes it to cancel previous in-progress actions on the same PR / branch,
23
+ group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
24
+ cancel-in-progress: true
25
+
26
+ jobs:
27
+ checkstyle:
28
+ runs-on: ubuntu-latest
29
+
30
+ steps:
31
+ - name: Checkout code
32
+ uses: actions/checkout@v3
33
+
34
+ - name: Set up Python
35
+ uses: actions/setup-python@v3
36
+ with:
37
+ python-version: '3.10'
38
+
39
+ - name: Install dependencies
40
+ run: |
41
+ python -m pip install --upgrade pip
42
+ pip install -r dev/fmt-requirements.txt
43
+
44
+ - name: Run checkstyle
45
+ run: make checkstyle
46
+
47
+ tests:
48
+ runs-on: linux-max1550-gpu-8
49
+ needs: [checkstyle]
50
+
51
+ steps:
52
+ - name: Checkout code
53
+ uses: actions/checkout@v3
54
+
55
+ - name: Set up Python
56
+ uses: actions/setup-python@v3
57
+ with:
58
+ python-version: '3.10'
59
+
60
+ - name: Setup Dependencies
61
+ run: |
62
+ python -m pip install --upgrade pip
63
+ pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/test/xpu
64
+
65
+ - name: List Python Environments
66
+ run: python -m pip list
67
+
68
+ - name: Run Unit Tests
69
+ run: |
70
+ make test
71
+ make test-convergence
@@ -9,8 +9,10 @@ test:
9
9
 
10
10
  # Command to run ruff for linting and formatting code
11
11
  checkstyle:
12
- ruff check . --fix; ruff_check_status=$$?; \
13
- ruff format .; ruff_format_status=$$?; \
12
+ ruff check .; ruff_check_status=$$?; \
13
+ ruff format --check .; ruff_format_status=$$?; \
14
+ ruff check . --fix; \
15
+ ruff format .; \
14
16
  if [ $$ruff_check_status -ne 0 ] || [ $$ruff_format_status -ne 0 ]; then \
15
17
  exit 1; \
16
18
  fi
@@ -18,9 +20,13 @@ checkstyle:
18
20
  # Command to run pytest for convergence tests
19
21
  # We have to explicitly set HF_DATASETS_OFFLINE=1, or dataset will silently try to send metrics and timeout (80s) https://github.com/huggingface/datasets/blob/37a603679f451826cfafd8aae00738b01dcb9d58/src/datasets/load.py#L286
20
22
  test-convergence:
21
- HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models.py
22
- HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models_multimodal.py
23
- HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models_with_logits.py
23
+ HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/fp32/test_mini_models.py
24
+ HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/fp32/test_mini_models_multimodal.py
25
+ HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/fp32/test_mini_models_with_logits.py
26
+
27
+ HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models.py
28
+ HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models_multimodal.py
29
+ HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models_with_logits.py
24
30
 
25
31
  # Command to run all benchmark scripts and update benchmarking data file
26
32
  # By default this doesn't overwrite existing data for the same benchmark experiment
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: liger_kernel
3
- Version: 0.5.3
3
+ Version: 0.5.4
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -97,6 +97,11 @@ Dynamic: requires-dist
97
97
  <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
98
98
  </a>
99
99
  </div>
100
+ <div style="display: block;">
101
+ <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
102
+ <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?event=schedule" alt="Build">
103
+ </a>
104
+ </div>
100
105
  </td>
101
106
  </tr>
102
107
  </table>
@@ -123,7 +128,7 @@ Dynamic: requires-dist
123
128
 
124
129
  **Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
125
130
 
126
- We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
131
+ We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, KTO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
127
132
 
128
133
  ## Supercharge Your Model with Liger Kernel
129
134
 
@@ -188,6 +193,11 @@ y = orpo_loss(lm_head.weight, x, target)
188
193
  - `torch >= 2.5.0` Install according to the instruction in Pytorch official webpage.
189
194
  - `triton >= 3.0.0` Install from pypi. (e.g. `pip install triton==3.0.0`)
190
195
 
196
+ ```bash
197
+ # Need to pass the url when installing
198
+ pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2
199
+ ```
200
+
191
201
  ### Optional Dependencies
192
202
 
193
203
  - `transformers >= 4.x`: Required if you plan to use the transformers models patching APIs. The specific model you are working will dictate the minimum version of transformers.
@@ -305,6 +315,8 @@ loss.backward()
305
315
  | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
306
316
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
307
317
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
318
+ | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
319
+ | OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
308
320
 
309
321
 
310
322
  ## Low-level APIs
@@ -333,6 +345,7 @@ loss.backward()
333
345
  | Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
334
346
  | Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
335
347
  | Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
348
+ | Fused Linear KTO Loss | `liger_kernel.chunked_loss.LigerFusedLinearKTOLoss` |
336
349
 
337
350
  ### Distillation Kernels
338
351
 
@@ -341,6 +354,7 @@ loss.backward()
341
354
  | KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
342
355
  | JSD | `liger_kernel.transformers.LigerJSD` |
343
356
  | Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
357
+ | TVD | `liger_kernel.transformers.LigerTVDLoss` |
344
358
 
345
359
  ### Experimental Kernels
346
360
 
@@ -372,7 +386,7 @@ loss.backward()
372
386
 
373
387
  - For issues, create a Github ticket in this repository
374
388
  - For open discussion, join [our discord channel](https://discord.gg/gpumode)
375
- - For formal collaboration, send an email to byhsu@linkedin.com
389
+ - For formal collaboration, send an email to yannchen@linkedin.com
376
390
 
377
391
  ## Cite this work
378
392
 
@@ -47,6 +47,11 @@
47
47
  <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml/badge.svg?event=schedule" alt="Build">
48
48
  </a>
49
49
  </div>
50
+ <div style="display: block;">
51
+ <a href="https://github.com/linkedin/Liger-Kernel/actions/workflows/amd-ci.yml">
52
+ <img src="https://github.com/linkedin/Liger-Kernel/actions/workflows/intel-ci.yml/badge.svg?event=schedule" alt="Build">
53
+ </a>
54
+ </div>
50
55
  </td>
51
56
  </tr>
52
57
  </table>
@@ -73,7 +78,7 @@
73
78
 
74
79
  **Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
75
80
 
76
- We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
81
+ We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, KTO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
77
82
 
78
83
  ## Supercharge Your Model with Liger Kernel
79
84
 
@@ -138,6 +143,11 @@ y = orpo_loss(lm_head.weight, x, target)
138
143
  - `torch >= 2.5.0` Install according to the instruction in Pytorch official webpage.
139
144
  - `triton >= 3.0.0` Install from pypi. (e.g. `pip install triton==3.0.0`)
140
145
 
146
+ ```bash
147
+ # Need to pass the url when installing
148
+ pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2
149
+ ```
150
+
141
151
  ### Optional Dependencies
142
152
 
143
153
  - `transformers >= 4.x`: Required if you plan to use the transformers models patching APIs. The specific model you are working will dictate the minimum version of transformers.
@@ -255,6 +265,8 @@ loss.backward()
255
265
  | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
256
266
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
257
267
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
268
+ | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
269
+ | OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
258
270
 
259
271
 
260
272
  ## Low-level APIs
@@ -283,6 +295,7 @@ loss.backward()
283
295
  | Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
284
296
  | Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
285
297
  | Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
298
+ | Fused Linear KTO Loss | `liger_kernel.chunked_loss.LigerFusedLinearKTOLoss` |
286
299
 
287
300
  ### Distillation Kernels
288
301
 
@@ -291,6 +304,7 @@ loss.backward()
291
304
  | KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
292
305
  | JSD | `liger_kernel.transformers.LigerJSD` |
293
306
  | Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
307
+ | TVD | `liger_kernel.transformers.LigerTVDLoss` |
294
308
 
295
309
  ### Experimental Kernels
296
310
 
@@ -322,7 +336,7 @@ loss.backward()
322
336
 
323
337
  - For issues, create a Github ticket in this repository
324
338
  - For open discussion, join [our discord channel](https://discord.gg/gpumode)
325
- - For formal collaboration, send an email to byhsu@linkedin.com
339
+ - For formal collaboration, send an email to yannchen@linkedin.com
326
340
 
327
341
  ## Cite this work
328
342
 
@@ -505,6 +505,42 @@ fused_linear_jsd,torch,full,memory,MB,BT,B x T,1024,10609.005859375,10609.005859
505
505
  fused_linear_jsd,torch,full,memory,MB,BT,B x T,2048,17146.009765625,17146.009765625,17146.009765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1
506
506
  fused_linear_jsd,torch,full,memory,MB,BT,B x T,4096,30220.017578125,30220.017578125,30220.017578125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1
507
507
  fused_linear_jsd,torch,full,memory,MB,BT,B x T,8192,56368.015625,56368.015625,56368.015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1
508
+ tvd,liger,full,memory,MB,V,vocab size,4096,1792.0009765625,1792.0009765625,1792.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
509
+ tvd,liger,full,memory,MB,V,vocab size,8192,3584.0009765625,3584.0009765625,3584.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
510
+ tvd,liger,full,memory,MB,V,vocab size,16384,7168.0009765625,7168.0009765625,7168.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
511
+ tvd,liger,full,memory,MB,V,vocab size,32768,14336.0009765625,14336.0009765625,14336.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
512
+ tvd,liger,full,memory,MB,V,vocab size,65536,28672.0,28672.0,28672.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
513
+ tvd,liger,full,memory,MB,V,vocab size,131072,57344.0,57344.0,57344.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
514
+ tvd,torch,full,memory,MB,V,vocab size,4096,2048.0009765625,2048.0009765625,2048.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
515
+ tvd,torch,full,memory,MB,V,vocab size,8192,4096.0009765625,4096.0009765625,4096.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
516
+ tvd,torch,full,memory,MB,V,vocab size,16384,8192.0009765625,8192.0009765625,8192.0009765625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
517
+ tvd,torch,full,memory,MB,V,vocab size,32768,16384.0,16384.0,16384.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
518
+ tvd,torch,full,memory,MB,V,vocab size,65536,32768.0,32768.0,32768.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
519
+ tvd,torch,full,memory,MB,V,vocab size,131072,65536.0,65536.0,65536.0,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:04,0.3.1
520
+ tvd,liger,forward,speed,ms,V,vocab size,4096,0.47814399003982544,0.4774720072746277,0.4790079891681671,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
521
+ tvd,liger,forward,speed,ms,V,vocab size,8192,0.906495988368988,0.905951976776123,0.9073920249938965,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
522
+ tvd,liger,forward,speed,ms,V,vocab size,16384,1.8787360191345215,1.8778239488601685,1.8797119855880737,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
523
+ tvd,liger,forward,speed,ms,V,vocab size,32768,3.5788800716400146,3.5772159099578857,3.58076810836792,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
524
+ tvd,liger,forward,speed,ms,V,vocab size,65536,7.008831977844238,7.007718086242676,7.010636806488037,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
525
+ tvd,liger,forward,speed,ms,V,vocab size,131072,13.88646411895752,13.88128662109375,13.890560150146484,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:08,0.3.1
526
+ tvd,torch,forward,speed,ms,V,vocab size,4096,1.308608055114746,1.306502342224121,1.3104127645492554,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
527
+ tvd,torch,forward,speed,ms,V,vocab size,8192,2.4735519886016846,2.472287893295288,2.4749441146850586,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
528
+ tvd,torch,forward,speed,ms,V,vocab size,16384,4.828320026397705,4.826848030090332,4.830643177032471,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
529
+ tvd,torch,forward,speed,ms,V,vocab size,32768,9.5206880569458,9.517024040222168,9.525145530700684,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
530
+ tvd,torch,forward,speed,ms,V,vocab size,65536,19.01535987854004,19.011123657226562,19.01806640625,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
531
+ tvd,torch,forward,speed,ms,V,vocab size,131072,38.022865295410156,38.01945877075195,38.02627182006836,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:09,0.3.1
532
+ tvd,liger,full,speed,ms,V,vocab size,4096,2.626512050628662,2.621260643005371,2.646751880645752,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
533
+ tvd,liger,full,speed,ms,V,vocab size,8192,4.661711692810059,4.657618999481201,4.662930965423584,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
534
+ tvd,liger,full,speed,ms,V,vocab size,16384,9.088272094726562,9.080741882324219,9.092268943786621,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
535
+ tvd,liger,full,speed,ms,V,vocab size,32768,18.116064071655273,18.112728118896484,18.118234634399414,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
536
+ tvd,liger,full,speed,ms,V,vocab size,65536,35.85124969482422,35.849971771240234,35.85252380371094,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
537
+ tvd,liger,full,speed,ms,V,vocab size,131072,71.1648941040039,71.1648941040039,71.1648941040039,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:11,0.3.1
538
+ tvd,torch,full,speed,ms,V,vocab size,4096,4.361599922180176,4.360159873962402,4.3639678955078125,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
539
+ tvd,torch,full,speed,ms,V,vocab size,8192,8.11302375793457,8.11075210571289,8.114463806152344,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
540
+ tvd,torch,full,speed,ms,V,vocab size,16384,15.841055870056152,15.837087631225586,15.841856002807617,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
541
+ tvd,torch,full,speed,ms,V,vocab size,32768,31.71219253540039,31.706951141357422,31.715898513793945,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
542
+ tvd,torch,full,speed,ms,V,vocab size,65536,63.17919921875,63.17919921875,63.17919921875,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
543
+ tvd,torch,full,speed,ms,V,vocab size,131072,126.0436782836914,126.0436782836914,126.0436782836914,"{""B"": 8, ""T"": 2048}",NVIDIA H100 PCIe,2024-10-26 01:17:13,0.3.1
508
544
  group_norm,liger,forward,speed,ms,C,num_channels,32,0.03481600061058998,0.03379200026392937,0.03993599861860275,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1
509
545
  group_norm,liger,forward,speed,ms,C,num_channels,64,0.05222399905323982,0.05119999870657921,0.05222399905323982,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1
510
546
  group_norm,liger,forward,speed,ms,C,num_channels,128,0.08499199897050858,0.08396799862384796,0.08499199897050858,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1
@@ -769,3 +805,4 @@ distill_jsd_loss,torch,full,memory,MB,BT,B x T,1024,16174.0390625,16174.0390625,
769
805
  distill_jsd_loss,torch,full,memory,MB,BT,B x T,2048,23713.05078125,23713.05078125,23713.05078125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
770
806
  distill_jsd_loss,torch,full,memory,MB,BT,B x T,4096,38791.07421875,38791.07421875,38791.07421875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
771
807
  distill_jsd_loss,torch,full,memory,MB,BT,B x T,8192,68947.1015625,68947.1015625,68947.1015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
808
+
@@ -103,8 +103,8 @@ def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
103
103
  H=H,
104
104
  V=V,
105
105
  dtype=dtype,
106
- bias=bias,
107
- ref_bias=bias,
106
+ use_bias=bias,
107
+ use_ref_bias=bias,
108
108
  ignore_index=ignore_index,
109
109
  beta=beta,
110
110
  ).to(device)
@@ -113,8 +113,8 @@ def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
113
113
  H=H,
114
114
  V=V,
115
115
  dtype=dtype,
116
- bias=bias,
117
- ref_bias=bias,
116
+ use_bias=bias,
117
+ use_ref_bias=bias,
118
118
  ignore_index=ignore_index,
119
119
  beta=beta,
120
120
  ).to(device)
@@ -189,7 +189,7 @@ def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
189
189
  dtype=dtype,
190
190
  beta=beta,
191
191
  ignore_index=ignore_index,
192
- bias=bias,
192
+ use_bias=bias,
193
193
  ).to(device)
194
194
  liger_kto_loss = LigerLMHeadKTO(
195
195
  H=H,
@@ -197,7 +197,7 @@ def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
197
197
  dtype=dtype,
198
198
  beta=beta,
199
199
  ignore_index=ignore_index,
200
- bias=bias,
200
+ use_bias=bias,
201
201
  ).to(device)
202
202
 
203
203
  # Input shape: [B, T, H]
@@ -1,7 +1,6 @@
1
1
  import torch
2
2
  import triton
3
3
 
4
- from test.utils import transformers_version_dispatch
5
4
  from transformers.models.llama.configuration_llama import LlamaConfig
6
5
  from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
7
6
  from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
@@ -14,6 +13,7 @@ from utils import run_benchmarks
14
13
 
15
14
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
16
15
  from liger_kernel.utils import infer_device
16
+ from liger_kernel.utils import transformers_version_dispatch
17
17
 
18
18
  device = infer_device()
19
19
 
@@ -0,0 +1,133 @@
1
+ import torch
2
+ import triton
3
+
4
+ from utils import QUANTILES
5
+ from utils import SingleBenchmarkRunInput
6
+ from utils import SingleBenchmarkRunOutput
7
+ from utils import _test_memory
8
+ from utils import parse_benchmark_script_args
9
+ from utils import run_benchmarks
10
+
11
+ from liger_kernel.transformers.tvd import LigerTVDLoss
12
+
13
+
14
+ class TorchTVDLoss(torch.nn.Module):
15
+ def __init__(self, reduction="batchmean"):
16
+ super(TorchTVDLoss, self).__init__()
17
+ self.reduction = reduction
18
+
19
+ def forward(self, p, q):
20
+ tvd = torch.abs(p - q) / 2.0
21
+ if self.reduction == "mean":
22
+ return torch.sum(tvd) / (p.size(0) * p.size(1))
23
+ elif self.reduction == "sum":
24
+ return torch.sum(tvd)
25
+ elif self.reduction == "none":
26
+ return tvd
27
+ elif self.reduction == "batchmean":
28
+ return torch.sum(tvd) / p.size(0)
29
+ else:
30
+ raise ValueError("Invalid reduction type.")
31
+
32
+
33
+ S, E = 12, 18
34
+
35
+
36
+ def bench_speed_tvd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
37
+ reduction = "batchmean"
38
+ V = input.x
39
+ B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
40
+ torch_tvd = TorchTVDLoss(reduction=reduction)
41
+ liger_tvd = LigerTVDLoss(reduction=reduction)
42
+
43
+ _input = torch.randn(B * T, V, requires_grad=True, device="cuda").softmax(dim=-1)
44
+ target = torch.randn(B * T, V, device="cuda").softmax(dim=-1)
45
+
46
+ def fwd():
47
+ if input.kernel_provider == "liger":
48
+ return liger_tvd(_input, target)
49
+ else:
50
+ return torch_tvd(_input, target)
51
+
52
+ if input.kernel_operation_mode == "forward":
53
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100)
54
+ elif input.kernel_operation_mode == "backward":
55
+ y = fwd()
56
+
57
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(
58
+ lambda: y.backward(retain_graph=True),
59
+ quantiles=QUANTILES,
60
+ grad_to_none=[_input],
61
+ rep=100,
62
+ )
63
+ elif input.kernel_operation_mode == "full":
64
+
65
+ def full():
66
+ y = fwd()
67
+ y.backward(retain_graph=True)
68
+
69
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100)
70
+ return SingleBenchmarkRunOutput(
71
+ y_20=ms_20,
72
+ y_50=ms_50,
73
+ y_80=ms_80,
74
+ )
75
+
76
+
77
+ def bench_memory_tvd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
78
+ reduction = "batchmean"
79
+ torch_tvd = TorchTVDLoss(reduction=reduction)
80
+ liger_tvd = LigerTVDLoss(reduction=reduction)
81
+
82
+ V = input.x
83
+ B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
84
+
85
+ _input = torch.randn(B * T, V, requires_grad=True, device="cuda").softmax(dim=-1)
86
+ target = torch.randn(B * T, V, device="cuda").softmax(dim=-1)
87
+
88
+ def fwd():
89
+ if input.kernel_provider == "liger":
90
+ return liger_tvd(_input, target)
91
+ else:
92
+ return torch_tvd(_input, target)
93
+
94
+ def full():
95
+ y = fwd()
96
+ y.backward(retain_graph=True)
97
+
98
+ mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
99
+
100
+ return SingleBenchmarkRunOutput(
101
+ y_20=mem_20,
102
+ y_50=mem_50,
103
+ y_80=mem_80,
104
+ )
105
+
106
+
107
+ if __name__ == "__main__":
108
+ args = parse_benchmark_script_args()
109
+ common_args = {
110
+ "kernel_name": "tvd",
111
+ "x_name": "V",
112
+ "x_label": "vocab size",
113
+ "x_values": [2**i for i in range(12, 18)],
114
+ "kernel_providers": ["liger", "torch"],
115
+ "extra_benchmark_configs": [{"B": 8, "T": 2048}],
116
+ "overwrite": args.overwrite,
117
+ }
118
+
119
+ run_benchmarks(
120
+ bench_test_fn=bench_memory_tvd,
121
+ kernel_operation_modes=["full"],
122
+ metric_name="memory",
123
+ metric_unit="MB",
124
+ **common_args,
125
+ )
126
+
127
+ run_benchmarks(
128
+ bench_test_fn=bench_speed_tvd,
129
+ kernel_operation_modes=["forward", "full"],
130
+ metric_name="speed",
131
+ metric_unit="ms",
132
+ **common_args,
133
+ )
@@ -14,7 +14,7 @@ app = modal.App("liger_tests", image=image)
14
14
  repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH)
15
15
 
16
16
 
17
- @app.function(gpu="A10G", mounts=[repo], timeout=60 * 15)
17
+ @app.function(gpu="A10G", mounts=[repo], timeout=60 * 20)
18
18
  def liger_tests():
19
19
  import subprocess
20
20
 
@@ -158,7 +158,7 @@ class DataModule(pl.LightningDataModule):
158
158
  for i in range(len(example["question"])):
159
159
  choices = ""
160
160
  for j in range(len(example["choices"][i])):
161
- choices += f"{j+1}. {example['choices'][i][j]}; "
161
+ choices += f"{j + 1}. {example['choices'][i][j]}; "
162
162
  s = "Below is a question and multiple choice answers, choices separated by a semicolon. Please select the best answer for the question. "
163
163
  s += f"{QUESTION}{example['question'][i]} "
164
164
  s += f"{CHOICES}{choices} "
@@ -352,9 +352,9 @@ class EfficiencyCallback(transformers.TrainerCallback):
352
352
  else:
353
353
  return world_size
354
354
 
355
- assert (
356
- world_size != 0
357
- ), "WORLD_SIZE should be set to a positive integer. For single GPU training, please explicitly set WORLD_SIZE=1."
355
+ assert world_size != 0, (
356
+ "WORLD_SIZE should be set to a positive integer. For single GPU training, please explicitly set WORLD_SIZE=1."
357
+ )
358
358
 
359
359
  # TODO: add deepspeed support
360
360
  return world_size
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel"
7
- version = "0.5.3"
7
+ version = "0.5.4"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -21,6 +21,11 @@ def get_default_dependencies():
21
21
  "torch>=2.6.0.dev",
22
22
  "triton>=3.0.0",
23
23
  ]
24
+ elif platform == "xpu":
25
+ return [
26
+ "torch>=2.6.0",
27
+ "pytorch-triton-xpu>=3.2.0",
28
+ ]
24
29
 
25
30
 
26
31
  def get_optional_dependencies():
@@ -43,8 +48,7 @@ def get_optional_dependencies():
43
48
  }
44
49
 
45
50
 
46
- # TODO: add intel XPU
47
- def get_platform() -> Literal["cuda", "rocm", "cpu"]:
51
+ def get_platform() -> Literal["cuda", "rocm", "cpu", "xpu"]:
48
52
  """
49
53
  Detect whether the system has NVIDIA or AMD GPU without torch dependency.
50
54
  """
@@ -60,8 +64,13 @@ def get_platform() -> Literal["cuda", "rocm", "cpu"]:
60
64
  print("ROCm GPU detected")
61
65
  return "rocm"
62
66
  except (subprocess.SubprocessError, FileNotFoundError):
63
- print("No GPU detected")
64
- return "cpu"
67
+ try:
68
+ subprocess.run(["xpu-smi"], check=True)
69
+ print("Intel GPU detected")
70
+ return "xpu"
71
+ except (subprocess.SubprocessError, FileNotFoundError):
72
+ print("No GPU detected")
73
+ return "cpu"
65
74
 
66
75
 
67
76
  setup(
@@ -1,5 +1,6 @@
1
1
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
2
2
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
3
+ from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401
3
4
  from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDLoss # noqa: F401
4
5
  from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401
5
6
  from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401