liger-kernel-nightly 0.5.3.dev20250221230243__tar.gz → 0.5.3.dev20250224175624__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (227) hide show
  1. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_rope.py +1 -1
  3. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/pyproject.toml +1 -1
  4. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/layer_norm.py +20 -7
  5. liger_kernel_nightly-0.5.3.dev20250224175624/src/liger_kernel/utils.py +62 -0
  6. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
  7. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_layer_norm.py +20 -5
  8. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_rope.py +1 -1
  9. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/utils.py +0 -47
  10. liger_kernel_nightly-0.5.3.dev20250221230243/src/liger_kernel/utils.py +0 -15
  11. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  12. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  13. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/.github/pull_request_template.md +0 -0
  14. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/.github/workflows/amd-ci.yml +0 -0
  15. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/.github/workflows/docs.yml +0 -0
  16. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/.github/workflows/intel-ci.yml +0 -0
  17. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/.github/workflows/nvi-ci.yml +0 -0
  18. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/.github/workflows/publish-nightly.yml +0 -0
  19. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/.github/workflows/publish-release.yml +0 -0
  20. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/.gitignore +0 -0
  21. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/LICENSE +0 -0
  22. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/Makefile +0 -0
  23. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/NOTICE +0 -0
  24. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/README.md +0 -0
  25. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/README.md +0 -0
  26. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/__init__.py +0 -0
  27. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/benchmarks_visualizer.py +0 -0
  28. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/data/all_benchmark_data.csv +0 -0
  29. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/__init__.py +0 -0
  30. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  31. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  32. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  33. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  34. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_embedding.py +0 -0
  35. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  36. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  37. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_geglu.py +0 -0
  38. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_group_norm.py +0 -0
  39. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_jsd.py +0 -0
  40. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_kl_div.py +0 -0
  41. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  42. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  43. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  44. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  45. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  46. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  47. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_swiglu.py +0 -0
  48. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/benchmark_tvd.py +0 -0
  49. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/benchmark/scripts/utils.py +0 -0
  50. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/dev/fmt-requirements.txt +0 -0
  51. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/dev/modal/tests.py +0 -0
  52. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/dev/modal/tests_bwd.py +0 -0
  53. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/Examples.md +0 -0
  54. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/Getting-Started.md +0 -0
  55. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/High-Level-APIs.md +0 -0
  56. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/Low-Level-APIs.md +0 -0
  57. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/acknowledgement.md +0 -0
  58. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/contributing.md +0 -0
  59. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/images/banner.GIF +0 -0
  60. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/images/compose.gif +0 -0
  61. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/images/e2e-memory.png +0 -0
  62. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/images/e2e-tps.png +0 -0
  63. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/images/logo-banner.png +0 -0
  64. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/images/patch.gif +0 -0
  65. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/images/post-training.png +0 -0
  66. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/index.md +0 -0
  67. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/docs/license.md +0 -0
  68. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/alignment/accelerate_config.yaml +0 -0
  69. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/alignment/run_orpo.py +0 -0
  70. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/README.md +0 -0
  71. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/callback.py +0 -0
  72. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/config/fsdp_config.json +0 -0
  73. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  74. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  75. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  76. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/img/llama_tps.png +0 -0
  77. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  78. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/img/qwen_tps.png +0 -0
  79. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/launch_on_modal.py +0 -0
  80. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/requirements.txt +0 -0
  81. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/run_benchmarks.sh +0 -0
  82. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/run_gemma.sh +0 -0
  83. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/run_llama.sh +0 -0
  84. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/run_qwen.sh +0 -0
  85. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/run_qwen2_vl.sh +0 -0
  86. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/training.py +0 -0
  87. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/huggingface/training_multimodal.py +0 -0
  88. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/lightning/README.md +0 -0
  89. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/lightning/requirements.txt +0 -0
  90. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/lightning/training.py +0 -0
  91. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/README.md +0 -0
  92. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/callback.py +0 -0
  93. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  94. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  95. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  96. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  97. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  98. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  99. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  100. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  101. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  102. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/medusa_util.py +0 -0
  103. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/requirements.txt +0 -0
  104. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  105. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/examples/medusa/train.py +0 -0
  106. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/licenses/LICENSE-Apache-2.0 +0 -0
  107. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  108. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  109. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/licenses/LICENSE-MIT-llmc +0 -0
  110. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/licenses/LICENSE-MIT-triton +0 -0
  111. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/mkdocs.yml +0 -0
  112. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/setup.cfg +0 -0
  113. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/setup.py +0 -0
  114. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/__init__.py +0 -0
  115. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/chunked_loss/README.md +0 -0
  116. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  117. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  118. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  119. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/chunked_loss/functional.py +0 -0
  120. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  121. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  122. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -0
  123. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  124. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
  125. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  126. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  127. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  128. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  129. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/env_report.py +0 -0
  130. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/__init__.py +0 -0
  131. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/cross_entropy.py +0 -0
  132. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  133. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  134. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  135. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  136. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/geglu.py +0 -0
  137. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/group_norm.py +0 -0
  138. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/jsd.py +0 -0
  139. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/kl_div.py +0 -0
  140. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  141. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/rms_norm.py +0 -0
  142. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/rope.py +0 -0
  143. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/swiglu.py +0 -0
  144. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/tvd.py +0 -0
  145. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/ops/utils.py +0 -0
  146. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/__init__.py +0 -0
  147. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/auto_model.py +0 -0
  148. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  149. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  150. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/functional.py +0 -0
  151. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  152. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  153. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/geglu.py +0 -0
  154. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/group_norm.py +0 -0
  155. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/jsd.py +0 -0
  156. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/kl_div.py +0 -0
  157. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/layer_norm.py +0 -0
  158. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/model/__init__.py +0 -0
  159. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/model/gemma.py +0 -0
  160. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  161. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/model/llama.py +0 -0
  162. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/model/mistral.py +0 -0
  163. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  164. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/model/mllama.py +0 -0
  165. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/model/phi3.py +0 -0
  166. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  167. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  168. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  169. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  170. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/rms_norm.py +0 -0
  171. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/rope.py +0 -0
  172. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/swiglu.py +0 -0
  173. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  174. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  175. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  176. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/transformers/tvd.py +0 -0
  177. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/triton/__init__.py +0 -0
  178. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel/triton/monkey_patch.py +0 -0
  179. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
  180. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  181. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  182. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  183. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/__init__.py +0 -0
  184. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/chunked_loss/__init__.py +0 -0
  185. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/chunked_loss/test_cpo_loss.py +0 -0
  186. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/chunked_loss/test_dpo_loss.py +0 -0
  187. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/chunked_loss/test_grpo_loss.py +0 -0
  188. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/chunked_loss/test_jsd_loss.py +0 -0
  189. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/chunked_loss/test_kto_loss.py +0 -0
  190. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/chunked_loss/test_orpo_loss.py +0 -0
  191. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/chunked_loss/test_simpo_loss.py +0 -0
  192. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/conftest.py +0 -0
  193. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/convergence/__init__.py +0 -0
  194. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/convergence/bf16/__init__.py +0 -0
  195. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/convergence/bf16/test_mini_models.py +0 -0
  196. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
  197. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
  198. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/convergence/fp32/__init__.py +0 -0
  199. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/convergence/fp32/test_mini_models.py +0 -0
  200. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
  201. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
  202. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  203. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  204. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  205. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/resources/tiny_shakespeare.txt +0 -0
  206. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  207. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  208. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  209. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_auto_model.py +0 -0
  210. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_cross_entropy.py +0 -0
  211. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_embedding.py +0 -0
  212. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_flex_attention.py +0 -0
  213. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  214. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_fused_linear_jsd.py +0 -0
  215. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_geglu.py +0 -0
  216. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_group_norm.py +0 -0
  217. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_jsd.py +0 -0
  218. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_kl_div.py +0 -0
  219. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_mm_int8int2.py +0 -0
  220. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_monkey_patch.py +0 -0
  221. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_qwen2vl_mrope.py +0 -0
  222. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_rms_norm.py +0 -0
  223. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_swiglu.py +0 -0
  224. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_trainer_integration.py +0 -0
  225. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_transformers.py +0 -0
  226. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/transformers/test_tvd.py +0 -0
  227. {liger_kernel_nightly-0.5.3.dev20250221230243 → liger_kernel_nightly-0.5.3.dev20250224175624}/test/triton/test_triton_monkey_patch.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.3.dev20250221230243
3
+ Version: 0.5.3.dev20250224175624
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -1,7 +1,6 @@
1
1
  import torch
2
2
  import triton
3
3
 
4
- from test.utils import transformers_version_dispatch
5
4
  from transformers.models.llama.configuration_llama import LlamaConfig
6
5
  from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
7
6
  from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
@@ -14,6 +13,7 @@ from utils import run_benchmarks
14
13
 
15
14
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
16
15
  from liger_kernel.utils import infer_device
16
+ from liger_kernel.utils import transformers_version_dispatch
17
17
 
18
18
  device = infer_device()
19
19
 
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.3.dev20250221230243"
7
+ version = "0.5.3.dev20250224175624"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -57,13 +57,14 @@ def _layer_norm_forward_kernel(
57
57
  B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0)
58
58
 
59
59
  mean = tl.sum(X_row, axis=0) / n_cols
60
- var = tl.sum((X_row - mean) * (X_row - mean), axis=0) / n_cols
60
+ Xmm = tl.where(mask, X_row - mean, 0)
61
+ var = tl.sum(Xmm * Xmm, axis=0) / n_cols
61
62
  rstd = rsqrt(var + eps)
62
63
 
63
64
  tl.store(Mean_ptr, mean)
64
65
  tl.store(RSTD_ptr, rstd)
65
66
 
66
- Y_row = (X_row - mean) * rstd * W_row + B_row
67
+ Y_row = Xmm * rstd * W_row + B_row
67
68
 
68
69
  tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
69
70
 
@@ -147,9 +148,11 @@ def layer_norm_forward(X, W, B, eps):
147
148
  Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
148
149
  Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
149
150
  RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
150
- assert X.shape[1] == W.shape[0], (
151
- f"Incompatible hidden size dimension between input tensor with shape[1] = {X.shape[1]} and weight tensor with shape[0] = {W.shape[0]}"
152
- )
151
+ if X.shape[1] != W.shape[0]:
152
+ raise ValueError(
153
+ f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
154
+ f"must match weight size (W.shape[0]={W.shape[0]})"
155
+ )
153
156
 
154
157
  _layer_norm_forward_kernel[(n_rows,)](
155
158
  Y,
@@ -190,11 +193,21 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
190
193
 
191
194
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
192
195
  if n_cols > BLOCK_SIZE:
193
- raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
196
+ raise RuntimeError(
197
+ f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
198
+ )
194
199
 
195
200
  rows_per_program = math.ceil(n_rows / sm_count)
196
201
  grid = (sm_count,)
197
- triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16
202
+ triton_dtype = (
203
+ tl.float32
204
+ if X.dtype == torch.float32
205
+ else tl.bfloat16
206
+ if X.dtype == torch.bfloat16
207
+ else tl.float16
208
+ if X.dtype == torch.float16
209
+ else tl.float32 # fallback to float32 for other types
210
+ )
198
211
  _layer_norm_backward_kernel[grid](
199
212
  X,
200
213
  W,
@@ -0,0 +1,62 @@
1
+ import torch
2
+
3
+
4
+ def infer_device():
5
+ """
6
+ Get current device name based on available devices
7
+ """
8
+ if torch.cuda.is_available():
9
+ return "cuda"
10
+ elif torch.xpu.is_available():
11
+ return "xpu"
12
+ elif torch.hip.is_available():
13
+ return "hip"
14
+ else:
15
+ return "cpu"
16
+
17
+
18
+ def transformers_version_dispatch(
19
+ required_version: str,
20
+ before_fn,
21
+ after_fn,
22
+ before_args: tuple = (),
23
+ after_args: tuple = (),
24
+ before_kwargs: dict = None,
25
+ after_kwargs: dict = None,
26
+ ):
27
+ """
28
+ Dispatches to different functions based on package version comparison.
29
+
30
+ Args:
31
+ required_version: Version to compare against (e.g. "4.48.0")
32
+ before_fn: Function to call if package_version < required_version
33
+ after_fn: Function to call if package_version >= required_version
34
+ before_args: Positional arguments for before_fn
35
+ after_args: Positional arguments for after_fn
36
+ before_kwargs: Keyword arguments for before_fn
37
+ after_kwargs: Keyword arguments for after_fn
38
+
39
+ Returns:
40
+ Result from either before_fn or after_fn
41
+
42
+ Example:
43
+ >>> rotary_emb = transformers_version_dispatch(
44
+ ... "4.48.0",
45
+ ... LlamaRotaryEmbedding,
46
+ ... LlamaRotaryEmbedding,
47
+ ... before_args=(head_dim,),
48
+ ... after_args=(LlamaConfig(head_dim=head_dim),),
49
+ ... before_kwargs={'device': device},
50
+ ... after_kwargs={'device': device}
51
+ ... )
52
+ """
53
+ from packaging import version
54
+ from transformers import __version__ as transformers_version
55
+
56
+ before_kwargs = before_kwargs or {}
57
+ after_kwargs = after_kwargs or {}
58
+
59
+ if version.parse(transformers_version) < version.parse(required_version):
60
+ return before_fn(*before_args, **before_kwargs)
61
+ else:
62
+ return after_fn(*after_args, **after_kwargs)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.3.dev20250221230243
3
+ Version: 0.5.3.dev20250224175624
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -14,6 +14,8 @@ device = infer_device()
14
14
  [
15
15
  (2, 8, 64),
16
16
  (4, 16, 128),
17
+ (1, 1, 1023), # Minimal batch/seq with near power-of-2 hidden
18
+ (3, 7, 256), # Prime numbers for batch/seq
17
19
  ],
18
20
  )
19
21
  @pytest.mark.parametrize(
@@ -22,7 +24,15 @@ device = infer_device()
22
24
  (torch.float32, 1e-5, 1e-5),
23
25
  ],
24
26
  )
25
- def test_liger_layer_norm(batch_size, seq_len, hidden_size, dtype, atol, rtol):
27
+ def test_liger_layer_norm(
28
+ batch_size: int,
29
+ seq_len: int,
30
+ hidden_size: int,
31
+ dtype: torch.dtype,
32
+ atol: float,
33
+ rtol: float,
34
+ ) -> None:
35
+ """Test basic layer norm functionality against PyTorch implementation."""
26
36
  torch.manual_seed(0)
27
37
 
28
38
  x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device)
@@ -64,7 +74,15 @@ def test_liger_layer_norm(batch_size, seq_len, hidden_size, dtype, atol, rtol):
64
74
  (torch.float32, 1e-5, 1e-5),
65
75
  ],
66
76
  )
67
- def test_liger_layer_norm_functional(hidden_size, batch_size, seq_len, dtype, atol, rtol):
77
+ def test_liger_layer_norm_functional(
78
+ hidden_size: int,
79
+ batch_size: int,
80
+ seq_len: int,
81
+ dtype: torch.dtype,
82
+ atol: float,
83
+ rtol: float,
84
+ ) -> None:
85
+ """Test functional layer norm interface against autograd function."""
68
86
  torch.manual_seed(0)
69
87
 
70
88
  input = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device)
@@ -73,12 +91,10 @@ def test_liger_layer_norm_functional(hidden_size, batch_size, seq_len, dtype, at
73
91
  x2 = input.clone().requires_grad_(True)
74
92
 
75
93
  w = torch.randn(hidden_size, device=device, dtype=dtype)
76
-
77
94
  w1 = w.clone().requires_grad_(True)
78
95
  w2 = w.clone().requires_grad_(True)
79
96
 
80
97
  b = torch.randn(hidden_size, device=device, dtype=dtype)
81
-
82
98
  b1 = b.clone().requires_grad_(True)
83
99
  b2 = b.clone().requires_grad_(True)
84
100
 
@@ -88,7 +104,6 @@ def test_liger_layer_norm_functional(hidden_size, batch_size, seq_len, dtype, at
88
104
  assert torch.allclose(y1, y2, atol=atol, rtol=rtol)
89
105
 
90
106
  grad_output = torch.randn_like(y2)
91
-
92
107
  y1.backward(grad_output, retain_graph=True)
93
108
  y2.backward(grad_output, retain_graph=True)
94
109
 
@@ -2,7 +2,6 @@ import pytest
2
2
  import torch
3
3
 
4
4
  from test.utils import supports_bfloat16
5
- from test.utils import transformers_version_dispatch
6
5
  from transformers.models.llama.configuration_llama import LlamaConfig
7
6
  from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
8
7
  from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
@@ -11,6 +10,7 @@ from liger_kernel.ops.rope import LigerRopeFunction
11
10
  from liger_kernel.transformers.functional import liger_rope
12
11
  from liger_kernel.transformers.rope import liger_rotary_pos_emb
13
12
  from liger_kernel.utils import infer_device
13
+ from liger_kernel.utils import transformers_version_dispatch
14
14
 
15
15
  device = infer_device()
16
16
 
@@ -214,53 +214,6 @@ def supports_bfloat16():
214
214
  return False
215
215
 
216
216
 
217
- def transformers_version_dispatch(
218
- required_version: str,
219
- before_fn,
220
- after_fn,
221
- before_args: tuple = (),
222
- after_args: tuple = (),
223
- before_kwargs: dict = None,
224
- after_kwargs: dict = None,
225
- ):
226
- """
227
- Dispatches to different functions based on package version comparison.
228
-
229
- Args:
230
- required_version: Version to compare against (e.g. "4.48.0")
231
- before_fn: Function to call if package_version < required_version
232
- after_fn: Function to call if package_version >= required_version
233
- before_args: Positional arguments for before_fn
234
- after_args: Positional arguments for after_fn
235
- before_kwargs: Keyword arguments for before_fn
236
- after_kwargs: Keyword arguments for after_fn
237
-
238
- Returns:
239
- Result from either before_fn or after_fn
240
-
241
- Example:
242
- >>> rotary_emb = transformers_version_dispatch(
243
- ... "4.48.0",
244
- ... LlamaRotaryEmbedding,
245
- ... LlamaRotaryEmbedding,
246
- ... before_args=(head_dim,),
247
- ... after_args=(LlamaConfig(head_dim=head_dim),),
248
- ... before_kwargs={'device': device},
249
- ... after_kwargs={'device': device}
250
- ... )
251
- """
252
- from packaging import version
253
- from transformers import __version__ as transformers_version
254
-
255
- before_kwargs = before_kwargs or {}
256
- after_kwargs = after_kwargs or {}
257
-
258
- if version.parse(transformers_version) < version.parse(required_version):
259
- return before_fn(*before_args, **before_kwargs)
260
- else:
261
- return after_fn(*after_args, **after_kwargs)
262
-
263
-
264
217
  def revert_liger_kernel_to_granite(model_config: MiniModelConfig):
265
218
  """
266
219
  Revert all Liger kernel patches applied to Granite.
@@ -1,15 +0,0 @@
1
- import torch
2
-
3
-
4
- def infer_device():
5
- """
6
- Get current device name based on available devices
7
- """
8
- if torch.cuda.is_available():
9
- return "cuda"
10
- elif torch.xpu.is_available():
11
- return "xpu"
12
- elif torch.hip.is_available():
13
- return "hip"
14
- else:
15
- return "cpu"