liger-kernel-nightly 0.6.2.dev20251011154226__tar.gz → 0.6.2.dev20251011154427__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (288) hide show
  1. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_cross_entropy.py +4 -1
  3. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +25 -19
  4. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/pyproject.toml +1 -1
  5. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/ops/cross_entropy.py +55 -52
  6. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/ops/fused_linear_cross_entropy.py +3 -2
  7. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
  8. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/transformers/test_cross_entropy.py +45 -0
  9. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/transformers/test_fused_linear_cross_entropy.py +113 -0
  10. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  11. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  12. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/.github/pull_request_template.md +0 -0
  13. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/.github/workflows/amd-ci.yml +0 -0
  14. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/.github/workflows/benchmark.yml +0 -0
  15. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/.github/workflows/docs.yml +0 -0
  16. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/.github/workflows/intel-ci.yml +0 -0
  17. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/.github/workflows/nvi-ci.yml +0 -0
  18. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/.github/workflows/publish-nightly.yml +0 -0
  19. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/.github/workflows/publish-release.yml +0 -0
  20. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/.gitignore +0 -0
  21. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/LICENSE +0 -0
  22. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/Makefile +0 -0
  23. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/NOTICE +0 -0
  24. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/README.md +0 -0
  25. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/README.md +0 -0
  26. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/__init__.py +0 -0
  27. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/benchmarks_visualizer.py +0 -0
  28. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/data/all_benchmark_data.csv +0 -0
  29. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/__init__.py +0 -0
  30. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  31. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_distill_cosine_loss.py +0 -0
  32. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  33. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  34. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_dyt.py +0 -0
  35. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_embedding.py +0 -0
  36. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_fused_add_rms_norm.py +0 -0
  37. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  38. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_fused_neighborhood_attention.py +0 -0
  39. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_geglu.py +0 -0
  40. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_group_norm.py +0 -0
  41. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_grpo_loss.py +0 -0
  42. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_jsd.py +0 -0
  43. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_kl_div.py +0 -0
  44. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  45. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  46. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_llama4_rope.py +0 -0
  47. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_multi_token_attention.py +0 -0
  48. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  49. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  50. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  51. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_rope.py +0 -0
  52. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  53. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_softmax.py +0 -0
  54. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_sparse_multi_token_attention.py +0 -0
  55. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_sparsemax.py +0 -0
  56. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_swiglu.py +0 -0
  57. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/benchmark_tvd.py +0 -0
  58. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/benchmark/scripts/utils.py +0 -0
  59. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/dev/fmt-requirements.txt +0 -0
  60. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/dev/modal/benchmarks.py +0 -0
  61. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/dev/modal/tests.py +0 -0
  62. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/dev/modal/tests_bwd.py +0 -0
  63. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/docs/Examples.md +0 -0
  64. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/docs/Getting-Started.md +0 -0
  65. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/docs/High-Level-APIs.md +0 -0
  66. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/docs/Low-Level-APIs.md +0 -0
  67. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/docs/acknowledgement.md +0 -0
  68. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/docs/contributing.md +0 -0
  69. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/docs/images/banner.GIF +0 -0
  70. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/docs/images/compose.gif +0 -0
  71. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/docs/images/e2e-memory.png +0 -0
  72. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/docs/images/e2e-tps.png +0 -0
  73. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/docs/images/logo-banner.png +0 -0
  74. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/docs/images/patch.gif +0 -0
  75. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/docs/images/post-training.png +0 -0
  76. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/docs/index.md +0 -0
  77. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/docs/license.md +0 -0
  78. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/alignment/accelerate_config.yaml +0 -0
  79. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/alignment/run_orpo.py +0 -0
  80. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/huggingface/README.md +0 -0
  81. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/huggingface/callback.py +0 -0
  82. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/huggingface/config/fsdp_config.json +0 -0
  83. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  84. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  85. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  86. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/huggingface/img/llama_tps.png +0 -0
  87. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  88. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/huggingface/img/qwen_tps.png +0 -0
  89. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/huggingface/launch_on_modal.py +0 -0
  90. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/huggingface/requirements.txt +0 -0
  91. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/huggingface/run_benchmarks.sh +0 -0
  92. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/huggingface/run_gemma.sh +0 -0
  93. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/huggingface/run_llama.sh +0 -0
  94. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/huggingface/run_qwen.sh +0 -0
  95. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/huggingface/run_qwen2_vl.sh +0 -0
  96. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/huggingface/training.py +0 -0
  97. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/huggingface/training_multimodal.py +0 -0
  98. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/lightning/README.md +0 -0
  99. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/lightning/requirements.txt +0 -0
  100. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/lightning/training.py +0 -0
  101. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/medusa/README.md +0 -0
  102. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/medusa/callback.py +0 -0
  103. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  104. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  105. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  106. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  107. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  108. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  109. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  110. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  111. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  112. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/medusa/medusa_util.py +0 -0
  113. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/medusa/requirements.txt +0 -0
  114. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  115. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/examples/medusa/train.py +0 -0
  116. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/licenses/LICENSE-Apache-2.0 +0 -0
  117. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  118. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  119. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/licenses/LICENSE-MIT-llmc +0 -0
  120. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/licenses/LICENSE-MIT-triton +0 -0
  121. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/mkdocs.yml +0 -0
  122. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/setup.cfg +0 -0
  123. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/setup.py +0 -0
  124. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/__init__.py +0 -0
  125. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/chunked_loss/README.md +0 -0
  126. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  127. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/chunked_loss/cosine_similarity_loss.py +0 -0
  128. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  129. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  130. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/chunked_loss/functional.py +0 -0
  131. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  132. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/chunked_loss/fused_linear_ppo.py +0 -0
  133. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  134. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  135. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
  136. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  137. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  138. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  139. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  140. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/env_report.py +0 -0
  141. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/ops/__init__.py +0 -0
  142. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/ops/dyt.py +0 -0
  143. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  144. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  145. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/ops/fused_add_rms_norm.py +0 -0
  146. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  147. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/ops/fused_neighborhood_attention.py +0 -0
  148. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/ops/geglu.py +0 -0
  149. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/ops/group_norm.py +0 -0
  150. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/ops/grpo_loss.py +0 -0
  151. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/ops/jsd.py +0 -0
  152. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/ops/kl_div.py +0 -0
  153. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/ops/layer_norm.py +0 -0
  154. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/ops/llama4_rope.py +0 -0
  155. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/ops/multi_token_attention.py +0 -0
  156. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  157. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/ops/rms_norm.py +0 -0
  158. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/ops/rope.py +0 -0
  159. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/ops/softmax.py +0 -0
  160. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/ops/sparsemax.py +0 -0
  161. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/ops/swiglu.py +0 -0
  162. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/ops/tvd.py +0 -0
  163. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/ops/utils.py +0 -0
  164. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/__init__.py +0 -0
  165. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/auto_model.py +0 -0
  166. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  167. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/dyt.py +0 -0
  168. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/experimental/__init__.py +0 -0
  169. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  170. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/fsdp.py +0 -0
  171. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/functional.py +0 -0
  172. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/fused_add_rms_norm.py +0 -0
  173. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  174. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  175. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/fused_neighborhood_attention.py +0 -0
  176. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/geglu.py +0 -0
  177. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/group_norm.py +0 -0
  178. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/grpo_loss.py +0 -0
  179. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/jsd.py +0 -0
  180. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/kl_div.py +0 -0
  181. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/layer_norm.py +0 -0
  182. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/llama4_rope.py +0 -0
  183. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/model/__init__.py +0 -0
  184. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/model/falcon_h1.py +0 -0
  185. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/model/gemma.py +0 -0
  186. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  187. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/model/gemma3.py +0 -0
  188. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/model/glm4.py +0 -0
  189. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/model/glm4v.py +0 -0
  190. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/model/glm4v_moe.py +0 -0
  191. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/model/internvl.py +0 -0
  192. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/model/llama.py +0 -0
  193. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/model/llama4.py +0 -0
  194. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/model/llava.py +0 -0
  195. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
  196. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/model/mistral.py +0 -0
  197. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  198. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/model/mllama.py +0 -0
  199. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/model/olmo2.py +0 -0
  200. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/model/paligemma.py +0 -0
  201. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/model/phi3.py +0 -0
  202. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  203. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
  204. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  205. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/model/qwen3.py +0 -0
  206. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/model/qwen3_moe.py +0 -0
  207. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/model/smollm3.py +0 -0
  208. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  209. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/multi_token_attention.py +0 -0
  210. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  211. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/rms_norm.py +0 -0
  212. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/rope.py +0 -0
  213. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/softmax.py +0 -0
  214. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/sparsemax.py +0 -0
  215. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/swiglu.py +0 -0
  216. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  217. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  218. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  219. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/transformers/tvd.py +0 -0
  220. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/triton/__init__.py +0 -0
  221. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/triton/monkey_patch.py +0 -0
  222. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel/utils.py +0 -0
  223. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
  224. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  225. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  226. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  227. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/__init__.py +0 -0
  228. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/chunked_loss/__init__.py +0 -0
  229. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/chunked_loss/test_cosine_loss.py +0 -0
  230. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/chunked_loss/test_cpo_loss.py +0 -0
  231. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/chunked_loss/test_dpo_loss.py +0 -0
  232. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/chunked_loss/test_grpo_loss.py +0 -0
  233. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/chunked_loss/test_jsd_loss.py +0 -0
  234. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/chunked_loss/test_kto_loss.py +0 -0
  235. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/chunked_loss/test_orpo_loss.py +0 -0
  236. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/chunked_loss/test_simpo_loss.py +0 -0
  237. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/conftest.py +0 -0
  238. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/convergence/__init__.py +0 -0
  239. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/convergence/bf16/__init__.py +0 -0
  240. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/convergence/bf16/test_mini_models.py +0 -0
  241. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
  242. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
  243. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/convergence/fp32/__init__.py +0 -0
  244. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/convergence/fp32/test_mini_models.py +0 -0
  245. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
  246. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
  247. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json +0 -0
  248. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
  249. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
  250. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
  251. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
  252. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/resources/fake_configs/OpenGVLab/InternVL3-1B-hf/tokenizer_config.json +0 -0
  253. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  254. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
  255. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  256. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/resources/fake_configs/meta-llama/Llama-4-Scout-17B-16E-Instruct/tokenizer_config.json +0 -0
  257. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  258. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/resources/tiny_shakespeare.txt +0 -0
  259. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  260. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  261. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  262. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/transformers/test_auto_model.py +0 -0
  263. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/transformers/test_dyt.py +0 -0
  264. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/transformers/test_embedding.py +0 -0
  265. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/transformers/test_flex_attention.py +0 -0
  266. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/transformers/test_fused_add_rms_norm.py +0 -0
  267. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/transformers/test_fused_linear_jsd.py +0 -0
  268. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/transformers/test_fused_neighborhood_attention.py +0 -0
  269. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/transformers/test_geglu.py +0 -0
  270. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/transformers/test_group_norm.py +0 -0
  271. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/transformers/test_grpo_loss.py +0 -0
  272. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/transformers/test_jsd.py +0 -0
  273. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/transformers/test_kl_div.py +0 -0
  274. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/transformers/test_layer_norm.py +0 -0
  275. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/transformers/test_mm_int8int2.py +0 -0
  276. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/transformers/test_monkey_patch.py +0 -0
  277. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/transformers/test_multi_token_attention.py +0 -0
  278. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/transformers/test_qwen2vl_mrope.py +0 -0
  279. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/transformers/test_rms_norm.py +0 -0
  280. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/transformers/test_rope.py +0 -0
  281. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/transformers/test_softmax.py +0 -0
  282. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/transformers/test_sparsemax.py +0 -0
  283. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/transformers/test_swiglu.py +0 -0
  284. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/transformers/test_trainer_integration.py +0 -0
  285. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/transformers/test_transformers.py +0 -0
  286. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/transformers/test_tvd.py +0 -0
  287. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/triton/test_triton_monkey_patch.py +0 -0
  288. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251011154427}/test/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.6.2.dev20251011154226
3
+ Version: 0.6.2.dev20251011154427
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -70,6 +70,9 @@ def bench_speed_cross_entropy(
70
70
 
71
71
  if mode == "forward":
72
72
  ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES)
73
+ elif mode == "no-grad-forward":
74
+ with torch.no_grad():
75
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES)
73
76
  elif mode == "backward":
74
77
  y = fwd()
75
78
 
@@ -109,7 +112,7 @@ if __name__ == "__main__":
109
112
 
110
113
  run_benchmarks(
111
114
  bench_test_fn=bench_speed_cross_entropy,
112
- kernel_operation_modes=["forward", "backward", "full"],
115
+ kernel_operation_modes=["forward", "backward", "full", "no-grad-forward"],
113
116
  metric_name="speed",
114
117
  metric_unit="ms",
115
118
  **common_configs,
@@ -59,26 +59,26 @@ def bench_memory_fused_linear_cross_entropy(
59
59
  dtype = input.extra_benchmark_config["dtype"]
60
60
  provider = input.kernel_provider
61
61
 
62
- torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)
63
- liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)
64
- liger_lm_head_ce_fp32_accum = LigerLMHeadCE(H=H, V=V, dtype=dtype, accum_dtype=torch.float32).to(device)
62
+ lm_head_ce = None
63
+ if provider == "liger":
64
+ lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)
65
+ elif provider == "liger-fp32-accum":
66
+ lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype, accum_dtype=torch.float32).to(device)
67
+ else:
68
+ lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)
65
69
 
66
70
  _input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device)
67
71
  target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1)
68
72
 
69
73
  def fwd():
70
- if provider == "liger":
71
- return liger_lm_head_ce(_input, target)
72
- elif provider == "liger-fp32-accum":
73
- return liger_lm_head_ce_fp32_accum(_input, target)
74
- elif provider == "huggingface":
75
- return torch_lm_head_ce(_input, target)
74
+ return lm_head_ce(_input, target)
76
75
 
77
76
  def full():
78
77
  y = fwd()
79
78
  y.backward()
80
79
 
81
80
  mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
81
+
82
82
  return SingleBenchmarkRunOutput(
83
83
  y_20=mem_20,
84
84
  y_50=mem_50,
@@ -101,20 +101,19 @@ def bench_speed_fused_linear_cross_entropy(
101
101
  provider = input.kernel_provider
102
102
  mode = input.kernel_operation_mode
103
103
 
104
- torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)
105
- liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)
106
- liger_lm_head_ce_fp32_accum = LigerLMHeadCE(H=H, V=V, dtype=dtype, accum_dtype=torch.float32).to(device)
104
+ lm_head_ce = None
105
+ if provider == "liger":
106
+ lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)
107
+ elif provider == "liger-fp32-accum":
108
+ lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype, accum_dtype=torch.float32).to(device)
109
+ else:
110
+ lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)
107
111
 
108
112
  _input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device)
109
113
  target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1)
110
114
 
111
115
  def fwd():
112
- if provider == "liger":
113
- return liger_lm_head_ce(_input, target)
114
- elif provider == "liger-fp32-accum":
115
- return liger_lm_head_ce_fp32_accum(_input, target)
116
- elif provider == "huggingface":
117
- return torch_lm_head_ce(_input, target)
116
+ return lm_head_ce(_input, target)
118
117
 
119
118
  if mode == "forward":
120
119
  ms_50, ms_20, ms_80 = triton.testing.do_bench(
@@ -122,6 +121,13 @@ def bench_speed_fused_linear_cross_entropy(
122
121
  rep=100,
123
122
  quantiles=QUANTILES,
124
123
  )
124
+ elif mode == "no-grad-forward":
125
+ with torch.no_grad():
126
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(
127
+ fwd,
128
+ rep=100,
129
+ quantiles=QUANTILES,
130
+ )
125
131
  elif mode == "backward":
126
132
  y = fwd()
127
133
 
@@ -164,7 +170,7 @@ if __name__ == "__main__":
164
170
 
165
171
  run_benchmarks(
166
172
  bench_test_fn=bench_speed_fused_linear_cross_entropy,
167
- kernel_operation_modes=["forward", "backward", "full"],
173
+ kernel_operation_modes=["forward", "backward", "full", "no-grad-forward"],
168
174
  metric_name="speed",
169
175
  metric_unit="ms",
170
176
  **common_configs,
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.6.2.dev20251011154226"
7
+ version = "0.6.2.dev20251011154427"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -45,6 +45,7 @@ def liger_cross_entropy_kernel(
45
45
  BLOCK_SIZE: tl.constexpr,
46
46
  HAS_WEIGHT: tl.constexpr,
47
47
  HAS_SOFTCAPPING: tl.constexpr,
48
+ HAS_GRADIENTS: tl.constexpr,
48
49
  ):
49
50
  """
50
51
  This kernel computes both cross entropy loss and the gradient of the input.
@@ -72,6 +73,7 @@ def liger_cross_entropy_kernel(
72
73
  BLOCK_SIZE (int): The block size for Triton operations.
73
74
  HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
74
75
  HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
76
+ HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
75
77
  """
76
78
 
77
79
  # https://github.com/triton-lang/triton/issues/1058
@@ -155,58 +157,58 @@ def liger_cross_entropy_kernel(
155
157
  # For 'sum' reduction, no normalization is applied:
156
158
  # dx_y = softmax(x_y) - 1
157
159
  # dx_i = softmax(x_i), for i ≠ y
158
-
159
- for i in range(0, n_cols, BLOCK_SIZE):
160
- X_offsets = i + tl.arange(0, BLOCK_SIZE)
161
- X_block = tl.load(
162
- X_ptr + X_offsets,
163
- mask=X_offsets < n_cols,
164
- other=float("-inf"),
165
- # Ensure float32 precision for softmax calculation
166
- ).cast(tl.float32)
167
- if HAS_SOFTCAPPING:
168
- intermediate = tanh(X_block / softcap)
169
- X_block = softcap * intermediate
170
-
171
- if not HAS_WEIGHT:
172
- # softmax(x_i)
173
- X_block = tl.exp(X_block - m) / d
174
- # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
175
- X_block += 2 * lse_square_scale * lse * X_block
176
- # smoothing term
177
- X_block += -eps
178
- # special handle dx_y
179
- X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
180
- # reduction scale
181
- if reduction == "mean":
182
- X_block = X_block / n_non_ignore
183
- else:
184
- weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
185
- softmax_X = tl.exp(X_block - m) / d
186
- # derivative of original_loss
187
- dloss_ori = (1 - label_smoothing) * softmax_X
188
- # specially handle dx_y
189
- dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
190
- dloss_ori = dloss_ori * weight_y
191
- # derivative of smooth_loss
192
- dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
193
- # derivative of z-loss
194
- dz_loss = 2 * lse_square_scale * lse * softmax_X
195
- # reduction scale
196
- if reduction == "mean":
197
- dloss_ori = dloss_ori / sum_non_ignore_weight
198
- dloss_smooth = dloss_smooth / sum_non_ignore_weight
199
- # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
200
- dz_loss = dz_loss / n_non_ignore
201
- # derivative of total_loss
202
- X_block = dloss_ori + dloss_smooth + dz_loss
203
-
204
- # chain rule softcapping
205
- # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
206
- if HAS_SOFTCAPPING:
207
- X_block = X_block * (1 - intermediate * intermediate)
208
-
209
- tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
160
+ if HAS_GRADIENTS:
161
+ for i in range(0, n_cols, BLOCK_SIZE):
162
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
163
+ X_block = tl.load(
164
+ X_ptr + X_offsets,
165
+ mask=X_offsets < n_cols,
166
+ other=float("-inf"),
167
+ # Ensure float32 precision for softmax calculation
168
+ ).cast(tl.float32)
169
+ if HAS_SOFTCAPPING:
170
+ intermediate = tanh(X_block / softcap)
171
+ X_block = softcap * intermediate
172
+
173
+ if not HAS_WEIGHT:
174
+ # softmax(x_i)
175
+ X_block = tl.exp(X_block - m) / d
176
+ # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
177
+ X_block += 2 * lse_square_scale * lse * X_block
178
+ # smoothing term
179
+ X_block += -eps
180
+ # special handle dx_y
181
+ X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
182
+ # reduction scale
183
+ if reduction == "mean":
184
+ X_block = X_block / n_non_ignore
185
+ else:
186
+ weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
187
+ softmax_X = tl.exp(X_block - m) / d
188
+ # derivative of original_loss
189
+ dloss_ori = (1 - label_smoothing) * softmax_X
190
+ # specially handle dx_y
191
+ dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
192
+ dloss_ori = dloss_ori * weight_y
193
+ # derivative of smooth_loss
194
+ dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
195
+ # derivative of z-loss
196
+ dz_loss = 2 * lse_square_scale * lse * softmax_X
197
+ # reduction scale
198
+ if reduction == "mean":
199
+ dloss_ori = dloss_ori / sum_non_ignore_weight
200
+ dloss_smooth = dloss_smooth / sum_non_ignore_weight
201
+ # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
202
+ dz_loss = dz_loss / n_non_ignore
203
+ # derivative of total_loss
204
+ X_block = dloss_ori + dloss_smooth + dz_loss
205
+
206
+ # chain rule softcapping
207
+ # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
208
+ if HAS_SOFTCAPPING:
209
+ X_block = X_block * (1 - intermediate * intermediate)
210
+
211
+ tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
210
212
 
211
213
  # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
212
214
  # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
@@ -332,6 +334,7 @@ def cross_entropy_forward(
332
334
  BLOCK_SIZE=BLOCK_SIZE,
333
335
  HAS_WEIGHT=True if weight is not None else False,
334
336
  HAS_SOFTCAPPING=True if softcap is not None else False,
337
+ HAS_GRADIENTS=_input.requires_grad,
335
338
  # TODO: 32 seems to give the best performance
336
339
  # Performance is quite sensitive to num_warps
337
340
  num_warps=32 if not is_hip() else 16,
@@ -150,6 +150,7 @@ def fused_linear_cross_entropy_forward(
150
150
  RETURN_Z_LOSS=return_z_loss,
151
151
  HAS_WEIGHT=True if ce_weight is not None else False,
152
152
  HAS_SOFTCAPPING=True if softcap is not None else False,
153
+ HAS_GRADIENTS=_input.requires_grad,
153
154
  BLOCK_SIZE=BLOCK_SIZE,
154
155
  num_warps=32 if not is_hip() else 16,
155
156
  )
@@ -173,10 +174,10 @@ def fused_linear_cross_entropy_forward(
173
174
 
174
175
  grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
175
176
 
176
- if grad_weight is not None:
177
+ if grad_weight is not None and _input.requires_grad:
177
178
  grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
178
179
 
179
- if bias is not None:
180
+ if bias is not None and _input.requires_grad:
180
181
  torch.add(
181
182
  input=grad_bias,
182
183
  other=grad_logits_chunk.sum(dim=0),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.6.2.dev20251011154226
3
+ Version: 0.6.2.dev20251011154427
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -455,6 +455,28 @@ def _test_correctness_not_last_layer_with_other_params_once(
455
455
  assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)
456
456
 
457
457
 
458
+ def _test_correctness_with_forward_only(target_ce, B, T, V, reduction, dtype, scalar, atol, rtol):
459
+ torch.manual_seed(0)
460
+ torch_ce = CrossEntropyLoss(reduction=reduction)
461
+
462
+ _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar
463
+ _input = _tensor.detach().clone()
464
+ _input2 = _tensor.detach().clone()
465
+
466
+ target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
467
+
468
+ with torch.no_grad():
469
+ output = torch_ce(_input, target)
470
+ output2 = target_ce(_input2, target)
471
+ assert torch.allclose(output, output2, atol=atol, rtol=rtol)
472
+
473
+ try:
474
+ # Try running backward on liger output
475
+ output2.backward(gradient=torch.ones_like(output))
476
+ except RuntimeError as e:
477
+ assert "does not require grad" in str(e)
478
+
479
+
458
480
  def _test_correctness_functional(
459
481
  B,
460
482
  T,
@@ -1014,6 +1036,7 @@ def test_float32_internal():
1014
1036
  RETURN_Z_LOSS=0, # False
1015
1037
  HAS_WEIGHT=False,
1016
1038
  HAS_SOFTCAPPING=False,
1039
+ HAS_GRADIENTS=True,
1017
1040
  BLOCK_SIZE=BLOCK_SIZE,
1018
1041
  num_warps=32 if not is_hip() else 16,
1019
1042
  )
@@ -1042,6 +1065,7 @@ def test_float32_internal():
1042
1065
  RETURN_Z_LOSS=0, # False
1043
1066
  HAS_WEIGHT=False,
1044
1067
  HAS_SOFTCAPPING=False,
1068
+ HAS_GRADIENTS=True,
1045
1069
  BLOCK_SIZE=BLOCK_SIZE,
1046
1070
  num_warps=32 if not is_hip() else 16,
1047
1071
  )
@@ -1061,3 +1085,24 @@ def test_float32_internal():
1061
1085
  def test_correctness_with_out_of_bounds_target_once(B, T, V, ignore_index):
1062
1086
  liger_ce = LigerCrossEntropyLoss(ignore_index=ignore_index)
1063
1087
  _test_correctness_with_out_of_bounds_target_once(liger_ce, B, T, V, ignore_index)
1088
+
1089
+
1090
+ @pytest.mark.parametrize(
1091
+ "B, T, V, ignore_index",
1092
+ [
1093
+ (2, 4096, 32000, -100),
1094
+ (3, 423, 32000, 2),
1095
+ ],
1096
+ )
1097
+ @pytest.mark.parametrize("reduction", ["mean", "sum", "none"])
1098
+ @pytest.mark.parametrize(
1099
+ "dtype, scalar, atol, rtol",
1100
+ [
1101
+ (torch.float32, 1.0, 1e-4, 1e-4),
1102
+ (torch.float16, 1.0, 1e-2, 1e-2),
1103
+ (torch.bfloat16, 1.0, 1e-2, 1e-2),
1104
+ ],
1105
+ )
1106
+ def test_correctness_with_forward_only(B, T, V, ignore_index, reduction, dtype, scalar, atol, rtol):
1107
+ liger_ce = LigerCrossEntropyLoss(ignore_index=ignore_index, reduction=reduction)
1108
+ _test_correctness_with_forward_only(liger_ce, B, T, V, reduction, dtype, scalar, atol, rtol)
@@ -231,6 +231,119 @@ def test_correctness(
231
231
  )
232
232
 
233
233
 
234
+ @pytest.mark.parametrize(
235
+ "B, T, H, V",
236
+ [
237
+ (8, 128, 1024, 4096),
238
+ (4, 47, 31, 123), # random shape
239
+ ],
240
+ )
241
+ @pytest.mark.parametrize(
242
+ "reduction, scalar, dtype, atol, rtol",
243
+ [
244
+ ("mean", 1.0, torch.bfloat16, 5e-3, 5e-2),
245
+ ("mean", 1.0, torch.float32, 1e-5, 5e-4),
246
+ ("sum", 1.0, torch.bfloat16, 5e-0, 5e1),
247
+ ("sum", 1.0, torch.float32, 1e-3, 5e-2),
248
+ ],
249
+ )
250
+ @pytest.mark.parametrize("bias", [True, False])
251
+ @pytest.mark.parametrize(
252
+ "has_ce_weight, label_smoothing, ignore_index, lse_square_scale, softcap, return_z_loss, accum_dtype",
253
+ [
254
+ (False, 0, -100, 0, None, False, None),
255
+ # Pass non-default values once to ensure all params work along
256
+ (True, 0.1, 42, 1e-4, 30.0, True, torch.float32),
257
+ ],
258
+ )
259
+ def test_correctness_with_forward_only(
260
+ B,
261
+ T,
262
+ H,
263
+ V,
264
+ scalar,
265
+ dtype,
266
+ bias,
267
+ has_ce_weight,
268
+ lse_square_scale,
269
+ label_smoothing,
270
+ ignore_index,
271
+ reduction,
272
+ softcap,
273
+ return_z_loss,
274
+ accum_dtype,
275
+ atol,
276
+ rtol,
277
+ ):
278
+ if has_ce_weight:
279
+ ce_weight = torch.rand(V, device=device, dtype=torch.float32)
280
+ else:
281
+ ce_weight = None
282
+ torch_lm_head_ce = TorchLMHeadCE(
283
+ H=H,
284
+ V=V,
285
+ bias=bias,
286
+ ce_weight=ce_weight,
287
+ lse_square_scale=lse_square_scale,
288
+ label_smoothing=label_smoothing,
289
+ ignore_index=ignore_index,
290
+ reduction=reduction,
291
+ softcap=softcap,
292
+ return_z_loss=return_z_loss,
293
+ dtype=dtype,
294
+ ).to(device)
295
+ liger_lm_head_ce = LigerLMHeadCE(
296
+ H=H,
297
+ V=V,
298
+ bias=bias,
299
+ ce_weight=ce_weight,
300
+ lse_square_scale=lse_square_scale,
301
+ label_smoothing=label_smoothing,
302
+ ignore_index=ignore_index,
303
+ reduction=reduction,
304
+ softcap=softcap,
305
+ return_z_loss=return_z_loss,
306
+ dtype=dtype,
307
+ accum_dtype=accum_dtype,
308
+ ).to(device)
309
+
310
+ # init the linear in all CEs with the same weights
311
+ torch_lm_head_ce.lin.weight.data = liger_lm_head_ce.lin.weight.data = torch.rand(V, H, device=device, dtype=dtype)
312
+
313
+ if bias:
314
+ torch_lm_head_ce.lin.bias.data = liger_lm_head_ce.lin.bias.data = torch.rand(V, device=device, dtype=dtype)
315
+
316
+ _tensor = torch.randn(B * T, H, device=device, dtype=dtype) * scalar
317
+ _input1 = _tensor.detach().clone().requires_grad_(True)
318
+ _input2 = _tensor.detach().clone().requires_grad_(True)
319
+
320
+ target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
321
+ # Assign some random number of elements as ignore_index
322
+ num_elements_to_assign = torch.randint(
323
+ 1, B * T // 2, (1,)
324
+ ).item() # Random number of elements to set to ignore_index
325
+ indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] # Randomly select indices
326
+ target[indices_to_assign] = ignore_index
327
+
328
+ with torch.no_grad():
329
+ if return_z_loss:
330
+ output1, z_output1 = torch_lm_head_ce(_input1, target)
331
+ output2, z_output2 = liger_lm_head_ce(_input2, target)
332
+ else:
333
+ output1 = torch_lm_head_ce(_input1, target)
334
+ output2 = liger_lm_head_ce(_input2, target)
335
+
336
+ assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol)
337
+ if return_z_loss:
338
+ assert_verbose_allclose(z_output1, z_output2, atol=atol, rtol=rtol)
339
+
340
+ try:
341
+ grad_output = torch.rand_like(output1)
342
+ output2.backward(gradient=grad_output)
343
+ except RuntimeError as e:
344
+ assert "does not require grad" in str(e)
345
+
346
+
234
347
  @pytest.mark.parametrize(
235
348
  "B, T, H, V",
236
349
  [