liger-kernel-nightly 0.5.9.dev20250515034325__tar.gz → 0.5.9.dev20250515065336__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (251) hide show
  1. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/data/all_benchmark_data.csv +72 -0
  3. liger_kernel_nightly-0.5.9.dev20250515065336/benchmark/scripts/benchmark_sparsemax.py +172 -0
  4. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/pyproject.toml +1 -1
  5. liger_kernel_nightly-0.5.9.dev20250515065336/src/liger_kernel/ops/sparsemax.py +167 -0
  6. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/functional.py +8 -0
  7. liger_kernel_nightly-0.5.9.dev20250515065336/src/liger_kernel/transformers/sparsemax.py +16 -0
  8. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
  9. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel_nightly.egg-info/SOURCES.txt +4 -0
  10. liger_kernel_nightly-0.5.9.dev20250515065336/test/transformers/test_sparsemax.py +111 -0
  11. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  12. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  13. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/pull_request_template.md +0 -0
  14. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/workflows/amd-ci.yml +0 -0
  15. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/workflows/docs.yml +0 -0
  16. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/workflows/intel-ci.yml +0 -0
  17. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/workflows/nvi-ci.yml +0 -0
  18. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/workflows/publish-nightly.yml +0 -0
  19. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/workflows/publish-release.yml +0 -0
  20. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/.gitignore +0 -0
  21. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/.idea/workspace.xml +0 -0
  22. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/LICENSE +0 -0
  23. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/Makefile +0 -0
  24. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/NOTICE +0 -0
  25. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/README.md +0 -0
  26. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/README.md +0 -0
  27. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/__init__.py +0 -0
  28. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/benchmarks_visualizer.py +0 -0
  29. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/__init__.py +0 -0
  30. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  31. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  32. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  33. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  34. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_dyt.py +0 -0
  35. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_embedding.py +0 -0
  36. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  37. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  38. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_geglu.py +0 -0
  39. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_group_norm.py +0 -0
  40. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_jsd.py +0 -0
  41. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_kl_div.py +0 -0
  42. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  43. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  44. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  45. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  46. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  47. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_rope.py +0 -0
  48. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  49. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_swiglu.py +0 -0
  50. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_tvd.py +0 -0
  51. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/utils.py +0 -0
  52. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/dev/fmt-requirements.txt +0 -0
  53. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/dev/modal/tests.py +0 -0
  54. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/dev/modal/tests_bwd.py +0 -0
  55. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/Examples.md +0 -0
  56. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/Getting-Started.md +0 -0
  57. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/High-Level-APIs.md +0 -0
  58. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/Low-Level-APIs.md +0 -0
  59. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/acknowledgement.md +0 -0
  60. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/contributing.md +0 -0
  61. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/images/banner.GIF +0 -0
  62. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/images/compose.gif +0 -0
  63. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/images/e2e-memory.png +0 -0
  64. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/images/e2e-tps.png +0 -0
  65. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/images/logo-banner.png +0 -0
  66. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/images/patch.gif +0 -0
  67. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/images/post-training.png +0 -0
  68. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/index.md +0 -0
  69. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/license.md +0 -0
  70. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/alignment/accelerate_config.yaml +0 -0
  71. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/alignment/run_orpo.py +0 -0
  72. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/README.md +0 -0
  73. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/callback.py +0 -0
  74. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/config/fsdp_config.json +0 -0
  75. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  76. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  77. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  78. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/img/llama_tps.png +0 -0
  79. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  80. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/img/qwen_tps.png +0 -0
  81. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/launch_on_modal.py +0 -0
  82. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/requirements.txt +0 -0
  83. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/run_benchmarks.sh +0 -0
  84. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/run_gemma.sh +0 -0
  85. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/run_llama.sh +0 -0
  86. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/run_qwen.sh +0 -0
  87. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/run_qwen2_vl.sh +0 -0
  88. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/training.py +0 -0
  89. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/training_multimodal.py +0 -0
  90. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/lightning/README.md +0 -0
  91. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/lightning/requirements.txt +0 -0
  92. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/lightning/training.py +0 -0
  93. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/README.md +0 -0
  94. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/callback.py +0 -0
  95. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  96. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  97. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  98. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  99. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  100. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  101. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  102. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  103. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  104. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/medusa_util.py +0 -0
  105. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/requirements.txt +0 -0
  106. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  107. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/train.py +0 -0
  108. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/licenses/LICENSE-Apache-2.0 +0 -0
  109. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  110. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  111. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/licenses/LICENSE-MIT-llmc +0 -0
  112. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/licenses/LICENSE-MIT-triton +0 -0
  113. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/mkdocs.yml +0 -0
  114. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/setup.cfg +0 -0
  115. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/setup.py +0 -0
  116. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/__init__.py +0 -0
  117. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/README.md +0 -0
  118. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  119. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  120. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  121. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/functional.py +0 -0
  122. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  123. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/fused_linear_ppo.py +0 -0
  124. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  125. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  126. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
  127. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  128. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  129. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  130. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  131. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/env_report.py +0 -0
  132. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/__init__.py +0 -0
  133. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/cross_entropy.py +0 -0
  134. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/dyt.py +0 -0
  135. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  136. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  137. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  138. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  139. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/geglu.py +0 -0
  140. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/group_norm.py +0 -0
  141. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/jsd.py +0 -0
  142. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/kl_div.py +0 -0
  143. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/layer_norm.py +0 -0
  144. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  145. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/rms_norm.py +0 -0
  146. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/rope.py +0 -0
  147. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/swiglu.py +0 -0
  148. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/tvd.py +0 -0
  149. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/utils.py +0 -0
  150. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/__init__.py +0 -0
  151. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/auto_model.py +0 -0
  152. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  153. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/dyt.py +0 -0
  154. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  155. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  156. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  157. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/geglu.py +0 -0
  158. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/gema3_rms.py +0 -0
  159. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/group_norm.py +0 -0
  160. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/jsd.py +0 -0
  161. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/kl_div.py +0 -0
  162. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/layer_norm.py +0 -0
  163. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/__init__.py +0 -0
  164. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/gemma.py +0 -0
  165. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  166. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/gemma3.py +0 -0
  167. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/glm4.py +0 -0
  168. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/llama.py +0 -0
  169. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/llava.py +0 -0
  170. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
  171. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/mistral.py +0 -0
  172. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  173. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/mllama.py +0 -0
  174. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/olmo2.py +0 -0
  175. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/paligemma.py +0 -0
  176. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/phi3.py +0 -0
  177. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  178. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
  179. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  180. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/qwen3.py +0 -0
  181. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/qwen3_moe.py +0 -0
  182. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  183. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  184. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/rms_norm.py +0 -0
  185. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/rope.py +0 -0
  186. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/swiglu.py +0 -0
  187. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  188. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  189. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  190. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/tvd.py +0 -0
  191. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/triton/__init__.py +0 -0
  192. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/triton/monkey_patch.py +0 -0
  193. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/utils.py +0 -0
  194. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  195. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  196. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  197. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/__init__.py +0 -0
  198. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/__init__.py +0 -0
  199. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/test_cpo_loss.py +0 -0
  200. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/test_dpo_loss.py +0 -0
  201. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/test_grpo_loss.py +0 -0
  202. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/test_jsd_loss.py +0 -0
  203. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/test_kto_loss.py +0 -0
  204. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/test_orpo_loss.py +0 -0
  205. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/test_simpo_loss.py +0 -0
  206. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/conftest.py +0 -0
  207. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/__init__.py +0 -0
  208. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/bf16/__init__.py +0 -0
  209. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/bf16/test_mini_models.py +0 -0
  210. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
  211. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
  212. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/fp32/__init__.py +0 -0
  213. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/fp32/test_mini_models.py +0 -0
  214. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
  215. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
  216. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json +0 -0
  217. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
  218. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
  219. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
  220. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
  221. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  222. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
  223. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  224. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  225. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/tiny_shakespeare.txt +0 -0
  226. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  227. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  228. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  229. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_auto_model.py +0 -0
  230. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_cross_entropy.py +0 -0
  231. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_dyt.py +0 -0
  232. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_embedding.py +0 -0
  233. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_flex_attention.py +0 -0
  234. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  235. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_fused_linear_jsd.py +0 -0
  236. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_geglu.py +0 -0
  237. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_group_norm.py +0 -0
  238. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_jsd.py +0 -0
  239. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_kl_div.py +0 -0
  240. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_layer_norm.py +0 -0
  241. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_mm_int8int2.py +0 -0
  242. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_monkey_patch.py +0 -0
  243. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_qwen2vl_mrope.py +0 -0
  244. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_rms_norm.py +0 -0
  245. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_rope.py +0 -0
  246. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_swiglu.py +0 -0
  247. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_trainer_integration.py +0 -0
  248. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_transformers.py +0 -0
  249. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_tvd.py +0 -0
  250. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/triton/test_triton_monkey_patch.py +0 -0
  251. {liger_kernel_nightly-0.5.9.dev20250515034325 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.9.dev20250515034325
3
+ Version: 0.5.9.dev20250515065336
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -805,3 +805,75 @@ kto_loss,huggingface,full,memory,MB,B,Batch Size (B),4,6099.26513671875,6099.265
805
805
  kto_loss,huggingface,full,memory,MB,B,Batch Size (B),8,9613.298828125,9613.298828125,9613.298828125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
806
806
  kto_loss,huggingface,full,memory,MB,B,Batch Size (B),16,16643.365234375,16643.365234375,16643.365234375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
807
807
  kto_loss,huggingface,full,memory,MB,B,Batch Size (B),32,30703.498046875,30703.498046875,30703.498046875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
808
+ sparsemax,liger,forward,speed,ms,V,feature size,1024,0.41471999883651733,0.4126720130443573,0.42393600940704346,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
809
+ sparsemax,liger,forward,speed,ms,V,feature size,2048,0.7608320116996765,0.7598080039024353,0.7628800272941589,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
810
+ sparsemax,liger,forward,speed,ms,V,feature size,4096,1.4561280012130737,1.4540799856185913,1.4581760168075562,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
811
+ sparsemax,liger,forward,speed,ms,V,feature size,8192,5.288959980010986,5.2848639488220215,5.29986572265625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
812
+ sparsemax,liger,forward,speed,ms,V,feature size,16384,10.734624862670898,10.729472160339355,11.096882820129395,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
813
+ sparsemax,liger,forward,speed,ms,V,feature size,32768,21.729312896728516,21.7128963470459,22.20728302001953,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
814
+ sparsemax,torch,forward,speed,ms,V,feature size,1024,0.42291200160980225,0.42188799381256104,0.42393600940704346,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
815
+ sparsemax,torch,forward,speed,ms,V,feature size,2048,0.7782400250434875,0.7772160172462463,0.779263973236084,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
816
+ sparsemax,torch,forward,speed,ms,V,feature size,4096,1.4940160512924194,1.491968035697937,1.4960639476776123,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
817
+ sparsemax,torch,forward,speed,ms,V,feature size,8192,5.359615802764893,5.356544017791748,5.366579055786133,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
818
+ sparsemax,torch,forward,speed,ms,V,feature size,16384,10.883584022521973,10.874879837036133,11.224268913269043,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
819
+ sparsemax,torch,forward,speed,ms,V,feature size,32768,22.19878387451172,22.018457412719727,22.48888397216797,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
820
+ sparsemax,liger,full,speed,ms,V,feature size,1024,0.4558719992637634,0.45558398962020874,0.45772799849510193,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
821
+ sparsemax,liger,full,speed,ms,V,feature size,2048,0.8488960266113281,0.8478720188140869,0.8509439826011658,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
822
+ sparsemax,liger,full,speed,ms,V,feature size,4096,1.6476160287857056,1.6465920209884644,1.6499264240264893,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
823
+ sparsemax,liger,full,speed,ms,V,feature size,8192,5.664768218994141,5.660672187805176,5.681356906890869,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
824
+ sparsemax,liger,full,speed,ms,V,feature size,16384,11.486207962036133,11.478015899658203,11.874713897705078,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
825
+ sparsemax,liger,full,speed,ms,V,feature size,32768,23.457279205322266,23.289682388305664,23.76642608642578,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
826
+ sparsemax,torch,full,speed,ms,V,feature size,1024,0.6021119952201843,0.6010879874229431,0.6041600108146667,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
827
+ sparsemax,torch,full,speed,ms,V,feature size,2048,1.1212799549102783,1.119264006614685,1.1223039627075195,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
828
+ sparsemax,torch,full,speed,ms,V,feature size,4096,2.1637120246887207,2.1616640090942383,2.165760040283203,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
829
+ sparsemax,torch,full,speed,ms,V,feature size,8192,6.693888187408447,6.68723201751709,6.705561637878418,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
830
+ sparsemax,torch,full,speed,ms,V,feature size,16384,13.523456573486328,13.518848419189453,13.878681182861328,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
831
+ sparsemax,torch,full,speed,ms,V,feature size,32768,27.604991912841797,27.295129776000977,27.77518081665039,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
832
+ sparsemax,liger,backward,speed,ms,V,feature size,1024,0.04403200000524521,0.043007999658584595,0.05222399905323982,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
833
+ sparsemax,liger,backward,speed,ms,V,feature size,2048,0.08806400001049042,0.08713600039482117,0.08806400001049042,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
834
+ sparsemax,liger,backward,speed,ms,V,feature size,4096,0.1884160041809082,0.1884160041809082,0.18943999707698822,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
835
+ sparsemax,liger,backward,speed,ms,V,feature size,8192,0.374783992767334,0.37376001477241516,0.37486720085144043,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
836
+ sparsemax,liger,backward,speed,ms,V,feature size,16384,0.7516160011291504,0.7505919933319092,0.7516160011291504,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
837
+ sparsemax,liger,backward,speed,ms,V,feature size,32768,1.5738879442214966,1.572864055633545,1.575935959815979,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
838
+ sparsemax,torch,backward,speed,ms,V,feature size,1024,0.1812479943037033,0.1802240014076233,0.18227200210094452,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
839
+ sparsemax,torch,backward,speed,ms,V,feature size,2048,0.34406399726867676,0.34406399726867676,0.34508800506591797,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
840
+ sparsemax,torch,backward,speed,ms,V,feature size,4096,0.6717439889907837,0.6707199811935425,0.6727679967880249,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
841
+ sparsemax,torch,backward,speed,ms,V,feature size,8192,1.3250559568405151,1.3241215944290161,1.3260799646377563,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
842
+ sparsemax,torch,backward,speed,ms,V,feature size,16384,2.629631996154785,2.628607988357544,2.6306560039520264,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
843
+ sparsemax,torch,backward,speed,ms,V,feature size,32768,5.236735820770264,5.235712051391602,5.239808082580566,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
844
+ sparsemax,liger,full,memory,MB,V,feature size,1024,82.03515625,82.03515625,82.03515625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
845
+ sparsemax,liger,full,memory,MB,V,feature size,2048,164.0390625,164.0390625,164.0390625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
846
+ sparsemax,liger,full,memory,MB,V,feature size,4096,328.046875,328.046875,328.046875,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
847
+ sparsemax,liger,full,memory,MB,V,feature size,8192,704.00048828125,704.00048828125,704.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
848
+ sparsemax,liger,full,memory,MB,V,feature size,16384,1408.00048828125,1408.00048828125,1408.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
849
+ sparsemax,liger,full,memory,MB,V,feature size,32768,2816.00048828125,2816.00048828125,2816.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
850
+ sparsemax,torch,full,memory,MB,V,feature size,1024,82.03515625,82.03515625,82.03515625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
851
+ sparsemax,torch,full,memory,MB,V,feature size,2048,164.0390625,164.0390625,164.0390625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
852
+ sparsemax,torch,full,memory,MB,V,feature size,4096,328.046875,328.046875,328.046875,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
853
+ sparsemax,torch,full,memory,MB,V,feature size,8192,704.00048828125,704.00048828125,704.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
854
+ sparsemax,torch,full,memory,MB,V,feature size,16384,1408.00048828125,1408.00048828125,1408.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
855
+ sparsemax,torch,full,memory,MB,V,feature size,32768,2816.00048828125,2816.00048828125,2816.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
856
+ sparsemax,liger,forward,memory,MB,V,feature size,1024,56.0078125,56.0078125,56.0078125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
857
+ sparsemax,liger,forward,memory,MB,V,feature size,2048,112.015625,112.015625,112.015625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
858
+ sparsemax,liger,forward,memory,MB,V,feature size,4096,224.03125,224.03125,224.03125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
859
+ sparsemax,liger,forward,memory,MB,V,feature size,8192,768.00048828125,768.00048828125,768.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
860
+ sparsemax,liger,forward,memory,MB,V,feature size,16384,1536.00048828125,1536.00048828125,1536.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
861
+ sparsemax,liger,forward,memory,MB,V,feature size,32768,3072.00048828125,3072.00048828125,3072.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
862
+ sparsemax,torch,forward,memory,MB,V,feature size,1024,82.03515625,82.03515625,82.03515625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
863
+ sparsemax,torch,forward,memory,MB,V,feature size,2048,164.0390625,164.0390625,164.0390625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
864
+ sparsemax,torch,forward,memory,MB,V,feature size,4096,328.046875,328.046875,328.046875,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
865
+ sparsemax,torch,forward,memory,MB,V,feature size,8192,704.00048828125,704.00048828125,704.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
866
+ sparsemax,torch,forward,memory,MB,V,feature size,16384,1408.00048828125,1408.00048828125,1408.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
867
+ sparsemax,torch,forward,memory,MB,V,feature size,32768,2816.00048828125,2816.00048828125,2816.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
868
+ sparsemax,liger,backward,memory,MB,V,feature size,1024,56.0078125,56.0078125,56.0078125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
869
+ sparsemax,liger,backward,memory,MB,V,feature size,2048,112.015625,112.015625,112.015625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
870
+ sparsemax,liger,backward,memory,MB,V,feature size,4096,224.03125,224.03125,224.03125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
871
+ sparsemax,liger,backward,memory,MB,V,feature size,8192,768.00048828125,768.00048828125,768.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
872
+ sparsemax,liger,backward,memory,MB,V,feature size,16384,1536.00048828125,1536.00048828125,1536.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
873
+ sparsemax,liger,backward,memory,MB,V,feature size,32768,3072.00048828125,3072.00048828125,3072.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
874
+ sparsemax,torch,backward,memory,MB,V,feature size,1024,82.03515625,82.03515625,82.03515625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
875
+ sparsemax,torch,backward,memory,MB,V,feature size,2048,164.0390625,164.0390625,164.0390625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
876
+ sparsemax,torch,backward,memory,MB,V,feature size,4096,328.046875,328.046875,328.046875,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
877
+ sparsemax,torch,backward,memory,MB,V,feature size,8192,704.00048828125,704.00048828125,704.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
878
+ sparsemax,torch,backward,memory,MB,V,feature size,16384,1408.00048828125,1408.00048828125,1408.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
879
+ sparsemax,torch,backward,memory,MB,V,feature size,32768,2816.00048828125,2816.00048828125,2816.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
@@ -0,0 +1,172 @@
1
+ import torch
2
+ import triton
3
+
4
+ from utils import QUANTILES
5
+ from utils import SingleBenchmarkRunInput
6
+ from utils import SingleBenchmarkRunOutput
7
+ from utils import _test_memory
8
+ from utils import parse_benchmark_script_args
9
+ from utils import run_benchmarks
10
+
11
+ from liger_kernel.transformers.sparsemax import LigerSparsemax
12
+ from liger_kernel.utils import infer_device
13
+
14
+ device = infer_device()
15
+
16
+
17
+ def torch_sparsemax(input_tensor: torch.Tensor, dim: int = -1) -> torch.Tensor:
18
+ input_dims = input_tensor.dim()
19
+ if dim < 0:
20
+ dim = input_dims + dim
21
+ input_sorted, _ = torch.sort(input_tensor, dim=dim, descending=True)
22
+ cumsum_input = torch.cumsum(input_sorted, dim=dim)
23
+ input_size = input_tensor.size(dim)
24
+ range_tensor = torch.arange(1, input_size + 1, device=input_tensor.device, dtype=input_tensor.dtype)
25
+ shape = [1] * input_dims
26
+ shape[dim] = input_size
27
+ range_tensor = range_tensor.view(shape)
28
+ k_bound = 1 + range_tensor * input_sorted
29
+ support = k_bound > cumsum_input
30
+ k = support.sum(dim=dim, keepdim=True).clamp(min=1)
31
+ support_sum = (input_sorted * support).sum(dim=dim, keepdim=True)
32
+ tau = (support_sum - 1) / k
33
+ return torch.clamp(input_tensor - tau, min=0)
34
+
35
+
36
+ class TorchSparsemax(torch.nn.Module):
37
+ def __init__(self, dim: int = -1):
38
+ super().__init__()
39
+ self.dim = dim
40
+
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ return torch_sparsemax(x, dim=self.dim)
43
+
44
+
45
+ def bench_speed_sparsemax(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
46
+ V = input.x
47
+ provider = input.kernel_provider
48
+ mode = input.kernel_operation_mode
49
+
50
+ extra_benchmark_config = input.extra_benchmark_config
51
+ B = extra_benchmark_config["B"]
52
+ T = extra_benchmark_config["T"]
53
+ dim = extra_benchmark_config["dim"]
54
+ dtype = extra_benchmark_config["dtype"]
55
+
56
+ x_shape = (B * T, V)
57
+
58
+ torch_sparsemax_module = TorchSparsemax(dim=dim).to(device)
59
+ liger_sparsemax_module = LigerSparsemax(dim=dim).to(device)
60
+
61
+ x = torch.randn(x_shape, dtype=dtype, device=device)
62
+ dy = torch.randn_like(x)
63
+ x.requires_grad_(True)
64
+
65
+ # utility functions
66
+ def y_fwd():
67
+ if provider == "liger":
68
+ return liger_sparsemax_module(x)
69
+ elif provider == "torch":
70
+ return torch_sparsemax_module(x)
71
+
72
+ if mode == "forward":
73
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(
74
+ y_fwd,
75
+ grad_to_none=[x],
76
+ rep=500,
77
+ quantiles=QUANTILES,
78
+ )
79
+ elif mode == "backward":
80
+ y = y_fwd()
81
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(
82
+ lambda: y.backward(dy, retain_graph=True),
83
+ grad_to_none=[x],
84
+ rep=500,
85
+ quantiles=QUANTILES,
86
+ )
87
+ elif mode == "full":
88
+
89
+ def full():
90
+ y = y_fwd()
91
+ y.backward(dy, retain_graph=True)
92
+
93
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(
94
+ full,
95
+ grad_to_none=[x],
96
+ rep=500,
97
+ quantiles=QUANTILES,
98
+ )
99
+
100
+ return SingleBenchmarkRunOutput(
101
+ y_20=ms_20,
102
+ y_50=ms_50,
103
+ y_80=ms_80,
104
+ )
105
+
106
+
107
+ def bench_memory_sparsemax(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
108
+ V = input.x
109
+ provider = input.kernel_provider
110
+
111
+ extra_benchmark_config = input.extra_benchmark_config
112
+ B = extra_benchmark_config["B"]
113
+ T = extra_benchmark_config["T"]
114
+ dim = extra_benchmark_config["dim"]
115
+ dtype = extra_benchmark_config["dtype"]
116
+
117
+ x_shape = (B * T, V)
118
+
119
+ torch_sparsemax_module = TorchSparsemax(dim=dim).to(device)
120
+ liger_sparsemax_module = LigerSparsemax(dim=dim).to(device)
121
+
122
+ x = torch.randn(x_shape, dtype=dtype, device=device)
123
+ dy = torch.randn_like(x)
124
+ x.requires_grad_(True)
125
+
126
+ # utility functions
127
+ def y_fwd():
128
+ if provider == "liger":
129
+ return liger_sparsemax_module(x)
130
+ elif provider == "torch":
131
+ return torch_sparsemax_module(x)
132
+
133
+ def full():
134
+ y = y_fwd()
135
+ y.backward(dy, retain_graph=True)
136
+
137
+ mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
138
+
139
+ return SingleBenchmarkRunOutput(
140
+ y_20=mem_20,
141
+ y_50=mem_50,
142
+ y_80=mem_80,
143
+ )
144
+
145
+
146
+ if __name__ == "__main__":
147
+ args = parse_benchmark_script_args()
148
+
149
+ common_configs = {
150
+ "kernel_name": "sparsemax",
151
+ "x_name": "V",
152
+ "x_label": "feature size",
153
+ "x_values": [2**i for i in range(10, 16)],
154
+ "kernel_providers": ["liger", "torch"],
155
+ "extra_benchmark_configs": [{"B": 4, "T": 512, "dim": -1, "dtype": torch.float32}],
156
+ "overwrite": args.overwrite,
157
+ }
158
+
159
+ run_benchmarks(
160
+ bench_test_fn=bench_speed_sparsemax,
161
+ kernel_operation_modes=["forward", "full", "backward"],
162
+ metric_name="speed",
163
+ metric_unit="ms",
164
+ **common_configs,
165
+ )
166
+ run_benchmarks(
167
+ bench_test_fn=bench_memory_sparsemax,
168
+ kernel_operation_modes=["full"],
169
+ metric_name="memory",
170
+ metric_unit="MB",
171
+ **common_configs,
172
+ )
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.9.dev20250515034325"
7
+ version = "0.5.9.dev20250515065336"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -0,0 +1,167 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from liger_kernel.ops.utils import calculate_settings
6
+ from liger_kernel.ops.utils import ensure_contiguous
7
+
8
+
9
+ @triton.jit
10
+ def _sparsemax_forward_kernel(
11
+ x_ptr,
12
+ x_stride_row,
13
+ sorted_x_ptr,
14
+ sorted_x_stride_row,
15
+ o_ptr,
16
+ o_stride_row,
17
+ n_cols,
18
+ BLOCK_SIZE: tl.constexpr,
19
+ num_warps: tl.constexpr,
20
+ ):
21
+ pid_row = tl.program_id(0)
22
+ ptr_x_data_row = x_ptr + pid_row * x_stride_row
23
+ ptr_sorted_x_data_row = sorted_x_ptr + pid_row * sorted_x_stride_row
24
+ ptr_output_row = o_ptr + pid_row * o_stride_row
25
+
26
+ offs = tl.arange(0, BLOCK_SIZE)
27
+ mask = offs < n_cols
28
+
29
+ z_sorted_block = tl.load(
30
+ ptr_sorted_x_data_row + offs,
31
+ mask=mask,
32
+ other=-float("inf"),
33
+ cache_modifier=".ca",
34
+ ).to(tl.float32)
35
+
36
+ z_valid = tl.where(mask, z_sorted_block, 0.0)
37
+ cssv = tl.cumsum(z_valid, 0)
38
+
39
+ r = (offs + 1).to(tl.float32)
40
+ safe_r = tl.where(mask, r, 1.0)
41
+
42
+ t_vec = (cssv - 1.0) / safe_r
43
+
44
+ support = (z_sorted_block > t_vec) & mask
45
+
46
+ k_int = tl.sum(support.to(tl.int32), 0)
47
+ k_clamped_int = tl.maximum(k_int, 1)
48
+ k = k_clamped_int.to(tl.float32)
49
+
50
+ s = tl.sum(tl.where(support, z_sorted_block, 0.0), 0)
51
+
52
+ tau = (s - 1.0) / k
53
+
54
+ x_block = tl.load(
55
+ ptr_x_data_row + offs,
56
+ mask=mask,
57
+ other=0.0,
58
+ cache_modifier=".ca",
59
+ ).to(tl.float32)
60
+
61
+ y = tl.maximum(x_block - tau, 0.0)
62
+
63
+ tl.store(
64
+ ptr_output_row + offs,
65
+ y.to(ptr_output_row.dtype.element_ty),
66
+ mask=mask,
67
+ cache_modifier=".cs",
68
+ )
69
+
70
+
71
+ @triton.jit
72
+ def _sparsemax_backward_kernel(
73
+ o_ptr, go_ptr, gi_ptr, stride, n_cols, BLOCK_SIZE: tl.constexpr, num_warps: tl.constexpr
74
+ ):
75
+ row = tl.program_id(0)
76
+ o_row = o_ptr + row * stride
77
+ go_row = go_ptr + row * stride
78
+ gi_row = gi_ptr + row * stride
79
+
80
+ offs = tl.arange(0, BLOCK_SIZE)
81
+
82
+ supp_cnt = tl.zeros((), tl.float32)
83
+ go_sum = tl.zeros((), tl.float32)
84
+
85
+ for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
86
+ offs_iter = i * BLOCK_SIZE + offs
87
+ mask_iter = offs_iter < n_cols
88
+ o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
89
+ go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
90
+ supp = o_val > 0.0
91
+ go_sum += tl.sum(tl.where(supp, go_val, 0.0))
92
+ supp_cnt += tl.sum(supp.to(tl.float32))
93
+
94
+ for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
95
+ offs_iter = i * BLOCK_SIZE + offs
96
+ mask_iter = offs_iter < n_cols
97
+ o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
98
+ go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
99
+ supp = o_val > 0.0
100
+ gi_val = tl.where(
101
+ supp,
102
+ go_val - tl.cast(go_sum / tl.maximum(supp_cnt, 1e-6), gi_row.dtype.element_ty).to(tl.float32),
103
+ 0.0,
104
+ )
105
+ tl.store(gi_row + offs_iter, gi_val.to(gi_row.dtype.element_ty), mask=mask_iter, cache_modifier=".wb")
106
+
107
+
108
+ class LigerSparsemaxFunction(torch.autograd.Function):
109
+ @staticmethod
110
+ @ensure_contiguous
111
+ def forward(ctx, x: torch.Tensor, dim: int):
112
+ if dim < 0:
113
+ dim += x.dim()
114
+ ctx.dim = dim
115
+
116
+ x_sw = x.transpose(dim, -1).contiguous()
117
+ n_cols = x_sw.size(-1)
118
+ n_rows = x_sw.numel() // n_cols
119
+ x_flat = x_sw.view(n_rows, n_cols)
120
+
121
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
122
+ out_flat = torch.empty_like(x_flat)
123
+ grid = (n_rows,)
124
+
125
+ x_sorted_flat = torch.sort(x_flat.float(), dim=-1, descending=True).values
126
+
127
+ _sparsemax_forward_kernel[grid](
128
+ x_flat,
129
+ x_flat.stride(0),
130
+ x_sorted_flat,
131
+ x_sorted_flat.stride(0),
132
+ out_flat,
133
+ out_flat.stride(0),
134
+ n_cols,
135
+ BLOCK_SIZE=BLOCK_SIZE,
136
+ num_warps=num_warps,
137
+ )
138
+
139
+ ctx.save_for_backward(out_flat)
140
+ return out_flat.view_as(x_sw).transpose(dim, -1)
141
+
142
+ @staticmethod
143
+ @ensure_contiguous
144
+ def backward(ctx, grad_out: torch.Tensor):
145
+ (out_flat,) = ctx.saved_tensors
146
+ dim = ctx.dim
147
+
148
+ go_sw = grad_out.transpose(dim, -1).contiguous()
149
+ n_cols = go_sw.size(-1)
150
+ n_rows = go_sw.numel() // n_cols
151
+ go_flat = go_sw.view(n_rows, n_cols)
152
+
153
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
154
+ gi_flat = torch.empty_like(go_flat)
155
+ grid = (n_rows,)
156
+
157
+ _sparsemax_backward_kernel[grid](
158
+ out_flat,
159
+ go_flat,
160
+ gi_flat,
161
+ out_flat.stride(0),
162
+ n_cols,
163
+ BLOCK_SIZE=BLOCK_SIZE,
164
+ num_warps=num_warps,
165
+ )
166
+
167
+ return gi_flat.view_as(go_sw).transpose(dim, -1), None
@@ -12,6 +12,7 @@ from liger_kernel.ops.layer_norm import LigerLayerNormFunction
12
12
  from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
13
13
  from liger_kernel.ops.rms_norm import LigerRMSNormFunction
14
14
  from liger_kernel.ops.rope import LigerRopeFunction
15
+ from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
15
16
  from liger_kernel.ops.swiglu import LigerSiLUMulFunction
16
17
  from liger_kernel.ops.tvd import LigerTVDLossFunction
17
18
 
@@ -159,6 +160,13 @@ def liger_kl_div(
159
160
  )
160
161
 
161
162
 
163
+ def liger_sparsemax(
164
+ input,
165
+ dim: int = -1,
166
+ ):
167
+ return LigerSparsemaxFunction.apply(input, dim)
168
+
169
+
162
170
  def liger_tvd(
163
171
  input,
164
172
  target,
@@ -0,0 +1,16 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
5
+
6
+
7
+ class LigerSparsemax(nn.Module):
8
+ def __init__(self, dim: int = -1):
9
+ super().__init__()
10
+ self.dim = dim
11
+
12
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
13
+ return LigerSparsemaxFunction.apply(x, self.dim)
14
+
15
+ def extra_repr(self) -> str:
16
+ return f"dim={self.dim}"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.9.dev20250515034325
3
+ Version: 0.5.9.dev20250515065336
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -40,6 +40,7 @@ benchmark/scripts/benchmark_qwen2vl_mrope.py
40
40
  benchmark/scripts/benchmark_rms_norm.py
41
41
  benchmark/scripts/benchmark_rope.py
42
42
  benchmark/scripts/benchmark_simpo_loss.py
43
+ benchmark/scripts/benchmark_sparsemax.py
43
44
  benchmark/scripts/benchmark_swiglu.py
44
45
  benchmark/scripts/benchmark_tvd.py
45
46
  benchmark/scripts/utils.py
@@ -134,6 +135,7 @@ src/liger_kernel/ops/layer_norm.py
134
135
  src/liger_kernel/ops/qwen2vl_mrope.py
135
136
  src/liger_kernel/ops/rms_norm.py
136
137
  src/liger_kernel/ops/rope.py
138
+ src/liger_kernel/ops/sparsemax.py
137
139
  src/liger_kernel/ops/swiglu.py
138
140
  src/liger_kernel/ops/tvd.py
139
141
  src/liger_kernel/ops/utils.py
@@ -156,6 +158,7 @@ src/liger_kernel/transformers/monkey_patch.py
156
158
  src/liger_kernel/transformers/qwen2vl_mrope.py
157
159
  src/liger_kernel/transformers/rms_norm.py
158
160
  src/liger_kernel/transformers/rope.py
161
+ src/liger_kernel/transformers/sparsemax.py
159
162
  src/liger_kernel/transformers/swiglu.py
160
163
  src/liger_kernel/transformers/trainer_integration.py
161
164
  src/liger_kernel/transformers/tvd.py
@@ -238,6 +241,7 @@ test/transformers/test_monkey_patch.py
238
241
  test/transformers/test_qwen2vl_mrope.py
239
242
  test/transformers/test_rms_norm.py
240
243
  test/transformers/test_rope.py
244
+ test/transformers/test_sparsemax.py
241
245
  test/transformers/test_swiglu.py
242
246
  test/transformers/test_trainer_integration.py
243
247
  test/transformers/test_transformers.py
@@ -0,0 +1,111 @@
1
+ import pytest
2
+ import torch
3
+
4
+ from test.utils import assert_verbose_allclose
5
+ from test.utils import set_seed
6
+
7
+ from liger_kernel.transformers.functional import liger_sparsemax
8
+ from liger_kernel.transformers.sparsemax import LigerSparsemax
9
+ from liger_kernel.utils import infer_device
10
+
11
+ device = infer_device()
12
+
13
+
14
+ def torch_sparsemax(input_tensor: torch.Tensor, dim: int = -1) -> torch.Tensor:
15
+ input_dims = input_tensor.dim()
16
+ if dim < 0:
17
+ dim = input_dims + dim
18
+ input_sorted, _ = torch.sort(input_tensor, dim=dim, descending=True)
19
+ cumsum_input = torch.cumsum(input_sorted, dim=dim)
20
+ input_size = input_tensor.size(dim)
21
+ range_tensor = torch.arange(1, input_size + 1, device=input_tensor.device, dtype=input_tensor.dtype)
22
+ shape = [1] * input_dims
23
+ shape[dim] = input_size
24
+ range_tensor = range_tensor.view(shape)
25
+ k_bound = 1 + range_tensor * input_sorted
26
+ support = k_bound > cumsum_input
27
+ k = support.sum(dim=dim, keepdim=True).clamp(min=1)
28
+ support_sum = (input_sorted * support).sum(dim=dim, keepdim=True)
29
+ tau = (support_sum - 1) / k
30
+ return torch.clamp(input_tensor - tau, min=0)
31
+
32
+
33
+ @pytest.mark.parametrize(
34
+ "batch_size, seq_len, features",
35
+ [
36
+ (2, 128, 512),
37
+ (5, 123, 123),
38
+ ],
39
+ )
40
+ @pytest.mark.parametrize("dim", [-1, 1])
41
+ @pytest.mark.parametrize(
42
+ "dtype, atol, rtol",
43
+ [(torch.float32, 1e-5, 1e-5)],
44
+ )
45
+ def test_liger_sparsemax_correctness(batch_size, seq_len, features, dim, dtype, atol, rtol):
46
+ set_seed(0)
47
+ shape = (batch_size, seq_len, features)
48
+ if dim >= len(shape) or dim < -len(shape):
49
+ pytest.skip("invalid dim")
50
+ if shape[dim if dim >= 0 else len(shape) + dim] <= 1:
51
+ pytest.skip("trivial dim")
52
+
53
+ x = torch.randn(*shape, dtype=dtype, device=device)
54
+ lx = x.clone().requires_grad_(True)
55
+ tx = x.clone().requires_grad_(True)
56
+
57
+ model = LigerSparsemax(dim=dim).to(device)
58
+ out_l = model(lx)
59
+ out_t = torch_sparsemax(tx, dim=dim)
60
+ assert_verbose_allclose(out_l, out_t, atol=atol, rtol=rtol)
61
+
62
+ sum_l = out_l.sum(dim=dim)
63
+ sum_t = out_t.sum(dim=dim)
64
+ assert_verbose_allclose(sum_l, torch.ones_like(sum_l), atol=atol * 10, rtol=rtol * 10)
65
+ assert_verbose_allclose(sum_t, torch.ones_like(sum_t), atol=atol * 10, rtol=rtol * 10)
66
+
67
+ g = torch.randn_like(x)
68
+ out_l.backward(g)
69
+ out_t.backward(g)
70
+ assert_verbose_allclose(lx.grad, tx.grad, atol=atol, rtol=rtol)
71
+
72
+
73
+ @pytest.mark.parametrize(
74
+ "batch_size, seq_len, features",
75
+ [
76
+ (2, 128, 512),
77
+ (5, 123, 123),
78
+ ],
79
+ )
80
+ @pytest.mark.parametrize("dim", [-1, 1])
81
+ @pytest.mark.parametrize(
82
+ "dtype, atol, rtol",
83
+ [
84
+ (torch.float32, 1e-5, 1e-5),
85
+ ],
86
+ )
87
+ def test_liger_sparsemax_functional_correctness(batch_size, seq_len, features, dim, dtype, atol, rtol):
88
+ set_seed(0)
89
+ shape = (batch_size, seq_len, features)
90
+ if dim >= len(shape) or dim < -len(shape):
91
+ pytest.skip("invalid dim")
92
+ if shape[dim if dim >= 0 else len(shape) + dim] <= 1:
93
+ pytest.skip("trivial dim")
94
+
95
+ x = torch.randn(*shape, dtype=dtype, device=device)
96
+ lx = x.clone().requires_grad_(True)
97
+ tx = x.clone().requires_grad_(True)
98
+
99
+ out_l = liger_sparsemax(lx, dim=dim)
100
+ out_t = torch_sparsemax(tx, dim=dim)
101
+ assert_verbose_allclose(out_l, out_t, atol=atol, rtol=rtol)
102
+
103
+ sum_l = out_l.sum(dim=dim)
104
+ sum_t = out_t.sum(dim=dim)
105
+ assert_verbose_allclose(sum_l, torch.ones_like(sum_l), atol=atol * 10, rtol=rtol * 10)
106
+ assert_verbose_allclose(sum_t, torch.ones_like(sum_t), atol=atol * 10, rtol=rtol * 10)
107
+
108
+ g = torch.randn_like(x)
109
+ out_l.backward(g)
110
+ out_t.backward(g)
111
+ assert_verbose_allclose(lx.grad, tx.grad, atol=atol, rtol=rtol)