liger-kernel-nightly 0.5.6.dev20250402215200__tar.gz → 0.5.6.dev20250403190551__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (243) hide show
  1. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/PKG-INFO +3 -1
  2. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/README.md +2 -0
  3. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/pyproject.toml +1 -1
  4. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/__init__.py +2 -0
  5. liger_kernel_nightly-0.5.6.dev20250403190551/src/liger_kernel/transformers/gema3_rms.py +8 -0
  6. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/model/gemma.py +9 -4
  7. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/model/gemma2.py +9 -4
  8. liger_kernel_nightly-0.5.6.dev20250403190551/src/liger_kernel/transformers/model/gemma3.py +335 -0
  9. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/model/llama.py +9 -4
  10. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/model/mistral.py +19 -15
  11. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/model/mixtral.py +12 -11
  12. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/model/mllama.py +9 -4
  13. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/model/olmo2.py +9 -4
  14. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/model/phi3.py +9 -4
  15. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/model/qwen2.py +9 -4
  16. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/monkey_patch.py +173 -0
  17. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel_nightly.egg-info/PKG-INFO +3 -1
  18. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel_nightly.egg-info/SOURCES.txt +3 -0
  19. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/convergence/bf16/test_mini_models.py +59 -0
  20. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/convergence/bf16/test_mini_models_multimodal.py +100 -0
  21. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/convergence/bf16/test_mini_models_with_logits.py +59 -0
  22. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/convergence/fp32/test_mini_models.py +55 -0
  23. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/convergence/fp32/test_mini_models_multimodal.py +96 -2
  24. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/convergence/fp32/test_mini_models_with_logits.py +55 -0
  25. liger_kernel_nightly-0.5.6.dev20250403190551/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json +90 -0
  26. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/transformers/test_monkey_patch.py +150 -0
  27. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/utils.py +29 -1
  28. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  29. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  30. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/.github/pull_request_template.md +0 -0
  31. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/.github/workflows/amd-ci.yml +0 -0
  32. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/.github/workflows/docs.yml +0 -0
  33. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/.github/workflows/intel-ci.yml +0 -0
  34. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/.github/workflows/nvi-ci.yml +0 -0
  35. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/.github/workflows/publish-nightly.yml +0 -0
  36. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/.github/workflows/publish-release.yml +0 -0
  37. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/.gitignore +0 -0
  38. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/LICENSE +0 -0
  39. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/Makefile +0 -0
  40. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/NOTICE +0 -0
  41. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/benchmark/README.md +0 -0
  42. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/benchmark/__init__.py +0 -0
  43. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/benchmark/benchmarks_visualizer.py +0 -0
  44. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/benchmark/data/all_benchmark_data.csv +0 -0
  45. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/benchmark/scripts/__init__.py +0 -0
  46. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  47. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  48. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  49. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  50. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/benchmark/scripts/benchmark_dyt.py +0 -0
  51. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/benchmark/scripts/benchmark_embedding.py +0 -0
  52. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  53. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  54. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/benchmark/scripts/benchmark_geglu.py +0 -0
  55. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/benchmark/scripts/benchmark_group_norm.py +0 -0
  56. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/benchmark/scripts/benchmark_jsd.py +0 -0
  57. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/benchmark/scripts/benchmark_kl_div.py +0 -0
  58. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  59. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  60. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  61. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  62. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  63. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/benchmark/scripts/benchmark_rope.py +0 -0
  64. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  65. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/benchmark/scripts/benchmark_swiglu.py +0 -0
  66. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/benchmark/scripts/benchmark_tvd.py +0 -0
  67. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/benchmark/scripts/utils.py +0 -0
  68. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/dev/fmt-requirements.txt +0 -0
  69. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/dev/modal/tests.py +0 -0
  70. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/dev/modal/tests_bwd.py +0 -0
  71. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/docs/Examples.md +0 -0
  72. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/docs/Getting-Started.md +0 -0
  73. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/docs/High-Level-APIs.md +0 -0
  74. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/docs/Low-Level-APIs.md +0 -0
  75. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/docs/acknowledgement.md +0 -0
  76. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/docs/contributing.md +0 -0
  77. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/docs/images/banner.GIF +0 -0
  78. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/docs/images/compose.gif +0 -0
  79. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/docs/images/e2e-memory.png +0 -0
  80. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/docs/images/e2e-tps.png +0 -0
  81. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/docs/images/logo-banner.png +0 -0
  82. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/docs/images/patch.gif +0 -0
  83. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/docs/images/post-training.png +0 -0
  84. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/docs/index.md +0 -0
  85. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/docs/license.md +0 -0
  86. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/alignment/accelerate_config.yaml +0 -0
  87. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/alignment/run_orpo.py +0 -0
  88. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/huggingface/README.md +0 -0
  89. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/huggingface/callback.py +0 -0
  90. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/huggingface/config/fsdp_config.json +0 -0
  91. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  92. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  93. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  94. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/huggingface/img/llama_tps.png +0 -0
  95. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  96. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/huggingface/img/qwen_tps.png +0 -0
  97. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/huggingface/launch_on_modal.py +0 -0
  98. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/huggingface/requirements.txt +0 -0
  99. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/huggingface/run_benchmarks.sh +0 -0
  100. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/huggingface/run_gemma.sh +0 -0
  101. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/huggingface/run_llama.sh +0 -0
  102. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/huggingface/run_qwen.sh +0 -0
  103. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/huggingface/run_qwen2_vl.sh +0 -0
  104. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/huggingface/training.py +0 -0
  105. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/huggingface/training_multimodal.py +0 -0
  106. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/lightning/README.md +0 -0
  107. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/lightning/requirements.txt +0 -0
  108. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/lightning/training.py +0 -0
  109. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/medusa/README.md +0 -0
  110. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/medusa/callback.py +0 -0
  111. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  112. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  113. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  114. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  115. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  116. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  117. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  118. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  119. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  120. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/medusa/medusa_util.py +0 -0
  121. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/medusa/requirements.txt +0 -0
  122. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  123. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/examples/medusa/train.py +0 -0
  124. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/licenses/LICENSE-Apache-2.0 +0 -0
  125. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  126. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  127. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/licenses/LICENSE-MIT-llmc +0 -0
  128. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/licenses/LICENSE-MIT-triton +0 -0
  129. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/mkdocs.yml +0 -0
  130. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/setup.cfg +0 -0
  131. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/setup.py +0 -0
  132. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/__init__.py +0 -0
  133. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/chunked_loss/README.md +0 -0
  134. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  135. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  136. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  137. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/chunked_loss/functional.py +0 -0
  138. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  139. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/chunked_loss/fused_linear_ppo.py +0 -0
  140. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  141. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  142. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
  143. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  144. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  145. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  146. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  147. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/env_report.py +0 -0
  148. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/ops/__init__.py +0 -0
  149. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/ops/cross_entropy.py +0 -0
  150. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/ops/dyt.py +0 -0
  151. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  152. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  153. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  154. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  155. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/ops/geglu.py +0 -0
  156. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/ops/group_norm.py +0 -0
  157. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/ops/jsd.py +0 -0
  158. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/ops/kl_div.py +0 -0
  159. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/ops/layer_norm.py +0 -0
  160. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  161. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/ops/rms_norm.py +0 -0
  162. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/ops/rope.py +0 -0
  163. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/ops/swiglu.py +0 -0
  164. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/ops/tvd.py +0 -0
  165. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/ops/utils.py +0 -0
  166. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/auto_model.py +0 -0
  167. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  168. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/dyt.py +0 -0
  169. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  170. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/functional.py +0 -0
  171. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  172. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  173. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/geglu.py +0 -0
  174. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/group_norm.py +0 -0
  175. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/jsd.py +0 -0
  176. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/kl_div.py +0 -0
  177. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/layer_norm.py +0 -0
  178. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/model/__init__.py +0 -0
  179. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/model/llava.py +0 -0
  180. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
  181. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/model/paligemma.py +0 -0
  182. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
  183. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  184. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  185. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/rms_norm.py +0 -0
  186. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/rope.py +0 -0
  187. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/swiglu.py +0 -0
  188. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  189. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  190. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  191. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/transformers/tvd.py +0 -0
  192. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/triton/__init__.py +0 -0
  193. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/triton/monkey_patch.py +0 -0
  194. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel/utils.py +0 -0
  195. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  196. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  197. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  198. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/__init__.py +0 -0
  199. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/chunked_loss/__init__.py +0 -0
  200. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/chunked_loss/test_cpo_loss.py +0 -0
  201. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/chunked_loss/test_dpo_loss.py +0 -0
  202. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/chunked_loss/test_grpo_loss.py +0 -0
  203. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/chunked_loss/test_jsd_loss.py +0 -0
  204. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/chunked_loss/test_kto_loss.py +0 -0
  205. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/chunked_loss/test_orpo_loss.py +0 -0
  206. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/chunked_loss/test_simpo_loss.py +0 -0
  207. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/conftest.py +0 -0
  208. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/convergence/__init__.py +0 -0
  209. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/convergence/bf16/__init__.py +0 -0
  210. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/convergence/fp32/__init__.py +0 -0
  211. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
  212. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
  213. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
  214. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
  215. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  216. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
  217. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  218. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  219. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/resources/tiny_shakespeare.txt +0 -0
  220. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  221. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  222. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  223. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/transformers/test_auto_model.py +0 -0
  224. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/transformers/test_cross_entropy.py +0 -0
  225. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/transformers/test_dyt.py +0 -0
  226. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/transformers/test_embedding.py +0 -0
  227. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/transformers/test_flex_attention.py +0 -0
  228. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  229. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/transformers/test_fused_linear_jsd.py +0 -0
  230. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/transformers/test_geglu.py +0 -0
  231. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/transformers/test_group_norm.py +0 -0
  232. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/transformers/test_jsd.py +0 -0
  233. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/transformers/test_kl_div.py +0 -0
  234. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/transformers/test_layer_norm.py +0 -0
  235. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/transformers/test_mm_int8int2.py +0 -0
  236. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/transformers/test_qwen2vl_mrope.py +0 -0
  237. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/transformers/test_rms_norm.py +0 -0
  238. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/transformers/test_rope.py +0 -0
  239. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/transformers/test_swiglu.py +0 -0
  240. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/transformers/test_trainer_integration.py +0 -0
  241. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/transformers/test_transformers.py +0 -0
  242. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/transformers/test_tvd.py +0 -0
  243. {liger_kernel_nightly-0.5.6.dev20250402215200 → liger_kernel_nightly-0.5.6.dev20250403190551}/test/triton/test_triton_monkey_patch.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.6.dev20250402215200
3
+ Version: 0.5.6.dev20250403190551
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -311,6 +311,8 @@ loss.backward()
311
311
  | Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
312
312
  | Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
313
313
  | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
314
+ | Gemma3 (Text) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3_text` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
315
+ | Gemma3 (Multimodal) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
314
316
  | Paligemma, Paligemma2, & Paligemma2 Mix | `liger_kernel.transformers.apply_liger_kernel_to_paligemma` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
315
317
  | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
316
318
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
@@ -263,6 +263,8 @@ loss.backward()
263
263
  | Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
264
264
  | Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
265
265
  | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
266
+ | Gemma3 (Text) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3_text` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
267
+ | Gemma3 (Multimodal) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
266
268
  | Paligemma, Paligemma2, & Paligemma2 Mix | `liger_kernel.transformers.apply_liger_kernel_to_paligemma` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
267
269
  | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
268
270
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.6.dev20250402215200"
7
+ version = "0.5.6.dev20250403190551"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -10,6 +10,8 @@ from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa:
10
10
  from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401
11
11
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
12
12
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
13
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401
14
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401
13
15
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
14
16
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
15
17
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
@@ -0,0 +1,8 @@
1
+ from .rms_norm import LigerRMSNorm
2
+
3
+
4
+ class LigerRMSNormForGemma3(LigerRMSNorm):
5
+ """Gemma3RMSNorm has a dim argument not hidden_size used in q_norm and k_norm."""
6
+
7
+ def __init__(self, dim, eps=0.000001, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False):
8
+ super().__init__(dim, eps, offset, casting_mode, init_fn, in_place)
@@ -12,6 +12,7 @@ from transformers.models.gemma.modeling_gemma import _CONFIG_FOR_DOC
12
12
  from transformers.models.gemma.modeling_gemma import GEMMA_INPUTS_DOCSTRING
13
13
  from transformers.utils import add_start_docstrings_to_model_forward
14
14
  from transformers.utils import replace_return_docstrings
15
+ from transformers.utils.deprecation import deprecate_kwarg
15
16
 
16
17
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
17
18
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
@@ -127,6 +128,7 @@ def lce_forward_deprecated(
127
128
  )
128
129
 
129
130
 
131
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
130
132
  @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
131
133
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
132
134
  def lce_forward(
@@ -142,7 +144,7 @@ def lce_forward(
142
144
  output_hidden_states: Optional[bool] = None,
143
145
  return_dict: Optional[bool] = None,
144
146
  cache_position: Optional[torch.LongTensor] = None,
145
- num_logits_to_keep: int = 0,
147
+ logits_to_keep: Union[int, torch.Tensor] = 0,
146
148
  **loss_kwargs,
147
149
  ) -> Union[Tuple, CausalLMOutputWithPast]:
148
150
  r"""
@@ -152,10 +154,12 @@ def lce_forward(
152
154
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
153
155
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
154
156
 
155
- num_logits_to_keep (`int`, *optional*):
156
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
157
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
158
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
157
159
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
158
160
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
161
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
162
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
159
163
 
160
164
  Returns:
161
165
 
@@ -209,7 +213,8 @@ def lce_forward(
209
213
  **loss_kwargs,
210
214
  )
211
215
  else: # if in inference mode materialize logits
212
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
216
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
217
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
213
218
  if labels is not None:
214
219
  loss = self.loss_function(
215
220
  logits=logits,
@@ -13,6 +13,7 @@ from transformers.models.gemma2.modeling_gemma2 import _CONFIG_FOR_DOC
13
13
  from transformers.models.gemma2.modeling_gemma2 import GEMMA2_INPUTS_DOCSTRING
14
14
  from transformers.utils import add_start_docstrings_to_model_forward
15
15
  from transformers.utils import replace_return_docstrings
16
+ from transformers.utils.deprecation import deprecate_kwarg
16
17
 
17
18
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
18
19
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
@@ -134,6 +135,7 @@ def lce_forward_deprecated(
134
135
  )
135
136
 
136
137
 
138
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
137
139
  @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
138
140
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
139
141
  def lce_forward(
@@ -149,7 +151,7 @@ def lce_forward(
149
151
  output_hidden_states: Optional[bool] = None,
150
152
  return_dict: Optional[bool] = None,
151
153
  cache_position: Optional[torch.LongTensor] = None,
152
- num_logits_to_keep: int = 0,
154
+ logits_to_keep: Union[int, torch.Tensor] = 0,
153
155
  **loss_kwargs,
154
156
  ) -> Union[Tuple, CausalLMOutputWithPast]:
155
157
  r"""
@@ -159,10 +161,12 @@ def lce_forward(
159
161
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
160
162
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
161
163
 
162
- num_logits_to_keep (`int`, *optional*):
163
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
164
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
165
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
164
166
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
165
167
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
168
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
169
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
166
170
 
167
171
  Returns:
168
172
 
@@ -223,7 +227,8 @@ def lce_forward(
223
227
  )
224
228
 
225
229
  else: # if in inference mode materialize logits
226
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
230
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
231
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
227
232
  if self.config.final_logit_softcapping is not None:
228
233
  logits = logits / self.config.final_logit_softcapping
229
234
  logits = torch.tanh(logits)
@@ -0,0 +1,335 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from transformers.cache_utils import Cache
10
+ from transformers.cache_utils import HybridCache
11
+ from transformers.modeling_outputs import CausalLMOutputWithPast
12
+ from transformers.models.gemma3.modeling_gemma3 import _CONFIG_FOR_DOC
13
+ from transformers.models.gemma3.modeling_gemma3 import GEMMA3_INPUTS_DOCSTRING
14
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3CausalLMOutputWithPast
15
+ from transformers.utils import add_start_docstrings_to_model_forward
16
+ from transformers.utils import is_torchdynamo_compiling
17
+ from transformers.utils import logging
18
+ from transformers.utils import replace_return_docstrings
19
+ from transformers.utils.deprecation import deprecate_kwarg
20
+
21
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
22
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
28
+ @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
29
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
30
+ def causal_forward(
31
+ self,
32
+ input_ids: torch.LongTensor = None,
33
+ attention_mask: Optional[torch.Tensor] = None,
34
+ position_ids: Optional[torch.LongTensor] = None,
35
+ past_key_values: Optional[HybridCache] = None,
36
+ inputs_embeds: Optional[torch.FloatTensor] = None,
37
+ labels: Optional[torch.LongTensor] = None,
38
+ use_cache: Optional[bool] = None,
39
+ output_attentions: Optional[bool] = None,
40
+ output_hidden_states: Optional[bool] = None,
41
+ return_dict: Optional[bool] = None,
42
+ cache_position: Optional[torch.LongTensor] = None,
43
+ logits_to_keep: Union[int, torch.Tensor] = 0,
44
+ **loss_kwargs,
45
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
46
+ r"""
47
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
48
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
49
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
50
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
51
+
52
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
53
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
54
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
55
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
56
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
57
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
58
+
59
+ Returns:
60
+
61
+ Example:
62
+
63
+ ```python
64
+ >>> from transformers import AutoTokenizer, Gemma3ForCausalLM
65
+
66
+ >>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b")
67
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
68
+
69
+ >>> prompt = "What is your favorite condiment?"
70
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
71
+
72
+ >>> # Generate
73
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
74
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
75
+ "What is your favorite condiment?"
76
+ ```"""
77
+
78
+ if self.training and self.config._attn_implementation != "eager":
79
+ logger.warning_once(
80
+ "It is strongly recommended to train Gemma3 models with the `eager` attention implementation "
81
+ f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
82
+ )
83
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
84
+ output_hidden_states = (
85
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
86
+ )
87
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
88
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
89
+ outputs = self.model(
90
+ input_ids=input_ids,
91
+ attention_mask=attention_mask,
92
+ position_ids=position_ids,
93
+ past_key_values=past_key_values,
94
+ inputs_embeds=inputs_embeds,
95
+ use_cache=use_cache,
96
+ output_attentions=output_attentions,
97
+ output_hidden_states=output_hidden_states,
98
+ return_dict=return_dict,
99
+ cache_position=cache_position,
100
+ **loss_kwargs,
101
+ )
102
+
103
+ hidden_states = outputs[0]
104
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
105
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
106
+ kept_hidden_states = hidden_states[:, slice_indices, :]
107
+ loss = None
108
+ logits = None
109
+ if self.training and (labels is not None):
110
+ loss = LigerForCausalLMLoss(
111
+ hidden_states=kept_hidden_states,
112
+ lm_head_weight=self.lm_head.weight,
113
+ labels=labels,
114
+ hidden_size=self.config.hidden_size,
115
+ softcap=self.config.final_logit_softcapping,
116
+ **loss_kwargs,
117
+ )
118
+
119
+ else:
120
+ logits = self.lm_head(kept_hidden_states)
121
+ if self.config.final_logit_softcapping is not None:
122
+ logits = logits / self.config.final_logit_softcapping
123
+ logits = torch.tanh(logits)
124
+ logits = logits * self.config.final_logit_softcapping
125
+ if labels is not None:
126
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
127
+
128
+ if not return_dict:
129
+ output = (logits,) + outputs[1:]
130
+ return (loss,) + output if loss is not None else output
131
+
132
+ return CausalLMOutputWithPast(
133
+ loss=loss,
134
+ logits=logits,
135
+ past_key_values=outputs.past_key_values,
136
+ hidden_states=outputs.hidden_states,
137
+ attentions=outputs.attentions,
138
+ )
139
+
140
+
141
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
142
+ @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
143
+ @replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
144
+ def multimodal_forward(
145
+ self,
146
+ input_ids: torch.LongTensor = None,
147
+ pixel_values: torch.FloatTensor = None,
148
+ attention_mask: Optional[torch.Tensor] = None,
149
+ position_ids: Optional[torch.LongTensor] = None,
150
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
151
+ token_type_ids: Optional[torch.LongTensor] = None,
152
+ cache_position: Optional[torch.LongTensor] = None,
153
+ inputs_embeds: Optional[torch.FloatTensor] = None,
154
+ labels: Optional[torch.LongTensor] = None,
155
+ use_cache: Optional[bool] = None,
156
+ output_attentions: Optional[bool] = None,
157
+ output_hidden_states: Optional[bool] = None,
158
+ return_dict: Optional[bool] = None,
159
+ logits_to_keep: Union[int, torch.Tensor] = 0,
160
+ **lm_kwargs,
161
+ ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
162
+ r"""
163
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
164
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
165
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
166
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
167
+
168
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
169
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
170
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
171
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
172
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
173
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
174
+
175
+ Returns:
176
+
177
+ Example:
178
+
179
+ ```python
180
+ >>> from PIL import Image
181
+ >>> import requests
182
+ >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
183
+
184
+ >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
185
+ >>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
186
+
187
+ >>> prompt = "answer en Where is the cow standing?"
188
+ >>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
189
+ >>> image = Image.open(requests.get(url, stream=True).raw)
190
+
191
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
192
+
193
+ >>> # Generate
194
+ >>> generate_ids = model.generate(**inputs, max_length=30)
195
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
196
+ "answer en Where is the cow standing?\nbeach"
197
+ ```"""
198
+
199
+ if (input_ids is None) ^ (inputs_embeds is not None):
200
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
201
+
202
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
203
+ output_hidden_states = (
204
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
205
+ )
206
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
207
+
208
+ is_training = token_type_ids is not None and labels is not None
209
+
210
+ # Replace image id woth PAD if the image token if OOV, to avoid index-errors
211
+ if input_ids is not None and self.config.image_token_index >= self.vocab_size:
212
+ special_image_mask = input_ids == self.config.image_token_index
213
+ llm_input_ids = input_ids.clone()
214
+ llm_input_ids[special_image_mask] = 0
215
+ else:
216
+ llm_input_ids = input_ids
217
+
218
+ if inputs_embeds is None:
219
+ inputs_embeds = self.get_input_embeddings()(llm_input_ids)
220
+
221
+ if cache_position is None:
222
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
223
+ cache_position = torch.arange(
224
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
225
+ )
226
+
227
+ if position_ids is None:
228
+ position_ids = cache_position.unsqueeze(0) + 1 # Gemma3 positions are 1-indexed
229
+
230
+ # Merge text and images
231
+ if pixel_values is not None:
232
+ image_features = self.get_image_features(pixel_values)
233
+
234
+ if input_ids is None:
235
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
236
+ torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device)
237
+ )
238
+ else:
239
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
240
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
241
+
242
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
243
+ image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
244
+ raise ValueError(
245
+ f"Number of images does not match number of special image tokens in the input text. "
246
+ f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
247
+ "tokens from image embeddings."
248
+ )
249
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
250
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
251
+
252
+ # mask out pad-token-ids in labels for BC
253
+ if labels is not None and self.pad_token_id in labels:
254
+ logger.warning_once(
255
+ "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
256
+ "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
257
+ )
258
+ labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
259
+
260
+ causal_mask = self._update_causal_mask(
261
+ attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
262
+ )
263
+ outputs = self.language_model.model(
264
+ attention_mask=causal_mask,
265
+ position_ids=position_ids,
266
+ past_key_values=past_key_values,
267
+ inputs_embeds=inputs_embeds,
268
+ use_cache=use_cache,
269
+ output_attentions=output_attentions,
270
+ output_hidden_states=output_hidden_states,
271
+ return_dict=return_dict,
272
+ cache_position=cache_position,
273
+ logits_to_keep=logits_to_keep,
274
+ **lm_kwargs,
275
+ )
276
+
277
+ hidden_states = outputs[0]
278
+ loss = None
279
+ logits = None
280
+
281
+ if self.training and (labels is not None):
282
+ shift_hidden_states = hidden_states[..., :-1, :]
283
+ shift_labels = labels[..., 1:]
284
+
285
+ hidden_device = shift_hidden_states.device
286
+ if attention_mask is not None:
287
+ # we use the input attention mask to shift the hidden_states and labels, because it is 2D.
288
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
289
+ shift_attention_mask = attention_mask[:, -shift_hidden_states.shape[1] :].to(hidden_device)
290
+ shift_hidden_states = shift_hidden_states[shift_attention_mask.to(hidden_device) != 0].contiguous()
291
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
292
+ else:
293
+ shift_hidden_states = shift_hidden_states.contiguous()
294
+ shift_labels = shift_labels.contiguous()
295
+
296
+ # Flatten hidden state
297
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.text_config.hidden_size)
298
+ shift_labels = shift_labels.view(-1).to(hidden_device)
299
+
300
+ lce = LigerFusedLinearCrossEntropyLoss()
301
+ loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
302
+ else:
303
+ logits = self.language_model.lm_head(hidden_states)
304
+ if labels is not None:
305
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
306
+ logits = logits.float()
307
+ shift_logits = logits[..., :-1, :]
308
+ shift_labels = labels[..., 1:]
309
+ if attention_mask is not None:
310
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
311
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
312
+ shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
313
+ shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
314
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
315
+ else:
316
+ shift_logits = shift_logits.contiguous()
317
+ shift_labels = shift_labels.contiguous()
318
+ # Flatten the tokens
319
+ loss_fct = nn.CrossEntropyLoss()
320
+
321
+ flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
322
+ flat_labels = shift_labels.view(-1).to(shift_logits.device)
323
+ loss = loss_fct(flat_logits, flat_labels)
324
+ if not return_dict:
325
+ output = (logits,) + outputs[1:]
326
+ return (loss,) + output if loss is not None else output
327
+
328
+ return Gemma3CausalLMOutputWithPast(
329
+ loss=loss,
330
+ logits=logits,
331
+ past_key_values=outputs.past_key_values,
332
+ hidden_states=outputs.hidden_states,
333
+ attentions=outputs.attentions,
334
+ image_hidden_states=image_features if pixel_values is not None else None,
335
+ )
@@ -13,6 +13,7 @@ from transformers.models.llama.modeling_llama import _CONFIG_FOR_DOC
13
13
  from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING
14
14
  from transformers.utils import add_start_docstrings_to_model_forward
15
15
  from transformers.utils import replace_return_docstrings
16
+ from transformers.utils.deprecation import deprecate_kwarg
16
17
 
17
18
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
18
19
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
@@ -135,6 +136,7 @@ def lce_forward_deprecated(
135
136
  )
136
137
 
137
138
 
139
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
138
140
  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
139
141
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
140
142
  def lce_forward(
@@ -150,7 +152,7 @@ def lce_forward(
150
152
  output_hidden_states: Optional[bool] = None,
151
153
  return_dict: Optional[bool] = None,
152
154
  cache_position: Optional[torch.LongTensor] = None,
153
- num_logits_to_keep: int = 0,
155
+ logits_to_keep: Union[int, torch.Tensor] = 0,
154
156
  **loss_kwargs,
155
157
  ) -> Union[Tuple, CausalLMOutputWithPast]:
156
158
  r"""
@@ -160,10 +162,12 @@ def lce_forward(
160
162
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
161
163
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
162
164
 
163
- num_logits_to_keep (`int`, *optional*):
164
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
165
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
166
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
165
167
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
166
168
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
169
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
170
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
167
171
 
168
172
  Returns:
169
173
 
@@ -222,7 +226,8 @@ def lce_forward(
222
226
  )
223
227
 
224
228
  else: # if in inference mode materialize logits
225
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
229
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
230
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
226
231
  if labels is not None:
227
232
  loss = self.loss_function(
228
233
  logits=logits,
@@ -5,17 +5,18 @@ from typing import Union
5
5
 
6
6
  import torch
7
7
 
8
- from torch.nn import CrossEntropyLoss
9
8
  from transformers.cache_utils import Cache
10
9
  from transformers.modeling_outputs import CausalLMOutputWithPast
11
10
  from transformers.models.mistral.modeling_mistral import _CONFIG_FOR_DOC
12
11
  from transformers.models.mistral.modeling_mistral import MISTRAL_INPUTS_DOCSTRING
13
12
  from transformers.utils import add_start_docstrings_to_model_forward
14
13
  from transformers.utils import replace_return_docstrings
14
+ from transformers.utils.deprecation import deprecate_kwarg
15
15
 
16
16
  from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
17
17
 
18
18
 
19
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
19
20
  @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
20
21
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
21
22
  def lce_forward(
@@ -31,6 +32,7 @@ def lce_forward(
31
32
  output_hidden_states: Optional[bool] = None,
32
33
  return_dict: Optional[bool] = None,
33
34
  cache_position: Optional[torch.LongTensor] = None,
35
+ logits_to_keep: Union[int, torch.Tensor] = 0,
34
36
  **loss_kwargs,
35
37
  ) -> Union[Tuple, CausalLMOutputWithPast]:
36
38
  r"""
@@ -43,6 +45,12 @@ def lce_forward(
43
45
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
44
46
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
45
47
 
48
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
49
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
50
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
51
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
52
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
53
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
46
54
  Returns:
47
55
 
48
56
  Example:
@@ -97,21 +105,17 @@ def lce_forward(
97
105
  )
98
106
 
99
107
  else:
100
- logits = self.lm_head(hidden_states)
101
- if labels is not None:
102
- # Upcast to float if we need to compute the loss to avoid potential precision issues
103
- logits = logits.float()
104
- # Shift so that tokens < n predict n
105
- shift_logits = logits[..., :-1, :].contiguous()
106
- shift_labels = labels[..., 1:].contiguous()
107
- # Flatten the tokens
108
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
109
- shift_labels = shift_labels.view(-1)
110
- # Ensure tensors are on the same device
111
- shift_labels = shift_labels.to(shift_logits.device)
112
- loss_fct = CrossEntropyLoss()
113
- loss = loss_fct(shift_logits, shift_labels)
108
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
109
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
114
110
 
111
+ loss = None
112
+ if labels is not None:
113
+ loss = self.loss_function(
114
+ logits=logits,
115
+ labels=labels,
116
+ vocab_size=self.config.vocab_size,
117
+ **loss_kwargs,
118
+ )
115
119
  if not return_dict:
116
120
  output = (logits,) + outputs[1:]
117
121
  return (loss,) + output if loss is not None else output