liger-kernel 0.5.4__tar.gz → 0.5.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (229) hide show
  1. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/.github/workflows/amd-ci.yml +4 -1
  2. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/PKG-INFO +3 -2
  3. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/README.md +2 -1
  4. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/benchmark/data/all_benchmark_data.csv +30 -31
  5. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_distill_jsd_loss.py +2 -0
  6. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_kto_loss.py +4 -4
  7. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/pyproject.toml +1 -1
  8. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/cpo_loss.py +51 -11
  9. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/dpo_loss.py +30 -4
  10. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +3 -3
  11. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/fused_linear_preference.py +2 -2
  12. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +33 -6
  13. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +112 -17
  14. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/grpo_loss.py +37 -3
  15. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/jsd_loss.py +31 -6
  16. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/kto_loss.py +50 -12
  17. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/orpo_loss.py +37 -5
  18. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/simpo_loss.py +47 -11
  19. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/ops/cross_entropy.py +4 -0
  20. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/__init__.py +1 -0
  21. liger_kernel-0.5.5/src/liger_kernel/transformers/model/qwen2_5_vl.py +205 -0
  22. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/monkey_patch.py +68 -0
  23. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/utils.py +1 -3
  24. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel.egg-info/PKG-INFO +3 -2
  25. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel.egg-info/SOURCES.txt +2 -0
  26. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/chunked_loss/test_jsd_loss.py +49 -10
  27. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/chunked_loss/test_kto_loss.py +85 -8
  28. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/convergence/bf16/test_mini_models.py +86 -0
  29. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/convergence/bf16/test_mini_models_multimodal.py +100 -1
  30. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/convergence/bf16/test_mini_models_with_logits.py +108 -22
  31. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/convergence/fp32/test_mini_models.py +83 -0
  32. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/convergence/fp32/test_mini_models_multimodal.py +99 -1
  33. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/convergence/fp32/test_mini_models_with_logits.py +83 -0
  34. liger_kernel-0.5.5/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +63 -0
  35. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/transformers/test_cross_entropy.py +39 -0
  36. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/transformers/test_monkey_patch.py +68 -0
  37. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/utils.py +18 -1
  38. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  39. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  40. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/.github/pull_request_template.md +0 -0
  41. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/.github/workflows/docs.yml +0 -0
  42. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/.github/workflows/intel-ci.yml +0 -0
  43. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/.github/workflows/nvi-ci.yml +0 -0
  44. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/.github/workflows/publish-nightly.yml +0 -0
  45. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/.github/workflows/publish-release.yml +0 -0
  46. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/.gitignore +0 -0
  47. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/LICENSE +0 -0
  48. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/Makefile +0 -0
  49. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/NOTICE +0 -0
  50. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/benchmark/README.md +0 -0
  51. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/benchmark/__init__.py +0 -0
  52. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/benchmark/benchmarks_visualizer.py +0 -0
  53. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/benchmark/scripts/__init__.py +0 -0
  54. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  55. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  56. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  57. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_embedding.py +0 -0
  58. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  59. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  60. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_geglu.py +0 -0
  61. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_group_norm.py +0 -0
  62. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_jsd.py +0 -0
  63. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_kl_div.py +0 -0
  64. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  65. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  66. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  67. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  68. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_rope.py +0 -0
  69. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  70. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_swiglu.py +0 -0
  71. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/benchmark/scripts/benchmark_tvd.py +0 -0
  72. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/benchmark/scripts/utils.py +0 -0
  73. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/dev/fmt-requirements.txt +0 -0
  74. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/dev/modal/tests.py +0 -0
  75. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/dev/modal/tests_bwd.py +0 -0
  76. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/docs/Examples.md +0 -0
  77. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/docs/Getting-Started.md +0 -0
  78. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/docs/High-Level-APIs.md +0 -0
  79. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/docs/Low-Level-APIs.md +0 -0
  80. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/docs/acknowledgement.md +0 -0
  81. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/docs/contributing.md +0 -0
  82. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/docs/images/banner.GIF +0 -0
  83. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/docs/images/compose.gif +0 -0
  84. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/docs/images/e2e-memory.png +0 -0
  85. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/docs/images/e2e-tps.png +0 -0
  86. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/docs/images/logo-banner.png +0 -0
  87. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/docs/images/patch.gif +0 -0
  88. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/docs/images/post-training.png +0 -0
  89. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/docs/index.md +0 -0
  90. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/docs/license.md +0 -0
  91. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/alignment/accelerate_config.yaml +0 -0
  92. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/alignment/run_orpo.py +0 -0
  93. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/huggingface/README.md +0 -0
  94. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/huggingface/callback.py +0 -0
  95. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/huggingface/config/fsdp_config.json +0 -0
  96. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  97. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  98. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  99. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/huggingface/img/llama_tps.png +0 -0
  100. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  101. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/huggingface/img/qwen_tps.png +0 -0
  102. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/huggingface/launch_on_modal.py +0 -0
  103. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/huggingface/requirements.txt +0 -0
  104. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/huggingface/run_benchmarks.sh +0 -0
  105. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/huggingface/run_gemma.sh +0 -0
  106. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/huggingface/run_llama.sh +0 -0
  107. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/huggingface/run_qwen.sh +0 -0
  108. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/huggingface/run_qwen2_vl.sh +0 -0
  109. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/huggingface/training.py +0 -0
  110. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/huggingface/training_multimodal.py +0 -0
  111. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/lightning/README.md +0 -0
  112. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/lightning/requirements.txt +0 -0
  113. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/lightning/training.py +0 -0
  114. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/medusa/README.md +0 -0
  115. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/medusa/callback.py +0 -0
  116. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  117. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  118. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  119. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  120. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  121. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  122. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  123. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  124. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  125. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/medusa/medusa_util.py +0 -0
  126. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/medusa/requirements.txt +0 -0
  127. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  128. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/examples/medusa/train.py +0 -0
  129. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/licenses/LICENSE-Apache-2.0 +0 -0
  130. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  131. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  132. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/licenses/LICENSE-MIT-llmc +0 -0
  133. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/licenses/LICENSE-MIT-triton +0 -0
  134. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/mkdocs.yml +0 -0
  135. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/setup.cfg +0 -0
  136. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/setup.py +0 -0
  137. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/__init__.py +0 -0
  138. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/README.md +0 -0
  139. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  140. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/chunked_loss/functional.py +0 -0
  141. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/env_report.py +0 -0
  142. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/ops/__init__.py +0 -0
  143. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  144. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  145. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  146. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  147. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/ops/geglu.py +0 -0
  148. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/ops/group_norm.py +0 -0
  149. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/ops/jsd.py +0 -0
  150. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/ops/kl_div.py +0 -0
  151. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/ops/layer_norm.py +0 -0
  152. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  153. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/ops/rms_norm.py +0 -0
  154. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/ops/rope.py +0 -0
  155. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/ops/swiglu.py +0 -0
  156. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/ops/tvd.py +0 -0
  157. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/ops/utils.py +0 -0
  158. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/auto_model.py +0 -0
  159. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  160. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  161. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/functional.py +0 -0
  162. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  163. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  164. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/geglu.py +0 -0
  165. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/group_norm.py +0 -0
  166. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/jsd.py +0 -0
  167. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/kl_div.py +0 -0
  168. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/layer_norm.py +0 -0
  169. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/__init__.py +0 -0
  170. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/gemma.py +0 -0
  171. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  172. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/llama.py +0 -0
  173. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/mistral.py +0 -0
  174. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  175. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/mllama.py +0 -0
  176. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/olmo2.py +0 -0
  177. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/phi3.py +0 -0
  178. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  179. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  180. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  181. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/rms_norm.py +0 -0
  182. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/rope.py +0 -0
  183. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/swiglu.py +0 -0
  184. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  185. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  186. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  187. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/transformers/tvd.py +0 -0
  188. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/triton/__init__.py +0 -0
  189. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel/triton/monkey_patch.py +0 -0
  190. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel.egg-info/dependency_links.txt +0 -0
  191. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel.egg-info/requires.txt +0 -0
  192. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/src/liger_kernel.egg-info/top_level.txt +0 -0
  193. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/__init__.py +0 -0
  194. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/chunked_loss/__init__.py +0 -0
  195. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/chunked_loss/test_cpo_loss.py +0 -0
  196. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/chunked_loss/test_dpo_loss.py +0 -0
  197. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/chunked_loss/test_grpo_loss.py +0 -0
  198. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/chunked_loss/test_orpo_loss.py +0 -0
  199. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/chunked_loss/test_simpo_loss.py +0 -0
  200. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/conftest.py +0 -0
  201. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/convergence/__init__.py +0 -0
  202. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/convergence/bf16/__init__.py +0 -0
  203. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/convergence/fp32/__init__.py +0 -0
  204. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  205. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  206. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  207. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/resources/tiny_shakespeare.txt +0 -0
  208. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  209. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  210. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  211. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/transformers/test_auto_model.py +0 -0
  212. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/transformers/test_embedding.py +0 -0
  213. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/transformers/test_flex_attention.py +0 -0
  214. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  215. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/transformers/test_fused_linear_jsd.py +0 -0
  216. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/transformers/test_geglu.py +0 -0
  217. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/transformers/test_group_norm.py +0 -0
  218. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/transformers/test_jsd.py +0 -0
  219. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/transformers/test_kl_div.py +0 -0
  220. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/transformers/test_layer_norm.py +0 -0
  221. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/transformers/test_mm_int8int2.py +0 -0
  222. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/transformers/test_qwen2vl_mrope.py +0 -0
  223. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/transformers/test_rms_norm.py +0 -0
  224. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/transformers/test_rope.py +0 -0
  225. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/transformers/test_swiglu.py +0 -0
  226. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/transformers/test_trainer_integration.py +0 -0
  227. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/transformers/test_transformers.py +0 -0
  228. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/transformers/test_tvd.py +0 -0
  229. {liger_kernel-0.5.4 → liger_kernel-0.5.5}/test/triton/test_triton_monkey_patch.py +0 -0
@@ -47,6 +47,9 @@ jobs:
47
47
  tests:
48
48
  runs-on: linux-mi300-gpu-1
49
49
  needs: [checkstyle]
50
+ strategy:
51
+ matrix:
52
+ rocm_version: ['6.2', '6.3']
50
53
 
51
54
  steps:
52
55
  - name: Checkout code
@@ -60,7 +63,7 @@ jobs:
60
63
  - name: Setup Dependencies
61
64
  run: |
62
65
  python -m pip install --upgrade pip
63
- pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2
66
+ pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/nightly/rocm${{ matrix.rocm_version }}
64
67
 
65
68
  - name: List Python Environments
66
69
  run: python -m pip list
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: liger_kernel
3
- Version: 0.5.4
3
+ Version: 0.5.5
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -154,7 +154,7 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
154
154
  We provide optimized post training kernels like DPO, ORPO, SimPO, and more which can reduce memory usage by up to 80%. You can easily use them as python modules.
155
155
 
156
156
  ```python
157
- from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss
157
+ from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss
158
158
  orpo_loss = LigerFusedLinearORPOLoss()
159
159
  y = orpo_loss(lm_head.weight, x, target)
160
160
  ```
@@ -314,6 +314,7 @@ loss.backward()
314
314
  | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
315
315
  | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
316
316
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
317
+ | Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
317
318
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
318
319
  | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
319
320
  | OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
@@ -104,7 +104,7 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
104
104
  We provide optimized post training kernels like DPO, ORPO, SimPO, and more which can reduce memory usage by up to 80%. You can easily use them as python modules.
105
105
 
106
106
  ```python
107
- from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss
107
+ from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss
108
108
  orpo_loss = LigerFusedLinearORPOLoss()
109
109
  y = orpo_loss(lm_head.weight, x, target)
110
110
  ```
@@ -264,6 +264,7 @@ loss.backward()
264
264
  | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
265
265
  | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
266
266
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
267
+ | Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
267
268
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
268
269
  | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
269
270
  | OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
@@ -751,36 +751,6 @@ fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.314
751
751
  fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
752
752
  fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
753
753
  fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
754
- kto_loss,liger,forward,speed,ms,B,Batch Size (B),2,7.841599941253662,7.801983833312988,7.849664211273193,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2
755
- kto_loss,liger,forward,speed,ms,B,Batch Size (B),4,15.568096160888672,15.555737495422363,16.054176330566406,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2
756
- kto_loss,liger,forward,speed,ms,B,Batch Size (B),8,31.145376205444336,30.750951766967773,31.5398006439209,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2
757
- kto_loss,liger,forward,speed,ms,B,Batch Size (B),16,61.49708938598633,61.49708938598633,61.49708938598633,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2
758
- kto_loss,liger,forward,speed,ms,B,Batch Size (B),32,122.01449584960938,122.01449584960938,122.01449584960938,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:06,0.5.2
759
- kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),2,7.892335891723633,7.8687615394592285,8.03729248046875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2
760
- kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),4,14.16302490234375,13.813311576843262,15.860223770141602,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2
761
- kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),8,25.56470489501953,25.564167022705078,25.641658782958984,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2
762
- kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),16,53.0928955078125,53.0928955078125,53.0928955078125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2
763
- kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),32,108.76080322265625,108.76080322265625,108.76080322265625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:28,0.5.2
764
- kto_loss,liger,full,speed,ms,B,Batch Size (B),2,8.662687301635742,8.488287925720215,9.611334800720215,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2
765
- kto_loss,liger,full,speed,ms,B,Batch Size (B),4,18.40096092224121,17.99224281311035,18.57883644104004,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2
766
- kto_loss,liger,full,speed,ms,B,Batch Size (B),8,32.09159851074219,31.708070755004883,32.475128173828125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2
767
- kto_loss,liger,full,speed,ms,B,Batch Size (B),16,69.30239868164062,69.30239868164062,69.30239868164062,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2
768
- kto_loss,liger,full,speed,ms,B,Batch Size (B),32,124.2437744140625,124.2437744140625,124.2437744140625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:33:50,0.5.2
769
- kto_loss,huggingface,full,speed,ms,B,Batch Size (B),2,11.449472427368164,11.407564163208008,11.773555755615234,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2
770
- kto_loss,huggingface,full,speed,ms,B,Batch Size (B),4,20.871471405029297,20.862951278686523,20.879276275634766,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2
771
- kto_loss,huggingface,full,speed,ms,B,Batch Size (B),8,41.16409683227539,40.760780334472656,41.567413330078125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2
772
- kto_loss,huggingface,full,speed,ms,B,Batch Size (B),16,77.720703125,77.720703125,77.720703125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2
773
- kto_loss,huggingface,full,speed,ms,B,Batch Size (B),32,156.25794982910156,156.25794982910156,156.25794982910156,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:11,0.5.2
774
- kto_loss,liger,full,memory,MB,B,Batch Size (B),2,2027.48583984375,2027.48583984375,2027.48583984375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2
775
- kto_loss,liger,full,memory,MB,B,Batch Size (B),4,2789.736328125,2789.736328125,2789.736328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2
776
- kto_loss,liger,full,memory,MB,B,Batch Size (B),8,2801.751953125,2801.751953125,2801.751953125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2
777
- kto_loss,liger,full,memory,MB,B,Batch Size (B),16,2825.783203125,2825.783203125,2825.783203125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2
778
- kto_loss,liger,full,memory,MB,B,Batch Size (B),32,2873.845703125,2873.845703125,2873.845703125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:36,0.5.2
779
- kto_loss,huggingface,full,memory,MB,B,Batch Size (B),2,3786.7373046875,3786.7373046875,3786.7373046875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
780
- kto_loss,huggingface,full,memory,MB,B,Batch Size (B),4,5544.25390625,5544.25390625,5544.25390625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
781
- kto_loss,huggingface,full,memory,MB,B,Batch Size (B),8,9057.287109375,9057.287109375,9057.287109375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
782
- kto_loss,huggingface,full,memory,MB,B,Batch Size (B),16,16087.353515625,16087.353515625,16087.353515625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
783
- kto_loss,huggingface,full,memory,MB,B,Batch Size (B),32,30147.486328125,30147.486328125,30147.486328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2
784
754
  distill_jsd_loss,liger,forward,speed,ms,BT,B x T,1024,7.735536098480225,7.729177474975586,7.798131465911865,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
785
755
  distill_jsd_loss,liger,forward,speed,ms,BT,B x T,2048,15.20411205291748,15.165056228637695,15.226079940795898,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
786
756
  distill_jsd_loss,liger,forward,speed,ms,BT,B x T,4096,30.159456253051758,30.126911163330078,30.165311813354492,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2
@@ -805,4 +775,33 @@ distill_jsd_loss,torch,full,memory,MB,BT,B x T,1024,16174.0390625,16174.0390625,
805
775
  distill_jsd_loss,torch,full,memory,MB,BT,B x T,2048,23713.05078125,23713.05078125,23713.05078125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
806
776
  distill_jsd_loss,torch,full,memory,MB,BT,B x T,4096,38791.07421875,38791.07421875,38791.07421875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
807
777
  distill_jsd_loss,torch,full,memory,MB,BT,B x T,8192,68947.1015625,68947.1015625,68947.1015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
808
-
778
+ kto_loss,liger,forward,speed,ms,B,Batch Size (B),2,3.9951679706573486,3.991487979888916,4.002252578735352,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:22:44,0.5.4
779
+ kto_loss,liger,forward,speed,ms,B,Batch Size (B),4,7.8037919998168945,7.788575649261475,7.808595180511475,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:22:44,0.5.4
780
+ kto_loss,liger,forward,speed,ms,B,Batch Size (B),8,15.43172836303711,15.430015563964844,15.4335355758667,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:22:44,0.5.4
781
+ kto_loss,liger,forward,speed,ms,B,Batch Size (B),16,30.66864013671875,30.66431999206543,30.670501708984375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:22:44,0.5.4
782
+ kto_loss,liger,forward,speed,ms,B,Batch Size (B),32,61.1163215637207,61.1163215637207,61.1163215637207,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:22:44,0.5.4
783
+ kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),2,3.8766400814056396,3.8680384159088135,3.8897151947021484,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:01,0.5.4
784
+ kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),4,7.213727951049805,7.206470489501953,7.229574680328369,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:01,0.5.4
785
+ kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),8,13.828800201416016,13.810944557189941,13.834943771362305,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:01,0.5.4
786
+ kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),16,27.0930233001709,27.08517074584961,27.09713363647461,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:01,0.5.4
787
+ kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),32,54.13715362548828,54.13715362548828,54.13715362548828,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:01,0.5.4
788
+ kto_loss,liger,full,speed,ms,B,Batch Size (B),2,4.782928466796875,4.677459239959717,5.3430914878845215,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:18,0.5.4
789
+ kto_loss,liger,full,speed,ms,B,Batch Size (B),4,8.517248153686523,8.481344223022461,8.561504364013672,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:18,0.5.4
790
+ kto_loss,liger,full,speed,ms,B,Batch Size (B),8,16.547504425048828,16.513471603393555,16.678144454956055,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:18,0.5.4
791
+ kto_loss,liger,full,speed,ms,B,Batch Size (B),16,31.891263961791992,31.819705963134766,32.274131774902344,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:18,0.5.4
792
+ kto_loss,liger,full,speed,ms,B,Batch Size (B),32,62.953758239746094,62.953758239746094,62.953758239746094,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:18,0.5.4
793
+ kto_loss,huggingface,full,speed,ms,B,Batch Size (B),2,6.201632022857666,6.163315296173096,6.314668655395508,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:35,0.5.4
794
+ kto_loss,huggingface,full,speed,ms,B,Batch Size (B),4,11.156224250793457,11.142304420471191,11.207296371459961,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:35,0.5.4
795
+ kto_loss,huggingface,full,speed,ms,B,Batch Size (B),8,21.249855041503906,21.231891632080078,21.264543533325195,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:35,0.5.4
796
+ kto_loss,huggingface,full,speed,ms,B,Batch Size (B),16,41.55686569213867,41.536956787109375,41.57677459716797,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:35,0.5.4
797
+ kto_loss,huggingface,full,speed,ms,B,Batch Size (B),32,81.56924438476562,81.56924438476562,81.56924438476562,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:35,0.5.4
798
+ kto_loss,liger,full,memory,MB,B,Batch Size (B),2,2585.73876953125,2585.73876953125,2585.73876953125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:55,0.5.4
799
+ kto_loss,liger,full,memory,MB,B,Batch Size (B),4,3348.9892578125,3348.9892578125,3348.9892578125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:55,0.5.4
800
+ kto_loss,liger,full,memory,MB,B,Batch Size (B),8,3361.0048828125,3361.0048828125,3361.0048828125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:55,0.5.4
801
+ kto_loss,liger,full,memory,MB,B,Batch Size (B),16,3385.0361328125,3385.0361328125,3385.0361328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:55,0.5.4
802
+ kto_loss,liger,full,memory,MB,B,Batch Size (B),32,3433.0986328125,3433.0986328125,3433.0986328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:23:55,0.5.4
803
+ kto_loss,huggingface,full,memory,MB,B,Batch Size (B),2,4341.74951171875,4341.74951171875,4341.74951171875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
804
+ kto_loss,huggingface,full,memory,MB,B,Batch Size (B),4,6099.26513671875,6099.26513671875,6099.26513671875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
805
+ kto_loss,huggingface,full,memory,MB,B,Batch Size (B),8,9613.298828125,9613.298828125,9613.298828125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
806
+ kto_loss,huggingface,full,memory,MB,B,Batch Size (B),16,16643.365234375,16643.365234375,16643.365234375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
807
+ kto_loss,huggingface,full,memory,MB,B,Batch Size (B),32,30703.498046875,30703.498046875,30703.498046875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
@@ -81,6 +81,8 @@ class LigerJSDLoss(torch.nn.Module):
81
81
  teacher,
82
82
  self.teacher_lin.weight,
83
83
  target,
84
+ self.student_lin.bias,
85
+ self.teacher_lin.bias,
84
86
  self.weight_hard_loss,
85
87
  self.weight_soft_loss,
86
88
  )
@@ -149,7 +149,7 @@ def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
149
149
  y=target,
150
150
  preference_labels=preference_labels,
151
151
  kl=kl,
152
- )
152
+ )[0]
153
153
  elif provider == "huggingface":
154
154
  return torch_kto_loss(
155
155
  x=_input,
@@ -157,7 +157,7 @@ def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
157
157
  y=target,
158
158
  preference_labels=preference_labels,
159
159
  kl=kl,
160
- )
160
+ )[0]
161
161
 
162
162
  def full():
163
163
  y = fwd()
@@ -230,7 +230,7 @@ def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
230
230
  y=target,
231
231
  preference_labels=preference_labels,
232
232
  kl=kl,
233
- )
233
+ )[0]
234
234
  elif provider == "huggingface":
235
235
  return torch_kto_loss(
236
236
  x=_input,
@@ -238,7 +238,7 @@ def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
238
238
  y=target,
239
239
  preference_labels=preference_labels,
240
240
  kl=kl,
241
- )
241
+ )[0]
242
242
 
243
243
  if mode == "forward":
244
244
  ms_50, ms_20, ms_80 = triton.testing.do_bench(
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel"
7
- version = "0.5.4"
7
+ version = "0.5.5"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -39,8 +39,9 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
39
39
 
40
40
  return loss, chosen_rewards, rejected_rewards
41
41
 
42
- @staticmethod
42
+ @classmethod
43
43
  def forward(
44
+ cls,
44
45
  ctx,
45
46
  _input,
46
47
  weight,
@@ -52,27 +53,48 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
52
53
  label_smoothing=0.0,
53
54
  compute_nll_loss=True,
54
55
  compiled=True,
56
+ average_log_prob=False,
57
+ chunk_size=1,
55
58
  ):
56
- return LigerFusedLinearPreferenceBase.forward(
57
- ctx,
58
- _input,
59
- weight,
60
- target,
61
- bias,
62
- loss_fn=LigerFusedLinearCPOFunction.preference_loss_fn,
59
+ """
60
+ Fused linear layer with CPO loss.
61
+ Args:
62
+ _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
63
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
64
+ target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
65
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
66
+ ignore_index (int): Index to ignore in loss computation
67
+ beta (float): Weight for the odds ratio loss
68
+ alpha (float): Weight for the alpha parameter
69
+ label_smoothing (float): Label smoothing factor
70
+ compute_nll_loss (bool): Whether to compute the NLL loss
71
+ compiled (bool): Whether to use torch compile
72
+ average_log_prob (bool): Whether to average the log probability per non-masked token
73
+ chunk_size (int): Size of chunks for processing.
74
+ Returns:
75
+ torch.Tensor: Computed loss
76
+ """
77
+ return super().forward(
78
+ cls=cls,
79
+ ctx=ctx,
80
+ _input=_input,
81
+ weight=weight,
82
+ target=target,
83
+ bias=bias,
63
84
  ignore_index=ignore_index,
64
85
  alpha=alpha,
65
86
  beta=beta,
66
87
  label_smoothing=label_smoothing,
67
88
  compute_nll_loss=compute_nll_loss,
68
- average_log_prob=False,
89
+ average_log_prob=average_log_prob,
69
90
  compiled=compiled,
91
+ chunk_size=chunk_size,
70
92
  )
71
93
 
72
94
  @staticmethod
73
95
  def backward(ctx, *grad_output):
74
96
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
75
- return *grads, None, None, None, None, None, None
97
+ return *grads, None, None, None, None, None, None, None, None
76
98
 
77
99
 
78
100
  class LigerFusedLinearCPOLoss(torch.nn.Module):
@@ -88,11 +110,19 @@ class LigerFusedLinearCPOLoss(torch.nn.Module):
88
110
  label_smoothing: float = 0.0,
89
111
  compute_nll_loss: bool = True,
90
112
  compiled: bool = True,
113
+ average_log_prob: bool = False,
114
+ chunk_size: int = 1,
91
115
  ):
92
116
  """
93
117
  Args:
94
118
  ignore_index (int): Index to ignore in the loss.
95
119
  beta (float): Weight for the odds ratio loss.
120
+ alpha (float): Weight for the alpha parameter.
121
+ label_smoothing (float): Label smoothing factor.
122
+ compute_nll_loss (bool): Whether to compute the NLL loss.
123
+ compiled (bool): Whether to use the torch compiled kernel.
124
+ average_log_prob (bool): Whether to average the log probability per non-masked token.
125
+ chunk_size (int): Size of chunks for processing.
96
126
  """
97
127
  super().__init__()
98
128
  self.ignore_index = ignore_index
@@ -101,8 +131,16 @@ class LigerFusedLinearCPOLoss(torch.nn.Module):
101
131
  self.label_smoothing = label_smoothing
102
132
  self.compute_nll_loss = compute_nll_loss
103
133
  self.compiled = compiled
134
+ self.average_log_prob = average_log_prob
135
+ self.chunk_size = chunk_size
104
136
 
105
- def forward(self, lin_weight, _input, target, bias=None):
137
+ def forward(
138
+ self,
139
+ lin_weight,
140
+ _input,
141
+ target,
142
+ bias=None,
143
+ ):
106
144
  return LigerFusedLinearCPOFunction.apply(
107
145
  _input,
108
146
  lin_weight,
@@ -114,4 +152,6 @@ class LigerFusedLinearCPOLoss(torch.nn.Module):
114
152
  self.label_smoothing,
115
153
  self.compute_nll_loss,
116
154
  self.compiled,
155
+ self.average_log_prob,
156
+ self.chunk_size,
117
157
  )
@@ -52,8 +52,9 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
52
52
  loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
53
53
  return loss, chosen_rewards, rejected_rewards
54
54
 
55
- @staticmethod
55
+ @classmethod
56
56
  def forward(
57
+ cls,
57
58
  ctx,
58
59
  _input,
59
60
  weight,
@@ -67,14 +68,34 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
67
68
  compute_nll_loss=False,
68
69
  compiled=True,
69
70
  use_ref_model=True,
71
+ chunk_size=1,
70
72
  ):
71
- return LigerFusedLinearPreferenceBase.forward(
73
+ """
74
+ Fused linear layer with DPO loss.
75
+ Args:
76
+ _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
77
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
78
+ target (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
79
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
80
+ ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
81
+ ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
82
+ ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
83
+ ignore_index (int): Index to ignore in loss computation
84
+ beta (float): Weight for the odds ratio loss
85
+ compute_nll_loss (bool): Whether to compute the NLL loss
86
+ compiled (bool): Whether to use torch compile
87
+ use_ref_model (bool): Whether to use a reference model
88
+ chunk_size (int): Size of chunks for processing.
89
+ Returns:
90
+ torch.Tensor: Computed loss
91
+ """
92
+ return super().forward(
93
+ cls=cls,
72
94
  ctx=ctx,
73
95
  _input=_input,
74
96
  weight=weight,
75
97
  target=target,
76
98
  bias=bias,
77
- loss_fn=LigerFusedLinearDPOFunction.preference_loss_fn,
78
99
  ignore_index=ignore_index,
79
100
  beta=beta,
80
101
  compute_nll_loss=compute_nll_loss,
@@ -83,12 +104,13 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
83
104
  ref_input=ref_input,
84
105
  ref_weight=ref_weight,
85
106
  ref_bias=ref_bias,
107
+ chunk_size=chunk_size,
86
108
  )
87
109
 
88
110
  @staticmethod
89
111
  def backward(ctx, *grad_output):
90
112
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
91
- return *grads, None, None, None, None, None, None, None, None
113
+ return *grads, None, None, None, None, None, None, None, None, None
92
114
 
93
115
 
94
116
  class LigerFusedLinearDPOLoss(torch.nn.Module):
@@ -103,6 +125,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
103
125
  compute_nll_loss: bool = False,
104
126
  compiled: bool = True,
105
127
  use_ref_model: bool = True,
128
+ chunk_size: int = 1,
106
129
  ):
107
130
  """
108
131
  Args:
@@ -111,6 +134,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
111
134
  compute_nll_loss (bool): Whether to compute the NLL loss.
112
135
  compiled (bool): Whether to use the torch compiled kernel.
113
136
  use_ref_model (bool): Whether to use a reference model for the DPO loss.
137
+ chunk_size (int): Size of chunks for processing.
114
138
  """
115
139
  super().__init__()
116
140
  self.ignore_index = ignore_index
@@ -118,6 +142,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
118
142
  self.compute_nll_loss = compute_nll_loss
119
143
  self.compiled = compiled
120
144
  self.use_ref_model = use_ref_model
145
+ self.chunk_size = chunk_size
121
146
 
122
147
  def forward(
123
148
  self,
@@ -142,4 +167,5 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
142
167
  self.compute_nll_loss,
143
168
  self.compiled,
144
169
  self.use_ref_model,
170
+ self.chunk_size,
145
171
  )
@@ -125,6 +125,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
125
125
 
126
126
  @staticmethod
127
127
  def forward(
128
+ cls,
128
129
  ctx,
129
130
  student_input,
130
131
  student_weight,
@@ -133,7 +134,6 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
133
134
  target,
134
135
  student_bias=None,
135
136
  teacher_bias=None,
136
- loss_fn=None,
137
137
  chunk_size=1024,
138
138
  ignore_index=-100,
139
139
  weight_hard_loss=0.5,
@@ -175,7 +175,7 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
175
175
 
176
176
  loss_func_to_call = partial(
177
177
  LigerFusedLinearDistillationBase._compute_loss,
178
- distillation_loss_fn=loss_fn,
178
+ distillation_loss_fn=cls.distillation_loss_fn,
179
179
  full_target=target,
180
180
  ignore_index=ignore_index,
181
181
  weight_hard_loss=weight_hard_loss,
@@ -263,4 +263,4 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
263
263
  grad_weight = grad_weight * grad_output
264
264
  grad_bias = grad_bias * grad_output if grad_bias is not None else None
265
265
 
266
- return grad_input, grad_weight, None, grad_bias
266
+ return grad_input, grad_weight, None, None, None, grad_bias
@@ -16,12 +16,12 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
16
16
 
17
17
  @staticmethod
18
18
  def forward(
19
+ cls,
19
20
  ctx,
20
21
  _input,
21
22
  weight,
22
23
  target,
23
24
  bias=None,
24
- loss_fn=None,
25
25
  chunk_size=1,
26
26
  ignore_index=-100,
27
27
  alpha=1.0,
@@ -89,7 +89,7 @@ class LigerFusedLinearPreferenceBase(torch.autograd.Function):
89
89
 
90
90
  compute_loss = partial(
91
91
  LigerFusedLinearPreferenceBase._compute_loss,
92
- preference_loss_fn=loss_fn,
92
+ preference_loss_fn=cls.preference_loss_fn,
93
93
  ignore_index=ignore_index,
94
94
  alpha=alpha,
95
95
  beta=beta,
@@ -1,3 +1,4 @@
1
+ from abc import abstractmethod
1
2
  from functools import partial
2
3
 
3
4
  import torch
@@ -5,15 +6,22 @@ import torch.nn.functional as F
5
6
 
6
7
 
7
8
  class LigerFusedLinearRLHFBase(torch.autograd.Function):
9
+ @abstractmethod
10
+ def rlhf_loss_fn(*args, **kwargs):
11
+ """
12
+ To be extended by subclasses.
13
+ """
14
+ raise NotImplementedError("RLHF loss function must be implemented.")
15
+
8
16
  @staticmethod
9
17
  def forward(
18
+ cls,
10
19
  ctx,
11
20
  _input,
12
21
  weight,
13
22
  attention_mask,
14
23
  rewards,
15
24
  bias=None,
16
- loss_fn=None,
17
25
  num_generations=4,
18
26
  beta=0.1,
19
27
  compiled=True,
@@ -21,8 +29,27 @@ class LigerFusedLinearRLHFBase(torch.autograd.Function):
21
29
  ref_input=None,
22
30
  ref_weight=None,
23
31
  ref_bias=None,
32
+ chunk_size=1,
24
33
  ):
25
- """Chunked forward pass for RLHF loss computation."""
34
+ """Chunked forward pass for RLHF loss computation.
35
+
36
+ Args:
37
+ cls: The class
38
+ ctx: Context for backward
39
+ _input: Input tensor
40
+ weight: Weight tensor
41
+ attention_mask: Attention mask tensor
42
+ rewards: Rewards tensor
43
+ bias: Bias tensor
44
+ num_generations: Number of generations per prompt
45
+ beta: Weight for the KL penalty
46
+ compiled: Whether to use torch compile
47
+ use_ref_model: Whether to use a reference model
48
+ ref_input: Reference model input tensor
49
+ ref_weight: Reference model weight tensor
50
+ ref_bias: Reference model bias tensor
51
+ chunk_size: Size of chunks for processing in other loss modules
52
+ """
26
53
  # Save for backward
27
54
  ctx.beta = beta
28
55
  ctx.rewards = rewards
@@ -41,7 +68,7 @@ class LigerFusedLinearRLHFBase(torch.autograd.Function):
41
68
  use_ref_model=use_ref_model,
42
69
  ref_weight=ref_weight,
43
70
  ref_bias=ref_bias,
44
- rlhf_loss_fn=loss_fn,
71
+ rlhf_loss_fn=cls.rlhf_loss_fn,
45
72
  )
46
73
 
47
74
  def fused_fwd_bwd(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk):
@@ -98,7 +125,7 @@ class LigerFusedLinearRLHFBase(torch.autograd.Function):
98
125
  if compiled:
99
126
  accumulate_chunk = torch.compile(accumulate_chunk)
100
127
 
101
- # Process input in chunks
128
+ # Process input in chunks based on num_generations
102
129
  chunks = max(1, _input.shape[0] // num_generations)
103
130
  _input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
104
131
  _attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
@@ -202,12 +229,12 @@ class LigerFusedLinearRLHFBase(torch.autograd.Function):
202
229
  None, # grad_attention_mask
203
230
  None, # grad_rewards
204
231
  grad_bias,
205
- None, # grad_loss_fn
206
- None, # grad_chunk_size
232
+ None, # grad_num_generations
207
233
  None, # grad_beta
208
234
  None, # grad_compiled
209
235
  None, # grad_use_ref_model
210
236
  None, # grad_ref_input
211
237
  None, # grad_ref_weight
212
238
  None, # grad_ref_bias
239
+ None, # grad_chunk_size
213
240
  )