liger-kernel-nightly 0.5.3.dev20250221011147__tar.gz → 0.5.3.dev20250221011217__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/pyproject.toml +1 -1
  3. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
  4. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel_nightly.egg-info/SOURCES.txt +1 -0
  5. liger_kernel_nightly-0.5.3.dev20250221011217/test/transformers/test_flex_attention.py +283 -0
  6. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  7. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  8. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/.github/pull_request_template.md +0 -0
  9. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/.github/workflows/amd-ci.yml +0 -0
  10. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/.github/workflows/docs.yml +0 -0
  11. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/.github/workflows/intel-ci.yml +0 -0
  12. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/.github/workflows/nvi-ci.yml +0 -0
  13. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/.github/workflows/publish-nightly.yml +0 -0
  14. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/.github/workflows/publish-release.yml +0 -0
  15. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/.gitignore +0 -0
  16. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/LICENSE +0 -0
  17. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/Makefile +0 -0
  18. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/NOTICE +0 -0
  19. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/README.md +0 -0
  20. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/benchmark/README.md +0 -0
  21. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/benchmark/__init__.py +0 -0
  22. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/benchmark/benchmarks_visualizer.py +0 -0
  23. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/benchmark/data/all_benchmark_data.csv +0 -0
  24. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/benchmark/scripts/__init__.py +0 -0
  25. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  26. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  27. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  28. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  29. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/benchmark/scripts/benchmark_embedding.py +0 -0
  30. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  31. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  32. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/benchmark/scripts/benchmark_geglu.py +0 -0
  33. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/benchmark/scripts/benchmark_group_norm.py +0 -0
  34. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/benchmark/scripts/benchmark_jsd.py +0 -0
  35. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/benchmark/scripts/benchmark_kl_div.py +0 -0
  36. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  37. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  38. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  39. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  40. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  41. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/benchmark/scripts/benchmark_rope.py +0 -0
  42. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  43. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/benchmark/scripts/benchmark_swiglu.py +0 -0
  44. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/benchmark/scripts/benchmark_tvd.py +0 -0
  45. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/benchmark/scripts/utils.py +0 -0
  46. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/dev/fmt-requirements.txt +0 -0
  47. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/dev/modal/tests.py +0 -0
  48. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/dev/modal/tests_bwd.py +0 -0
  49. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/docs/Examples.md +0 -0
  50. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/docs/Getting-Started.md +0 -0
  51. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/docs/High-Level-APIs.md +0 -0
  52. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/docs/Low-Level-APIs.md +0 -0
  53. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/docs/acknowledgement.md +0 -0
  54. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/docs/contributing.md +0 -0
  55. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/docs/images/banner.GIF +0 -0
  56. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/docs/images/compose.gif +0 -0
  57. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/docs/images/e2e-memory.png +0 -0
  58. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/docs/images/e2e-tps.png +0 -0
  59. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/docs/images/logo-banner.png +0 -0
  60. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/docs/images/patch.gif +0 -0
  61. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/docs/images/post-training.png +0 -0
  62. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/docs/index.md +0 -0
  63. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/docs/license.md +0 -0
  64. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/alignment/accelerate_config.yaml +0 -0
  65. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/alignment/run_orpo.py +0 -0
  66. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/huggingface/README.md +0 -0
  67. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/huggingface/callback.py +0 -0
  68. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/huggingface/config/fsdp_config.json +0 -0
  69. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  70. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  71. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  72. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/huggingface/img/llama_tps.png +0 -0
  73. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  74. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/huggingface/img/qwen_tps.png +0 -0
  75. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/huggingface/launch_on_modal.py +0 -0
  76. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/huggingface/requirements.txt +0 -0
  77. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/huggingface/run_benchmarks.sh +0 -0
  78. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/huggingface/run_gemma.sh +0 -0
  79. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/huggingface/run_llama.sh +0 -0
  80. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/huggingface/run_qwen.sh +0 -0
  81. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/huggingface/run_qwen2_vl.sh +0 -0
  82. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/huggingface/training.py +0 -0
  83. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/huggingface/training_multimodal.py +0 -0
  84. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/lightning/README.md +0 -0
  85. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/lightning/requirements.txt +0 -0
  86. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/lightning/training.py +0 -0
  87. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/medusa/README.md +0 -0
  88. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/medusa/callback.py +0 -0
  89. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  90. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  91. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  92. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  93. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  94. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  95. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  96. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  97. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  98. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/medusa/medusa_util.py +0 -0
  99. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/medusa/requirements.txt +0 -0
  100. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  101. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/examples/medusa/train.py +0 -0
  102. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/licenses/LICENSE-Apache-2.0 +0 -0
  103. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  104. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  105. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/licenses/LICENSE-MIT-llmc +0 -0
  106. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/licenses/LICENSE-MIT-triton +0 -0
  107. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/mkdocs.yml +0 -0
  108. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/setup.cfg +0 -0
  109. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/setup.py +0 -0
  110. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/__init__.py +0 -0
  111. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/chunked_loss/README.md +0 -0
  112. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  113. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  114. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  115. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/chunked_loss/functional.py +0 -0
  116. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  117. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  118. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -0
  119. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  120. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
  121. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  122. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  123. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  124. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  125. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/env_report.py +0 -0
  126. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/ops/__init__.py +0 -0
  127. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/ops/cross_entropy.py +0 -0
  128. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  129. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  130. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  131. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  132. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/ops/geglu.py +0 -0
  133. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/ops/group_norm.py +0 -0
  134. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/ops/jsd.py +0 -0
  135. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/ops/kl_div.py +0 -0
  136. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/ops/layer_norm.py +0 -0
  137. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  138. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/ops/rms_norm.py +0 -0
  139. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/ops/rope.py +0 -0
  140. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/ops/swiglu.py +0 -0
  141. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/ops/tvd.py +0 -0
  142. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/ops/utils.py +0 -0
  143. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/__init__.py +0 -0
  144. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/auto_model.py +0 -0
  145. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  146. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  147. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/functional.py +0 -0
  148. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  149. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  150. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/geglu.py +0 -0
  151. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/group_norm.py +0 -0
  152. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/jsd.py +0 -0
  153. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/kl_div.py +0 -0
  154. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/layer_norm.py +0 -0
  155. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/model/__init__.py +0 -0
  156. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/model/gemma.py +0 -0
  157. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  158. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/model/llama.py +0 -0
  159. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/model/mistral.py +0 -0
  160. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  161. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/model/mllama.py +0 -0
  162. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/model/phi3.py +0 -0
  163. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  164. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  165. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  166. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  167. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/rms_norm.py +0 -0
  168. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/rope.py +0 -0
  169. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/swiglu.py +0 -0
  170. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  171. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  172. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  173. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/transformers/tvd.py +0 -0
  174. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/triton/__init__.py +0 -0
  175. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/triton/monkey_patch.py +0 -0
  176. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel/utils.py +0 -0
  177. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  178. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  179. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  180. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/__init__.py +0 -0
  181. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/chunked_loss/__init__.py +0 -0
  182. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/chunked_loss/test_cpo_loss.py +0 -0
  183. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/chunked_loss/test_dpo_loss.py +0 -0
  184. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/chunked_loss/test_grpo_loss.py +0 -0
  185. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/chunked_loss/test_jsd_loss.py +0 -0
  186. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/chunked_loss/test_kto_loss.py +0 -0
  187. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/chunked_loss/test_orpo_loss.py +0 -0
  188. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/chunked_loss/test_simpo_loss.py +0 -0
  189. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/conftest.py +0 -0
  190. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/convergence/__init__.py +0 -0
  191. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/convergence/bf16/__init__.py +0 -0
  192. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/convergence/bf16/test_mini_models.py +0 -0
  193. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
  194. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
  195. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/convergence/fp32/__init__.py +0 -0
  196. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/convergence/fp32/test_mini_models.py +0 -0
  197. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
  198. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
  199. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  200. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  201. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  202. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/resources/tiny_shakespeare.txt +0 -0
  203. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  204. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  205. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  206. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/transformers/test_auto_model.py +0 -0
  207. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/transformers/test_cross_entropy.py +0 -0
  208. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/transformers/test_embedding.py +0 -0
  209. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  210. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/transformers/test_fused_linear_jsd.py +0 -0
  211. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/transformers/test_geglu.py +0 -0
  212. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/transformers/test_group_norm.py +0 -0
  213. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/transformers/test_jsd.py +0 -0
  214. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/transformers/test_kl_div.py +0 -0
  215. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/transformers/test_layer_norm.py +0 -0
  216. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/transformers/test_mm_int8int2.py +0 -0
  217. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/transformers/test_monkey_patch.py +0 -0
  218. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/transformers/test_qwen2vl_mrope.py +0 -0
  219. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/transformers/test_rms_norm.py +0 -0
  220. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/transformers/test_rope.py +0 -0
  221. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/transformers/test_swiglu.py +0 -0
  222. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/transformers/test_trainer_integration.py +0 -0
  223. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/transformers/test_transformers.py +0 -0
  224. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/transformers/test_tvd.py +0 -0
  225. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/triton/test_triton_monkey_patch.py +0 -0
  226. {liger_kernel_nightly-0.5.3.dev20250221011147 → liger_kernel_nightly-0.5.3.dev20250221011217}/test/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.3.dev20250221011147
3
+ Version: 0.5.3.dev20250221011217
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.3.dev20250221011147"
7
+ version = "0.5.3.dev20250221011217"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.3.dev20250221011147
3
+ Version: 0.5.3.dev20250221011217
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -204,6 +204,7 @@ test/resources/tiny_shakespeare_tokenized/state.json
204
204
  test/transformers/test_auto_model.py
205
205
  test/transformers/test_cross_entropy.py
206
206
  test/transformers/test_embedding.py
207
+ test/transformers/test_flex_attention.py
207
208
  test/transformers/test_fused_linear_cross_entropy.py
208
209
  test/transformers/test_fused_linear_jsd.py
209
210
  test/transformers/test_geglu.py
@@ -0,0 +1,283 @@
1
+ import pytest
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ from test.utils import assert_verbose_allclose
6
+ from test.utils import set_seed
7
+ from test.utils import supports_bfloat16
8
+ from torch.nn.attention.flex_attention import create_block_mask
9
+ from torch.nn.attention.flex_attention import create_mask
10
+ from torch.nn.attention.flex_attention import flex_attention
11
+
12
+ from liger_kernel.utils import infer_device
13
+
14
+
15
+ def causal_mask(b, h, q_idx, kv_idx):
16
+ return q_idx >= kv_idx
17
+
18
+
19
+ def prefix_mask(b, h, q_idx, kv_idx, rejected_index, chosen_index):
20
+ return (~((q_idx >= rejected_index[b]) & (chosen_index[b] <= kv_idx) & (kv_idx < rejected_index[b]))) & (
21
+ q_idx >= kv_idx
22
+ )
23
+
24
+
25
+ device = infer_device()
26
+ set_seed(42)
27
+
28
+
29
+ def _test_correctness_flex(B, H, S, D, mask_func, dtype, atol, rtol, device="cuda"):
30
+ """
31
+ Test attention mechanisms with various implementations.
32
+
33
+ Parameters:
34
+ B (int): Batch size
35
+ H (int): Number of attention heads
36
+ S (int): Sequence length
37
+ D (int): Hidden dimension per head
38
+ mask_func: A function that generates custom attention mask
39
+ dtype: Data type for computation
40
+ atol (float): Absolute tolerance for comparison
41
+ rtol (float): Relative tolerance for comparison
42
+ """
43
+ torch.manual_seed(0)
44
+
45
+ # Initialize input tensors, i.e. the tensors after q, k, and v projections of hidden states (attention head input)
46
+ query_torch = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)
47
+ key_torch = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)
48
+ value_torch = torch.randn(B, H, S, D, device=device, dtype=dtype, requires_grad=True)
49
+
50
+ query_flex = query_torch.clone().detach().requires_grad_(True)
51
+ key_flex = key_torch.clone().detach().requires_grad_(True)
52
+ value_flex = value_torch.clone().detach().requires_grad_(True)
53
+
54
+ block_mask = create_block_mask(mask_func, B, H, S, S, device=device) # Sparsity block mask
55
+ mask = create_mask(mask_func, B, H, S, S, device=device) # Regular mask
56
+
57
+ # If you are using a causal mask with FA2, you can enable `is_causal`."
58
+ # e.g.,
59
+ # F.scaled_dot_product_attention(query, key, value, is_causal=is_causal)
60
+
61
+ torch_out = F.scaled_dot_product_attention(query_torch, key_torch, value_torch, attn_mask=mask)
62
+
63
+ flex_out = flex_attention(query_flex, key_flex, value_flex, block_mask=block_mask)
64
+
65
+ # Check forward pass
66
+ assert_verbose_allclose(flex_out, torch_out, atol=atol, rtol=rtol)
67
+
68
+ grad_out = torch.randn_like(torch_out)
69
+ torch_out.backward(grad_out)
70
+ flex_out.backward(grad_out)
71
+
72
+ # Check gradients
73
+ assert_verbose_allclose(query_flex.grad, query_torch.grad, atol=atol, rtol=rtol)
74
+ assert_verbose_allclose(key_flex.grad, key_torch.grad, atol=atol, rtol=rtol)
75
+ assert_verbose_allclose(value_flex.grad, value_torch.grad, atol=atol, rtol=rtol)
76
+
77
+
78
+ @pytest.mark.parametrize(
79
+ "B, H, S, D",
80
+ [
81
+ (2, 8, 1024, 32),
82
+ (3, 12, 2048, 64),
83
+ ],
84
+ )
85
+ @pytest.mark.parametrize(
86
+ "dtype, atol, rtol",
87
+ [
88
+ pytest.param(
89
+ torch.bfloat16,
90
+ 3e-2,
91
+ 5e-1,
92
+ marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
93
+ ),
94
+ (torch.float16, 1e-2, 5e-3),
95
+ (torch.float32, 1e-3, 5e-4),
96
+ ],
97
+ )
98
+ def test_correctness_flex(B, H, S, D, dtype, atol, rtol):
99
+ _test_correctness_flex(B, H, S, D, causal_mask, dtype, atol, rtol)
100
+
101
+ # Roughly generate custom rejected and chosen indices for each batch
102
+ chosen_index = torch.randint(0, S // 2, (B,), device="cuda")
103
+ rejected_index = torch.randint(S // 2, S, (B,), device="cuda")
104
+
105
+ def wrapped_prefix_mask(b, h, q_idx, kv_idx):
106
+ return prefix_mask(b, h, q_idx, kv_idx, rejected_index, chosen_index)
107
+
108
+ _test_correctness_flex(B, H, S, D, wrapped_prefix_mask, dtype, atol, rtol)
109
+
110
+
111
+ def _test_correctness_prefix(
112
+ B=2,
113
+ H=8,
114
+ P=512,
115
+ C=256,
116
+ R=256,
117
+ D=32,
118
+ dtype=torch.float32,
119
+ atol=1e-3,
120
+ rtol=5e-4,
121
+ device="cuda",
122
+ ):
123
+ """
124
+ Test that prefix sharing attention matches separate computations (i.e. two separate casual masked attention, prefix+chosen and prefix+rejected).
125
+ The mental model is:
126
+
127
+ A. prefix + chosen
128
+ P
129
+ P P
130
+ P P P
131
+ P P P C
132
+ P P P C C
133
+ P P P C C C
134
+
135
+ B. prefix + rejected
136
+ P
137
+ P P
138
+ P P P
139
+ P P P R
140
+ P P P R R
141
+ P P P R R R
142
+
143
+ C. shared prefix + chosen + rejected
144
+ P
145
+ P P
146
+ P P P
147
+ P P P C
148
+ P P P C C
149
+ P P P C C C
150
+ P P P R
151
+ P P P R R
152
+ P P P R R R
153
+
154
+
155
+ We test them as below to ensure attention value equivalence:
156
+ 1. prefix of shared attn (upper of C.) == prefix of chosen attn (upper of A.)
157
+ 2. prefix of shared attn (upper of C.) == prefix of rejected attn (upper of B.)
158
+ P P
159
+ P P = P P
160
+ P P P P P P
161
+
162
+ 3. prefix of shared attn (middle right of C.) == prefix of chosen attn (lower right of A.)
163
+ C C
164
+ C C = C C
165
+ C C C C C C
166
+
167
+ 4. prefix of shared attn (lower right of C.) == prefix of rejected attn (lower right of B.)
168
+ R R
169
+ R R = R R
170
+ R R R R R R
171
+
172
+ Args:
173
+ B: batch size
174
+ H: number of heads
175
+ P: prefix length
176
+ C: chosen response length
177
+ R: rejected response length
178
+ D: hidden dimension per head
179
+ """
180
+ torch.manual_seed(0)
181
+
182
+ # Total sequence length for shared version
183
+ S = P + C + R
184
+
185
+ # Initialize input tensors, i.e. the tensors after q, k, and v projections of hidden states (attention head input)
186
+ query = torch.randn(B, H, S, D, device=device, dtype=dtype)
187
+ key = torch.randn(B, H, S, D, device=device, dtype=dtype)
188
+ value = torch.randn(B, H, S, D, device=device, dtype=dtype)
189
+
190
+ # Split tensors for separate computation
191
+ query_prefix = query[:, :, :P, :]
192
+ key_prefix = key[:, :, :P, :]
193
+ value_prefix = value[:, :, :P, :]
194
+
195
+ query_chosen = query[:, :, P : P + C, :]
196
+ key_chosen = key[:, :, P : P + C, :]
197
+ value_chosen = value[:, :, P : P + C, :]
198
+
199
+ query_rejected = query[:, :, P + C :, :]
200
+ key_rejected = key[:, :, P + C :, :]
201
+ value_rejected = value[:, :, P + C :, :]
202
+
203
+ chosen_index = torch.full((B,), P + C, device=device)
204
+ rejected_index = torch.full((B,), S, device=device)
205
+
206
+ def wrapped_prefix_mask(b, h, q_idx, kv_idx):
207
+ return prefix_mask(b, h, q_idx, kv_idx, rejected_index, chosen_index)
208
+
209
+ block_mask = create_block_mask(wrapped_prefix_mask, B, H, S, S, device=device)
210
+ shared_out = flex_attention(query, key, value, block_mask=block_mask)
211
+
212
+ # Compute attention for prefix + chosen separately
213
+ PC = P + C
214
+ query_pc = torch.cat([query_prefix, query_chosen], dim=2)
215
+ key_pc = torch.cat([key_prefix, key_chosen], dim=2)
216
+ value_pc = torch.cat([value_prefix, value_chosen], dim=2)
217
+
218
+ def causal_mask(b, h, q_idx, kv_idx):
219
+ return q_idx >= kv_idx
220
+
221
+ pc_block_mask = create_block_mask(causal_mask, B, H, PC, PC, device=device)
222
+ pc_out = flex_attention(query_pc, key_pc, value_pc, block_mask=pc_block_mask)
223
+
224
+ # Compute attention for prefix + rejected separately
225
+ PR = P + R
226
+ query_pr = torch.cat([query_prefix, query_rejected], dim=2)
227
+ key_pr = torch.cat([key_prefix, key_rejected], dim=2)
228
+ value_pr = torch.cat([value_prefix, value_rejected], dim=2)
229
+
230
+ pr_block_mask = create_block_mask(causal_mask, B, H, PR, PR, device=device)
231
+ pr_out = flex_attention(query_pr, key_pr, value_pr, block_mask=pr_block_mask)
232
+
233
+ shared_prefix = shared_out[:, :, :P, :P]
234
+ shared_chosen = shared_out[:, :, P : P + C, P : P + C]
235
+ shared_rejected = shared_out[:, :, P + C :, P + C :]
236
+
237
+ separate_prefix_c = pc_out[:, :, :P, :P]
238
+ separate_chosen = pc_out[:, :, P:, P:]
239
+ separate_prefix_r = pr_out[:, :, :P, :P]
240
+ separate_rejected = pr_out[:, :, P:, P:]
241
+
242
+ # Verify prefix outputs are identical
243
+ assert torch.allclose(
244
+ shared_prefix, separate_prefix_c, atol=atol, rtol=rtol
245
+ ), "Prefix attention from shared computation doesn't match prefix+chosen computation"
246
+ assert torch.allclose(
247
+ shared_prefix, separate_prefix_r, atol=atol, rtol=rtol
248
+ ), "Prefix attention from shared computation doesn't match prefix+rejected computation"
249
+
250
+ # Verify chosen and rejected outputs
251
+ assert torch.allclose(
252
+ shared_chosen, separate_chosen, atol=atol, rtol=rtol
253
+ ), "Chosen response attention doesn't match between shared and separate computation"
254
+ assert torch.allclose(
255
+ shared_rejected, separate_rejected, atol=atol, rtol=rtol
256
+ ), "Rejected response attention doesn't match between shared and separate computation"
257
+
258
+ print("All attention values match between shared and separate computations!")
259
+
260
+
261
+ @pytest.mark.parametrize(
262
+ "B, H, P, C, R, D",
263
+ [
264
+ (2, 8, 512, 256, 256, 32),
265
+ (3, 12, 1024, 512, 512, 64),
266
+ ],
267
+ )
268
+ @pytest.mark.parametrize(
269
+ "dtype, atol, rtol",
270
+ [
271
+ pytest.param(
272
+ torch.bfloat16,
273
+ 3e-2,
274
+ 5e-1,
275
+ marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
276
+ ),
277
+ (torch.float16, 1e-2, 5e-3),
278
+ (torch.float32, 1e-3, 5e-4),
279
+ ],
280
+ )
281
+ def test_correctness_prefix(B, H, P, C, R, D, dtype, atol, rtol):
282
+ """Parametrized test for different configurations"""
283
+ _test_correctness_prefix(B=B, H=H, P=P, C=C, R=R, D=D, dtype=dtype, atol=atol, rtol=rtol)