liger-kernel-nightly 0.5.9.dev20250512213150__tar.gz → 0.5.9.dev20250515065336__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/data/all_benchmark_data.csv +72 -0
  3. liger_kernel_nightly-0.5.9.dev20250515065336/benchmark/scripts/benchmark_sparsemax.py +172 -0
  4. liger_kernel_nightly-0.5.9.dev20250515065336/examples/medusa/requirements.txt +3 -0
  5. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/scripts/llama3_8b_medusa.sh +2 -5
  6. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/train.py +36 -38
  7. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/pyproject.toml +1 -1
  8. liger_kernel_nightly-0.5.9.dev20250515065336/src/liger_kernel/ops/sparsemax.py +167 -0
  9. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/functional.py +8 -0
  10. liger_kernel_nightly-0.5.9.dev20250515065336/src/liger_kernel/transformers/sparsemax.py +16 -0
  11. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
  12. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel_nightly.egg-info/SOURCES.txt +4 -0
  13. liger_kernel_nightly-0.5.9.dev20250515065336/test/transformers/test_sparsemax.py +111 -0
  14. liger_kernel_nightly-0.5.9.dev20250512213150/examples/medusa/requirements.txt +0 -3
  15. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  16. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  17. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/pull_request_template.md +0 -0
  18. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/workflows/amd-ci.yml +0 -0
  19. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/workflows/docs.yml +0 -0
  20. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/workflows/intel-ci.yml +0 -0
  21. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/workflows/nvi-ci.yml +0 -0
  22. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/workflows/publish-nightly.yml +0 -0
  23. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/.github/workflows/publish-release.yml +0 -0
  24. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/.gitignore +0 -0
  25. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/.idea/workspace.xml +0 -0
  26. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/LICENSE +0 -0
  27. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/Makefile +0 -0
  28. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/NOTICE +0 -0
  29. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/README.md +0 -0
  30. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/README.md +0 -0
  31. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/__init__.py +0 -0
  32. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/benchmarks_visualizer.py +0 -0
  33. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/__init__.py +0 -0
  34. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  35. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  36. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  37. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  38. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_dyt.py +0 -0
  39. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_embedding.py +0 -0
  40. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  41. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  42. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_geglu.py +0 -0
  43. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_group_norm.py +0 -0
  44. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_jsd.py +0 -0
  45. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_kl_div.py +0 -0
  46. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  47. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  48. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  49. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  50. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  51. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_rope.py +0 -0
  52. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  53. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_swiglu.py +0 -0
  54. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/benchmark_tvd.py +0 -0
  55. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/benchmark/scripts/utils.py +0 -0
  56. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/dev/fmt-requirements.txt +0 -0
  57. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/dev/modal/tests.py +0 -0
  58. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/dev/modal/tests_bwd.py +0 -0
  59. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/Examples.md +0 -0
  60. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/Getting-Started.md +0 -0
  61. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/High-Level-APIs.md +0 -0
  62. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/Low-Level-APIs.md +0 -0
  63. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/acknowledgement.md +0 -0
  64. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/contributing.md +0 -0
  65. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/images/banner.GIF +0 -0
  66. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/images/compose.gif +0 -0
  67. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/images/e2e-memory.png +0 -0
  68. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/images/e2e-tps.png +0 -0
  69. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/images/logo-banner.png +0 -0
  70. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/images/patch.gif +0 -0
  71. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/images/post-training.png +0 -0
  72. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/index.md +0 -0
  73. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/docs/license.md +0 -0
  74. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/alignment/accelerate_config.yaml +0 -0
  75. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/alignment/run_orpo.py +0 -0
  76. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/README.md +0 -0
  77. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/callback.py +0 -0
  78. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/config/fsdp_config.json +0 -0
  79. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  80. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  81. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  82. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/img/llama_tps.png +0 -0
  83. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  84. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/img/qwen_tps.png +0 -0
  85. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/launch_on_modal.py +0 -0
  86. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/requirements.txt +0 -0
  87. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/run_benchmarks.sh +0 -0
  88. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/run_gemma.sh +0 -0
  89. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/run_llama.sh +0 -0
  90. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/run_qwen.sh +0 -0
  91. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/run_qwen2_vl.sh +0 -0
  92. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/training.py +0 -0
  93. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/huggingface/training_multimodal.py +0 -0
  94. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/lightning/README.md +0 -0
  95. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/lightning/requirements.txt +0 -0
  96. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/lightning/training.py +0 -0
  97. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/README.md +0 -0
  98. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/callback.py +0 -0
  99. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  100. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  101. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  102. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  103. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  104. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  105. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  106. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  107. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  108. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/examples/medusa/medusa_util.py +0 -0
  109. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/licenses/LICENSE-Apache-2.0 +0 -0
  110. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  111. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  112. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/licenses/LICENSE-MIT-llmc +0 -0
  113. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/licenses/LICENSE-MIT-triton +0 -0
  114. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/mkdocs.yml +0 -0
  115. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/setup.cfg +0 -0
  116. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/setup.py +0 -0
  117. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/__init__.py +0 -0
  118. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/README.md +0 -0
  119. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  120. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  121. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  122. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/functional.py +0 -0
  123. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  124. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/fused_linear_ppo.py +0 -0
  125. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  126. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  127. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
  128. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  129. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  130. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  131. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  132. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/env_report.py +0 -0
  133. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/__init__.py +0 -0
  134. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/cross_entropy.py +0 -0
  135. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/dyt.py +0 -0
  136. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  137. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  138. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  139. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  140. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/geglu.py +0 -0
  141. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/group_norm.py +0 -0
  142. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/jsd.py +0 -0
  143. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/kl_div.py +0 -0
  144. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/layer_norm.py +0 -0
  145. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  146. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/rms_norm.py +0 -0
  147. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/rope.py +0 -0
  148. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/swiglu.py +0 -0
  149. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/tvd.py +0 -0
  150. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/ops/utils.py +0 -0
  151. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/__init__.py +0 -0
  152. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/auto_model.py +0 -0
  153. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  154. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/dyt.py +0 -0
  155. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  156. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  157. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  158. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/geglu.py +0 -0
  159. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/gema3_rms.py +0 -0
  160. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/group_norm.py +0 -0
  161. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/jsd.py +0 -0
  162. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/kl_div.py +0 -0
  163. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/layer_norm.py +0 -0
  164. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/__init__.py +0 -0
  165. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/gemma.py +0 -0
  166. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  167. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/gemma3.py +0 -0
  168. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/glm4.py +0 -0
  169. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/llama.py +0 -0
  170. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/llava.py +0 -0
  171. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
  172. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/mistral.py +0 -0
  173. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  174. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/mllama.py +0 -0
  175. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/olmo2.py +0 -0
  176. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/paligemma.py +0 -0
  177. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/phi3.py +0 -0
  178. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  179. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
  180. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  181. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/qwen3.py +0 -0
  182. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/model/qwen3_moe.py +0 -0
  183. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  184. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  185. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/rms_norm.py +0 -0
  186. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/rope.py +0 -0
  187. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/swiglu.py +0 -0
  188. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  189. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  190. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  191. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/transformers/tvd.py +0 -0
  192. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/triton/__init__.py +0 -0
  193. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/triton/monkey_patch.py +0 -0
  194. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel/utils.py +0 -0
  195. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  196. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  197. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  198. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/__init__.py +0 -0
  199. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/__init__.py +0 -0
  200. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/test_cpo_loss.py +0 -0
  201. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/test_dpo_loss.py +0 -0
  202. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/test_grpo_loss.py +0 -0
  203. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/test_jsd_loss.py +0 -0
  204. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/test_kto_loss.py +0 -0
  205. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/test_orpo_loss.py +0 -0
  206. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/chunked_loss/test_simpo_loss.py +0 -0
  207. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/conftest.py +0 -0
  208. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/__init__.py +0 -0
  209. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/bf16/__init__.py +0 -0
  210. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/bf16/test_mini_models.py +0 -0
  211. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
  212. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
  213. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/fp32/__init__.py +0 -0
  214. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/fp32/test_mini_models.py +0 -0
  215. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
  216. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
  217. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json +0 -0
  218. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
  219. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
  220. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
  221. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
  222. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  223. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
  224. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  225. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  226. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/tiny_shakespeare.txt +0 -0
  227. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  228. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  229. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  230. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_auto_model.py +0 -0
  231. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_cross_entropy.py +0 -0
  232. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_dyt.py +0 -0
  233. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_embedding.py +0 -0
  234. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_flex_attention.py +0 -0
  235. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  236. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_fused_linear_jsd.py +0 -0
  237. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_geglu.py +0 -0
  238. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_group_norm.py +0 -0
  239. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_jsd.py +0 -0
  240. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_kl_div.py +0 -0
  241. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_layer_norm.py +0 -0
  242. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_mm_int8int2.py +0 -0
  243. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_monkey_patch.py +0 -0
  244. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_qwen2vl_mrope.py +0 -0
  245. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_rms_norm.py +0 -0
  246. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_rope.py +0 -0
  247. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_swiglu.py +0 -0
  248. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_trainer_integration.py +0 -0
  249. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_transformers.py +0 -0
  250. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/transformers/test_tvd.py +0 -0
  251. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/triton/test_triton_monkey_patch.py +0 -0
  252. {liger_kernel_nightly-0.5.9.dev20250512213150 → liger_kernel_nightly-0.5.9.dev20250515065336}/test/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.9.dev20250512213150
3
+ Version: 0.5.9.dev20250515065336
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -805,3 +805,75 @@ kto_loss,huggingface,full,memory,MB,B,Batch Size (B),4,6099.26513671875,6099.265
805
805
  kto_loss,huggingface,full,memory,MB,B,Batch Size (B),8,9613.298828125,9613.298828125,9613.298828125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
806
806
  kto_loss,huggingface,full,memory,MB,B,Batch Size (B),16,16643.365234375,16643.365234375,16643.365234375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
807
807
  kto_loss,huggingface,full,memory,MB,B,Batch Size (B),32,30703.498046875,30703.498046875,30703.498046875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA H100 80GB HBM3,2025-03-03 08:24:11,0.5.4
808
+ sparsemax,liger,forward,speed,ms,V,feature size,1024,0.41471999883651733,0.4126720130443573,0.42393600940704346,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
809
+ sparsemax,liger,forward,speed,ms,V,feature size,2048,0.7608320116996765,0.7598080039024353,0.7628800272941589,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
810
+ sparsemax,liger,forward,speed,ms,V,feature size,4096,1.4561280012130737,1.4540799856185913,1.4581760168075562,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
811
+ sparsemax,liger,forward,speed,ms,V,feature size,8192,5.288959980010986,5.2848639488220215,5.29986572265625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
812
+ sparsemax,liger,forward,speed,ms,V,feature size,16384,10.734624862670898,10.729472160339355,11.096882820129395,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
813
+ sparsemax,liger,forward,speed,ms,V,feature size,32768,21.729312896728516,21.7128963470459,22.20728302001953,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:08,0.5.8
814
+ sparsemax,torch,forward,speed,ms,V,feature size,1024,0.42291200160980225,0.42188799381256104,0.42393600940704346,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
815
+ sparsemax,torch,forward,speed,ms,V,feature size,2048,0.7782400250434875,0.7772160172462463,0.779263973236084,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
816
+ sparsemax,torch,forward,speed,ms,V,feature size,4096,1.4940160512924194,1.491968035697937,1.4960639476776123,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
817
+ sparsemax,torch,forward,speed,ms,V,feature size,8192,5.359615802764893,5.356544017791748,5.366579055786133,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
818
+ sparsemax,torch,forward,speed,ms,V,feature size,16384,10.883584022521973,10.874879837036133,11.224268913269043,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
819
+ sparsemax,torch,forward,speed,ms,V,feature size,32768,22.19878387451172,22.018457412719727,22.48888397216797,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:12,0.5.8
820
+ sparsemax,liger,full,speed,ms,V,feature size,1024,0.4558719992637634,0.45558398962020874,0.45772799849510193,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
821
+ sparsemax,liger,full,speed,ms,V,feature size,2048,0.8488960266113281,0.8478720188140869,0.8509439826011658,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
822
+ sparsemax,liger,full,speed,ms,V,feature size,4096,1.6476160287857056,1.6465920209884644,1.6499264240264893,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
823
+ sparsemax,liger,full,speed,ms,V,feature size,8192,5.664768218994141,5.660672187805176,5.681356906890869,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
824
+ sparsemax,liger,full,speed,ms,V,feature size,16384,11.486207962036133,11.478015899658203,11.874713897705078,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
825
+ sparsemax,liger,full,speed,ms,V,feature size,32768,23.457279205322266,23.289682388305664,23.76642608642578,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:16,0.5.8
826
+ sparsemax,torch,full,speed,ms,V,feature size,1024,0.6021119952201843,0.6010879874229431,0.6041600108146667,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
827
+ sparsemax,torch,full,speed,ms,V,feature size,2048,1.1212799549102783,1.119264006614685,1.1223039627075195,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
828
+ sparsemax,torch,full,speed,ms,V,feature size,4096,2.1637120246887207,2.1616640090942383,2.165760040283203,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
829
+ sparsemax,torch,full,speed,ms,V,feature size,8192,6.693888187408447,6.68723201751709,6.705561637878418,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
830
+ sparsemax,torch,full,speed,ms,V,feature size,16384,13.523456573486328,13.518848419189453,13.878681182861328,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
831
+ sparsemax,torch,full,speed,ms,V,feature size,32768,27.604991912841797,27.295129776000977,27.77518081665039,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:20,0.5.8
832
+ sparsemax,liger,backward,speed,ms,V,feature size,1024,0.04403200000524521,0.043007999658584595,0.05222399905323982,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
833
+ sparsemax,liger,backward,speed,ms,V,feature size,2048,0.08806400001049042,0.08713600039482117,0.08806400001049042,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
834
+ sparsemax,liger,backward,speed,ms,V,feature size,4096,0.1884160041809082,0.1884160041809082,0.18943999707698822,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
835
+ sparsemax,liger,backward,speed,ms,V,feature size,8192,0.374783992767334,0.37376001477241516,0.37486720085144043,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
836
+ sparsemax,liger,backward,speed,ms,V,feature size,16384,0.7516160011291504,0.7505919933319092,0.7516160011291504,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
837
+ sparsemax,liger,backward,speed,ms,V,feature size,32768,1.5738879442214966,1.572864055633545,1.575935959815979,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:22,0.5.8
838
+ sparsemax,torch,backward,speed,ms,V,feature size,1024,0.1812479943037033,0.1802240014076233,0.18227200210094452,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
839
+ sparsemax,torch,backward,speed,ms,V,feature size,2048,0.34406399726867676,0.34406399726867676,0.34508800506591797,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
840
+ sparsemax,torch,backward,speed,ms,V,feature size,4096,0.6717439889907837,0.6707199811935425,0.6727679967880249,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
841
+ sparsemax,torch,backward,speed,ms,V,feature size,8192,1.3250559568405151,1.3241215944290161,1.3260799646377563,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
842
+ sparsemax,torch,backward,speed,ms,V,feature size,16384,2.629631996154785,2.628607988357544,2.6306560039520264,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
843
+ sparsemax,torch,backward,speed,ms,V,feature size,32768,5.236735820770264,5.235712051391602,5.239808082580566,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
844
+ sparsemax,liger,full,memory,MB,V,feature size,1024,82.03515625,82.03515625,82.03515625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
845
+ sparsemax,liger,full,memory,MB,V,feature size,2048,164.0390625,164.0390625,164.0390625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
846
+ sparsemax,liger,full,memory,MB,V,feature size,4096,328.046875,328.046875,328.046875,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
847
+ sparsemax,liger,full,memory,MB,V,feature size,8192,704.00048828125,704.00048828125,704.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
848
+ sparsemax,liger,full,memory,MB,V,feature size,16384,1408.00048828125,1408.00048828125,1408.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
849
+ sparsemax,liger,full,memory,MB,V,feature size,32768,2816.00048828125,2816.00048828125,2816.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:25,0.5.8
850
+ sparsemax,torch,full,memory,MB,V,feature size,1024,82.03515625,82.03515625,82.03515625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
851
+ sparsemax,torch,full,memory,MB,V,feature size,2048,164.0390625,164.0390625,164.0390625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
852
+ sparsemax,torch,full,memory,MB,V,feature size,4096,328.046875,328.046875,328.046875,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
853
+ sparsemax,torch,full,memory,MB,V,feature size,8192,704.00048828125,704.00048828125,704.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
854
+ sparsemax,torch,full,memory,MB,V,feature size,16384,1408.00048828125,1408.00048828125,1408.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
855
+ sparsemax,torch,full,memory,MB,V,feature size,32768,2816.00048828125,2816.00048828125,2816.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-04-28 00:38:26,0.5.8
856
+ sparsemax,liger,forward,memory,MB,V,feature size,1024,56.0078125,56.0078125,56.0078125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
857
+ sparsemax,liger,forward,memory,MB,V,feature size,2048,112.015625,112.015625,112.015625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
858
+ sparsemax,liger,forward,memory,MB,V,feature size,4096,224.03125,224.03125,224.03125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
859
+ sparsemax,liger,forward,memory,MB,V,feature size,8192,768.00048828125,768.00048828125,768.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
860
+ sparsemax,liger,forward,memory,MB,V,feature size,16384,1536.00048828125,1536.00048828125,1536.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
861
+ sparsemax,liger,forward,memory,MB,V,feature size,32768,3072.00048828125,3072.00048828125,3072.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
862
+ sparsemax,torch,forward,memory,MB,V,feature size,1024,82.03515625,82.03515625,82.03515625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
863
+ sparsemax,torch,forward,memory,MB,V,feature size,2048,164.0390625,164.0390625,164.0390625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
864
+ sparsemax,torch,forward,memory,MB,V,feature size,4096,328.046875,328.046875,328.046875,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
865
+ sparsemax,torch,forward,memory,MB,V,feature size,8192,704.00048828125,704.00048828125,704.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
866
+ sparsemax,torch,forward,memory,MB,V,feature size,16384,1408.00048828125,1408.00048828125,1408.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
867
+ sparsemax,torch,forward,memory,MB,V,feature size,32768,2816.00048828125,2816.00048828125,2816.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:39,0.5.8
868
+ sparsemax,liger,backward,memory,MB,V,feature size,1024,56.0078125,56.0078125,56.0078125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
869
+ sparsemax,liger,backward,memory,MB,V,feature size,2048,112.015625,112.015625,112.015625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
870
+ sparsemax,liger,backward,memory,MB,V,feature size,4096,224.03125,224.03125,224.03125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
871
+ sparsemax,liger,backward,memory,MB,V,feature size,8192,768.00048828125,768.00048828125,768.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
872
+ sparsemax,liger,backward,memory,MB,V,feature size,16384,1536.00048828125,1536.00048828125,1536.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
873
+ sparsemax,liger,backward,memory,MB,V,feature size,32768,3072.00048828125,3072.00048828125,3072.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:40,0.5.8
874
+ sparsemax,torch,backward,memory,MB,V,feature size,1024,82.03515625,82.03515625,82.03515625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
875
+ sparsemax,torch,backward,memory,MB,V,feature size,2048,164.0390625,164.0390625,164.0390625,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
876
+ sparsemax,torch,backward,memory,MB,V,feature size,4096,328.046875,328.046875,328.046875,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
877
+ sparsemax,torch,backward,memory,MB,V,feature size,8192,704.00048828125,704.00048828125,704.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
878
+ sparsemax,torch,backward,memory,MB,V,feature size,16384,1408.00048828125,1408.00048828125,1408.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
879
+ sparsemax,torch,backward,memory,MB,V,feature size,32768,2816.00048828125,2816.00048828125,2816.00048828125,"{""B"": 4, ""T"": 512, ""dim"": -1, ""dtype"": ""torch.float32""}",NVIDIA GeForce RTX 3090,2025-05-15 02:04:41,0.5.8
@@ -0,0 +1,172 @@
1
+ import torch
2
+ import triton
3
+
4
+ from utils import QUANTILES
5
+ from utils import SingleBenchmarkRunInput
6
+ from utils import SingleBenchmarkRunOutput
7
+ from utils import _test_memory
8
+ from utils import parse_benchmark_script_args
9
+ from utils import run_benchmarks
10
+
11
+ from liger_kernel.transformers.sparsemax import LigerSparsemax
12
+ from liger_kernel.utils import infer_device
13
+
14
+ device = infer_device()
15
+
16
+
17
+ def torch_sparsemax(input_tensor: torch.Tensor, dim: int = -1) -> torch.Tensor:
18
+ input_dims = input_tensor.dim()
19
+ if dim < 0:
20
+ dim = input_dims + dim
21
+ input_sorted, _ = torch.sort(input_tensor, dim=dim, descending=True)
22
+ cumsum_input = torch.cumsum(input_sorted, dim=dim)
23
+ input_size = input_tensor.size(dim)
24
+ range_tensor = torch.arange(1, input_size + 1, device=input_tensor.device, dtype=input_tensor.dtype)
25
+ shape = [1] * input_dims
26
+ shape[dim] = input_size
27
+ range_tensor = range_tensor.view(shape)
28
+ k_bound = 1 + range_tensor * input_sorted
29
+ support = k_bound > cumsum_input
30
+ k = support.sum(dim=dim, keepdim=True).clamp(min=1)
31
+ support_sum = (input_sorted * support).sum(dim=dim, keepdim=True)
32
+ tau = (support_sum - 1) / k
33
+ return torch.clamp(input_tensor - tau, min=0)
34
+
35
+
36
+ class TorchSparsemax(torch.nn.Module):
37
+ def __init__(self, dim: int = -1):
38
+ super().__init__()
39
+ self.dim = dim
40
+
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ return torch_sparsemax(x, dim=self.dim)
43
+
44
+
45
+ def bench_speed_sparsemax(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
46
+ V = input.x
47
+ provider = input.kernel_provider
48
+ mode = input.kernel_operation_mode
49
+
50
+ extra_benchmark_config = input.extra_benchmark_config
51
+ B = extra_benchmark_config["B"]
52
+ T = extra_benchmark_config["T"]
53
+ dim = extra_benchmark_config["dim"]
54
+ dtype = extra_benchmark_config["dtype"]
55
+
56
+ x_shape = (B * T, V)
57
+
58
+ torch_sparsemax_module = TorchSparsemax(dim=dim).to(device)
59
+ liger_sparsemax_module = LigerSparsemax(dim=dim).to(device)
60
+
61
+ x = torch.randn(x_shape, dtype=dtype, device=device)
62
+ dy = torch.randn_like(x)
63
+ x.requires_grad_(True)
64
+
65
+ # utility functions
66
+ def y_fwd():
67
+ if provider == "liger":
68
+ return liger_sparsemax_module(x)
69
+ elif provider == "torch":
70
+ return torch_sparsemax_module(x)
71
+
72
+ if mode == "forward":
73
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(
74
+ y_fwd,
75
+ grad_to_none=[x],
76
+ rep=500,
77
+ quantiles=QUANTILES,
78
+ )
79
+ elif mode == "backward":
80
+ y = y_fwd()
81
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(
82
+ lambda: y.backward(dy, retain_graph=True),
83
+ grad_to_none=[x],
84
+ rep=500,
85
+ quantiles=QUANTILES,
86
+ )
87
+ elif mode == "full":
88
+
89
+ def full():
90
+ y = y_fwd()
91
+ y.backward(dy, retain_graph=True)
92
+
93
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(
94
+ full,
95
+ grad_to_none=[x],
96
+ rep=500,
97
+ quantiles=QUANTILES,
98
+ )
99
+
100
+ return SingleBenchmarkRunOutput(
101
+ y_20=ms_20,
102
+ y_50=ms_50,
103
+ y_80=ms_80,
104
+ )
105
+
106
+
107
+ def bench_memory_sparsemax(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
108
+ V = input.x
109
+ provider = input.kernel_provider
110
+
111
+ extra_benchmark_config = input.extra_benchmark_config
112
+ B = extra_benchmark_config["B"]
113
+ T = extra_benchmark_config["T"]
114
+ dim = extra_benchmark_config["dim"]
115
+ dtype = extra_benchmark_config["dtype"]
116
+
117
+ x_shape = (B * T, V)
118
+
119
+ torch_sparsemax_module = TorchSparsemax(dim=dim).to(device)
120
+ liger_sparsemax_module = LigerSparsemax(dim=dim).to(device)
121
+
122
+ x = torch.randn(x_shape, dtype=dtype, device=device)
123
+ dy = torch.randn_like(x)
124
+ x.requires_grad_(True)
125
+
126
+ # utility functions
127
+ def y_fwd():
128
+ if provider == "liger":
129
+ return liger_sparsemax_module(x)
130
+ elif provider == "torch":
131
+ return torch_sparsemax_module(x)
132
+
133
+ def full():
134
+ y = y_fwd()
135
+ y.backward(dy, retain_graph=True)
136
+
137
+ mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
138
+
139
+ return SingleBenchmarkRunOutput(
140
+ y_20=mem_20,
141
+ y_50=mem_50,
142
+ y_80=mem_80,
143
+ )
144
+
145
+
146
+ if __name__ == "__main__":
147
+ args = parse_benchmark_script_args()
148
+
149
+ common_configs = {
150
+ "kernel_name": "sparsemax",
151
+ "x_name": "V",
152
+ "x_label": "feature size",
153
+ "x_values": [2**i for i in range(10, 16)],
154
+ "kernel_providers": ["liger", "torch"],
155
+ "extra_benchmark_configs": [{"B": 4, "T": 512, "dim": -1, "dtype": torch.float32}],
156
+ "overwrite": args.overwrite,
157
+ }
158
+
159
+ run_benchmarks(
160
+ bench_test_fn=bench_speed_sparsemax,
161
+ kernel_operation_modes=["forward", "full", "backward"],
162
+ metric_name="speed",
163
+ metric_unit="ms",
164
+ **common_configs,
165
+ )
166
+ run_benchmarks(
167
+ bench_test_fn=bench_memory_sparsemax,
168
+ kernel_operation_modes=["full"],
169
+ metric_name="memory",
170
+ metric_unit="MB",
171
+ **common_configs,
172
+ )
@@ -0,0 +1,3 @@
1
+ accelerate==1.6.0
2
+ scikit-learn
3
+ transformers==4.51.3
@@ -22,9 +22,6 @@ export MEDUSA_LR_MULTIPLIER=4.0
22
22
  accelerate launch --config_file fsdp/acc-fsdp.conf \
23
23
  --num_machines $NUM_NODES \
24
24
  --num_processes $WORLD_SIZE \
25
- --main_process_ip $MASTER_ADDR \
26
- --main_process_port $MASTER_PORT \
27
- --machine_rank $RANK \
28
25
  train.py \
29
26
  --bf16 True \
30
27
  --output_dir $OUTPUT_DIR \
@@ -32,7 +29,7 @@ accelerate launch --config_file fsdp/acc-fsdp.conf \
32
29
  --per_device_train_batch_size $LOCAL_TRAIN_BATCH_SIZE \
33
30
  --per_device_eval_batch_size 1 \
34
31
  --gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \
35
- --evaluation_strategy "no" \
32
+ --eval_strategy "no" \
36
33
  --save_strategy "no" \
37
34
  --prediction_loss_only \
38
35
  --learning_rate $LR \
@@ -53,4 +50,4 @@ accelerate launch --config_file fsdp/acc-fsdp.conf \
53
50
  --medusa_lr_multiplier $MEDUSA_LR_MULTIPLIER \
54
51
  --medusa_only_heads False \
55
52
  --medusa_return True \
56
- --use_liger True
53
+ --use_liger True
@@ -32,21 +32,18 @@ from callback import EfficiencyCallback
32
32
  from medusa_util import add_medusa_heads
33
33
  from safetensors.torch import save_file
34
34
  from sklearn.model_selection import train_test_split
35
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
36
- from torch.distributed.fsdp.fully_sharded_data_parallel import FullStateDictConfig
37
- from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
38
35
  from torch.utils.data import Dataset
39
36
  from transformers import Trainer
40
37
  from transformers.trainer_pt_utils import LabelSmoother
41
38
 
42
- from liger_kernel.transformers import apply_liger_kernel_to_llama
39
+ from liger_kernel.transformers import AutoLigerKernelForCausalLM
43
40
 
44
41
  IGNORE_TOKEN_ID = LabelSmoother.ignore_index
45
42
 
46
43
 
47
44
  @dataclass
48
45
  class ModelArguments:
49
- model_name_or_path: Optional[str] = field(default="meta-llama/Meta-Llama-3-8B")
46
+ model_name_or_path: Optional[str] = field(default="meta-llama/Meta-Llama-3-8B-Instruct")
50
47
 
51
48
 
52
49
  @dataclass
@@ -310,29 +307,36 @@ def train():
310
307
  print(tokenizer(["This is a test", "secondary"], padding=True))
311
308
  print(tokenizer.apply_chat_template([{"role": "user", "content": "This is a test"}]))
312
309
 
313
- # Load model and tokenizer
314
- model = transformers.AutoModelForCausalLM.from_pretrained(
315
- model_args.model_name_or_path,
316
- # config=config,
317
- cache_dir=training_args.cache_dir,
318
- torch_dtype=torch.bfloat16,
319
- )
310
+ def _model_loader():
311
+ # we use a customized model loader to inject medusa heads to FSDP-wrapped model variables properly.
312
+ # see https://github.com/linkedin/Liger-Kernel/issues/309#issuecomment-2455077623 for details.
320
313
 
321
- if training_args.use_liger is True:
322
- apply_liger_kernel_to_llama()
314
+ # Load model
315
+ if training_args.use_liger:
316
+ model_builder = AutoLigerKernelForCausalLM.from_pretrained
317
+ else:
318
+ model_builder = transformers.AutoModelForCausalLM.from_pretrained
319
+ model = model_builder(
320
+ model_args.model_name_or_path,
321
+ cache_dir=training_args.cache_dir,
322
+ torch_dtype=torch.bfloat16,
323
+ )
323
324
 
324
- # Freeze the base model
325
- for param in model.base_model.parameters():
326
- param.requires_grad = False
325
+ # Freeze the base model
326
+ for param in model.base_model.parameters():
327
+ param.requires_grad = False
328
+
329
+ # Inject Medusa heads
330
+ add_medusa_heads(
331
+ model,
332
+ training_args.medusa_num_heads,
333
+ training_args.medusa_num_layers,
334
+ training_args.medusa_return,
335
+ training_args.medusa_only_heads,
336
+ training_args.use_liger,
337
+ )
338
+ return model
327
339
 
328
- add_medusa_heads(
329
- model,
330
- training_args.medusa_num_heads,
331
- training_args.medusa_num_layers,
332
- training_args.medusa_return,
333
- training_args.medusa_only_heads,
334
- training_args.use_liger,
335
- )
336
340
  # Format output dir
337
341
  training_args.output_dir = f"{training_args.output_dir}_medusa_mlp_{model_args.model_name_or_path.split('/')[-1]}_medusa_{training_args.medusa_num_heads}_lr_{training_args.learning_rate}_layers_{training_args.medusa_num_layers}"
338
342
 
@@ -341,7 +345,7 @@ def train():
341
345
 
342
346
  # Start trainner
343
347
  trainer = Trainer(
344
- model=model,
348
+ model_init=_model_loader,
345
349
  tokenizer=tokenizer,
346
350
  args=training_args,
347
351
  callbacks=[EfficiencyCallback()],
@@ -355,17 +359,11 @@ def train():
355
359
 
356
360
  if training_args.medusa_return and training_args.medusa_only_heads:
357
361
  # Save only the updated head without saving the backbone model
358
- if hasattr(model, "module"):
359
- lm_head = model.module.medusa_head
360
- else:
361
- lm_head = model.medusa_head
362
-
363
- with FSDP.state_dict_type(
364
- model,
365
- StateDictType.FULL_STATE_DICT,
366
- FullStateDictConfig(offload_to_cpu=True),
367
- ):
368
- state_dict = lm_head.state_dict()
362
+ state_dict = {
363
+ k.replace("medusa_head.", ""): v.to(torch.bfloat16)
364
+ for k, v in trainer.accelerator.get_state_dict(trainer.model).items()
365
+ if "medusa_head" in k
366
+ }
369
367
 
370
368
  # Save Medusa heads
371
369
  if local_rank == 0:
@@ -373,9 +371,9 @@ def train():
373
371
  state_dict,
374
372
  os.path.join(training_args.output_dir, "medusa_lm_head.safetensors"),
375
373
  )
374
+ trainer.accelerator.wait_for_everyone()
376
375
  else:
377
376
  # Save the whole model weight
378
- trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
379
377
  trainer.save_model(training_args.output_dir)
380
378
 
381
379
 
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.9.dev20250512213150"
7
+ version = "0.5.9.dev20250515065336"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -0,0 +1,167 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from liger_kernel.ops.utils import calculate_settings
6
+ from liger_kernel.ops.utils import ensure_contiguous
7
+
8
+
9
+ @triton.jit
10
+ def _sparsemax_forward_kernel(
11
+ x_ptr,
12
+ x_stride_row,
13
+ sorted_x_ptr,
14
+ sorted_x_stride_row,
15
+ o_ptr,
16
+ o_stride_row,
17
+ n_cols,
18
+ BLOCK_SIZE: tl.constexpr,
19
+ num_warps: tl.constexpr,
20
+ ):
21
+ pid_row = tl.program_id(0)
22
+ ptr_x_data_row = x_ptr + pid_row * x_stride_row
23
+ ptr_sorted_x_data_row = sorted_x_ptr + pid_row * sorted_x_stride_row
24
+ ptr_output_row = o_ptr + pid_row * o_stride_row
25
+
26
+ offs = tl.arange(0, BLOCK_SIZE)
27
+ mask = offs < n_cols
28
+
29
+ z_sorted_block = tl.load(
30
+ ptr_sorted_x_data_row + offs,
31
+ mask=mask,
32
+ other=-float("inf"),
33
+ cache_modifier=".ca",
34
+ ).to(tl.float32)
35
+
36
+ z_valid = tl.where(mask, z_sorted_block, 0.0)
37
+ cssv = tl.cumsum(z_valid, 0)
38
+
39
+ r = (offs + 1).to(tl.float32)
40
+ safe_r = tl.where(mask, r, 1.0)
41
+
42
+ t_vec = (cssv - 1.0) / safe_r
43
+
44
+ support = (z_sorted_block > t_vec) & mask
45
+
46
+ k_int = tl.sum(support.to(tl.int32), 0)
47
+ k_clamped_int = tl.maximum(k_int, 1)
48
+ k = k_clamped_int.to(tl.float32)
49
+
50
+ s = tl.sum(tl.where(support, z_sorted_block, 0.0), 0)
51
+
52
+ tau = (s - 1.0) / k
53
+
54
+ x_block = tl.load(
55
+ ptr_x_data_row + offs,
56
+ mask=mask,
57
+ other=0.0,
58
+ cache_modifier=".ca",
59
+ ).to(tl.float32)
60
+
61
+ y = tl.maximum(x_block - tau, 0.0)
62
+
63
+ tl.store(
64
+ ptr_output_row + offs,
65
+ y.to(ptr_output_row.dtype.element_ty),
66
+ mask=mask,
67
+ cache_modifier=".cs",
68
+ )
69
+
70
+
71
+ @triton.jit
72
+ def _sparsemax_backward_kernel(
73
+ o_ptr, go_ptr, gi_ptr, stride, n_cols, BLOCK_SIZE: tl.constexpr, num_warps: tl.constexpr
74
+ ):
75
+ row = tl.program_id(0)
76
+ o_row = o_ptr + row * stride
77
+ go_row = go_ptr + row * stride
78
+ gi_row = gi_ptr + row * stride
79
+
80
+ offs = tl.arange(0, BLOCK_SIZE)
81
+
82
+ supp_cnt = tl.zeros((), tl.float32)
83
+ go_sum = tl.zeros((), tl.float32)
84
+
85
+ for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
86
+ offs_iter = i * BLOCK_SIZE + offs
87
+ mask_iter = offs_iter < n_cols
88
+ o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
89
+ go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
90
+ supp = o_val > 0.0
91
+ go_sum += tl.sum(tl.where(supp, go_val, 0.0))
92
+ supp_cnt += tl.sum(supp.to(tl.float32))
93
+
94
+ for i in tl.range(0, tl.cdiv(n_cols, BLOCK_SIZE)):
95
+ offs_iter = i * BLOCK_SIZE + offs
96
+ mask_iter = offs_iter < n_cols
97
+ o_val = tl.load(o_row + offs_iter, mask=mask_iter, other=0.0, cache_modifier=".ca").to(tl.float32)
98
+ go_val = tl.load(go_row + offs_iter, mask=mask_iter, other=0.0).to(tl.float32)
99
+ supp = o_val > 0.0
100
+ gi_val = tl.where(
101
+ supp,
102
+ go_val - tl.cast(go_sum / tl.maximum(supp_cnt, 1e-6), gi_row.dtype.element_ty).to(tl.float32),
103
+ 0.0,
104
+ )
105
+ tl.store(gi_row + offs_iter, gi_val.to(gi_row.dtype.element_ty), mask=mask_iter, cache_modifier=".wb")
106
+
107
+
108
+ class LigerSparsemaxFunction(torch.autograd.Function):
109
+ @staticmethod
110
+ @ensure_contiguous
111
+ def forward(ctx, x: torch.Tensor, dim: int):
112
+ if dim < 0:
113
+ dim += x.dim()
114
+ ctx.dim = dim
115
+
116
+ x_sw = x.transpose(dim, -1).contiguous()
117
+ n_cols = x_sw.size(-1)
118
+ n_rows = x_sw.numel() // n_cols
119
+ x_flat = x_sw.view(n_rows, n_cols)
120
+
121
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
122
+ out_flat = torch.empty_like(x_flat)
123
+ grid = (n_rows,)
124
+
125
+ x_sorted_flat = torch.sort(x_flat.float(), dim=-1, descending=True).values
126
+
127
+ _sparsemax_forward_kernel[grid](
128
+ x_flat,
129
+ x_flat.stride(0),
130
+ x_sorted_flat,
131
+ x_sorted_flat.stride(0),
132
+ out_flat,
133
+ out_flat.stride(0),
134
+ n_cols,
135
+ BLOCK_SIZE=BLOCK_SIZE,
136
+ num_warps=num_warps,
137
+ )
138
+
139
+ ctx.save_for_backward(out_flat)
140
+ return out_flat.view_as(x_sw).transpose(dim, -1)
141
+
142
+ @staticmethod
143
+ @ensure_contiguous
144
+ def backward(ctx, grad_out: torch.Tensor):
145
+ (out_flat,) = ctx.saved_tensors
146
+ dim = ctx.dim
147
+
148
+ go_sw = grad_out.transpose(dim, -1).contiguous()
149
+ n_cols = go_sw.size(-1)
150
+ n_rows = go_sw.numel() // n_cols
151
+ go_flat = go_sw.view(n_rows, n_cols)
152
+
153
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
154
+ gi_flat = torch.empty_like(go_flat)
155
+ grid = (n_rows,)
156
+
157
+ _sparsemax_backward_kernel[grid](
158
+ out_flat,
159
+ go_flat,
160
+ gi_flat,
161
+ out_flat.stride(0),
162
+ n_cols,
163
+ BLOCK_SIZE=BLOCK_SIZE,
164
+ num_warps=num_warps,
165
+ )
166
+
167
+ return gi_flat.view_as(go_sw).transpose(dim, -1), None
@@ -12,6 +12,7 @@ from liger_kernel.ops.layer_norm import LigerLayerNormFunction
12
12
  from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
13
13
  from liger_kernel.ops.rms_norm import LigerRMSNormFunction
14
14
  from liger_kernel.ops.rope import LigerRopeFunction
15
+ from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
15
16
  from liger_kernel.ops.swiglu import LigerSiLUMulFunction
16
17
  from liger_kernel.ops.tvd import LigerTVDLossFunction
17
18
 
@@ -159,6 +160,13 @@ def liger_kl_div(
159
160
  )
160
161
 
161
162
 
163
+ def liger_sparsemax(
164
+ input,
165
+ dim: int = -1,
166
+ ):
167
+ return LigerSparsemaxFunction.apply(input, dim)
168
+
169
+
162
170
  def liger_tvd(
163
171
  input,
164
172
  target,
@@ -0,0 +1,16 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from liger_kernel.ops.sparsemax import LigerSparsemaxFunction
5
+
6
+
7
+ class LigerSparsemax(nn.Module):
8
+ def __init__(self, dim: int = -1):
9
+ super().__init__()
10
+ self.dim = dim
11
+
12
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
13
+ return LigerSparsemaxFunction.apply(x, self.dim)
14
+
15
+ def extra_repr(self) -> str:
16
+ return f"dim={self.dim}"