liger-kernel 0.5.8__tar.gz → 0.5.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/.github/workflows/intel-ci.yml +24 -11
  2. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/PKG-INFO +3 -1
  3. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/README.md +2 -0
  4. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/mkdocs.yml +3 -3
  5. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/pyproject.toml +1 -1
  6. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/setup.py +0 -1
  7. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/chunked_loss/dpo_loss.py +8 -1
  8. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/cross_entropy.py +4 -1
  9. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/fused_linear_cross_entropy.py +4 -3
  10. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/__init__.py +6 -0
  11. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +1 -2
  12. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/gemma.py +8 -4
  13. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/gemma2.py +8 -4
  14. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/gemma3.py +3 -1
  15. liger_kernel-0.5.9/src/liger_kernel/transformers/model/glm4.py +125 -0
  16. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/llama.py +8 -4
  17. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/mistral.py +8 -4
  18. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/mixtral.py +8 -4
  19. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/mllama.py +8 -4
  20. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/olmo2.py +8 -4
  21. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/phi3.py +8 -4
  22. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/qwen2.py +8 -4
  23. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/qwen2_5_vl.py +3 -1
  24. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/qwen2_vl.py +3 -1
  25. liger_kernel-0.5.9/src/liger_kernel/transformers/model/qwen3.py +118 -0
  26. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/monkey_patch.py +121 -0
  27. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel.egg-info/PKG-INFO +3 -1
  28. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel.egg-info/SOURCES.txt +2 -0
  29. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/convergence/bf16/test_mini_models.py +119 -0
  30. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/convergence/bf16/test_mini_models_multimodal.py +3 -4
  31. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/convergence/bf16/test_mini_models_with_logits.py +119 -1
  32. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/convergence/fp32/test_mini_models.py +113 -0
  33. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/convergence/fp32/test_mini_models_multimodal.py +7 -6
  34. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/convergence/fp32/test_mini_models_with_logits.py +113 -1
  35. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_cross_entropy.py +108 -2
  36. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_fused_linear_cross_entropy.py +5 -6
  37. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_monkey_patch.py +100 -0
  38. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_rms_norm.py +1 -1
  39. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/utils.py +24 -0
  40. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  41. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  42. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/.github/pull_request_template.md +0 -0
  43. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/.github/workflows/amd-ci.yml +0 -0
  44. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/.github/workflows/docs.yml +0 -0
  45. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/.github/workflows/nvi-ci.yml +0 -0
  46. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/.github/workflows/publish-nightly.yml +0 -0
  47. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/.github/workflows/publish-release.yml +0 -0
  48. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/.gitignore +0 -0
  49. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/LICENSE +0 -0
  50. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/Makefile +0 -0
  51. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/NOTICE +0 -0
  52. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/README.md +0 -0
  53. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/__init__.py +0 -0
  54. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/benchmarks_visualizer.py +0 -0
  55. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/data/all_benchmark_data.csv +0 -0
  56. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/__init__.py +0 -0
  57. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  58. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  59. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  60. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  61. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_dyt.py +0 -0
  62. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_embedding.py +0 -0
  63. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  64. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  65. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_geglu.py +0 -0
  66. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_group_norm.py +0 -0
  67. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_jsd.py +0 -0
  68. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_kl_div.py +0 -0
  69. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  70. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  71. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  72. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  73. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  74. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_rope.py +0 -0
  75. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  76. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_swiglu.py +0 -0
  77. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/benchmark_tvd.py +0 -0
  78. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/benchmark/scripts/utils.py +0 -0
  79. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/dev/fmt-requirements.txt +0 -0
  80. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/dev/modal/tests.py +0 -0
  81. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/dev/modal/tests_bwd.py +0 -0
  82. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/Examples.md +0 -0
  83. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/Getting-Started.md +0 -0
  84. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/High-Level-APIs.md +0 -0
  85. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/Low-Level-APIs.md +0 -0
  86. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/acknowledgement.md +0 -0
  87. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/contributing.md +0 -0
  88. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/images/banner.GIF +0 -0
  89. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/images/compose.gif +0 -0
  90. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/images/e2e-memory.png +0 -0
  91. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/images/e2e-tps.png +0 -0
  92. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/images/logo-banner.png +0 -0
  93. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/images/patch.gif +0 -0
  94. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/images/post-training.png +0 -0
  95. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/index.md +0 -0
  96. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/docs/license.md +0 -0
  97. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/alignment/accelerate_config.yaml +0 -0
  98. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/alignment/run_orpo.py +0 -0
  99. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/README.md +0 -0
  100. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/callback.py +0 -0
  101. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/config/fsdp_config.json +0 -0
  102. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  103. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  104. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  105. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/img/llama_tps.png +0 -0
  106. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  107. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/img/qwen_tps.png +0 -0
  108. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/launch_on_modal.py +0 -0
  109. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/requirements.txt +0 -0
  110. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/run_benchmarks.sh +0 -0
  111. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/run_gemma.sh +0 -0
  112. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/run_llama.sh +0 -0
  113. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/run_qwen.sh +0 -0
  114. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/run_qwen2_vl.sh +0 -0
  115. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/training.py +0 -0
  116. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/huggingface/training_multimodal.py +0 -0
  117. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/lightning/README.md +0 -0
  118. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/lightning/requirements.txt +0 -0
  119. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/lightning/training.py +0 -0
  120. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/README.md +0 -0
  121. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/callback.py +0 -0
  122. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  123. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  124. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  125. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  126. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  127. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  128. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  129. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  130. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  131. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/medusa_util.py +0 -0
  132. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/requirements.txt +0 -0
  133. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  134. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/examples/medusa/train.py +0 -0
  135. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/licenses/LICENSE-Apache-2.0 +0 -0
  136. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  137. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  138. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/licenses/LICENSE-MIT-llmc +0 -0
  139. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/licenses/LICENSE-MIT-triton +0 -0
  140. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/setup.cfg +0 -0
  141. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/__init__.py +0 -0
  142. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/chunked_loss/README.md +0 -0
  143. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  144. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  145. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/chunked_loss/functional.py +0 -0
  146. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  147. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/chunked_loss/fused_linear_ppo.py +0 -0
  148. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  149. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  150. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
  151. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  152. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  153. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  154. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  155. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/env_report.py +0 -0
  156. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/__init__.py +0 -0
  157. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/dyt.py +0 -0
  158. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  159. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  160. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  161. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/geglu.py +0 -0
  162. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/group_norm.py +0 -0
  163. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/jsd.py +0 -0
  164. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/kl_div.py +0 -0
  165. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/layer_norm.py +0 -0
  166. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  167. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/rms_norm.py +0 -0
  168. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/rope.py +0 -0
  169. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/swiglu.py +0 -0
  170. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/tvd.py +0 -0
  171. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/ops/utils.py +0 -0
  172. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/auto_model.py +0 -0
  173. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  174. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/dyt.py +0 -0
  175. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  176. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/functional.py +0 -0
  177. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  178. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/geglu.py +0 -0
  179. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/gema3_rms.py +0 -0
  180. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/group_norm.py +0 -0
  181. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/jsd.py +0 -0
  182. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/kl_div.py +0 -0
  183. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/layer_norm.py +0 -0
  184. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/__init__.py +0 -0
  185. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/llava.py +0 -0
  186. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
  187. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/model/paligemma.py +0 -0
  188. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  189. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/rms_norm.py +0 -0
  190. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/rope.py +0 -0
  191. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/swiglu.py +0 -0
  192. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  193. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  194. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  195. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/transformers/tvd.py +0 -0
  196. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/triton/__init__.py +0 -0
  197. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/triton/monkey_patch.py +0 -0
  198. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel/utils.py +0 -0
  199. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel.egg-info/dependency_links.txt +0 -0
  200. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel.egg-info/requires.txt +0 -0
  201. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/src/liger_kernel.egg-info/top_level.txt +0 -0
  202. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/__init__.py +0 -0
  203. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/chunked_loss/__init__.py +0 -0
  204. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/chunked_loss/test_cpo_loss.py +0 -0
  205. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/chunked_loss/test_dpo_loss.py +0 -0
  206. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/chunked_loss/test_grpo_loss.py +0 -0
  207. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/chunked_loss/test_jsd_loss.py +0 -0
  208. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/chunked_loss/test_kto_loss.py +0 -0
  209. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/chunked_loss/test_orpo_loss.py +0 -0
  210. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/chunked_loss/test_simpo_loss.py +0 -0
  211. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/conftest.py +0 -0
  212. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/convergence/__init__.py +0 -0
  213. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/convergence/bf16/__init__.py +0 -0
  214. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/convergence/fp32/__init__.py +0 -0
  215. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json +0 -0
  216. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
  217. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
  218. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
  219. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
  220. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  221. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
  222. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  223. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  224. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/resources/tiny_shakespeare.txt +0 -0
  225. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  226. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  227. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  228. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_auto_model.py +0 -0
  229. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_dyt.py +0 -0
  230. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_embedding.py +0 -0
  231. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_flex_attention.py +0 -0
  232. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_fused_linear_jsd.py +0 -0
  233. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_geglu.py +0 -0
  234. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_group_norm.py +0 -0
  235. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_jsd.py +0 -0
  236. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_kl_div.py +0 -0
  237. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_layer_norm.py +0 -0
  238. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_mm_int8int2.py +0 -0
  239. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_qwen2vl_mrope.py +0 -0
  240. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_rope.py +0 -0
  241. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_swiglu.py +0 -0
  242. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_trainer_integration.py +0 -0
  243. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_transformers.py +0 -0
  244. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/transformers/test_tvd.py +0 -0
  245. {liger_kernel-0.5.8 → liger_kernel-0.5.9}/test/triton/test_triton_monkey_patch.py +0 -0
@@ -45,27 +45,40 @@ jobs:
45
45
  run: make checkstyle
46
46
 
47
47
  tests:
48
- runs-on: linux-max1550-gpu-8
48
+ runs-on: linux-max1550-pvc-8
49
49
  needs: [checkstyle]
50
-
50
+ if: success()
51
+ container:
52
+ image: intel/oneapi-basekit:2025.0.1-0-devel-ubuntu24.04
53
+ options: --privileged -v /dev/dri/by-path:/dev/dri/by-path --device=/dev/dri --ipc=host
51
54
  steps:
55
+ - name: Set up python
56
+ shell: bash
57
+ run: |
58
+ apt-get update && \
59
+ apt-get install -y python3.12-venv python3-pip && \
60
+ ln -sf /usr/bin/python3 /usr/bin/python && \
61
+ apt-get clean && rm -rf /var/lib/apt/lists/*
62
+
52
63
  - name: Checkout code
53
64
  uses: actions/checkout@v3
54
-
55
- - name: Set up Python
56
- uses: actions/setup-python@v3
57
- with:
58
- python-version: '3.10'
59
-
65
+
60
66
  - name: Setup Dependencies
67
+ shell: bash
61
68
  run: |
62
- python -m pip install --upgrade pip
69
+ python -m venv test-env
70
+ . test-env/bin/activate
63
71
  pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/test/xpu
64
-
72
+
65
73
  - name: List Python Environments
66
- run: python -m pip list
74
+ shell: bash
75
+ run: |
76
+ . test-env/bin/activate
77
+ python -m pip list
67
78
 
68
79
  - name: Run Unit Tests
80
+ shell: bash
69
81
  run: |
82
+ . test-env/bin/activate
70
83
  make test
71
84
  make test-convergence
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: liger_kernel
3
- Version: 0.5.8
3
+ Version: 0.5.9
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -320,9 +320,11 @@ loss.backward()
320
320
  | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
321
321
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
322
322
  | Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
323
+ | Qwen3 | `liger_kernel.transformers.apply_liger_kernel_to_qwen3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
323
324
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
324
325
  | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
325
326
  | OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
327
+ | GLM-4 | `liger_kernel.transformers.apply_liger_kernel_to_glm4` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
326
328
 
327
329
 
328
330
  ## Low-level APIs
@@ -269,9 +269,11 @@ loss.backward()
269
269
  | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
270
270
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
271
271
  | Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
272
+ | Qwen3 | `liger_kernel.transformers.apply_liger_kernel_to_qwen3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
272
273
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
273
274
  | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
274
275
  | OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
276
+ | GLM-4 | `liger_kernel.transformers.apply_liger_kernel_to_glm4` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
275
277
 
276
278
 
277
279
  ## Low-level APIs
@@ -1,6 +1,6 @@
1
1
  site_name: Liger-Kernel Docs
2
- site_url: https://ligerkernel-io.github.io/ligerkernel
3
- site_author: Parag Ekbote
2
+ # site_url: ...
3
+ # site_author: LinkedIn
4
4
  site_description: Efficient Triton Kernels for LLM Training
5
5
  theme:
6
6
  name: material
@@ -66,4 +66,4 @@ edit_uri: edit/main/docs/
66
66
  extra:
67
67
  social:
68
68
  - icon: simple/github
69
- link: https://github.com/linkedin/Liger-Kernel
69
+ link: https://github.com/linkedin/Liger-Kernel
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel"
7
- version = "0.5.8"
7
+ version = "0.5.9"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -24,7 +24,6 @@ def get_default_dependencies():
24
24
  elif platform == "xpu":
25
25
  return [
26
26
  "torch>=2.6.0",
27
- "pytorch-triton-xpu>=3.2.0",
28
27
  ]
29
28
 
30
29
 
@@ -68,6 +68,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
68
68
  compute_nll_loss=False,
69
69
  compiled=True,
70
70
  use_ref_model=True,
71
+ average_log_prob=False,
71
72
  chunk_size=1,
72
73
  ):
73
74
  """
@@ -85,6 +86,7 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
85
86
  compute_nll_loss (bool): Whether to compute the NLL loss
86
87
  compiled (bool): Whether to use torch compile
87
88
  use_ref_model (bool): Whether to use a reference model
89
+ average_log_prob (bool): Whether to average the log probability per non-masked token
88
90
  chunk_size (int): Size of chunks for processing.
89
91
  Returns:
90
92
  torch.Tensor: Computed loss
@@ -104,13 +106,14 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
104
106
  ref_input=ref_input,
105
107
  ref_weight=ref_weight,
106
108
  ref_bias=ref_bias,
109
+ average_log_prob=average_log_prob,
107
110
  chunk_size=chunk_size,
108
111
  )
109
112
 
110
113
  @staticmethod
111
114
  def backward(ctx, *grad_output):
112
115
  grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
113
- return *grads, None, None, None, None, None, None, None, None, None
116
+ return *grads, None, None, None, None, None, None, None, None, None, None
114
117
 
115
118
 
116
119
  class LigerFusedLinearDPOLoss(torch.nn.Module):
@@ -125,6 +128,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
125
128
  compute_nll_loss: bool = False,
126
129
  compiled: bool = True,
127
130
  use_ref_model: bool = True,
131
+ average_log_prob: bool = True,
128
132
  chunk_size: int = 1,
129
133
  ):
130
134
  """
@@ -134,6 +138,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
134
138
  compute_nll_loss (bool): Whether to compute the NLL loss.
135
139
  compiled (bool): Whether to use the torch compiled kernel.
136
140
  use_ref_model (bool): Whether to use a reference model for the DPO loss.
141
+ average_log_prob (bool): Whether to average the log probability per non-masked token.
137
142
  chunk_size (int): Size of chunks for processing.
138
143
  """
139
144
  super().__init__()
@@ -142,6 +147,7 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
142
147
  self.compute_nll_loss = compute_nll_loss
143
148
  self.compiled = compiled
144
149
  self.use_ref_model = use_ref_model
150
+ self.average_log_prob = average_log_prob
145
151
  self.chunk_size = chunk_size
146
152
 
147
153
  def forward(
@@ -167,5 +173,6 @@ class LigerFusedLinearDPOLoss(torch.nn.Module):
167
173
  self.compute_nll_loss,
168
174
  self.compiled,
169
175
  self.use_ref_model,
176
+ self.average_log_prob,
170
177
  self.chunk_size,
171
178
  )
@@ -351,7 +351,10 @@ def cross_entropy_backward(_input, grad_output):
351
351
  # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
352
352
  if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
353
353
  pass
354
-
354
+ # If reduction is 'none'
355
+ elif grad_output.ndim > 0:
356
+ _input = _input * grad_output.unsqueeze(dim=1)
357
+ # If reduction is ['mean', 'sum'], grad_output is just a scalar
355
358
  # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
356
359
  # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
357
360
  else:
@@ -143,9 +143,10 @@ def fused_linear_cross_entropy_forward(
143
143
  alpha=1.0,
144
144
  )
145
145
 
146
- if reduction == "none":
147
- loss = loss_1d
148
- z_loss = z_loss_1d if return_z_loss else None
146
+ # Need extra calculations for backward if reduction=='none'. Not supporting reduction='none' now.
147
+ # if reduction == "none":
148
+ # loss = loss_1d
149
+ # z_loss = z_loss_1d if return_z_loss else None
149
150
 
150
151
  else:
151
152
  loss = torch.sum(loss_1d)
@@ -26,6 +26,7 @@ if TYPE_CHECKING:
26
26
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
27
27
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401
28
28
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401
29
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
29
30
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
30
31
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
31
32
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
@@ -38,6 +39,7 @@ if TYPE_CHECKING:
38
39
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
39
40
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_5_vl # noqa: F401
40
41
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
42
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401
41
43
 
42
44
 
43
45
  # Check if 'transformers' is installed
@@ -79,6 +81,7 @@ def __getattr__(name: str):
79
81
  "apply_liger_kernel_to_gemma2",
80
82
  "apply_liger_kernel_to_gemma3",
81
83
  "apply_liger_kernel_to_gemma3_text",
84
+ "apply_liger_kernel_to_glm4",
82
85
  "apply_liger_kernel_to_granite",
83
86
  "apply_liger_kernel_to_llama",
84
87
  "apply_liger_kernel_to_llava",
@@ -91,6 +94,7 @@ def __getattr__(name: str):
91
94
  "apply_liger_kernel_to_qwen2",
92
95
  "apply_liger_kernel_to_qwen2_5_vl",
93
96
  "apply_liger_kernel_to_qwen2_vl",
97
+ "apply_liger_kernel_to_qwen3",
94
98
  }
95
99
 
96
100
  if name in monkey_patch_symbols:
@@ -129,6 +133,7 @@ if _TRANSFORMERS_AVAILABLE:
129
133
  "apply_liger_kernel_to_gemma2",
130
134
  "apply_liger_kernel_to_gemma3",
131
135
  "apply_liger_kernel_to_gemma3_text",
136
+ "apply_liger_kernel_to_glm4",
132
137
  "apply_liger_kernel_to_granite",
133
138
  "apply_liger_kernel_to_llama",
134
139
  "apply_liger_kernel_to_llava",
@@ -141,5 +146,6 @@ if _TRANSFORMERS_AVAILABLE:
141
146
  "apply_liger_kernel_to_qwen2",
142
147
  "apply_liger_kernel_to_qwen2_5_vl",
143
148
  "apply_liger_kernel_to_qwen2_vl",
149
+ "apply_liger_kernel_to_qwen3",
144
150
  ]
145
151
  )
@@ -23,8 +23,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
23
23
  assert reduction in {
24
24
  "mean",
25
25
  "sum",
26
- "none",
27
- }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
26
+ }, f"reduction must be 'mean' or 'sum'. Got: {reduction}"
28
27
  assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
29
28
  self.ce_weight = ce_weight
30
29
  self.ignore_index = ignore_index
@@ -200,21 +200,25 @@ def lce_forward(
200
200
  )
201
201
 
202
202
  hidden_states = outputs[0]
203
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
204
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
205
+ kept_hidden_states = hidden_states[:, slice_indices, :]
203
206
 
207
+ shift_labels = loss_kwargs.pop("shift_labels", None)
204
208
  logits = None
205
209
  loss = None
206
210
  # if in training mode, don't materialize logits
207
- if self.training and (labels is not None):
211
+ if self.training and (labels is not None or shift_labels is not None):
208
212
  loss = LigerForCausalLMLoss(
209
- hidden_states=hidden_states,
213
+ hidden_states=kept_hidden_states,
210
214
  lm_head_weight=self.lm_head.weight,
211
215
  labels=labels,
216
+ shift_labels=shift_labels,
212
217
  hidden_size=self.config.hidden_size,
213
218
  **loss_kwargs,
214
219
  )
215
220
  else: # if in inference mode materialize logits
216
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
217
- logits = self.lm_head(hidden_states[:, slice_indices, :])
221
+ logits = self.lm_head(kept_hidden_states)
218
222
  if labels is not None:
219
223
  loss = self.loss_function(
220
224
  logits=logits,
@@ -212,23 +212,27 @@ def lce_forward(
212
212
  )
213
213
 
214
214
  hidden_states = outputs[0]
215
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
216
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
217
+ kept_hidden_states = hidden_states[:, slice_indices, :]
215
218
 
219
+ shift_labels = loss_kwargs.pop("shift_labels", None)
216
220
  logits = None
217
221
  loss = None
218
222
  # if in training mode, don't materialize logits
219
- if self.training and (labels is not None):
223
+ if self.training and (labels is not None or shift_labels is not None):
220
224
  loss = LigerForCausalLMLoss(
221
- hidden_states=hidden_states,
225
+ hidden_states=kept_hidden_states,
222
226
  lm_head_weight=self.lm_head.weight,
223
227
  labels=labels,
228
+ shift_labels=shift_labels,
224
229
  hidden_size=self.config.hidden_size,
225
230
  final_logit_softcapping=self.config.final_logit_softcapping,
226
231
  **loss_kwargs,
227
232
  )
228
233
 
229
234
  else: # if in inference mode materialize logits
230
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
231
- logits = self.lm_head(hidden_states[:, slice_indices, :])
235
+ logits = self.lm_head(kept_hidden_states)
232
236
  if self.config.final_logit_softcapping is not None:
233
237
  logits = logits / self.config.final_logit_softcapping
234
238
  logits = torch.tanh(logits)
@@ -104,13 +104,15 @@ def causal_forward(
104
104
  # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
105
105
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
106
106
  kept_hidden_states = hidden_states[:, slice_indices, :]
107
+ shift_labels = loss_kwargs.pop("shift_labels", None)
107
108
  loss = None
108
109
  logits = None
109
- if self.training and (labels is not None):
110
+ if self.training and (labels is not None or shift_labels is not None):
110
111
  loss = LigerForCausalLMLoss(
111
112
  hidden_states=kept_hidden_states,
112
113
  lm_head_weight=self.lm_head.weight,
113
114
  labels=labels,
115
+ shift_labels=shift_labels,
114
116
  hidden_size=self.config.hidden_size,
115
117
  final_logit_softcapping=self.config.final_logit_softcapping,
116
118
  **loss_kwargs,
@@ -0,0 +1,125 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from transformers.modeling_outputs import CausalLMOutputWithPast
9
+ from transformers.models.glm4.modeling_glm4 import _CONFIG_FOR_DOC
10
+ from transformers.models.glm4.modeling_glm4 import GLM4_INPUTS_DOCSTRING
11
+ from transformers.utils import add_start_docstrings_to_model_forward
12
+ from transformers.utils import replace_return_docstrings
13
+ from transformers.utils.deprecation import deprecate_kwarg
14
+
15
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
+
17
+
18
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
19
+ @add_start_docstrings_to_model_forward(GLM4_INPUTS_DOCSTRING)
20
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
21
+ def lce_forward(
22
+ self,
23
+ input_ids: torch.LongTensor = None,
24
+ attention_mask: Optional[torch.Tensor] = None,
25
+ position_ids: Optional[torch.LongTensor] = None,
26
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
27
+ inputs_embeds: Optional[torch.FloatTensor] = None,
28
+ labels: Optional[torch.LongTensor] = None,
29
+ use_cache: Optional[bool] = None,
30
+ output_attentions: Optional[bool] = None,
31
+ output_hidden_states: Optional[bool] = None,
32
+ return_dict: Optional[bool] = None,
33
+ cache_position: Optional[torch.LongTensor] = None,
34
+ logits_to_keep: Union[int, torch.Tensor] = 0,
35
+ **loss_kwargs,
36
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
37
+ r"""
38
+ Args:
39
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
40
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
41
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
42
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
43
+
44
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
45
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
46
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
47
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
48
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
49
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
50
+
51
+ Returns:
52
+
53
+ Example:
54
+
55
+ ```python
56
+ >>> from transformers import AutoTokenizer, Glm4ForCausalLM
57
+
58
+ >>> model = Glm4ForCausalLM.from_pretrained("THUDM/GLM-4-9B-0414")
59
+ >>> tokenizer = AutoTokenizer.from_pretrained("THUDM/GLM-4-9B-0414")
60
+
61
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
62
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
63
+
64
+ >>> # Generate
65
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
66
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
67
+ 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
68
+ ```
69
+ """
70
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
71
+ output_hidden_states = (
72
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
73
+ )
74
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
75
+
76
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
77
+ outputs = self.model(
78
+ input_ids=input_ids,
79
+ attention_mask=attention_mask,
80
+ position_ids=position_ids,
81
+ past_key_values=past_key_values,
82
+ inputs_embeds=inputs_embeds,
83
+ use_cache=use_cache,
84
+ output_attentions=output_attentions,
85
+ output_hidden_states=output_hidden_states,
86
+ return_dict=return_dict,
87
+ cache_position=cache_position,
88
+ )
89
+
90
+ hidden_states = outputs[0]
91
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
92
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
93
+ kept_hidden_states = hidden_states[:, slice_indices, :]
94
+
95
+ shift_labels = loss_kwargs.pop("shift_labels", None)
96
+ logits = None
97
+ loss = None
98
+ # if in training mode, don't materialize logits
99
+ if self.training and (labels is not None or shift_labels is not None):
100
+ loss = LigerForCausalLMLoss(
101
+ hidden_states=kept_hidden_states,
102
+ lm_head_weight=self.lm_head.weight,
103
+ labels=labels,
104
+ shift_labels=shift_labels,
105
+ hidden_size=self.config.hidden_size,
106
+ **loss_kwargs,
107
+ )
108
+
109
+ else: # if in inference mode materialize logits
110
+ logits = self.lm_head(kept_hidden_states)
111
+ if labels is not None:
112
+ loss = self.loss_function(
113
+ logits=logits,
114
+ labels=labels,
115
+ vocab_size=self.config.vocab_size,
116
+ **loss_kwargs,
117
+ )
118
+
119
+ return CausalLMOutputWithPast(
120
+ loss=loss,
121
+ logits=logits,
122
+ past_key_values=outputs.past_key_values,
123
+ hidden_states=outputs.hidden_states,
124
+ attentions=outputs.attentions,
125
+ )
@@ -209,25 +209,29 @@ def lce_forward(
209
209
  )
210
210
 
211
211
  hidden_states = outputs[0]
212
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
213
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
214
+ kept_hidden_states = hidden_states[:, slice_indices, :]
212
215
 
213
216
  if self.config.pretraining_tp > 1:
214
217
  raise Exception("Liger Kernel does not support pretraining_tp!!")
215
218
 
219
+ shift_labels = loss_kwargs.pop("shift_labels", None)
216
220
  logits = None
217
221
  loss = None
218
222
  # if in training mode, don't materialize logits
219
- if self.training and (labels is not None):
223
+ if self.training and (labels is not None or shift_labels is not None):
220
224
  loss = LigerForCausalLMLoss(
221
- hidden_states=hidden_states,
225
+ hidden_states=kept_hidden_states,
222
226
  lm_head_weight=self.lm_head.weight,
223
227
  labels=labels,
228
+ shift_labels=shift_labels,
224
229
  hidden_size=self.config.hidden_size,
225
230
  **loss_kwargs,
226
231
  )
227
232
 
228
233
  else: # if in inference mode materialize logits
229
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
230
- logits = self.lm_head(hidden_states[:, slice_indices, :])
234
+ logits = self.lm_head(kept_hidden_states)
231
235
  if labels is not None:
232
236
  loss = self.loss_function(
233
237
  logits=logits,
@@ -91,22 +91,26 @@ def lce_forward(
91
91
  )
92
92
 
93
93
  hidden_states = outputs[0]
94
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
95
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
96
+ kept_hidden_states = hidden_states[:, slice_indices, :]
94
97
 
98
+ shift_labels = loss_kwargs.pop("shift_labels", None)
95
99
  loss = None
96
100
  logits = None
97
101
 
98
- if self.training and (labels is not None):
102
+ if self.training and (labels is not None or shift_labels is not None):
99
103
  loss = LigerForCausalLMLoss(
100
- hidden_states=hidden_states,
104
+ hidden_states=kept_hidden_states,
101
105
  lm_head_weight=self.lm_head.weight,
102
106
  labels=labels,
107
+ shift_labels=shift_labels,
103
108
  hidden_size=self.config.hidden_size,
104
109
  **loss_kwargs,
105
110
  )
106
111
 
107
112
  else:
108
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
109
- logits = self.lm_head(hidden_states[:, slice_indices, :])
113
+ logits = self.lm_head(kept_hidden_states)
110
114
 
111
115
  loss = None
112
116
  if labels is not None:
@@ -225,22 +225,26 @@ def lce_forward(
225
225
  )
226
226
 
227
227
  hidden_states = outputs[0]
228
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
229
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
230
+ kept_hidden_states = hidden_states[:, slice_indices, :]
228
231
 
232
+ shift_labels = loss_kwargs.pop("shift_labels", None)
229
233
  logits = None
230
234
  loss = None
231
235
  # if in training mode, don't materialize logits
232
- if self.training and (labels is not None):
236
+ if self.training and (labels is not None or shift_labels is not None):
233
237
  loss = LigerForCausalLMLoss(
234
- hidden_states=hidden_states,
238
+ hidden_states=kept_hidden_states,
235
239
  lm_head_weight=self.lm_head.weight,
236
240
  labels=labels,
241
+ shift_labels=shift_labels,
237
242
  hidden_size=self.config.hidden_size,
238
243
  **loss_kwargs,
239
244
  )
240
245
 
241
246
  else: # if in inference mode materialize logits
242
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
243
- logits = self.lm_head(hidden_states[:, slice_indices, :])
247
+ logits = self.lm_head(kept_hidden_states)
244
248
 
245
249
  loss = None
246
250
  if labels is not None:
@@ -215,22 +215,26 @@ def lce_forward(
215
215
  )
216
216
 
217
217
  hidden_states = outputs[0]
218
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
219
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
220
+ kept_hidden_states = hidden_states[:, slice_indices, :]
218
221
 
222
+ shift_labels = loss_kwargs.pop("shift_labels", None)
219
223
  logits = None
220
224
  loss = None
221
225
  # if in training mode, don't materialize logits
222
- if self.training and (labels is not None):
226
+ if self.training and (labels is not None or shift_labels is not None):
223
227
  loss = LigerForCausalLMLoss(
224
- hidden_states=hidden_states,
228
+ hidden_states=kept_hidden_states,
225
229
  lm_head_weight=self.lm_head.weight,
226
230
  labels=labels,
231
+ shift_labels=shift_labels,
227
232
  hidden_size=self.config.hidden_size,
228
233
  **loss_kwargs,
229
234
  )
230
235
 
231
236
  else: # if in inference mode materialize logits
232
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
233
- logits = self.lm_head(hidden_states[:, slice_indices, :])
237
+ logits = self.lm_head(kept_hidden_states)
234
238
  if labels is not None:
235
239
  loss = self.loss_function(
236
240
  logits=logits,
@@ -88,22 +88,26 @@ def lce_forward(
88
88
  )
89
89
 
90
90
  hidden_states = outputs[0]
91
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
92
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
93
+ kept_hidden_states = hidden_states[:, slice_indices, :]
91
94
 
95
+ shift_labels = loss_kwargs.pop("shift_labels", None)
92
96
  logits = None
93
97
  loss = None
94
98
  # if in training mode, don't materialize logits
95
- if self.training and (labels is not None):
99
+ if self.training and (labels is not None or shift_labels is not None):
96
100
  loss = LigerForCausalLMLoss(
97
- hidden_states=hidden_states,
101
+ hidden_states=kept_hidden_states,
98
102
  lm_head_weight=self.lm_head.weight,
99
103
  labels=labels,
104
+ shift_labels=shift_labels,
100
105
  hidden_size=self.config.hidden_size,
101
106
  **loss_kwargs,
102
107
  )
103
108
 
104
109
  else: # if in inference mode materialize logits
105
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
106
- logits = self.lm_head(hidden_states[:, slice_indices, :])
110
+ logits = self.lm_head(kept_hidden_states)
107
111
  if labels is not None:
108
112
  loss = self.loss_function(
109
113
  logits=logits,