liger-kernel-nightly 0.5.5.dev20250328142748__tar.gz → 0.5.5.dev20250331170510__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (240) hide show
  1. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/pyproject.toml +1 -1
  3. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/__init__.py +1 -0
  4. liger_kernel_nightly-0.5.5.dev20250331170510/src/liger_kernel/transformers/model/llava.py +369 -0
  5. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/monkey_patch.py +84 -1
  6. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
  7. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel_nightly.egg-info/SOURCES.txt +4 -0
  8. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/convergence/bf16/test_mini_models.py +93 -0
  9. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/convergence/bf16/test_mini_models_multimodal.py +138 -2
  10. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/convergence/bf16/test_mini_models_with_logits.py +94 -0
  11. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/convergence/fp32/test_mini_models.py +90 -0
  12. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/convergence/fp32/test_mini_models_multimodal.py +134 -1
  13. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/convergence/fp32/test_mini_models_with_logits.py +92 -0
  14. liger_kernel_nightly-0.5.5.dev20250331170510/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +28 -0
  15. liger_kernel_nightly-0.5.5.dev20250331170510/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +7 -0
  16. liger_kernel_nightly-0.5.5.dev20250331170510/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +66 -0
  17. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/utils.py +31 -0
  18. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  19. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  20. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/.github/pull_request_template.md +0 -0
  21. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/.github/workflows/amd-ci.yml +0 -0
  22. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/.github/workflows/docs.yml +0 -0
  23. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/.github/workflows/intel-ci.yml +0 -0
  24. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/.github/workflows/nvi-ci.yml +0 -0
  25. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/.github/workflows/publish-nightly.yml +0 -0
  26. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/.github/workflows/publish-release.yml +0 -0
  27. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/.gitignore +0 -0
  28. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/LICENSE +0 -0
  29. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/Makefile +0 -0
  30. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/NOTICE +0 -0
  31. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/README.md +0 -0
  32. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/benchmark/README.md +0 -0
  33. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/benchmark/__init__.py +0 -0
  34. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/benchmark/benchmarks_visualizer.py +0 -0
  35. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/benchmark/data/all_benchmark_data.csv +0 -0
  36. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/benchmark/scripts/__init__.py +0 -0
  37. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  38. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  39. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  40. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  41. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/benchmark/scripts/benchmark_dyt.py +0 -0
  42. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/benchmark/scripts/benchmark_embedding.py +0 -0
  43. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  44. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  45. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/benchmark/scripts/benchmark_geglu.py +0 -0
  46. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/benchmark/scripts/benchmark_group_norm.py +0 -0
  47. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/benchmark/scripts/benchmark_jsd.py +0 -0
  48. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/benchmark/scripts/benchmark_kl_div.py +0 -0
  49. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  50. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  51. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  52. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  53. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  54. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/benchmark/scripts/benchmark_rope.py +0 -0
  55. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  56. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/benchmark/scripts/benchmark_swiglu.py +0 -0
  57. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/benchmark/scripts/benchmark_tvd.py +0 -0
  58. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/benchmark/scripts/utils.py +0 -0
  59. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/dev/fmt-requirements.txt +0 -0
  60. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/dev/modal/tests.py +0 -0
  61. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/dev/modal/tests_bwd.py +0 -0
  62. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/docs/Examples.md +0 -0
  63. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/docs/Getting-Started.md +0 -0
  64. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/docs/High-Level-APIs.md +0 -0
  65. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/docs/Low-Level-APIs.md +0 -0
  66. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/docs/acknowledgement.md +0 -0
  67. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/docs/contributing.md +0 -0
  68. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/docs/images/banner.GIF +0 -0
  69. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/docs/images/compose.gif +0 -0
  70. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/docs/images/e2e-memory.png +0 -0
  71. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/docs/images/e2e-tps.png +0 -0
  72. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/docs/images/logo-banner.png +0 -0
  73. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/docs/images/patch.gif +0 -0
  74. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/docs/images/post-training.png +0 -0
  75. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/docs/index.md +0 -0
  76. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/docs/license.md +0 -0
  77. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/alignment/accelerate_config.yaml +0 -0
  78. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/alignment/run_orpo.py +0 -0
  79. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/huggingface/README.md +0 -0
  80. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/huggingface/callback.py +0 -0
  81. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/huggingface/config/fsdp_config.json +0 -0
  82. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  83. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  84. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  85. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/huggingface/img/llama_tps.png +0 -0
  86. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  87. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/huggingface/img/qwen_tps.png +0 -0
  88. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/huggingface/launch_on_modal.py +0 -0
  89. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/huggingface/requirements.txt +0 -0
  90. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/huggingface/run_benchmarks.sh +0 -0
  91. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/huggingface/run_gemma.sh +0 -0
  92. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/huggingface/run_llama.sh +0 -0
  93. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/huggingface/run_qwen.sh +0 -0
  94. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/huggingface/run_qwen2_vl.sh +0 -0
  95. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/huggingface/training.py +0 -0
  96. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/huggingface/training_multimodal.py +0 -0
  97. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/lightning/README.md +0 -0
  98. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/lightning/requirements.txt +0 -0
  99. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/lightning/training.py +0 -0
  100. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/medusa/README.md +0 -0
  101. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/medusa/callback.py +0 -0
  102. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  103. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  104. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  105. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  106. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  107. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  108. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  109. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  110. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  111. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/medusa/medusa_util.py +0 -0
  112. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/medusa/requirements.txt +0 -0
  113. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  114. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/examples/medusa/train.py +0 -0
  115. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/licenses/LICENSE-Apache-2.0 +0 -0
  116. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  117. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  118. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/licenses/LICENSE-MIT-llmc +0 -0
  119. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/licenses/LICENSE-MIT-triton +0 -0
  120. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/mkdocs.yml +0 -0
  121. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/setup.cfg +0 -0
  122. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/setup.py +0 -0
  123. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/__init__.py +0 -0
  124. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/chunked_loss/README.md +0 -0
  125. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  126. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  127. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  128. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/chunked_loss/functional.py +0 -0
  129. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  130. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  131. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -0
  132. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  133. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
  134. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  135. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  136. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  137. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  138. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/env_report.py +0 -0
  139. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/ops/__init__.py +0 -0
  140. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/ops/cross_entropy.py +0 -0
  141. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/ops/dyt.py +0 -0
  142. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  143. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  144. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  145. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  146. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/ops/geglu.py +0 -0
  147. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/ops/group_norm.py +0 -0
  148. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/ops/jsd.py +0 -0
  149. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/ops/kl_div.py +0 -0
  150. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/ops/layer_norm.py +0 -0
  151. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  152. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/ops/rms_norm.py +0 -0
  153. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/ops/rope.py +0 -0
  154. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/ops/swiglu.py +0 -0
  155. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/ops/tvd.py +0 -0
  156. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/ops/utils.py +0 -0
  157. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/auto_model.py +0 -0
  158. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  159. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/dyt.py +0 -0
  160. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  161. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/functional.py +0 -0
  162. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  163. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  164. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/geglu.py +0 -0
  165. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/group_norm.py +0 -0
  166. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/jsd.py +0 -0
  167. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/kl_div.py +0 -0
  168. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/layer_norm.py +0 -0
  169. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/model/__init__.py +0 -0
  170. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/model/gemma.py +0 -0
  171. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  172. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/model/llama.py +0 -0
  173. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
  174. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/model/mistral.py +0 -0
  175. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  176. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/model/mllama.py +0 -0
  177. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/model/olmo2.py +0 -0
  178. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/model/paligemma.py +0 -0
  179. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/model/phi3.py +0 -0
  180. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  181. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
  182. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  183. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  184. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/rms_norm.py +0 -0
  185. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/rope.py +0 -0
  186. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/swiglu.py +0 -0
  187. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  188. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  189. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  190. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/transformers/tvd.py +0 -0
  191. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/triton/__init__.py +0 -0
  192. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/triton/monkey_patch.py +0 -0
  193. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel/utils.py +0 -0
  194. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  195. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  196. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  197. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/__init__.py +0 -0
  198. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/chunked_loss/__init__.py +0 -0
  199. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/chunked_loss/test_cpo_loss.py +0 -0
  200. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/chunked_loss/test_dpo_loss.py +0 -0
  201. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/chunked_loss/test_grpo_loss.py +0 -0
  202. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/chunked_loss/test_jsd_loss.py +0 -0
  203. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/chunked_loss/test_kto_loss.py +0 -0
  204. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/chunked_loss/test_orpo_loss.py +0 -0
  205. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/chunked_loss/test_simpo_loss.py +0 -0
  206. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/conftest.py +0 -0
  207. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/convergence/__init__.py +0 -0
  208. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/convergence/bf16/__init__.py +0 -0
  209. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/convergence/fp32/__init__.py +0 -0
  210. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
  211. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  212. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
  213. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  214. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  215. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/resources/tiny_shakespeare.txt +0 -0
  216. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  217. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  218. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  219. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/transformers/test_auto_model.py +0 -0
  220. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/transformers/test_cross_entropy.py +0 -0
  221. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/transformers/test_dyt.py +0 -0
  222. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/transformers/test_embedding.py +0 -0
  223. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/transformers/test_flex_attention.py +0 -0
  224. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  225. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/transformers/test_fused_linear_jsd.py +0 -0
  226. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/transformers/test_geglu.py +0 -0
  227. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/transformers/test_group_norm.py +0 -0
  228. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/transformers/test_jsd.py +0 -0
  229. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/transformers/test_kl_div.py +0 -0
  230. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/transformers/test_layer_norm.py +0 -0
  231. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/transformers/test_mm_int8int2.py +0 -0
  232. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/transformers/test_monkey_patch.py +0 -0
  233. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/transformers/test_qwen2vl_mrope.py +0 -0
  234. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/transformers/test_rms_norm.py +0 -0
  235. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/transformers/test_rope.py +0 -0
  236. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/transformers/test_swiglu.py +0 -0
  237. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/transformers/test_trainer_integration.py +0 -0
  238. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/transformers/test_transformers.py +0 -0
  239. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/transformers/test_tvd.py +0 -0
  240. {liger_kernel_nightly-0.5.5.dev20250328142748 → liger_kernel_nightly-0.5.5.dev20250331170510}/test/triton/test_triton_monkey_patch.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.5.dev20250328142748
3
+ Version: 0.5.5.dev20250331170510
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.5.dev20250328142748"
7
+ version = "0.5.5.dev20250331170510"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -12,6 +12,7 @@ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma
12
12
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
13
13
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
14
14
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
15
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
15
16
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
16
17
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
17
18
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
@@ -0,0 +1,369 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from transformers.models.llava.modeling_llava import _CONFIG_FOR_DOC
9
+ from transformers.models.llava.modeling_llava import LLAVA_INPUTS_DOCSTRING
10
+ from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
11
+ from transformers.utils import add_start_docstrings_to_model_forward
12
+ from transformers.utils import is_torchdynamo_compiling
13
+ from transformers.utils import replace_return_docstrings
14
+ from transformers.utils.deprecation import deprecate_kwarg
15
+
16
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
+
18
+
19
+ @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
20
+ @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
21
+ def lce_forward_deprecated(
22
+ self,
23
+ input_ids: torch.LongTensor = None,
24
+ pixel_values: torch.FloatTensor = None,
25
+ attention_mask: Optional[torch.Tensor] = None,
26
+ position_ids: Optional[torch.LongTensor] = None,
27
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
28
+ inputs_embeds: Optional[torch.FloatTensor] = None,
29
+ vision_feature_layer: Optional[int] = None,
30
+ vision_feature_select_strategy: Optional[str] = None,
31
+ labels: Optional[torch.LongTensor] = None,
32
+ use_cache: Optional[bool] = None,
33
+ output_attentions: Optional[bool] = None,
34
+ output_hidden_states: Optional[bool] = None,
35
+ return_dict: Optional[bool] = None,
36
+ ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
37
+ r"""
38
+ Args:
39
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
40
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
41
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
42
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
43
+
44
+ num_logits_to_keep (`int`, *optional*):
45
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
46
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
47
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
48
+
49
+
50
+ Returns:
51
+
52
+ Example:
53
+
54
+ ```python
55
+ >>> from PIL import Image
56
+ >>> import requests
57
+ >>> from transformers import AutoProcessor, LlavaForConditionalGeneration
58
+
59
+ >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
60
+ >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
61
+
62
+ >>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
63
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
64
+ >>> image = Image.open(requests.get(url, stream=True).raw)
65
+
66
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
67
+
68
+ >>> # Generate
69
+ >>> generate_ids = model.generate(**inputs, max_new_tokens=15)
70
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
71
+ "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
72
+ ```"""
73
+
74
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
75
+ output_hidden_states = (
76
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
77
+ )
78
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
79
+ vision_feature_layer = (
80
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
81
+ )
82
+ vision_feature_select_strategy = (
83
+ vision_feature_select_strategy
84
+ if vision_feature_select_strategy is not None
85
+ else self.config.vision_feature_select_strategy
86
+ )
87
+
88
+ if (input_ids is None) ^ (inputs_embeds is not None):
89
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
90
+
91
+ if pixel_values is not None and inputs_embeds is not None:
92
+ raise ValueError(
93
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
94
+ )
95
+
96
+ if inputs_embeds is None:
97
+ # 1. Extra the input embeddings
98
+ inputs_embeds = self.get_input_embeddings()(input_ids)
99
+
100
+ # 2. Merge text and images
101
+ if pixel_values is not None and input_ids.shape[1] != 1:
102
+ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
103
+ # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
104
+ selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
105
+
106
+ if vision_feature_select_strategy == "default":
107
+ selected_image_feature = selected_image_feature[:, 1:]
108
+ elif vision_feature_select_strategy == "full":
109
+ selected_image_feature = selected_image_feature
110
+ else:
111
+ raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
112
+
113
+ image_features = self.multi_modal_projector(selected_image_feature)
114
+ inputs_embeds = inputs_embeds.to(image_features.dtype)
115
+ inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
116
+ image_features, inputs_embeds, input_ids, attention_mask, labels
117
+ )
118
+
119
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
120
+ # generation with cache
121
+ elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
122
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
123
+ # that are set to 0
124
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
125
+
126
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
127
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
128
+
129
+ # Get the target length
130
+ target_length = input_ids.shape[1]
131
+ past_length = first_layer_past_key_value.shape[-1]
132
+
133
+ extended_attention_mask = torch.ones(
134
+ (attention_mask.shape[0], past_length),
135
+ dtype=attention_mask.dtype,
136
+ device=attention_mask.device,
137
+ )
138
+
139
+ # Filter out only the tokens that can be un-attended, this can happen
140
+ # if one uses Llava + Fused modules where the cache on the
141
+ # first iteration is already big enough, or if one passes custom cache
142
+ valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
143
+ new_batch_index = batch_index[valid_indices]
144
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
145
+
146
+ # Zero-out the places where we don't need to attend
147
+ extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
148
+
149
+ attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
150
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
151
+
152
+ # TODO: @raushan retain only the new behavior after v4.47
153
+ elif image_features is not None:
154
+ n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
155
+ n_image_features = image_features.shape[0] * image_features.shape[1]
156
+
157
+ if n_image_tokens != n_image_features:
158
+ raise ValueError(
159
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
160
+ )
161
+ special_image_mask = (
162
+ (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
163
+ )
164
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
165
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
166
+
167
+ outputs = self.language_model.model(
168
+ attention_mask=attention_mask,
169
+ position_ids=position_ids,
170
+ past_key_values=past_key_values,
171
+ inputs_embeds=inputs_embeds,
172
+ use_cache=use_cache,
173
+ output_attentions=output_attentions,
174
+ output_hidden_states=output_hidden_states,
175
+ return_dict=return_dict,
176
+ )
177
+ hidden_states = outputs[0]
178
+
179
+ loss = None
180
+ logits = None
181
+
182
+ if self.training and (labels is not None):
183
+ # Shift so that tokens < n predict n
184
+ if attention_mask is not None:
185
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
186
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
187
+ shift_attention_mask = attention_mask[:, -(hidden_states.shape[1] - 1) :].to(hidden_states.device)
188
+ shift_hidden_states = hidden_states[..., :-1, :][
189
+ shift_attention_mask.to(hidden_states.device) != 0
190
+ ].contiguous()
191
+ shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
192
+ else:
193
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
194
+ shift_labels = labels[..., 1:].contiguous()
195
+
196
+ lce = LigerFusedLinearCrossEntropyLoss()
197
+ loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
198
+
199
+ if not return_dict:
200
+ # NOTE: This part has not been tested.
201
+ output = outputs[1:]
202
+ return (loss,) + output if loss is not None else output
203
+
204
+ return LlavaCausalLMOutputWithPast(
205
+ loss=loss,
206
+ logits=logits,
207
+ past_key_values=outputs.past_key_values,
208
+ hidden_states=outputs.hidden_states,
209
+ attentions=outputs.attentions,
210
+ )
211
+
212
+
213
+ @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
214
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
215
+ @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
216
+ def lce_forward(
217
+ self,
218
+ input_ids: torch.LongTensor = None,
219
+ pixel_values: torch.FloatTensor = None,
220
+ attention_mask: Optional[torch.Tensor] = None,
221
+ position_ids: Optional[torch.LongTensor] = None,
222
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
223
+ inputs_embeds: Optional[torch.FloatTensor] = None,
224
+ vision_feature_layer: Optional[int] = None,
225
+ vision_feature_select_strategy: Optional[str] = None,
226
+ labels: Optional[torch.LongTensor] = None,
227
+ use_cache: Optional[bool] = None,
228
+ output_attentions: Optional[bool] = None,
229
+ output_hidden_states: Optional[bool] = None,
230
+ return_dict: Optional[bool] = None,
231
+ cache_position: Optional[torch.LongTensor] = None,
232
+ logits_to_keep: Union[int, torch.Tensor] = 0,
233
+ image_sizes: torch.Tensor = None,
234
+ **lm_kwargs,
235
+ ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
236
+ r"""
237
+ Args:
238
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
239
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
240
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
241
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
242
+
243
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
244
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
245
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
246
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
247
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
248
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
249
+
250
+
251
+ Returns:
252
+
253
+ Example:
254
+
255
+ ```python
256
+ >>> from PIL import Image
257
+ >>> import requests
258
+ >>> from transformers import AutoProcessor, LlavaForConditionalGeneration
259
+
260
+ >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
261
+ >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
262
+
263
+ >>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
264
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
265
+ >>> image = Image.open(requests.get(url, stream=True).raw)
266
+
267
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
268
+
269
+ >>> # Generate
270
+ >>> generate_ids = model.generate(**inputs, max_new_tokens=15)
271
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
272
+ "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
273
+ ```"""
274
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
275
+ output_hidden_states = (
276
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
277
+ )
278
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
279
+ vision_feature_layer = (
280
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
281
+ )
282
+ vision_feature_select_strategy = (
283
+ vision_feature_select_strategy
284
+ if vision_feature_select_strategy is not None
285
+ else self.config.vision_feature_select_strategy
286
+ )
287
+
288
+ if (input_ids is None) ^ (inputs_embeds is not None):
289
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
290
+
291
+ if pixel_values is not None and inputs_embeds is not None:
292
+ raise ValueError(
293
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
294
+ )
295
+
296
+ if inputs_embeds is None:
297
+ inputs_embeds = self.get_input_embeddings()(input_ids)
298
+
299
+ if pixel_values is not None:
300
+ image_features = self.get_image_features(
301
+ pixel_values=pixel_values,
302
+ vision_feature_layer=vision_feature_layer,
303
+ vision_feature_select_strategy=vision_feature_select_strategy,
304
+ image_sizes=image_sizes,
305
+ )
306
+
307
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
308
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
309
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
310
+ n_image_tokens = (input_ids == self.config.image_token_index).sum()
311
+ n_image_features = image_features.shape[0] * image_features.shape[1]
312
+ raise ValueError(
313
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
314
+ )
315
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
316
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
317
+
318
+ outputs = self.language_model.model(
319
+ attention_mask=attention_mask,
320
+ position_ids=position_ids,
321
+ past_key_values=past_key_values,
322
+ inputs_embeds=inputs_embeds,
323
+ use_cache=use_cache,
324
+ output_attentions=output_attentions,
325
+ output_hidden_states=output_hidden_states,
326
+ return_dict=return_dict,
327
+ cache_position=cache_position,
328
+ logits_to_keep=logits_to_keep,
329
+ **lm_kwargs,
330
+ )
331
+ hidden_states = outputs[0]
332
+
333
+ loss = None
334
+ logits = None
335
+
336
+ if self.training and (labels is not None):
337
+ # Shift so that tokens < n predict n
338
+ if attention_mask is not None:
339
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
340
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
341
+ shift_attention_mask = attention_mask[:, -(hidden_states.shape[1] - 1) :].to(hidden_states.device)
342
+ shift_hidden_states = hidden_states[..., :-1, :][
343
+ shift_attention_mask.to(hidden_states.device) != 0
344
+ ].contiguous()
345
+ shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
346
+ else:
347
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
348
+ shift_labels = labels[..., 1:].contiguous()
349
+
350
+ lce = LigerFusedLinearCrossEntropyLoss()
351
+ loss = lce(
352
+ self.language_model.lm_head.weight,
353
+ shift_hidden_states.view(-1, shift_hidden_states.size(-1)),
354
+ shift_labels.view(-1).to(shift_hidden_states.device),
355
+ )
356
+
357
+ if not return_dict:
358
+ # NOTE: This part has not been tested.
359
+ output = outputs[1:]
360
+ return (loss,) + output if loss is not None else output
361
+
362
+ return LlavaCausalLMOutputWithPast(
363
+ loss=loss,
364
+ logits=logits,
365
+ past_key_values=outputs.past_key_values,
366
+ hidden_states=outputs.hidden_states,
367
+ attentions=outputs.attentions,
368
+ image_hidden_states=image_features if pixel_values is not None else None,
369
+ )
@@ -19,6 +19,8 @@ from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_for
19
19
  from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
20
20
  from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
21
21
  from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
22
+ from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
23
+ from liger_kernel.transformers.model.llava import lce_forward_deprecated as llava_lce_forward_deprecated
22
24
  from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
23
25
  from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
24
26
  from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
@@ -57,7 +59,8 @@ def _patch_rms_norm_module(module, offset=0.0, eps=1e-6, casting_mode="llama", i
57
59
 
58
60
  def _patch_layer_norm_module(module, eps=1e-6):
59
61
  module.variance_epsilon = getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
60
- module.hidden_size = module.normalized_shape
62
+ module.hidden_size = getattr(module, "hidden_size", None) or getattr(module, "normalized_shape", None)
63
+
61
64
  _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
62
65
  _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
63
66
  module.__class__.__name__ = LigerLayerNorm.__name__
@@ -224,6 +227,85 @@ def apply_liger_kernel_to_llama(
224
227
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
225
228
 
226
229
 
230
+ def apply_liger_kernel_to_llava(
231
+ cross_entropy: bool = False,
232
+ fused_linear_cross_entropy: bool = True,
233
+ model: PreTrainedModel = None,
234
+ **kwargs,
235
+ ) -> None:
236
+ """
237
+ Apply Liger kernels to replace original implementation in HuggingFace Llava models.
238
+ Due to the characteristics of LlaVa, the model must be passed to apply Liger-Kernel's patch to other models connected to LLaVa.
239
+ However, if an LM not supported by Liger-Kernel is connected to LLaVa, unexpected side effects may occur.
240
+ NOTE: Llava is not available in transformers<4.36.0
241
+
242
+ Args:
243
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
244
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
245
+ fused_linear_cross_entropy (bool):
246
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
247
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
248
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
249
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
250
+ swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
251
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
252
+ loaded. Default is None.
253
+ """
254
+ assert not (cross_entropy and fused_linear_cross_entropy), (
255
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
256
+ )
257
+
258
+ from transformers.models.llava import modeling_llava
259
+
260
+ if cross_entropy:
261
+ logger.warning(TRANSFORMER_DEPRECATION_WARNING)
262
+ modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
263
+ if fused_linear_cross_entropy:
264
+ if transformer_version >= version.parse("4.49.0"):
265
+ modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
266
+ else: # if version < 4.49.0
267
+ logger.warning(
268
+ "Support for transformers versions < 4.49.0 will soon be discontinued due to issues with incorrect legacy processing. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/35526"
269
+ )
270
+ modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
271
+
272
+ if model is not None:
273
+ text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
274
+ text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
275
+ vision_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(vision_model_name, None)
276
+
277
+ kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs}
278
+ if text_liger_fn:
279
+ accept_params = inspect.signature(text_liger_fn).parameters
280
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
281
+ text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
282
+
283
+ if remain_params:
284
+ logger.warning(
285
+ f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
286
+ f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
287
+ )
288
+ text_kwargs["model"] = model.language_model
289
+ text_liger_fn(**text_kwargs)
290
+ elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
291
+ logger.warning(f"{text_model_name} is not supported by Liger kernel.")
292
+
293
+ if vision_liger_fn:
294
+ accept_params = inspect.signature(vision_liger_fn).parameters
295
+ remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
296
+ vision_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
297
+
298
+ if remain_params:
299
+ logger.warning(
300
+ f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
301
+ f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
302
+ )
303
+ vision_kwargs["model"] = model.vision_tower
304
+ vision_liger_fn(**vision_kwargs)
305
+ elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
306
+ logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
307
+
308
+
227
309
  def apply_liger_kernel_to_mllama(
228
310
  rope: bool = True,
229
311
  cross_entropy: bool = False,
@@ -1071,6 +1153,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
1071
1153
  "gemma": apply_liger_kernel_to_gemma,
1072
1154
  "gemma2": apply_liger_kernel_to_gemma2,
1073
1155
  "llama": apply_liger_kernel_to_llama,
1156
+ "llava": apply_liger_kernel_to_llava,
1074
1157
  "granite": apply_liger_kernel_to_granite,
1075
1158
  "mllama": apply_liger_kernel_to_mllama,
1076
1159
  "mllama_text_model": apply_liger_kernel_to_mllama,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.5.dev20250328142748
3
+ Version: 0.5.5.dev20250331170510
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -162,6 +162,7 @@ src/liger_kernel/transformers/model/__init__.py
162
162
  src/liger_kernel/transformers/model/gemma.py
163
163
  src/liger_kernel/transformers/model/gemma2.py
164
164
  src/liger_kernel/transformers/model/llama.py
165
+ src/liger_kernel/transformers/model/llava.py
165
166
  src/liger_kernel/transformers/model/loss_utils.py
166
167
  src/liger_kernel/transformers/model/mistral.py
167
168
  src/liger_kernel/transformers/model/mixtral.py
@@ -203,6 +204,9 @@ test/convergence/fp32/test_mini_models_multimodal.py
203
204
  test/convergence/fp32/test_mini_models_with_logits.py
204
205
  test/resources/tiny_shakespeare.txt
205
206
  test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json
207
+ test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json
208
+ test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json
209
+ test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json
206
210
  test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json
207
211
  test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json
208
212
  test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json
@@ -22,6 +22,7 @@ from liger_kernel.transformers import apply_liger_kernel_to_gemma
22
22
  from liger_kernel.transformers import apply_liger_kernel_to_gemma2
23
23
  from liger_kernel.transformers import apply_liger_kernel_to_granite
24
24
  from liger_kernel.transformers import apply_liger_kernel_to_llama
25
+ from liger_kernel.transformers import apply_liger_kernel_to_llava
25
26
  from liger_kernel.transformers import apply_liger_kernel_to_mistral
26
27
  from liger_kernel.transformers import apply_liger_kernel_to_mixtral
27
28
  from liger_kernel.transformers import apply_liger_kernel_to_mllama
@@ -37,6 +38,7 @@ from test.utils import revert_liger_kernel_to_gemma
37
38
  from test.utils import revert_liger_kernel_to_gemma2
38
39
  from test.utils import revert_liger_kernel_to_granite
39
40
  from test.utils import revert_liger_kernel_to_llama
41
+ from test.utils import revert_liger_kernel_to_llava
40
42
  from test.utils import revert_liger_kernel_to_mistral
41
43
  from test.utils import revert_liger_kernel_to_mixtral
42
44
  from test.utils import revert_liger_kernel_to_mllama
@@ -84,6 +86,15 @@ try:
84
86
  except ImportError:
85
87
  GRANITE_AVAILABLE = False
86
88
 
89
+ try:
90
+ from transformers import CLIPVisionConfig
91
+ from transformers.models.llava.configuration_llava import LlavaConfig
92
+ from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration
93
+
94
+ LLAVA_AVAILABLE = True
95
+ except ImportError:
96
+ LLAVA_AVAILABLE = False
97
+
87
98
  try:
88
99
  # OLMO2 is only available in transformers>=4.47.0
89
100
  from transformers.models.olmo2.configuration_olmo2 import Olmo2Config
@@ -93,6 +104,7 @@ try:
93
104
  except ImportError:
94
105
  OLMO2_AVAILABLE = False
95
106
 
107
+
96
108
  from liger_kernel.utils import infer_device
97
109
 
98
110
  device = infer_device()
@@ -504,6 +516,65 @@ if GRANITE_AVAILABLE:
504
516
  ),
505
517
  )
506
518
 
519
+ if LLAVA_AVAILABLE:
520
+ # https://huggingface.co/llava-hf/llava-1.5-7b-hf
521
+ MINI_MODEL_SETUPS["mini_llava"] = MiniModelConfig(
522
+ liger_kernel_patch_func=apply_liger_kernel_to_llava,
523
+ liger_kernel_patch_revert_func=revert_liger_kernel_to_llava,
524
+ model_class=LlavaForConditionalGeneration,
525
+ mini_model_config=LlavaConfig(
526
+ text_config=LlamaConfig(
527
+ attention_bias=False,
528
+ attention_dropout=0.0,
529
+ bos_token_id=1,
530
+ eos_token_id=2,
531
+ hidden_act="silu",
532
+ hidden_size=1024,
533
+ initializer_range=0.02,
534
+ intermediate_size=2048,
535
+ num_attention_heads=8,
536
+ num_hidden_layers=4,
537
+ num_key_value_heads=2,
538
+ pretraining_tp=1,
539
+ rope_scaling=None,
540
+ rope_theta=500000.0,
541
+ tie_word_embeddings=False,
542
+ use_cache=True,
543
+ max_position_embeddings=4096, # llava-1.5-7b-hf
544
+ rms_norm_eps=1e-05, # llava-1.5-7b-hf
545
+ vocab_size=32064, # llava-1.5-7b-hf
546
+ # At rope backward
547
+ # Eager produces incontiguous dq and dk
548
+ # SDPA produces contiguous dq and incontiguous dk
549
+ # Flash_attn produces contiguous dq and dk
550
+ attn_implementation="sdpa", # default value, pytorch native attention
551
+ ),
552
+ vision_config=CLIPVisionConfig(
553
+ hidden_size=1024,
554
+ image_size=336,
555
+ intermediate_size=2048, # 4096
556
+ model_type="clip_vision_model",
557
+ num_attention_heads=4, # 16
558
+ num_hidden_layers=4, # 24
559
+ patch_size=14,
560
+ projection_dim=768,
561
+ vocab_size=32000,
562
+ ),
563
+ vocab_size=32064,
564
+ ignore_index=-100,
565
+ pad_token_id=4,
566
+ image_token_index=3,
567
+ projector_hidden_act="gelu",
568
+ vision_feature_layer=-2,
569
+ vision_feature_select_strategy="default",
570
+ # At rope backward
571
+ # Eager produces incontiguous dq and dk
572
+ # SDPA produces contiguous dq and incontiguous dk
573
+ # Flash_attn produces contiguous dq and dk
574
+ attn_implementation="sdpa", # default value, pytorch native attention
575
+ ),
576
+ )
577
+
507
578
  if OLMO2_AVAILABLE:
508
579
  MINI_MODEL_SETUPS["mini_olmo2"] = MiniModelConfig(
509
580
  liger_kernel_patch_func=apply_liger_kernel_to_olmo2,
@@ -577,6 +648,9 @@ def run_mini_model(
577
648
  else:
578
649
  kwargs["swiglu"] = True
579
650
 
651
+ if "llava" in model_name:
652
+ apply_liger_kernel_to_llama(**kwargs)
653
+
580
654
  # fused_linear_cross_entropy is not supported in mini_granite3
581
655
  kwargs["fused_linear_cross_entropy"] = True if model_name != "mini_granite3" else False
582
656
  kwargs["cross_entropy"] = False
@@ -623,6 +697,25 @@ def run_mini_model(
623
697
  1e-2,
624
698
  marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
625
699
  ),
700
+ pytest.param(
701
+ "mini_llava",
702
+ 32,
703
+ 1e-4,
704
+ torch.bfloat16,
705
+ 1e-3,
706
+ 1e-2,
707
+ 1e-1,
708
+ 1e-2,
709
+ 1e-2,
710
+ 1e-2,
711
+ marks=[
712
+ pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
713
+ pytest.mark.skipif(
714
+ not LLAVA_AVAILABLE,
715
+ reason="LLaVa not available in this version of transformers",
716
+ ),
717
+ ],
718
+ ),
626
719
  pytest.param(
627
720
  "mini_granite3",
628
721
  32,