liger-kernel-nightly 0.5.2.dev20241220231758__tar.gz → 0.5.2.dev20241223032630__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (199) hide show
  1. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/pyproject.toml +1 -1
  3. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/ops/cross_entropy.py +2 -2
  4. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/ops/kl_div.py +4 -4
  5. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/ops/rms_norm.py +3 -3
  6. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/ops/rope.py +23 -9
  7. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/rope.py +2 -2
  8. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
  9. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/transformers/test_rope.py +30 -2
  10. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/.flake8 +0 -0
  11. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  12. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  13. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/.github/pull_request_template.md +0 -0
  14. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/.github/workflows/amd-ci.yml +0 -0
  15. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/.github/workflows/nvi-ci.yml +0 -0
  16. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/.github/workflows/publish-nightly.yml +0 -0
  17. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/.github/workflows/publish-release.yml +0 -0
  18. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/.gitignore +0 -0
  19. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/.isort.cfg +0 -0
  20. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/LICENSE +0 -0
  21. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/Makefile +0 -0
  22. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/NOTICE +0 -0
  23. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/README.md +0 -0
  24. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/benchmark/__init__.py +0 -0
  25. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/benchmark/benchmarks_visualizer.py +0 -0
  26. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/benchmark/data/all_benchmark_data.csv +0 -0
  27. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/benchmark/scripts/__init__.py +0 -0
  28. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  29. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  30. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  31. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/benchmark/scripts/benchmark_embedding.py +0 -0
  32. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  33. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  34. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/benchmark/scripts/benchmark_geglu.py +0 -0
  35. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/benchmark/scripts/benchmark_group_norm.py +0 -0
  36. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/benchmark/scripts/benchmark_jsd.py +0 -0
  37. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/benchmark/scripts/benchmark_kl_div.py +0 -0
  38. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  39. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  40. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  41. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  42. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/benchmark/scripts/benchmark_rope.py +0 -0
  43. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  44. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/benchmark/scripts/benchmark_swiglu.py +0 -0
  45. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/benchmark/scripts/utils.py +0 -0
  46. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/dev/fmt-requirements.txt +0 -0
  47. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/dev/modal/tests.py +0 -0
  48. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/dev/modal/tests_bwd.py +0 -0
  49. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/docs/Acknowledgement.md +0 -0
  50. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/docs/CONTRIBUTING.md +0 -0
  51. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/docs/License.md +0 -0
  52. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/docs/images/banner.GIF +0 -0
  53. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/docs/images/compose.gif +0 -0
  54. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/docs/images/e2e-memory.png +0 -0
  55. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/docs/images/e2e-tps.png +0 -0
  56. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/docs/images/logo-banner.png +0 -0
  57. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/docs/images/patch.gif +0 -0
  58. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/docs/images/post-training.png +0 -0
  59. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/alignment/accelerate_config.yaml +0 -0
  60. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/alignment/run_orpo.py +0 -0
  61. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/huggingface/README.md +0 -0
  62. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/huggingface/callback.py +0 -0
  63. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/huggingface/config/fsdp_config.json +0 -0
  64. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  65. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  66. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  67. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/huggingface/img/llama_tps.png +0 -0
  68. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  69. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/huggingface/img/qwen_tps.png +0 -0
  70. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/huggingface/launch_on_modal.py +0 -0
  71. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/huggingface/requirements.txt +0 -0
  72. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/huggingface/run_benchmarks.sh +0 -0
  73. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/huggingface/run_gemma.sh +0 -0
  74. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/huggingface/run_llama.sh +0 -0
  75. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/huggingface/run_qwen.sh +0 -0
  76. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/huggingface/run_qwen2_vl.sh +0 -0
  77. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/huggingface/training.py +0 -0
  78. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/huggingface/training_multimodal.py +0 -0
  79. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/lightning/README.md +0 -0
  80. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/lightning/requirements.txt +0 -0
  81. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/lightning/training.py +0 -0
  82. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/medusa/README.md +0 -0
  83. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/medusa/callback.py +0 -0
  84. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  85. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  86. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  87. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  88. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  89. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  90. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  91. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  92. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  93. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/medusa/medusa_util.py +0 -0
  94. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/medusa/requirements.txt +0 -0
  95. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  96. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/examples/medusa/train.py +0 -0
  97. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/licenses/LICENSE-Apache-2.0 +0 -0
  98. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  99. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  100. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/licenses/LICENSE-MIT-llmc +0 -0
  101. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/licenses/LICENSE-MIT-triton +0 -0
  102. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/setup.cfg +0 -0
  103. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/setup.py +0 -0
  104. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/__init__.py +0 -0
  105. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/chunked_loss/README.md +0 -0
  106. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  107. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  108. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  109. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/chunked_loss/functional.py +0 -0
  110. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  111. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  112. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  113. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  114. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/env_report.py +0 -0
  115. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/ops/__init__.py +0 -0
  116. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  117. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  118. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  119. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  120. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/ops/geglu.py +0 -0
  121. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/ops/group_norm.py +0 -0
  122. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/ops/jsd.py +0 -0
  123. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/ops/layer_norm.py +0 -0
  124. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  125. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/ops/swiglu.py +0 -0
  126. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/ops/utils.py +0 -0
  127. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/__init__.py +0 -0
  128. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/auto_model.py +0 -0
  129. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  130. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  131. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/functional.py +0 -0
  132. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  133. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  134. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/geglu.py +0 -0
  135. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/group_norm.py +0 -0
  136. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/jsd.py +0 -0
  137. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/kl_div.py +0 -0
  138. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/layer_norm.py +0 -0
  139. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/model/__init__.py +0 -0
  140. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/model/gemma.py +0 -0
  141. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  142. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/model/llama.py +0 -0
  143. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/model/mistral.py +0 -0
  144. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  145. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/model/mllama.py +0 -0
  146. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/model/phi3.py +0 -0
  147. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  148. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  149. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  150. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  151. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/rms_norm.py +0 -0
  152. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/swiglu.py +0 -0
  153. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  154. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  155. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  156. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/triton/__init__.py +0 -0
  157. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/triton/monkey_patch.py +0 -0
  158. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel/utils.py +0 -0
  159. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
  160. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  161. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  162. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  163. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/__init__.py +0 -0
  164. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/chunked_loss/__init__.py +0 -0
  165. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/chunked_loss/test_cpo_loss.py +0 -0
  166. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/chunked_loss/test_dpo_loss.py +0 -0
  167. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/chunked_loss/test_orpo_loss.py +0 -0
  168. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/chunked_loss/test_simpo_loss.py +0 -0
  169. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/conftest.py +0 -0
  170. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/convergence/__init__.py +0 -0
  171. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/convergence/test_mini_models.py +0 -0
  172. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/convergence/test_mini_models_multimodal.py +0 -0
  173. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/convergence/test_mini_models_with_logits.py +0 -0
  174. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  175. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  176. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  177. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/resources/tiny_shakespeare.txt +0 -0
  178. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  179. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  180. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  181. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/transformers/test_auto_model.py +0 -0
  182. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/transformers/test_cross_entropy.py +0 -0
  183. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/transformers/test_embedding.py +0 -0
  184. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  185. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/transformers/test_fused_linear_jsd.py +0 -0
  186. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/transformers/test_geglu.py +0 -0
  187. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/transformers/test_group_norm.py +0 -0
  188. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/transformers/test_jsd.py +0 -0
  189. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/transformers/test_kl_div.py +0 -0
  190. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/transformers/test_layer_norm.py +0 -0
  191. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/transformers/test_mm_int8int2.py +0 -0
  192. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/transformers/test_monkey_patch.py +0 -0
  193. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/transformers/test_qwen2vl_mrope.py +0 -0
  194. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/transformers/test_rms_norm.py +0 -0
  195. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/transformers/test_swiglu.py +0 -0
  196. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/transformers/test_trainer_integration.py +0 -0
  197. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/transformers/test_transformers.py +0 -0
  198. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/triton/test_triton_monkey_patch.py +0 -0
  199. {liger_kernel_nightly-0.5.2.dev20241220231758 → liger_kernel_nightly-0.5.2.dev20241223032630}/test/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20241220231758
3
+ Version: 0.5.2.dev20241223032630
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.2.dev20241220231758"
7
+ version = "0.5.2.dev20241223032630"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -17,8 +17,8 @@ if compare_version("triton", operator.ge, "3.0.0"):
17
17
  else:
18
18
  from triton.language.math import tanh
19
19
 
20
- _TRUE = tl.constexpr(1)
21
- _FALSE = tl.constexpr(0)
20
+ _TRUE: tl.constexpr = tl.constexpr(1)
21
+ _FALSE: tl.constexpr = tl.constexpr(0)
22
22
 
23
23
 
24
24
  @triton.jit
@@ -23,10 +23,10 @@ MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best
23
23
 
24
24
  REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
25
25
 
26
- _REDUCTION_MODE_NONE = tl.constexpr(0)
27
- _REDUCTION_MODE_SUM = tl.constexpr(1)
28
- _REDUCTION_MODE_MEAN = tl.constexpr(2)
29
- _REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)
26
+ _REDUCTION_MODE_NONE: tl.constexpr = tl.constexpr(0)
27
+ _REDUCTION_MODE_SUM: tl.constexpr = tl.constexpr(1)
28
+ _REDUCTION_MODE_MEAN: tl.constexpr = tl.constexpr(2)
29
+ _REDUCTION_MODE_BATCHMEAN: tl.constexpr = tl.constexpr(3)
30
30
 
31
31
  _str_to_reduction_mode = {
32
32
  "none": _REDUCTION_MODE_NONE.value,
@@ -35,9 +35,9 @@ else:
35
35
  from triton.language.math import rsqrt
36
36
 
37
37
 
38
- _CASTING_MODE_NONE = tl.constexpr(-1)
39
- _CASTING_MODE_LLAMA = tl.constexpr(0)
40
- _CASTING_MODE_GEMMA = tl.constexpr(1)
38
+ _CASTING_MODE_NONE: tl.constexpr = tl.constexpr(-1)
39
+ _CASTING_MODE_LLAMA: tl.constexpr = tl.constexpr(0)
40
+ _CASTING_MODE_GEMMA: tl.constexpr = tl.constexpr(1)
41
41
 
42
42
 
43
43
  @triton.jit
@@ -15,6 +15,7 @@ def _triton_rope(
15
15
  sin_row_stride,
16
16
  sl,
17
17
  bs: tl.constexpr,
18
+ cos_bs: tl.constexpr,
18
19
  n_qh: tl.constexpr,
19
20
  n_kh: tl.constexpr,
20
21
  hd: tl.constexpr,
@@ -29,7 +30,7 @@ def _triton_rope(
29
30
  # k size: (bsz, seq_len, num_kv_heads, head_dim)
30
31
  # k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)
31
32
 
32
- # cos size: (1, seq_len, head_dim)
33
+ # cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
33
34
  # stride: (seq_len * head_dim, head_dim, 1)
34
35
  pid = tl.program_id(0)
35
36
 
@@ -48,9 +49,19 @@ def _triton_rope(
48
49
  # and pid % sl to get the sequence index.
49
50
  # 2. We only need the left half of cos and sin matrix because the right half is just
50
51
  # a clone of the left half.
51
- cos_row_idx = pid % (sl)
52
- cos = cos + cos_row_idx * cos_row_stride
53
- sin = sin + cos_row_idx * sin_row_stride
52
+ batch_idx = pid // sl
53
+ cos_row_idx = pid % sl
54
+ cos = cos + tl.where(
55
+ cos_bs == 1,
56
+ cos_row_idx * cos_row_stride,
57
+ batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride,
58
+ )
59
+ sin = sin + tl.where(
60
+ cos_bs == 1,
61
+ cos_row_idx * sin_row_stride,
62
+ batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride,
63
+ )
64
+
54
65
  cos_offsets = tl.arange(0, pad_hd // 2)
55
66
  cos_mask = cos_offsets < hd // 2
56
67
  cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
@@ -118,7 +129,6 @@ def _triton_rope(
118
129
 
119
130
 
120
131
  def rope_forward(q, k, cos, sin):
121
-
122
132
  # transpose it back to the physical shape because Triton looks at the physical storage
123
133
  # note: q and k are incontiguous before the transformation and will become contiguous after transpose
124
134
  q = q.transpose(1, 2)
@@ -138,6 +148,7 @@ def rope_forward(q, k, cos, sin):
138
148
  k = k.contiguous()
139
149
  cos = cos.contiguous()
140
150
  sin = sin.contiguous()
151
+ cos_batch_size = cos.shape[0]
141
152
 
142
153
  _triton_rope[(n_row,)](
143
154
  q,
@@ -150,6 +161,7 @@ def rope_forward(q, k, cos, sin):
150
161
  sin.stride(-2),
151
162
  seq_len,
152
163
  batch_size,
164
+ cos_batch_size,
153
165
  n_q_head,
154
166
  n_kv_head,
155
167
  head_dim,
@@ -167,6 +179,7 @@ def rope_backward(dq, dk, cos, sin):
167
179
  dk = dk.transpose(1, 2)
168
180
 
169
181
  batch_size, seq_len, n_q_head, head_dim = dq.shape
182
+ cos_batch_size = cos.shape[0]
170
183
  n_kv_head = dk.shape[2]
171
184
  pad_hd = triton.next_power_of_2(head_dim)
172
185
  pad_n_q_head = triton.next_power_of_2(n_q_head)
@@ -191,6 +204,7 @@ def rope_backward(dq, dk, cos, sin):
191
204
  sin.stride(-2),
192
205
  seq_len,
193
206
  batch_size,
207
+ cos_batch_size,
194
208
  n_q_head,
195
209
  n_kv_head,
196
210
  head_dim,
@@ -221,8 +235,8 @@ class LigerRopeFunction(torch.autograd.Function):
221
235
  """
222
236
  q size: (bsz, n_q_head, seq_len, head_dim)
223
237
  k size: (bsz, n_kv_head, seq_len, head_dim)
224
- cos size: (1, seq_len, head_dim)
225
- sin size: (1, seq_len, head_dim)
238
+ cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
239
+ sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
226
240
  """
227
241
  q, k, cos, sin = rope_forward(q, k, cos, sin)
228
242
  ctx.save_for_backward(cos, sin)
@@ -232,8 +246,8 @@ class LigerRopeFunction(torch.autograd.Function):
232
246
  """
233
247
  dq size: (bsz, n_q_head, seq_len, head_dim)
234
248
  dk size: (bsz, n_kv_head, seq_len, head_dim)
235
- cos size: (1, seq_len, head_dim)
236
- sin size: (1, seq_len, head_dim)
249
+ cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
250
+ sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
237
251
  """
238
252
 
239
253
  cos, sin = ctx.saved_tensors
@@ -8,8 +8,8 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
8
8
  Args:
9
9
  q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
10
10
  k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
11
- cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim).
12
- sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim).
11
+ cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
12
+ sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
13
13
  position_ids (torch.Tensor, optional): The position ids tensor. Defaults to None.
14
14
  unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
15
15
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.2.dev20241220231758
3
+ Version: 0.5.2.dev20241223032630
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -46,8 +46,20 @@ SLEEP_SECONDS = 0.1
46
46
  ),
47
47
  ],
48
48
  )
49
+ @pytest.mark.parametrize(
50
+ "expand_position_ids",
51
+ [True, False],
52
+ )
49
53
  def test_correctness(
50
- bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol
54
+ bsz,
55
+ seq_len,
56
+ num_q_heads,
57
+ num_kv_heads,
58
+ head_dim,
59
+ dtype,
60
+ expand_position_ids,
61
+ atol,
62
+ rtol,
51
63
  ):
52
64
  rotary_emb = LlamaRotaryEmbedding(head_dim, device=device)
53
65
 
@@ -70,6 +82,8 @@ def test_correctness(
70
82
  k2 = _tensor_k.clone().requires_grad_(True)
71
83
 
72
84
  pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
85
+ if expand_position_ids:
86
+ pos_ids = pos_ids.expand(bsz, -1)
73
87
  cos, sin = rotary_emb(k1, pos_ids)
74
88
 
75
89
  # validate forward pass
@@ -111,8 +125,20 @@ def test_correctness(
111
125
  (torch.bfloat16, 1e-1, 1e-5),
112
126
  ],
113
127
  )
128
+ @pytest.mark.parametrize(
129
+ "expand_position_ids",
130
+ [True, False],
131
+ )
114
132
  def test_functional_correctness(
115
- bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol
133
+ bsz,
134
+ seq_len,
135
+ num_q_heads,
136
+ num_kv_heads,
137
+ head_dim,
138
+ expand_position_ids,
139
+ dtype,
140
+ atol,
141
+ rtol,
116
142
  ):
117
143
  _q = torch.randn((bsz, num_q_heads, seq_len, head_dim), device=device, dtype=dtype)
118
144
  _k = torch.randn((bsz, num_kv_heads, seq_len, head_dim), device=device, dtype=dtype)
@@ -126,6 +152,8 @@ def test_functional_correctness(
126
152
  rotary_emb = LlamaRotaryEmbedding(head_dim, device=device)
127
153
 
128
154
  pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
155
+ if expand_position_ids:
156
+ pos_ids = pos_ids.expand(bsz, -1)
129
157
  cos, sin = rotary_emb(k1, pos_ids)
130
158
 
131
159
  functional_q, functional_k = liger_rope(q=q1, k=k1, cos=cos, sin=sin)