liger-kernel-nightly 0.6.2.dev20251011154226__tar.gz → 0.6.2.dev20251013144132__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (288) hide show
  1. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_cross_entropy.py +4 -1
  3. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +25 -19
  4. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/pyproject.toml +1 -1
  5. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/ops/cross_entropy.py +55 -52
  6. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/ops/fused_linear_cross_entropy.py +3 -2
  7. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/monkey_patch.py +5 -2
  8. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
  9. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/transformers/test_cross_entropy.py +45 -0
  10. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/transformers/test_fused_linear_cross_entropy.py +113 -0
  11. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/transformers/test_monkey_patch.py +26 -4
  12. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  13. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  14. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/.github/pull_request_template.md +0 -0
  15. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/.github/workflows/amd-ci.yml +0 -0
  16. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/.github/workflows/benchmark.yml +0 -0
  17. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/.github/workflows/docs.yml +0 -0
  18. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/.github/workflows/intel-ci.yml +0 -0
  19. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/.github/workflows/nvi-ci.yml +0 -0
  20. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/.github/workflows/publish-nightly.yml +0 -0
  21. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/.github/workflows/publish-release.yml +0 -0
  22. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/.gitignore +0 -0
  23. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/LICENSE +0 -0
  24. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/Makefile +0 -0
  25. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/NOTICE +0 -0
  26. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/README.md +0 -0
  27. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/README.md +0 -0
  28. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/__init__.py +0 -0
  29. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/benchmarks_visualizer.py +0 -0
  30. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/data/all_benchmark_data.csv +0 -0
  31. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/__init__.py +0 -0
  32. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  33. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_distill_cosine_loss.py +0 -0
  34. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  35. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  36. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_dyt.py +0 -0
  37. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_embedding.py +0 -0
  38. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_fused_add_rms_norm.py +0 -0
  39. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  40. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_fused_neighborhood_attention.py +0 -0
  41. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_geglu.py +0 -0
  42. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_group_norm.py +0 -0
  43. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_grpo_loss.py +0 -0
  44. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_jsd.py +0 -0
  45. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_kl_div.py +0 -0
  46. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  47. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  48. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_llama4_rope.py +0 -0
  49. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_multi_token_attention.py +0 -0
  50. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  51. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  52. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  53. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_rope.py +0 -0
  54. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  55. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_softmax.py +0 -0
  56. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_sparse_multi_token_attention.py +0 -0
  57. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_sparsemax.py +0 -0
  58. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_swiglu.py +0 -0
  59. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/benchmark_tvd.py +0 -0
  60. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/benchmark/scripts/utils.py +0 -0
  61. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/dev/fmt-requirements.txt +0 -0
  62. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/dev/modal/benchmarks.py +0 -0
  63. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/dev/modal/tests.py +0 -0
  64. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/dev/modal/tests_bwd.py +0 -0
  65. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/docs/Examples.md +0 -0
  66. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/docs/Getting-Started.md +0 -0
  67. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/docs/High-Level-APIs.md +0 -0
  68. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/docs/Low-Level-APIs.md +0 -0
  69. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/docs/acknowledgement.md +0 -0
  70. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/docs/contributing.md +0 -0
  71. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/docs/images/banner.GIF +0 -0
  72. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/docs/images/compose.gif +0 -0
  73. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/docs/images/e2e-memory.png +0 -0
  74. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/docs/images/e2e-tps.png +0 -0
  75. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/docs/images/logo-banner.png +0 -0
  76. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/docs/images/patch.gif +0 -0
  77. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/docs/images/post-training.png +0 -0
  78. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/docs/index.md +0 -0
  79. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/docs/license.md +0 -0
  80. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/alignment/accelerate_config.yaml +0 -0
  81. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/alignment/run_orpo.py +0 -0
  82. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/huggingface/README.md +0 -0
  83. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/huggingface/callback.py +0 -0
  84. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/huggingface/config/fsdp_config.json +0 -0
  85. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  86. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  87. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  88. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/huggingface/img/llama_tps.png +0 -0
  89. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  90. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/huggingface/img/qwen_tps.png +0 -0
  91. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/huggingface/launch_on_modal.py +0 -0
  92. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/huggingface/requirements.txt +0 -0
  93. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/huggingface/run_benchmarks.sh +0 -0
  94. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/huggingface/run_gemma.sh +0 -0
  95. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/huggingface/run_llama.sh +0 -0
  96. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/huggingface/run_qwen.sh +0 -0
  97. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/huggingface/run_qwen2_vl.sh +0 -0
  98. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/huggingface/training.py +0 -0
  99. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/huggingface/training_multimodal.py +0 -0
  100. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/lightning/README.md +0 -0
  101. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/lightning/requirements.txt +0 -0
  102. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/lightning/training.py +0 -0
  103. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/medusa/README.md +0 -0
  104. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/medusa/callback.py +0 -0
  105. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  106. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  107. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  108. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  109. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  110. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  111. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  112. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  113. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  114. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/medusa/medusa_util.py +0 -0
  115. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/medusa/requirements.txt +0 -0
  116. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  117. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/examples/medusa/train.py +0 -0
  118. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/licenses/LICENSE-Apache-2.0 +0 -0
  119. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  120. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  121. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/licenses/LICENSE-MIT-llmc +0 -0
  122. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/licenses/LICENSE-MIT-triton +0 -0
  123. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/mkdocs.yml +0 -0
  124. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/setup.cfg +0 -0
  125. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/setup.py +0 -0
  126. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/__init__.py +0 -0
  127. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/chunked_loss/README.md +0 -0
  128. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  129. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/chunked_loss/cosine_similarity_loss.py +0 -0
  130. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  131. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  132. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/chunked_loss/functional.py +0 -0
  133. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  134. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/chunked_loss/fused_linear_ppo.py +0 -0
  135. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  136. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  137. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
  138. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  139. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  140. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  141. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  142. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/env_report.py +0 -0
  143. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/ops/__init__.py +0 -0
  144. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/ops/dyt.py +0 -0
  145. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  146. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  147. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/ops/fused_add_rms_norm.py +0 -0
  148. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  149. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/ops/fused_neighborhood_attention.py +0 -0
  150. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/ops/geglu.py +0 -0
  151. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/ops/group_norm.py +0 -0
  152. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/ops/grpo_loss.py +0 -0
  153. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/ops/jsd.py +0 -0
  154. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/ops/kl_div.py +0 -0
  155. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/ops/layer_norm.py +0 -0
  156. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/ops/llama4_rope.py +0 -0
  157. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/ops/multi_token_attention.py +0 -0
  158. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  159. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/ops/rms_norm.py +0 -0
  160. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/ops/rope.py +0 -0
  161. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/ops/softmax.py +0 -0
  162. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/ops/sparsemax.py +0 -0
  163. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/ops/swiglu.py +0 -0
  164. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/ops/tvd.py +0 -0
  165. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/ops/utils.py +0 -0
  166. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/__init__.py +0 -0
  167. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/auto_model.py +0 -0
  168. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  169. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/dyt.py +0 -0
  170. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/experimental/__init__.py +0 -0
  171. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  172. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/fsdp.py +0 -0
  173. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/functional.py +0 -0
  174. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/fused_add_rms_norm.py +0 -0
  175. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  176. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  177. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/fused_neighborhood_attention.py +0 -0
  178. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/geglu.py +0 -0
  179. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/group_norm.py +0 -0
  180. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/grpo_loss.py +0 -0
  181. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/jsd.py +0 -0
  182. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/kl_div.py +0 -0
  183. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/layer_norm.py +0 -0
  184. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/llama4_rope.py +0 -0
  185. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/model/__init__.py +0 -0
  186. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/model/falcon_h1.py +0 -0
  187. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/model/gemma.py +0 -0
  188. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  189. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/model/gemma3.py +0 -0
  190. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/model/glm4.py +0 -0
  191. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/model/glm4v.py +0 -0
  192. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/model/glm4v_moe.py +0 -0
  193. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/model/internvl.py +0 -0
  194. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/model/llama.py +0 -0
  195. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/model/llama4.py +0 -0
  196. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/model/llava.py +0 -0
  197. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
  198. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/model/mistral.py +0 -0
  199. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  200. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/model/mllama.py +0 -0
  201. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/model/olmo2.py +0 -0
  202. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/model/paligemma.py +0 -0
  203. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/model/phi3.py +0 -0
  204. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  205. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
  206. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  207. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/model/qwen3.py +0 -0
  208. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/model/qwen3_moe.py +0 -0
  209. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/model/smollm3.py +0 -0
  210. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/multi_token_attention.py +0 -0
  211. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  212. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/rms_norm.py +0 -0
  213. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/rope.py +0 -0
  214. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/softmax.py +0 -0
  215. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/sparsemax.py +0 -0
  216. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/swiglu.py +0 -0
  217. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  218. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  219. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  220. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/transformers/tvd.py +0 -0
  221. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/triton/__init__.py +0 -0
  222. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/triton/monkey_patch.py +0 -0
  223. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel/utils.py +0 -0
  224. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
  225. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  226. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  227. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  228. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/__init__.py +0 -0
  229. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/chunked_loss/__init__.py +0 -0
  230. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/chunked_loss/test_cosine_loss.py +0 -0
  231. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/chunked_loss/test_cpo_loss.py +0 -0
  232. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/chunked_loss/test_dpo_loss.py +0 -0
  233. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/chunked_loss/test_grpo_loss.py +0 -0
  234. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/chunked_loss/test_jsd_loss.py +0 -0
  235. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/chunked_loss/test_kto_loss.py +0 -0
  236. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/chunked_loss/test_orpo_loss.py +0 -0
  237. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/chunked_loss/test_simpo_loss.py +0 -0
  238. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/conftest.py +0 -0
  239. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/convergence/__init__.py +0 -0
  240. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/convergence/bf16/__init__.py +0 -0
  241. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/convergence/bf16/test_mini_models.py +0 -0
  242. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
  243. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
  244. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/convergence/fp32/__init__.py +0 -0
  245. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/convergence/fp32/test_mini_models.py +0 -0
  246. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
  247. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
  248. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json +0 -0
  249. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
  250. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
  251. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
  252. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
  253. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/resources/fake_configs/OpenGVLab/InternVL3-1B-hf/tokenizer_config.json +0 -0
  254. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  255. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
  256. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  257. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/resources/fake_configs/meta-llama/Llama-4-Scout-17B-16E-Instruct/tokenizer_config.json +0 -0
  258. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  259. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/resources/tiny_shakespeare.txt +0 -0
  260. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  261. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  262. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  263. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/transformers/test_auto_model.py +0 -0
  264. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/transformers/test_dyt.py +0 -0
  265. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/transformers/test_embedding.py +0 -0
  266. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/transformers/test_flex_attention.py +0 -0
  267. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/transformers/test_fused_add_rms_norm.py +0 -0
  268. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/transformers/test_fused_linear_jsd.py +0 -0
  269. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/transformers/test_fused_neighborhood_attention.py +0 -0
  270. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/transformers/test_geglu.py +0 -0
  271. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/transformers/test_group_norm.py +0 -0
  272. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/transformers/test_grpo_loss.py +0 -0
  273. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/transformers/test_jsd.py +0 -0
  274. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/transformers/test_kl_div.py +0 -0
  275. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/transformers/test_layer_norm.py +0 -0
  276. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/transformers/test_mm_int8int2.py +0 -0
  277. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/transformers/test_multi_token_attention.py +0 -0
  278. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/transformers/test_qwen2vl_mrope.py +0 -0
  279. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/transformers/test_rms_norm.py +0 -0
  280. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/transformers/test_rope.py +0 -0
  281. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/transformers/test_softmax.py +0 -0
  282. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/transformers/test_sparsemax.py +0 -0
  283. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/transformers/test_swiglu.py +0 -0
  284. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/transformers/test_trainer_integration.py +0 -0
  285. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/transformers/test_transformers.py +0 -0
  286. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/transformers/test_tvd.py +0 -0
  287. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/triton/test_triton_monkey_patch.py +0 -0
  288. {liger_kernel_nightly-0.6.2.dev20251011154226 → liger_kernel_nightly-0.6.2.dev20251013144132}/test/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.6.2.dev20251011154226
3
+ Version: 0.6.2.dev20251013144132
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -70,6 +70,9 @@ def bench_speed_cross_entropy(
70
70
 
71
71
  if mode == "forward":
72
72
  ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES)
73
+ elif mode == "no-grad-forward":
74
+ with torch.no_grad():
75
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES)
73
76
  elif mode == "backward":
74
77
  y = fwd()
75
78
 
@@ -109,7 +112,7 @@ if __name__ == "__main__":
109
112
 
110
113
  run_benchmarks(
111
114
  bench_test_fn=bench_speed_cross_entropy,
112
- kernel_operation_modes=["forward", "backward", "full"],
115
+ kernel_operation_modes=["forward", "backward", "full", "no-grad-forward"],
113
116
  metric_name="speed",
114
117
  metric_unit="ms",
115
118
  **common_configs,
@@ -59,26 +59,26 @@ def bench_memory_fused_linear_cross_entropy(
59
59
  dtype = input.extra_benchmark_config["dtype"]
60
60
  provider = input.kernel_provider
61
61
 
62
- torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)
63
- liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)
64
- liger_lm_head_ce_fp32_accum = LigerLMHeadCE(H=H, V=V, dtype=dtype, accum_dtype=torch.float32).to(device)
62
+ lm_head_ce = None
63
+ if provider == "liger":
64
+ lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)
65
+ elif provider == "liger-fp32-accum":
66
+ lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype, accum_dtype=torch.float32).to(device)
67
+ else:
68
+ lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)
65
69
 
66
70
  _input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device)
67
71
  target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1)
68
72
 
69
73
  def fwd():
70
- if provider == "liger":
71
- return liger_lm_head_ce(_input, target)
72
- elif provider == "liger-fp32-accum":
73
- return liger_lm_head_ce_fp32_accum(_input, target)
74
- elif provider == "huggingface":
75
- return torch_lm_head_ce(_input, target)
74
+ return lm_head_ce(_input, target)
76
75
 
77
76
  def full():
78
77
  y = fwd()
79
78
  y.backward()
80
79
 
81
80
  mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
81
+
82
82
  return SingleBenchmarkRunOutput(
83
83
  y_20=mem_20,
84
84
  y_50=mem_50,
@@ -101,20 +101,19 @@ def bench_speed_fused_linear_cross_entropy(
101
101
  provider = input.kernel_provider
102
102
  mode = input.kernel_operation_mode
103
103
 
104
- torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)
105
- liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)
106
- liger_lm_head_ce_fp32_accum = LigerLMHeadCE(H=H, V=V, dtype=dtype, accum_dtype=torch.float32).to(device)
104
+ lm_head_ce = None
105
+ if provider == "liger":
106
+ lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)
107
+ elif provider == "liger-fp32-accum":
108
+ lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype, accum_dtype=torch.float32).to(device)
109
+ else:
110
+ lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)
107
111
 
108
112
  _input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device)
109
113
  target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1)
110
114
 
111
115
  def fwd():
112
- if provider == "liger":
113
- return liger_lm_head_ce(_input, target)
114
- elif provider == "liger-fp32-accum":
115
- return liger_lm_head_ce_fp32_accum(_input, target)
116
- elif provider == "huggingface":
117
- return torch_lm_head_ce(_input, target)
116
+ return lm_head_ce(_input, target)
118
117
 
119
118
  if mode == "forward":
120
119
  ms_50, ms_20, ms_80 = triton.testing.do_bench(
@@ -122,6 +121,13 @@ def bench_speed_fused_linear_cross_entropy(
122
121
  rep=100,
123
122
  quantiles=QUANTILES,
124
123
  )
124
+ elif mode == "no-grad-forward":
125
+ with torch.no_grad():
126
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(
127
+ fwd,
128
+ rep=100,
129
+ quantiles=QUANTILES,
130
+ )
125
131
  elif mode == "backward":
126
132
  y = fwd()
127
133
 
@@ -164,7 +170,7 @@ if __name__ == "__main__":
164
170
 
165
171
  run_benchmarks(
166
172
  bench_test_fn=bench_speed_fused_linear_cross_entropy,
167
- kernel_operation_modes=["forward", "backward", "full"],
173
+ kernel_operation_modes=["forward", "backward", "full", "no-grad-forward"],
168
174
  metric_name="speed",
169
175
  metric_unit="ms",
170
176
  **common_configs,
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.6.2.dev20251011154226"
7
+ version = "0.6.2.dev20251013144132"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -45,6 +45,7 @@ def liger_cross_entropy_kernel(
45
45
  BLOCK_SIZE: tl.constexpr,
46
46
  HAS_WEIGHT: tl.constexpr,
47
47
  HAS_SOFTCAPPING: tl.constexpr,
48
+ HAS_GRADIENTS: tl.constexpr,
48
49
  ):
49
50
  """
50
51
  This kernel computes both cross entropy loss and the gradient of the input.
@@ -72,6 +73,7 @@ def liger_cross_entropy_kernel(
72
73
  BLOCK_SIZE (int): The block size for Triton operations.
73
74
  HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
74
75
  HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
76
+ HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
75
77
  """
76
78
 
77
79
  # https://github.com/triton-lang/triton/issues/1058
@@ -155,58 +157,58 @@ def liger_cross_entropy_kernel(
155
157
  # For 'sum' reduction, no normalization is applied:
156
158
  # dx_y = softmax(x_y) - 1
157
159
  # dx_i = softmax(x_i), for i ≠ y
158
-
159
- for i in range(0, n_cols, BLOCK_SIZE):
160
- X_offsets = i + tl.arange(0, BLOCK_SIZE)
161
- X_block = tl.load(
162
- X_ptr + X_offsets,
163
- mask=X_offsets < n_cols,
164
- other=float("-inf"),
165
- # Ensure float32 precision for softmax calculation
166
- ).cast(tl.float32)
167
- if HAS_SOFTCAPPING:
168
- intermediate = tanh(X_block / softcap)
169
- X_block = softcap * intermediate
170
-
171
- if not HAS_WEIGHT:
172
- # softmax(x_i)
173
- X_block = tl.exp(X_block - m) / d
174
- # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
175
- X_block += 2 * lse_square_scale * lse * X_block
176
- # smoothing term
177
- X_block += -eps
178
- # special handle dx_y
179
- X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
180
- # reduction scale
181
- if reduction == "mean":
182
- X_block = X_block / n_non_ignore
183
- else:
184
- weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
185
- softmax_X = tl.exp(X_block - m) / d
186
- # derivative of original_loss
187
- dloss_ori = (1 - label_smoothing) * softmax_X
188
- # specially handle dx_y
189
- dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
190
- dloss_ori = dloss_ori * weight_y
191
- # derivative of smooth_loss
192
- dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
193
- # derivative of z-loss
194
- dz_loss = 2 * lse_square_scale * lse * softmax_X
195
- # reduction scale
196
- if reduction == "mean":
197
- dloss_ori = dloss_ori / sum_non_ignore_weight
198
- dloss_smooth = dloss_smooth / sum_non_ignore_weight
199
- # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
200
- dz_loss = dz_loss / n_non_ignore
201
- # derivative of total_loss
202
- X_block = dloss_ori + dloss_smooth + dz_loss
203
-
204
- # chain rule softcapping
205
- # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
206
- if HAS_SOFTCAPPING:
207
- X_block = X_block * (1 - intermediate * intermediate)
208
-
209
- tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
160
+ if HAS_GRADIENTS:
161
+ for i in range(0, n_cols, BLOCK_SIZE):
162
+ X_offsets = i + tl.arange(0, BLOCK_SIZE)
163
+ X_block = tl.load(
164
+ X_ptr + X_offsets,
165
+ mask=X_offsets < n_cols,
166
+ other=float("-inf"),
167
+ # Ensure float32 precision for softmax calculation
168
+ ).cast(tl.float32)
169
+ if HAS_SOFTCAPPING:
170
+ intermediate = tanh(X_block / softcap)
171
+ X_block = softcap * intermediate
172
+
173
+ if not HAS_WEIGHT:
174
+ # softmax(x_i)
175
+ X_block = tl.exp(X_block - m) / d
176
+ # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
177
+ X_block += 2 * lse_square_scale * lse * X_block
178
+ # smoothing term
179
+ X_block += -eps
180
+ # special handle dx_y
181
+ X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing))
182
+ # reduction scale
183
+ if reduction == "mean":
184
+ X_block = X_block / n_non_ignore
185
+ else:
186
+ weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
187
+ softmax_X = tl.exp(X_block - m) / d
188
+ # derivative of original_loss
189
+ dloss_ori = (1 - label_smoothing) * softmax_X
190
+ # specially handle dx_y
191
+ dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing))
192
+ dloss_ori = dloss_ori * weight_y
193
+ # derivative of smooth_loss
194
+ dloss_smooth = eps * (-weight_block + softmax_X * weight_sum)
195
+ # derivative of z-loss
196
+ dz_loss = 2 * lse_square_scale * lse * softmax_X
197
+ # reduction scale
198
+ if reduction == "mean":
199
+ dloss_ori = dloss_ori / sum_non_ignore_weight
200
+ dloss_smooth = dloss_smooth / sum_non_ignore_weight
201
+ # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
202
+ dz_loss = dz_loss / n_non_ignore
203
+ # derivative of total_loss
204
+ X_block = dloss_ori + dloss_smooth + dz_loss
205
+
206
+ # chain rule softcapping
207
+ # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
208
+ if HAS_SOFTCAPPING:
209
+ X_block = X_block * (1 - intermediate * intermediate)
210
+
211
+ tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
210
212
 
211
213
  # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
212
214
  # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
@@ -332,6 +334,7 @@ def cross_entropy_forward(
332
334
  BLOCK_SIZE=BLOCK_SIZE,
333
335
  HAS_WEIGHT=True if weight is not None else False,
334
336
  HAS_SOFTCAPPING=True if softcap is not None else False,
337
+ HAS_GRADIENTS=_input.requires_grad,
335
338
  # TODO: 32 seems to give the best performance
336
339
  # Performance is quite sensitive to num_warps
337
340
  num_warps=32 if not is_hip() else 16,
@@ -150,6 +150,7 @@ def fused_linear_cross_entropy_forward(
150
150
  RETURN_Z_LOSS=return_z_loss,
151
151
  HAS_WEIGHT=True if ce_weight is not None else False,
152
152
  HAS_SOFTCAPPING=True if softcap is not None else False,
153
+ HAS_GRADIENTS=_input.requires_grad,
153
154
  BLOCK_SIZE=BLOCK_SIZE,
154
155
  num_warps=32 if not is_hip() else 16,
155
156
  )
@@ -173,10 +174,10 @@ def fused_linear_cross_entropy_forward(
173
174
 
174
175
  grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
175
176
 
176
- if grad_weight is not None:
177
+ if grad_weight is not None and _input.requires_grad:
177
178
  grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
178
179
 
179
- if bias is not None:
180
+ if bias is not None and _input.requires_grad:
180
181
  torch.add(
181
182
  input=grad_bias,
182
183
  other=grad_logits_chunk.sum(dim=0),
@@ -469,7 +469,7 @@ def apply_liger_kernel_to_llama4(
469
469
  `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
470
470
  If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
471
471
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
472
- swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
472
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
473
473
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
474
474
  loaded. Default is None.
475
475
  """
@@ -522,7 +522,10 @@ def apply_liger_kernel_to_llama4(
522
522
  _patch_rms_norm_module(text_model.norm)
523
523
  for decoder_layer in text_model.layers:
524
524
  if swiglu:
525
- _patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
525
+ if decoder_layer.is_moe_layer:
526
+ _patch_swiglu_module(decoder_layer.feed_forward.shared_expert, LigerSwiGLUMLP)
527
+ else:
528
+ _patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
526
529
  if rms_norm:
527
530
  _patch_rms_norm_module(decoder_layer.input_layernorm)
528
531
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.6.2.dev20251011154226
3
+ Version: 0.6.2.dev20251013144132
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -455,6 +455,28 @@ def _test_correctness_not_last_layer_with_other_params_once(
455
455
  assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)
456
456
 
457
457
 
458
+ def _test_correctness_with_forward_only(target_ce, B, T, V, reduction, dtype, scalar, atol, rtol):
459
+ torch.manual_seed(0)
460
+ torch_ce = CrossEntropyLoss(reduction=reduction)
461
+
462
+ _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar
463
+ _input = _tensor.detach().clone()
464
+ _input2 = _tensor.detach().clone()
465
+
466
+ target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
467
+
468
+ with torch.no_grad():
469
+ output = torch_ce(_input, target)
470
+ output2 = target_ce(_input2, target)
471
+ assert torch.allclose(output, output2, atol=atol, rtol=rtol)
472
+
473
+ try:
474
+ # Try running backward on liger output
475
+ output2.backward(gradient=torch.ones_like(output))
476
+ except RuntimeError as e:
477
+ assert "does not require grad" in str(e)
478
+
479
+
458
480
  def _test_correctness_functional(
459
481
  B,
460
482
  T,
@@ -1014,6 +1036,7 @@ def test_float32_internal():
1014
1036
  RETURN_Z_LOSS=0, # False
1015
1037
  HAS_WEIGHT=False,
1016
1038
  HAS_SOFTCAPPING=False,
1039
+ HAS_GRADIENTS=True,
1017
1040
  BLOCK_SIZE=BLOCK_SIZE,
1018
1041
  num_warps=32 if not is_hip() else 16,
1019
1042
  )
@@ -1042,6 +1065,7 @@ def test_float32_internal():
1042
1065
  RETURN_Z_LOSS=0, # False
1043
1066
  HAS_WEIGHT=False,
1044
1067
  HAS_SOFTCAPPING=False,
1068
+ HAS_GRADIENTS=True,
1045
1069
  BLOCK_SIZE=BLOCK_SIZE,
1046
1070
  num_warps=32 if not is_hip() else 16,
1047
1071
  )
@@ -1061,3 +1085,24 @@ def test_float32_internal():
1061
1085
  def test_correctness_with_out_of_bounds_target_once(B, T, V, ignore_index):
1062
1086
  liger_ce = LigerCrossEntropyLoss(ignore_index=ignore_index)
1063
1087
  _test_correctness_with_out_of_bounds_target_once(liger_ce, B, T, V, ignore_index)
1088
+
1089
+
1090
+ @pytest.mark.parametrize(
1091
+ "B, T, V, ignore_index",
1092
+ [
1093
+ (2, 4096, 32000, -100),
1094
+ (3, 423, 32000, 2),
1095
+ ],
1096
+ )
1097
+ @pytest.mark.parametrize("reduction", ["mean", "sum", "none"])
1098
+ @pytest.mark.parametrize(
1099
+ "dtype, scalar, atol, rtol",
1100
+ [
1101
+ (torch.float32, 1.0, 1e-4, 1e-4),
1102
+ (torch.float16, 1.0, 1e-2, 1e-2),
1103
+ (torch.bfloat16, 1.0, 1e-2, 1e-2),
1104
+ ],
1105
+ )
1106
+ def test_correctness_with_forward_only(B, T, V, ignore_index, reduction, dtype, scalar, atol, rtol):
1107
+ liger_ce = LigerCrossEntropyLoss(ignore_index=ignore_index, reduction=reduction)
1108
+ _test_correctness_with_forward_only(liger_ce, B, T, V, reduction, dtype, scalar, atol, rtol)
@@ -231,6 +231,119 @@ def test_correctness(
231
231
  )
232
232
 
233
233
 
234
+ @pytest.mark.parametrize(
235
+ "B, T, H, V",
236
+ [
237
+ (8, 128, 1024, 4096),
238
+ (4, 47, 31, 123), # random shape
239
+ ],
240
+ )
241
+ @pytest.mark.parametrize(
242
+ "reduction, scalar, dtype, atol, rtol",
243
+ [
244
+ ("mean", 1.0, torch.bfloat16, 5e-3, 5e-2),
245
+ ("mean", 1.0, torch.float32, 1e-5, 5e-4),
246
+ ("sum", 1.0, torch.bfloat16, 5e-0, 5e1),
247
+ ("sum", 1.0, torch.float32, 1e-3, 5e-2),
248
+ ],
249
+ )
250
+ @pytest.mark.parametrize("bias", [True, False])
251
+ @pytest.mark.parametrize(
252
+ "has_ce_weight, label_smoothing, ignore_index, lse_square_scale, softcap, return_z_loss, accum_dtype",
253
+ [
254
+ (False, 0, -100, 0, None, False, None),
255
+ # Pass non-default values once to ensure all params work along
256
+ (True, 0.1, 42, 1e-4, 30.0, True, torch.float32),
257
+ ],
258
+ )
259
+ def test_correctness_with_forward_only(
260
+ B,
261
+ T,
262
+ H,
263
+ V,
264
+ scalar,
265
+ dtype,
266
+ bias,
267
+ has_ce_weight,
268
+ lse_square_scale,
269
+ label_smoothing,
270
+ ignore_index,
271
+ reduction,
272
+ softcap,
273
+ return_z_loss,
274
+ accum_dtype,
275
+ atol,
276
+ rtol,
277
+ ):
278
+ if has_ce_weight:
279
+ ce_weight = torch.rand(V, device=device, dtype=torch.float32)
280
+ else:
281
+ ce_weight = None
282
+ torch_lm_head_ce = TorchLMHeadCE(
283
+ H=H,
284
+ V=V,
285
+ bias=bias,
286
+ ce_weight=ce_weight,
287
+ lse_square_scale=lse_square_scale,
288
+ label_smoothing=label_smoothing,
289
+ ignore_index=ignore_index,
290
+ reduction=reduction,
291
+ softcap=softcap,
292
+ return_z_loss=return_z_loss,
293
+ dtype=dtype,
294
+ ).to(device)
295
+ liger_lm_head_ce = LigerLMHeadCE(
296
+ H=H,
297
+ V=V,
298
+ bias=bias,
299
+ ce_weight=ce_weight,
300
+ lse_square_scale=lse_square_scale,
301
+ label_smoothing=label_smoothing,
302
+ ignore_index=ignore_index,
303
+ reduction=reduction,
304
+ softcap=softcap,
305
+ return_z_loss=return_z_loss,
306
+ dtype=dtype,
307
+ accum_dtype=accum_dtype,
308
+ ).to(device)
309
+
310
+ # init the linear in all CEs with the same weights
311
+ torch_lm_head_ce.lin.weight.data = liger_lm_head_ce.lin.weight.data = torch.rand(V, H, device=device, dtype=dtype)
312
+
313
+ if bias:
314
+ torch_lm_head_ce.lin.bias.data = liger_lm_head_ce.lin.bias.data = torch.rand(V, device=device, dtype=dtype)
315
+
316
+ _tensor = torch.randn(B * T, H, device=device, dtype=dtype) * scalar
317
+ _input1 = _tensor.detach().clone().requires_grad_(True)
318
+ _input2 = _tensor.detach().clone().requires_grad_(True)
319
+
320
+ target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
321
+ # Assign some random number of elements as ignore_index
322
+ num_elements_to_assign = torch.randint(
323
+ 1, B * T // 2, (1,)
324
+ ).item() # Random number of elements to set to ignore_index
325
+ indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] # Randomly select indices
326
+ target[indices_to_assign] = ignore_index
327
+
328
+ with torch.no_grad():
329
+ if return_z_loss:
330
+ output1, z_output1 = torch_lm_head_ce(_input1, target)
331
+ output2, z_output2 = liger_lm_head_ce(_input2, target)
332
+ else:
333
+ output1 = torch_lm_head_ce(_input1, target)
334
+ output2 = liger_lm_head_ce(_input2, target)
335
+
336
+ assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol)
337
+ if return_z_loss:
338
+ assert_verbose_allclose(z_output1, z_output2, atol=atol, rtol=rtol)
339
+
340
+ try:
341
+ grad_output = torch.rand_like(output1)
342
+ output2.backward(gradient=grad_output)
343
+ except RuntimeError as e:
344
+ assert "does not require grad" in str(e)
345
+
346
+
234
347
  @pytest.mark.parametrize(
235
348
  "B, T, H, V",
236
349
  [
@@ -600,13 +600,19 @@ def test_apply_liger_kernel_to_instance_for_llama4_for_causal_lm():
600
600
  intermediate_size=64,
601
601
  hidden_act="silu",
602
602
  num_hidden_layers=2,
603
+ moe_layers=[1],
603
604
  )
604
605
  dummy_model_instance = Llama4ForCausalLM._from_config(config)
605
606
 
606
607
  # Check that model instance variables are not yet patched with Liger modules
607
608
  assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward)
608
609
  for layer in dummy_model_instance.model.layers:
609
- assert inspect.getsource(layer.feed_forward.forward) != inspect.getsource(LigerSwiGLUMLP.forward)
610
+ if layer.is_moe_layer:
611
+ assert inspect.getsource(layer.feed_forward.shared_expert.forward) != inspect.getsource(
612
+ LigerSwiGLUMLP.forward
613
+ )
614
+ else:
615
+ assert inspect.getsource(layer.feed_forward.forward) != inspect.getsource(LigerSwiGLUMLP.forward)
610
616
  assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward)
611
617
  assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward)
612
618
 
@@ -616,7 +622,12 @@ def test_apply_liger_kernel_to_instance_for_llama4_for_causal_lm():
616
622
  # Check that the model's instance variables were correctly patched with Liger modules
617
623
  assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward)
618
624
  for layer in dummy_model_instance.model.layers:
619
- assert inspect.getsource(layer.feed_forward.forward) == inspect.getsource(LigerSwiGLUMLP.forward)
625
+ if layer.is_moe_layer:
626
+ assert inspect.getsource(layer.feed_forward.shared_expert.forward) == inspect.getsource(
627
+ LigerSwiGLUMLP.forward
628
+ )
629
+ else:
630
+ assert inspect.getsource(layer.feed_forward.forward) == inspect.getsource(LigerSwiGLUMLP.forward)
620
631
  assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward)
621
632
  assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward)
622
633
 
@@ -642,6 +653,7 @@ def test_apply_liger_kernel_to_instance_for_llama4_for_conditional_generation():
642
653
  intermediate_size=64,
643
654
  hidden_act="silu",
644
655
  num_hidden_layers=2,
656
+ moe_layers=[1],
645
657
  ),
646
658
  vision_config=transformers.models.llama4.configuration_llama4.Llama4VisionConfig(
647
659
  rms_norm_eps=1e-5,
@@ -662,7 +674,12 @@ def test_apply_liger_kernel_to_instance_for_llama4_for_conditional_generation():
662
674
  LigerRMSNorm.forward
663
675
  )
664
676
  for layer in dummy_model_instance.language_model.model.layers:
665
- assert inspect.getsource(layer.feed_forward.forward) != inspect.getsource(LigerSwiGLUMLP.forward)
677
+ if layer.is_moe_layer:
678
+ assert inspect.getsource(layer.feed_forward.shared_expert.forward) != inspect.getsource(
679
+ LigerSwiGLUMLP.forward
680
+ )
681
+ else:
682
+ assert inspect.getsource(layer.feed_forward.forward) != inspect.getsource(LigerSwiGLUMLP.forward)
666
683
  assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward)
667
684
  assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward)
668
685
 
@@ -686,7 +703,12 @@ def test_apply_liger_kernel_to_instance_for_llama4_for_conditional_generation():
686
703
  LigerRMSNorm.forward
687
704
  )
688
705
  for layer in dummy_model_instance.language_model.model.layers:
689
- assert inspect.getsource(layer.feed_forward.forward) == inspect.getsource(LigerSwiGLUMLP.forward)
706
+ if layer.is_moe_layer:
707
+ assert inspect.getsource(layer.feed_forward.shared_expert.forward) == inspect.getsource(
708
+ LigerSwiGLUMLP.forward
709
+ )
710
+ else:
711
+ assert inspect.getsource(layer.feed_forward.forward) == inspect.getsource(LigerSwiGLUMLP.forward)
690
712
  assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward)
691
713
  assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward)
692
714