liger-kernel-nightly 0.5.3.dev20250221011217__tar.gz → 0.5.3.dev20250221230243__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (226) hide show
  1. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/PKG-INFO +2 -1
  2. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/README.md +1 -0
  3. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_tvd.py +8 -11
  4. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/dev/modal/tests.py +1 -1
  5. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/training.py +2 -0
  6. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/training_multimodal.py +67 -23
  7. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/pyproject.toml +1 -1
  8. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/tvd.py +6 -7
  9. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/functional.py +4 -1
  10. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/tvd.py +1 -3
  11. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/utils.py +2 -0
  12. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel_nightly.egg-info/PKG-INFO +2 -1
  13. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/convergence/bf16/test_mini_models.py +52 -37
  14. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/convergence/bf16/test_mini_models_with_logits.py +51 -37
  15. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/convergence/fp32/test_mini_models.py +61 -37
  16. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/convergence/fp32/test_mini_models_with_logits.py +60 -37
  17. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_flex_attention.py +25 -17
  18. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_tvd.py +13 -20
  19. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  20. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  21. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/.github/pull_request_template.md +0 -0
  22. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/.github/workflows/amd-ci.yml +0 -0
  23. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/.github/workflows/docs.yml +0 -0
  24. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/.github/workflows/intel-ci.yml +0 -0
  25. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/.github/workflows/nvi-ci.yml +0 -0
  26. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/.github/workflows/publish-nightly.yml +0 -0
  27. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/.github/workflows/publish-release.yml +0 -0
  28. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/.gitignore +0 -0
  29. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/LICENSE +0 -0
  30. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/Makefile +0 -0
  31. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/NOTICE +0 -0
  32. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/README.md +0 -0
  33. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/__init__.py +0 -0
  34. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/benchmarks_visualizer.py +0 -0
  35. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/data/all_benchmark_data.csv +0 -0
  36. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/__init__.py +0 -0
  37. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  38. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  39. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  40. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  41. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_embedding.py +0 -0
  42. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  43. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  44. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_geglu.py +0 -0
  45. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_group_norm.py +0 -0
  46. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_jsd.py +0 -0
  47. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_kl_div.py +0 -0
  48. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  49. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  50. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  51. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  52. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  53. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_rope.py +0 -0
  54. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  55. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/benchmark_swiglu.py +0 -0
  56. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/benchmark/scripts/utils.py +0 -0
  57. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/dev/fmt-requirements.txt +0 -0
  58. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/dev/modal/tests_bwd.py +0 -0
  59. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/Examples.md +0 -0
  60. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/Getting-Started.md +0 -0
  61. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/High-Level-APIs.md +0 -0
  62. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/Low-Level-APIs.md +0 -0
  63. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/acknowledgement.md +0 -0
  64. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/contributing.md +0 -0
  65. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/images/banner.GIF +0 -0
  66. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/images/compose.gif +0 -0
  67. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/images/e2e-memory.png +0 -0
  68. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/images/e2e-tps.png +0 -0
  69. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/images/logo-banner.png +0 -0
  70. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/images/patch.gif +0 -0
  71. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/images/post-training.png +0 -0
  72. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/index.md +0 -0
  73. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/docs/license.md +0 -0
  74. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/alignment/accelerate_config.yaml +0 -0
  75. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/alignment/run_orpo.py +0 -0
  76. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/README.md +0 -0
  77. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/callback.py +0 -0
  78. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/config/fsdp_config.json +0 -0
  79. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  80. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  81. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  82. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/img/llama_tps.png +0 -0
  83. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  84. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/img/qwen_tps.png +0 -0
  85. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/launch_on_modal.py +0 -0
  86. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/requirements.txt +0 -0
  87. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/run_benchmarks.sh +0 -0
  88. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/run_gemma.sh +0 -0
  89. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/run_llama.sh +0 -0
  90. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/run_qwen.sh +0 -0
  91. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/huggingface/run_qwen2_vl.sh +0 -0
  92. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/lightning/README.md +0 -0
  93. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/lightning/requirements.txt +0 -0
  94. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/lightning/training.py +0 -0
  95. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/README.md +0 -0
  96. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/callback.py +0 -0
  97. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  98. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  99. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  100. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  101. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  102. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  103. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  104. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  105. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  106. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/medusa_util.py +0 -0
  107. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/requirements.txt +0 -0
  108. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  109. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/examples/medusa/train.py +0 -0
  110. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/licenses/LICENSE-Apache-2.0 +0 -0
  111. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  112. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  113. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/licenses/LICENSE-MIT-llmc +0 -0
  114. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/licenses/LICENSE-MIT-triton +0 -0
  115. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/mkdocs.yml +0 -0
  116. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/setup.cfg +0 -0
  117. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/setup.py +0 -0
  118. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/__init__.py +0 -0
  119. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/chunked_loss/README.md +0 -0
  120. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  121. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  122. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  123. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/chunked_loss/functional.py +0 -0
  124. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  125. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  126. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -0
  127. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  128. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
  129. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  130. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  131. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  132. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  133. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/env_report.py +0 -0
  134. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/__init__.py +0 -0
  135. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/cross_entropy.py +0 -0
  136. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  137. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  138. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  139. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  140. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/geglu.py +0 -0
  141. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/group_norm.py +0 -0
  142. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/jsd.py +0 -0
  143. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/kl_div.py +0 -0
  144. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/layer_norm.py +0 -0
  145. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  146. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/rms_norm.py +0 -0
  147. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/rope.py +0 -0
  148. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/swiglu.py +0 -0
  149. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/ops/utils.py +0 -0
  150. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/__init__.py +1 -1
  151. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/auto_model.py +0 -0
  152. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  153. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  154. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  155. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  156. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/geglu.py +0 -0
  157. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/group_norm.py +0 -0
  158. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/jsd.py +0 -0
  159. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/kl_div.py +0 -0
  160. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/layer_norm.py +0 -0
  161. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/model/__init__.py +0 -0
  162. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/model/gemma.py +0 -0
  163. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  164. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/model/llama.py +0 -0
  165. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/model/mistral.py +0 -0
  166. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  167. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/model/mllama.py +0 -0
  168. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/model/phi3.py +0 -0
  169. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  170. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  171. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  172. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  173. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/rms_norm.py +0 -0
  174. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/rope.py +0 -0
  175. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/swiglu.py +0 -0
  176. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  177. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  178. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  179. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/triton/__init__.py +0 -0
  180. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel/triton/monkey_patch.py +0 -0
  181. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
  182. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  183. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  184. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  185. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/__init__.py +0 -0
  186. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/chunked_loss/__init__.py +0 -0
  187. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/chunked_loss/test_cpo_loss.py +0 -0
  188. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/chunked_loss/test_dpo_loss.py +0 -0
  189. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/chunked_loss/test_grpo_loss.py +0 -0
  190. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/chunked_loss/test_jsd_loss.py +0 -0
  191. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/chunked_loss/test_kto_loss.py +0 -0
  192. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/chunked_loss/test_orpo_loss.py +0 -0
  193. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/chunked_loss/test_simpo_loss.py +0 -0
  194. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/conftest.py +0 -0
  195. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/convergence/__init__.py +0 -0
  196. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/convergence/bf16/__init__.py +0 -0
  197. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
  198. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/convergence/fp32/__init__.py +0 -0
  199. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
  200. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  201. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  202. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  203. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/resources/tiny_shakespeare.txt +0 -0
  204. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  205. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  206. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  207. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_auto_model.py +0 -0
  208. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_cross_entropy.py +0 -0
  209. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_embedding.py +0 -0
  210. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  211. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_fused_linear_jsd.py +0 -0
  212. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_geglu.py +0 -0
  213. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_group_norm.py +0 -0
  214. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_jsd.py +0 -0
  215. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_kl_div.py +0 -0
  216. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_layer_norm.py +0 -0
  217. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_mm_int8int2.py +0 -0
  218. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_monkey_patch.py +0 -0
  219. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_qwen2vl_mrope.py +0 -0
  220. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_rms_norm.py +0 -0
  221. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_rope.py +0 -0
  222. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_swiglu.py +0 -0
  223. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_trainer_integration.py +0 -0
  224. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/transformers/test_transformers.py +0 -0
  225. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/triton/test_triton_monkey_patch.py +0 -0
  226. {liger_kernel_nightly-0.5.3.dev20250221011217 → liger_kernel_nightly-0.5.3.dev20250221230243}/test/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.3.dev20250221011217
3
+ Version: 0.5.3.dev20250221230243
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -313,6 +313,7 @@ loss.backward()
313
313
  | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
314
314
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
315
315
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
316
+ | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
316
317
 
317
318
 
318
319
  ## Low-level APIs
@@ -265,6 +265,7 @@ loss.backward()
265
265
  | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
266
266
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
267
267
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
268
+ | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
268
269
 
269
270
 
270
271
  ## Low-level APIs
@@ -1,13 +1,12 @@
1
1
  import torch
2
2
  import triton
3
- from utils import (
4
- QUANTILES,
5
- SingleBenchmarkRunInput,
6
- SingleBenchmarkRunOutput,
7
- _test_memory,
8
- parse_benchmark_script_args,
9
- run_benchmarks,
10
- )
3
+
4
+ from utils import QUANTILES
5
+ from utils import SingleBenchmarkRunInput
6
+ from utils import SingleBenchmarkRunOutput
7
+ from utils import _test_memory
8
+ from utils import parse_benchmark_script_args
9
+ from utils import run_benchmarks
11
10
 
12
11
  from liger_kernel.transformers.tvd import LigerTVDLoss
13
12
 
@@ -67,9 +66,7 @@ def bench_speed_tvd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
67
66
  y = fwd()
68
67
  y.backward(retain_graph=True)
69
68
 
70
- ms_50, ms_20, ms_80 = triton.testing.do_bench(
71
- full, quantiles=QUANTILES, rep=100
72
- )
69
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100)
73
70
  return SingleBenchmarkRunOutput(
74
71
  y_20=ms_20,
75
72
  y_50=ms_50,
@@ -14,7 +14,7 @@ app = modal.App("liger_tests", image=image)
14
14
  repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH)
15
15
 
16
16
 
17
- @app.function(gpu="A10G", mounts=[repo], timeout=60 * 15)
17
+ @app.function(gpu="A10G", mounts=[repo], timeout=60 * 20)
18
18
  def liger_tests():
19
19
  import subprocess
20
20
 
@@ -15,6 +15,7 @@ from liger_kernel.transformers import AutoLigerKernelForCausalLM
15
15
  class CustomArguments:
16
16
  model_name: str = "meta-llama/Meta-Llama-3-8B"
17
17
  dataset: str = "tatsu-lab/alpaca"
18
+ max_seq_length: int = 512
18
19
  use_liger: bool = False
19
20
 
20
21
 
@@ -65,6 +66,7 @@ def train():
65
66
  model=model,
66
67
  args=training_args,
67
68
  data_collator=collator,
69
+ max_seq_length=custom_args.max_seq_length,
68
70
  train_dataset=train_dataset,
69
71
  eval_dataset=eval_dataset,
70
72
  formatting_func=formatting_prompts_func,
@@ -1,11 +1,15 @@
1
1
  import os
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import datasets
2
6
  import torch
3
7
  import transformers
4
- import datasets
5
- from dataclasses import dataclass
6
- from trl import SFTTrainer, SFTConfig
7
- from trl.trainer import ConstantLengthDataset
8
+
9
+ from callback import EfficiencyCallback
8
10
  from datasets import Image as ImageFeature
11
+ from trl import SFTTrainer
12
+
9
13
  from liger_kernel.transformers import monkey_patch
10
14
 
11
15
 
@@ -15,6 +19,8 @@ class CustomArguments:
15
19
  dataset: str = "HuggingFaceM4/the_cauldron"
16
20
  dataset_subset: str = "ai2d"
17
21
  dataset_split: str = "train"
22
+ max_seq_length: int = 512
23
+ dataset_text_field: str = "texts"
18
24
  use_liger: bool = False
19
25
 
20
26
 
@@ -89,37 +95,75 @@ def _format_for_convo(example, tokenizer):
89
95
  def train():
90
96
  parser = transformers.HfArgumentParser((transformers.TrainingArguments, CustomArguments))
91
97
  training_args, custom_args = parser.parse_args_into_dataclasses()
98
+ training_args.remove_unused_columns = False # required to not drop the image column
99
+ training_args.dataset_kwargs = {"skip_prepare_dataset": True}
92
100
 
93
- model, processor, image_token_id = construct_model_and_processor(
94
- custom_args.model_name, custom_args.use_liger
95
- )
101
+ model, processor, image_token_id = construct_model_and_processor(custom_args.model_name, custom_args.use_liger)
96
102
 
97
- dataset = datasets.load_dataset(
98
- custom_args.dataset,
99
- custom_args.dataset_subset,
100
- split=custom_args.dataset_split
103
+ dataset = (
104
+ datasets.load_dataset(
105
+ custom_args.dataset,
106
+ custom_args.dataset_subset,
107
+ split=custom_args.dataset_split,
108
+ )
109
+ .map(
110
+ _validate_and_extract_the_cauldron,
111
+ batched=True,
112
+ num_proc=min(os.cpu_count(), 16),
113
+ desc="Extracting text and images",
114
+ )
115
+ .map(
116
+ _format_for_convo,
117
+ fn_kwargs={"tokenizer": processor.tokenizer},
118
+ desc="Formatting for convo",
119
+ )
120
+ .cast_column("images", ImageFeature())
121
+ .train_test_split(test_size=0.1)
101
122
  )
102
123
 
103
- train_dataset, eval_dataset = prepare_dataset(dataset, processor, image_token_id)
124
+ train_dataset = dataset["train"]
125
+ eval_dataset = dataset["test"]
104
126
 
105
- sft_config = SFTConfig(
106
- output_dir=training_args.output_dir,
107
- per_device_train_batch_size=training_args.per_device_train_batch_size,
108
- per_device_eval_batch_size=training_args.per_device_eval_batch_size,
109
- learning_rate=training_args.learning_rate,
110
- num_train_epochs=training_args.num_train_epochs,
111
- gradient_accumulation_steps=training_args.gradient_accumulation_steps,
112
- )
127
+ def collate_fn(examples):
128
+ """
129
+ Taken directly from the TRL documentation with minor modifications:
130
+ https://huggingface.co/docs/trl/en/sft_trainer#a-custom-collator-for-processing-multi-modal-data
131
+
132
+ Modifications:
133
+ 1. `apply_chat_template` is used to preprocess the texts before training begins (see above)
134
+ 2. `example["messages"]` -> `example["texts"]` to conform with the_cauldron dataset schema
135
+ 3. Ignoring image tokens in the loss computation
136
+ """
137
+ # Get the texts and images
138
+ texts = [example["texts"] for example in examples]
139
+ images = [example["images"] for example in examples]
140
+
141
+ # Tokenize the texts and process the images
142
+ batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
143
+
144
+ # The labels are the input_ids, and we mask the padding tokens in the loss computation
145
+ labels = batch["input_ids"].clone()
146
+ labels[labels == processor.tokenizer.pad_token_id] = -100
147
+
148
+ # Ignore the image token index in the loss computation
149
+ labels[labels == image_token_id] = -100
150
+ batch["labels"] = labels
151
+
152
+ return batch
113
153
 
114
154
  trainer = SFTTrainer(
115
155
  model=model,
116
- args=sft_config,
156
+ args=training_args,
157
+ data_collator=collate_fn,
158
+ max_seq_length=custom_args.max_seq_length,
159
+ dataset_text_field=custom_args.dataset_text_field,
117
160
  train_dataset=train_dataset,
118
161
  eval_dataset=eval_dataset,
119
- processing_class=processor,
162
+ tokenizer=processor.tokenizer,
163
+ callbacks=[EfficiencyCallback()],
120
164
  )
121
165
  trainer.train()
122
166
 
123
167
 
124
168
  if __name__ == "__main__":
125
- train()
169
+ train()
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.3.dev20250221011217"
7
+ version = "0.5.3.dev20250221230243"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -1,4 +1,5 @@
1
- from typing import Literal, Optional
1
+ from typing import Literal
2
+ from typing import Optional
2
3
 
3
4
  import torch
4
5
  import triton
@@ -178,15 +179,13 @@ class LigerTVDLossFunction(torch.autograd.Function):
178
179
  """
179
180
  has_label = False
180
181
  if shift_labels is not None:
181
- assert shift_labels.shape == (
182
- p.shape[0],
183
- ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
182
+ assert shift_labels.shape == (p.shape[0],), (
183
+ f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
184
+ )
184
185
  shift_labels = shift_labels.contiguous()
185
186
  has_label = True
186
187
 
187
- loss, grads = tv_distance_forward_triton(
188
- p, q, shift_labels, reduction, ignore_index, has_label
189
- )
188
+ loss, grads = tv_distance_forward_triton(p, q, shift_labels, reduction, ignore_index, has_label)
190
189
  ctx.save_for_backward(grads)
191
190
  return loss
192
191
 
@@ -14,6 +14,7 @@ from liger_kernel.ops.rope import LigerRopeFunction
14
14
  from liger_kernel.ops.swiglu import LigerSiLUMulFunction
15
15
  from liger_kernel.ops.tvd import LigerTVDLossFunction
16
16
 
17
+
17
18
  # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
18
19
  # `weight` and `size_average` are placeholders and not implemented yet
19
20
  def liger_cross_entropy(
@@ -156,6 +157,7 @@ def liger_kl_div(
156
157
  eps,
157
158
  )
158
159
 
160
+
159
161
  def liger_tvd(
160
162
  input,
161
163
  target,
@@ -169,7 +171,8 @@ def liger_tvd(
169
171
  shift_labels,
170
172
  reduction,
171
173
  ignore_index,
172
- )
174
+ )
175
+
173
176
 
174
177
  def liger_layer_norm(X, W, B, eps):
175
178
  return LigerLayerNormFunction.apply(X, W, B, eps)
@@ -10,6 +10,4 @@ class LigerTVDLoss(nn.Module):
10
10
  self.ignore_index = ignore_index
11
11
 
12
12
  def forward(self, p, q, shift_labels=None):
13
- return LigerTVDLossFunction.apply(
14
- p, q, shift_labels, self.reduction, self.ignore_index
15
- )
13
+ return LigerTVDLossFunction.apply(p, q, shift_labels, self.reduction, self.ignore_index)
@@ -9,5 +9,7 @@ def infer_device():
9
9
  return "cuda"
10
10
  elif torch.xpu.is_available():
11
11
  return "xpu"
12
+ elif torch.hip.is_available():
13
+ return "hip"
12
14
  else:
13
15
  return "cpu"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.3.dev20250221011217
3
+ Version: 0.5.3.dev20250221230243
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -313,6 +313,7 @@ loss.backward()
313
313
  | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
314
314
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
315
315
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
316
+ | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
316
317
 
317
318
 
318
319
  ## Low-level APIs
@@ -7,8 +7,6 @@ from transformers.models.gemma import GemmaConfig
7
7
  from transformers.models.gemma import GemmaForCausalLM
8
8
  from transformers.models.gemma2 import Gemma2Config
9
9
  from transformers.models.gemma2 import Gemma2ForCausalLM
10
- from transformers.models.granite import GraniteConfig
11
- from transformers.models.granite import GraniteForCausalLM
12
10
  from transformers.models.llama import LlamaConfig
13
11
  from transformers.models.llama import LlamaForCausalLM
14
12
  from transformers.models.mistral import MistralConfig
@@ -65,44 +63,19 @@ try:
65
63
  except ImportError:
66
64
  QWEN2_VL_AVAILABLE = False
67
65
 
66
+ try:
67
+ from transformers.models.granite import GraniteConfig
68
+ from transformers.models.granite import GraniteForCausalLM
69
+
70
+ GRANITE_AVAILABLE = True
71
+ except ImportError:
72
+ GRANITE_AVAILABLE = False
73
+
68
74
  from liger_kernel.utils import infer_device
69
75
 
70
76
  device = infer_device()
71
77
 
72
78
  MINI_MODEL_SETUPS = {
73
- "mini_granite3": MiniModelConfig(
74
- liger_kernel_patch_func=apply_liger_kernel_to_granite,
75
- liger_kernel_patch_revert_func=revert_liger_kernel_to_granite,
76
- model_class=GraniteForCausalLM,
77
- mini_model_config=GraniteConfig(
78
- attention_bias=False,
79
- attention_dropout=0.1,
80
- # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset
81
- # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
82
- bos_token_id=1, # 128000
83
- eos_token_id=2, # 128001
84
- hidden_act="silu",
85
- hidden_size=1024, # 4096
86
- initializer_range=0.02,
87
- intermediate_size=2048, # 14336
88
- max_position_embeddings=8192,
89
- num_attention_heads=8, # 32
90
- num_hidden_layers=4, # 32
91
- num_key_value_heads=2, # 8
92
- pretraining_tp=1,
93
- rms_norm_eps=1e-5,
94
- rope_scaling=None,
95
- rope_theta=500000.0,
96
- tie_word_embeddings=False,
97
- use_cache=True,
98
- vocab_size=32000, # 128256,
99
- # At rope backward
100
- # Eager produces incontiguous dq and dk
101
- # SDPA produces contiguous dq and incontiguous dk
102
- # Flash_attn produces contiguous dq and dk
103
- attn_implementation="sdpa", # default value, pytorch native attention
104
- ),
105
- ),
106
79
  "mini_llama3": MiniModelConfig(
107
80
  liger_kernel_patch_func=apply_liger_kernel_to_llama,
108
81
  liger_kernel_patch_revert_func=revert_liger_kernel_to_llama,
@@ -418,6 +391,41 @@ if QWEN2_VL_AVAILABLE:
418
391
  ),
419
392
  )
420
393
 
394
+ if GRANITE_AVAILABLE:
395
+ MINI_MODEL_SETUPS["mini_granite3"] = MiniModelConfig(
396
+ liger_kernel_patch_func=apply_liger_kernel_to_granite,
397
+ liger_kernel_patch_revert_func=revert_liger_kernel_to_granite,
398
+ model_class=GraniteForCausalLM,
399
+ mini_model_config=GraniteConfig(
400
+ attention_bias=False,
401
+ attention_dropout=0.1,
402
+ # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset
403
+ # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
404
+ bos_token_id=1, # 128000
405
+ eos_token_id=2, # 128001
406
+ hidden_act="silu",
407
+ hidden_size=1024, # 4096
408
+ initializer_range=0.02,
409
+ intermediate_size=2048, # 14336
410
+ max_position_embeddings=8192,
411
+ num_attention_heads=8, # 32
412
+ num_hidden_layers=4, # 32
413
+ num_key_value_heads=2, # 8
414
+ pretraining_tp=1,
415
+ rms_norm_eps=1e-5,
416
+ rope_scaling=None,
417
+ rope_theta=500000.0,
418
+ tie_word_embeddings=False,
419
+ use_cache=True,
420
+ vocab_size=32000, # 128256,
421
+ # At rope backward
422
+ # Eager produces incontiguous dq and dk
423
+ # SDPA produces contiguous dq and incontiguous dk
424
+ # Flash_attn produces contiguous dq and dk
425
+ attn_implementation="sdpa", # default value, pytorch native attention
426
+ ),
427
+ )
428
+
421
429
 
422
430
  def create_model(model_name="mini_llama3"):
423
431
  """
@@ -462,7 +470,8 @@ def run_mini_model(
462
470
  else:
463
471
  kwargs["swiglu"] = True
464
472
 
465
- kwargs["fused_linear_cross_entropy"] = True
473
+ # fused_linear_cross_entropy is not supported in mini_granite3
474
+ kwargs["fused_linear_cross_entropy"] = True if model_name != "mini_granite3" else False
466
475
  kwargs["cross_entropy"] = False
467
476
 
468
477
  MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs)
@@ -518,7 +527,13 @@ def run_mini_model(
518
527
  1e-2,
519
528
  1e-2,
520
529
  1e-2,
521
- marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
530
+ marks=[
531
+ pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
532
+ pytest.mark.skipif(
533
+ not GRANITE_AVAILABLE,
534
+ reason="Granite not available in this version of transformers",
535
+ ),
536
+ ],
522
537
  ),
523
538
  pytest.param(
524
539
  "mini_mllama",
@@ -7,8 +7,6 @@ from transformers.models.gemma import GemmaConfig
7
7
  from transformers.models.gemma import GemmaForCausalLM
8
8
  from transformers.models.gemma2 import Gemma2Config
9
9
  from transformers.models.gemma2 import Gemma2ForCausalLM
10
- from transformers.models.granite import GraniteConfig
11
- from transformers.models.granite import GraniteForCausalLM
12
10
  from transformers.models.llama import LlamaConfig
13
11
  from transformers.models.llama import LlamaForCausalLM
14
12
  from transformers.models.mistral import MistralConfig
@@ -65,6 +63,14 @@ try:
65
63
  except ImportError:
66
64
  QWEN2_VL_AVAILABLE = False
67
65
 
66
+ try:
67
+ from transformers.models.granite import GraniteConfig
68
+ from transformers.models.granite import GraniteForCausalLM
69
+
70
+ GRANITE_AVAILABLE = True
71
+ except ImportError:
72
+ GRANITE_AVAILABLE = False
73
+
68
74
  from liger_kernel.utils import infer_device
69
75
 
70
76
  device = infer_device()
@@ -103,40 +109,6 @@ MINI_MODEL_SETUPS = {
103
109
  attn_implementation="sdpa", # default value, pytorch native attention
104
110
  ),
105
111
  ),
106
- "mini_granite3": MiniModelConfig(
107
- liger_kernel_patch_func=apply_liger_kernel_to_granite,
108
- liger_kernel_patch_revert_func=revert_liger_kernel_to_granite,
109
- model_class=GraniteForCausalLM,
110
- mini_model_config=GraniteConfig(
111
- attention_bias=False,
112
- attention_dropout=0.0,
113
- # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset
114
- # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
115
- bos_token_id=1, # 128000
116
- eos_token_id=2, # 128001
117
- hidden_act="silu",
118
- hidden_size=1024, # 4096
119
- initializer_range=0.02,
120
- intermediate_size=2048, # 14336
121
- max_position_embeddings=8192,
122
- num_attention_heads=8, # 32
123
- num_hidden_layers=4, # 32
124
- num_key_value_heads=2, # 8
125
- pretraining_tp=1,
126
- rms_norm_eps=1e-5,
127
- rope_scaling=None,
128
- rope_theta=500000.0,
129
- tie_word_embeddings=False,
130
- use_cache=True,
131
- vocab_size=32000, # 128256,
132
- logits_scaling=8.0,
133
- # At rope backward
134
- # Eager produces incontiguous dq and dk
135
- # SDPA produces contiguous dq and incontiguous dk
136
- # Flash_attn produces contiguous dq and dk
137
- attn_implementation="sdpa", # default value, pytorch native attention
138
- ),
139
- ),
140
112
  "mini_qwen2": MiniModelConfig(
141
113
  liger_kernel_patch_func=apply_liger_kernel_to_qwen2,
142
114
  liger_kernel_patch_revert_func=revert_liger_kernel_to_qwen2,
@@ -419,6 +391,42 @@ if QWEN2_VL_AVAILABLE:
419
391
  ),
420
392
  )
421
393
 
394
+ if GRANITE_AVAILABLE:
395
+ MINI_MODEL_SETUPS["mini_granite3"] = MiniModelConfig(
396
+ liger_kernel_patch_func=apply_liger_kernel_to_granite,
397
+ liger_kernel_patch_revert_func=revert_liger_kernel_to_granite,
398
+ model_class=GraniteForCausalLM,
399
+ mini_model_config=GraniteConfig(
400
+ attention_bias=False,
401
+ attention_dropout=0.0,
402
+ # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset
403
+ # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
404
+ bos_token_id=1, # 128000
405
+ eos_token_id=2, # 128001
406
+ hidden_act="silu",
407
+ hidden_size=1024, # 4096
408
+ initializer_range=0.02,
409
+ intermediate_size=2048, # 14336
410
+ max_position_embeddings=8192,
411
+ num_attention_heads=8, # 32
412
+ num_hidden_layers=4, # 32
413
+ num_key_value_heads=2, # 8
414
+ pretraining_tp=1,
415
+ rms_norm_eps=1e-5,
416
+ rope_scaling=None,
417
+ rope_theta=500000.0,
418
+ tie_word_embeddings=False,
419
+ use_cache=True,
420
+ vocab_size=32000, # 128256,
421
+ logits_scaling=8.0,
422
+ # At rope backward
423
+ # Eager produces incontiguous dq and dk
424
+ # SDPA produces contiguous dq and incontiguous dk
425
+ # Flash_attn produces contiguous dq and dk
426
+ attn_implementation="sdpa", # default value, pytorch native attention
427
+ ),
428
+ )
429
+
422
430
 
423
431
  def create_model(model_name="mini_llama3"):
424
432
  """
@@ -518,7 +526,13 @@ def run_mini_model(
518
526
  1e-2, # logits rtol
519
527
  1e-2,
520
528
  1e-2,
521
- marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
529
+ marks=[
530
+ pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
531
+ pytest.mark.skipif(
532
+ not GRANITE_AVAILABLE,
533
+ reason="Granite not available in this version of transformers",
534
+ ),
535
+ ],
522
536
  ),
523
537
  pytest.param(
524
538
  "mini_mllama",