liger-kernel-nightly 0.5.5.dev20250314203927__tar.gz → 0.5.5.dev20250316145754__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (231) hide show
  1. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/PKG-INFO +2 -1
  2. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/README.md +1 -0
  3. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/pyproject.toml +1 -1
  4. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/__init__.py +1 -0
  5. liger_kernel_nightly-0.5.5.dev20250316145754/src/liger_kernel/transformers/model/paligemma.py +213 -0
  6. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/monkey_patch.py +85 -0
  7. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel_nightly.egg-info/PKG-INFO +2 -1
  8. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel_nightly.egg-info/SOURCES.txt +2 -0
  9. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/convergence/bf16/test_mini_models_multimodal.py +107 -0
  10. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/convergence/fp32/test_mini_models_multimodal.py +103 -0
  11. liger_kernel_nightly-0.5.5.dev20250316145754/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +61 -0
  12. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/utils.py +16 -0
  13. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  14. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  15. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/.github/pull_request_template.md +0 -0
  16. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/.github/workflows/amd-ci.yml +0 -0
  17. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/.github/workflows/docs.yml +0 -0
  18. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/.github/workflows/intel-ci.yml +0 -0
  19. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/.github/workflows/nvi-ci.yml +0 -0
  20. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/.github/workflows/publish-nightly.yml +0 -0
  21. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/.github/workflows/publish-release.yml +0 -0
  22. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/.gitignore +0 -0
  23. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/LICENSE +0 -0
  24. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/Makefile +0 -0
  25. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/NOTICE +0 -0
  26. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/benchmark/README.md +0 -0
  27. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/benchmark/__init__.py +0 -0
  28. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/benchmark/benchmarks_visualizer.py +0 -0
  29. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/benchmark/data/all_benchmark_data.csv +0 -0
  30. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/benchmark/scripts/__init__.py +0 -0
  31. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  32. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  33. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  34. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  35. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/benchmark/scripts/benchmark_embedding.py +0 -0
  36. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  37. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  38. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/benchmark/scripts/benchmark_geglu.py +0 -0
  39. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/benchmark/scripts/benchmark_group_norm.py +0 -0
  40. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/benchmark/scripts/benchmark_jsd.py +0 -0
  41. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/benchmark/scripts/benchmark_kl_div.py +0 -0
  42. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  43. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  44. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  45. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  46. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  47. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/benchmark/scripts/benchmark_rope.py +0 -0
  48. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  49. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/benchmark/scripts/benchmark_swiglu.py +0 -0
  50. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/benchmark/scripts/benchmark_tvd.py +0 -0
  51. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/benchmark/scripts/utils.py +0 -0
  52. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/dev/fmt-requirements.txt +0 -0
  53. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/dev/modal/tests.py +0 -0
  54. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/dev/modal/tests_bwd.py +0 -0
  55. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/docs/Examples.md +0 -0
  56. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/docs/Getting-Started.md +0 -0
  57. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/docs/High-Level-APIs.md +0 -0
  58. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/docs/Low-Level-APIs.md +0 -0
  59. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/docs/acknowledgement.md +0 -0
  60. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/docs/contributing.md +0 -0
  61. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/docs/images/banner.GIF +0 -0
  62. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/docs/images/compose.gif +0 -0
  63. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/docs/images/e2e-memory.png +0 -0
  64. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/docs/images/e2e-tps.png +0 -0
  65. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/docs/images/logo-banner.png +0 -0
  66. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/docs/images/patch.gif +0 -0
  67. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/docs/images/post-training.png +0 -0
  68. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/docs/index.md +0 -0
  69. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/docs/license.md +0 -0
  70. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/alignment/accelerate_config.yaml +0 -0
  71. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/alignment/run_orpo.py +0 -0
  72. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/huggingface/README.md +0 -0
  73. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/huggingface/callback.py +0 -0
  74. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/huggingface/config/fsdp_config.json +0 -0
  75. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  76. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  77. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  78. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/huggingface/img/llama_tps.png +0 -0
  79. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  80. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/huggingface/img/qwen_tps.png +0 -0
  81. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/huggingface/launch_on_modal.py +0 -0
  82. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/huggingface/requirements.txt +0 -0
  83. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/huggingface/run_benchmarks.sh +0 -0
  84. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/huggingface/run_gemma.sh +0 -0
  85. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/huggingface/run_llama.sh +0 -0
  86. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/huggingface/run_qwen.sh +0 -0
  87. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/huggingface/run_qwen2_vl.sh +0 -0
  88. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/huggingface/training.py +0 -0
  89. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/huggingface/training_multimodal.py +0 -0
  90. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/lightning/README.md +0 -0
  91. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/lightning/requirements.txt +0 -0
  92. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/lightning/training.py +0 -0
  93. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/medusa/README.md +0 -0
  94. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/medusa/callback.py +0 -0
  95. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  96. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  97. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  98. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  99. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  100. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  101. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  102. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  103. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  104. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/medusa/medusa_util.py +0 -0
  105. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/medusa/requirements.txt +0 -0
  106. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  107. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/examples/medusa/train.py +0 -0
  108. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/licenses/LICENSE-Apache-2.0 +0 -0
  109. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  110. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  111. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/licenses/LICENSE-MIT-llmc +0 -0
  112. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/licenses/LICENSE-MIT-triton +0 -0
  113. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/mkdocs.yml +0 -0
  114. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/setup.cfg +0 -0
  115. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/setup.py +0 -0
  116. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/__init__.py +0 -0
  117. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/chunked_loss/README.md +0 -0
  118. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  119. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  120. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  121. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/chunked_loss/functional.py +0 -0
  122. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  123. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  124. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -0
  125. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  126. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
  127. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  128. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  129. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  130. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  131. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/env_report.py +0 -0
  132. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/ops/__init__.py +0 -0
  133. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/ops/cross_entropy.py +0 -0
  134. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  135. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  136. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  137. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  138. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/ops/geglu.py +0 -0
  139. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/ops/group_norm.py +0 -0
  140. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/ops/jsd.py +0 -0
  141. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/ops/kl_div.py +0 -0
  142. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/ops/layer_norm.py +0 -0
  143. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  144. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/ops/rms_norm.py +0 -0
  145. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/ops/rope.py +0 -0
  146. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/ops/swiglu.py +0 -0
  147. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/ops/tvd.py +0 -0
  148. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/ops/utils.py +0 -0
  149. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/auto_model.py +0 -0
  150. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  151. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  152. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/functional.py +0 -0
  153. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  154. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  155. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/geglu.py +0 -0
  156. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/group_norm.py +0 -0
  157. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/jsd.py +0 -0
  158. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/kl_div.py +0 -0
  159. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/layer_norm.py +0 -0
  160. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/model/__init__.py +0 -0
  161. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/model/gemma.py +0 -0
  162. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  163. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/model/llama.py +0 -0
  164. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/model/mistral.py +0 -0
  165. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  166. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/model/mllama.py +0 -0
  167. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/model/olmo2.py +0 -0
  168. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/model/phi3.py +0 -0
  169. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  170. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
  171. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  172. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  173. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/rms_norm.py +0 -0
  174. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/rope.py +0 -0
  175. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/swiglu.py +0 -0
  176. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  177. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  178. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  179. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/transformers/tvd.py +0 -0
  180. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/triton/__init__.py +0 -0
  181. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/triton/monkey_patch.py +0 -0
  182. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel/utils.py +0 -0
  183. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  184. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  185. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  186. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/__init__.py +0 -0
  187. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/chunked_loss/__init__.py +0 -0
  188. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/chunked_loss/test_cpo_loss.py +0 -0
  189. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/chunked_loss/test_dpo_loss.py +0 -0
  190. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/chunked_loss/test_grpo_loss.py +0 -0
  191. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/chunked_loss/test_jsd_loss.py +0 -0
  192. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/chunked_loss/test_kto_loss.py +0 -0
  193. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/chunked_loss/test_orpo_loss.py +0 -0
  194. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/chunked_loss/test_simpo_loss.py +0 -0
  195. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/conftest.py +0 -0
  196. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/convergence/__init__.py +0 -0
  197. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/convergence/bf16/__init__.py +0 -0
  198. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/convergence/bf16/test_mini_models.py +0 -0
  199. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
  200. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/convergence/fp32/__init__.py +0 -0
  201. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/convergence/fp32/test_mini_models.py +0 -0
  202. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
  203. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  204. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
  205. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  206. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  207. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/resources/tiny_shakespeare.txt +0 -0
  208. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  209. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  210. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  211. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/transformers/test_auto_model.py +0 -0
  212. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/transformers/test_cross_entropy.py +0 -0
  213. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/transformers/test_embedding.py +0 -0
  214. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/transformers/test_flex_attention.py +0 -0
  215. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  216. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/transformers/test_fused_linear_jsd.py +0 -0
  217. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/transformers/test_geglu.py +0 -0
  218. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/transformers/test_group_norm.py +0 -0
  219. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/transformers/test_jsd.py +0 -0
  220. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/transformers/test_kl_div.py +0 -0
  221. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/transformers/test_layer_norm.py +0 -0
  222. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/transformers/test_mm_int8int2.py +0 -0
  223. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/transformers/test_monkey_patch.py +0 -0
  224. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/transformers/test_qwen2vl_mrope.py +0 -0
  225. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/transformers/test_rms_norm.py +0 -0
  226. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/transformers/test_rope.py +0 -0
  227. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/transformers/test_swiglu.py +0 -0
  228. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/transformers/test_trainer_integration.py +0 -0
  229. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/transformers/test_transformers.py +0 -0
  230. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/transformers/test_tvd.py +0 -0
  231. {liger_kernel_nightly-0.5.5.dev20250314203927 → liger_kernel_nightly-0.5.5.dev20250316145754}/test/triton/test_triton_monkey_patch.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.5.dev20250314203927
3
+ Version: 0.5.5.dev20250316145754
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -310,6 +310,7 @@ loss.backward()
310
310
  | Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
311
311
  | Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
312
312
  | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
313
+ | Paligemma, Paligemma2, & Paligemma2 Mix | `liger_kernel.transformers.apply_liger_kernel_to_paligemma` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
313
314
  | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
314
315
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
315
316
  | Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
@@ -262,6 +262,7 @@ loss.backward()
262
262
  | Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
263
263
  | Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
264
264
  | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
265
+ | Paligemma, Paligemma2, & Paligemma2 Mix | `liger_kernel.transformers.apply_liger_kernel_to_paligemma` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
265
266
  | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
266
267
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
267
268
  | Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.5.dev20250314203927"
7
+ version = "0.5.5.dev20250316145754"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -15,6 +15,7 @@ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral
15
15
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
16
16
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
17
17
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo2 # noqa: F401
18
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_paligemma # noqa: F401
18
19
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
19
20
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
20
21
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_5_vl # noqa: F401
@@ -0,0 +1,213 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from torch.nn import CrossEntropyLoss
9
+ from transformers.cache_utils import Cache
10
+ from transformers.models.paligemma.modeling_paligemma import _CONFIG_FOR_DOC
11
+ from transformers.models.paligemma.modeling_paligemma import PALIGEMMA_INPUTS_DOCSTRING
12
+ from transformers.models.paligemma.modeling_paligemma import PaliGemmaCausalLMOutputWithPast
13
+ from transformers.utils import add_start_docstrings_to_model_forward
14
+ from transformers.utils import is_torchdynamo_compiling
15
+ from transformers.utils import logging
16
+ from transformers.utils import replace_return_docstrings
17
+ from transformers.utils.deprecation import deprecate_kwarg
18
+
19
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
25
+ @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
26
+ @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
27
+ def lce_forward(
28
+ self,
29
+ input_ids: torch.LongTensor = None,
30
+ pixel_values: torch.FloatTensor = None,
31
+ attention_mask: Optional[torch.Tensor] = None,
32
+ position_ids: Optional[torch.LongTensor] = None,
33
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
34
+ token_type_ids: Optional[torch.LongTensor] = None,
35
+ cache_position: Optional[torch.LongTensor] = None,
36
+ inputs_embeds: Optional[torch.FloatTensor] = None,
37
+ labels: Optional[torch.LongTensor] = None,
38
+ use_cache: Optional[bool] = None,
39
+ output_attentions: Optional[bool] = None,
40
+ output_hidden_states: Optional[bool] = None,
41
+ return_dict: Optional[bool] = None,
42
+ logits_to_keep: Union[int, torch.Tensor] = 0,
43
+ **lm_kwargs,
44
+ ) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]:
45
+ r"""
46
+ Args:
47
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
48
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
49
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
50
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
51
+
52
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
53
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
54
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
55
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
56
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
57
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
58
+
59
+ Returns:
60
+
61
+ Example:
62
+
63
+ ```python
64
+ >>> from PIL import Image
65
+ >>> import requests
66
+ >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
67
+
68
+ >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/PaliGemma-test-224px-hf")
69
+ >>> processor = AutoProcessor.from_pretrained("google/PaliGemma-test-224px-hf")
70
+
71
+ >>> prompt = "answer en Where is the cow standing?"
72
+ >>> url = "https://huggingface.co/gv-hf/PaliGemma-test-224px-hf/resolve/main/cow_beach_1.png"
73
+ >>> image = Image.open(requests.get(url, stream=True).raw)
74
+
75
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
76
+
77
+ >>> # Generate
78
+ >>> generate_ids = model.generate(**inputs, max_length=30)
79
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
80
+ "answer en Where is the cow standing?\nbeach"
81
+ ```"""
82
+
83
+ if (input_ids is None) ^ (inputs_embeds is not None):
84
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
85
+
86
+ if pixel_values is not None and inputs_embeds is not None:
87
+ raise ValueError(
88
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
89
+ )
90
+
91
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
92
+ output_hidden_states = (
93
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
94
+ )
95
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
96
+
97
+ is_training = token_type_ids is not None and labels is not None
98
+
99
+ if inputs_embeds is None:
100
+ inputs_embeds = self.get_input_embeddings()(input_ids)
101
+
102
+ if cache_position is None:
103
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
104
+ cache_position = torch.arange(
105
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
106
+ )
107
+
108
+ if position_ids is None:
109
+ position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed
110
+
111
+ # Merge text and images
112
+ if pixel_values is not None:
113
+ image_features = self.get_image_features(pixel_values)
114
+
115
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
116
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
117
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
118
+ image_tokens_in_text = torch.sum(input_ids == self.config.image_token_index)
119
+ raise ValueError(
120
+ f"Number of images does not match number of special image tokens in the input text. "
121
+ f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
122
+ "tokens from image embeddings."
123
+ )
124
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
125
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
126
+
127
+ # mask out pad-token-ids in labels for BC
128
+ if labels is not None and self.pad_token_id in labels:
129
+ logger.warning_once(
130
+ "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
131
+ "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
132
+ )
133
+ labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
134
+
135
+ causal_mask = self._update_causal_mask(
136
+ attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
137
+ )
138
+
139
+ outputs = self.language_model.model(
140
+ attention_mask=causal_mask,
141
+ position_ids=position_ids,
142
+ past_key_values=past_key_values,
143
+ inputs_embeds=inputs_embeds,
144
+ use_cache=use_cache,
145
+ output_attentions=output_attentions,
146
+ output_hidden_states=output_hidden_states,
147
+ return_dict=return_dict,
148
+ cache_position=cache_position,
149
+ logits_to_keep=logits_to_keep,
150
+ **lm_kwargs,
151
+ )
152
+
153
+ hidden_states = outputs[0]
154
+
155
+ loss = None
156
+ logits = None
157
+
158
+ if self.training and (labels is not None):
159
+ shift_hidden_states = hidden_states[..., :-1, :]
160
+ shift_labels = labels[..., 1:]
161
+
162
+ hidden_device = shift_hidden_states.device
163
+
164
+ if attention_mask is not None:
165
+ # we use the input attention mask to shift the hidden_states and labels, because it is 2D.
166
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
167
+ shift_attention_mask = attention_mask[:, -shift_hidden_states.shape[1] :].to(hidden_device)
168
+ shift_hidden_states = shift_hidden_states[shift_attention_mask.to(hidden_device) != 0].contiguous()
169
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
170
+ else:
171
+ shift_hidden_states = shift_hidden_states.contiguous()
172
+ shift_labels = shift_labels.contiguous()
173
+
174
+ # Flatten hidden state
175
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.text_config.hidden_size)
176
+ shift_labels = shift_labels.view(-1).to(hidden_device)
177
+
178
+ lce = LigerFusedLinearCrossEntropyLoss()
179
+ loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
180
+ else:
181
+ logits = self.language_model.lm_head(hidden_states)
182
+ if labels is not None:
183
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
184
+ logits = logits.float()
185
+ shift_logits = logits[..., :-1, :]
186
+ shift_labels = labels[..., 1:]
187
+ if attention_mask is not None:
188
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
189
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
190
+ shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
191
+ shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
192
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
193
+ else:
194
+ shift_logits = shift_logits.contiguous()
195
+ shift_labels = shift_labels.contiguous()
196
+ # Flatten the tokens
197
+ loss_fct = CrossEntropyLoss()
198
+
199
+ flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
200
+ flat_labels = shift_labels.view(-1).to(shift_logits.device)
201
+ loss = loss_fct(flat_logits, flat_labels)
202
+ if not return_dict:
203
+ output = (logits,) + outputs[1:]
204
+ return (loss,) + output if loss is not None else output
205
+
206
+ return PaliGemmaCausalLMOutputWithPast(
207
+ loss=loss,
208
+ logits=logits,
209
+ past_key_values=outputs.past_key_values,
210
+ hidden_states=outputs.hidden_states,
211
+ attentions=outputs.attentions,
212
+ image_hidden_states=image_features if pixel_values is not None else None,
213
+ )
@@ -600,6 +600,90 @@ def apply_liger_kernel_to_gemma2(
600
600
  _patch_rms_norm_module_for_gemma2(decoder_layer.post_feedforward_layernorm)
601
601
 
602
602
 
603
+ def apply_liger_kernel_to_paligemma(
604
+ rope: bool = True,
605
+ cross_entropy: bool = False,
606
+ fused_linear_cross_entropy: bool = True,
607
+ layer_norm: bool = True,
608
+ rms_norm: bool = True,
609
+ geglu: bool = True,
610
+ model: PreTrainedModel = None,
611
+ ) -> None:
612
+ """
613
+ Apply Liger kernels to replace original implementation in HuggingFace PaliGemma
614
+
615
+ Args:
616
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
617
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
618
+ fused_linear_cross_entropy (bool):
619
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
620
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
621
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
622
+ layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
623
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
624
+ geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
625
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
626
+ loaded. Default is None.
627
+ """
628
+ assert not (cross_entropy and fused_linear_cross_entropy), (
629
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
630
+ )
631
+
632
+ # PaliGemma submodules are ['vision_tower', 'multi_modal_projector', 'language_model']
633
+
634
+ from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM
635
+ from transformers.models.paligemma import modeling_paligemma
636
+ from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
637
+ from transformers.models.siglip import modeling_siglip
638
+ from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer
639
+ from transformers.models.siglip.modeling_siglip import SiglipVisionModel
640
+
641
+ from liger_kernel.transformers.model.paligemma import lce_forward
642
+
643
+ # The vision_tower is a SiglipVisionModel
644
+ if layer_norm:
645
+ modeling_siglip.nn.LayerNorm = LigerLayerNorm
646
+
647
+ # SiglipMLP is standard FFN so LigerGEGLUMLP is not compatible
648
+ # The multi_modal_projector is Linear, nothing to do
649
+
650
+ # The language_model is Gemma2ForCausalLM
651
+ apply_liger_kernel_to_gemma2(rope=rope, cross_entropy=False, fused_linear_cross_entropy=False, geglu=geglu)
652
+ # Handle loss function
653
+ if cross_entropy:
654
+ modeling_paligemma.nn.CrossEntropyLoss = LigerCrossEntropyLoss
655
+ if fused_linear_cross_entropy:
656
+ modeling_paligemma.PaliGemmaForConditionalGeneration.forward = lce_forward
657
+
658
+ if model is not None:
659
+ # The model instance already exists, so we need to additionally patch the
660
+ # instance variables that reference already-instantiated modules
661
+
662
+ if not isinstance(model, PaliGemmaForConditionalGeneration):
663
+ raise TypeError("model have to be of type PaliGemmaForConditionalGeneration")
664
+
665
+ vision_tower: SiglipVisionModel = model.vision_tower
666
+
667
+ _patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
668
+
669
+ for layer in vision_tower.vision_model.encoder.layers:
670
+ layer: SiglipEncoderLayer
671
+ if layer_norm:
672
+ _patch_layer_norm_module(layer.layer_norm1)
673
+ _patch_layer_norm_module(layer.layer_norm2)
674
+
675
+ language_model: Gemma2ForCausalLM = model.language_model
676
+
677
+ apply_liger_kernel_to_gemma2(
678
+ rope=rope,
679
+ cross_entropy=False,
680
+ fused_linear_cross_entropy=False,
681
+ rms_norm=rms_norm,
682
+ geglu=geglu,
683
+ model=language_model,
684
+ )
685
+
686
+
603
687
  def apply_liger_kernel_to_qwen2(
604
688
  rope: bool = True,
605
689
  cross_entropy: bool = False,
@@ -959,6 +1043,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
959
1043
  "qwen2_vl": apply_liger_kernel_to_qwen2_vl,
960
1044
  "qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
961
1045
  "phi3": apply_liger_kernel_to_phi3,
1046
+ "paligemma": apply_liger_kernel_to_paligemma,
962
1047
  }
963
1048
 
964
1049
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.5.dev20250314203927
3
+ Version: 0.5.5.dev20250316145754
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -310,6 +310,7 @@ loss.backward()
310
310
  | Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
311
311
  | Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
312
312
  | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
313
+ | Paligemma, Paligemma2, & Paligemma2 Mix | `liger_kernel.transformers.apply_liger_kernel_to_paligemma` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
313
314
  | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
314
315
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
315
316
  | Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
@@ -163,6 +163,7 @@ src/liger_kernel/transformers/model/mistral.py
163
163
  src/liger_kernel/transformers/model/mixtral.py
164
164
  src/liger_kernel/transformers/model/mllama.py
165
165
  src/liger_kernel/transformers/model/olmo2.py
166
+ src/liger_kernel/transformers/model/paligemma.py
166
167
  src/liger_kernel/transformers/model/phi3.py
167
168
  src/liger_kernel/transformers/model/qwen2.py
168
169
  src/liger_kernel/transformers/model/qwen2_5_vl.py
@@ -197,6 +198,7 @@ test/convergence/fp32/test_mini_models.py
197
198
  test/convergence/fp32/test_mini_models_multimodal.py
198
199
  test/convergence/fp32/test_mini_models_with_logits.py
199
200
  test/resources/tiny_shakespeare.txt
201
+ test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json
200
202
  test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json
201
203
  test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json
202
204
  test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json
@@ -9,6 +9,7 @@ from torch.utils.data import DataLoader
9
9
  from transformers import PreTrainedTokenizerFast
10
10
 
11
11
  from liger_kernel.transformers import apply_liger_kernel_to_mllama
12
+ from liger_kernel.transformers import apply_liger_kernel_to_paligemma
12
13
  from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl
13
14
  from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl
14
15
  from test.utils import FAKE_CONFIGS_PATH
@@ -18,6 +19,7 @@ from test.utils import assert_verbose_allclose
18
19
  from test.utils import load_tokenizer_config
19
20
  from test.utils import multimodal_collate_fn
20
21
  from test.utils import revert_liger_kernel_to_mllama
22
+ from test.utils import revert_liger_kernel_to_Paligemma
21
23
  from test.utils import revert_liger_kernel_to_qwen2_5_vl
22
24
  from test.utils import revert_liger_kernel_to_qwen2_vl
23
25
  from test.utils import set_seed
@@ -61,6 +63,19 @@ try:
61
63
  except ImportError:
62
64
  MLLAMA_AVAILABLE = False
63
65
 
66
+ try:
67
+ from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast
68
+ from transformers.models.gemma2.configuration_gemma2 import Gemma2Config
69
+ from transformers.models.paligemma.configuration_paligemma import PaliGemmaConfig
70
+ from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
71
+ from transformers.models.paligemma.processing_paligemma import PaliGemmaProcessor
72
+ from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
73
+ from transformers.models.siglip.image_processing_siglip import SiglipImageProcessor
74
+
75
+ PALIGEMMA_AVAILABLE = True
76
+ except ImportError:
77
+ PALIGEMMA_AVAILABLE = False
78
+
64
79
  from liger_kernel.utils import infer_device
65
80
 
66
81
  device = infer_device()
@@ -135,6 +150,58 @@ if MLLAMA_AVAILABLE:
135
150
  ),
136
151
  )
137
152
 
153
+ if PALIGEMMA_AVAILABLE:
154
+ MINI_MODEL_SETUPS["mini_paligemma"] = MiniModelConfig(
155
+ liger_kernel_patch_func=functools.partial(apply_liger_kernel_to_paligemma, fused_linear_cross_entropy=False),
156
+ liger_kernel_patch_revert_func=revert_liger_kernel_to_Paligemma,
157
+ model_class=PaliGemmaForConditionalGeneration,
158
+ mini_model_config=PaliGemmaConfig(
159
+ vision_config=SiglipVisionConfig(
160
+ attention_dropout=0.0,
161
+ hidden_act="gelu_pytorch_tanh",
162
+ hidden_size=1152,
163
+ image_size=224,
164
+ intermediate_size=2048, # 4304
165
+ layer_norm_eps=1e-06,
166
+ num_attention_heads=4, # 16
167
+ num_channels=3,
168
+ num_hidden_layers=4, # 27
169
+ num_image_tokens=256,
170
+ num_positions=256,
171
+ patch_size=14,
172
+ projection_dim=1024, # 2304
173
+ ),
174
+ text_config=Gemma2Config(
175
+ vocab_size=32000, # 256000
176
+ hidden_size=1024, # 3072
177
+ intermediate_size=2048, # 24576
178
+ num_hidden_layers=4, # 28
179
+ num_attention_heads=4, # 16
180
+ num_key_value_heads=4, # 16
181
+ head_dim=256,
182
+ hidden_activation="gelu_pytorch_tanh",
183
+ max_position_embeddings=8192,
184
+ initializer_range=0.02,
185
+ rms_norm_eps=1e-06,
186
+ use_cache=True,
187
+ pad_token_id=0,
188
+ # Special token ids/vocab size to match Mistral-7B tokenizer used to create the tokenized dataset
189
+ # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
190
+ bos_token_id=1, # 128000
191
+ eos_token_id=2, # 128001
192
+ tie_word_embeddings=True,
193
+ rope_theta=10000.0,
194
+ attention_bias=False,
195
+ attention_dropout=0.0,
196
+ ),
197
+ image_token_index=4, # NOTE: outside the vocab size
198
+ attn_implementation="eager",
199
+ vocab_size=32000,
200
+ projection_dim=1024,
201
+ ),
202
+ )
203
+
204
+
138
205
  if QWEN2_VL_AVAILABLE:
139
206
  MINI_MODEL_SETUPS["mini_qwen2_vl"] = MiniModelConfig(
140
207
  liger_kernel_patch_func=functools.partial(apply_liger_kernel_to_qwen2_vl, fused_linear_cross_entropy=False),
@@ -284,6 +351,27 @@ def create_processor(model_name):
284
351
  fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer_base, **tokenizer_config)
285
352
  image_processor = MllamaImageProcessor(size={"height": 560, "width": 560})
286
353
  return MllamaProcessor(image_processor=image_processor, tokenizer=fast_tokenizer)
354
+
355
+ elif model_name == "mini_paligemma":
356
+ tokenizer_config = load_tokenizer_config(
357
+ os.path.join(
358
+ FAKE_CONFIGS_PATH,
359
+ "Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json",
360
+ )
361
+ )
362
+ tokenizer_base = train_bpe_tokenizer(
363
+ [
364
+ token.content
365
+ for key, token in sorted(
366
+ tokenizer_config["added_tokens_decoder"].items(),
367
+ key=lambda x: int(x[0]),
368
+ )
369
+ ]
370
+ )
371
+ fast_tokenizer = GemmaTokenizerFast(tokenizer_object=tokenizer_base, **tokenizer_config)
372
+ image_processor = SiglipImageProcessor(size={"height": 224, "width": 224}, image_seq_length=256)
373
+ return PaliGemmaProcessor(image_processor=image_processor, tokenizer=fast_tokenizer)
374
+
287
375
  else:
288
376
  raise ValueError(f"Processor not available for model {model_name}")
289
377
 
@@ -473,6 +561,25 @@ def run_mini_model_multimodal(
473
561
  ),
474
562
  ],
475
563
  ),
564
+ pytest.param(
565
+ "mini_paligemma",
566
+ 32,
567
+ 1e-4,
568
+ torch.bfloat16,
569
+ 1e-3,
570
+ 1e-2,
571
+ 1e-1,
572
+ 1e-2,
573
+ 1e-2,
574
+ 1e-2,
575
+ marks=[
576
+ pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
577
+ pytest.mark.skipif(
578
+ not PALIGEMMA_AVAILABLE,
579
+ reason="Paligemma not available in this version of transformers",
580
+ ),
581
+ ],
582
+ ),
476
583
  ],
477
584
  )
478
585
  def test_mini_model_multimodal(