liger-kernel-nightly 0.6.2.dev20251011154427__tar.gz → 0.6.2.dev20251014053719__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (292) hide show
  1. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/PKG-INFO +1 -1
  2. liger_kernel_nightly-0.6.2.dev20251014053719/benchmark/scripts/benchmark_poly_norm.py +197 -0
  3. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/pyproject.toml +1 -1
  4. liger_kernel_nightly-0.6.2.dev20251014053719/src/liger_kernel/ops/poly_norm.py +386 -0
  5. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/__init__.py +2 -0
  6. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/functional.py +5 -0
  7. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/monkey_patch.py +5 -2
  8. liger_kernel_nightly-0.6.2.dev20251014053719/src/liger_kernel/transformers/poly_norm.py +42 -0
  9. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
  10. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel_nightly.egg-info/SOURCES.txt +4 -0
  11. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_monkey_patch.py +26 -4
  12. liger_kernel_nightly-0.6.2.dev20251014053719/test/transformers/test_poly_norm.py +281 -0
  13. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  14. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  15. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/.github/pull_request_template.md +0 -0
  16. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/.github/workflows/amd-ci.yml +0 -0
  17. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/.github/workflows/benchmark.yml +0 -0
  18. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/.github/workflows/docs.yml +0 -0
  19. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/.github/workflows/intel-ci.yml +0 -0
  20. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/.github/workflows/nvi-ci.yml +0 -0
  21. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/.github/workflows/publish-nightly.yml +0 -0
  22. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/.github/workflows/publish-release.yml +0 -0
  23. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/.gitignore +0 -0
  24. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/LICENSE +0 -0
  25. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/Makefile +0 -0
  26. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/NOTICE +0 -0
  27. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/README.md +0 -0
  28. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/README.md +0 -0
  29. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/__init__.py +0 -0
  30. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/benchmarks_visualizer.py +0 -0
  31. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/data/all_benchmark_data.csv +0 -0
  32. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/__init__.py +0 -0
  33. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  34. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  35. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_distill_cosine_loss.py +0 -0
  36. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  37. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  38. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_dyt.py +0 -0
  39. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_embedding.py +0 -0
  40. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_fused_add_rms_norm.py +0 -0
  41. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  42. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  43. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_fused_neighborhood_attention.py +0 -0
  44. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_geglu.py +0 -0
  45. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_group_norm.py +0 -0
  46. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_grpo_loss.py +0 -0
  47. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_jsd.py +0 -0
  48. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_kl_div.py +0 -0
  49. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  50. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  51. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_llama4_rope.py +0 -0
  52. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_multi_token_attention.py +0 -0
  53. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  54. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  55. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  56. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_rope.py +0 -0
  57. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  58. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_softmax.py +0 -0
  59. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_sparse_multi_token_attention.py +0 -0
  60. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_sparsemax.py +0 -0
  61. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_swiglu.py +0 -0
  62. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/benchmark_tvd.py +0 -0
  63. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/benchmark/scripts/utils.py +0 -0
  64. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/dev/fmt-requirements.txt +0 -0
  65. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/dev/modal/benchmarks.py +0 -0
  66. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/dev/modal/tests.py +0 -0
  67. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/dev/modal/tests_bwd.py +0 -0
  68. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/Examples.md +0 -0
  69. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/Getting-Started.md +0 -0
  70. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/High-Level-APIs.md +0 -0
  71. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/Low-Level-APIs.md +0 -0
  72. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/acknowledgement.md +0 -0
  73. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/contributing.md +0 -0
  74. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/images/banner.GIF +0 -0
  75. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/images/compose.gif +0 -0
  76. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/images/e2e-memory.png +0 -0
  77. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/images/e2e-tps.png +0 -0
  78. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/images/logo-banner.png +0 -0
  79. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/images/patch.gif +0 -0
  80. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/images/post-training.png +0 -0
  81. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/index.md +0 -0
  82. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/docs/license.md +0 -0
  83. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/alignment/accelerate_config.yaml +0 -0
  84. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/alignment/run_orpo.py +0 -0
  85. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/README.md +0 -0
  86. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/callback.py +0 -0
  87. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/config/fsdp_config.json +0 -0
  88. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  89. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  90. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  91. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/img/llama_tps.png +0 -0
  92. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  93. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/img/qwen_tps.png +0 -0
  94. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/launch_on_modal.py +0 -0
  95. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/requirements.txt +0 -0
  96. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/run_benchmarks.sh +0 -0
  97. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/run_gemma.sh +0 -0
  98. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/run_llama.sh +0 -0
  99. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/run_qwen.sh +0 -0
  100. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/run_qwen2_vl.sh +0 -0
  101. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/training.py +0 -0
  102. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/huggingface/training_multimodal.py +0 -0
  103. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/lightning/README.md +0 -0
  104. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/lightning/requirements.txt +0 -0
  105. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/lightning/training.py +0 -0
  106. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/README.md +0 -0
  107. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/callback.py +0 -0
  108. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  109. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  110. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  111. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  112. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  113. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  114. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  115. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  116. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  117. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/medusa_util.py +0 -0
  118. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/requirements.txt +0 -0
  119. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  120. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/examples/medusa/train.py +0 -0
  121. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/licenses/LICENSE-Apache-2.0 +0 -0
  122. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  123. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  124. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/licenses/LICENSE-MIT-llmc +0 -0
  125. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/licenses/LICENSE-MIT-triton +0 -0
  126. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/mkdocs.yml +0 -0
  127. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/setup.cfg +0 -0
  128. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/setup.py +0 -0
  129. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/__init__.py +0 -0
  130. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/README.md +0 -0
  131. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  132. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/cosine_similarity_loss.py +0 -0
  133. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  134. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  135. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/functional.py +0 -0
  136. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  137. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/fused_linear_ppo.py +0 -0
  138. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  139. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  140. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
  141. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  142. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  143. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  144. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  145. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/env_report.py +0 -0
  146. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/__init__.py +0 -0
  147. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/cross_entropy.py +0 -0
  148. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/dyt.py +0 -0
  149. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  150. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  151. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/fused_add_rms_norm.py +0 -0
  152. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  153. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  154. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/fused_neighborhood_attention.py +0 -0
  155. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/geglu.py +0 -0
  156. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/group_norm.py +0 -0
  157. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/grpo_loss.py +0 -0
  158. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/jsd.py +0 -0
  159. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/kl_div.py +0 -0
  160. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/layer_norm.py +0 -0
  161. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/llama4_rope.py +0 -0
  162. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/multi_token_attention.py +0 -0
  163. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  164. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/rms_norm.py +0 -0
  165. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/rope.py +0 -0
  166. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/softmax.py +0 -0
  167. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/sparsemax.py +0 -0
  168. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/swiglu.py +0 -0
  169. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/tvd.py +0 -0
  170. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/ops/utils.py +0 -0
  171. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/auto_model.py +0 -0
  172. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  173. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/dyt.py +0 -0
  174. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/experimental/__init__.py +0 -0
  175. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  176. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/fsdp.py +0 -0
  177. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/fused_add_rms_norm.py +0 -0
  178. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  179. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  180. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/fused_neighborhood_attention.py +0 -0
  181. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/geglu.py +0 -0
  182. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/group_norm.py +0 -0
  183. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/grpo_loss.py +0 -0
  184. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/jsd.py +0 -0
  185. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/kl_div.py +0 -0
  186. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/layer_norm.py +0 -0
  187. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/llama4_rope.py +0 -0
  188. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/__init__.py +0 -0
  189. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/falcon_h1.py +0 -0
  190. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/gemma.py +0 -0
  191. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  192. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/gemma3.py +0 -0
  193. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/glm4.py +0 -0
  194. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/glm4v.py +0 -0
  195. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/glm4v_moe.py +0 -0
  196. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/internvl.py +0 -0
  197. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/llama.py +0 -0
  198. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/llama4.py +0 -0
  199. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/llava.py +0 -0
  200. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
  201. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/mistral.py +0 -0
  202. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  203. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/mllama.py +0 -0
  204. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/olmo2.py +0 -0
  205. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/paligemma.py +0 -0
  206. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/phi3.py +0 -0
  207. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  208. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
  209. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  210. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/qwen3.py +0 -0
  211. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/qwen3_moe.py +0 -0
  212. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/model/smollm3.py +0 -0
  213. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/multi_token_attention.py +0 -0
  214. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  215. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/rms_norm.py +0 -0
  216. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/rope.py +0 -0
  217. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/softmax.py +0 -0
  218. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/sparsemax.py +0 -0
  219. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/swiglu.py +0 -0
  220. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  221. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  222. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  223. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/transformers/tvd.py +0 -0
  224. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/triton/__init__.py +0 -0
  225. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/triton/monkey_patch.py +0 -0
  226. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel/utils.py +0 -0
  227. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  228. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  229. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  230. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/__init__.py +0 -0
  231. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/chunked_loss/__init__.py +0 -0
  232. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/chunked_loss/test_cosine_loss.py +0 -0
  233. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/chunked_loss/test_cpo_loss.py +0 -0
  234. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/chunked_loss/test_dpo_loss.py +0 -0
  235. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/chunked_loss/test_grpo_loss.py +0 -0
  236. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/chunked_loss/test_jsd_loss.py +0 -0
  237. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/chunked_loss/test_kto_loss.py +0 -0
  238. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/chunked_loss/test_orpo_loss.py +0 -0
  239. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/chunked_loss/test_simpo_loss.py +0 -0
  240. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/conftest.py +0 -0
  241. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/convergence/__init__.py +0 -0
  242. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/convergence/bf16/__init__.py +0 -0
  243. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/convergence/bf16/test_mini_models.py +0 -0
  244. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
  245. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
  246. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/convergence/fp32/__init__.py +0 -0
  247. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/convergence/fp32/test_mini_models.py +0 -0
  248. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
  249. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
  250. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json +0 -0
  251. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
  252. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
  253. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
  254. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
  255. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/fake_configs/OpenGVLab/InternVL3-1B-hf/tokenizer_config.json +0 -0
  256. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  257. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
  258. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  259. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/fake_configs/meta-llama/Llama-4-Scout-17B-16E-Instruct/tokenizer_config.json +0 -0
  260. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  261. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/tiny_shakespeare.txt +0 -0
  262. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  263. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  264. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  265. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_auto_model.py +0 -0
  266. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_cross_entropy.py +0 -0
  267. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_dyt.py +0 -0
  268. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_embedding.py +0 -0
  269. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_flex_attention.py +0 -0
  270. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_fused_add_rms_norm.py +0 -0
  271. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  272. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_fused_linear_jsd.py +0 -0
  273. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_fused_neighborhood_attention.py +0 -0
  274. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_geglu.py +0 -0
  275. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_group_norm.py +0 -0
  276. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_grpo_loss.py +0 -0
  277. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_jsd.py +0 -0
  278. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_kl_div.py +0 -0
  279. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_layer_norm.py +0 -0
  280. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_mm_int8int2.py +0 -0
  281. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_multi_token_attention.py +0 -0
  282. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_qwen2vl_mrope.py +0 -0
  283. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_rms_norm.py +0 -0
  284. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_rope.py +0 -0
  285. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_softmax.py +0 -0
  286. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_sparsemax.py +0 -0
  287. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_swiglu.py +0 -0
  288. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_trainer_integration.py +0 -0
  289. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_transformers.py +0 -0
  290. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/transformers/test_tvd.py +0 -0
  291. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/triton/test_triton_monkey_patch.py +0 -0
  292. {liger_kernel_nightly-0.6.2.dev20251011154427 → liger_kernel_nightly-0.6.2.dev20251014053719}/test/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.6.2.dev20251011154427
3
+ Version: 0.6.2.dev20251014053719
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -0,0 +1,197 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import triton
4
+
5
+ from utils import QUANTILES
6
+ from utils import SingleBenchmarkRunInput
7
+ from utils import SingleBenchmarkRunOutput
8
+ from utils import _test_memory
9
+ from utils import parse_benchmark_script_args
10
+ from utils import run_benchmarks
11
+
12
+ from liger_kernel.transformers.poly_norm import LigerPolyNorm
13
+ from liger_kernel.utils import infer_device
14
+
15
+ device = infer_device()
16
+
17
+
18
+ class NaivePolyNorm(nn.Module):
19
+ """
20
+ Naive PyTorch implementation of PolyNorm.
21
+
22
+ Reference:
23
+ https://github.com/BryceZhuo/PolyCom/
24
+
25
+ PolyNorm formula:
26
+ y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
27
+ where norm(u) = u / sqrt(mean(u²) + ε)
28
+ """
29
+
30
+ def __init__(self, eps=1e-6):
31
+ super().__init__()
32
+ # Align with PolyCom reference: (1/3, 1/3, 1/3) and bias=1.0
33
+ self.weight = nn.Parameter(torch.full((3,), 1.0 / 3.0))
34
+ self.bias = nn.Parameter(torch.tensor(1.0))
35
+ self.variance_epsilon = eps
36
+
37
+ def _norm(self, x):
38
+ """RMSNorm operation"""
39
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon)
40
+
41
+ def forward(self, hidden_states):
42
+ """
43
+ Forward pass of PolyNorm
44
+
45
+ Args:
46
+ hidden_states: input tensor of shape (..., H)
47
+
48
+ Returns:
49
+ output tensor of same shape as input
50
+ """
51
+ input_dtype = hidden_states.dtype
52
+ hidden_states = hidden_states.to(torch.float32)
53
+
54
+ # Compute powers
55
+ x_pow3 = hidden_states**3
56
+ x_pow2 = hidden_states**2
57
+ x_pow1 = hidden_states**1
58
+
59
+ # Normalize each power
60
+ norm_x3 = self._norm(x_pow3)
61
+ norm_x2 = self._norm(x_pow2)
62
+ norm_x1 = self._norm(x_pow1)
63
+
64
+ # Weighted sum with bias
65
+ output = self.weight[0] * norm_x3 + self.weight[1] * norm_x2 + self.weight[2] * norm_x1 + self.bias
66
+
67
+ return output.to(input_dtype)
68
+
69
+
70
+ def bench_speed_poly_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
71
+ N = input.x
72
+ provider = input.kernel_provider
73
+ mode = input.kernel_operation_mode
74
+
75
+ extra_benchmark_config = input.extra_benchmark_config
76
+ M = extra_benchmark_config["M"]
77
+ eps = extra_benchmark_config["eps"]
78
+ dtype = extra_benchmark_config["dtype"]
79
+
80
+ x_shape = (M, N)
81
+
82
+ triton_poly = LigerPolyNorm(eps=eps).to(device)
83
+ naive_poly = NaivePolyNorm(eps=eps).to(device)
84
+
85
+ x = torch.randn(x_shape, dtype=dtype, device=device)
86
+ dy = torch.randn_like(x)
87
+ x.requires_grad_(True)
88
+
89
+ # utility functions
90
+
91
+ def y_fwd():
92
+ if provider == "liger":
93
+ return triton_poly(x)
94
+
95
+ if provider == "huggingface":
96
+ return naive_poly(x)
97
+
98
+ if mode == "forward":
99
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(
100
+ y_fwd,
101
+ grad_to_none=[x],
102
+ rep=500,
103
+ quantiles=QUANTILES,
104
+ )
105
+ elif mode == "backward":
106
+ y = y_fwd()
107
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(
108
+ lambda: y.backward(dy, retain_graph=True),
109
+ grad_to_none=[x],
110
+ rep=500,
111
+ quantiles=QUANTILES,
112
+ )
113
+ elif mode == "full":
114
+
115
+ def full():
116
+ y = y_fwd()
117
+ y.backward(dy, retain_graph=True)
118
+
119
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(
120
+ full,
121
+ grad_to_none=[x],
122
+ rep=500,
123
+ quantiles=QUANTILES,
124
+ )
125
+
126
+ return SingleBenchmarkRunOutput(
127
+ y_20=ms_20,
128
+ y_50=ms_50,
129
+ y_80=ms_80,
130
+ )
131
+
132
+
133
+ def bench_memory_poly_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
134
+ N = input.x
135
+ provider = input.kernel_provider
136
+
137
+ extra_benchmark_config = input.extra_benchmark_config
138
+ M = extra_benchmark_config["M"]
139
+ eps = extra_benchmark_config["eps"]
140
+ dtype = extra_benchmark_config["dtype"]
141
+
142
+ x_shape = (M, N)
143
+
144
+ triton_poly = LigerPolyNorm(eps=eps).to(device)
145
+ naive_poly = NaivePolyNorm(eps=eps).to(device)
146
+
147
+ x = torch.randn(x_shape, dtype=dtype, device=device)
148
+ dy = torch.randn_like(x)
149
+ x.requires_grad_(True)
150
+
151
+ # utility functions
152
+ def y_fwd():
153
+ if provider == "liger":
154
+ return triton_poly(x)
155
+ if provider == "huggingface":
156
+ return naive_poly(x)
157
+
158
+ def full():
159
+ y = y_fwd()
160
+ y.backward(dy, retain_graph=True)
161
+
162
+ mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
163
+
164
+ return SingleBenchmarkRunOutput(
165
+ y_20=mem_20,
166
+ y_50=mem_50,
167
+ y_80=mem_80,
168
+ )
169
+
170
+
171
+ if __name__ == "__main__":
172
+ args = parse_benchmark_script_args()
173
+
174
+ common_configs = {
175
+ "kernel_name": "poly_norm",
176
+ "x_name": "H",
177
+ "x_label": "hidden size",
178
+ "x_values": [2**i for i in range(10, 16)],
179
+ "kernel_providers": ["liger", "huggingface"],
180
+ "extra_benchmark_configs": [{"M": 2048, "dtype": torch.bfloat16, "eps": 1e-6}],
181
+ "overwrite": args.overwrite,
182
+ }
183
+
184
+ run_benchmarks(
185
+ bench_test_fn=bench_speed_poly_norm,
186
+ kernel_operation_modes=["forward", "full", "backward"],
187
+ metric_name="speed",
188
+ metric_unit="ms",
189
+ **common_configs,
190
+ )
191
+ run_benchmarks(
192
+ bench_test_fn=bench_memory_poly_norm,
193
+ kernel_operation_modes=["full"],
194
+ metric_name="memory",
195
+ metric_unit="MB",
196
+ **common_configs,
197
+ )
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.6.2.dev20251011154427"
7
+ version = "0.6.2.dev20251014053719"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -0,0 +1,386 @@
1
+ import operator
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import compare_version
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+
11
+ if compare_version("triton", operator.ge, "3.0.0"):
12
+ try:
13
+ from triton.language.extra.libdevice import rsqrt
14
+ except ModuleNotFoundError:
15
+ from triton.language.extra.cuda.libdevice import rsqrt
16
+ else:
17
+ from triton.language.math import rsqrt
18
+
19
+
20
+ @triton.jit
21
+ def _poly_norm_forward_kernel(
22
+ Y_ptr,
23
+ Y_row_stride,
24
+ X_ptr,
25
+ X_row_stride,
26
+ W_ptr, # weight: [3] for [w0, w1, w2]
27
+ B_ptr, # bias: scalar
28
+ RSTD_ptr, # cache rstd for backward: shape (n_rows, 3)
29
+ RSTD_row_stride,
30
+ n_cols,
31
+ eps,
32
+ BLOCK_SIZE: tl.constexpr,
33
+ ):
34
+ """
35
+ PolyNorm formula:
36
+ y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
37
+ where norm(u) = u / sqrt(mean(u²) + ε)
38
+
39
+ Reference:
40
+ 1. https://github.com/BryceZhuo/PolyCom/
41
+ 2. https://arxiv.org/pdf/2411.03884
42
+
43
+ Cache rstd values for backward pass
44
+ """
45
+ row_idx = tl.program_id(0).to(tl.int64)
46
+ col_offsets = tl.arange(0, BLOCK_SIZE)
47
+ mask = col_offsets < n_cols
48
+
49
+ # Load pointers
50
+ Y_ptr += row_idx * Y_row_stride
51
+ X_ptr += row_idx * X_row_stride
52
+ RSTD_ptr += row_idx * RSTD_row_stride
53
+
54
+ # Load input row
55
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
56
+
57
+ # Load weights and bias
58
+ w0 = tl.load(W_ptr + 0)
59
+ w1 = tl.load(W_ptr + 1)
60
+ w2 = tl.load(W_ptr + 2)
61
+ b = tl.load(B_ptr)
62
+
63
+ # Compute x³, x², x
64
+ X_pow3 = X_row * X_row * X_row
65
+ X_pow2 = X_row * X_row
66
+ X_pow1 = X_row
67
+
68
+ # Compute norm(x³): norm(u) = u * rsqrt(mean(u²) + eps)
69
+ mean_square_3 = tl.sum(X_pow3 * X_pow3, axis=0) / n_cols
70
+ rstd_3 = rsqrt(mean_square_3 + eps)
71
+ norm_x3 = X_pow3 * rstd_3
72
+
73
+ # Compute norm(x²)
74
+ mean_square_2 = tl.sum(X_pow2 * X_pow2, axis=0) / n_cols
75
+ rstd_2 = rsqrt(mean_square_2 + eps)
76
+ norm_x2 = X_pow2 * rstd_2
77
+
78
+ # Compute norm(x)
79
+ mean_square_1 = tl.sum(X_pow1 * X_pow1, axis=0) / n_cols
80
+ rstd_1 = rsqrt(mean_square_1 + eps)
81
+ norm_x1 = X_pow1 * rstd_1
82
+
83
+ # Cache rstd values for backward
84
+ tl.store(RSTD_ptr + 0, rstd_3)
85
+ tl.store(RSTD_ptr + 1, rstd_2)
86
+ tl.store(RSTD_ptr + 2, rstd_1)
87
+
88
+ # Compute output: y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
89
+ Y_row = w0 * norm_x3 + w1 * norm_x2 + w2 * norm_x1 + b
90
+
91
+ # Store output
92
+ tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
93
+
94
+
95
+ @triton.jit
96
+ def _poly_norm_backward_kernel(
97
+ dY_ptr,
98
+ dY_row_stride,
99
+ dX_ptr,
100
+ dX_row_stride,
101
+ X_ptr,
102
+ X_row_stride,
103
+ W_ptr,
104
+ RSTD_ptr,
105
+ RSTD_row_stride,
106
+ dW_ptr, # shape: (n_programs, 3)
107
+ dW_row_stride,
108
+ dB_ptr, # shape: (n_programs,)
109
+ n_rows,
110
+ n_cols,
111
+ rows_per_program: tl.constexpr,
112
+ BLOCK_SIZE: tl.constexpr,
113
+ ):
114
+ """
115
+ PolyNorm Backward Kernel Gradient:
116
+ ∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
117
+
118
+ where:
119
+ - D_p = RMS(x^p) = 1/rstd_p
120
+ - S_p = sum(grad * x^p) over the row
121
+ - d = n_cols
122
+ - p ∈ {3, 2, 1}
123
+ """
124
+ row_block_id = tl.program_id(0).to(tl.int64)
125
+ row_start = row_block_id * rows_per_program
126
+ row_end = min((row_block_id + 1) * rows_per_program, n_rows)
127
+ col_offsets = tl.arange(0, BLOCK_SIZE)
128
+ mask = col_offsets < n_cols
129
+
130
+ # Initialize accumulators for weight and bias gradients (scalars)
131
+ dW0_acc = 0.0
132
+ dW1_acc = 0.0
133
+ dW2_acc = 0.0
134
+ dB_acc = 0.0
135
+
136
+ # Load weights
137
+ w0 = tl.load(W_ptr + 0).to(tl.float32)
138
+ w1 = tl.load(W_ptr + 1).to(tl.float32)
139
+ w2 = tl.load(W_ptr + 2).to(tl.float32)
140
+
141
+ dY_ptr += row_start * dY_row_stride
142
+ dX_ptr += row_start * dX_row_stride
143
+ X_ptr += row_start * X_row_stride
144
+ RSTD_ptr += row_start * RSTD_row_stride
145
+
146
+ for _ in range(row_start, row_end):
147
+ # Load input and gradient
148
+ dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
149
+ X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
150
+
151
+ # Load cached rstd values
152
+ rstd_3 = tl.load(RSTD_ptr + 0).to(tl.float32)
153
+ rstd_2 = tl.load(RSTD_ptr + 1).to(tl.float32)
154
+ rstd_1 = tl.load(RSTD_ptr + 2).to(tl.float32)
155
+
156
+ # Compute powers
157
+ X_pow3 = X_row * X_row * X_row
158
+ X_pow2 = X_row * X_row
159
+ X_pow1 = X_row
160
+
161
+ # Accumulate bias gradient: dB = sum(dY)
162
+ dB_acc += tl.sum(dY_row, axis=0)
163
+
164
+ # Compute gradient w.r.t. input using closed-form formula
165
+ # For p=3: ∂L/∂x from w0 * norm(x³)
166
+ S_3 = tl.sum(dY_row * X_pow3, axis=0) # scalar
167
+ grad_x_3 = w0 * (
168
+ 3.0 * X_pow2 * rstd_3 * dY_row
169
+ - (3.0 / n_cols) * X_row * X_row * X_row * X_row * X_row * (rstd_3 * rstd_3 * rstd_3) * S_3
170
+ )
171
+
172
+ # For p=2: ∂L/∂x from w1 * norm(x²)
173
+ S_2 = tl.sum(dY_row * X_pow2, axis=0) # scalar
174
+ grad_x_2 = w1 * (
175
+ 2.0 * X_row * rstd_2 * dY_row - (2.0 / n_cols) * X_row * X_row * X_row * (rstd_2 * rstd_2 * rstd_2) * S_2
176
+ )
177
+
178
+ # For p=1: ∂L/∂x from w2 * norm(x)
179
+ S_1 = tl.sum(dY_row * X_pow1, axis=0) # scalar
180
+ grad_x_1 = w2 * (1.0 * rstd_1 * dY_row - (1.0 / n_cols) * X_row * (rstd_1 * rstd_1 * rstd_1) * S_1)
181
+
182
+ # Accumulate weight gradients using closed-form: dW_p = rstd_p * S_p
183
+ dW0_acc += rstd_3 * S_3
184
+ dW1_acc += rstd_2 * S_2
185
+ dW2_acc += rstd_1 * S_1
186
+
187
+ # Total gradient
188
+ dX_row = grad_x_3 + grad_x_2 + grad_x_1
189
+
190
+ # Store gradient
191
+ tl.store(dX_ptr + col_offsets, dX_row, mask=mask)
192
+
193
+ # Update pointers
194
+ dY_ptr += dY_row_stride
195
+ dX_ptr += dX_row_stride
196
+ X_ptr += X_row_stride
197
+ RSTD_ptr += RSTD_row_stride
198
+
199
+ # Store accumulated gradients (scalars)
200
+ tl.store(dW_ptr + row_block_id * dW_row_stride + 0, dW0_acc)
201
+ tl.store(dW_ptr + row_block_id * dW_row_stride + 1, dW1_acc)
202
+ tl.store(dW_ptr + row_block_id * dW_row_stride + 2, dW2_acc)
203
+ tl.store(dB_ptr + row_block_id, dB_acc)
204
+
205
+
206
+ def poly_norm_forward(X, W, B, eps=1e-6):
207
+ """
208
+ PolyNorm Forward Pass
209
+
210
+ Args:
211
+ X: input tensor of shape (*, H) where H is hidden dimension
212
+ W: weight tensor of shape (3,) for [w0, w1, w2]
213
+ B: bias scalar tensor
214
+ eps: epsilon for numerical stability
215
+
216
+ Returns:
217
+ Y: output tensor of same shape as X
218
+ X: reshaped input (for backward)
219
+ RSTD: cached rstd values (for backward)
220
+ BLOCK_SIZE: block size used
221
+ num_warps: number of warps used
222
+ """
223
+ shape = X.shape
224
+ dim = shape[-1]
225
+ X = X.view(-1, dim)
226
+ n_rows, n_cols = X.shape
227
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
228
+
229
+ # RSTD is to cache rstd for each row
230
+ Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
231
+ RSTD = torch.empty((n_rows, 3), dtype=torch.float32, device=X.device)
232
+
233
+ # Check constraints
234
+ assert W.shape[0] == 3, "Weight tensor must have shape (3,)"
235
+ assert B.numel() == 1, "Bias must be a scalar"
236
+
237
+ # XPU-specific optimization
238
+ kernel_args = {}
239
+ if X.device.type == "xpu":
240
+ kernel_args["grf_mode"] = "large"
241
+
242
+ # Launch kernel
243
+ _poly_norm_forward_kernel[(n_rows,)](
244
+ Y,
245
+ Y.stride(0),
246
+ X,
247
+ X.stride(0),
248
+ W,
249
+ B,
250
+ RSTD,
251
+ RSTD.stride(0),
252
+ n_cols,
253
+ eps,
254
+ BLOCK_SIZE=BLOCK_SIZE,
255
+ num_warps=num_warps,
256
+ **kernel_args,
257
+ )
258
+
259
+ return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps
260
+
261
+
262
+ def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
263
+ """
264
+ PolyNorm Backward Pass
265
+
266
+ Args:
267
+ dY: gradient of output
268
+ X: input tensor (already reshaped to 2D)
269
+ W: weight tensor
270
+ RSTD: cached rstd values from forward
271
+ BLOCK_SIZE: block size from forward
272
+ num_warps: number of warps from forward
273
+ in_place: whether to in-place modify dY to store dX (saves memory)
274
+
275
+ Returns:
276
+ dX: gradient w.r.t. input
277
+ dW: gradient w.r.t. weight
278
+ dB: gradient w.r.t. bias
279
+ """
280
+ shape = dY.shape
281
+ dim = shape[-1]
282
+ dY = dY.view(-1, dim)
283
+ n_rows, n_cols = dY.shape
284
+
285
+ # Get number of SMs for parallelization
286
+ import math
287
+
288
+ sm_count = 1
289
+ if X.device.type == "cuda":
290
+ sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
291
+ elif X.device.type == "xpu":
292
+ sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
293
+
294
+ # Allocate or reuse gradients
295
+ if in_place is True:
296
+ dX = dY
297
+ else:
298
+ dX = torch.zeros_like(dY)
299
+
300
+ _dW = torch.empty((sm_count, 3), dtype=torch.float32, device=W.device)
301
+ _dB = torch.empty((sm_count,), dtype=torch.float32, device=W.device)
302
+
303
+ rows_per_program = math.ceil(n_rows / sm_count)
304
+ grid = (sm_count,)
305
+
306
+ # XPU-specific optimization
307
+ kernel_args = {}
308
+ if X.device.type == "xpu":
309
+ kernel_args["grf_mode"] = "large"
310
+
311
+ # Launch backward kernel
312
+ _poly_norm_backward_kernel[grid](
313
+ dY,
314
+ dY.stride(0),
315
+ dX,
316
+ dX.stride(0),
317
+ X,
318
+ X.stride(0),
319
+ W,
320
+ RSTD,
321
+ RSTD.stride(0),
322
+ _dW,
323
+ _dW.stride(0),
324
+ _dB,
325
+ n_rows,
326
+ n_cols,
327
+ rows_per_program,
328
+ BLOCK_SIZE=BLOCK_SIZE,
329
+ num_warps=num_warps,
330
+ **kernel_args,
331
+ )
332
+
333
+ # Reduce gradients across SMs
334
+ dX = dX.view(*shape)
335
+ dW = _dW.sum(dim=0).to(W.dtype)
336
+ dB = _dB.sum().to(W.dtype)
337
+
338
+ return dX, dW, dB
339
+
340
+
341
+ class LigerPolyNormFunction(torch.autograd.Function):
342
+ """
343
+ PolyNorm Function with forward and backward pass
344
+
345
+ PolyNorm formula:
346
+ y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
347
+ where norm(u) = u / sqrt(mean(u²) + ε)
348
+
349
+ Backward uses closed-form gradient:
350
+ ∂L/∂x_i = Σ_p w_p * [p*x_i^(p-1) * grad_i/D_p - (p/d)*x_i^(2p-1) * S_p/(D_p³)]
351
+ """
352
+
353
+ @staticmethod
354
+ @ensure_contiguous
355
+ def forward(ctx, X, W, B, eps=1e-6, in_place=True):
356
+ """
357
+ Args:
358
+ X: input tensor of shape (B, T, H) or (BxT, H)
359
+ W: weight tensor of shape (3,) for [w0, w1, w2]
360
+ B: bias scalar
361
+ eps: epsilon for numerical stability
362
+ in_place: whether to in-place modify grad_output in backward (saves memory)
363
+
364
+ Returns:
365
+ Y: output tensor of same shape as X
366
+ """
367
+ Y, X, RSTD, BLOCK_SIZE, num_warps = poly_norm_forward(X, W, B, eps)
368
+ ctx.BLOCK_SIZE = BLOCK_SIZE
369
+ ctx.num_warps = num_warps
370
+ ctx.in_place = in_place
371
+ ctx.save_for_backward(X, W, RSTD)
372
+ return Y
373
+
374
+ @staticmethod
375
+ @ensure_contiguous
376
+ def backward(ctx, grad_output):
377
+ """
378
+ Args:
379
+ grad_output: gradient of output
380
+
381
+ Returns:
382
+ dX, dW, dB: gradients w.r.t. X, W, B
383
+ """
384
+ X, W, RSTD = ctx.saved_tensors
385
+ dX, dW, dB = poly_norm_backward(grad_output, X, W, RSTD, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place)
386
+ return dX, dW, dB, None, None
@@ -15,6 +15,7 @@ from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401
15
15
  from liger_kernel.transformers.llama4_rope import liger_llama4_text_rotary_pos_emb # noqa: F401
16
16
  from liger_kernel.transformers.llama4_rope import liger_llama4_vision_rotary_pos_emb # noqa: F401
17
17
  from liger_kernel.transformers.multi_token_attention import LigerMultiTokenAttention # noqa: F401
18
+ from liger_kernel.transformers.poly_norm import LigerPolyNorm # noqa: F401
18
19
  from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
19
20
  from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
20
21
  from liger_kernel.transformers.softmax import LigerSoftmax # noqa: F401
@@ -137,6 +138,7 @@ __all__ = [
137
138
  "LigerJSD",
138
139
  "LigerLayerNorm",
139
140
  "LigerFusedAddRMSNorm",
141
+ "LigerPolyNorm",
140
142
  "LigerRMSNorm",
141
143
  "liger_rotary_pos_emb",
142
144
  "liger_llama4_text_rotary_pos_emb",
@@ -12,6 +12,7 @@ from liger_kernel.ops.jsd import LigerJSDFunction
12
12
  from liger_kernel.ops.kl_div import LigerKLDivLossFunction
13
13
  from liger_kernel.ops.layer_norm import LigerLayerNormFunction
14
14
  from liger_kernel.ops.multi_token_attention import LigerMultiTokenAttentionFunction
15
+ from liger_kernel.ops.poly_norm import LigerPolyNormFunction
15
16
  from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
16
17
  from liger_kernel.ops.rms_norm import LigerRMSNormFunction
17
18
  from liger_kernel.ops.rope import LigerRopeFunction
@@ -258,6 +259,10 @@ def liger_rms_norm(X, W, eps, offset: float = 0.0, casting_mode: str = "llama",
258
259
  return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place)
259
260
 
260
261
 
262
+ def liger_poly_norm(X, W, B, eps=1e-6, in_place=True):
263
+ return LigerPolyNormFunction.apply(X, W, B, eps, in_place)
264
+
265
+
261
266
  def liger_fused_add_rms_norm(X, R, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True):
262
267
  return LigerFusedAddRMSNormFunction.apply(X, R, W, eps, offset, casting_mode, in_place)
263
268
 
@@ -469,7 +469,7 @@ def apply_liger_kernel_to_llama4(
469
469
  `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
470
470
  If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
471
471
  rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
472
- swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
472
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
473
473
  model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
474
474
  loaded. Default is None.
475
475
  """
@@ -522,7 +522,10 @@ def apply_liger_kernel_to_llama4(
522
522
  _patch_rms_norm_module(text_model.norm)
523
523
  for decoder_layer in text_model.layers:
524
524
  if swiglu:
525
- _patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
525
+ if decoder_layer.is_moe_layer:
526
+ _patch_swiglu_module(decoder_layer.feed_forward.shared_expert, LigerSwiGLUMLP)
527
+ else:
528
+ _patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
526
529
  if rms_norm:
527
530
  _patch_rms_norm_module(decoder_layer.input_layernorm)
528
531
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
@@ -0,0 +1,42 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from liger_kernel.ops.poly_norm import LigerPolyNormFunction
5
+
6
+
7
+ class LigerPolyNorm(nn.Module):
8
+ """
9
+ PolyNorm layer wrapper for Liger kernel.
10
+
11
+ PolyNorm formula:
12
+ y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
13
+ where norm(u) = u / sqrt(mean(u²) + ε)
14
+
15
+ Reference:
16
+ https://github.com/BryceZhuo/PolyCom/
17
+
18
+ Args:
19
+ eps: epsilon for numerical stability (default: 1e-6)
20
+ in_place: whether to in-place modify grad_output in backward to save memory (default: False).
21
+ Set to True to save memory if grad_output is not needed elsewhere.
22
+ """
23
+
24
+ def __init__(self, eps=1e-6, in_place=True):
25
+ super().__init__()
26
+ # Align with PolyCom reference: initialize weights to (1/3, 1/3, 1/3) and bias to 1.0
27
+ self.weight = nn.Parameter(torch.full((3,), 1.0 / 3.0))
28
+ self.bias = nn.Parameter(torch.tensor(1.0))
29
+ self.variance_epsilon = eps
30
+ self.in_place = in_place
31
+
32
+ def forward(self, hidden_states):
33
+ return LigerPolyNormFunction.apply(
34
+ hidden_states,
35
+ self.weight,
36
+ self.bias,
37
+ self.variance_epsilon,
38
+ self.in_place,
39
+ )
40
+
41
+ def extra_repr(self):
42
+ return f"weight_shape={tuple(self.weight.shape)}, eps={self.variance_epsilon}, in_place={self.in_place}"