liger-kernel-nightly 0.5.6.dev20250408223717__tar.gz → 0.5.6.dev20250411210855__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (243) hide show
  1. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/pyproject.toml +1 -1
  3. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/chunked_loss/fused_linear_ppo.py +15 -0
  4. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/chunked_loss/grpo_loss.py +33 -1
  5. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/ops/kl_div.py +13 -6
  6. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
  7. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/chunked_loss/test_grpo_loss.py +35 -3
  8. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  9. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  10. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/.github/pull_request_template.md +0 -0
  11. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/.github/workflows/amd-ci.yml +0 -0
  12. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/.github/workflows/docs.yml +0 -0
  13. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/.github/workflows/intel-ci.yml +0 -0
  14. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/.github/workflows/nvi-ci.yml +0 -0
  15. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/.github/workflows/publish-nightly.yml +0 -0
  16. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/.github/workflows/publish-release.yml +0 -0
  17. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/.gitignore +0 -0
  18. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/LICENSE +0 -0
  19. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/Makefile +0 -0
  20. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/NOTICE +0 -0
  21. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/README.md +0 -0
  22. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/benchmark/README.md +0 -0
  23. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/benchmark/__init__.py +0 -0
  24. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/benchmark/benchmarks_visualizer.py +0 -0
  25. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/benchmark/data/all_benchmark_data.csv +0 -0
  26. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/benchmark/scripts/__init__.py +0 -0
  27. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  28. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  29. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  30. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  31. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/benchmark/scripts/benchmark_dyt.py +0 -0
  32. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/benchmark/scripts/benchmark_embedding.py +0 -0
  33. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  34. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  35. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/benchmark/scripts/benchmark_geglu.py +0 -0
  36. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/benchmark/scripts/benchmark_group_norm.py +0 -0
  37. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/benchmark/scripts/benchmark_jsd.py +0 -0
  38. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/benchmark/scripts/benchmark_kl_div.py +0 -0
  39. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  40. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  41. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  42. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  43. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  44. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/benchmark/scripts/benchmark_rope.py +0 -0
  45. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  46. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/benchmark/scripts/benchmark_swiglu.py +0 -0
  47. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/benchmark/scripts/benchmark_tvd.py +0 -0
  48. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/benchmark/scripts/utils.py +0 -0
  49. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/dev/fmt-requirements.txt +0 -0
  50. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/dev/modal/tests.py +0 -0
  51. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/dev/modal/tests_bwd.py +0 -0
  52. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/docs/Examples.md +0 -0
  53. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/docs/Getting-Started.md +0 -0
  54. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/docs/High-Level-APIs.md +0 -0
  55. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/docs/Low-Level-APIs.md +0 -0
  56. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/docs/acknowledgement.md +0 -0
  57. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/docs/contributing.md +0 -0
  58. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/docs/images/banner.GIF +0 -0
  59. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/docs/images/compose.gif +0 -0
  60. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/docs/images/e2e-memory.png +0 -0
  61. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/docs/images/e2e-tps.png +0 -0
  62. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/docs/images/logo-banner.png +0 -0
  63. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/docs/images/patch.gif +0 -0
  64. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/docs/images/post-training.png +0 -0
  65. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/docs/index.md +0 -0
  66. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/docs/license.md +0 -0
  67. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/alignment/accelerate_config.yaml +0 -0
  68. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/alignment/run_orpo.py +0 -0
  69. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/huggingface/README.md +0 -0
  70. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/huggingface/callback.py +0 -0
  71. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/huggingface/config/fsdp_config.json +0 -0
  72. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  73. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  74. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  75. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/huggingface/img/llama_tps.png +0 -0
  76. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  77. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/huggingface/img/qwen_tps.png +0 -0
  78. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/huggingface/launch_on_modal.py +0 -0
  79. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/huggingface/requirements.txt +0 -0
  80. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/huggingface/run_benchmarks.sh +0 -0
  81. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/huggingface/run_gemma.sh +0 -0
  82. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/huggingface/run_llama.sh +0 -0
  83. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/huggingface/run_qwen.sh +0 -0
  84. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/huggingface/run_qwen2_vl.sh +0 -0
  85. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/huggingface/training.py +0 -0
  86. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/huggingface/training_multimodal.py +0 -0
  87. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/lightning/README.md +0 -0
  88. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/lightning/requirements.txt +0 -0
  89. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/lightning/training.py +0 -0
  90. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/medusa/README.md +0 -0
  91. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/medusa/callback.py +0 -0
  92. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  93. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  94. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  95. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  96. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  97. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  98. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  99. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  100. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  101. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/medusa/medusa_util.py +0 -0
  102. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/medusa/requirements.txt +0 -0
  103. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  104. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/examples/medusa/train.py +0 -0
  105. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/licenses/LICENSE-Apache-2.0 +0 -0
  106. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  107. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  108. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/licenses/LICENSE-MIT-llmc +0 -0
  109. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/licenses/LICENSE-MIT-triton +0 -0
  110. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/mkdocs.yml +0 -0
  111. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/setup.cfg +0 -0
  112. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/setup.py +0 -0
  113. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/__init__.py +0 -0
  114. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/chunked_loss/README.md +0 -0
  115. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  116. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  117. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  118. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/chunked_loss/functional.py +0 -0
  119. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  120. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  121. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  122. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  123. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  124. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  125. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  126. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/env_report.py +0 -0
  127. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/ops/__init__.py +0 -0
  128. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/ops/cross_entropy.py +0 -0
  129. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/ops/dyt.py +0 -0
  130. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  131. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  132. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  133. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  134. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/ops/geglu.py +0 -0
  135. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/ops/group_norm.py +0 -0
  136. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/ops/jsd.py +0 -0
  137. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/ops/layer_norm.py +0 -0
  138. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  139. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/ops/rms_norm.py +0 -0
  140. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/ops/rope.py +0 -0
  141. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/ops/swiglu.py +0 -0
  142. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/ops/tvd.py +0 -0
  143. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/ops/utils.py +0 -0
  144. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/__init__.py +0 -0
  145. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/auto_model.py +0 -0
  146. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  147. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/dyt.py +0 -0
  148. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  149. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/functional.py +0 -0
  150. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  151. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  152. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/geglu.py +0 -0
  153. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/gema3_rms.py +0 -0
  154. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/group_norm.py +0 -0
  155. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/jsd.py +0 -0
  156. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/kl_div.py +0 -0
  157. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/layer_norm.py +0 -0
  158. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/model/__init__.py +0 -0
  159. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/model/gemma.py +0 -0
  160. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  161. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/model/gemma3.py +0 -0
  162. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/model/llama.py +0 -0
  163. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/model/llava.py +0 -0
  164. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
  165. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/model/mistral.py +0 -0
  166. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  167. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/model/mllama.py +0 -0
  168. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/model/olmo2.py +0 -0
  169. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/model/paligemma.py +0 -0
  170. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/model/phi3.py +0 -0
  171. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  172. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
  173. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  174. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  175. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  176. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/rms_norm.py +0 -0
  177. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/rope.py +0 -0
  178. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/swiglu.py +0 -0
  179. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  180. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  181. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  182. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/transformers/tvd.py +0 -0
  183. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/triton/__init__.py +0 -0
  184. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/triton/monkey_patch.py +0 -0
  185. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel/utils.py +0 -0
  186. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
  187. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  188. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  189. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  190. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/__init__.py +0 -0
  191. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/chunked_loss/__init__.py +0 -0
  192. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/chunked_loss/test_cpo_loss.py +0 -0
  193. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/chunked_loss/test_dpo_loss.py +0 -0
  194. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/chunked_loss/test_jsd_loss.py +0 -0
  195. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/chunked_loss/test_kto_loss.py +0 -0
  196. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/chunked_loss/test_orpo_loss.py +0 -0
  197. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/chunked_loss/test_simpo_loss.py +0 -0
  198. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/conftest.py +0 -0
  199. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/convergence/__init__.py +0 -0
  200. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/convergence/bf16/__init__.py +0 -0
  201. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/convergence/bf16/test_mini_models.py +0 -0
  202. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
  203. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
  204. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/convergence/fp32/__init__.py +0 -0
  205. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/convergence/fp32/test_mini_models.py +0 -0
  206. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
  207. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
  208. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json +0 -0
  209. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
  210. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
  211. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
  212. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
  213. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  214. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
  215. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  216. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  217. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/resources/tiny_shakespeare.txt +0 -0
  218. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  219. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  220. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  221. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/transformers/test_auto_model.py +0 -0
  222. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/transformers/test_cross_entropy.py +0 -0
  223. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/transformers/test_dyt.py +0 -0
  224. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/transformers/test_embedding.py +0 -0
  225. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/transformers/test_flex_attention.py +0 -0
  226. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  227. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/transformers/test_fused_linear_jsd.py +0 -0
  228. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/transformers/test_geglu.py +0 -0
  229. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/transformers/test_group_norm.py +0 -0
  230. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/transformers/test_jsd.py +0 -0
  231. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/transformers/test_kl_div.py +0 -0
  232. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/transformers/test_layer_norm.py +0 -0
  233. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/transformers/test_mm_int8int2.py +0 -0
  234. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/transformers/test_monkey_patch.py +0 -0
  235. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/transformers/test_qwen2vl_mrope.py +0 -0
  236. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/transformers/test_rms_norm.py +0 -0
  237. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/transformers/test_rope.py +0 -0
  238. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/transformers/test_swiglu.py +0 -0
  239. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/transformers/test_trainer_integration.py +0 -0
  240. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/transformers/test_transformers.py +0 -0
  241. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/transformers/test_tvd.py +0 -0
  242. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/triton/test_triton_monkey_patch.py +0 -0
  243. {liger_kernel_nightly-0.5.6.dev20250408223717 → liger_kernel_nightly-0.5.6.dev20250411210855}/test/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.6.dev20250408223717
3
+ Version: 0.5.6.dev20250411210855
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.6.dev20250408223717"
7
+ version = "0.5.6.dev20250411210855"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -32,6 +32,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
32
32
  epsilon_low=0.2,
33
33
  epsilon_high=0.2,
34
34
  beta=0.04,
35
+ loss_type="bnpo",
36
+ max_completion_length=None,
35
37
  temperature=1.0,
36
38
  compiled=True,
37
39
  use_ref_model=False,
@@ -57,6 +59,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
57
59
  epsilon_low: Lower bound for clipping the importance sampling ratio
58
60
  epsilon_high: Upper bound for clipping the importance sampling ratio
59
61
  beta: Weight for the KL penalty
62
+ loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo")
63
+ max_completion_length: Maximum completion length required for "dr_grpo"
60
64
  temperature: Temperature for the logits
61
65
  compiled: Whether to use torch compile
62
66
  use_ref_model: Whether to use a reference model
@@ -68,6 +72,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
68
72
  )
69
73
  if ref_per_token_logps is not None and ref_input is not None:
70
74
  raise Warning("Both ref_per_token_logps and ref_input are provided. Using ref_per_token_logps.")
75
+ if loss_type == "dr_grpo":
76
+ assert max_completion_length is not None, "max_completion_length must be provided for loss_type 'dr_grpo'"
71
77
  # Initialize accumulators
72
78
  loss_acc = torch.zeros((), device=_input.device, dtype=torch.float32)
73
79
  grad_weight = torch.zeros_like(weight) # [V, H]
@@ -84,6 +90,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
84
90
  epsilon_low=epsilon_low,
85
91
  epsilon_high=epsilon_high,
86
92
  beta=beta,
93
+ loss_type=loss_type,
94
+ max_completion_length=max_completion_length,
87
95
  temperature=temperature,
88
96
  use_ref_model=use_ref_model,
89
97
  ppo_loss_fn=cls.ppo_loss_fn,
@@ -251,6 +259,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
251
259
  epsilon_low=0.2,
252
260
  epsilon_high=0.2,
253
261
  beta=0.04,
262
+ loss_type="bnpo",
263
+ max_completion_length=None,
254
264
  temperature=1.0,
255
265
  use_ref_model=False,
256
266
  ppo_loss_fn=None,
@@ -280,6 +290,8 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
280
290
  epsilon_low=epsilon_low,
281
291
  epsilon_high=epsilon_high,
282
292
  beta=beta,
293
+ loss_type=loss_type,
294
+ max_completion_length=max_completion_length,
283
295
  )
284
296
 
285
297
  return chunk_loss, chunk_metrics
@@ -303,6 +315,7 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
303
315
  def backward(ctx, grad_output, *grad_metrics):
304
316
  """Backward pass for PPO loss."""
305
317
  grad_input, grad_weight, grad_bias = ctx.saved_tensors
318
+
306
319
  if grad_output != 1.0:
307
320
  grad_input = grad_input * grad_output
308
321
  grad_weight = grad_weight * grad_output
@@ -328,4 +341,6 @@ class LigerFusedLinearPPOBase(torch.autograd.Function):
328
341
  None, # grad_compiled
329
342
  None, # grad_use_ref_model
330
343
  None, # grad_chunk_size
344
+ None, # grad_loss_type
345
+ None, # grad_max_completion_length
331
346
  )
@@ -27,6 +27,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
27
27
  epsilon_low=0.2,
28
28
  epsilon_high=0.2,
29
29
  beta=0.04,
30
+ loss_type="bnpo", # ["grpo", "bnpo", "dr_grpo"]
31
+ max_completion_length=None, # Required for dr_grpo
30
32
  **kwargs,
31
33
  ):
32
34
  """GRPO Loss Function matching GRPOTrainer implementation."""
@@ -61,7 +63,21 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
61
63
  # which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1)
62
64
  # and TRL GRPO implementation
63
65
  # (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966)
64
- loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
66
+ if loss_type == "grpo":
67
+ # Average per-sequence loss
68
+ loss = (
69
+ (per_token_loss * attention_mask).sum(-1) / torch.clamp(attention_mask.sum(-1), min=1.0)
70
+ ).sum() / full_attention_mask.shape[0]
71
+ elif loss_type == "bnpo":
72
+ # Batch Normalized Per-token loss (original implementation)
73
+ loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
74
+ elif loss_type == "dr_grpo":
75
+ # Dimension-Reduced GRPO (normalize by batch_size * max_completion_length)
76
+ if max_completion_length is None:
77
+ raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
78
+ loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
79
+ else:
80
+ raise ValueError(f"Unknown loss type: {loss_type}")
65
81
 
66
82
  # Calculate metrics
67
83
  metrics = []
@@ -91,6 +107,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
91
107
  beta=0.04,
92
108
  epsilon_low=0.2,
93
109
  epsilon_high=0.2,
110
+ loss_type="bnpo",
111
+ max_completion_length=None,
94
112
  temperature=1.0,
95
113
  compiled=True,
96
114
  use_ref_model=True,
@@ -110,6 +128,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
110
128
  ref_weight (torch.Tensor, optional): Reference model weight tensor. Shape: (vocab_size, hidden_size)
111
129
  ref_bias (torch.Tensor, optional): Reference model bias tensor. Shape: (vocab_size,)
112
130
  beta (float): Weight for the KL penalty
131
+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
132
+ max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
113
133
  temperature (float): Temperature for the logits
114
134
  compiled (bool): Whether to use torch compile
115
135
  use_ref_model (bool): Whether to use a reference model
@@ -134,6 +154,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
134
154
  beta=beta,
135
155
  epsilon_low=epsilon_low,
136
156
  epsilon_high=epsilon_high,
157
+ loss_type=loss_type,
158
+ max_completion_length=max_completion_length,
137
159
  temperature=temperature,
138
160
  compiled=compiled,
139
161
  use_ref_model=use_ref_model,
@@ -161,6 +183,8 @@ class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
161
183
  None, # grad_beta
162
184
  None, # grad_epsilon_low
163
185
  None, # grad_epsilon_high
186
+ None, # grad_loss_type (string, not differentiable)
187
+ None, # grad_max_completion_length (int, not differentiable)
164
188
  None, # grad_temperature
165
189
  None, # grad_compiled
166
190
  None, # grad_use_ref_model
@@ -179,6 +203,8 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
179
203
  chunk_size: int = 1,
180
204
  epsilon_low: float = 0.2,
181
205
  epsilon_high: float = 0.2,
206
+ loss_type: str = "bnpo",
207
+ max_completion_length: int | None = None,
182
208
  temperature: float = 1.0,
183
209
  ):
184
210
  """
@@ -189,6 +215,8 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
189
215
  chunk_size (int): Size of chunks for processing.
190
216
  epsilon_low (float): Lower bound for the importance sampling ratio.
191
217
  epsilon_high (float): Upper bound for the importance sampling ratio.
218
+ loss_type (str): Type of loss calculation ("grpo", "bnpo", "dr_grpo"). Defaults to "bnpo".
219
+ max_completion_length (int, optional): Maximum completion length, required for "dr_grpo". Defaults to None.
192
220
  temperature (float): Temperature for the logits.
193
221
  """
194
222
  super().__init__()
@@ -198,6 +226,8 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
198
226
  self.chunk_size = chunk_size
199
227
  self.epsilon_low = epsilon_low
200
228
  self.epsilon_high = epsilon_high
229
+ self.loss_type = loss_type
230
+ self.max_completion_length = max_completion_length
201
231
  self.temperature = temperature
202
232
 
203
233
  def forward(
@@ -229,6 +259,8 @@ class LigerFusedLinearGRPOLoss(torch.nn.Module):
229
259
  self.beta,
230
260
  self.epsilon_low,
231
261
  self.epsilon_high,
262
+ self.loss_type,
263
+ self.max_completion_length,
232
264
  self.temperature,
233
265
  self.compiled,
234
266
  self.use_ref_model,
@@ -6,6 +6,7 @@ import triton.language as tl
6
6
 
7
7
  from liger_kernel.ops.utils import ensure_contiguous
8
8
  from liger_kernel.ops.utils import is_hip
9
+ from liger_kernel.utils import infer_device
9
10
 
10
11
 
11
12
  def get_num_warps(BLOCK_SIZE):
@@ -115,9 +116,12 @@ def _kldiv_kernel_backward(
115
116
 
116
117
  def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
117
118
  BT, V = y_pred.shape
118
-
119
- BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
120
- num_warps = get_num_warps(BLOCK_SIZE)
119
+ BLOCK_SIZE = (
120
+ min(8192, triton.next_power_of_2(V))
121
+ if infer_device() == "xpu"
122
+ else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
123
+ )
124
+ num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
121
125
 
122
126
  grid = (BT,)
123
127
  reduction = _str_to_reduction_mode[reduction]
@@ -155,9 +159,12 @@ def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]
155
159
 
156
160
  def kldiv_backward_triton(target, grad_output, new_grads, log_target):
157
161
  BT, V = target.shape
158
-
159
- BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
160
- num_warps = get_num_warps(BLOCK_SIZE)
162
+ BLOCK_SIZE = (
163
+ min(8192, triton.next_power_of_2(V))
164
+ if infer_device() == "xpu"
165
+ else min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
166
+ )
167
+ num_warps = 32 if infer_device() == "xpu" else get_num_warps(BLOCK_SIZE)
161
168
 
162
169
  grid = (BT,)
163
170
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.6.dev20250408223717
3
+ Version: 0.5.6.dev20250411210855
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -27,6 +27,8 @@ class TorchLMHeadGRPO(torch.nn.Module):
27
27
  epsilon_high: float = 0.2,
28
28
  temperature: float = 1.0,
29
29
  use_ref_model: bool = True,
30
+ loss_type: str = "bnpo",
31
+ max_completion_length: int | None = None,
30
32
  ):
31
33
  super().__init__()
32
34
  self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
@@ -36,6 +38,10 @@ class TorchLMHeadGRPO(torch.nn.Module):
36
38
  self.epsilon_high = epsilon_high
37
39
  self.temperature = temperature
38
40
  self.use_ref_model = use_ref_model
41
+ self.loss_type = loss_type
42
+ self.max_completion_length = max_completion_length
43
+ if self.loss_type == "dr_grpo":
44
+ assert self.max_completion_length is not None, "max_completion_length must be provided for dr_grpo"
39
45
 
40
46
  def forward(
41
47
  self,
@@ -89,8 +95,15 @@ class TorchLMHeadGRPO(torch.nn.Module):
89
95
  kl_div = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1.0
90
96
  per_token_loss = per_token_loss + self.beta * kl_div
91
97
 
92
- # Apply masking and normalize
93
- loss = (per_token_loss * attention_mask).sum() / torch.clamp(attention_mask.sum(), min=1.0)
98
+ # Apply masking and calculate loss based on loss_type
99
+ if self.loss_type == "grpo":
100
+ loss = ((per_token_loss * attention_mask).sum(-1) / torch.clamp(attention_mask.sum(-1), min=1.0)).mean()
101
+ elif self.loss_type == "bnpo":
102
+ loss = (per_token_loss * attention_mask).sum() / torch.clamp(attention_mask.sum(), min=1.0)
103
+ elif self.loss_type == "dr_grpo":
104
+ loss = (per_token_loss * attention_mask).sum() / (per_token_loss.size(0) * self.max_completion_length)
105
+ else:
106
+ raise ValueError(f"Unknown loss type: {self.loss_type}")
94
107
 
95
108
  # Compute metrics
96
109
  metrics = []
@@ -115,6 +128,8 @@ class LigerLMHeadGRPO(torch.nn.Module):
115
128
  epsilon_high: float = 0.2,
116
129
  temperature: float = 1.0,
117
130
  use_ref_model: bool = True,
131
+ loss_type: str = "bnpo",
132
+ max_completion_length: int | None = None,
118
133
  ):
119
134
  super().__init__()
120
135
  self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
@@ -126,6 +141,8 @@ class LigerLMHeadGRPO(torch.nn.Module):
126
141
  temperature=temperature,
127
142
  use_ref_model=use_ref_model,
128
143
  compiled=True,
144
+ loss_type=loss_type,
145
+ max_completion_length=max_completion_length,
129
146
  )
130
147
 
131
148
  def forward(
@@ -186,6 +203,7 @@ class LigerLMHeadGRPO(torch.nn.Module):
186
203
  ],
187
204
  )
188
205
  @pytest.mark.parametrize("old_per_token_logps", [True, False])
206
+ @pytest.mark.parametrize("loss_type", ["bnpo", "grpo", "dr_grpo"])
189
207
  def test_correctness(
190
208
  B,
191
209
  T,
@@ -203,9 +221,12 @@ def test_correctness(
203
221
  use_ref_per_token_logps,
204
222
  use_ref_model,
205
223
  old_per_token_logps,
224
+ loss_type,
206
225
  ):
207
226
  # Reset torch compiler cache for each parameter of the test case
208
227
  torch.compiler.reset()
228
+ max_completion_length = T if loss_type == "dr_grpo" else None
229
+
209
230
  torch_lm_head_grpo = TorchLMHeadGRPO(
210
231
  H=H,
211
232
  V=V,
@@ -216,6 +237,8 @@ def test_correctness(
216
237
  epsilon_high=epsilon_high,
217
238
  temperature=temperature,
218
239
  use_ref_model=use_ref_model,
240
+ loss_type=loss_type,
241
+ max_completion_length=max_completion_length,
219
242
  )
220
243
  liger_lm_head_grpo = LigerLMHeadGRPO(
221
244
  H=H,
@@ -227,6 +250,8 @@ def test_correctness(
227
250
  epsilon_high=epsilon_high,
228
251
  temperature=temperature,
229
252
  use_ref_model=use_ref_model,
253
+ loss_type=loss_type,
254
+ max_completion_length=max_completion_length,
230
255
  )
231
256
 
232
257
  # Initialize weights
@@ -319,7 +344,7 @@ def test_correctness(
319
344
  loss1.backward()
320
345
  loss2.backward()
321
346
 
322
- # Check gradients match
347
+ # Check gradients match for loss_type
323
348
  assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol)
324
349
  assert_verbose_allclose(
325
350
  torch_lm_head_grpo.lin.weight.grad,
@@ -351,6 +376,7 @@ def test_correctness(
351
376
  ],
352
377
  )
353
378
  @pytest.mark.parametrize("bias", [True, False])
379
+ @pytest.mark.parametrize("loss_type", ["bnpo", "grpo", "dr_grpo"])
354
380
  def test_functional_correctness(
355
381
  B,
356
382
  T,
@@ -361,9 +387,11 @@ def test_functional_correctness(
361
387
  atol,
362
388
  rtol,
363
389
  bias,
390
+ loss_type,
364
391
  ):
365
392
  # Reset torch compiler cache for each parameter of the test case
366
393
  torch.compiler.reset()
394
+ max_completion_length = T if loss_type == "dr_grpo" else None
367
395
  _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar
368
396
  input1 = _input.detach().clone().requires_grad_(True)
369
397
  input2 = _input.detach().clone().requires_grad_(True)
@@ -418,6 +446,8 @@ def test_functional_correctness(
418
446
  0.04,
419
447
  0.2,
420
448
  0.2,
449
+ loss_type,
450
+ max_completion_length,
421
451
  1.0,
422
452
  True,
423
453
  True,
@@ -439,6 +469,8 @@ def test_functional_correctness(
439
469
  0.04,
440
470
  0.2,
441
471
  0.2,
472
+ loss_type,
473
+ max_completion_length,
442
474
  1.0,
443
475
  True,
444
476
  True,