liger-kernel-nightly 0.5.3.dev20250212203051__tar.gz → 0.5.3.dev20250214100345__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (218) hide show
  1. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/Makefile +7 -3
  2. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/PKG-INFO +1 -1
  3. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/pyproject.toml +1 -1
  4. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
  5. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel_nightly.egg-info/SOURCES.txt +8 -3
  6. liger_kernel_nightly-0.5.3.dev20250214100345/test/convergence/bf16/__init__.py +0 -0
  7. {liger_kernel_nightly-0.5.3.dev20250212203051/test/convergence → liger_kernel_nightly-0.5.3.dev20250214100345/test/convergence/bf16}/test_mini_models.py +0 -40
  8. {liger_kernel_nightly-0.5.3.dev20250212203051/test/convergence → liger_kernel_nightly-0.5.3.dev20250214100345/test/convergence/bf16}/test_mini_models_multimodal.py +0 -35
  9. {liger_kernel_nightly-0.5.3.dev20250212203051/test/convergence → liger_kernel_nightly-0.5.3.dev20250214100345/test/convergence/bf16}/test_mini_models_with_logits.py +0 -39
  10. liger_kernel_nightly-0.5.3.dev20250214100345/test/convergence/fp32/__init__.py +0 -0
  11. liger_kernel_nightly-0.5.3.dev20250214100345/test/convergence/fp32/test_mini_models.py +546 -0
  12. liger_kernel_nightly-0.5.3.dev20250214100345/test/convergence/fp32/test_mini_models_multimodal.py +416 -0
  13. liger_kernel_nightly-0.5.3.dev20250214100345/test/convergence/fp32/test_mini_models_with_logits.py +545 -0
  14. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  15. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  16. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/.github/pull_request_template.md +0 -0
  17. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/.github/workflows/amd-ci.yml +0 -0
  18. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/.github/workflows/docs.yml +0 -0
  19. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/.github/workflows/intel-ci.yml +0 -0
  20. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/.github/workflows/nvi-ci.yml +0 -0
  21. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/.github/workflows/publish-nightly.yml +0 -0
  22. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/.github/workflows/publish-release.yml +0 -0
  23. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/.gitignore +0 -0
  24. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/LICENSE +0 -0
  25. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/NOTICE +0 -0
  26. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/README.md +0 -0
  27. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/benchmark/README.md +0 -0
  28. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/benchmark/__init__.py +0 -0
  29. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/benchmark/benchmarks_visualizer.py +0 -0
  30. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/benchmark/data/all_benchmark_data.csv +0 -0
  31. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/benchmark/scripts/__init__.py +0 -0
  32. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  33. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  34. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  35. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  36. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/benchmark/scripts/benchmark_embedding.py +0 -0
  37. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  38. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  39. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/benchmark/scripts/benchmark_geglu.py +0 -0
  40. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/benchmark/scripts/benchmark_group_norm.py +0 -0
  41. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/benchmark/scripts/benchmark_jsd.py +0 -0
  42. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/benchmark/scripts/benchmark_kl_div.py +0 -0
  43. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  44. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  45. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  46. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  47. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  48. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/benchmark/scripts/benchmark_rope.py +0 -0
  49. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  50. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/benchmark/scripts/benchmark_swiglu.py +0 -0
  51. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/benchmark/scripts/utils.py +0 -0
  52. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/dev/fmt-requirements.txt +0 -0
  53. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/dev/modal/tests.py +0 -0
  54. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/dev/modal/tests_bwd.py +0 -0
  55. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/docs/Examples.md +0 -0
  56. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/docs/Getting-Started.md +0 -0
  57. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/docs/High-Level-APIs.md +0 -0
  58. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/docs/Low-Level-APIs.md +0 -0
  59. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/docs/acknowledgement.md +0 -0
  60. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/docs/contributing.md +0 -0
  61. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/docs/images/banner.GIF +0 -0
  62. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/docs/images/compose.gif +0 -0
  63. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/docs/images/e2e-memory.png +0 -0
  64. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/docs/images/e2e-tps.png +0 -0
  65. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/docs/images/logo-banner.png +0 -0
  66. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/docs/images/patch.gif +0 -0
  67. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/docs/images/post-training.png +0 -0
  68. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/docs/index.md +0 -0
  69. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/docs/license.md +0 -0
  70. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/alignment/accelerate_config.yaml +0 -0
  71. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/alignment/run_orpo.py +0 -0
  72. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/huggingface/README.md +0 -0
  73. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/huggingface/callback.py +0 -0
  74. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/huggingface/config/fsdp_config.json +0 -0
  75. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  76. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  77. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  78. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/huggingface/img/llama_tps.png +0 -0
  79. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  80. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/huggingface/img/qwen_tps.png +0 -0
  81. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/huggingface/launch_on_modal.py +0 -0
  82. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/huggingface/requirements.txt +0 -0
  83. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/huggingface/run_benchmarks.sh +0 -0
  84. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/huggingface/run_gemma.sh +0 -0
  85. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/huggingface/run_llama.sh +0 -0
  86. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/huggingface/run_qwen.sh +0 -0
  87. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/huggingface/run_qwen2_vl.sh +0 -0
  88. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/huggingface/training.py +0 -0
  89. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/huggingface/training_multimodal.py +0 -0
  90. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/lightning/README.md +0 -0
  91. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/lightning/requirements.txt +0 -0
  92. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/lightning/training.py +0 -0
  93. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/medusa/README.md +0 -0
  94. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/medusa/callback.py +0 -0
  95. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  96. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  97. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  98. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  99. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  100. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  101. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  102. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  103. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  104. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/medusa/medusa_util.py +0 -0
  105. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/medusa/requirements.txt +0 -0
  106. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  107. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/examples/medusa/train.py +0 -0
  108. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/licenses/LICENSE-Apache-2.0 +0 -0
  109. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  110. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  111. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/licenses/LICENSE-MIT-llmc +0 -0
  112. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/licenses/LICENSE-MIT-triton +0 -0
  113. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/mkdocs.yml +0 -0
  114. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/setup.cfg +0 -0
  115. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/setup.py +0 -0
  116. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/__init__.py +0 -0
  117. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/chunked_loss/README.md +0 -0
  118. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  119. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  120. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  121. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/chunked_loss/functional.py +0 -0
  122. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  123. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  124. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  125. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  126. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  127. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  128. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  129. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/env_report.py +0 -0
  130. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/ops/__init__.py +0 -0
  131. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/ops/cross_entropy.py +0 -0
  132. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  133. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  134. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  135. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  136. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/ops/geglu.py +0 -0
  137. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/ops/group_norm.py +0 -0
  138. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/ops/jsd.py +0 -0
  139. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/ops/kl_div.py +0 -0
  140. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/ops/layer_norm.py +0 -0
  141. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  142. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/ops/rms_norm.py +0 -0
  143. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/ops/rope.py +0 -0
  144. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/ops/swiglu.py +0 -0
  145. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/ops/utils.py +0 -0
  146. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/__init__.py +0 -0
  147. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/auto_model.py +0 -0
  148. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  149. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  150. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/functional.py +0 -0
  151. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  152. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  153. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/geglu.py +0 -0
  154. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/group_norm.py +0 -0
  155. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/jsd.py +0 -0
  156. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/kl_div.py +0 -0
  157. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/layer_norm.py +0 -0
  158. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/model/__init__.py +0 -0
  159. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/model/gemma.py +0 -0
  160. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  161. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/model/llama.py +0 -0
  162. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/model/mistral.py +0 -0
  163. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  164. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/model/mllama.py +0 -0
  165. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/model/phi3.py +0 -0
  166. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  167. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  168. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  169. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  170. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/rms_norm.py +0 -0
  171. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/rope.py +0 -0
  172. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/swiglu.py +0 -0
  173. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  174. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  175. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  176. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/triton/__init__.py +0 -0
  177. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/triton/monkey_patch.py +0 -0
  178. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel/utils.py +0 -0
  179. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  180. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  181. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  182. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/__init__.py +0 -0
  183. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/chunked_loss/__init__.py +0 -0
  184. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/chunked_loss/test_cpo_loss.py +0 -0
  185. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/chunked_loss/test_dpo_loss.py +0 -0
  186. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/chunked_loss/test_jsd_loss.py +0 -0
  187. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/chunked_loss/test_kto_loss.py +0 -0
  188. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/chunked_loss/test_orpo_loss.py +0 -0
  189. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/chunked_loss/test_simpo_loss.py +0 -0
  190. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/conftest.py +0 -0
  191. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/convergence/__init__.py +0 -0
  192. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  193. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  194. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  195. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/resources/tiny_shakespeare.txt +0 -0
  196. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  197. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  198. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  199. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/transformers/test_auto_model.py +0 -0
  200. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/transformers/test_cross_entropy.py +0 -0
  201. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/transformers/test_embedding.py +0 -0
  202. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  203. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/transformers/test_fused_linear_jsd.py +0 -0
  204. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/transformers/test_geglu.py +0 -0
  205. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/transformers/test_group_norm.py +0 -0
  206. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/transformers/test_jsd.py +0 -0
  207. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/transformers/test_kl_div.py +0 -0
  208. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/transformers/test_layer_norm.py +0 -0
  209. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/transformers/test_mm_int8int2.py +0 -0
  210. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/transformers/test_monkey_patch.py +0 -0
  211. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/transformers/test_qwen2vl_mrope.py +0 -0
  212. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/transformers/test_rms_norm.py +0 -0
  213. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/transformers/test_rope.py +0 -0
  214. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/transformers/test_swiglu.py +0 -0
  215. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/transformers/test_trainer_integration.py +0 -0
  216. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/transformers/test_transformers.py +0 -0
  217. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/triton/test_triton_monkey_patch.py +0 -0
  218. {liger_kernel_nightly-0.5.3.dev20250212203051 → liger_kernel_nightly-0.5.3.dev20250214100345}/test/utils.py +0 -0
@@ -18,9 +18,13 @@ checkstyle:
18
18
  # Command to run pytest for convergence tests
19
19
  # We have to explicitly set HF_DATASETS_OFFLINE=1, or dataset will silently try to send metrics and timeout (80s) https://github.com/huggingface/datasets/blob/37a603679f451826cfafd8aae00738b01dcb9d58/src/datasets/load.py#L286
20
20
  test-convergence:
21
- HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models.py
22
- HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models_multimodal.py
23
- HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/test_mini_models_with_logits.py
21
+ HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/fp32/test_mini_models.py
22
+ HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/fp32/test_mini_models_multimodal.py
23
+ HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/fp32/test_mini_models_with_logits.py
24
+
25
+ HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models.py
26
+ HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models_multimodal.py
27
+ HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence/bf16/test_mini_models_with_logits.py
24
28
 
25
29
  # Command to run all benchmark scripts and update benchmarking data file
26
30
  # By default this doesn't overwrite existing data for the same benchmark experiment
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.3.dev20250212203051
3
+ Version: 0.5.3.dev20250214100345
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.3.dev20250212203051"
7
+ version = "0.5.3.dev20250214100345"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.3.dev20250212203051
3
+ Version: 0.5.3.dev20250214100345
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -180,9 +180,14 @@ test/chunked_loss/test_kto_loss.py
180
180
  test/chunked_loss/test_orpo_loss.py
181
181
  test/chunked_loss/test_simpo_loss.py
182
182
  test/convergence/__init__.py
183
- test/convergence/test_mini_models.py
184
- test/convergence/test_mini_models_multimodal.py
185
- test/convergence/test_mini_models_with_logits.py
183
+ test/convergence/bf16/__init__.py
184
+ test/convergence/bf16/test_mini_models.py
185
+ test/convergence/bf16/test_mini_models_multimodal.py
186
+ test/convergence/bf16/test_mini_models_with_logits.py
187
+ test/convergence/fp32/__init__.py
188
+ test/convergence/fp32/test_mini_models.py
189
+ test/convergence/fp32/test_mini_models_multimodal.py
190
+ test/convergence/fp32/test_mini_models_with_logits.py
186
191
  test/resources/tiny_shakespeare.txt
187
192
  test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json
188
193
  test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json
@@ -457,7 +457,6 @@ def run_mini_model(
457
457
  @pytest.mark.parametrize(
458
458
  "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
459
459
  [
460
- ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5),
461
460
  pytest.param(
462
461
  "mini_llama3",
463
462
  32,
@@ -471,22 +470,6 @@ def run_mini_model(
471
470
  1e-2,
472
471
  marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
473
472
  ),
474
- pytest.param(
475
- "mini_mllama",
476
- 32,
477
- 1e-4,
478
- torch.float32,
479
- 1e-8,
480
- 1e-5,
481
- 5e-3,
482
- 1e-5,
483
- 5e-3,
484
- 1e-5,
485
- marks=pytest.mark.skipif(
486
- not MLLAMA_AVAILABLE,
487
- reason="Mllama not available in this version of transformers",
488
- ),
489
- ),
490
473
  pytest.param(
491
474
  "mini_mllama",
492
475
  32,
@@ -506,7 +489,6 @@ def run_mini_model(
506
489
  ),
507
490
  ],
508
491
  ),
509
- ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
510
492
  pytest.param(
511
493
  "mini_qwen2",
512
494
  32,
@@ -520,22 +502,6 @@ def run_mini_model(
520
502
  1e-2,
521
503
  marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
522
504
  ),
523
- pytest.param( # qwen2_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0
524
- "mini_qwen2_vl",
525
- 32,
526
- 1e-4,
527
- torch.float32,
528
- 1e-5, # 1e-8,
529
- 1e-1, # 1e-5,
530
- 5e-3,
531
- 1e-5,
532
- 5e-3,
533
- 1e-5,
534
- marks=pytest.mark.skipif(
535
- not QWEN2_VL_AVAILABLE,
536
- reason="Qwen2-VL not available in this version of transformers",
537
- ),
538
- ),
539
505
  pytest.param(
540
506
  "mini_qwen2_vl",
541
507
  32,
@@ -555,7 +521,6 @@ def run_mini_model(
555
521
  ),
556
522
  ],
557
523
  ),
558
- ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
559
524
  pytest.param(
560
525
  "mini_phi3",
561
526
  32,
@@ -569,7 +534,6 @@ def run_mini_model(
569
534
  1e-2,
570
535
  marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
571
536
  ),
572
- ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
573
537
  pytest.param(
574
538
  "mini_mistral",
575
539
  32,
@@ -584,7 +548,6 @@ def run_mini_model(
584
548
  marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
585
549
  ),
586
550
  # TODO: mixtral is flaky so disable the test for now
587
- # ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5),
588
551
  # pytest.param(
589
552
  # "mini_mixtral",
590
553
  # 32,
@@ -601,7 +564,6 @@ def run_mini_model(
601
564
  # ),
602
565
  # ),
603
566
  # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
604
- ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
605
567
  pytest.param(
606
568
  "mini_gemma1",
607
569
  32,
@@ -615,7 +577,6 @@ def run_mini_model(
615
577
  1e-2,
616
578
  marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
617
579
  ),
618
- ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
619
580
  pytest.param(
620
581
  "mini_gemma1.1",
621
582
  32,
@@ -629,7 +590,6 @@ def run_mini_model(
629
590
  1e-2,
630
591
  marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
631
592
  ),
632
- ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
633
593
  # TODO: Gemma2 test for bf16 is not passing within the tolerance range, might be casting issue, need to investigate
634
594
  # pytest.param(
635
595
  # "mini_gemma2",
@@ -335,25 +335,6 @@ def run_mini_model_multimodal(
335
335
  @pytest.mark.parametrize(
336
336
  "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
337
337
  [
338
- pytest.param(
339
- "mini_qwen2_vl",
340
- 32,
341
- 1e-4,
342
- torch.float32,
343
- 1e-8,
344
- 1e-5,
345
- 5e-3,
346
- 1e-5,
347
- 5e-3,
348
- 1e-5,
349
- marks=[
350
- pytest.mark.skipif(
351
- not QWEN2_VL_AVAILABLE,
352
- reason="Qwen2-VL not available in this version of transformers",
353
- ),
354
- pytest.mark.skipif(device == "xpu", reason="skip for XPU"),
355
- ],
356
- ),
357
338
  pytest.param(
358
339
  "mini_qwen2_vl",
359
340
  32,
@@ -374,22 +355,6 @@ def run_mini_model_multimodal(
374
355
  pytest.mark.skipif(device == "xpu", reason="skip for XPU"),
375
356
  ],
376
357
  ),
377
- pytest.param(
378
- "mini_mllama",
379
- 32,
380
- 1e-4,
381
- torch.float32,
382
- 1e-8,
383
- 1e-5,
384
- 5e-3,
385
- 1e-5,
386
- 5e-3,
387
- 1e-5,
388
- marks=pytest.mark.skipif(
389
- not MLLAMA_AVAILABLE,
390
- reason="Mllama not available in this version of transformers",
391
- ),
392
- ),
393
358
  pytest.param(
394
359
  "mini_mllama",
395
360
  32,
@@ -456,7 +456,6 @@ def run_mini_model(
456
456
  @pytest.mark.parametrize(
457
457
  "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
458
458
  [
459
- ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5),
460
459
  pytest.param(
461
460
  "mini_llama3",
462
461
  32,
@@ -470,22 +469,6 @@ def run_mini_model(
470
469
  1e-2,
471
470
  marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
472
471
  ),
473
- pytest.param(
474
- "mini_mllama",
475
- 32,
476
- 1e-4,
477
- torch.float32,
478
- 1e-8,
479
- 1e-5,
480
- 5e-3,
481
- 1e-5,
482
- 5e-3,
483
- 1e-5,
484
- marks=pytest.mark.skipif(
485
- not MLLAMA_AVAILABLE,
486
- reason="Mllama not available in this version of transformers",
487
- ),
488
- ),
489
472
  pytest.param(
490
473
  "mini_mllama",
491
474
  32,
@@ -519,22 +502,6 @@ def run_mini_model(
519
502
  1e-2,
520
503
  marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
521
504
  ),
522
- pytest.param(
523
- "mini_qwen2_vl",
524
- 32,
525
- 1e-4,
526
- torch.float32,
527
- 1e-8,
528
- 1e-5,
529
- 5e-3,
530
- 1e-5,
531
- 5e-3,
532
- 1e-5,
533
- marks=pytest.mark.skipif(
534
- not QWEN2_VL_AVAILABLE,
535
- reason="Qwen2-VL not available in this version of transformers",
536
- ),
537
- ),
538
505
  pytest.param(
539
506
  "mini_qwen2_vl",
540
507
  32,
@@ -554,7 +521,6 @@ def run_mini_model(
554
521
  ),
555
522
  ],
556
523
  ),
557
- ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
558
524
  pytest.param(
559
525
  "mini_phi3",
560
526
  32,
@@ -568,7 +534,6 @@ def run_mini_model(
568
534
  1e-2,
569
535
  marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
570
536
  ),
571
- ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
572
537
  pytest.param(
573
538
  "mini_mistral",
574
539
  32,
@@ -583,7 +548,6 @@ def run_mini_model(
583
548
  marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
584
549
  ),
585
550
  # TODO: mixtral is flaky so disable the test for now
586
- # ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5),
587
551
  # pytest.param(
588
552
  # "mini_mixtral",
589
553
  # 32,
@@ -600,7 +564,6 @@ def run_mini_model(
600
564
  # ),
601
565
  # ),
602
566
  # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match
603
- ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
604
567
  pytest.param(
605
568
  "mini_gemma1",
606
569
  32,
@@ -614,7 +577,6 @@ def run_mini_model(
614
577
  1e-2,
615
578
  marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
616
579
  ),
617
- ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
618
580
  pytest.param(
619
581
  "mini_gemma1.1",
620
582
  32,
@@ -628,7 +590,6 @@ def run_mini_model(
628
590
  1e-2,
629
591
  marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
630
592
  ),
631
- ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
632
593
  # TODO: Gemma2 test for bf16 is not passing within the tolerance range, might be casting issue, need to investigate
633
594
  # pytest.param(
634
595
  # "mini_gemma2",