liger-kernel-nightly 0.5.5.dev20250331170510__tar.gz → 0.5.5.dev20250402185606__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (243) hide show
  1. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/dev/modal/tests.py +1 -1
  3. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/dev/modal/tests_bwd.py +1 -1
  4. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/pyproject.toml +1 -1
  5. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/chunked_loss/functional.py +2 -0
  6. liger_kernel_nightly-0.5.5.dev20250402185606/src/liger_kernel/chunked_loss/fused_linear_ppo.py +330 -0
  7. liger_kernel_nightly-0.5.5.dev20250402185606/src/liger_kernel/chunked_loss/grpo_loss.py +236 -0
  8. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/cross_entropy.py +3 -2
  9. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
  10. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel_nightly.egg-info/SOURCES.txt +1 -1
  11. liger_kernel_nightly-0.5.5.dev20250402185606/test/chunked_loss/test_grpo_loss.py +470 -0
  12. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/chunked_loss/test_orpo_loss.py +6 -0
  13. liger_kernel_nightly-0.5.5.dev20250331170510/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -240
  14. liger_kernel_nightly-0.5.5.dev20250331170510/src/liger_kernel/chunked_loss/grpo_loss.py +0 -194
  15. liger_kernel_nightly-0.5.5.dev20250331170510/test/chunked_loss/test_grpo_loss.py +0 -275
  16. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  17. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  18. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/.github/pull_request_template.md +0 -0
  19. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/.github/workflows/amd-ci.yml +0 -0
  20. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/.github/workflows/docs.yml +0 -0
  21. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/.github/workflows/intel-ci.yml +0 -0
  22. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/.github/workflows/nvi-ci.yml +0 -0
  23. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/.github/workflows/publish-nightly.yml +0 -0
  24. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/.github/workflows/publish-release.yml +0 -0
  25. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/.gitignore +0 -0
  26. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/LICENSE +0 -0
  27. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/Makefile +0 -0
  28. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/NOTICE +0 -0
  29. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/README.md +0 -0
  30. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/README.md +0 -0
  31. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/__init__.py +0 -0
  32. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/benchmarks_visualizer.py +0 -0
  33. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/data/all_benchmark_data.csv +0 -0
  34. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/__init__.py +0 -0
  35. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  36. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  37. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  38. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  39. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_dyt.py +0 -0
  40. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_embedding.py +0 -0
  41. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  42. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  43. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_geglu.py +0 -0
  44. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_group_norm.py +0 -0
  45. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_jsd.py +0 -0
  46. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_kl_div.py +0 -0
  47. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  48. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  49. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  50. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  51. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  52. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_rope.py +0 -0
  53. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  54. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_swiglu.py +0 -0
  55. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/benchmark_tvd.py +0 -0
  56. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/benchmark/scripts/utils.py +0 -0
  57. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/dev/fmt-requirements.txt +0 -0
  58. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/Examples.md +0 -0
  59. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/Getting-Started.md +0 -0
  60. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/High-Level-APIs.md +0 -0
  61. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/Low-Level-APIs.md +0 -0
  62. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/acknowledgement.md +0 -0
  63. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/contributing.md +0 -0
  64. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/images/banner.GIF +0 -0
  65. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/images/compose.gif +0 -0
  66. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/images/e2e-memory.png +0 -0
  67. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/images/e2e-tps.png +0 -0
  68. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/images/logo-banner.png +0 -0
  69. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/images/patch.gif +0 -0
  70. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/images/post-training.png +0 -0
  71. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/index.md +0 -0
  72. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/docs/license.md +0 -0
  73. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/alignment/accelerate_config.yaml +0 -0
  74. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/alignment/run_orpo.py +0 -0
  75. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/README.md +0 -0
  76. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/callback.py +0 -0
  77. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/config/fsdp_config.json +0 -0
  78. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  79. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  80. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  81. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/img/llama_tps.png +0 -0
  82. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  83. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/img/qwen_tps.png +0 -0
  84. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/launch_on_modal.py +0 -0
  85. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/requirements.txt +0 -0
  86. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/run_benchmarks.sh +0 -0
  87. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/run_gemma.sh +0 -0
  88. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/run_llama.sh +0 -0
  89. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/run_qwen.sh +0 -0
  90. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/run_qwen2_vl.sh +0 -0
  91. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/training.py +0 -0
  92. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/huggingface/training_multimodal.py +0 -0
  93. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/lightning/README.md +0 -0
  94. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/lightning/requirements.txt +0 -0
  95. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/lightning/training.py +0 -0
  96. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/README.md +0 -0
  97. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/callback.py +0 -0
  98. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  99. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  100. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  101. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  102. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  103. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  104. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  105. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  106. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  107. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/medusa_util.py +0 -0
  108. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/requirements.txt +0 -0
  109. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  110. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/examples/medusa/train.py +0 -0
  111. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/licenses/LICENSE-Apache-2.0 +0 -0
  112. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  113. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  114. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/licenses/LICENSE-MIT-llmc +0 -0
  115. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/licenses/LICENSE-MIT-triton +0 -0
  116. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/mkdocs.yml +0 -0
  117. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/setup.cfg +0 -0
  118. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/setup.py +0 -0
  119. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/__init__.py +0 -0
  120. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/chunked_loss/README.md +0 -0
  121. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  122. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  123. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  124. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  125. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  126. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  127. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  128. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  129. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  130. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  131. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/env_report.py +0 -0
  132. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/__init__.py +0 -0
  133. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/dyt.py +0 -0
  134. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  135. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  136. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  137. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  138. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/geglu.py +0 -0
  139. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/group_norm.py +0 -0
  140. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/jsd.py +0 -0
  141. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/kl_div.py +0 -0
  142. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/layer_norm.py +0 -0
  143. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  144. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/rms_norm.py +0 -0
  145. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/rope.py +0 -0
  146. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/swiglu.py +0 -0
  147. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/tvd.py +0 -0
  148. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/ops/utils.py +0 -0
  149. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/__init__.py +0 -0
  150. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/auto_model.py +0 -0
  151. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  152. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/dyt.py +0 -0
  153. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  154. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/functional.py +0 -0
  155. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  156. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  157. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/geglu.py +0 -0
  158. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/group_norm.py +0 -0
  159. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/jsd.py +0 -0
  160. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/kl_div.py +0 -0
  161. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/layer_norm.py +0 -0
  162. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/__init__.py +0 -0
  163. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/gemma.py +0 -0
  164. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  165. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/llama.py +0 -0
  166. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/llava.py +0 -0
  167. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
  168. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/mistral.py +0 -0
  169. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  170. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/mllama.py +0 -0
  171. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/olmo2.py +0 -0
  172. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/paligemma.py +0 -0
  173. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/phi3.py +0 -0
  174. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  175. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
  176. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  177. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  178. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  179. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/rms_norm.py +0 -0
  180. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/rope.py +0 -0
  181. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/swiglu.py +0 -0
  182. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  183. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  184. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  185. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/transformers/tvd.py +0 -0
  186. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/triton/__init__.py +0 -0
  187. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/triton/monkey_patch.py +0 -0
  188. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel/utils.py +0 -0
  189. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  190. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  191. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  192. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/__init__.py +0 -0
  193. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/chunked_loss/__init__.py +0 -0
  194. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/chunked_loss/test_cpo_loss.py +0 -0
  195. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/chunked_loss/test_dpo_loss.py +0 -0
  196. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/chunked_loss/test_jsd_loss.py +0 -0
  197. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/chunked_loss/test_kto_loss.py +0 -0
  198. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/chunked_loss/test_simpo_loss.py +0 -0
  199. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/conftest.py +0 -0
  200. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/convergence/__init__.py +0 -0
  201. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/convergence/bf16/__init__.py +0 -0
  202. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/convergence/bf16/test_mini_models.py +0 -0
  203. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
  204. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
  205. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/convergence/fp32/__init__.py +0 -0
  206. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/convergence/fp32/test_mini_models.py +0 -0
  207. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
  208. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
  209. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
  210. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
  211. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
  212. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
  213. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  214. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
  215. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  216. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  217. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/resources/tiny_shakespeare.txt +0 -0
  218. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  219. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  220. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  221. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_auto_model.py +0 -0
  222. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_cross_entropy.py +0 -0
  223. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_dyt.py +0 -0
  224. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_embedding.py +0 -0
  225. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_flex_attention.py +0 -0
  226. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  227. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_fused_linear_jsd.py +0 -0
  228. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_geglu.py +0 -0
  229. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_group_norm.py +0 -0
  230. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_jsd.py +0 -0
  231. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_kl_div.py +0 -0
  232. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_layer_norm.py +0 -0
  233. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_mm_int8int2.py +0 -0
  234. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_monkey_patch.py +0 -0
  235. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_qwen2vl_mrope.py +0 -0
  236. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_rms_norm.py +0 -0
  237. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_rope.py +0 -0
  238. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_swiglu.py +0 -0
  239. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_trainer_integration.py +0 -0
  240. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_transformers.py +0 -0
  241. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/transformers/test_tvd.py +0 -0
  242. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/triton/test_triton_monkey_patch.py +0 -0
  243. {liger_kernel_nightly-0.5.5.dev20250331170510 → liger_kernel_nightly-0.5.5.dev20250402185606}/test/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.5.dev20250331170510
3
+ Version: 0.5.5.dev20250402185606
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -14,7 +14,7 @@ app = modal.App("liger_tests", image=image)
14
14
  repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH)
15
15
 
16
16
 
17
- @app.function(gpu="A10G", mounts=[repo], timeout=60 * 20)
17
+ @app.function(gpu="A10G", mounts=[repo], timeout=60 * 30)
18
18
  def liger_tests():
19
19
  import subprocess
20
20
 
@@ -14,7 +14,7 @@ app = modal.App("liger_tests_bwd", image=image)
14
14
  repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH)
15
15
 
16
16
 
17
- @app.function(gpu="A10G", mounts=[repo], timeout=60 * 15)
17
+ @app.function(gpu="A10G", mounts=[repo], timeout=60 * 30)
18
18
  def liger_bwd_tests():
19
19
  import subprocess
20
20
 
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.5.dev20250331170510"
7
+ version = "0.5.5.dev20250402185606"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -1,5 +1,6 @@
1
1
  from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
2
2
  from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
3
+ from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction
3
4
  from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
4
5
  from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction
5
6
  from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
@@ -11,3 +12,4 @@ liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply
11
12
  liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
12
13
  liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
13
14
  liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply
15
+ liger_fused_linear_grpo = LigerFusedLinearGRPOFunction.apply
@@ -0,0 +1,330 @@
1
+ from abc import abstractmethod
2
+ from functools import partial
3
+
4
+ import torch
5
+ import torch._dynamo.config
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class LigerFusedLinearPPOBase(torch.autograd.Function):
10
+ @abstractmethod
11
+ def ppo_loss_fn(*args, **kwargs):
12
+ """
13
+ To be extended by subclasses.
14
+ """
15
+ raise NotImplementedError("PPO loss function must be implemented.")
16
+
17
+ @staticmethod
18
+ def forward(
19
+ cls,
20
+ ctx,
21
+ _input,
22
+ weight,
23
+ selected_token_ids,
24
+ attention_mask,
25
+ advantages,
26
+ bias=None,
27
+ ref_per_token_logps=None,
28
+ old_per_token_logps=None,
29
+ ref_input=None,
30
+ ref_weight=None,
31
+ ref_bias=None,
32
+ epsilon_low=0.2,
33
+ epsilon_high=0.2,
34
+ beta=0.04,
35
+ temperature=1.0,
36
+ compiled=True,
37
+ use_ref_model=False,
38
+ chunk_size=1,
39
+ ):
40
+ """Chunked forward pass for PPO loss computation.
41
+
42
+ Args:
43
+ cls: The class
44
+ ctx: Context for backward
45
+ _input: Input tensor
46
+ weight: Weight tensor
47
+ selected_token_ids: Selected token ids tensor
48
+ attention_mask: Attention mask tensor
49
+ advantages: Advantages tensor
50
+ bias: Bias tensor
51
+ ref_per_token_logps: Reference model log probs per token tensor
52
+ old_per_token_logps: Old per token log probabilities tensor
53
+ ref_input: Reference model input tensor
54
+ ref_weight: Reference model weight tensor
55
+ ref_bias: Reference model bias tensor
56
+ epsilon_low: Lower bound for clipping the importance sampling ratio
57
+ epsilon_high: Upper bound for clipping the importance sampling ratio
58
+ beta: Weight for the KL penalty
59
+ temperature: Temperature for the logits
60
+ compiled: Whether to use torch compile
61
+ use_ref_model: Whether to use a reference model
62
+ chunk_size: Size of chunks for processing in other loss modules
63
+ """
64
+ if use_ref_model:
65
+ assert ref_per_token_logps is not None or ref_input is not None, (
66
+ "If use_ref_model is True, ref_per_token_logps or ref_input must be provided"
67
+ )
68
+ if ref_per_token_logps is not None and ref_input is not None:
69
+ raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.")
70
+ # Initialize accumulators
71
+ loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32)
72
+ grad_weight = torch.zeros_like(weight) # [V, H]
73
+ grad_inputs = []
74
+ grad_bias = torch.zeros_like(bias) if bias is not None else None # [V]
75
+ aggregated_metrics = []
76
+
77
+ # Create a partial function with fixed arguments
78
+ compute_loss = partial(
79
+ LigerFusedLinearPPOBase._compute_chunk_loss,
80
+ ref_weight=ref_weight,
81
+ ref_bias=ref_bias,
82
+ full_attention_mask=attention_mask,
83
+ epsilon_low=epsilon_low,
84
+ epsilon_high=epsilon_high,
85
+ beta=beta,
86
+ temperature=temperature,
87
+ use_ref_model=use_ref_model,
88
+ ppo_loss_fn=cls.ppo_loss_fn,
89
+ )
90
+
91
+ def fused_fwd_bwd(
92
+ input_chunk,
93
+ selected_token_ids_chunk,
94
+ attention_mask_chunk,
95
+ advantages_chunk,
96
+ ref_per_token_logps_chunk,
97
+ old_per_token_logps_chunk,
98
+ ref_input_chunk,
99
+ ):
100
+ """Fused forward and backward for a chunk."""
101
+ argnums = (0, 1, 5) if bias is not None else (0, 1)
102
+ return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)(
103
+ input_chunk, # arg 0
104
+ weight, # arg 1
105
+ selected_token_ids_chunk, # arg 2
106
+ attention_mask_chunk, # arg 3
107
+ advantages_chunk, # arg 4
108
+ bias, # arg 5
109
+ ref_per_token_logps_chunk=ref_per_token_logps_chunk, # arg 6
110
+ old_per_token_logps_chunk=old_per_token_logps_chunk, # arg 7
111
+ ref_input_chunk=ref_input_chunk, # arg 8
112
+ )
113
+
114
+ def accumulate_chunk(
115
+ input_chunk,
116
+ selected_token_ids_chunk,
117
+ attention_mask_chunk,
118
+ advantages_chunk,
119
+ ref_per_token_logps_chunk=None,
120
+ old_per_token_logps_chunk=None,
121
+ ref_input_chunk=None,
122
+ ):
123
+ (chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
124
+ input_chunk,
125
+ selected_token_ids_chunk,
126
+ attention_mask_chunk,
127
+ advantages_chunk,
128
+ ref_per_token_logps_chunk,
129
+ old_per_token_logps_chunk,
130
+ ref_input_chunk,
131
+ )
132
+ if bias is not None:
133
+ grad_bias.add_(chunk_grad_bias[0])
134
+
135
+ # Accumulate gradients and loss
136
+ grad_weight.add_(chunk_grad_weight)
137
+ grad_inputs.append(chunk_grad_input)
138
+ loss_acc.add_(chunk_loss)
139
+ # Initialize storage for metrics on first chunk
140
+ if len(aggregated_metrics) == 0:
141
+ for metric in chunk_metrics:
142
+ if metric.ndim == 0:
143
+ aggregated_metrics.append(torch.zeros((), device=metric.device))
144
+ else:
145
+ aggregated_metrics.append([])
146
+
147
+ # Accumulate metrics
148
+ for i, metric in enumerate(chunk_metrics):
149
+ if metric.ndim == 0:
150
+ aggregated_metrics[i].add_(metric)
151
+ else:
152
+ aggregated_metrics[i].append(metric)
153
+
154
+ if compiled:
155
+ # TODO: Figure out what is better to compile here
156
+ # accumulate_chunk = torch.compile(accumulate_chunk)
157
+ fused_fwd_bwd = torch.compile(fused_fwd_bwd)
158
+
159
+ # Process input in chunks based on chunk_size
160
+ chunks = max(1, _input.shape[0] // chunk_size)
161
+ _input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
162
+ _selected_token_ids_chunks = torch.chunk(selected_token_ids, chunks=chunks, dim=0)
163
+ _attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
164
+ _advantages_chunks = torch.chunk(advantages, chunks=chunks, dim=0)
165
+ _ref_per_token_logps_chunks = (
166
+ torch.chunk(ref_per_token_logps, chunks=chunks, dim=0)
167
+ if use_ref_model and ref_per_token_logps is not None
168
+ else [None] * chunks
169
+ )
170
+ _old_per_token_logps_chunks = (
171
+ torch.chunk(old_per_token_logps, chunks=chunks, dim=0)
172
+ if old_per_token_logps is not None
173
+ else [None] * chunks
174
+ )
175
+ # if ref_log_probs is not none, then we don't need ref_input to calculate the log probs
176
+ _ref_input_chunks = (
177
+ torch.chunk(ref_input, chunks=chunks, dim=0)
178
+ if use_ref_model and ref_per_token_logps is None
179
+ else [None] * chunks
180
+ )
181
+
182
+ for (
183
+ input_chunk,
184
+ selected_token_ids_chunk,
185
+ attention_mask_chunk,
186
+ advantages_chunk,
187
+ ref_per_token_logps_chunk,
188
+ old_per_token_logps_chunk,
189
+ ref_input_chunk,
190
+ ) in zip(
191
+ _input_chunks,
192
+ _selected_token_ids_chunks,
193
+ _attention_mask_chunks,
194
+ _advantages_chunks,
195
+ _ref_per_token_logps_chunks,
196
+ _old_per_token_logps_chunks,
197
+ _ref_input_chunks,
198
+ ):
199
+ # Mark dynamic dimensions
200
+ torch._dynamo.mark_dynamic(input_chunk, 1)
201
+ torch._dynamo.mark_dynamic(selected_token_ids_chunk, 1)
202
+ torch._dynamo.mark_dynamic(attention_mask_chunk, 1)
203
+ if ref_per_token_logps_chunk is not None:
204
+ torch._dynamo.mark_dynamic(ref_per_token_logps_chunk, 1)
205
+ if ref_input_chunk is not None:
206
+ torch._dynamo.mark_dynamic(ref_input_chunk, 1)
207
+ if old_per_token_logps_chunk is not None:
208
+ torch._dynamo.mark_dynamic(old_per_token_logps_chunk, 1)
209
+
210
+ accumulate_chunk(
211
+ input_chunk,
212
+ selected_token_ids_chunk,
213
+ attention_mask_chunk,
214
+ advantages_chunk,
215
+ ref_per_token_logps_chunk,
216
+ old_per_token_logps_chunk,
217
+ ref_input_chunk,
218
+ )
219
+
220
+ # Combine gradients
221
+ grad_input = torch.cat(grad_inputs, dim=0)
222
+
223
+ # Save for backward
224
+ ctx.save_for_backward(grad_input, grad_weight, grad_bias)
225
+
226
+ # Finalize metrics
227
+ final_metrics = []
228
+ for metric in aggregated_metrics:
229
+ if isinstance(metric, list):
230
+ final_metrics.append(torch.cat(metric, dim=0))
231
+ else:
232
+ final_metrics.append(metric)
233
+
234
+ return loss_acc, tuple(final_metrics)
235
+
236
+ @staticmethod
237
+ def _compute_chunk_loss(
238
+ input_chunk,
239
+ weight,
240
+ selected_token_ids_chunk,
241
+ attention_mask_chunk,
242
+ advantages_chunk,
243
+ bias=None,
244
+ ref_per_token_logps_chunk=None,
245
+ old_per_token_logps_chunk=None,
246
+ ref_input_chunk=None,
247
+ ref_weight=None,
248
+ ref_bias=None,
249
+ full_attention_mask=None,
250
+ epsilon_low=0.2,
251
+ epsilon_high=0.2,
252
+ beta=0.04,
253
+ temperature=1.0,
254
+ use_ref_model=False,
255
+ ppo_loss_fn=None,
256
+ ):
257
+ """Compute loss for a single chunk."""
258
+ # Get policy log probabilities using chunk_forward
259
+ log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(input_chunk, weight, bias=bias, temperature=temperature)
260
+
261
+ # Get reference log probabilities if needed
262
+ ref_log_probs = None
263
+ if use_ref_model and ref_per_token_logps_chunk is None:
264
+ with torch.no_grad():
265
+ ref_log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(
266
+ ref_input_chunk, ref_weight, bias=ref_bias, temperature=temperature
267
+ )
268
+
269
+ # Compute chunk loss and metrics using the provided loss function
270
+ chunk_loss, chunk_metrics = ppo_loss_fn(
271
+ log_probs=log_probs,
272
+ selected_token_ids=selected_token_ids_chunk,
273
+ attention_mask=attention_mask_chunk,
274
+ advantages=advantages_chunk,
275
+ full_attention_mask=full_attention_mask,
276
+ ref_per_token_logps=ref_per_token_logps_chunk.float() if ref_per_token_logps_chunk is not None else None,
277
+ old_per_token_logps=old_per_token_logps_chunk.float() if old_per_token_logps_chunk is not None else None,
278
+ ref_log_probs=ref_log_probs, # used when ref_per_token_logps is None
279
+ epsilon_low=epsilon_low,
280
+ epsilon_high=epsilon_high,
281
+ beta=beta,
282
+ )
283
+
284
+ return chunk_loss, chunk_metrics
285
+
286
+ @staticmethod
287
+ def chunk_forward(input_chunk, weight, bias=None, temperature=1.0):
288
+ """Forward pass computation for a single chunk without explicit reshaping."""
289
+ # Directly compute logits via batched matrix multiplication: [B, T, H] @ [H, V] -> [B, T, V]
290
+ logits = torch.matmul(input_chunk, weight.t())
291
+ if bias is not None:
292
+ logits = logits + bias # Broadcasts bias to [B, T, V]
293
+ if temperature != 1.0:
294
+ logits = logits / temperature
295
+
296
+ # Compute log probabilities using softmax over the last dimension
297
+ log_probs = F.log_softmax(logits.float(), dim=-1)
298
+
299
+ return log_probs, logits
300
+
301
+ @staticmethod
302
+ def backward(ctx, grad_output, *grad_metrics):
303
+ """Backward pass for PPO loss."""
304
+ grad_input, grad_weight, grad_bias = ctx.saved_tensors
305
+ if grad_output != 1.0:
306
+ grad_input = grad_input * grad_output
307
+ grad_weight = grad_weight * grad_output
308
+ if grad_bias is not None:
309
+ grad_bias = grad_bias * grad_output
310
+
311
+ return (
312
+ grad_input,
313
+ grad_weight,
314
+ None, # grad_selected_token_ids
315
+ None, # grad_attention_mask
316
+ None, # grad_advantages
317
+ grad_bias,
318
+ None, # grad_ref_per_token_logps
319
+ None, # grad_old_per_token_logps
320
+ None, # grad_ref_input
321
+ None, # grad_ref_weight
322
+ None, # grad_ref_bias
323
+ None, # grad_epsilon_low
324
+ None, # grad_epsilon_high
325
+ None, # grad_beta
326
+ None, # grad_temperature
327
+ None, # grad_compiled
328
+ None, # grad_use_ref_model
329
+ None, # grad_chunk_size
330
+ )
@@ -0,0 +1,236 @@
1
+ import torch
2
+
3
+ from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase
4
+
5
+
6
+ def k3_loss_fn(log_p, log_q):
7
+ # computes k3 estimate of KL[q, p]
8
+ # ref: http://joschu.net/blog/kl-approx.html
9
+ return torch.exp(log_p - log_q) - (log_p - log_q) - 1.0
10
+
11
+
12
+ def clip_coef_fn(coef, epsilon_low, epsilon_high):
13
+ return torch.clamp(coef, 1 - epsilon_low, 1 + epsilon_high)
14
+
15
+
16
+ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
17
+ @staticmethod
18
+ def ppo_loss_fn(
19
+ log_probs,
20
+ selected_token_ids,
21
+ attention_mask,
22
+ advantages,
23
+ full_attention_mask,
24
+ ref_per_token_logps=None, # shape: [chunk_size, seq_len]
25
+ old_per_token_logps=None,
26
+ ref_log_probs=None, # used when ref_per_token_logps is None (shape: [chunk_size, seq_len, vocab_size])
27
+ epsilon_low=0.2,
28
+ epsilon_high=0.2,
29
+ beta=0.04,
30
+ **kwargs,
31
+ ):
32
+ """GRPO Loss Function matching GRPOTrainer implementation."""
33
+ per_token_logps = log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
34
+ -1
35
+ ) # (batch_size, seq_len)
36
+
37
+ # Get reference model probabilities
38
+ if ref_per_token_logps is None:
39
+ if ref_log_probs is not None:
40
+ with torch.no_grad():
41
+ ref_per_token_logps = ref_log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
42
+ -1
43
+ )
44
+ else:
45
+ ref_per_token_logps = per_token_logps.detach()
46
+
47
+ # Compute policy gradient loss with importance sampling ratio
48
+ old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
49
+ coef_1 = torch.exp(per_token_logps - old_per_token_logps)
50
+ coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
51
+ per_token_loss1 = coef_1 * advantages.unsqueeze(1)
52
+ per_token_loss2 = coef_2 * advantages.unsqueeze(1)
53
+ per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
54
+ if beta != 0.0:
55
+ # Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps])
56
+ kl_div = k3_loss_fn(ref_per_token_logps, per_token_logps)
57
+ # Combine losses
58
+ per_token_loss = per_token_loss + beta * kl_div
59
+
60
+ # Note: We normalize by the number of tokens in the batch (using full_attention_mask),
61
+ # which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1)
62
+ # and TRL GRPO implementation
63
+ # (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966)
64
+ loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
65
+
66
+ # Calculate metrics
67
+ metrics = []
68
+ if beta != 0.0:
69
+ metrics.append(((kl_div * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)))
70
+ is_clipped = ((coef_1 < 1 - epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
71
+ (coef_1 > 1 + epsilon_high) & (advantages.unsqueeze(1) > 0)
72
+ )
73
+ metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0))
74
+ return loss, metrics
75
+
76
+ @classmethod
77
+ def forward(
78
+ cls,
79
+ ctx,
80
+ _input,
81
+ weight,
82
+ selected_token_ids,
83
+ attention_mask,
84
+ advantages,
85
+ bias=None,
86
+ ref_per_token_logps=None,
87
+ old_per_token_logps=None,
88
+ ref_input=None,
89
+ ref_weight=None,
90
+ ref_bias=None,
91
+ beta=0.04,
92
+ epsilon_low=0.2,
93
+ epsilon_high=0.2,
94
+ temperature=1.0,
95
+ compiled=True,
96
+ use_ref_model=True,
97
+ chunk_size=1,
98
+ ):
99
+ """
100
+ Fused linear layer with GRPO loss.
101
+ Args:
102
+ _input (torch.Tensor): Input tensor. Shape: (batch_size * seq_len, hidden_size)
103
+ weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size)
104
+ selected_token_ids (torch.Tensor): Selected token ids tensor. Shape: (batch_size, seq_len)
105
+ attention_mask (torch.Tensor): Attention mask tensor. Shape: (batch_size, seq_len)
106
+ advantages (torch.Tensor): Advantages tensor. Shape: (batch_size,)
107
+ bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,)
108
+ ref_per_token_logps: Reference model log probs per token tensor. Shape:(batch_size, seq_len)
109
+ ref_input (torch.Tensor, optional): Reference model input tensor. Shape: (batch_size * seq_len, hidden_size)
110
+ ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
111
+ ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
112
+ beta (float): Weight for the KL penalty
113
+ temperature (float): Temperature for the logits
114
+ compiled (bool): Whether to use torch compile
115
+ use_ref_model (bool): Whether to use a reference model
116
+ chunk_size (int): Size of chunks for processing.
117
+ Returns:
118
+ torch.Tensor: Computed loss
119
+ """
120
+ return super().forward(
121
+ cls=cls,
122
+ ctx=ctx,
123
+ _input=_input,
124
+ weight=weight,
125
+ selected_token_ids=selected_token_ids,
126
+ attention_mask=attention_mask,
127
+ advantages=advantages,
128
+ bias=bias,
129
+ ref_per_token_logps=ref_per_token_logps,
130
+ old_per_token_logps=old_per_token_logps,
131
+ ref_input=ref_input,
132
+ ref_weight=ref_weight,
133
+ ref_bias=ref_bias,
134
+ beta=beta,
135
+ epsilon_low=epsilon_low,
136
+ epsilon_high=epsilon_high,
137
+ temperature=temperature,
138
+ compiled=compiled,
139
+ use_ref_model=use_ref_model,
140
+ chunk_size=chunk_size,
141
+ )
142
+
143
+ @staticmethod
144
+ def backward(ctx, grad_output, *grad_metrics):
145
+ """Backward pass for GRPO loss.
146
+
147
+ Args:
148
+ grad_output: Gradient of the loss (scalar)
149
+ grad_metrics: Gradients of the metrics (not used in backward computation)
150
+ """
151
+ grads = LigerFusedLinearPPOBase.backward(ctx, grad_output)
152
+ return (
153
+ *grads[
154
+ :6
155
+ ], # grad_input, grad_weight, grad_selected_token_ids, grad_attention_mask, grad_advantages, grad_bias
156
+ None, # grad_ref_per_token_logps
157
+ None, # grad_old_per_token_logps
158
+ None, # grad_ref_input
159
+ None, # grad_ref_weight
160
+ None, # grad_ref_bias
161
+ None, # grad_beta
162
+ None, # grad_epsilon_low
163
+ None, # grad_epsilon_high
164
+ None, # grad_temperature
165
+ None, # grad_compiled
166
+ None, # grad_use_ref_model
167
+ None, # grad_chunk_size
168
+ )
169
+
170
+
171
+ class LigerFusedLinearGRPOLoss(torch.nn.Module):
172
+ """Fused linear layer with GRPO loss."""
173
+
174
+ def __init__(
175
+ self,
176
+ beta: float = 0.04,
177
+ compiled: bool = True,
178
+ use_ref_model: bool = True,
179
+ chunk_size: int = 1,
180
+ epsilon_low: float = 0.2,
181
+ epsilon_high: float = 0.2,
182
+ temperature: float = 1.0,
183
+ ):
184
+ """
185
+ Args:
186
+ beta (float): Weight for the KL penalty.
187
+ compiled (bool): Whether to use torch compile.
188
+ use_ref_model (bool): Whether to use a reference model.
189
+ chunk_size (int): Size of chunks for processing.
190
+ epsilon_low (float): Lower bound for the importance sampling ratio.
191
+ epsilon_high (float): Upper bound for the importance sampling ratio.
192
+ temperature (float): Temperature for the logits.
193
+ """
194
+ super().__init__()
195
+ self.beta = beta
196
+ self.compiled = compiled
197
+ self.use_ref_model = use_ref_model
198
+ self.chunk_size = chunk_size
199
+ self.epsilon_low = epsilon_low
200
+ self.epsilon_high = epsilon_high
201
+ self.temperature = temperature
202
+
203
+ def forward(
204
+ self,
205
+ _input,
206
+ lin_weight,
207
+ selected_token_ids,
208
+ attention_mask,
209
+ advantages,
210
+ bias=None,
211
+ ref_per_token_logps=None,
212
+ old_per_token_logps=None,
213
+ ref_input=None,
214
+ ref_weight=None,
215
+ ref_bias=None,
216
+ ):
217
+ return LigerFusedLinearGRPOFunction.apply(
218
+ _input,
219
+ lin_weight,
220
+ selected_token_ids,
221
+ attention_mask,
222
+ advantages,
223
+ bias,
224
+ ref_per_token_logps,
225
+ old_per_token_logps,
226
+ ref_input,
227
+ ref_weight,
228
+ ref_bias,
229
+ self.beta,
230
+ self.epsilon_low,
231
+ self.epsilon_high,
232
+ self.temperature,
233
+ self.compiled,
234
+ self.use_ref_model,
235
+ self.chunk_size,
236
+ )
@@ -9,6 +9,7 @@ import triton.language as tl
9
9
  from liger_kernel.ops.utils import compare_version
10
10
  from liger_kernel.ops.utils import element_mul_kernel
11
11
  from liger_kernel.ops.utils import is_hip
12
+ from liger_kernel.utils import infer_device
12
13
 
13
14
  if compare_version("triton", operator.ge, "3.0.0"):
14
15
  try:
@@ -59,7 +60,7 @@ def liger_cross_entropy_kernel(
59
60
  z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
60
61
  loss_stride (int): The stride of the loss tensor.
61
62
  n_cols (int): The number of columns in the input tensor.
62
- n_non_ignore (flaot): The number of non-ignored elements in the batch.
63
+ n_non_ignore (float): The number of non-ignored elements in the batch.
63
64
  sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
64
65
  weight_sum (float): The sum of weight tensor.
65
66
  ignore_index (int): The index to ignore in the target.
@@ -258,7 +259,7 @@ def liger_cross_entropy_kernel(
258
259
  # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
259
260
  # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
260
261
  # The optimal maximum block size depends on your hardware, your kernel, and your dtype
261
- MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
262
+ MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning
262
263
 
263
264
 
264
265
  def cross_entropy_forward(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.5.dev20250331170510
3
+ Version: 0.5.5.dev20250402185606
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -112,8 +112,8 @@ src/liger_kernel/chunked_loss/cpo_loss.py
112
112
  src/liger_kernel/chunked_loss/dpo_loss.py
113
113
  src/liger_kernel/chunked_loss/functional.py
114
114
  src/liger_kernel/chunked_loss/fused_linear_distillation.py
115
+ src/liger_kernel/chunked_loss/fused_linear_ppo.py
115
116
  src/liger_kernel/chunked_loss/fused_linear_preference.py
116
- src/liger_kernel/chunked_loss/fused_linear_rlhf.py
117
117
  src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py
118
118
  src/liger_kernel/chunked_loss/grpo_loss.py
119
119
  src/liger_kernel/chunked_loss/jsd_loss.py