liger-kernel-nightly 0.5.8.dev20250422210723__tar.gz → 0.5.8.dev20250428050809__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (243) hide show
  1. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/pyproject.toml +1 -1
  3. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/model/gemma.py +3 -1
  4. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/model/gemma2.py +3 -1
  5. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/model/gemma3.py +3 -1
  6. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/model/llama.py +3 -1
  7. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/model/mistral.py +3 -1
  8. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/model/mixtral.py +3 -1
  9. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/model/mllama.py +3 -1
  10. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/model/olmo2.py +3 -1
  11. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/model/phi3.py +3 -1
  12. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/model/qwen2.py +3 -1
  13. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/model/qwen2_5_vl.py +3 -1
  14. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/model/qwen2_vl.py +3 -1
  15. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
  16. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/convergence/bf16/test_mini_models_multimodal.py +1 -3
  17. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/convergence/fp32/test_mini_models_multimodal.py +0 -2
  18. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/transformers/test_fused_linear_cross_entropy.py +2 -2
  19. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/transformers/test_rms_norm.py +1 -1
  20. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  21. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  22. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/.github/pull_request_template.md +0 -0
  23. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/.github/workflows/amd-ci.yml +0 -0
  24. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/.github/workflows/docs.yml +0 -0
  25. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/.github/workflows/intel-ci.yml +0 -0
  26. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/.github/workflows/nvi-ci.yml +0 -0
  27. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/.github/workflows/publish-nightly.yml +0 -0
  28. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/.github/workflows/publish-release.yml +0 -0
  29. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/.gitignore +0 -0
  30. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/LICENSE +0 -0
  31. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/Makefile +0 -0
  32. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/NOTICE +0 -0
  33. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/README.md +0 -0
  34. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/benchmark/README.md +0 -0
  35. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/benchmark/__init__.py +0 -0
  36. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/benchmark/benchmarks_visualizer.py +0 -0
  37. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/benchmark/data/all_benchmark_data.csv +0 -0
  38. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/benchmark/scripts/__init__.py +0 -0
  39. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  40. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  41. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  42. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  43. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/benchmark/scripts/benchmark_dyt.py +0 -0
  44. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/benchmark/scripts/benchmark_embedding.py +0 -0
  45. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  46. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  47. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/benchmark/scripts/benchmark_geglu.py +0 -0
  48. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/benchmark/scripts/benchmark_group_norm.py +0 -0
  49. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/benchmark/scripts/benchmark_jsd.py +0 -0
  50. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/benchmark/scripts/benchmark_kl_div.py +0 -0
  51. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  52. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  53. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  54. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  55. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  56. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/benchmark/scripts/benchmark_rope.py +0 -0
  57. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  58. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/benchmark/scripts/benchmark_swiglu.py +0 -0
  59. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/benchmark/scripts/benchmark_tvd.py +0 -0
  60. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/benchmark/scripts/utils.py +0 -0
  61. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/dev/fmt-requirements.txt +0 -0
  62. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/dev/modal/tests.py +0 -0
  63. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/dev/modal/tests_bwd.py +0 -0
  64. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/docs/Examples.md +0 -0
  65. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/docs/Getting-Started.md +0 -0
  66. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/docs/High-Level-APIs.md +0 -0
  67. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/docs/Low-Level-APIs.md +0 -0
  68. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/docs/acknowledgement.md +0 -0
  69. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/docs/contributing.md +0 -0
  70. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/docs/images/banner.GIF +0 -0
  71. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/docs/images/compose.gif +0 -0
  72. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/docs/images/e2e-memory.png +0 -0
  73. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/docs/images/e2e-tps.png +0 -0
  74. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/docs/images/logo-banner.png +0 -0
  75. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/docs/images/patch.gif +0 -0
  76. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/docs/images/post-training.png +0 -0
  77. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/docs/index.md +0 -0
  78. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/docs/license.md +0 -0
  79. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/alignment/accelerate_config.yaml +0 -0
  80. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/alignment/run_orpo.py +0 -0
  81. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/huggingface/README.md +0 -0
  82. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/huggingface/callback.py +0 -0
  83. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/huggingface/config/fsdp_config.json +0 -0
  84. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  85. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  86. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  87. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/huggingface/img/llama_tps.png +0 -0
  88. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  89. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/huggingface/img/qwen_tps.png +0 -0
  90. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/huggingface/launch_on_modal.py +0 -0
  91. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/huggingface/requirements.txt +0 -0
  92. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/huggingface/run_benchmarks.sh +0 -0
  93. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/huggingface/run_gemma.sh +0 -0
  94. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/huggingface/run_llama.sh +0 -0
  95. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/huggingface/run_qwen.sh +0 -0
  96. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/huggingface/run_qwen2_vl.sh +0 -0
  97. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/huggingface/training.py +0 -0
  98. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/huggingface/training_multimodal.py +0 -0
  99. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/lightning/README.md +0 -0
  100. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/lightning/requirements.txt +0 -0
  101. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/lightning/training.py +0 -0
  102. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/medusa/README.md +0 -0
  103. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/medusa/callback.py +0 -0
  104. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  105. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  106. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  107. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  108. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  109. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  110. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  111. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  112. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  113. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/medusa/medusa_util.py +0 -0
  114. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/medusa/requirements.txt +0 -0
  115. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  116. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/examples/medusa/train.py +0 -0
  117. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/licenses/LICENSE-Apache-2.0 +0 -0
  118. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  119. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  120. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/licenses/LICENSE-MIT-llmc +0 -0
  121. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/licenses/LICENSE-MIT-triton +0 -0
  122. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/mkdocs.yml +0 -0
  123. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/setup.cfg +0 -0
  124. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/setup.py +0 -0
  125. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/__init__.py +0 -0
  126. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/chunked_loss/README.md +0 -0
  127. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  128. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  129. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  130. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/chunked_loss/functional.py +0 -0
  131. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  132. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/chunked_loss/fused_linear_ppo.py +0 -0
  133. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  134. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  135. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
  136. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  137. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  138. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  139. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  140. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/env_report.py +0 -0
  141. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/ops/__init__.py +0 -0
  142. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/ops/cross_entropy.py +0 -0
  143. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/ops/dyt.py +0 -0
  144. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  145. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  146. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  147. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  148. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/ops/geglu.py +0 -0
  149. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/ops/group_norm.py +0 -0
  150. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/ops/jsd.py +0 -0
  151. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/ops/kl_div.py +0 -0
  152. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/ops/layer_norm.py +0 -0
  153. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  154. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/ops/rms_norm.py +0 -0
  155. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/ops/rope.py +0 -0
  156. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/ops/swiglu.py +0 -0
  157. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/ops/tvd.py +0 -0
  158. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/ops/utils.py +0 -0
  159. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/__init__.py +0 -0
  160. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/auto_model.py +0 -0
  161. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  162. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/dyt.py +0 -0
  163. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  164. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/functional.py +0 -0
  165. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  166. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  167. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/geglu.py +0 -0
  168. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/gema3_rms.py +0 -0
  169. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/group_norm.py +0 -0
  170. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/jsd.py +0 -0
  171. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/kl_div.py +0 -0
  172. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/layer_norm.py +0 -0
  173. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/model/__init__.py +0 -0
  174. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/model/llava.py +0 -0
  175. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
  176. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/model/paligemma.py +0 -0
  177. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  178. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  179. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/rms_norm.py +0 -0
  180. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/rope.py +0 -0
  181. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/swiglu.py +0 -0
  182. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  183. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  184. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  185. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/transformers/tvd.py +0 -0
  186. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/triton/__init__.py +0 -0
  187. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/triton/monkey_patch.py +0 -0
  188. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel/utils.py +0 -0
  189. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -0
  190. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  191. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  192. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  193. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/__init__.py +0 -0
  194. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/chunked_loss/__init__.py +0 -0
  195. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/chunked_loss/test_cpo_loss.py +0 -0
  196. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/chunked_loss/test_dpo_loss.py +0 -0
  197. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/chunked_loss/test_grpo_loss.py +0 -0
  198. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/chunked_loss/test_jsd_loss.py +0 -0
  199. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/chunked_loss/test_kto_loss.py +0 -0
  200. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/chunked_loss/test_orpo_loss.py +0 -0
  201. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/chunked_loss/test_simpo_loss.py +0 -0
  202. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/conftest.py +0 -0
  203. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/convergence/__init__.py +0 -0
  204. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/convergence/bf16/__init__.py +0 -0
  205. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/convergence/bf16/test_mini_models.py +0 -0
  206. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
  207. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/convergence/fp32/__init__.py +0 -0
  208. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/convergence/fp32/test_mini_models.py +0 -0
  209. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
  210. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json +0 -0
  211. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
  212. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
  213. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
  214. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
  215. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  216. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
  217. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  218. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  219. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/resources/tiny_shakespeare.txt +0 -0
  220. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  221. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  222. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  223. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/transformers/test_auto_model.py +0 -0
  224. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/transformers/test_cross_entropy.py +0 -0
  225. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/transformers/test_dyt.py +0 -0
  226. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/transformers/test_embedding.py +0 -0
  227. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/transformers/test_flex_attention.py +0 -0
  228. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/transformers/test_fused_linear_jsd.py +0 -0
  229. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/transformers/test_geglu.py +0 -0
  230. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/transformers/test_group_norm.py +0 -0
  231. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/transformers/test_jsd.py +0 -0
  232. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/transformers/test_kl_div.py +0 -0
  233. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/transformers/test_layer_norm.py +0 -0
  234. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/transformers/test_mm_int8int2.py +0 -0
  235. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/transformers/test_monkey_patch.py +0 -0
  236. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/transformers/test_qwen2vl_mrope.py +0 -0
  237. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/transformers/test_rope.py +0 -0
  238. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/transformers/test_swiglu.py +0 -0
  239. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/transformers/test_trainer_integration.py +0 -0
  240. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/transformers/test_transformers.py +0 -0
  241. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/transformers/test_tvd.py +0 -0
  242. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/triton/test_triton_monkey_patch.py +0 -0
  243. {liger_kernel_nightly-0.5.8.dev20250422210723 → liger_kernel_nightly-0.5.8.dev20250428050809}/test/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.8.dev20250422210723
3
+ Version: 0.5.8.dev20250428050809
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.8.dev20250422210723"
7
+ version = "0.5.8.dev20250428050809"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -201,14 +201,16 @@ def lce_forward(
201
201
 
202
202
  hidden_states = outputs[0]
203
203
 
204
+ shift_labels = loss_kwargs.pop("shift_labels", None)
204
205
  logits = None
205
206
  loss = None
206
207
  # if in training mode, don't materialize logits
207
- if self.training and (labels is not None):
208
+ if self.training and (labels is not None or shift_labels is not None):
208
209
  loss = LigerForCausalLMLoss(
209
210
  hidden_states=hidden_states,
210
211
  lm_head_weight=self.lm_head.weight,
211
212
  labels=labels,
213
+ shift_labels=shift_labels,
212
214
  hidden_size=self.config.hidden_size,
213
215
  **loss_kwargs,
214
216
  )
@@ -213,14 +213,16 @@ def lce_forward(
213
213
 
214
214
  hidden_states = outputs[0]
215
215
 
216
+ shift_labels = loss_kwargs.pop("shift_labels", None)
216
217
  logits = None
217
218
  loss = None
218
219
  # if in training mode, don't materialize logits
219
- if self.training and (labels is not None):
220
+ if self.training and (labels is not None or shift_labels is not None):
220
221
  loss = LigerForCausalLMLoss(
221
222
  hidden_states=hidden_states,
222
223
  lm_head_weight=self.lm_head.weight,
223
224
  labels=labels,
225
+ shift_labels=shift_labels,
224
226
  hidden_size=self.config.hidden_size,
225
227
  final_logit_softcapping=self.config.final_logit_softcapping,
226
228
  **loss_kwargs,
@@ -104,13 +104,15 @@ def causal_forward(
104
104
  # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
105
105
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
106
106
  kept_hidden_states = hidden_states[:, slice_indices, :]
107
+ shift_labels = loss_kwargs.pop("shift_labels", None)
107
108
  loss = None
108
109
  logits = None
109
- if self.training and (labels is not None):
110
+ if self.training and (labels is not None or shift_labels is not None):
110
111
  loss = LigerForCausalLMLoss(
111
112
  hidden_states=kept_hidden_states,
112
113
  lm_head_weight=self.lm_head.weight,
113
114
  labels=labels,
115
+ shift_labels=shift_labels,
114
116
  hidden_size=self.config.hidden_size,
115
117
  final_logit_softcapping=self.config.final_logit_softcapping,
116
118
  **loss_kwargs,
@@ -213,14 +213,16 @@ def lce_forward(
213
213
  if self.config.pretraining_tp > 1:
214
214
  raise Exception("Liger Kernel does not support pretraining_tp!!")
215
215
 
216
+ shift_labels = loss_kwargs.pop("shift_labels", None)
216
217
  logits = None
217
218
  loss = None
218
219
  # if in training mode, don't materialize logits
219
- if self.training and (labels is not None):
220
+ if self.training and (labels is not None or shift_labels is not None):
220
221
  loss = LigerForCausalLMLoss(
221
222
  hidden_states=hidden_states,
222
223
  lm_head_weight=self.lm_head.weight,
223
224
  labels=labels,
225
+ shift_labels=shift_labels,
224
226
  hidden_size=self.config.hidden_size,
225
227
  **loss_kwargs,
226
228
  )
@@ -92,14 +92,16 @@ def lce_forward(
92
92
 
93
93
  hidden_states = outputs[0]
94
94
 
95
+ shift_labels = loss_kwargs.pop("shift_labels", None)
95
96
  loss = None
96
97
  logits = None
97
98
 
98
- if self.training and (labels is not None):
99
+ if self.training and (labels is not None or shift_labels is not None):
99
100
  loss = LigerForCausalLMLoss(
100
101
  hidden_states=hidden_states,
101
102
  lm_head_weight=self.lm_head.weight,
102
103
  labels=labels,
104
+ shift_labels=shift_labels,
103
105
  hidden_size=self.config.hidden_size,
104
106
  **loss_kwargs,
105
107
  )
@@ -226,14 +226,16 @@ def lce_forward(
226
226
 
227
227
  hidden_states = outputs[0]
228
228
 
229
+ shift_labels = loss_kwargs.pop("shift_labels", None)
229
230
  logits = None
230
231
  loss = None
231
232
  # if in training mode, don't materialize logits
232
- if self.training and (labels is not None):
233
+ if self.training and (labels is not None or shift_labels is not None):
233
234
  loss = LigerForCausalLMLoss(
234
235
  hidden_states=hidden_states,
235
236
  lm_head_weight=self.lm_head.weight,
236
237
  labels=labels,
238
+ shift_labels=shift_labels,
237
239
  hidden_size=self.config.hidden_size,
238
240
  **loss_kwargs,
239
241
  )
@@ -216,14 +216,16 @@ def lce_forward(
216
216
 
217
217
  hidden_states = outputs[0]
218
218
 
219
+ shift_labels = loss_kwargs.pop("shift_labels", None)
219
220
  logits = None
220
221
  loss = None
221
222
  # if in training mode, don't materialize logits
222
- if self.training and (labels is not None):
223
+ if self.training and (labels is not None or shift_labels is not None):
223
224
  loss = LigerForCausalLMLoss(
224
225
  hidden_states=hidden_states,
225
226
  lm_head_weight=self.lm_head.weight,
226
227
  labels=labels,
228
+ shift_labels=shift_labels,
227
229
  hidden_size=self.config.hidden_size,
228
230
  **loss_kwargs,
229
231
  )
@@ -89,14 +89,16 @@ def lce_forward(
89
89
 
90
90
  hidden_states = outputs[0]
91
91
 
92
+ shift_labels = loss_kwargs.pop("shift_labels", None)
92
93
  logits = None
93
94
  loss = None
94
95
  # if in training mode, don't materialize logits
95
- if self.training and (labels is not None):
96
+ if self.training and (labels is not None or shift_labels is not None):
96
97
  loss = LigerForCausalLMLoss(
97
98
  hidden_states=hidden_states,
98
99
  lm_head_weight=self.lm_head.weight,
99
100
  labels=labels,
101
+ shift_labels=shift_labels,
100
102
  hidden_size=self.config.hidden_size,
101
103
  **loss_kwargs,
102
104
  )
@@ -214,14 +214,16 @@ def lce_forward(
214
214
 
215
215
  hidden_states = outputs[0]
216
216
 
217
+ shift_labels = loss_kwargs.pop("shift_labels", None)
217
218
  logits = None
218
219
  loss = None
219
220
  # if in training mode, don't materialize logits
220
- if self.training and (labels is not None):
221
+ if self.training and (labels is not None or shift_labels is not None):
221
222
  loss = LigerForCausalLMLoss(
222
223
  hidden_states=hidden_states,
223
224
  lm_head_weight=self.lm_head.weight,
224
225
  labels=labels,
226
+ shift_labels=shift_labels,
225
227
  hidden_size=self.config.hidden_size,
226
228
  **loss_kwargs,
227
229
  )
@@ -200,14 +200,16 @@ def lce_forward(
200
200
 
201
201
  hidden_states = outputs[0]
202
202
 
203
+ shift_labels = loss_kwargs.pop("shift_labels", None)
203
204
  logits = None
204
205
  loss = None
205
206
  # if in training mode, don't materialize logits
206
- if self.training and (labels is not None):
207
+ if self.training and (labels is not None or shift_labels is not None):
207
208
  loss = LigerForCausalLMLoss(
208
209
  hidden_states=hidden_states,
209
210
  lm_head_weight=self.lm_head.weight,
210
211
  labels=labels,
212
+ shift_labels=shift_labels,
211
213
  hidden_size=self.config.hidden_size,
212
214
  **loss_kwargs,
213
215
  )
@@ -163,14 +163,16 @@ def lce_forward(
163
163
 
164
164
  hidden_states = outputs[0]
165
165
 
166
+ shift_labels = loss_kwargs.pop("shift_labels", None)
166
167
  loss = None
167
168
  logits = None
168
169
 
169
- if self.training and (labels is not None):
170
+ if self.training and (labels is not None or shift_labels is not None):
170
171
  loss = LigerForCausalLMLoss(
171
172
  hidden_states=hidden_states,
172
173
  lm_head_weight=self.lm_head.weight,
173
174
  labels=labels,
175
+ shift_labels=shift_labels,
174
176
  hidden_size=self.config.hidden_size,
175
177
  **loss_kwargs,
176
178
  )
@@ -167,14 +167,16 @@ def lce_forward(
167
167
 
168
168
  hidden_states = outputs[0]
169
169
 
170
+ shift_labels = loss_kwargs.pop("shift_labels", None)
170
171
  loss = None
171
172
  logits = None
172
173
 
173
- if self.training and (labels is not None):
174
+ if self.training and (labels is not None or shift_labels is not None):
174
175
  loss = LigerForCausalLMLoss(
175
176
  hidden_states=hidden_states,
176
177
  lm_head_weight=self.lm_head.weight,
177
178
  labels=labels,
179
+ shift_labels=shift_labels,
178
180
  hidden_size=self.config.hidden_size,
179
181
  **loss_kwargs,
180
182
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.8.dev20250422210723
3
+ Version: 0.5.8.dev20250428050809
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -770,7 +770,6 @@ def run_mini_model_multimodal(
770
770
  not QWEN2_VL_AVAILABLE,
771
771
  reason="Qwen2-VL not available in this version of transformers",
772
772
  ),
773
- pytest.mark.skipif(device == "xpu", reason="skip for XPU"),
774
773
  ],
775
774
  ),
776
775
  pytest.param(
@@ -809,7 +808,6 @@ def run_mini_model_multimodal(
809
808
  not QWEN2_5_VL_AVAILABLE,
810
809
  reason="Qwen2.5-VL not available in this version of transformers",
811
810
  ),
812
- pytest.mark.skipif(device == "xpu", reason="skip for XPU"),
813
811
  ],
814
812
  ),
815
813
  pytest.param(
@@ -876,7 +874,7 @@ def run_mini_model_multimodal(
876
874
  torch.bfloat16,
877
875
  1e-3,
878
876
  1e-2,
879
- 0.25, # Increase the absolute tolerance for the logits of Gemma-3.
877
+ 0.4, # Increase the absolute tolerance for the logits of Gemma-3.
880
878
  1e-1,
881
879
  1e-2,
882
880
  1e-2,
@@ -765,7 +765,6 @@ def run_mini_model_multimodal(
765
765
  not QWEN2_VL_AVAILABLE,
766
766
  reason="Qwen2-VL not available in this version of transformers",
767
767
  ),
768
- pytest.mark.skipif(device == "xpu", reason="skip for XPU"),
769
768
  ],
770
769
  ),
771
770
  pytest.param(
@@ -800,7 +799,6 @@ def run_mini_model_multimodal(
800
799
  not QWEN2_5_VL_AVAILABLE,
801
800
  reason="Qwen2.5-VL not available in this version of transformers",
802
801
  ),
803
- pytest.mark.skipif(device == "xpu", reason="skip for XPU"),
804
802
  ],
805
803
  ),
806
804
  pytest.param(
@@ -105,7 +105,7 @@ class LigerLMHeadCE(torch.nn.Module):
105
105
  @pytest.mark.parametrize(
106
106
  "B, T, H, V",
107
107
  [
108
- pytest.param(8, 128, 1024, 4096, marks=pytest.mark.skipif(device="xpu", reason="skip for XPU")),
108
+ (8, 128, 1024, 4096),
109
109
  (4, 47, 31, 123), # random shape
110
110
  ],
111
111
  )
@@ -287,7 +287,7 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, bias, ce_weight, atol
287
287
  @pytest.mark.parametrize(
288
288
  "B, T, H, V",
289
289
  [
290
- pytest.param(8, 128, 1024, 4096, marks=pytest.mark.skipif(device="xpu", reason="skip for XPU")),
290
+ (8, 128, 1024, 4096),
291
291
  (4, 47, 31, 123), # random shape
292
292
  ],
293
293
  )
@@ -103,7 +103,7 @@ class GemmaRMSNorm(nn.Module):
103
103
  [
104
104
  (LlamaRMSNorm, 0.0, "llama"),
105
105
  (GemmaRMSNorm, 1.0, "gemma"),
106
- pytest.param(BaseRMSNorm, 0.0, "none", marks=pytest.mark.skipif(device="xpu", reason="skip for XPU")),
106
+ (BaseRMSNorm, 0.0, "none"),
107
107
  ],
108
108
  )
109
109
  @pytest.mark.parametrize(