liger-kernel-nightly 0.5.10.dev20250702150221__tar.gz → 0.5.10.dev20250704061237__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (274) hide show
  1. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/.gitignore +1 -0
  2. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/PKG-INFO +1 -1
  3. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/pyproject.toml +1 -1
  4. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/ops/geglu.py +1 -1
  5. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/ops/swiglu.py +1 -1
  6. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/model/gemma.py +9 -1
  7. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/model/gemma2.py +9 -1
  8. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/model/llama.py +10 -1
  9. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/model/mistral.py +0 -3
  10. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/model/phi3.py +9 -1
  11. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/model/qwen2.py +8 -0
  12. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/monkey_patch.py +10 -3
  13. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
  14. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel_nightly.egg-info/SOURCES.txt +0 -1
  15. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/convergence/bf16/test_mini_models.py +108 -78
  16. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/convergence/bf16/test_mini_models_multimodal.py +33 -25
  17. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/convergence/bf16/test_mini_models_with_logits.py +92 -72
  18. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/convergence/fp32/test_mini_models.py +96 -55
  19. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/convergence/fp32/test_mini_models_multimodal.py +34 -24
  20. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/convergence/fp32/test_mini_models_with_logits.py +63 -44
  21. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/transformers/test_dyt.py +12 -8
  22. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/transformers/test_monkey_patch.py +4 -0
  23. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/transformers/test_multi_token_attention.py +3 -0
  24. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/utils.py +3 -2
  25. liger_kernel_nightly-0.5.10.dev20250702150221/.idea/workspace.xml +0 -79
  26. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  27. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  28. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/.github/pull_request_template.md +0 -0
  29. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/.github/workflows/amd-ci.yml +0 -0
  30. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/.github/workflows/benchmark.yml +0 -0
  31. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/.github/workflows/docs.yml +0 -0
  32. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/.github/workflows/intel-ci.yml +0 -0
  33. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/.github/workflows/nvi-ci.yml +0 -0
  34. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/.github/workflows/publish-nightly.yml +0 -0
  35. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/.github/workflows/publish-release.yml +0 -0
  36. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/LICENSE +0 -0
  37. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/Makefile +0 -0
  38. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/NOTICE +0 -0
  39. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/README.md +0 -0
  40. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/README.md +0 -0
  41. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/__init__.py +0 -0
  42. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/benchmarks_visualizer.py +0 -0
  43. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/data/all_benchmark_data.csv +0 -0
  44. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/__init__.py +0 -0
  45. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  46. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  47. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/benchmark_distill_cosine_loss.py +0 -0
  48. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  49. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  50. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/benchmark_dyt.py +0 -0
  51. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/benchmark_embedding.py +0 -0
  52. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  53. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  54. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/benchmark_fused_neighborhood_attention.py +0 -0
  55. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/benchmark_geglu.py +0 -0
  56. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/benchmark_group_norm.py +0 -0
  57. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/benchmark_jsd.py +0 -0
  58. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/benchmark_kl_div.py +0 -0
  59. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  60. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  61. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/benchmark_multi_token_attention.py +0 -0
  62. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  63. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  64. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  65. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/benchmark_rope.py +0 -0
  66. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  67. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/benchmark_softmax.py +0 -0
  68. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/benchmark_sparse_multi_token_attention.py +0 -0
  69. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/benchmark_sparsemax.py +0 -0
  70. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/benchmark_swiglu.py +0 -0
  71. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/benchmark_tvd.py +0 -0
  72. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/benchmark/scripts/utils.py +0 -0
  73. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/dev/fmt-requirements.txt +0 -0
  74. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/dev/modal/benchmarks.py +0 -0
  75. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/dev/modal/tests.py +0 -0
  76. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/dev/modal/tests_bwd.py +0 -0
  77. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/docs/Examples.md +0 -0
  78. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/docs/Getting-Started.md +0 -0
  79. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/docs/High-Level-APIs.md +0 -0
  80. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/docs/Low-Level-APIs.md +0 -0
  81. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/docs/acknowledgement.md +0 -0
  82. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/docs/contributing.md +0 -0
  83. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/docs/images/banner.GIF +0 -0
  84. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/docs/images/compose.gif +0 -0
  85. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/docs/images/e2e-memory.png +0 -0
  86. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/docs/images/e2e-tps.png +0 -0
  87. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/docs/images/logo-banner.png +0 -0
  88. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/docs/images/patch.gif +0 -0
  89. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/docs/images/post-training.png +0 -0
  90. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/docs/index.md +0 -0
  91. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/docs/license.md +0 -0
  92. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/alignment/accelerate_config.yaml +0 -0
  93. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/alignment/run_orpo.py +0 -0
  94. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/huggingface/README.md +0 -0
  95. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/huggingface/callback.py +0 -0
  96. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/huggingface/config/fsdp_config.json +0 -0
  97. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  98. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  99. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  100. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/huggingface/img/llama_tps.png +0 -0
  101. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  102. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/huggingface/img/qwen_tps.png +0 -0
  103. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/huggingface/launch_on_modal.py +0 -0
  104. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/huggingface/requirements.txt +0 -0
  105. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/huggingface/run_benchmarks.sh +0 -0
  106. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/huggingface/run_gemma.sh +0 -0
  107. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/huggingface/run_llama.sh +0 -0
  108. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/huggingface/run_qwen.sh +0 -0
  109. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/huggingface/run_qwen2_vl.sh +0 -0
  110. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/huggingface/training.py +0 -0
  111. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/huggingface/training_multimodal.py +0 -0
  112. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/lightning/README.md +0 -0
  113. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/lightning/requirements.txt +0 -0
  114. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/lightning/training.py +0 -0
  115. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/medusa/README.md +0 -0
  116. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/medusa/callback.py +0 -0
  117. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  118. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  119. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  120. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  121. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  122. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  123. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  124. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  125. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  126. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/medusa/medusa_util.py +0 -0
  127. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/medusa/requirements.txt +0 -0
  128. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  129. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/examples/medusa/train.py +0 -0
  130. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/licenses/LICENSE-Apache-2.0 +0 -0
  131. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  132. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  133. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/licenses/LICENSE-MIT-llmc +0 -0
  134. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/licenses/LICENSE-MIT-triton +0 -0
  135. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/mkdocs.yml +0 -0
  136. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/setup.cfg +0 -0
  137. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/setup.py +0 -0
  138. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/__init__.py +0 -0
  139. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/chunked_loss/README.md +0 -0
  140. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  141. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/chunked_loss/cosine_similarity_loss.py +0 -0
  142. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  143. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  144. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/chunked_loss/functional.py +0 -0
  145. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  146. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/chunked_loss/fused_linear_ppo.py +0 -0
  147. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  148. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  149. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
  150. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  151. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  152. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  153. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  154. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/env_report.py +0 -0
  155. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/ops/__init__.py +0 -0
  156. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/ops/cross_entropy.py +0 -0
  157. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/ops/dyt.py +0 -0
  158. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  159. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  160. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  161. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  162. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/ops/fused_neighborhood_attention.py +0 -0
  163. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/ops/group_norm.py +0 -0
  164. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/ops/grpo_loss.py +0 -0
  165. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/ops/jsd.py +0 -0
  166. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/ops/kl_div.py +0 -0
  167. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/ops/layer_norm.py +0 -0
  168. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/ops/multi_token_attention.py +0 -0
  169. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  170. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/ops/rms_norm.py +0 -0
  171. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/ops/rope.py +0 -0
  172. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/ops/softmax.py +0 -0
  173. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/ops/sparsemax.py +0 -0
  174. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/ops/tvd.py +0 -0
  175. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/ops/utils.py +0 -0
  176. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/__init__.py +0 -0
  177. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/auto_model.py +0 -0
  178. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  179. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/dyt.py +0 -0
  180. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  181. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/fsdp.py +0 -0
  182. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/functional.py +0 -0
  183. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  184. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  185. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/fused_neighborhood_attention.py +0 -0
  186. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/geglu.py +0 -0
  187. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/group_norm.py +0 -0
  188. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/grpo_loss.py +0 -0
  189. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/jsd.py +0 -0
  190. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/kl_div.py +0 -0
  191. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/layer_norm.py +0 -0
  192. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/model/__init__.py +0 -0
  193. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/model/gemma3.py +0 -0
  194. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/model/glm4.py +0 -0
  195. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/model/llama4.py +0 -0
  196. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/model/llava.py +0 -0
  197. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
  198. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  199. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/model/mllama.py +0 -0
  200. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/model/olmo2.py +0 -0
  201. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/model/paligemma.py +0 -0
  202. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
  203. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  204. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/model/qwen3.py +0 -0
  205. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/model/qwen3_moe.py +0 -0
  206. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/multi_token_attention.py +0 -0
  207. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  208. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/rms_norm.py +0 -0
  209. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/rope.py +0 -0
  210. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/softmax.py +0 -0
  211. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/sparsemax.py +0 -0
  212. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/swiglu.py +0 -0
  213. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  214. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  215. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  216. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/transformers/tvd.py +0 -0
  217. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/triton/__init__.py +0 -0
  218. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/triton/monkey_patch.py +0 -0
  219. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel/utils.py +0 -0
  220. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  221. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  222. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  223. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/__init__.py +0 -0
  224. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/chunked_loss/__init__.py +0 -0
  225. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/chunked_loss/test_cosine_loss.py +0 -0
  226. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/chunked_loss/test_cpo_loss.py +0 -0
  227. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/chunked_loss/test_dpo_loss.py +0 -0
  228. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/chunked_loss/test_grpo_loss.py +0 -0
  229. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/chunked_loss/test_jsd_loss.py +0 -0
  230. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/chunked_loss/test_kto_loss.py +0 -0
  231. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/chunked_loss/test_orpo_loss.py +0 -0
  232. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/chunked_loss/test_simpo_loss.py +0 -0
  233. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/conftest.py +0 -0
  234. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/convergence/__init__.py +0 -0
  235. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/convergence/bf16/__init__.py +0 -0
  236. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/convergence/fp32/__init__.py +0 -0
  237. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json +0 -0
  238. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
  239. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
  240. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
  241. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
  242. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  243. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
  244. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  245. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/resources/fake_configs/meta-llama/Llama-4-Scout-17B-16E-Instruct/tokenizer_config.json +0 -0
  246. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  247. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/resources/tiny_shakespeare.txt +0 -0
  248. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  249. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  250. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  251. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/transformers/test_auto_model.py +0 -0
  252. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/transformers/test_cross_entropy.py +0 -0
  253. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/transformers/test_embedding.py +0 -0
  254. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/transformers/test_flex_attention.py +0 -0
  255. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  256. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/transformers/test_fused_linear_jsd.py +0 -0
  257. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/transformers/test_fused_neighborhood_attention.py +0 -0
  258. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/transformers/test_geglu.py +0 -0
  259. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/transformers/test_group_norm.py +0 -0
  260. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/transformers/test_grpo_loss.py +0 -0
  261. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/transformers/test_jsd.py +0 -0
  262. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/transformers/test_kl_div.py +0 -0
  263. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/transformers/test_layer_norm.py +0 -0
  264. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/transformers/test_mm_int8int2.py +0 -0
  265. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/transformers/test_qwen2vl_mrope.py +0 -0
  266. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/transformers/test_rms_norm.py +0 -0
  267. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/transformers/test_rope.py +0 -0
  268. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/transformers/test_softmax.py +0 -0
  269. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/transformers/test_sparsemax.py +0 -0
  270. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/transformers/test_swiglu.py +0 -0
  271. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/transformers/test_trainer_integration.py +0 -0
  272. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/transformers/test_transformers.py +0 -0
  273. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/transformers/test_tvd.py +0 -0
  274. {liger_kernel_nightly-0.5.10.dev20250702150221 → liger_kernel_nightly-0.5.10.dev20250704061237}/test/triton/test_triton_monkey_patch.py +0 -0
@@ -6,6 +6,7 @@ site/
6
6
  venv/
7
7
  .ipynb_checkpoints/
8
8
  .vscode/
9
+ .idea/
9
10
 
10
11
  # Misc
11
12
  .DS_Store
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.10.dev20250702150221
3
+ Version: 0.5.10.dev20250704061237
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.10.dev20250702150221"
7
+ version = "0.5.10.dev20250704061237"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -40,7 +40,7 @@ def _geglu_tanh_forward_kernel(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE
40
40
  tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
41
41
  tanh_result = tanh(tanh_arg)
42
42
  geglu_a = 0.5 * a_row * (1 + tanh_result)
43
- c_row = geglu_a * b_row
43
+ c_row = geglu_a.cast(b_row.dtype) * b_row
44
44
  tl.store(c + col_offsets, c_row, mask=mask)
45
45
 
46
46
 
@@ -26,7 +26,7 @@ def _swiglu_forward_kernel(a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BL
26
26
  # sigmoid requires type float32
27
27
  a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
28
28
  b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)
29
- c_row = silu(a_row) * b_row
29
+ c_row = silu(a_row).cast(b_row.dtype) * b_row
30
30
  tl.store(c_ptr + col_offsets, c_row, mask=mask)
31
31
 
32
32
 
@@ -27,6 +27,7 @@ def lce_forward_deprecated(
27
27
  output_hidden_states: Optional[bool] = None,
28
28
  return_dict: Optional[bool] = None,
29
29
  cache_position: Optional[torch.LongTensor] = None,
30
+ skip_logits: Optional[bool] = None,
30
31
  ) -> Union[Tuple, CausalLMOutputWithPast]:
31
32
  r"""
32
33
 
@@ -81,7 +82,14 @@ def lce_forward_deprecated(
81
82
  loss = None
82
83
  logits = None
83
84
 
84
- if self.training and (labels is not None):
85
+ if skip_logits and labels is None:
86
+ raise ValueError("skip_logits is True, but labels is None")
87
+
88
+ if skip_logits is None:
89
+ # By default, if in training mode, don't materialize logits
90
+ skip_logits = self.training and labels is not None
91
+
92
+ if skip_logits:
85
93
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
86
94
  shift_labels = labels[..., 1:].contiguous()
87
95
 
@@ -30,6 +30,7 @@ def lce_forward_deprecated(
30
30
  output_hidden_states: Optional[bool] = None,
31
31
  return_dict: Optional[bool] = None,
32
32
  cache_position: Optional[torch.LongTensor] = None,
33
+ skip_logits: Optional[bool] = None,
33
34
  **kwargs,
34
35
  ) -> Union[Tuple, CausalLMOutputWithPast]:
35
36
  r"""
@@ -85,7 +86,14 @@ def lce_forward_deprecated(
85
86
  loss = None
86
87
  logits = None
87
88
 
88
- if self.training and (labels is not None):
89
+ if skip_logits and labels is None:
90
+ raise ValueError("skip_logits is True, but labels is None")
91
+
92
+ if skip_logits is None:
93
+ # By default, if in training mode, don't materialize logits
94
+ skip_logits = self.training and labels is not None
95
+
96
+ if skip_logits:
89
97
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
90
98
  shift_labels = labels[..., 1:].contiguous()
91
99
 
@@ -37,6 +37,7 @@ def lce_forward_deprecated(
37
37
  output_hidden_states: Optional[bool] = None,
38
38
  return_dict: Optional[bool] = None,
39
39
  cache_position: Optional[torch.LongTensor] = None,
40
+ skip_logits: Optional[bool] = None,
40
41
  ) -> Union[Tuple, CausalLMOutputWithPast]:
41
42
  r"""
42
43
  Copy paste llama forward but replace torch cross entropy with liger fused linear cross entropy
@@ -91,7 +92,15 @@ def lce_forward_deprecated(
91
92
  loss = None
92
93
  logits = None
93
94
 
94
- if self.training and (labels is not None):
95
+ # if in training mode, don't materialize logits
96
+ if skip_logits and labels is None:
97
+ raise ValueError("skip_logits is True, but labels is None")
98
+
99
+ if skip_logits is None:
100
+ # By default, if in training mode, don't materialize logits
101
+ skip_logits = self.training and labels is not None
102
+
103
+ if skip_logits:
95
104
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
96
105
  shift_labels = labels[..., 1:].contiguous()
97
106
 
@@ -133,6 +133,3 @@ def lce_forward(
133
133
  hidden_states=outputs.hidden_states,
134
134
  attentions=outputs.attentions,
135
135
  )
136
-
137
-
138
- # Note: Grad Acc is not fixed in mistral at transformer 4.46.1
@@ -26,6 +26,7 @@ def lce_forward_deprecated(
26
26
  output_hidden_states: Optional[bool] = None,
27
27
  return_dict: Optional[bool] = None,
28
28
  cache_position: Optional[torch.LongTensor] = None,
29
+ skip_logits: Optional[bool] = None,
29
30
  ) -> Union[Tuple, CausalLMOutputWithPast]:
30
31
  r"""
31
32
  Copy paste phi3 forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy
@@ -80,7 +81,14 @@ def lce_forward_deprecated(
80
81
  loss = None
81
82
  logits = None
82
83
 
83
- if self.training and labels is not None:
84
+ if skip_logits and labels is None:
85
+ raise ValueError("skip_logits is True, but labels is None")
86
+
87
+ if skip_logits is None:
88
+ # By default, if in training mode, don't materialize logits
89
+ skip_logits = self.training and labels is not None
90
+
91
+ if skip_logits:
84
92
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
85
93
  shift_labels = labels[..., 1:].contiguous()
86
94
 
@@ -26,6 +26,7 @@ def lce_forward_deprecated(
26
26
  output_hidden_states: Optional[bool] = None,
27
27
  return_dict: Optional[bool] = None,
28
28
  cache_position: Optional[torch.LongTensor] = None,
29
+ skip_logits: Optional[bool] = None,
29
30
  ) -> Union[Tuple, CausalLMOutputWithPast]:
30
31
  r"""
31
32
  Copy paste Qwen2's forward but replace torch cross entropy with liger fused linear cross entropy
@@ -80,6 +81,13 @@ def lce_forward_deprecated(
80
81
  loss = None
81
82
  logits = None
82
83
 
84
+ if skip_logits and labels is None:
85
+ raise ValueError("skip_logits is True, but labels is None")
86
+
87
+ if skip_logits is None:
88
+ # By default, if in training mode, don't materialize logits
89
+ skip_logits = self.training and labels is not None
90
+
83
91
  if self.training and (labels is not None):
84
92
  shift_hidden_states = hidden_states[..., :-1, :].contiguous()
85
93
  shift_labels = labels[..., 1:].contiguous()
@@ -611,10 +611,17 @@ def apply_liger_kernel_to_mistral(
611
611
  if cross_entropy:
612
612
  modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
613
613
  if fused_linear_cross_entropy:
614
- if model is not None:
615
- model.forward = MethodType(mistral_lce_forward, model)
614
+ if transformer_version >= version.parse("4.49.0"):
615
+ if model is not None:
616
+ model.forward = MethodType(mistral_lce_forward, model)
617
+ else:
618
+ modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
616
619
  else:
617
- modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
620
+ logger.warning(
621
+ "The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
622
+ )
623
+ logger.warning("LigerFusedLinearCrossEntropy patch is not applied.")
624
+
618
625
  if swiglu:
619
626
  modeling_mistral.MistralMLP = LigerSwiGLUMLP
620
627
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.10.dev20250702150221
3
+ Version: 0.5.10.dev20250704061237
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -16,7 +16,6 @@ setup.py
16
16
  .github/workflows/nvi-ci.yml
17
17
  .github/workflows/publish-nightly.yml
18
18
  .github/workflows/publish-release.yml
19
- .idea/workspace.xml
20
19
  benchmark/README.md
21
20
  benchmark/__init__.py
22
21
  benchmark/benchmarks_visualizer.py
@@ -9,8 +9,6 @@ from transformers.models.gemma2 import Gemma2Config
9
9
  from transformers.models.gemma2 import Gemma2ForCausalLM
10
10
  from transformers.models.llama import LlamaConfig
11
11
  from transformers.models.llama import LlamaForCausalLM
12
- from transformers.models.llama4 import Llama4ForCausalLM
13
- from transformers.models.llama4.configuration_llama4 import Llama4TextConfig
14
12
  from transformers.models.mistral import MistralConfig
15
13
  from transformers.models.mistral import MistralForCausalLM
16
14
  from transformers.models.mixtral import MixtralConfig
@@ -65,6 +63,14 @@ from test.utils import set_seed
65
63
  from test.utils import simple_collate_fn
66
64
  from test.utils import supports_bfloat16
67
65
 
66
+ try:
67
+ from transformers.models.llama4.configuration_llama4 import Llama4TextConfig
68
+ from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM
69
+
70
+ LLAMA4_AVAILABLE = True
71
+ except ImportError:
72
+ LLAMA4_AVAILABLE = False
73
+
68
74
  try:
69
75
  # Mllama is only available in transformers>=4.45.0
70
76
  from transformers.models.mllama.configuration_mllama import MllamaTextConfig
@@ -156,35 +162,6 @@ from liger_kernel.utils import infer_device
156
162
  device = infer_device()
157
163
 
158
164
  MINI_MODEL_SETUPS = {
159
- "mini_llama4": MiniModelConfig(
160
- liger_kernel_patch_func=apply_liger_kernel_to_llama4,
161
- liger_kernel_patch_revert_func=revert_liger_kernel_to_llama4,
162
- model_class=Llama4ForCausalLM,
163
- mini_model_config=Llama4TextConfig(
164
- bos_token_id=1, # None
165
- eos_token_id=2, # 151329, 151336, 151338
166
- pad_token_id=2, # 151329
167
- partial_rotary_factor=1.0,
168
- cross_attention_layers=None,
169
- dropout=0,
170
- hidden_act="silu",
171
- hidden_size=1024, # 6144
172
- initializer_range=0.02,
173
- intermediate_size=2048, # 14336
174
- max_position_embeddings=4096, # 32768
175
- num_attention_heads=8, # 48
176
- num_hidden_layers=4, # 61
177
- num_key_value_heads=2,
178
- rms_norm_eps=1e-5,
179
- rope_scaling=None,
180
- rope_theta=10000.0,
181
- tie_word_embeddings=False,
182
- use_cache=True,
183
- vocab_size=32000, # 151552
184
- attention_bias=True,
185
- attn_implementation="sdpa", # default value, pytorch native attention
186
- ),
187
- ),
188
165
  "mini_llama3": MiniModelConfig(
189
166
  liger_kernel_patch_func=apply_liger_kernel_to_llama,
190
167
  liger_kernel_patch_revert_func=revert_liger_kernel_to_llama,
@@ -413,6 +390,37 @@ MINI_MODEL_SETUPS = {
413
390
  ),
414
391
  }
415
392
 
393
+ if LLAMA4_AVAILABLE:
394
+ MINI_MODEL_SETUPS["mini_llama4"] = MiniModelConfig(
395
+ liger_kernel_patch_func=apply_liger_kernel_to_llama4,
396
+ liger_kernel_patch_revert_func=revert_liger_kernel_to_llama4,
397
+ model_class=Llama4ForCausalLM,
398
+ mini_model_config=Llama4TextConfig(
399
+ bos_token_id=1, # None
400
+ eos_token_id=2, # 151329, 151336, 151338
401
+ pad_token_id=2, # 151329
402
+ partial_rotary_factor=1.0,
403
+ cross_attention_layers=None,
404
+ dropout=0,
405
+ hidden_act="silu",
406
+ hidden_size=1024, # 6144
407
+ initializer_range=0.02,
408
+ intermediate_size=2048, # 14336
409
+ max_position_embeddings=4096, # 32768
410
+ num_attention_heads=8, # 48
411
+ num_hidden_layers=4, # 61
412
+ num_key_value_heads=2,
413
+ rms_norm_eps=1e-5,
414
+ rope_scaling=None,
415
+ rope_theta=10000.0,
416
+ tie_word_embeddings=False,
417
+ use_cache=True,
418
+ vocab_size=32000, # 151552
419
+ attention_bias=True,
420
+ attn_implementation="sdpa", # default value, pytorch native attention
421
+ ),
422
+ )
423
+
416
424
 
417
425
  if QWEN3_AVAILABLE:
418
426
  MINI_MODEL_SETUPS["mini_qwen3"] = MiniModelConfig(
@@ -902,23 +910,29 @@ def run_mini_model(
902
910
  pytest.param(
903
911
  "mini_llama4",
904
912
  32,
905
- 1e-4,
913
+ 1e-5,
906
914
  torch.bfloat16,
907
- 1e-3,
908
915
  1e-2,
916
+ 5e-2,
909
917
  1e-1,
910
918
  1e-1,
911
919
  1e-2,
912
920
  1e-2,
913
- marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
921
+ marks=[
922
+ pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
923
+ pytest.mark.skipif(
924
+ not LLAMA4_AVAILABLE,
925
+ reason="Llama not available in this version of transformers",
926
+ ),
927
+ ],
914
928
  ),
915
929
  pytest.param(
916
930
  "mini_llama3",
917
931
  32,
918
- 1e-4,
932
+ 1e-5,
919
933
  torch.bfloat16,
920
- 1e-3,
921
934
  1e-2,
935
+ 5e-2,
922
936
  1e-1,
923
937
  1e-2,
924
938
  1e-2,
@@ -928,10 +942,10 @@ def run_mini_model(
928
942
  pytest.param(
929
943
  "mini_llava",
930
944
  32,
931
- 1e-4,
945
+ 1e-5,
932
946
  torch.bfloat16,
933
- 1e-3,
934
947
  1e-2,
948
+ 5e-2,
935
949
  1e-1,
936
950
  1e-1,
937
951
  1e-2,
@@ -942,17 +956,21 @@ def run_mini_model(
942
956
  not LLAVA_AVAILABLE,
943
957
  reason="LLaVa not available in this version of transformers",
944
958
  ),
959
+ pytest.mark.skipif(
960
+ version.parse(transformers.__version__) < version.parse("4.49.0"),
961
+ reason="Mistral not available in transformers<=4.49.0",
962
+ ),
945
963
  ],
946
964
  ),
947
965
  pytest.param(
948
966
  "mini_granite3",
949
967
  32,
950
- 1e-4,
968
+ 1e-5,
951
969
  torch.bfloat16,
952
- 1e-3,
953
970
  1e-2,
971
+ 5e-2,
954
972
  1e-1, # 1e-1
955
- 1e-1, # 1e-2
973
+ 1e-2, # 1e-2
956
974
  1e-2,
957
975
  1e-2,
958
976
  marks=[
@@ -966,9 +984,9 @@ def run_mini_model(
966
984
  pytest.param(
967
985
  "mini_mllama",
968
986
  32,
969
- 1e-4,
987
+ 1e-5,
970
988
  torch.bfloat16,
971
- 1e-3,
989
+ 1e-2,
972
990
  1e-2,
973
991
  1e-1,
974
992
  1e-2,
@@ -985,10 +1003,10 @@ def run_mini_model(
985
1003
  pytest.param(
986
1004
  "mini_qwen2",
987
1005
  32,
988
- 1e-4,
1006
+ 1e-5,
989
1007
  torch.bfloat16,
990
- 1e-3,
991
1008
  1e-2,
1009
+ 5e-2,
992
1010
  1e-1,
993
1011
  1e-2,
994
1012
  1e-2,
@@ -998,10 +1016,10 @@ def run_mini_model(
998
1016
  pytest.param(
999
1017
  "mini_qwen3",
1000
1018
  32,
1001
- 1e-4,
1019
+ 1e-5,
1002
1020
  torch.bfloat16,
1003
- 1e-3,
1004
1021
  1e-2,
1022
+ 5e-2,
1005
1023
  1e-1,
1006
1024
  1e-2,
1007
1025
  1e-2,
@@ -1014,13 +1032,16 @@ def run_mini_model(
1014
1032
  ),
1015
1033
  ],
1016
1034
  ),
1035
+ # TODO(tcc): Investigate qwen3_moe on different machines.
1036
+ # The loss diverges on ci test (A10G), but it never diverges on my local machine (3080).
1037
+ # Qwen3_moe can pass float32 tests.
1017
1038
  pytest.param(
1018
1039
  "mini_qwen3_moe",
1019
1040
  32,
1020
- 1e-4,
1041
+ 1e-5,
1021
1042
  torch.bfloat16,
1022
- 1e-3,
1023
- 1e-2,
1043
+ 5e-2,
1044
+ 5e-2,
1024
1045
  1e-1, # 1e-1
1025
1046
  1e-1, # 1e-2
1026
1047
  1e-2,
@@ -1036,12 +1057,12 @@ def run_mini_model(
1036
1057
  pytest.param(
1037
1058
  "mini_qwen2_vl",
1038
1059
  32,
1039
- 1e-4,
1060
+ 1e-5,
1040
1061
  torch.bfloat16,
1041
- 1e-3,
1062
+ 1e-2,
1042
1063
  5e-2,
1043
- 1, # 1e-1
1044
- 1e-1, # 1e-2
1064
+ 1e-1, # 1e-1
1065
+ 1e-2, # 1e-2
1045
1066
  1e-2,
1046
1067
  1e-2,
1047
1068
  marks=[
@@ -1052,16 +1073,15 @@ def run_mini_model(
1052
1073
  ),
1053
1074
  ],
1054
1075
  ),
1055
- # TODO: logits tolerances are significantly larger than the other tests, need to investigate
1056
1076
  pytest.param(
1057
1077
  "mini_qwen2_5_vl",
1058
1078
  32,
1059
- 1e-4,
1079
+ 1e-5,
1060
1080
  torch.bfloat16,
1061
- 1e-3,
1081
+ 1e-2,
1062
1082
  5e-2,
1063
- 3, # 1e-1
1064
- 1e-1, # 1e-2
1083
+ 1e-1, # 1e-1
1084
+ 1e-2, # 1e-2
1065
1085
  1e-2,
1066
1086
  1e-2,
1067
1087
  marks=[
@@ -1075,9 +1095,9 @@ def run_mini_model(
1075
1095
  pytest.param(
1076
1096
  "mini_phi3",
1077
1097
  32,
1078
- 1e-4,
1098
+ 1e-5,
1079
1099
  torch.bfloat16,
1080
- 1e-3,
1100
+ 1e-2,
1081
1101
  1e-2,
1082
1102
  1e-1,
1083
1103
  1e-2,
@@ -1088,22 +1108,28 @@ def run_mini_model(
1088
1108
  pytest.param(
1089
1109
  "mini_mistral",
1090
1110
  32,
1091
- 1e-4,
1111
+ 1e-5,
1092
1112
  torch.bfloat16,
1093
- 1e-3,
1094
- 1e-2,
1113
+ 5e-2,
1114
+ 5e-2,
1095
1115
  1e-1,
1096
1116
  1e-2,
1097
1117
  1e-2,
1098
1118
  1e-2,
1099
- marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
1119
+ marks=[
1120
+ pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
1121
+ pytest.mark.skipif(
1122
+ version.parse(transformers.__version__) < version.parse("4.49.0"),
1123
+ reason="Mistral not available in transformers<=4.49.0",
1124
+ ),
1125
+ ],
1100
1126
  ),
1101
1127
  pytest.param(
1102
1128
  "mini_olmo2",
1103
1129
  32,
1104
- 1e-4,
1130
+ 1e-5,
1105
1131
  torch.bfloat16,
1106
- 1e-3,
1132
+ 1e-2,
1107
1133
  1e-2,
1108
1134
  1e-1,
1109
1135
  1e-2,
@@ -1120,9 +1146,9 @@ def run_mini_model(
1120
1146
  pytest.param(
1121
1147
  "mini_glm4",
1122
1148
  32,
1123
- 1e-4,
1149
+ 1e-5,
1124
1150
  torch.bfloat16,
1125
- 1e-3,
1151
+ 1e-2,
1126
1152
  1e-2,
1127
1153
  1e-1,
1128
1154
  1e-2,
@@ -1156,27 +1182,27 @@ def run_mini_model(
1156
1182
  pytest.param(
1157
1183
  "mini_gemma1",
1158
1184
  32,
1159
- 1e-4,
1185
+ 1e-5,
1160
1186
  torch.bfloat16,
1161
- 1e-3,
1162
1187
  1e-2,
1163
1188
  1e-2,
1164
1189
  1e-1,
1165
1190
  1e-2,
1166
1191
  1e-2,
1192
+ 1e-2,
1167
1193
  marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
1168
1194
  ),
1169
1195
  pytest.param(
1170
1196
  "mini_gemma1.1",
1171
1197
  32,
1172
- 1e-4,
1198
+ 1e-5,
1173
1199
  torch.bfloat16,
1174
- 1e-3,
1175
1200
  1e-2,
1176
1201
  1e-2,
1177
1202
  1e-1,
1178
1203
  1e-2,
1179
1204
  1e-2,
1205
+ 1e-2,
1180
1206
  marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
1181
1207
  ),
1182
1208
  # TODO: Gemma2 test for bf16 is not passing within the tolerance range, might be casting issue, need to investigate
@@ -1198,12 +1224,12 @@ def run_mini_model(
1198
1224
  pytest.param(
1199
1225
  "mini_gemma3_text",
1200
1226
  32,
1201
- 1e-4,
1227
+ 1e-5,
1202
1228
  torch.bfloat16,
1203
- 1e-3,
1204
1229
  1e-2,
1205
- 3e-1,
1206
- 4e-1,
1230
+ 1e-2,
1231
+ 1e-1,
1232
+ 1e-2,
1207
1233
  1e-2,
1208
1234
  1e-2,
1209
1235
  marks=[
@@ -1240,6 +1266,7 @@ def test_mini_model(
1240
1266
  torch.tensor([actual_output["loss"]]),
1241
1267
  atol=loss_atol,
1242
1268
  rtol=loss_rtol,
1269
+ extra_info="[Loss]",
1243
1270
  )
1244
1271
 
1245
1272
  # Compare the topk logprobs from evaluation step
@@ -1249,6 +1276,7 @@ def test_mini_model(
1249
1276
  actual_output["topk_logprobs"],
1250
1277
  atol=logprobs_atol,
1251
1278
  rtol=logprobs_rtol,
1279
+ extra_info="[Top k logprobs]",
1252
1280
  )
1253
1281
 
1254
1282
  # Compare the params from the last step
@@ -1257,4 +1285,6 @@ def test_mini_model(
1257
1285
  expected_output["model"].named_parameters(),
1258
1286
  actual_output["model"].named_parameters(),
1259
1287
  ):
1260
- assert_verbose_allclose(expected_param[1], actual_param[1], atol=param_atol, rtol=param_rtol)
1288
+ assert_verbose_allclose(
1289
+ expected_param[1], actual_param[1], atol=param_atol, rtol=param_rtol, extra_info="[Model parameters]"
1290
+ )