liger-kernel-nightly 0.5.9.dev20250517045825__tar.gz → 0.5.9.dev20250519015630__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (255) hide show
  1. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_dyt.py +37 -34
  3. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/pyproject.toml +1 -1
  4. liger_kernel_nightly-0.5.9.dev20250519015630/src/liger_kernel/ops/dyt.py +159 -0
  5. liger_kernel_nightly-0.5.9.dev20250519015630/src/liger_kernel/ops/grpo_loss.py +310 -0
  6. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/dyt.py +5 -3
  7. liger_kernel_nightly-0.5.9.dev20250519015630/src/liger_kernel/transformers/grpo_loss.py +98 -0
  8. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
  9. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel_nightly.egg-info/SOURCES.txt +3 -0
  10. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_dyt.py +40 -20
  11. liger_kernel_nightly-0.5.9.dev20250519015630/test/transformers/test_grpo_loss.py +190 -0
  12. liger_kernel_nightly-0.5.9.dev20250517045825/src/liger_kernel/ops/dyt.py +0 -225
  13. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  14. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  15. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/.github/pull_request_template.md +0 -0
  16. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/.github/workflows/amd-ci.yml +0 -0
  17. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/.github/workflows/docs.yml +0 -0
  18. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/.github/workflows/intel-ci.yml +0 -0
  19. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/.github/workflows/nvi-ci.yml +0 -0
  20. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/.github/workflows/publish-nightly.yml +0 -0
  21. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/.github/workflows/publish-release.yml +0 -0
  22. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/.gitignore +0 -0
  23. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/.idea/workspace.xml +0 -0
  24. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/LICENSE +0 -0
  25. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/Makefile +0 -0
  26. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/NOTICE +0 -0
  27. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/README.md +0 -0
  28. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/README.md +0 -0
  29. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/__init__.py +0 -0
  30. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/benchmarks_visualizer.py +0 -0
  31. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/data/all_benchmark_data.csv +0 -0
  32. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/__init__.py +0 -0
  33. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  34. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  35. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  36. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  37. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_embedding.py +0 -0
  38. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  39. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  40. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_geglu.py +0 -0
  41. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_group_norm.py +0 -0
  42. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_jsd.py +0 -0
  43. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_kl_div.py +0 -0
  44. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  45. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  46. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  47. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  48. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  49. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_rope.py +0 -0
  50. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  51. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_sparsemax.py +0 -0
  52. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_swiglu.py +0 -0
  53. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/benchmark_tvd.py +0 -0
  54. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/benchmark/scripts/utils.py +0 -0
  55. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/dev/fmt-requirements.txt +0 -0
  56. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/dev/modal/tests.py +0 -0
  57. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/dev/modal/tests_bwd.py +0 -0
  58. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/Examples.md +0 -0
  59. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/Getting-Started.md +0 -0
  60. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/High-Level-APIs.md +0 -0
  61. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/Low-Level-APIs.md +0 -0
  62. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/acknowledgement.md +0 -0
  63. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/contributing.md +0 -0
  64. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/images/banner.GIF +0 -0
  65. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/images/compose.gif +0 -0
  66. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/images/e2e-memory.png +0 -0
  67. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/images/e2e-tps.png +0 -0
  68. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/images/logo-banner.png +0 -0
  69. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/images/patch.gif +0 -0
  70. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/images/post-training.png +0 -0
  71. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/index.md +0 -0
  72. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/docs/license.md +0 -0
  73. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/alignment/accelerate_config.yaml +0 -0
  74. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/alignment/run_orpo.py +0 -0
  75. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/README.md +0 -0
  76. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/callback.py +0 -0
  77. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/config/fsdp_config.json +0 -0
  78. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  79. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  80. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  81. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/img/llama_tps.png +0 -0
  82. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  83. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/img/qwen_tps.png +0 -0
  84. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/launch_on_modal.py +0 -0
  85. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/requirements.txt +0 -0
  86. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/run_benchmarks.sh +0 -0
  87. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/run_gemma.sh +0 -0
  88. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/run_llama.sh +0 -0
  89. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/run_qwen.sh +0 -0
  90. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/run_qwen2_vl.sh +0 -0
  91. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/training.py +0 -0
  92. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/huggingface/training_multimodal.py +0 -0
  93. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/lightning/README.md +0 -0
  94. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/lightning/requirements.txt +0 -0
  95. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/lightning/training.py +0 -0
  96. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/README.md +0 -0
  97. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/callback.py +0 -0
  98. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  99. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  100. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  101. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  102. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  103. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  104. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  105. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  106. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  107. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/medusa_util.py +0 -0
  108. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/requirements.txt +0 -0
  109. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  110. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/examples/medusa/train.py +0 -0
  111. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/licenses/LICENSE-Apache-2.0 +0 -0
  112. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  113. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  114. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/licenses/LICENSE-MIT-llmc +0 -0
  115. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/licenses/LICENSE-MIT-triton +0 -0
  116. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/mkdocs.yml +0 -0
  117. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/setup.cfg +0 -0
  118. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/setup.py +0 -0
  119. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/__init__.py +0 -0
  120. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/chunked_loss/README.md +0 -0
  121. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  122. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  123. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  124. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/chunked_loss/functional.py +0 -0
  125. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  126. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/chunked_loss/fused_linear_ppo.py +0 -0
  127. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  128. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  129. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
  130. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  131. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  132. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  133. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  134. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/env_report.py +0 -0
  135. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/__init__.py +0 -0
  136. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/cross_entropy.py +0 -0
  137. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  138. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  139. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  140. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  141. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/geglu.py +0 -0
  142. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/group_norm.py +0 -0
  143. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/jsd.py +0 -0
  144. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/kl_div.py +0 -0
  145. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/layer_norm.py +0 -0
  146. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  147. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/rms_norm.py +0 -0
  148. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/rope.py +0 -0
  149. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/sparsemax.py +0 -0
  150. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/swiglu.py +0 -0
  151. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/tvd.py +0 -0
  152. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/ops/utils.py +0 -0
  153. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/__init__.py +0 -0
  154. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/auto_model.py +0 -0
  155. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  156. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  157. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/functional.py +0 -0
  158. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  159. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  160. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/geglu.py +0 -0
  161. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/gema3_rms.py +0 -0
  162. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/group_norm.py +0 -0
  163. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/jsd.py +0 -0
  164. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/kl_div.py +0 -0
  165. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/layer_norm.py +0 -0
  166. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/__init__.py +0 -0
  167. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/gemma.py +0 -0
  168. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  169. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/gemma3.py +0 -0
  170. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/glm4.py +0 -0
  171. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/llama.py +0 -0
  172. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/llava.py +0 -0
  173. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
  174. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/mistral.py +0 -0
  175. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  176. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/mllama.py +0 -0
  177. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/olmo2.py +0 -0
  178. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/paligemma.py +0 -0
  179. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/phi3.py +0 -0
  180. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  181. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
  182. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  183. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/qwen3.py +0 -0
  184. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/model/qwen3_moe.py +0 -0
  185. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  186. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  187. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/rms_norm.py +0 -0
  188. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/rope.py +0 -0
  189. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/sparsemax.py +0 -0
  190. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/swiglu.py +0 -0
  191. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  192. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  193. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  194. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/transformers/tvd.py +0 -0
  195. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/triton/__init__.py +0 -0
  196. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/triton/monkey_patch.py +0 -0
  197. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel/utils.py +0 -0
  198. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  199. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  200. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  201. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/__init__.py +0 -0
  202. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/chunked_loss/__init__.py +0 -0
  203. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/chunked_loss/test_cpo_loss.py +0 -0
  204. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/chunked_loss/test_dpo_loss.py +0 -0
  205. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/chunked_loss/test_grpo_loss.py +0 -0
  206. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/chunked_loss/test_jsd_loss.py +0 -0
  207. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/chunked_loss/test_kto_loss.py +0 -0
  208. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/chunked_loss/test_orpo_loss.py +0 -0
  209. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/chunked_loss/test_simpo_loss.py +0 -0
  210. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/conftest.py +0 -0
  211. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/convergence/__init__.py +0 -0
  212. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/convergence/bf16/__init__.py +0 -0
  213. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/convergence/bf16/test_mini_models.py +0 -0
  214. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
  215. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
  216. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/convergence/fp32/__init__.py +0 -0
  217. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/convergence/fp32/test_mini_models.py +0 -0
  218. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
  219. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
  220. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json +0 -0
  221. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
  222. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
  223. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
  224. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
  225. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  226. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
  227. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  228. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  229. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/resources/tiny_shakespeare.txt +0 -0
  230. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  231. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  232. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  233. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_auto_model.py +0 -0
  234. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_cross_entropy.py +0 -0
  235. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_embedding.py +0 -0
  236. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_flex_attention.py +0 -0
  237. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  238. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_fused_linear_jsd.py +0 -0
  239. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_geglu.py +0 -0
  240. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_group_norm.py +0 -0
  241. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_jsd.py +0 -0
  242. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_kl_div.py +0 -0
  243. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_layer_norm.py +0 -0
  244. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_mm_int8int2.py +0 -0
  245. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_monkey_patch.py +0 -0
  246. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_qwen2vl_mrope.py +0 -0
  247. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_rms_norm.py +0 -0
  248. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_rope.py +0 -0
  249. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_sparsemax.py +0 -0
  250. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_swiglu.py +0 -0
  251. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_trainer_integration.py +0 -0
  252. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_transformers.py +0 -0
  253. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/transformers/test_tvd.py +0 -0
  254. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/triton/test_triton_monkey_patch.py +0 -0
  255. {liger_kernel_nightly-0.5.9.dev20250517045825 → liger_kernel_nightly-0.5.9.dev20250519015630}/test/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.9.dev20250517045825
3
+ Version: 0.5.9.dev20250519015630
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -22,17 +22,18 @@ def bench_speed_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
22
22
  from test.transformers.test_dyt import LigerDyT
23
23
  from test.transformers.test_dyt import TorchDyT
24
24
 
25
- BT = input.x
25
+ hidden_size = input.x
26
26
  provider = input.kernel_provider
27
27
  mode = input.kernel_operation_mode
28
28
  extra_benchmark_config = input.extra_benchmark_config
29
- hidden_size = extra_benchmark_config["hidden_size"]
29
+ BT = extra_benchmark_config["BT"]
30
+ beta = extra_benchmark_config["beta"]
30
31
  dtype = extra_benchmark_config["dtype"]
31
32
 
32
33
  x_shape = (BT, hidden_size)
33
- torch_dyt = TorchDyT(hidden_size=hidden_size).to(device)
34
- torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size).to(device))
35
- triton_dyt = LigerDyT(hidden_size=hidden_size).to(device)
34
+ torch_dyt = TorchDyT(hidden_size=hidden_size, beta=beta).to(device)
35
+ torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size, beta=beta).to(device))
36
+ triton_dyt = LigerDyT(hidden_size=hidden_size, beta=beta).to(device)
36
37
 
37
38
  x = torch.randn(x_shape, dtype=dtype, device=device)
38
39
  dy = torch.randn_like(x)
@@ -75,16 +76,17 @@ def bench_memory_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput
75
76
  from test.transformers.test_dyt import LigerDyT
76
77
  from test.transformers.test_dyt import TorchDyT
77
78
 
78
- BT = input.x
79
+ hidden_size = input.x
79
80
  provider = input.kernel_provider
80
81
  extra_benchmark_config = input.extra_benchmark_config
81
- hidden_size = extra_benchmark_config["hidden_size"]
82
+ BT = extra_benchmark_config["BT"]
83
+ beta = extra_benchmark_config["beta"]
82
84
  dtype = extra_benchmark_config["dtype"]
83
85
 
84
86
  x_shape = (BT, hidden_size)
85
- torch_dyt = TorchDyT(hidden_size=hidden_size).to(device)
86
- torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size).to(device))
87
- triton_dyt = LigerDyT(hidden_size=hidden_size).to(device)
87
+ torch_dyt = TorchDyT(hidden_size=hidden_size, beta=beta).to(device)
88
+ torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size, beta=beta).to(device))
89
+ triton_dyt = LigerDyT(hidden_size=hidden_size, beta=beta).to(device)
88
90
 
89
91
  x = torch.randn(x_shape, dtype=dtype, device=device)
90
92
  dy = torch.randn_like(x)
@@ -113,27 +115,28 @@ def bench_memory_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput
113
115
  if __name__ == "__main__":
114
116
  args = parse_benchmark_script_args()
115
117
 
116
- common_configs = {
117
- "kernel_name": "dyt",
118
- "x_name": "BT",
119
- "x_label": "batch_size * seq_len",
120
- "x_values": [2**i for i in range(10, 15)],
121
- "kernel_providers": ["liger", "torch", "torch_compile"],
122
- "extra_benchmark_configs": [{"hidden_size": 4096, "dtype": torch.float32}],
123
- "overwrite": args.overwrite,
124
- }
125
-
126
- run_benchmarks(
127
- bench_test_fn=bench_speed_dyt,
128
- kernel_operation_modes=["forward", "backward", "full"],
129
- metric_name="speed",
130
- metric_unit="ms",
131
- **common_configs,
132
- )
133
- run_benchmarks(
134
- bench_test_fn=bench_memory_dyt,
135
- kernel_operation_modes=["full"],
136
- metric_name="memory",
137
- metric_unit="MB",
138
- **common_configs,
139
- )
118
+ for beta in [False, True]:
119
+ common_configs = {
120
+ "kernel_name": f"dyt_beta={beta}",
121
+ "x_name": "hidden_size",
122
+ "x_label": "hidden_size",
123
+ "x_values": [1024 * i for i in range(1, 17)],
124
+ "kernel_providers": ["liger", "torch", "torch_compile"],
125
+ "extra_benchmark_configs": [{"BT": 4096, "dtype": torch.bfloat16, "beta": beta}],
126
+ "overwrite": args.overwrite,
127
+ }
128
+
129
+ run_benchmarks(
130
+ bench_test_fn=bench_speed_dyt,
131
+ kernel_operation_modes=["forward", "backward", "full"],
132
+ metric_name="speed",
133
+ metric_unit="ms",
134
+ **common_configs,
135
+ )
136
+ run_benchmarks(
137
+ bench_test_fn=bench_memory_dyt,
138
+ kernel_operation_modes=["full"],
139
+ metric_name="memory",
140
+ metric_unit="MB",
141
+ **common_configs,
142
+ )
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.9.dev20250517045825"
7
+ version = "0.5.9.dev20250519015630"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -0,0 +1,159 @@
1
+ import operator
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from triton.language.extra.libdevice import tanh
8
+
9
+ from liger_kernel.ops.utils import compare_version
10
+ from liger_kernel.ops.utils import ensure_contiguous
11
+ from liger_kernel.ops.utils import infer_device
12
+
13
+ if compare_version("triton", operator.ge, "3.0.0"):
14
+ try:
15
+ # typical import path with dispatch available
16
+ from triton.language.extra.libdevice import tanh
17
+ except ModuleNotFoundError:
18
+ # for working with NGC containers
19
+ from triton.language.extra.cuda.libdevice import tanh
20
+ else:
21
+ from triton.language.math import tanh
22
+
23
+
24
+ # @triton.autotune([triton.Config({"BLOCK_N":bn}, num_stages=ns, num_warps=nw)
25
+ # for bn in [1024, 2048, 4096]
26
+ # for ns in [1,2,4]
27
+ # for nw in [4, 8, 16, 32]
28
+ # ],
29
+ # key=['N'])
30
+ @triton.jit
31
+ def _dyt_fwd_kernel(X, Y, Alpha, Gamma, Beta, HAVE_BETA: tl.constexpr, N: tl.constexpr, BLOCK_N: tl.constexpr = 1024):
32
+ col = tl.cast(tl.program_id(0), tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
33
+ mask = col < N
34
+ row_id = tl.cast(tl.program_id(1), tl.int64)
35
+
36
+ X += row_id * N
37
+ Y += row_id * N
38
+ alpha = tl.load(Alpha).to(tl.float32)
39
+
40
+ gamma = tl.load(Gamma + col, mask=mask, other=0.0).to(tl.float32)
41
+
42
+ x = tl.load(X + col, mask=mask, other=0.0).to(tl.float32)
43
+
44
+ tanh_x = tanh(alpha * x)
45
+ y = tanh_x * gamma
46
+ if HAVE_BETA:
47
+ beta = tl.load(Beta + col, mask=mask, other=0.0).to(tl.float32)
48
+ y += beta
49
+ tl.store(Y + col, y, mask=mask)
50
+
51
+
52
+ # @triton.autotune([triton.Config({"BLOCK_N":bn}, num_stages=ns, num_warps=nw)
53
+ # for bn in [1024, 2048, 4096]
54
+ # for ns in [1,2,4]
55
+ # for nw in [4, 8, 16]
56
+ # ],
57
+ # key=['N'])
58
+ @triton.jit
59
+ def _dyt_bwd_kernel(
60
+ DY, DX, DA, DG, DB, X, Alpha, Gamma, HAVE_BETA: tl.constexpr, M, N: tl.constexpr, BLOCK_N: tl.constexpr = 1024
61
+ ):
62
+ col = tl.cast(tl.program_id(0), tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
63
+ mask = col < N
64
+ start_row_id = tl.cast(tl.program_id(1), tl.int64)
65
+
66
+ alpha = tl.load(Alpha).to(tl.float32)
67
+ da = 0.0
68
+ gamma = tl.load(Gamma + col, mask=mask, other=0.0).to(tl.float32)
69
+ dg = tl.zeros((BLOCK_N,), dtype=tl.float32)
70
+ if HAVE_BETA:
71
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
72
+ for row_id in range(start_row_id, M, tl.num_programs(1)):
73
+ x = tl.load(X + row_id * N + col, mask=mask, other=0.0).to(tl.float32)
74
+ dy = tl.load(DY + row_id * N + col, mask=mask, other=0.0).to(tl.float32)
75
+ tanh_x = tanh(alpha * x)
76
+ if HAVE_BETA:
77
+ db += dy
78
+ dg += dy * tanh_x
79
+ tmp = (1 - tanh_x * tanh_x) * dy * gamma
80
+ da += tl.sum(x * tmp, 0)
81
+ dx = alpha * tmp
82
+ tl.store(DX + row_id * N + col, dx, mask=mask)
83
+
84
+ tl.store(DG + start_row_id * N + col, dg, mask=mask)
85
+ if HAVE_BETA:
86
+ tl.store(DB + start_row_id * N + col, db, mask=mask)
87
+ tl.store(DA + start_row_id * tl.cdiv(N, 512) + tl.program_id(0), da)
88
+
89
+
90
+ def liger_dyt_fwd(x, alpha, gamma, beta):
91
+ assert x.is_contiguous()
92
+ HAVE_BETA = True if beta is not None else False
93
+ input_shape = x.shape
94
+ x = x.view(-1, input_shape[-1])
95
+ M, N = x.shape
96
+
97
+ y = torch.empty_like(x)
98
+
99
+ if N >= 4096:
100
+ kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 2048), "num_warps": 4, "num_stages": 1}
101
+ else:
102
+ kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 1024), "num_warps": 4, "num_stages": 1}
103
+
104
+ grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), M)
105
+ _dyt_fwd_kernel[(grid)](
106
+ x,
107
+ y,
108
+ alpha,
109
+ gamma,
110
+ beta,
111
+ HAVE_BETA,
112
+ N,
113
+ **kwargs,
114
+ )
115
+ return y.view(input_shape)
116
+
117
+
118
+ def liger_dyt_bwd(dy, x, alpha, gamma, beta):
119
+ assert dy.is_contiguous()
120
+ input_shape = x.shape
121
+ x = x.view(-1, input_shape[-1])
122
+ M, N = x.shape
123
+ HAVE_BETA = True if beta is not None else False
124
+
125
+ device = infer_device()
126
+ if device == "cuda":
127
+ NUM_SMS = torch.cuda.get_device_properties(x.device).multi_processor_count
128
+ elif device == "xpu":
129
+ NUM_SMS = torch.xpu.get_device_properties(x.device).gpu_subslice_count
130
+
131
+ da = torch.zeros(NUM_SMS, triton.cdiv(N, 512), dtype=torch.float32, device=x.device)
132
+ dg = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device)
133
+ db = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None
134
+ dx = torch.empty_like(dy)
135
+
136
+ kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 1024), "num_warps": 8, "num_stages": 2}
137
+ grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), NUM_SMS)
138
+ _dyt_bwd_kernel[grid](dy, dx, da, dg, db, x, alpha, gamma, HAVE_BETA, M, N, **kwargs)
139
+ if HAVE_BETA:
140
+ db = db.sum(0).to(x.dtype)
141
+ dg = dg.sum(0).to(gamma.dtype)
142
+ da = da.sum().to(x.dtype).unsqueeze(0)
143
+ return dx.view(input_shape), da, dg, db
144
+
145
+
146
+ class LigerDyTFunction(torch.autograd.Function):
147
+ @staticmethod
148
+ @ensure_contiguous
149
+ def forward(ctx, x, alpha, gamma, beta):
150
+ y = liger_dyt_fwd(x, alpha, gamma, beta)
151
+ ctx.save_for_backward(x, alpha, gamma, beta)
152
+ return y
153
+
154
+ @staticmethod
155
+ @ensure_contiguous
156
+ def backward(ctx, dy):
157
+ x, alpha, gamma, beta = ctx.saved_tensors
158
+ dx, dalpha, dgamma, dbeta = liger_dyt_bwd(dy, x, alpha, gamma, beta)
159
+ return dx, dalpha, dgamma, dbeta
@@ -0,0 +1,310 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def _selective_log_softmax_kernel(
8
+ LOGITS,
9
+ INPUT_IDS,
10
+ LOG_P,
11
+ MASK,
12
+ TEMPERATURE,
13
+ stride_input_ids_b,
14
+ L: tl.constexpr,
15
+ N: tl.constexpr,
16
+ BLOCK_N: tl.constexpr = 4096,
17
+ ):
18
+ off_b = tl.program_id(0).cast(tl.int64)
19
+ off_l = tl.program_id(1).cast(tl.int64)
20
+
21
+ LOGITS += off_b * (L + 1) * N + off_l * N
22
+ INPUT_IDS += off_b * stride_input_ids_b + off_l
23
+ LOG_P += off_b * L + off_l
24
+
25
+ if MASK is not None:
26
+ MASK += off_b * stride_input_ids_b + off_l
27
+ not_skip = tl.load(MASK)
28
+ if not_skip == 0:
29
+ return
30
+
31
+ m_i = float("-inf")
32
+ l_i = 0.0
33
+ for start in range(0, N, BLOCK_N):
34
+ cols = start + tl.arange(0, BLOCK_N)
35
+ logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
36
+ new_m_i = tl.maximum(m_i, tl.max(logits))
37
+ alpha = tl.exp(m_i - new_m_i)
38
+ l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
39
+ m_i = new_m_i
40
+ lse = m_i + tl.log(l_i)
41
+
42
+ ids = tl.load(INPUT_IDS)
43
+ x = tl.load(LOGITS + ids).to(tl.float32) / TEMPERATURE
44
+ logp = x - lse
45
+ tl.store(LOG_P, logp)
46
+
47
+
48
+ # compue old_logp and ref_logp, it reduce 10G peak Memory. it does not requires grad
49
+ @torch.no_grad
50
+ def fused_selective_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 0.9, mask=None):
51
+ assert logits.is_contiguous()
52
+ B, L_ADD_1, N = logits.shape
53
+ L = L_ADD_1 - 1
54
+ input_ids = input_ids[:, -L:]
55
+ if mask is not None:
56
+ mask = mask[:, -L:]
57
+ log_p = torch.zeros(B, L, dtype=torch.float32, device=logits.device)
58
+ kwargs = {"BLOCK_N": 2048, "num_stages": 4, "num_warps": 1}
59
+ _selective_log_softmax_kernel[(B, L)](
60
+ logits, input_ids, log_p, mask, temperature, input_ids.stride(0), L, N, **kwargs
61
+ )
62
+ return log_p
63
+
64
+
65
+ # @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
66
+ # for BLOCK_N in [2048, 4096, 8192]
67
+ # for ns in [1, 2, 4]
68
+ # for nw in [1, 2, 4, 8, 16]],
69
+ # key=['N'])
70
+ @triton.jit
71
+ def _grpo_loss_fwd_kernel(
72
+ LOGITS,
73
+ OLD_LOGP,
74
+ REF_LOGP,
75
+ INPUT_IDS,
76
+ COMPLETION_MASK,
77
+ ADVANTAGES,
78
+ LOSS,
79
+ LSE,
80
+ KL,
81
+ IS_CLIPPED,
82
+ TEMPERATURE,
83
+ BETA: tl.constexpr,
84
+ EPS_LOW,
85
+ EPS_HIGH,
86
+ L: tl.constexpr,
87
+ N: tl.constexpr,
88
+ BLOCK_N: tl.constexpr = 4096,
89
+ ):
90
+ off_b = tl.program_id(0).cast(tl.int64)
91
+ off_l = tl.program_id(1).cast(tl.int64)
92
+
93
+ if COMPLETION_MASK is not None:
94
+ COMPLETION_MASK += off_b * L + off_l
95
+ not_skip = tl.load(COMPLETION_MASK)
96
+ if not_skip == 0:
97
+ return
98
+
99
+ LOGITS += off_b * (L + 1) * N + off_l * N
100
+ INPUT_IDS += off_b * L + off_l
101
+ ADVANTAGES += off_b
102
+ LOSS += off_b * L + off_l
103
+ LSE += off_b * L + off_l
104
+ IS_CLIPPED += off_b * L + off_l
105
+
106
+ m_i = float("-inf")
107
+ l_i = 0.0
108
+ for start in range(0, N, BLOCK_N):
109
+ cols = start + tl.arange(0, BLOCK_N)
110
+ logits = tl.load(LOGITS + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
111
+ new_m_i = tl.maximum(m_i, tl.max(logits))
112
+ alpha = tl.exp(m_i - new_m_i)
113
+ l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
114
+ m_i = new_m_i
115
+ lse = m_i + tl.log(l_i)
116
+
117
+ idx = tl.load(INPUT_IDS)
118
+ x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
119
+ logp = x - lse
120
+ if OLD_LOGP is None:
121
+ old_logp = logp
122
+ else:
123
+ OLD_LOGP += off_b * L + off_l
124
+ old_logp = tl.load(OLD_LOGP).to(tl.float32)
125
+ coef_1 = tl.exp(logp - old_logp)
126
+ coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
127
+ advantage = tl.load(ADVANTAGES).to(tl.float32)
128
+ per_token_loss1 = coef_1 * advantage
129
+ per_token_loss2 = coef_2 * advantage
130
+ per_token_loss = -tl.minimum(per_token_loss1, per_token_loss2)
131
+ is_clipped = per_token_loss1 < per_token_loss2
132
+
133
+ if BETA != 0.0:
134
+ REF_LOGP += off_b * L + off_l
135
+ KL += off_b * L + off_l
136
+ ref_logp = tl.load(REF_LOGP).to(tl.float32)
137
+ kl = tl.exp(ref_logp - logp) - (ref_logp - logp) - 1
138
+ per_token_loss += BETA * kl
139
+ tl.store(KL, kl)
140
+
141
+ tl.store(LOSS, per_token_loss)
142
+ tl.store(LSE, lse)
143
+ tl.store(IS_CLIPPED, is_clipped)
144
+
145
+
146
+ # @triton.autotune([triton.Config({"BLOCK_N":BLOCK_N}, num_stages=ns, num_warps=nw)
147
+ # for BLOCK_N in [2048, 4096, 8192]
148
+ # for ns in [1, 2, 4]
149
+ # for nw in [1, 2, 4, 8, 16]],
150
+ # key=['N'])
151
+ @triton.jit
152
+ def _grpo_loss_bwd_kernel(
153
+ DLOSS,
154
+ DLOGITS,
155
+ LOGITS,
156
+ OLD_LOGP,
157
+ REF_LOGP,
158
+ INPUT_IDS,
159
+ ADVANTAGES,
160
+ COMPLETION_MASK,
161
+ LSE,
162
+ TEMPERATURE,
163
+ BETA: tl.constexpr,
164
+ EPS_LOW,
165
+ EPS_HIGH,
166
+ loss_stride0,
167
+ loss_stride1,
168
+ L: tl.constexpr,
169
+ N: tl.constexpr,
170
+ BLOCK_N: tl.constexpr = 4096,
171
+ ):
172
+ off_b = tl.program_id(0).cast(tl.int64)
173
+ off_l = tl.program_id(1).cast(tl.int64)
174
+
175
+ DLOGITS += off_b * (L + 1) * N + off_l * N
176
+ if COMPLETION_MASK is not None:
177
+ COMPLETION_MASK += off_b * L + off_l
178
+ not_skip = tl.load(COMPLETION_MASK)
179
+ if not_skip == 0:
180
+ for start in range(0, N, BLOCK_N):
181
+ cols = tl.arange(0, BLOCK_N) + start
182
+ tl.store(DLOGITS + cols, 0.0, mask=cols < N)
183
+ return
184
+
185
+ LOGITS += off_b * (L + 1) * N + off_l * N
186
+ DLOSS += off_b * loss_stride0 + off_l * loss_stride1
187
+ INPUT_IDS += off_b * L + off_l
188
+ ADVANTAGES += off_b
189
+ LSE += off_b * L + off_l
190
+
191
+ dloss = tl.load(DLOSS).to(tl.float32)
192
+ lse = tl.load(LSE).to(tl.float32)
193
+
194
+ idx = tl.load(INPUT_IDS)
195
+ x = tl.load(LOGITS + idx).to(tl.float32) / TEMPERATURE
196
+ logp = x - lse
197
+ if OLD_LOGP is None:
198
+ old_logp = logp
199
+ else:
200
+ OLD_LOGP += off_b * L + off_l
201
+ old_logp = tl.load(OLD_LOGP).to(tl.float32)
202
+ coef_1 = tl.exp(logp - old_logp)
203
+ coef_2 = tl.clamp(coef_1, 1 - EPS_LOW, 1 + EPS_HIGH)
204
+ advantage = tl.load(ADVANTAGES).to(tl.float32)
205
+ per_token_loss1 = coef_1 * advantage
206
+ per_token_loss2 = coef_2 * advantage
207
+ mask = per_token_loss2 >= per_token_loss1
208
+
209
+ dlogp = -per_token_loss1 * mask
210
+ if BETA != 0.0:
211
+ REF_LOGP += off_b * L + off_l
212
+ ref_logp = tl.load(REF_LOGP).to(tl.float32)
213
+ dlogp += BETA * (1 - tl.exp(ref_logp - logp))
214
+
215
+ dlogp = dlogp * dloss / TEMPERATURE
216
+ tl.debug_barrier()
217
+ for start_n in tl.range(0, N, BLOCK_N):
218
+ cols = start_n + tl.arange(0, BLOCK_N)
219
+ logits = tl.load(LOGITS + cols, mask=cols < N, other=-float("inf")).to(tl.float32) / TEMPERATURE
220
+ probs = tl.exp(logits - lse)
221
+ dlogits = tl.where(cols == idx, 1 - probs, -probs) * dlogp
222
+ tl.store(DLOGITS + cols, dlogits, mask=cols < N)
223
+
224
+
225
+ class GrpoLossFunction(torch.autograd.Function):
226
+ @staticmethod
227
+ def forward(
228
+ ctx,
229
+ logits,
230
+ old_logp,
231
+ ref_logp,
232
+ completion_ids,
233
+ advantages,
234
+ completion_mask,
235
+ temperature,
236
+ beta,
237
+ eps_low,
238
+ eps_high,
239
+ inplace,
240
+ ):
241
+ assert logits.is_contiguous() and completion_ids.is_contiguous()
242
+ assert old_logp is None or old_logp.is_contiguous()
243
+ assert (ref_logp is not None and ref_logp.is_contiguous()) if beta != 0.0 else True
244
+
245
+ B, L_ADD_1, N = logits.shape
246
+ L = L_ADD_1 - 1
247
+
248
+ if completion_mask is not None:
249
+ assert completion_mask.is_contiguous()
250
+
251
+ loss = torch.zeros(B, L, device=logits.device, dtype=torch.float32)
252
+ lse = torch.zeros_like(loss)
253
+ is_clipped = torch.zeros_like(loss)
254
+ kl = torch.zeros_like(loss) if beta != 0.0 else None
255
+ kwargs = {"BLOCK_N": 2048, "num_stages": 2, "num_warps": 1}
256
+ _grpo_loss_fwd_kernel[(B, L)](
257
+ logits,
258
+ old_logp,
259
+ ref_logp,
260
+ completion_ids,
261
+ completion_mask,
262
+ advantages,
263
+ loss,
264
+ lse,
265
+ kl,
266
+ is_clipped,
267
+ temperature,
268
+ beta,
269
+ eps_low,
270
+ eps_high,
271
+ L,
272
+ N,
273
+ **kwargs,
274
+ )
275
+ ctx.save_for_backward(logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse)
276
+ ctx.infos = (temperature, beta, eps_low, eps_high, inplace)
277
+ # return loss
278
+ return loss, kl, is_clipped
279
+
280
+ @staticmethod
281
+ def backward(ctx, *args):
282
+ dloss = args[0]
283
+ # print(dloss.shape)
284
+ logits, old_logp, ref_logp, completion_ids, advantages, completion_mask, lse = ctx.saved_tensors
285
+ temperature, beta, eps_low, eps_high, inplace = ctx.infos
286
+ B, L_ADD_1, N = logits.shape
287
+ L = L_ADD_1 - 1
288
+ dlogits = logits.data if inplace else torch.empty_like(logits)
289
+ kwargs = {"BLOCK_N": 4096, "num_stages": 1, "num_warps": 16}
290
+ _grpo_loss_bwd_kernel[(B, L)](
291
+ dloss,
292
+ dlogits,
293
+ logits,
294
+ old_logp,
295
+ ref_logp,
296
+ completion_ids,
297
+ advantages,
298
+ completion_mask,
299
+ lse,
300
+ temperature,
301
+ beta,
302
+ eps_low,
303
+ eps_high,
304
+ *dloss.stride(),
305
+ L,
306
+ N,
307
+ **kwargs,
308
+ )
309
+ dlogits[:, -1, :] = 0
310
+ return dlogits, None, None, None, None, None, None, None, None, None, None
@@ -5,16 +5,18 @@ from liger_kernel.ops.dyt import LigerDyTFunction
5
5
 
6
6
 
7
7
  class LigerDyT(nn.Module):
8
- def __init__(self, hidden_size, init_alpha=0.5):
8
+ def __init__(self, hidden_size, beta=True, init_alpha=0.5):
9
9
  super().__init__()
10
10
  self.hidden_size = hidden_size
11
11
  self.init_alpha = init_alpha
12
12
  self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
13
13
  self.gamma = nn.Parameter(torch.ones(hidden_size))
14
- self.beta = nn.Parameter(torch.zeros(hidden_size))
14
+ self.beta = None
15
+ if beta:
16
+ self.beta = nn.Parameter(torch.zeros(hidden_size))
15
17
 
16
18
  def forward(self, x):
17
19
  return LigerDyTFunction.apply(x, self.alpha, self.gamma, self.beta)
18
20
 
19
21
  def extra_repr(self):
20
- return f"{self.hidden_size}, init_alpha={self.init_alpha}"
22
+ return f"{self.hidden_size}, init_alpha={self.init_alpha}, beta={self.beta}"