liger-kernel-nightly 0.5.9.dev20250519011716__tar.gz → 0.5.9.dev20250519015630__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (254) hide show
  1. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/pyproject.toml +1 -1
  3. liger_kernel_nightly-0.5.9.dev20250519015630/src/liger_kernel/ops/grpo_loss.py +310 -0
  4. liger_kernel_nightly-0.5.9.dev20250519015630/src/liger_kernel/transformers/grpo_loss.py +98 -0
  5. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
  6. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel_nightly.egg-info/SOURCES.txt +3 -0
  7. liger_kernel_nightly-0.5.9.dev20250519015630/test/transformers/test_grpo_loss.py +190 -0
  8. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  9. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  10. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/.github/pull_request_template.md +0 -0
  11. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/.github/workflows/amd-ci.yml +0 -0
  12. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/.github/workflows/docs.yml +0 -0
  13. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/.github/workflows/intel-ci.yml +0 -0
  14. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/.github/workflows/nvi-ci.yml +0 -0
  15. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/.github/workflows/publish-nightly.yml +0 -0
  16. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/.github/workflows/publish-release.yml +0 -0
  17. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/.gitignore +0 -0
  18. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/.idea/workspace.xml +0 -0
  19. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/LICENSE +0 -0
  20. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/Makefile +0 -0
  21. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/NOTICE +0 -0
  22. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/README.md +0 -0
  23. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/README.md +0 -0
  24. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/__init__.py +0 -0
  25. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/benchmarks_visualizer.py +0 -0
  26. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/data/all_benchmark_data.csv +0 -0
  27. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/__init__.py +0 -0
  28. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  29. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  30. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  31. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  32. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_dyt.py +0 -0
  33. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_embedding.py +0 -0
  34. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  35. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  36. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_geglu.py +0 -0
  37. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_group_norm.py +0 -0
  38. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_jsd.py +0 -0
  39. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_kl_div.py +0 -0
  40. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  41. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  42. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  43. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  44. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  45. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_rope.py +0 -0
  46. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  47. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_sparsemax.py +0 -0
  48. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_swiglu.py +0 -0
  49. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_tvd.py +0 -0
  50. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/utils.py +0 -0
  51. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/dev/fmt-requirements.txt +0 -0
  52. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/dev/modal/tests.py +0 -0
  53. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/dev/modal/tests_bwd.py +0 -0
  54. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/Examples.md +0 -0
  55. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/Getting-Started.md +0 -0
  56. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/High-Level-APIs.md +0 -0
  57. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/Low-Level-APIs.md +0 -0
  58. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/acknowledgement.md +0 -0
  59. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/contributing.md +0 -0
  60. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/images/banner.GIF +0 -0
  61. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/images/compose.gif +0 -0
  62. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/images/e2e-memory.png +0 -0
  63. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/images/e2e-tps.png +0 -0
  64. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/images/logo-banner.png +0 -0
  65. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/images/patch.gif +0 -0
  66. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/images/post-training.png +0 -0
  67. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/index.md +0 -0
  68. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/license.md +0 -0
  69. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/alignment/accelerate_config.yaml +0 -0
  70. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/alignment/run_orpo.py +0 -0
  71. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/README.md +0 -0
  72. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/callback.py +0 -0
  73. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/config/fsdp_config.json +0 -0
  74. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  75. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  76. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  77. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/img/llama_tps.png +0 -0
  78. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  79. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/img/qwen_tps.png +0 -0
  80. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/launch_on_modal.py +0 -0
  81. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/requirements.txt +0 -0
  82. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/run_benchmarks.sh +0 -0
  83. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/run_gemma.sh +0 -0
  84. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/run_llama.sh +0 -0
  85. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/run_qwen.sh +0 -0
  86. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/run_qwen2_vl.sh +0 -0
  87. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/training.py +0 -0
  88. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/training_multimodal.py +0 -0
  89. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/lightning/README.md +0 -0
  90. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/lightning/requirements.txt +0 -0
  91. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/lightning/training.py +0 -0
  92. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/README.md +0 -0
  93. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/callback.py +0 -0
  94. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  95. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  96. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  97. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  98. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  99. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  100. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  101. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  102. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  103. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/medusa_util.py +0 -0
  104. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/requirements.txt +0 -0
  105. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  106. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/train.py +0 -0
  107. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/licenses/LICENSE-Apache-2.0 +0 -0
  108. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  109. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  110. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/licenses/LICENSE-MIT-llmc +0 -0
  111. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/licenses/LICENSE-MIT-triton +0 -0
  112. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/mkdocs.yml +0 -0
  113. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/setup.cfg +0 -0
  114. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/setup.py +0 -0
  115. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/__init__.py +0 -0
  116. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/chunked_loss/README.md +0 -0
  117. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  118. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  119. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  120. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/chunked_loss/functional.py +0 -0
  121. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  122. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/chunked_loss/fused_linear_ppo.py +0 -0
  123. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  124. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  125. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
  126. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  127. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  128. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  129. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  130. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/env_report.py +0 -0
  131. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/__init__.py +0 -0
  132. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/cross_entropy.py +0 -0
  133. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/dyt.py +0 -0
  134. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  135. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  136. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  137. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  138. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/geglu.py +0 -0
  139. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/group_norm.py +0 -0
  140. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/jsd.py +0 -0
  141. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/kl_div.py +0 -0
  142. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/layer_norm.py +0 -0
  143. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  144. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/rms_norm.py +0 -0
  145. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/rope.py +0 -0
  146. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/sparsemax.py +0 -0
  147. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/swiglu.py +0 -0
  148. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/tvd.py +0 -0
  149. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/utils.py +0 -0
  150. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/__init__.py +0 -0
  151. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/auto_model.py +0 -0
  152. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  153. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/dyt.py +0 -0
  154. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  155. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/functional.py +0 -0
  156. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  157. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  158. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/geglu.py +0 -0
  159. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/gema3_rms.py +0 -0
  160. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/group_norm.py +0 -0
  161. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/jsd.py +0 -0
  162. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/kl_div.py +0 -0
  163. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/layer_norm.py +0 -0
  164. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/__init__.py +0 -0
  165. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/gemma.py +0 -0
  166. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  167. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/gemma3.py +0 -0
  168. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/glm4.py +0 -0
  169. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/llama.py +0 -0
  170. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/llava.py +0 -0
  171. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
  172. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/mistral.py +0 -0
  173. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  174. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/mllama.py +0 -0
  175. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/olmo2.py +0 -0
  176. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/paligemma.py +0 -0
  177. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/phi3.py +0 -0
  178. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  179. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
  180. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  181. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/qwen3.py +0 -0
  182. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/qwen3_moe.py +0 -0
  183. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  184. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  185. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/rms_norm.py +0 -0
  186. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/rope.py +0 -0
  187. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/sparsemax.py +0 -0
  188. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/swiglu.py +0 -0
  189. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  190. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  191. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  192. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/tvd.py +0 -0
  193. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/triton/__init__.py +0 -0
  194. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/triton/monkey_patch.py +0 -0
  195. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/utils.py +0 -0
  196. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  197. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  198. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  199. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/__init__.py +0 -0
  200. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/chunked_loss/__init__.py +0 -0
  201. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/chunked_loss/test_cpo_loss.py +0 -0
  202. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/chunked_loss/test_dpo_loss.py +0 -0
  203. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/chunked_loss/test_grpo_loss.py +0 -0
  204. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/chunked_loss/test_jsd_loss.py +0 -0
  205. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/chunked_loss/test_kto_loss.py +0 -0
  206. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/chunked_loss/test_orpo_loss.py +0 -0
  207. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/chunked_loss/test_simpo_loss.py +0 -0
  208. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/conftest.py +0 -0
  209. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/convergence/__init__.py +0 -0
  210. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/convergence/bf16/__init__.py +0 -0
  211. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/convergence/bf16/test_mini_models.py +0 -0
  212. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
  213. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
  214. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/convergence/fp32/__init__.py +0 -0
  215. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/convergence/fp32/test_mini_models.py +0 -0
  216. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
  217. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
  218. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json +0 -0
  219. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
  220. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
  221. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
  222. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
  223. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  224. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
  225. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  226. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  227. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/resources/tiny_shakespeare.txt +0 -0
  228. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  229. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  230. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  231. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_auto_model.py +0 -0
  232. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_cross_entropy.py +0 -0
  233. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_dyt.py +0 -0
  234. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_embedding.py +0 -0
  235. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_flex_attention.py +0 -0
  236. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  237. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_fused_linear_jsd.py +0 -0
  238. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_geglu.py +0 -0
  239. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_group_norm.py +0 -0
  240. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_jsd.py +0 -0
  241. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_kl_div.py +0 -0
  242. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_layer_norm.py +0 -0
  243. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_mm_int8int2.py +0 -0
  244. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_monkey_patch.py +0 -0
  245. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_qwen2vl_mrope.py +0 -0
  246. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_rms_norm.py +0 -0
  247. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_rope.py +0 -0
  248. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_sparsemax.py +0 -0
  249. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_swiglu.py +0 -0
  250. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_trainer_integration.py +0 -0
  251. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_transformers.py +0 -0
  252. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_tvd.py +0 -0
  253. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/triton/test_triton_monkey_patch.py +0 -0
  254. {liger_kernel_nightly-0.5.9.dev20250519011716 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.9.dev20250519011716
3
+ Version: 0.5.9.dev20250519015630
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.9.dev20250519011716"
7
+ version = "0.5.9.dev20250519015630"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -0,0 +1,310 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def _selective_log_softmax_kernel(
8
+ LOGITS,
9
+ INPUT_IDS,
10
+ LOG_P,
11
+ MASK,
12
+ TEMPERATURE,
13
+ stride_input_ids_b,
14
+ L: tl.constexpr,
15
+ N: tl.constexpr,
16
+ BLOCK_N: tl.constexpr = 4096,
17
+ ):
18
+ off_b = tl.program_id(0).cast(tl.int64)
19
+ off_l = tl.program_id(1).cast(tl.int64)
20
+
21
+ LOGITS += off_b * (L + 1) * N + off_l * N
22
+ INPUT_IDS += off_b * stride_input_ids_b + off_l
23
+ LOG_P += off_b * L + off_l
24
+
25
+ if MASK is not None:
26
+ MASK += off_b * stride_input_ids_b + off_l
27
+ not_skip = tl.load(MASK)
28
+ if not_skip == 0:
29
+ return
30
+
31
+ m_i = float("-inf")
32
+ l_i = 0.0
33
+ for start in range(0, N, BLOCK_N):
34
+ cols = start + tl.arange(0, BLOCK_N)
35
+ logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
36
+ new_m_i = tl.maximum(m_i, tl.max(logits))
37
+ alpha = tl.exp(m_i - new_m_i)
38
+ l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
39
+ m_i = new_m_i
40
+ lse = m_i + tl.log(l_i)
41
+
42
+ ids = tl.load(INPUT_IDS)
43
+ x = tl.load(LOGITS + ids).to(tl.float32) / TEMPERATURE
44
+ logp = x - lse
45
+ tl.store(LOG_P, logp)
46
+
47
+
48
+ # compue old_logp and ref_logp, it reduce 10G peak Memory. it does not requires grad
49
+ @torch.no_grad
50
+ def fused_selective_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 0.9, mask=None):
51
+ assert logits.is_contiguous()
52
+ B, L_ADD_1, N = logits.shape
53
+ L = L_ADD_1 - 1
54
+ input_ids = input_ids[:, -L:]
55
+ if mask is not None:
56
+ mask = mask[:, -L:]
57
+ log_p = torch.zeros(B, L, dtype=torch.float32, device=logits.device)
58
+ kwargs = {"BLOCK_N": 2048, "num_stages": 4, "num_warps": 1}
59
+ _selective_log_softmax_kernel[(B, L)](
60
+ logits, input_ids, log_p, mask, temperature, input_ids.stride(0), L, N, **kwargs
61
+ )
62
+ return log_p
63
+
64
+
65
+ # @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
66
+ # for BLOCK_N in [2048, 4096, 8192]
67
+ # for ns in [1, 2, 4]
68
+ # for nw in [1, 2, 4, 8, 16]],
69
+ # key=['N'])
70
+ @triton.jit
71
+ def _grpo_loss_fwd_kernel(
72
+ LOGITS,
73
+ OLD_LOGP,
74
+ REF_LOGP,
75
+ INPUT_IDS,
76
+ COMPLETION_MASK,
77
+ ADVANTAGES,
78
+ LOSS,
79
+ LSE,
80
+ KL,
81
+ IS_CLIPPED,
82
+ TEMPERATURE,
83
+ BETA: tl.constexpr,
84
+ EPS_LOW,
85
+ EPS_HIGH,
86
+ L: tl.constexpr,
87
+ N: tl.constexpr,
88
+ BLOCK_N: tl.constexpr = 4096,
89
+ ):
90
+ off_b = tl.program_id(0).cast(tl.int64)
91
+ off_l = tl.program_id(1).cast(tl.int64)
92
+
93
+ if COMPLETION_MASK is not None:
94
+ COMPLETION_MASK += off_b * L + off_l
95
+ not_skip = tl.load(COMPLETION_MASK)
96
+ if not_skip == 0:
97
+ return
98
+
99
+ LOGITS += off_b * (L + 1) * N + off_l * N
100
+ INPUT_IDS += off_b * L + off_l
101
+ ADVANTAGES += off_b
102
+ LOSS += off_b * L + off_l
103
+ LSE += off_b * L + off_l
104
+ IS_CLIPPED += off_b * L + off_l
105
+
106
+ m_i = float("-inf")
107
+ l_i = 0.0
108
+ for start in range(0, N, BLOCK_N):
109
+ cols = start + tl.arange(0, BLOCK_N)
110
+ logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
111
+ new_m_i = tl.maximum(m_i, tl.max(logits))
112
+ alpha = tl.exp(m_i - new_m_i)
113
+ l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
114
+ m_i = new_m_i
115
+ lse = m_i + tl.log(l_i)
116
+
117
+ idx = tl.load(INPUT_IDS)
118
+ x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
119
+ logp = x - lse
120
+ if OLD_LOGP is None:
121
+ old_logp = logp
122
+ else:
123
+ OLD_LOGP += off_b * L + off_l
124
+ old_logp = tl.load(OLD_LOGP).to(tl.float32)
125
+ coef_1 = tl.exp(logp - old_logp)
126
+ coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
127
+ advantage = tl.load(ADVANTAGES).to(tl.float32)
128
+ per_token_loss1 = coef_1 * advantage
129
+ per_token_loss2 = coef_2 * advantage
130
+ per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
131
+ is_clipped = per_token_loss1 < per_token_loss2
132
+
133
+ if BETA != 0.0:
134
+ REF_LOGP += off_b * L + off_l
135
+ KL += off_b * L + off_l
136
+ ref_logp = tl.load(REF_LOGP).to(tl.float32)
137
+ kl = tl.exp(ref_logp - logp) - (ref_logp - logp) - 1
138
+ per_token_loss += BETA * kl
139
+ tl.store(KL, kl)
140
+
141
+ tl.store(LOSS, per_token_loss)
142
+ tl.store(LSE, lse)
143
+ tl.store(IS_CLIPPED, is_clipped)
144
+
145
+
146
+ # @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
147
+ # for BLOCK_N in [2048, 4096, 8192]
148
+ # for ns in [1, 2, 4]
149
+ # for nw in [1, 2, 4, 8, 16]],
150
+ # key=['N'])
151
+ @triton.jit
152
+ def _grpo_loss_bwd_kernel(
153
+ DLOSS,
154
+ DLOGITS,
155
+ LOGITS,
156
+ OLD_LOGP,
157
+ REF_LOGP,
158
+ INPUT_IDS,
159
+ ADVANTAGES,
160
+ COMPLETION_MASK,
161
+ LSE,
162
+ TEMPERATURE,
163
+ BETA: tl.constexpr,
164
+ EPS_LOW,
165
+ EPS_HIGH,
166
+ loss_stride0,
167
+ loss_stride1,
168
+ L: tl.constexpr,
169
+ N: tl.constexpr,
170
+ BLOCK_N: tl.constexpr = 4096,
171
+ ):
172
+ off_b = tl.program_id(0).cast(tl.int64)
173
+ off_l = tl.program_id(1).cast(tl.int64)
174
+
175
+ DLOGITS += off_b * (L + 1) * N + off_l * N
176
+ if COMPLETION_MASK is not None:
177
+ COMPLETION_MASK += off_b * L + off_l
178
+ not_skip = tl.load(COMPLETION_MASK)
179
+ if not_skip == 0:
180
+ for start in range(0, N, BLOCK_N):
181
+ cols = tl.arange(0, BLOCK_N) + start
182
+ tl.store(DLOGITS + cols, 0.0, mask=cols < N)
183
+ return
184
+
185
+ LOGITS += off_b * (L + 1) * N + off_l * N
186
+ DLOSS += off_b * loss_stride0 + off_l * loss_stride1
187
+ INPUT_IDS += off_b * L + off_l
188
+ ADVANTAGES += off_b
189
+ LSE += off_b * L + off_l
190
+
191
+ dloss = tl.load(DLOSS).to(tl.float32)
192
+ lse = tl.load(LSE).to(tl.float32)
193
+
194
+ idx = tl.load(INPUT_IDS)
195
+ x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
196
+ logp = x - lse
197
+ if OLD_LOGP is None:
198
+ old_logp = logp
199
+ else:
200
+ OLD_LOGP += off_b * L + off_l
201
+ old_logp = tl.load(OLD_LOGP).to(tl.float32)
202
+ coef_1 = tl.exp(logp - old_logp)
203
+ coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
204
+ advantage = tl.load(ADVANTAGES).to(tl.float32)
205
+ per_token_loss1 = coef_1 * advantage
206
+ per_token_loss2 = coef_2 * advantage
207
+ mask = per_token_loss2 >= per_token_loss1
208
+
209
+ dlogp = -per_token_loss1 * mask
210
+ if BETA != 0.0:
211
+ REF_LOGP += off_b * L + off_l
212
+ ref_logp = tl.load(REF_LOGP).to(tl.float32)
213
+ dlogp += BETA * (1 - tl.exp(ref_logp - logp))
214
+
215
+ dlogp = dlogp * dloss / TEMPERATURE
216
+ tl.debug_barrier()
217
+ for start_n in tl.range(0, N, BLOCK_N):
218
+ cols = start_n + tl.arange(0, BLOCK_N)
219
+ logits = tl.load(LOGITS + cols, mask=cols < N, other=-float("inf")).to(tl.float32) / TEMPERATURE
220
+ probs = tl.exp(logits - lse)
221
+ dlogits = tl.where(cols == idx, 1 - probs, -probs) * dlogp
222
+ tl.store(DLOGITS + cols, dlogits, mask=cols < N)
223
+
224
+
225
+ class GrpoLossFunction(torch.autograd.Function):
226
+ @staticmethod
227
+ def forward(
228
+ ctx,
229
+ logits,
230
+ old_logp,
231
+ ref_logp,
232
+ completion_ids,
233
+ advantages,
234
+ completion_mask,
235
+ temperature,
236
+ beta,
237
+ eps_low,
238
+ eps_high,
239
+ inplace,
240
+ ):
241
+ assert logits.is_contiguous() and completion_ids.is_contiguous()
242
+ assert old_logp is None or old_logp.is_contiguous()
243
+ assert (ref_logp is not None and ref_logp.is_contiguous()) if beta != 0.0 else True
244
+
245
+ B, L_ADD_1, N = logits.shape
246
+ L = L_ADD_1 - 1
247
+
248
+ if completion_mask is not None:
249
+ assert completion_mask.is_contiguous()
250
+
251
+ loss = torch.zeros(B, L, device=logits.device, dtype=torch.float32)
252
+ lse = torch.zeros_like(loss)
253
+ is_clipped = torch.zeros_like(loss)
254
+ kl = torch.zeros_like(loss) if beta != 0.0 else None
255
+ kwargs = {"BLOCK_N": 2048, "num_stages": 2, "num_warps": 1}
256
+ _grpo_loss_fwd_kernel[(B, L)](
257
+ logits,
258
+ old_logp,
259
+ ref_logp,
260
+ completion_ids,
261
+ completion_mask,
262
+ advantages,
263
+ loss,
264
+ lse,
265
+ kl,
266
+ is_clipped,
267
+ temperature,
268
+ beta,
269
+ eps_low,
270
+ eps_high,
271
+ L,
272
+ N,
273
+ **kwargs,
274
+ )
275
+ ctx.save_for_backward(logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse)
276
+ ctx.infos = (temperature, beta, eps_low, eps_high, inplace)
277
+ # return loss
278
+ return loss, kl, is_clipped
279
+
280
+ @staticmethod
281
+ def backward(ctx, *args):
282
+ dloss = args[0]
283
+ # print(dloss.shape)
284
+ logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse = ctx.saved_tensors
285
+ temperature, beta, eps_low, eps_high, inplace = ctx.infos
286
+ B, L_ADD_1, N = logits.shape
287
+ L = L_ADD_1 - 1
288
+ dlogits = logits.data if inplace else torch.empty_like(logits)
289
+ kwargs = {"BLOCK_N": 4096, "num_stages": 1, "num_warps": 16}
290
+ _grpo_loss_bwd_kernel[(B, L)](
291
+ dloss,
292
+ dlogits,
293
+ logits,
294
+ old_logp,
295
+ ref_logp,
296
+ completion_ids,
297
+ advantages,
298
+ completion_mask,
299
+ lse,
300
+ temperature,
301
+ beta,
302
+ eps_low,
303
+ eps_high,
304
+ *dloss.stride(),
305
+ L,
306
+ N,
307
+ **kwargs,
308
+ )
309
+ dlogits[:, -1, :] = 0
310
+ return dlogits, None, None, None, None, None, None, None, None, None, None
@@ -0,0 +1,98 @@
1
+ from liger_kernel.ops.grpo_loss import GrpoLossFunction
2
+
3
+
4
+ def triton_grpo_loss(
5
+ logits,
6
+ old_logp,
7
+ ref_logp,
8
+ completion_ids,
9
+ advantages,
10
+ completion_mask=None,
11
+ temperature=0.9,
12
+ beta=0.04,
13
+ eps_low=0.2,
14
+ eps_high=0.4,
15
+ inplace=True,
16
+ ):
17
+ assert logits is not None and completion_ids is not None and advantages is not None, (
18
+ "must provide logits、completion_ids and advantages"
19
+ )
20
+
21
+ return GrpoLossFunction.apply(
22
+ logits,
23
+ old_logp,
24
+ ref_logp,
25
+ completion_ids,
26
+ advantages,
27
+ completion_mask,
28
+ temperature,
29
+ beta,
30
+ eps_low,
31
+ eps_high,
32
+ inplace,
33
+ )
34
+
35
+
36
+ # This is a demo how to use grpo_loss in GRPOTrainer. The Trl version must be 0.16
37
+ """
38
+ import torch
39
+ import trl
40
+ assert trl.__version__.startswith("0.16"), "please pip install trl==0.16"
41
+ from trl.extras.profiling import profiling_decorator
42
+
43
+ @profiling_decorator
44
+ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
45
+ # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
46
+ logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
47
+ return fused_selective_log_softmax(logits, input_ids, self.temperature, mask=attention_mask)
48
+
49
+ @profiling_decorator
50
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
51
+ if return_outputs:
52
+ raise ValueError("The GRPOTrainer does not support returning outputs")
53
+ # Compute the per-token log probabilities for the model
54
+
55
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
56
+ completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
57
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
58
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
59
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
60
+ logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
61
+
62
+ ref_per_token_logps = inputs["ref_per_token_logps"]
63
+ advantages = inputs["advantages"]
64
+ old_per_token_logps = inputs["old_per_token_logps"]
65
+
66
+
67
+ per_token_loss, per_token_kl, is_clipped = triton_grpo_loss(logits,
68
+ old_per_token_logps,
69
+ ref_per_token_logps,
70
+ completion_ids,
71
+ advantages,
72
+ completion_mask,
73
+ self.temperature,
74
+ self.beta,
75
+ self.epsilon_low,
76
+ self.epsilon_high,)
77
+ loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
78
+
79
+ # Log the metrics
80
+ mode = "eval" if self.control.should_evaluate else "train"
81
+
82
+ if self.beta != 0.0:
83
+ mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
84
+ self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
85
+
86
+ clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
87
+ self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
88
+ return loss
89
+
90
+ trl.GRPOTrainer._get_per_token_logps = _get_per_token_logps
91
+ trl.GRPOTrainer.compute_loss = compute_loss
92
+ trigger = None
93
+ """
94
+
95
+ # add this line at the first line of grpo.py in open-r1
96
+ """
97
+ from liger_kernel.transformers.grpo_loss import trigger
98
+ """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.9.dev20250519011716
3
+ Version: 0.5.9.dev20250519015630
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -129,6 +129,7 @@ src/liger_kernel/ops/fused_linear_cross_entropy.py
129
129
  src/liger_kernel/ops/fused_linear_jsd.py
130
130
  src/liger_kernel/ops/geglu.py
131
131
  src/liger_kernel/ops/group_norm.py
132
+ src/liger_kernel/ops/grpo_loss.py
132
133
  src/liger_kernel/ops/jsd.py
133
134
  src/liger_kernel/ops/kl_div.py
134
135
  src/liger_kernel/ops/layer_norm.py
@@ -151,6 +152,7 @@ src/liger_kernel/transformers/fused_linear_jsd.py
151
152
  src/liger_kernel/transformers/geglu.py
152
153
  src/liger_kernel/transformers/gema3_rms.py
153
154
  src/liger_kernel/transformers/group_norm.py
155
+ src/liger_kernel/transformers/grpo_loss.py
154
156
  src/liger_kernel/transformers/jsd.py
155
157
  src/liger_kernel/transformers/kl_div.py
156
158
  src/liger_kernel/transformers/layer_norm.py
@@ -233,6 +235,7 @@ test/transformers/test_fused_linear_cross_entropy.py
233
235
  test/transformers/test_fused_linear_jsd.py
234
236
  test/transformers/test_geglu.py
235
237
  test/transformers/test_group_norm.py
238
+ test/transformers/test_grpo_loss.py
236
239
  test/transformers/test_jsd.py
237
240
  test/transformers/test_kl_div.py
238
241
  test/transformers/test_layer_norm.py
@@ -0,0 +1,190 @@
1
+ import pytest
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ from test.utils import infer_device
6
+ from test.utils import set_seed
7
+
8
+ from liger_kernel.ops.grpo_loss import fused_selective_log_softmax
9
+ from liger_kernel.transformers.grpo_loss import triton_grpo_loss
10
+
11
+
12
+ def compare(x, y, extra_str=""):
13
+ if x is None or y is None:
14
+ return
15
+ if any([x.dtype == torch.float32, y.dtype == torch.float32]):
16
+ x, y = x.float(), y.float()
17
+ diff = (x - y).abs()
18
+ diff = diff / (torch.max(x.abs(), y.abs()) + 1e-5)
19
+ print(f"{extra_str}Max difference: {diff.max().item()}, Mean difference: {diff.mean().item()}")
20
+
21
+
22
+ @torch.no_grad
23
+ def selective_log_softmax(logits, input_ids, temperature=0.9):
24
+ logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
25
+ logits_to_keep = logits.size(1)
26
+ index = input_ids[:, -logits_to_keep:]
27
+ logits = logits[:, -logits_to_keep:]
28
+ logits = logits / temperature
29
+
30
+ if logits.dtype in [torch.float32, torch.float64]:
31
+ selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
32
+ # loop to reduce peak mem consumption
33
+ logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
34
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
35
+ else:
36
+ # logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach
37
+ per_token_logps = []
38
+ for row_logits, row_labels in zip(logits, index): # loop to reduce peak mem consumption
39
+ row_logps = F.log_softmax(row_logits, dim=-1)
40
+ row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
41
+ per_token_logps.append(row_per_token_logps)
42
+ per_token_logps = torch.stack(per_token_logps)
43
+ return per_token_logps
44
+
45
+
46
+ def torch_grpo_loss(
47
+ logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, temperature, beta, eps_low, eps_high
48
+ ):
49
+ assert logits.is_contiguous() and completion_ids.is_contiguous()
50
+ assert old_logp is None or old_logp.is_contiguous()
51
+ assert (ref_logp is not None and ref_logp.is_contiguous()) if beta != 0.0 else True
52
+ logits = logits[:, :-1]
53
+
54
+ def get_log_probs(logits, input_ids):
55
+ per_token_logps = []
56
+ for logits_row, input_ids_row in zip(logits, input_ids[:, -logits.size(1) :]):
57
+ log_probs = logits_row.log_softmax(dim=-1)
58
+ token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
59
+ per_token_logps.append(token_log_prob)
60
+ return torch.stack(per_token_logps)
61
+
62
+ per_token_logps = get_log_probs(logits / temperature, completion_ids)
63
+ # return per_token_logps, None, None
64
+ ref_per_token_logps = ref_logp
65
+
66
+ if old_logp is None:
67
+ old_logp = per_token_logps.detach()
68
+ coef_1 = torch.exp(per_token_logps - old_logp)
69
+ coef_2 = torch.clamp(coef_1, 1 - eps_low, 1 + eps_high)
70
+ per_token_loss1 = coef_1 * advantages.unsqueeze(1)
71
+ per_token_loss2 = coef_2 * advantages.unsqueeze(1)
72
+ per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
73
+ per_token_loss = per_token_loss * completion_mask if completion_mask is not None else per_token_loss
74
+
75
+ per_token_kl = None
76
+ if beta != 0.0:
77
+ per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
78
+ if completion_mask is not None:
79
+ per_token_kl *= completion_mask
80
+ per_token_loss = per_token_loss + beta * per_token_kl
81
+ is_clipped = (per_token_loss1 < per_token_loss2).float()
82
+ return per_token_loss, per_token_kl, is_clipped
83
+
84
+
85
+ set_seed(42)
86
+ device = infer_device()
87
+
88
+
89
+ @pytest.mark.parametrize(
90
+ "temperature, B, T, V",
91
+ [
92
+ (0.9, 1, 1024, 64000),
93
+ (0.7, 1, 1024, 151936),
94
+ ],
95
+ )
96
+ @pytest.mark.parametrize(
97
+ "dtype, atol, rtol",
98
+ [
99
+ (torch.bfloat16, 1e-5, 1e-5),
100
+ ],
101
+ )
102
+ def test_selective_log_softmax(B, T, V, temperature, dtype, atol, rtol):
103
+ # logits_to_keep + 1
104
+ _input = torch.randn(B, T + 1, V, device=device, dtype=dtype)
105
+
106
+ logit1 = _input.clone()
107
+ logit2 = _input.clone()
108
+ logit3 = _input.clone().float()
109
+
110
+ # we set the length of prompt_ids is 100 and the length of completion_ids is T
111
+ input_ids = torch.randint(0, V - 1, (B, 100 + T), dtype=torch.int64, device=device)
112
+
113
+ torch_bf16_logp = selective_log_softmax(logit1, input_ids, temperature)
114
+ triton_bf16_logp = fused_selective_log_softmax(logit2, input_ids, temperature)
115
+ torch_fp32_logp = selective_log_softmax(logit3, input_ids, temperature)
116
+
117
+ # assert_verbose_allclose(torch_bf16_logp, torch_fp32_logp, rtol=rtol, atol=atol)
118
+ # assert_verbose_allclose(triton_bf16_logp, torch_fp32_logp, rtol=rtol, atol=atol)
119
+ print("\n" + "=" * 20 + " selective_log_softmax " + "=" * 20)
120
+ compare(torch_bf16_logp, torch_fp32_logp, "torch-bf16 vs torch-fp32, ")
121
+ compare(triton_bf16_logp, torch_fp32_logp, "triton-bf16 vs torch-fp32, ")
122
+
123
+
124
+ @pytest.mark.parametrize(
125
+ "temperature, num_iteration, beta, eps_low, eps_high",
126
+ [(0.7, num_iteration, beta, 0.2, 0.4) for num_iteration in [1, 5] for beta in [0.0, 0.04]],
127
+ )
128
+ @pytest.mark.parametrize(
129
+ "B, T, V",
130
+ [
131
+ (1, 1024, 151936),
132
+ ],
133
+ )
134
+ @pytest.mark.parametrize(
135
+ "dtype, atol, rtol",
136
+ [
137
+ (torch.bfloat16, 1e-5, 1e-5),
138
+ ],
139
+ )
140
+ def test_grpo_loss(B, T, V, temperature, num_iteration, beta, eps_low, eps_high, dtype, atol, rtol):
141
+ _input = torch.randn(B, T + 1, V, device=device, dtype=dtype)
142
+
143
+ logits1 = _input.clone().requires_grad_(True)
144
+ logits2 = _input.clone().requires_grad_(True)
145
+ logits3 = _input.clone().float().requires_grad_(True)
146
+
147
+ completion_ids = torch.randint(0, V - 1, (B, T), dtype=torch.int64, device=device)
148
+ completion_mask = torch.ones_like(completion_ids, dtype=torch.int32)
149
+ # we set num_padding is 100
150
+ completion_mask[:, -100:] = 0
151
+
152
+ # we set these in fp32, because fused_selective_log_softmax retutn fp32 logp, although logits in bf16
153
+ ref_logp = torch.randn(B, T, device=device, dtype=torch.float32) if beta != 0.0 else None
154
+ old_logp = torch.randn(B, T, device=device, dtype=torch.float32) if num_iteration > 1 else None
155
+ advantages = torch.randn(B, device=device, dtype=torch.float32)
156
+
157
+ loss1, kl1, is_clipped1 = torch_grpo_loss(
158
+ logits1, old_logp, ref_logp, completion_ids, advantages, completion_mask, temperature, beta, eps_low, eps_high
159
+ )
160
+
161
+ loss2, kl2, is_clipped2 = triton_grpo_loss(
162
+ logits2,
163
+ old_logp,
164
+ ref_logp,
165
+ completion_ids,
166
+ advantages,
167
+ completion_mask,
168
+ temperature,
169
+ beta,
170
+ eps_low,
171
+ eps_high,
172
+ inplace=True,
173
+ )
174
+
175
+ loss3, kl3, is_clipped3 = torch_grpo_loss(
176
+ logits3, old_logp, ref_logp, completion_ids, advantages, completion_mask, temperature, beta, eps_low, eps_high
177
+ )
178
+
179
+ dy = torch.randn_like(loss3)
180
+ loss1.backward(dy)
181
+ loss2.backward(dy)
182
+ loss3.backward(dy)
183
+
184
+ print("\n" + "=" * 20 + " grpo_loss " + "=" * 20)
185
+ compare(loss1, loss3, "per_token_loss: torch-bf16 vs torch-fp32, ")
186
+ compare(kl1, kl3, "per_token_kl: torch-bf16 vs torch-fp32, ")
187
+ compare(logits1.grad, logits3.grad, "logits.grad: torch-bf16 vs torch-fp32, ")
188
+ compare(loss2, loss3, "per_token_loss: triton-bf16 vs torch-fp32, ")
189
+ compare(kl2, kl3, "per_token_kl: triton-bf16 vs torch-fp32, ")
190
+ compare(logits2.grad, logits3.grad, "logits.grad: triton-bf16 vs torch-fp32, ")