liger-kernel 0.5.5__tar.gz → 0.5.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/.github/workflows/amd-ci.yml +2 -1
  2. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/PKG-INFO +11 -6
  3. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/README.md +8 -4
  4. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/benchmark/benchmarks_visualizer.py +2 -2
  5. liger_kernel-0.5.7/benchmark/scripts/benchmark_dyt.py +139 -0
  6. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/dev/modal/tests.py +1 -1
  7. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/dev/modal/tests_bwd.py +1 -1
  8. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/pyproject.toml +1 -1
  9. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/chunked_loss/functional.py +2 -0
  10. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +17 -2
  11. liger_kernel-0.5.7/src/liger_kernel/chunked_loss/fused_linear_ppo.py +346 -0
  12. liger_kernel-0.5.7/src/liger_kernel/chunked_loss/grpo_loss.py +268 -0
  13. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/chunked_loss/jsd_loss.py +12 -7
  14. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/ops/cross_entropy.py +3 -2
  15. liger_kernel-0.5.7/src/liger_kernel/ops/dyt.py +225 -0
  16. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/ops/fused_linear_jsd.py +2 -1
  17. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/ops/jsd.py +32 -12
  18. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/ops/kl_div.py +15 -8
  19. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/ops/layer_norm.py +14 -1
  20. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/ops/rms_norm.py +12 -1
  21. liger_kernel-0.5.7/src/liger_kernel/transformers/__init__.py +145 -0
  22. liger_kernel-0.5.7/src/liger_kernel/transformers/dyt.py +20 -0
  23. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/functional.py +5 -0
  24. liger_kernel-0.5.7/src/liger_kernel/transformers/gema3_rms.py +8 -0
  25. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/model/gemma.py +17 -20
  26. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/model/gemma2.py +17 -21
  27. liger_kernel-0.5.7/src/liger_kernel/transformers/model/gemma3.py +335 -0
  28. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/model/llama.py +17 -19
  29. liger_kernel-0.5.7/src/liger_kernel/transformers/model/llava.py +369 -0
  30. liger_kernel-0.5.7/src/liger_kernel/transformers/model/loss_utils.py +64 -0
  31. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/model/mistral.py +28 -25
  32. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/model/mixtral.py +20 -26
  33. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/model/mllama.py +17 -19
  34. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/model/olmo2.py +17 -20
  35. liger_kernel-0.5.7/src/liger_kernel/transformers/model/paligemma.py +397 -0
  36. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/model/phi3.py +17 -19
  37. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/model/qwen2.py +17 -19
  38. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/model/qwen2_5_vl.py +9 -10
  39. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/model/qwen2_vl.py +9 -10
  40. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/monkey_patch.py +392 -13
  41. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel.egg-info/PKG-INFO +11 -6
  42. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel.egg-info/SOURCES.txt +15 -1
  43. liger_kernel-0.5.7/test/chunked_loss/test_grpo_loss.py +495 -0
  44. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/chunked_loss/test_jsd_loss.py +15 -10
  45. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/chunked_loss/test_kto_loss.py +62 -58
  46. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/chunked_loss/test_orpo_loss.py +6 -0
  47. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/convergence/bf16/test_mini_models.py +152 -0
  48. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/convergence/bf16/test_mini_models_multimodal.py +417 -2
  49. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/convergence/bf16/test_mini_models_with_logits.py +153 -0
  50. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/convergence/fp32/test_mini_models.py +145 -0
  51. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/convergence/fp32/test_mini_models_multimodal.py +403 -2
  52. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/convergence/fp32/test_mini_models_with_logits.py +149 -2
  53. liger_kernel-0.5.7/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json +90 -0
  54. liger_kernel-0.5.7/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +61 -0
  55. liger_kernel-0.5.7/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +28 -0
  56. liger_kernel-0.5.7/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +7 -0
  57. liger_kernel-0.5.7/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +66 -0
  58. liger_kernel-0.5.7/test/transformers/test_dyt.py +136 -0
  59. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/transformers/test_jsd.py +5 -5
  60. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/transformers/test_monkey_patch.py +150 -0
  61. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/utils.py +80 -2
  62. liger_kernel-0.5.5/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -240
  63. liger_kernel-0.5.5/src/liger_kernel/chunked_loss/grpo_loss.py +0 -194
  64. liger_kernel-0.5.5/src/liger_kernel/transformers/__init__.py +0 -27
  65. liger_kernel-0.5.5/test/chunked_loss/test_grpo_loss.py +0 -275
  66. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  67. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  68. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/.github/pull_request_template.md +0 -0
  69. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/.github/workflows/docs.yml +0 -0
  70. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/.github/workflows/intel-ci.yml +0 -0
  71. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/.github/workflows/nvi-ci.yml +0 -0
  72. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/.github/workflows/publish-nightly.yml +0 -0
  73. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/.github/workflows/publish-release.yml +0 -0
  74. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/.gitignore +0 -0
  75. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/LICENSE +0 -0
  76. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/Makefile +0 -0
  77. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/NOTICE +0 -0
  78. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/benchmark/README.md +0 -0
  79. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/benchmark/__init__.py +0 -0
  80. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/benchmark/data/all_benchmark_data.csv +0 -0
  81. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/benchmark/scripts/__init__.py +0 -0
  82. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  83. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  84. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  85. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  86. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_embedding.py +0 -0
  87. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  88. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  89. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_geglu.py +0 -0
  90. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_group_norm.py +0 -0
  91. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_jsd.py +0 -0
  92. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_kl_div.py +0 -0
  93. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  94. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  95. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  96. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  97. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  98. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_rope.py +0 -0
  99. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  100. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_swiglu.py +0 -0
  101. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/benchmark/scripts/benchmark_tvd.py +0 -0
  102. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/benchmark/scripts/utils.py +0 -0
  103. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/dev/fmt-requirements.txt +0 -0
  104. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/docs/Examples.md +0 -0
  105. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/docs/Getting-Started.md +0 -0
  106. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/docs/High-Level-APIs.md +0 -0
  107. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/docs/Low-Level-APIs.md +0 -0
  108. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/docs/acknowledgement.md +0 -0
  109. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/docs/contributing.md +0 -0
  110. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/docs/images/banner.GIF +0 -0
  111. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/docs/images/compose.gif +0 -0
  112. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/docs/images/e2e-memory.png +0 -0
  113. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/docs/images/e2e-tps.png +0 -0
  114. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/docs/images/logo-banner.png +0 -0
  115. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/docs/images/patch.gif +0 -0
  116. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/docs/images/post-training.png +0 -0
  117. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/docs/index.md +0 -0
  118. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/docs/license.md +0 -0
  119. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/alignment/accelerate_config.yaml +0 -0
  120. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/alignment/run_orpo.py +0 -0
  121. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/huggingface/README.md +0 -0
  122. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/huggingface/callback.py +0 -0
  123. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/huggingface/config/fsdp_config.json +0 -0
  124. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  125. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  126. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  127. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/huggingface/img/llama_tps.png +0 -0
  128. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  129. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/huggingface/img/qwen_tps.png +0 -0
  130. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/huggingface/launch_on_modal.py +0 -0
  131. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/huggingface/requirements.txt +0 -0
  132. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/huggingface/run_benchmarks.sh +0 -0
  133. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/huggingface/run_gemma.sh +0 -0
  134. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/huggingface/run_llama.sh +0 -0
  135. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/huggingface/run_qwen.sh +0 -0
  136. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/huggingface/run_qwen2_vl.sh +0 -0
  137. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/huggingface/training.py +0 -0
  138. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/huggingface/training_multimodal.py +0 -0
  139. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/lightning/README.md +0 -0
  140. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/lightning/requirements.txt +0 -0
  141. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/lightning/training.py +0 -0
  142. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/medusa/README.md +0 -0
  143. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/medusa/callback.py +0 -0
  144. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  145. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  146. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  147. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  148. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  149. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  150. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  151. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  152. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  153. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/medusa/medusa_util.py +0 -0
  154. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/medusa/requirements.txt +0 -0
  155. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  156. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/examples/medusa/train.py +0 -0
  157. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/licenses/LICENSE-Apache-2.0 +0 -0
  158. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  159. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  160. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/licenses/LICENSE-MIT-llmc +0 -0
  161. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/licenses/LICENSE-MIT-triton +0 -0
  162. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/mkdocs.yml +0 -0
  163. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/setup.cfg +0 -0
  164. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/setup.py +0 -0
  165. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/__init__.py +0 -0
  166. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/chunked_loss/README.md +0 -0
  167. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  168. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  169. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  170. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  171. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  172. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  173. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  174. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  175. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/env_report.py +0 -0
  176. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/ops/__init__.py +0 -0
  177. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  178. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  179. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  180. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/ops/geglu.py +0 -0
  181. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/ops/group_norm.py +0 -0
  182. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  183. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/ops/rope.py +0 -0
  184. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/ops/swiglu.py +0 -0
  185. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/ops/tvd.py +0 -0
  186. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/ops/utils.py +0 -0
  187. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/auto_model.py +0 -0
  188. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  189. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  190. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  191. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  192. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/geglu.py +0 -0
  193. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/group_norm.py +0 -0
  194. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/jsd.py +0 -0
  195. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/kl_div.py +0 -0
  196. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/layer_norm.py +0 -0
  197. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/model/__init__.py +0 -0
  198. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  199. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/rms_norm.py +0 -0
  200. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/rope.py +0 -0
  201. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/swiglu.py +0 -0
  202. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  203. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  204. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  205. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/transformers/tvd.py +0 -0
  206. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/triton/__init__.py +0 -0
  207. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/triton/monkey_patch.py +0 -0
  208. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel/utils.py +0 -0
  209. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel.egg-info/dependency_links.txt +0 -0
  210. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel.egg-info/requires.txt +0 -0
  211. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/src/liger_kernel.egg-info/top_level.txt +0 -0
  212. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/__init__.py +0 -0
  213. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/chunked_loss/__init__.py +0 -0
  214. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/chunked_loss/test_cpo_loss.py +0 -0
  215. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/chunked_loss/test_dpo_loss.py +0 -0
  216. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/chunked_loss/test_simpo_loss.py +0 -0
  217. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/conftest.py +0 -0
  218. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/convergence/__init__.py +0 -0
  219. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/convergence/bf16/__init__.py +0 -0
  220. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/convergence/fp32/__init__.py +0 -0
  221. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  222. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
  223. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  224. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  225. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/resources/tiny_shakespeare.txt +0 -0
  226. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  227. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  228. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  229. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/transformers/test_auto_model.py +0 -0
  230. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/transformers/test_cross_entropy.py +0 -0
  231. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/transformers/test_embedding.py +0 -0
  232. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/transformers/test_flex_attention.py +0 -0
  233. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  234. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/transformers/test_fused_linear_jsd.py +0 -0
  235. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/transformers/test_geglu.py +0 -0
  236. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/transformers/test_group_norm.py +0 -0
  237. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/transformers/test_kl_div.py +0 -0
  238. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/transformers/test_layer_norm.py +0 -0
  239. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/transformers/test_mm_int8int2.py +0 -0
  240. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/transformers/test_qwen2vl_mrope.py +0 -0
  241. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/transformers/test_rms_norm.py +0 -0
  242. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/transformers/test_rope.py +0 -0
  243. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/transformers/test_swiglu.py +0 -0
  244. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/transformers/test_trainer_integration.py +0 -0
  245. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/transformers/test_transformers.py +0 -0
  246. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/transformers/test_tvd.py +0 -0
  247. {liger_kernel-0.5.5 → liger_kernel-0.5.7}/test/triton/test_triton_monkey_patch.py +0 -0
@@ -49,7 +49,7 @@ jobs:
49
49
  needs: [checkstyle]
50
50
  strategy:
51
51
  matrix:
52
- rocm_version: ['6.2', '6.3']
52
+ rocm_version: ['6.3']
53
53
 
54
54
  steps:
55
55
  - name: Checkout code
@@ -62,6 +62,7 @@ jobs:
62
62
 
63
63
  - name: Setup Dependencies
64
64
  run: |
65
+ rocm-smi
65
66
  python -m pip install --upgrade pip
66
67
  pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/nightly/rocm${{ matrix.rocm_version }}
67
68
 
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: liger_kernel
3
- Version: 0.5.5
3
+ Version: 0.5.7
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -45,6 +45,7 @@ Requires-Dist: datasets>=2.19.2; extra == "dev"
45
45
  Requires-Dist: seaborn; extra == "dev"
46
46
  Requires-Dist: mkdocs; extra == "dev"
47
47
  Requires-Dist: mkdocs-material; extra == "dev"
48
+ Dynamic: license-file
48
49
  Dynamic: provides-extra
49
50
  Dynamic: requires-dist
50
51
 
@@ -115,6 +116,7 @@ Dynamic: requires-dist
115
116
  <details>
116
117
  <summary>Latest News 🔥</summary>
117
118
 
119
+ - [2025/03/06] We release a joint blog post on TorchTune × Liger - [Peak Performance, Minimized Memory: Optimizing torchtune’s performance with torch.compile & Liger Kernel](https://pytorch.org/blog/peak-performance-minimized-memory/)
118
120
  - [2024/12/11] We release [v0.5.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.5.0): 80% more memory efficient post training losses (DPO, ORPO, CPO, etc)!
119
121
  - [2024/12/5] We release LinkedIn Engineering Blog - [Liger-Kernel: Empowering an open source ecosystem of Triton Kernels for Efficient LLM Training](https://www.linkedin.com/blog/engineering/open-source/liger-kernel-open-source-ecosystem-for-efficient-llm-training)
120
122
  - [2024/11/6] We release [v0.4.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.4.0): Full AMD support, Tech Report, Modal CI, Llama-3.2-Vision!
@@ -177,7 +179,7 @@ y = orpo_loss(lm_head.weight, x, target)
177
179
  - **Exact:** Computation is exact—no approximations! Both forward and backward passes are implemented with rigorous unit tests and undergo convergence testing against training runs without Liger Kernel to ensure accuracy.
178
180
  - **Lightweight:** Liger Kernel has minimal dependencies, requiring only Torch and Triton—no extra libraries needed! Say goodbye to dependency headaches!
179
181
  - **Multi-GPU supported:** Compatible with multi-GPU setups (PyTorch FSDP, DeepSpeed, DDP, etc.).
180
- - **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860), [SWIFT](https://github.com/modelscope/ms-swift)
182
+ - **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860), [SWIFT](https://github.com/modelscope/ms-swift), [oumi](https://github.com/oumi-ai/oumi/tree/main)
181
183
 
182
184
  ## Installation
183
185
 
@@ -312,6 +314,9 @@ loss.backward()
312
314
  | Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
313
315
  | Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
314
316
  | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
317
+ | Gemma3 (Text) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3_text` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
318
+ | Gemma3 (Multimodal) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
319
+ | Paligemma, Paligemma2, & Paligemma2 Mix | `liger_kernel.transformers.apply_liger_kernel_to_paligemma` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
315
320
  | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
316
321
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
317
322
  | Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
@@ -386,8 +391,8 @@ loss.backward()
386
391
  ## Contact
387
392
 
388
393
  - For issues, create a Github ticket in this repository
389
- - For open discussion, join [our discord channel](https://discord.gg/gpumode)
390
- - For formal collaboration, send an email to yannchen@linkedin.com
394
+ - For open discussion, join [our discord channel on GPUMode](https://discord.com/channels/1189498204333543425/1275130785933951039)
395
+ - For formal collaboration, send an email to yannchen@linkedin.com and hning@linkedin.com
391
396
 
392
397
  ## Cite this work
393
398
 
@@ -406,7 +411,7 @@ Biblatex entry:
406
411
  ```
407
412
 
408
413
  ## Star History
409
- [![Star History Chart](https://api.star-history.com/svg?repos=linkedin/Liger-Kernel&type=Date)](https://star-history.com/#linkedin/Liger-Kernel&Date)
414
+ [![Star History Chart](https://api.star-history.com/svg?repos=linkedin/Liger-Kernel&type=Date)](https://www.star-history.com/#linkedin/Liger-Kernel&Date)
410
415
 
411
416
  <p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
412
417
  <a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
@@ -65,6 +65,7 @@
65
65
  <details>
66
66
  <summary>Latest News 🔥</summary>
67
67
 
68
+ - [2025/03/06] We release a joint blog post on TorchTune × Liger - [Peak Performance, Minimized Memory: Optimizing torchtune’s performance with torch.compile & Liger Kernel](https://pytorch.org/blog/peak-performance-minimized-memory/)
68
69
  - [2024/12/11] We release [v0.5.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.5.0): 80% more memory efficient post training losses (DPO, ORPO, CPO, etc)!
69
70
  - [2024/12/5] We release LinkedIn Engineering Blog - [Liger-Kernel: Empowering an open source ecosystem of Triton Kernels for Efficient LLM Training](https://www.linkedin.com/blog/engineering/open-source/liger-kernel-open-source-ecosystem-for-efficient-llm-training)
70
71
  - [2024/11/6] We release [v0.4.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.4.0): Full AMD support, Tech Report, Modal CI, Llama-3.2-Vision!
@@ -127,7 +128,7 @@ y = orpo_loss(lm_head.weight, x, target)
127
128
  - **Exact:** Computation is exact—no approximations! Both forward and backward passes are implemented with rigorous unit tests and undergo convergence testing against training runs without Liger Kernel to ensure accuracy.
128
129
  - **Lightweight:** Liger Kernel has minimal dependencies, requiring only Torch and Triton—no extra libraries needed! Say goodbye to dependency headaches!
129
130
  - **Multi-GPU supported:** Compatible with multi-GPU setups (PyTorch FSDP, DeepSpeed, DDP, etc.).
130
- - **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860), [SWIFT](https://github.com/modelscope/ms-swift)
131
+ - **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860), [SWIFT](https://github.com/modelscope/ms-swift), [oumi](https://github.com/oumi-ai/oumi/tree/main)
131
132
 
132
133
  ## Installation
133
134
 
@@ -262,6 +263,9 @@ loss.backward()
262
263
  | Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
263
264
  | Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
264
265
  | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
266
+ | Gemma3 (Text) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3_text` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
267
+ | Gemma3 (Multimodal) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
268
+ | Paligemma, Paligemma2, & Paligemma2 Mix | `liger_kernel.transformers.apply_liger_kernel_to_paligemma` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
265
269
  | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
266
270
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
267
271
  | Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
@@ -336,8 +340,8 @@ loss.backward()
336
340
  ## Contact
337
341
 
338
342
  - For issues, create a Github ticket in this repository
339
- - For open discussion, join [our discord channel](https://discord.gg/gpumode)
340
- - For formal collaboration, send an email to yannchen@linkedin.com
343
+ - For open discussion, join [our discord channel on GPUMode](https://discord.com/channels/1189498204333543425/1275130785933951039)
344
+ - For formal collaboration, send an email to yannchen@linkedin.com and hning@linkedin.com
341
345
 
342
346
  ## Cite this work
343
347
 
@@ -356,7 +360,7 @@ Biblatex entry:
356
360
  ```
357
361
 
358
362
  ## Star History
359
- [![Star History Chart](https://api.star-history.com/svg?repos=linkedin/Liger-Kernel&type=Date)](https://star-history.com/#linkedin/Liger-Kernel&Date)
363
+ [![Star History Chart](https://api.star-history.com/svg?repos=linkedin/Liger-Kernel&type=Date)](https://www.star-history.com/#linkedin/Liger-Kernel&Date)
360
364
 
361
365
  <p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
362
366
  <a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
@@ -8,8 +8,8 @@ import matplotlib.pyplot as plt
8
8
  import pandas as pd
9
9
  import seaborn as sns
10
10
 
11
- DATA_PATH = "data/all_benchmark_data.csv"
12
- VISUALIZATIONS_PATH = "visualizations/"
11
+ DATA_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "data/all_benchmark_data.csv"))
12
+ VISUALIZATIONS_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "visualizations/"))
13
13
 
14
14
 
15
15
  @dataclass
@@ -0,0 +1,139 @@
1
+ import os
2
+ import sys
3
+
4
+ import torch
5
+ import triton
6
+
7
+ from utils import QUANTILES
8
+ from utils import SingleBenchmarkRunInput
9
+ from utils import SingleBenchmarkRunOutput
10
+ from utils import _test_memory
11
+ from utils import parse_benchmark_script_args
12
+ from utils import run_benchmarks
13
+
14
+ from liger_kernel.utils import infer_device
15
+
16
+ device = infer_device()
17
+
18
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
19
+
20
+
21
+ def bench_speed_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
22
+ from test.transformers.test_dyt import LigerDyT
23
+ from test.transformers.test_dyt import TorchDyT
24
+
25
+ BT = input.x
26
+ provider = input.kernel_provider
27
+ mode = input.kernel_operation_mode
28
+ extra_benchmark_config = input.extra_benchmark_config
29
+ hidden_size = extra_benchmark_config["hidden_size"]
30
+ dtype = extra_benchmark_config["dtype"]
31
+
32
+ x_shape = (BT, hidden_size)
33
+ torch_dyt = TorchDyT(hidden_size=hidden_size).to(device)
34
+ torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size).to(device))
35
+ triton_dyt = LigerDyT(hidden_size=hidden_size).to(device)
36
+
37
+ x = torch.randn(x_shape, dtype=dtype, device=device)
38
+ dy = torch.randn_like(x)
39
+ x.requires_grad_(True)
40
+
41
+ def fwd():
42
+ if provider == "liger":
43
+ return triton_dyt(x)
44
+ elif provider == "torch":
45
+ return torch_dyt(x)
46
+ elif provider == "torch_compile":
47
+ return torch_compile_dyt(x)
48
+
49
+ if mode == "forward":
50
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500)
51
+ elif mode == "backward":
52
+ y = fwd()
53
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(
54
+ lambda: y.backward(dy, retain_graph=True),
55
+ quantiles=QUANTILES,
56
+ grad_to_none=[x],
57
+ rep=500,
58
+ )
59
+ elif mode == "full":
60
+
61
+ def full():
62
+ y = fwd()
63
+ y.backward(dy)
64
+
65
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=[x], rep=500)
66
+
67
+ return SingleBenchmarkRunOutput(
68
+ y_20=ms_20,
69
+ y_50=ms_50,
70
+ y_80=ms_80,
71
+ )
72
+
73
+
74
+ def bench_memory_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
75
+ from test.transformers.test_dyt import LigerDyT
76
+ from test.transformers.test_dyt import TorchDyT
77
+
78
+ BT = input.x
79
+ provider = input.kernel_provider
80
+ extra_benchmark_config = input.extra_benchmark_config
81
+ hidden_size = extra_benchmark_config["hidden_size"]
82
+ dtype = extra_benchmark_config["dtype"]
83
+
84
+ x_shape = (BT, hidden_size)
85
+ torch_dyt = TorchDyT(hidden_size=hidden_size).to(device)
86
+ torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size).to(device))
87
+ triton_dyt = LigerDyT(hidden_size=hidden_size).to(device)
88
+
89
+ x = torch.randn(x_shape, dtype=dtype, device=device)
90
+ dy = torch.randn_like(x)
91
+ x.requires_grad_(True)
92
+
93
+ def fwd():
94
+ if provider == "liger":
95
+ return triton_dyt(x)
96
+ elif provider == "torch":
97
+ return torch_dyt(x)
98
+ elif provider == "torch_compile":
99
+ return torch_compile_dyt(x)
100
+
101
+ def full():
102
+ y = fwd()
103
+ y.backward(dy, retain_graph=True)
104
+
105
+ mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
106
+ return SingleBenchmarkRunOutput(
107
+ y_20=mem_20,
108
+ y_50=mem_50,
109
+ y_80=mem_80,
110
+ )
111
+
112
+
113
+ if __name__ == "__main__":
114
+ args = parse_benchmark_script_args()
115
+
116
+ common_configs = {
117
+ "kernel_name": "dyt",
118
+ "x_name": "BT",
119
+ "x_label": "batch_size * seq_len",
120
+ "x_values": [2**i for i in range(10, 15)],
121
+ "kernel_providers": ["liger", "torch", "torch_compile"],
122
+ "extra_benchmark_configs": [{"hidden_size": 4096, "dtype": torch.float32}],
123
+ "overwrite": args.overwrite,
124
+ }
125
+
126
+ run_benchmarks(
127
+ bench_test_fn=bench_speed_dyt,
128
+ kernel_operation_modes=["forward", "backward", "full"],
129
+ metric_name="speed",
130
+ metric_unit="ms",
131
+ **common_configs,
132
+ )
133
+ run_benchmarks(
134
+ bench_test_fn=bench_memory_dyt,
135
+ kernel_operation_modes=["full"],
136
+ metric_name="memory",
137
+ metric_unit="MB",
138
+ **common_configs,
139
+ )
@@ -14,7 +14,7 @@ app = modal.App("liger_tests", image=image)
14
14
  repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH)
15
15
 
16
16
 
17
- @app.function(gpu="A10G", mounts=[repo], timeout=60 * 20)
17
+ @app.function(gpu="A10G", mounts=[repo], timeout=60 * 30)
18
18
  def liger_tests():
19
19
  import subprocess
20
20
 
@@ -14,7 +14,7 @@ app = modal.App("liger_tests_bwd", image=image)
14
14
  repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH)
15
15
 
16
16
 
17
- @app.function(gpu="A10G", mounts=[repo], timeout=60 * 15)
17
+ @app.function(gpu="A10G", mounts=[repo], timeout=60 * 30)
18
18
  def liger_bwd_tests():
19
19
  import subprocess
20
20
 
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel"
7
- version = "0.5.5"
7
+ version = "0.5.7"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -1,5 +1,6 @@
1
1
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
2
2
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
3
+ from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
3
4
  from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
4
5
  from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction
5
6
  from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
@@ -11,3 +12,4 @@ liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
11
12
  liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
12
13
  liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
13
14
  liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
15
+ liger_fused_linear_grpo = LigerFusedLinearGRPOFunction.apply
@@ -115,9 +115,24 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
115
115
  student_logits_chunk /= temperature
116
116
  teacher_logits_chunk /= temperature
117
117
 
118
+ # If the teacher and student token size is different, pad student logits to match the teacher's.
119
+ # This only applies to cases where they share exactly the same vocab and tokenizer just
120
+ # that teacher logit is padded for some training efficiency such as
121
+ # https://huggingface.co/Qwen/Qwen1.5-72B-Chat/discussions/1#662883f568adf59b07b176d2
122
+ teacher_vocab_size = teacher_weight.shape[0]
123
+ student_vocab_size = student_weight.shape[0]
124
+ if teacher_vocab_size > student_vocab_size:
125
+ pad_size = teacher_vocab_size - student_vocab_size
126
+ pad_tensor = torch.zeros(
127
+ (*student_logits_chunk.shape[:-1], pad_size),
128
+ dtype=student_logits_chunk.dtype,
129
+ device=student_logits_chunk.device,
130
+ )
131
+ student_logits_chunk = torch.cat([student_logits_chunk, pad_tensor], dim=-1)
132
+
118
133
  hard_loss /= full_target.shape[0]
119
134
 
120
- soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk)
135
+ soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, **loss_kwargs)
121
136
  soft_loss /= full_target.shape[0]
122
137
 
123
138
  loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
@@ -180,9 +195,9 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
180
195
  ignore_index=ignore_index,
181
196
  weight_hard_loss=weight_hard_loss,
182
197
  weight_soft_loss=weight_soft_loss,
183
- beta=beta,
184
198
  compute_ce_loss=compute_ce_loss,
185
199
  temperature=temperature,
200
+ beta=beta,
186
201
  **loss_kwargs,
187
202
  )
188
203