liger-kernel-nightly 0.5.9.dev20250517045713__tar.gz → 0.5.9.dev20250519011716__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/PKG-INFO +3 -2
  2. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/README.md +2 -1
  3. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/benchmark/scripts/benchmark_dyt.py +37 -34
  4. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/docs/Low-Level-APIs.md +9 -1
  5. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/pyproject.toml +1 -1
  6. liger_kernel_nightly-0.5.9.dev20250519011716/src/liger_kernel/ops/dyt.py +159 -0
  7. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/dyt.py +5 -3
  8. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel_nightly.egg-info/PKG-INFO +3 -2
  9. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/transformers/test_dyt.py +40 -20
  10. liger_kernel_nightly-0.5.9.dev20250517045713/src/liger_kernel/ops/dyt.py +0 -225
  11. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  12. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  13. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/.github/pull_request_template.md +0 -0
  14. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/.github/workflows/amd-ci.yml +0 -0
  15. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/.github/workflows/docs.yml +0 -0
  16. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/.github/workflows/intel-ci.yml +0 -0
  17. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/.github/workflows/nvi-ci.yml +0 -0
  18. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/.github/workflows/publish-nightly.yml +0 -0
  19. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/.github/workflows/publish-release.yml +0 -0
  20. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/.gitignore +0 -0
  21. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/.idea/workspace.xml +0 -0
  22. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/LICENSE +0 -0
  23. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/Makefile +0 -0
  24. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/NOTICE +0 -0
  25. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/benchmark/README.md +0 -0
  26. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/benchmark/__init__.py +0 -0
  27. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/benchmark/benchmarks_visualizer.py +0 -0
  28. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/benchmark/data/all_benchmark_data.csv +0 -0
  29. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/benchmark/scripts/__init__.py +0 -0
  30. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  31. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  32. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  33. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  34. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/benchmark/scripts/benchmark_embedding.py +0 -0
  35. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  36. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  37. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/benchmark/scripts/benchmark_geglu.py +0 -0
  38. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/benchmark/scripts/benchmark_group_norm.py +0 -0
  39. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/benchmark/scripts/benchmark_jsd.py +0 -0
  40. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/benchmark/scripts/benchmark_kl_div.py +0 -0
  41. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  42. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  43. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  44. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  45. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  46. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/benchmark/scripts/benchmark_rope.py +0 -0
  47. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  48. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/benchmark/scripts/benchmark_sparsemax.py +0 -0
  49. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/benchmark/scripts/benchmark_swiglu.py +0 -0
  50. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/benchmark/scripts/benchmark_tvd.py +0 -0
  51. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/benchmark/scripts/utils.py +0 -0
  52. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/dev/fmt-requirements.txt +0 -0
  53. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/dev/modal/tests.py +0 -0
  54. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/dev/modal/tests_bwd.py +0 -0
  55. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/docs/Examples.md +0 -0
  56. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/docs/Getting-Started.md +0 -0
  57. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/docs/High-Level-APIs.md +0 -0
  58. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/docs/acknowledgement.md +0 -0
  59. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/docs/contributing.md +0 -0
  60. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/docs/images/banner.GIF +0 -0
  61. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/docs/images/compose.gif +0 -0
  62. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/docs/images/e2e-memory.png +0 -0
  63. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/docs/images/e2e-tps.png +0 -0
  64. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/docs/images/logo-banner.png +0 -0
  65. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/docs/images/patch.gif +0 -0
  66. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/docs/images/post-training.png +0 -0
  67. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/docs/index.md +0 -0
  68. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/docs/license.md +0 -0
  69. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/alignment/accelerate_config.yaml +0 -0
  70. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/alignment/run_orpo.py +0 -0
  71. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/huggingface/README.md +0 -0
  72. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/huggingface/callback.py +0 -0
  73. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/huggingface/config/fsdp_config.json +0 -0
  74. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  75. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  76. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  77. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/huggingface/img/llama_tps.png +0 -0
  78. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  79. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/huggingface/img/qwen_tps.png +0 -0
  80. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/huggingface/launch_on_modal.py +0 -0
  81. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/huggingface/requirements.txt +0 -0
  82. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/huggingface/run_benchmarks.sh +0 -0
  83. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/huggingface/run_gemma.sh +0 -0
  84. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/huggingface/run_llama.sh +0 -0
  85. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/huggingface/run_qwen.sh +0 -0
  86. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/huggingface/run_qwen2_vl.sh +0 -0
  87. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/huggingface/training.py +0 -0
  88. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/huggingface/training_multimodal.py +0 -0
  89. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/lightning/README.md +0 -0
  90. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/lightning/requirements.txt +0 -0
  91. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/lightning/training.py +0 -0
  92. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/medusa/README.md +0 -0
  93. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/medusa/callback.py +0 -0
  94. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  95. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  96. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  97. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  98. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  99. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  100. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  101. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  102. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  103. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/medusa/medusa_util.py +0 -0
  104. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/medusa/requirements.txt +0 -0
  105. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  106. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/examples/medusa/train.py +0 -0
  107. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/licenses/LICENSE-Apache-2.0 +0 -0
  108. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  109. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  110. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/licenses/LICENSE-MIT-llmc +0 -0
  111. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/licenses/LICENSE-MIT-triton +0 -0
  112. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/mkdocs.yml +0 -0
  113. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/setup.cfg +0 -0
  114. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/setup.py +0 -0
  115. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/__init__.py +0 -0
  116. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/chunked_loss/README.md +0 -0
  117. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  118. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  119. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  120. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/chunked_loss/functional.py +0 -0
  121. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  122. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/chunked_loss/fused_linear_ppo.py +0 -0
  123. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  124. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  125. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
  126. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  127. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  128. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  129. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  130. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/env_report.py +0 -0
  131. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/ops/__init__.py +0 -0
  132. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/ops/cross_entropy.py +0 -0
  133. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  134. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  135. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  136. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  137. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/ops/geglu.py +0 -0
  138. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/ops/group_norm.py +0 -0
  139. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/ops/jsd.py +0 -0
  140. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/ops/kl_div.py +0 -0
  141. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/ops/layer_norm.py +0 -0
  142. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  143. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/ops/rms_norm.py +0 -0
  144. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/ops/rope.py +0 -0
  145. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/ops/sparsemax.py +0 -0
  146. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/ops/swiglu.py +0 -0
  147. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/ops/tvd.py +0 -0
  148. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/ops/utils.py +0 -0
  149. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/__init__.py +0 -0
  150. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/auto_model.py +0 -0
  151. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  152. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  153. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/functional.py +0 -0
  154. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  155. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  156. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/geglu.py +0 -0
  157. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/gema3_rms.py +0 -0
  158. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/group_norm.py +0 -0
  159. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/jsd.py +0 -0
  160. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/kl_div.py +0 -0
  161. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/layer_norm.py +0 -0
  162. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/model/__init__.py +0 -0
  163. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/model/gemma.py +0 -0
  164. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  165. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/model/gemma3.py +0 -0
  166. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/model/glm4.py +0 -0
  167. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/model/llama.py +0 -0
  168. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/model/llava.py +0 -0
  169. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
  170. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/model/mistral.py +0 -0
  171. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  172. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/model/mllama.py +0 -0
  173. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/model/olmo2.py +0 -0
  174. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/model/paligemma.py +0 -0
  175. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/model/phi3.py +0 -0
  176. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  177. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
  178. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  179. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/model/qwen3.py +0 -0
  180. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/model/qwen3_moe.py +0 -0
  181. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  182. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  183. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/rms_norm.py +0 -0
  184. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/rope.py +0 -0
  185. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/sparsemax.py +0 -0
  186. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/swiglu.py +0 -0
  187. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  188. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  189. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  190. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/transformers/tvd.py +0 -0
  191. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/triton/__init__.py +0 -0
  192. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/triton/monkey_patch.py +0 -0
  193. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel/utils.py +0 -0
  194. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
  195. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  196. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  197. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  198. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/__init__.py +0 -0
  199. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/chunked_loss/__init__.py +0 -0
  200. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/chunked_loss/test_cpo_loss.py +0 -0
  201. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/chunked_loss/test_dpo_loss.py +0 -0
  202. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/chunked_loss/test_grpo_loss.py +0 -0
  203. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/chunked_loss/test_jsd_loss.py +0 -0
  204. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/chunked_loss/test_kto_loss.py +0 -0
  205. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/chunked_loss/test_orpo_loss.py +0 -0
  206. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/chunked_loss/test_simpo_loss.py +0 -0
  207. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/conftest.py +0 -0
  208. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/convergence/__init__.py +0 -0
  209. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/convergence/bf16/__init__.py +0 -0
  210. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/convergence/bf16/test_mini_models.py +0 -0
  211. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
  212. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
  213. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/convergence/fp32/__init__.py +0 -0
  214. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/convergence/fp32/test_mini_models.py +0 -0
  215. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
  216. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
  217. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json +0 -0
  218. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
  219. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
  220. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
  221. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
  222. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  223. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
  224. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  225. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  226. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/resources/tiny_shakespeare.txt +0 -0
  227. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  228. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  229. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  230. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/transformers/test_auto_model.py +0 -0
  231. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/transformers/test_cross_entropy.py +0 -0
  232. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/transformers/test_embedding.py +0 -0
  233. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/transformers/test_flex_attention.py +0 -0
  234. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  235. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/transformers/test_fused_linear_jsd.py +0 -0
  236. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/transformers/test_geglu.py +0 -0
  237. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/transformers/test_group_norm.py +0 -0
  238. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/transformers/test_jsd.py +0 -0
  239. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/transformers/test_kl_div.py +0 -0
  240. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/transformers/test_layer_norm.py +0 -0
  241. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/transformers/test_mm_int8int2.py +0 -0
  242. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/transformers/test_monkey_patch.py +0 -0
  243. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/transformers/test_qwen2vl_mrope.py +0 -0
  244. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/transformers/test_rms_norm.py +0 -0
  245. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/transformers/test_rope.py +0 -0
  246. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/transformers/test_sparsemax.py +0 -0
  247. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/transformers/test_swiglu.py +0 -0
  248. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/transformers/test_trainer_integration.py +0 -0
  249. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/transformers/test_transformers.py +0 -0
  250. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/transformers/test_tvd.py +0 -0
  251. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/triton/test_triton_monkey_patch.py +0 -0
  252. {liger_kernel_nightly-0.5.9.dev20250517045713 → liger_kernel_nightly-0.5.9.dev20250519011716}/test/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.9.dev20250517045713
3
+ Version: 0.5.9.dev20250519011716
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -322,7 +322,8 @@ loss.backward()
322
322
  | SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` |
323
323
  | GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
324
324
  | CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
325
- | Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
325
+ | Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
326
+ | Sparsemax | `liger_kernel.transformers.LigerSparsemax` |
326
327
 
327
328
 
328
329
  ### Alignment Kernels
@@ -274,7 +274,8 @@ loss.backward()
274
274
  | SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` |
275
275
  | GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
276
276
  | CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
277
- | Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
277
+ | Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
278
+ | Sparsemax | `liger_kernel.transformers.LigerSparsemax` |
278
279
 
279
280
 
280
281
  ### Alignment Kernels
@@ -22,17 +22,18 @@ def bench_speed_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
22
22
  from test.transformers.test_dyt import LigerDyT
23
23
  from test.transformers.test_dyt import TorchDyT
24
24
 
25
- BT = input.x
25
+ hidden_size = input.x
26
26
  provider = input.kernel_provider
27
27
  mode = input.kernel_operation_mode
28
28
  extra_benchmark_config = input.extra_benchmark_config
29
- hidden_size = extra_benchmark_config["hidden_size"]
29
+ BT = extra_benchmark_config["BT"]
30
+ beta = extra_benchmark_config["beta"]
30
31
  dtype = extra_benchmark_config["dtype"]
31
32
 
32
33
  x_shape = (BT, hidden_size)
33
- torch_dyt = TorchDyT(hidden_size=hidden_size).to(device)
34
- torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size).to(device))
35
- triton_dyt = LigerDyT(hidden_size=hidden_size).to(device)
34
+ torch_dyt = TorchDyT(hidden_size=hidden_size, beta=beta).to(device)
35
+ torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size, beta=beta).to(device))
36
+ triton_dyt = LigerDyT(hidden_size=hidden_size, beta=beta).to(device)
36
37
 
37
38
  x = torch.randn(x_shape, dtype=dtype, device=device)
38
39
  dy = torch.randn_like(x)
@@ -75,16 +76,17 @@ def bench_memory_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput
75
76
  from test.transformers.test_dyt import LigerDyT
76
77
  from test.transformers.test_dyt import TorchDyT
77
78
 
78
- BT = input.x
79
+ hidden_size = input.x
79
80
  provider = input.kernel_provider
80
81
  extra_benchmark_config = input.extra_benchmark_config
81
- hidden_size = extra_benchmark_config["hidden_size"]
82
+ BT = extra_benchmark_config["BT"]
83
+ beta = extra_benchmark_config["beta"]
82
84
  dtype = extra_benchmark_config["dtype"]
83
85
 
84
86
  x_shape = (BT, hidden_size)
85
- torch_dyt = TorchDyT(hidden_size=hidden_size).to(device)
86
- torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size).to(device))
87
- triton_dyt = LigerDyT(hidden_size=hidden_size).to(device)
87
+ torch_dyt = TorchDyT(hidden_size=hidden_size, beta=beta).to(device)
88
+ torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size, beta=beta).to(device))
89
+ triton_dyt = LigerDyT(hidden_size=hidden_size, beta=beta).to(device)
88
90
 
89
91
  x = torch.randn(x_shape, dtype=dtype, device=device)
90
92
  dy = torch.randn_like(x)
@@ -113,27 +115,28 @@ def bench_memory_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput
113
115
  if __name__ == "__main__":
114
116
  args = parse_benchmark_script_args()
115
117
 
116
- common_configs = {
117
- "kernel_name": "dyt",
118
- "x_name": "BT",
119
- "x_label": "batch_size * seq_len",
120
- "x_values": [2**i for i in range(10, 15)],
121
- "kernel_providers": ["liger", "torch", "torch_compile"],
122
- "extra_benchmark_configs": [{"hidden_size": 4096, "dtype": torch.float32}],
123
- "overwrite": args.overwrite,
124
- }
125
-
126
- run_benchmarks(
127
- bench_test_fn=bench_speed_dyt,
128
- kernel_operation_modes=["forward", "backward", "full"],
129
- metric_name="speed",
130
- metric_unit="ms",
131
- **common_configs,
132
- )
133
- run_benchmarks(
134
- bench_test_fn=bench_memory_dyt,
135
- kernel_operation_modes=["full"],
136
- metric_name="memory",
137
- metric_unit="MB",
138
- **common_configs,
139
- )
118
+ for beta in [False, True]:
119
+ common_configs = {
120
+ "kernel_name": f"dyt_beta={beta}",
121
+ "x_name": "hidden_size",
122
+ "x_label": "hidden_size",
123
+ "x_values": [1024 * i for i in range(1, 17)],
124
+ "kernel_providers": ["liger", "torch", "torch_compile"],
125
+ "extra_benchmark_configs": [{"BT": 4096, "dtype": torch.bfloat16, "beta": beta}],
126
+ "overwrite": args.overwrite,
127
+ }
128
+
129
+ run_benchmarks(
130
+ bench_test_fn=bench_speed_dyt,
131
+ kernel_operation_modes=["forward", "backward", "full"],
132
+ metric_name="speed",
133
+ metric_unit="ms",
134
+ **common_configs,
135
+ )
136
+ run_benchmarks(
137
+ bench_test_fn=bench_memory_dyt,
138
+ kernel_operation_modes=["full"],
139
+ metric_name="memory",
140
+ metric_unit="MB",
141
+ **common_configs,
142
+ )
@@ -8,7 +8,9 @@
8
8
  | SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` |
9
9
  | GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
10
10
  | CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
11
- | Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
11
+ | Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
12
+ | Sparsemax | `liger_kernel.transformers.LigerSparsemax` |
13
+
12
14
 
13
15
  ### RMS Norm
14
16
 
@@ -49,6 +51,12 @@ This kernel combines linear transformations with cross-entropy loss calculations
49
51
  !!! Example "Try it out"
50
52
  You can experiment as shown in this example [here](https://colab.research.google.com/drive/1Z2QtvaIiLm5MWOs7X6ZPS1MN3hcIJFbj?usp=sharing)
51
53
 
54
+ ### Sparsemax
55
+
56
+ Sparsemax is a sparse alternative to softmax that produces sparse probability distributions. This kernel implements an efficient version of the sparsemax operation that can be used as a drop-in replacement for softmax in attention mechanisms or classification tasks.
57
+
58
+ The implementation achieves significant speed improvements and memory savings compared to standard PyTorch implementations, particularly for large input tensors.
59
+
52
60
  ## Alignment Kernels
53
61
 
54
62
  | **Kernel** | **API** |
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.9.dev20250517045713"
7
+ version = "0.5.9.dev20250519011716"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -0,0 +1,159 @@
1
+ import operator
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from triton.language.extra.libdevice import tanh
8
+
9
+ from liger_kernel.ops.utils import compare_version
10
+ from liger_kernel.ops.utils import ensure_contiguous
11
+ from liger_kernel.ops.utils import infer_device
12
+
13
+ if compare_version("triton", operator.ge, "3.0.0"):
14
+ try:
15
+ # typical import path with dispatch available
16
+ from triton.language.extra.libdevice import tanh
17
+ except ModuleNotFoundError:
18
+ # for working with NGC containers
19
+ from triton.language.extra.cuda.libdevice import tanh
20
+ else:
21
+ from triton.language.math import tanh
22
+
23
+
24
+ # @triton.autotune([triton.Config({"BLOCK_N":bn}, num_stages=ns, num_warps=nw)
25
+ # for bn in [1024, 2048, 4096]
26
+ # for ns in [1,2,4]
27
+ # for nw in [4, 8, 16, 32]
28
+ # ],
29
+ # key=['N'])
30
+ @triton.jit
31
+ def _dyt_fwd_kernel(X, Y, Alpha, Gamma, Beta, HAVE_BETA: tl.constexpr, N: tl.constexpr, BLOCK_N: tl.constexpr = 1024):
32
+ col = tl.cast(tl.program_id(0), tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
33
+ mask = col < N
34
+ row_id = tl.cast(tl.program_id(1), tl.int64)
35
+
36
+ X += row_id * N
37
+ Y += row_id * N
38
+ alpha = tl.load(Alpha).to(tl.float32)
39
+
40
+ gamma = tl.load(Gamma + col, mask=mask, other=0.0).to(tl.float32)
41
+
42
+ x = tl.load(X + col, mask=mask, other=0.0).to(tl.float32)
43
+
44
+ tanh_x = tanh(alpha * x)
45
+ y = tanh_x * gamma
46
+ if HAVE_BETA:
47
+ beta = tl.load(Beta + col, mask=mask, other=0.0).to(tl.float32)
48
+ y += beta
49
+ tl.store(Y + col, y, mask=mask)
50
+
51
+
52
+ # @triton.autotune([triton.Config({"BLOCK_N":bn}, num_stages=ns, num_warps=nw)
53
+ # for bn in [1024, 2048, 4096]
54
+ # for ns in [1,2,4]
55
+ # for nw in [4, 8, 16]
56
+ # ],
57
+ # key=['N'])
58
+ @triton.jit
59
+ def _dyt_bwd_kernel(
60
+ DY, DX, DA, DG, DB, X, Alpha, Gamma, HAVE_BETA: tl.constexpr, M, N: tl.constexpr, BLOCK_N: tl.constexpr = 1024
61
+ ):
62
+ col = tl.cast(tl.program_id(0), tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
63
+ mask = col < N
64
+ start_row_id = tl.cast(tl.program_id(1), tl.int64)
65
+
66
+ alpha = tl.load(Alpha).to(tl.float32)
67
+ da = 0.0
68
+ gamma = tl.load(Gamma + col, mask=mask, other=0.0).to(tl.float32)
69
+ dg = tl.zeros((BLOCK_N,), dtype=tl.float32)
70
+ if HAVE_BETA:
71
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
72
+ for row_id in range(start_row_id, M, tl.num_programs(1)):
73
+ x = tl.load(X + row_id * N + col, mask=mask, other=0.0).to(tl.float32)
74
+ dy = tl.load(DY + row_id * N + col, mask=mask, other=0.0).to(tl.float32)
75
+ tanh_x = tanh(alpha * x)
76
+ if HAVE_BETA:
77
+ db += dy
78
+ dg += dy * tanh_x
79
+ tmp = (1 - tanh_x * tanh_x) * dy * gamma
80
+ da += tl.sum(x * tmp, 0)
81
+ dx = alpha * tmp
82
+ tl.store(DX + row_id * N + col, dx, mask=mask)
83
+
84
+ tl.store(DG + start_row_id * N + col, dg, mask=mask)
85
+ if HAVE_BETA:
86
+ tl.store(DB + start_row_id * N + col, db, mask=mask)
87
+ tl.store(DA + start_row_id * tl.cdiv(N, 512) + tl.program_id(0), da)
88
+
89
+
90
+ def liger_dyt_fwd(x, alpha, gamma, beta):
91
+ assert x.is_contiguous()
92
+ HAVE_BETA = True if beta is not None else False
93
+ input_shape = x.shape
94
+ x = x.view(-1, input_shape[-1])
95
+ M, N = x.shape
96
+
97
+ y = torch.empty_like(x)
98
+
99
+ if N >= 4096:
100
+ kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 2048), "num_warps": 4, "num_stages": 1}
101
+ else:
102
+ kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 1024), "num_warps": 4, "num_stages": 1}
103
+
104
+ grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), M)
105
+ _dyt_fwd_kernel[(grid)](
106
+ x,
107
+ y,
108
+ alpha,
109
+ gamma,
110
+ beta,
111
+ HAVE_BETA,
112
+ N,
113
+ **kwargs,
114
+ )
115
+ return y.view(input_shape)
116
+
117
+
118
+ def liger_dyt_bwd(dy, x, alpha, gamma, beta):
119
+ assert dy.is_contiguous()
120
+ input_shape = x.shape
121
+ x = x.view(-1, input_shape[-1])
122
+ M, N = x.shape
123
+ HAVE_BETA = True if beta is not None else False
124
+
125
+ device = infer_device()
126
+ if device == "cuda":
127
+ NUM_SMS = torch.cuda.get_device_properties(x.device).multi_processor_count
128
+ elif device == "xpu":
129
+ NUM_SMS = torch.xpu.get_device_properties(x.device).gpu_subslice_count
130
+
131
+ da = torch.zeros(NUM_SMS, triton.cdiv(N, 512), dtype=torch.float32, device=x.device)
132
+ dg = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device)
133
+ db = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None
134
+ dx = torch.empty_like(dy)
135
+
136
+ kwargs = {"BLOCK_N": min(triton.next_power_of_2(N), 1024), "num_warps": 8, "num_stages": 2}
137
+ grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), NUM_SMS)
138
+ _dyt_bwd_kernel[grid](dy, dx, da, dg, db, x, alpha, gamma, HAVE_BETA, M, N, **kwargs)
139
+ if HAVE_BETA:
140
+ db = db.sum(0).to(x.dtype)
141
+ dg = dg.sum(0).to(gamma.dtype)
142
+ da = da.sum().to(x.dtype).unsqueeze(0)
143
+ return dx.view(input_shape), da, dg, db
144
+
145
+
146
+ class LigerDyTFunction(torch.autograd.Function):
147
+ @staticmethod
148
+ @ensure_contiguous
149
+ def forward(ctx, x, alpha, gamma, beta):
150
+ y = liger_dyt_fwd(x, alpha, gamma, beta)
151
+ ctx.save_for_backward(x, alpha, gamma, beta)
152
+ return y
153
+
154
+ @staticmethod
155
+ @ensure_contiguous
156
+ def backward(ctx, dy):
157
+ x, alpha, gamma, beta = ctx.saved_tensors
158
+ dx, dalpha, dgamma, dbeta = liger_dyt_bwd(dy, x, alpha, gamma, beta)
159
+ return dx, dalpha, dgamma, dbeta
@@ -5,16 +5,18 @@ from liger_kernel.ops.dyt import LigerDyTFunction
5
5
 
6
6
 
7
7
  class LigerDyT(nn.Module):
8
- def __init__(self, hidden_size, init_alpha=0.5):
8
+ def __init__(self, hidden_size, beta=True, init_alpha=0.5):
9
9
  super().__init__()
10
10
  self.hidden_size = hidden_size
11
11
  self.init_alpha = init_alpha
12
12
  self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
13
13
  self.gamma = nn.Parameter(torch.ones(hidden_size))
14
- self.beta = nn.Parameter(torch.zeros(hidden_size))
14
+ self.beta = None
15
+ if beta:
16
+ self.beta = nn.Parameter(torch.zeros(hidden_size))
15
17
 
16
18
  def forward(self, x):
17
19
  return LigerDyTFunction.apply(x, self.alpha, self.gamma, self.beta)
18
20
 
19
21
  def extra_repr(self):
20
- return f"{self.hidden_size}, init_alpha={self.init_alpha}"
22
+ return f"{self.hidden_size}, init_alpha={self.init_alpha}, beta={self.beta}"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.9.dev20250517045713
3
+ Version: 0.5.9.dev20250519011716
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -322,7 +322,8 @@ loss.backward()
322
322
  | SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` |
323
323
  | GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
324
324
  | CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
325
- | Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
325
+ | Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
326
+ | Sparsemax | `liger_kernel.transformers.LigerSparsemax` |
326
327
 
327
328
 
328
329
  ### Alignment Kernels
@@ -1,6 +1,5 @@
1
1
  import pytest
2
2
  import torch
3
- import torch.nn as nn
4
3
 
5
4
  from test.utils import assert_verbose_allclose
6
5
  from test.utils import infer_device
@@ -12,22 +11,37 @@ from liger_kernel.transformers.dyt import LigerDyT
12
11
  from liger_kernel.transformers.functional import liger_dyt
13
12
 
14
13
 
15
- class TorchDyT(nn.Module):
16
- def __init__(self, hidden_size, init_alpha=0.5):
14
+ # @torch.compile
15
+ def torch_dyt_with_beta(x, alpha, gamma, beta):
16
+ return gamma * torch.tanh(x * alpha) + beta
17
+
18
+
19
+ # @torch.compile
20
+ def torch_dyt_without_beta(x, alpha, gamma):
21
+ return gamma * torch.tanh(x * alpha)
22
+
23
+
24
+ class TorchDyT(torch.nn.Module):
25
+ def __init__(self, hidden_size, beta=True, init_alpha=0.5):
17
26
  super().__init__()
18
- self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
19
- self.gamma = nn.Parameter(torch.ones(hidden_size))
20
- self.beta = nn.Parameter(torch.zeros(hidden_size))
27
+ self.alpha = torch.nn.Parameter(torch.ones(1) * init_alpha)
28
+ self.gamma = torch.nn.Parameter(torch.ones(hidden_size))
29
+ self.beta = None
30
+ if beta:
31
+ self.beta = torch.nn.Parameter(torch.zeros(hidden_size))
21
32
 
22
33
  def forward(self, x):
23
- return self.gamma * torch.tanh(self.alpha * x) + self.beta
34
+ if self.beta is None:
35
+ return torch_dyt_without_beta(x, self.alpha, self.gamma)
36
+ return torch_dyt_with_beta(x, self.alpha, self.gamma, self.beta)
24
37
 
25
38
 
26
39
  set_seed(42)
27
40
  device = infer_device()
28
41
 
29
42
 
30
- @pytest.mark.parametrize("init_alpha", [0.5, 0.2, 1.0])
43
+ @pytest.mark.parametrize("beta", [False, True])
44
+ @pytest.mark.parametrize("init_alpha", [0.5])
31
45
  @pytest.mark.parametrize(
32
46
  "B, T, hidden_size",
33
47
  [
@@ -43,7 +57,7 @@ device = infer_device()
43
57
  (torch.float32, 1e-5, 1e-5),
44
58
  ],
45
59
  )
46
- def test_liger_dyt_correctness(B, T, hidden_size, init_alpha, dtype, atol, rtol):
60
+ def test_liger_dyt_correctness(B, T, hidden_size, beta, init_alpha, dtype, atol, rtol):
47
61
  _input = torch.randn(B, T, hidden_size, device=device, dtype=dtype)
48
62
 
49
63
  x1 = _input.clone().requires_grad_(True)
@@ -52,17 +66,19 @@ def test_liger_dyt_correctness(B, T, hidden_size, init_alpha, dtype, atol, rtol)
52
66
  # initialize weights
53
67
  alpha = torch.randn(1, device=device, dtype=dtype)
54
68
  gamma = torch.randn(hidden_size, device=device, dtype=dtype)
55
- beta = torch.randn(hidden_size, device=device, dtype=dtype)
69
+ beta_data = torch.randn(hidden_size, device=device, dtype=dtype)
56
70
 
57
- torch_dyt = TorchDyT(hidden_size=hidden_size, init_alpha=init_alpha).to(device).to(dtype)
71
+ torch_dyt = TorchDyT(hidden_size=hidden_size, beta=beta, init_alpha=init_alpha).to(device).to(dtype)
58
72
  torch_dyt.alpha.data = alpha.clone()
59
73
  torch_dyt.gamma.data = gamma.clone()
60
- torch_dyt.beta.data = beta.clone()
74
+ if beta:
75
+ torch_dyt.beta.data = beta_data.clone()
61
76
 
62
- liger_dyt = LigerDyT(hidden_size=hidden_size, init_alpha=init_alpha).to(device).to(dtype)
77
+ liger_dyt = LigerDyT(hidden_size=hidden_size, beta=beta, init_alpha=init_alpha).to(device).to(dtype)
63
78
  liger_dyt.alpha.data = alpha.clone()
64
79
  liger_dyt.gamma.data = gamma.clone()
65
- liger_dyt.beta.data = beta.clone()
80
+ if beta:
81
+ liger_dyt.beta.data = beta_data.clone()
66
82
 
67
83
  torch_output = torch_dyt(x1)
68
84
  liger_output = liger_dyt(x2)
@@ -76,9 +92,11 @@ def test_liger_dyt_correctness(B, T, hidden_size, init_alpha, dtype, atol, rtol)
76
92
  assert_verbose_allclose(x1.grad, x2.grad, rtol=rtol, atol=atol)
77
93
  assert_verbose_allclose(torch_dyt.alpha.grad, liger_dyt.alpha.grad, rtol=rtol, atol=atol)
78
94
  assert_verbose_allclose(torch_dyt.gamma.grad, liger_dyt.gamma.grad, rtol=rtol, atol=atol)
79
- assert_verbose_allclose(torch_dyt.beta.grad, liger_dyt.beta.grad, rtol=rtol, atol=atol)
95
+ if beta:
96
+ assert_verbose_allclose(torch_dyt.beta.grad, liger_dyt.beta.grad, rtol=rtol, atol=atol)
80
97
 
81
98
 
99
+ @pytest.mark.parametrize("beta", [False, True])
82
100
  @pytest.mark.parametrize(
83
101
  "B, T, hidden_size",
84
102
  [
@@ -102,7 +120,7 @@ def test_liger_dyt_correctness(B, T, hidden_size, init_alpha, dtype, atol, rtol)
102
120
  ),
103
121
  ],
104
122
  )
105
- def test_liger_dyt_functional(B, T, hidden_size, dtype, atol, rtol):
123
+ def test_liger_dyt_functional(B, T, hidden_size, beta, dtype, atol, rtol):
106
124
  _input = torch.randn(B, T, hidden_size, device=device, dtype=dtype)
107
125
 
108
126
  x1 = _input.clone().requires_grad_(True)
@@ -111,15 +129,16 @@ def test_liger_dyt_functional(B, T, hidden_size, dtype, atol, rtol):
111
129
  # initialize weights
112
130
  alpha = torch.randn(1, device=device, dtype=dtype)
113
131
  gamma = torch.randn(hidden_size, device=device, dtype=dtype)
114
- beta = torch.randn(hidden_size, device=device, dtype=dtype)
132
+ beta_data = torch.randn(hidden_size, device=device, dtype=dtype)
115
133
 
116
134
  alpha1 = alpha.clone().requires_grad_(True)
117
135
  gamma1 = gamma.clone().requires_grad_(True)
118
- beta1 = beta.clone().requires_grad_(True)
136
+ beta1 = beta_data.clone().requires_grad_(True) if beta else None
119
137
 
120
138
  alpha2 = alpha.clone().requires_grad_(True)
121
139
  gamma2 = gamma.clone().requires_grad_(True)
122
- beta2 = beta.clone().requires_grad_(True)
140
+
141
+ beta2 = beta_data.clone().requires_grad_(True) if beta else None
123
142
 
124
143
  output1 = liger_dyt(x1, alpha=alpha1, gamma=gamma1, beta=beta1)
125
144
  output2 = LigerDyTFunction.apply(x2, alpha2, gamma2, beta2)
@@ -133,4 +152,5 @@ def test_liger_dyt_functional(B, T, hidden_size, dtype, atol, rtol):
133
152
  assert_verbose_allclose(x1.grad, x2.grad, rtol=rtol, atol=atol)
134
153
  assert_verbose_allclose(alpha1.grad, alpha2.grad, rtol=rtol, atol=atol)
135
154
  assert_verbose_allclose(gamma1.grad, gamma2.grad, rtol=rtol, atol=atol)
136
- assert_verbose_allclose(beta1.grad, beta2.grad, rtol=rtol, atol=atol)
155
+ if beta:
156
+ assert_verbose_allclose(beta1.grad, beta2.grad, rtol=rtol, atol=atol)