liger-kernel-nightly 0.5.3.dev20250221233257__tar.gz → 0.5.3.dev20250224212643__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (227) hide show
  1. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/PKG-INFO +2 -1
  2. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/README.md +1 -0
  3. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/pyproject.toml +1 -1
  4. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/ops/layer_norm.py +20 -7
  5. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/__init__.py +1 -0
  6. liger_kernel_nightly-0.5.3.dev20250224212643/src/liger_kernel/transformers/model/olmo2.py +124 -0
  7. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/monkey_patch.py +64 -0
  8. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel_nightly.egg-info/PKG-INFO +2 -1
  9. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel_nightly.egg-info/SOURCES.txt +1 -0
  10. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/convergence/bf16/test_mini_models.py +59 -0
  11. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/convergence/bf16/test_mini_models_with_logits.py +60 -0
  12. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/convergence/fp32/test_mini_models.py +56 -0
  13. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/convergence/fp32/test_mini_models_with_logits.py +56 -0
  14. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/transformers/test_layer_norm.py +20 -5
  15. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/transformers/test_monkey_patch.py +51 -0
  16. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/utils.py +12 -0
  17. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  18. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  19. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/.github/pull_request_template.md +0 -0
  20. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/.github/workflows/amd-ci.yml +0 -0
  21. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/.github/workflows/docs.yml +0 -0
  22. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/.github/workflows/intel-ci.yml +0 -0
  23. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/.github/workflows/nvi-ci.yml +0 -0
  24. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/.github/workflows/publish-nightly.yml +0 -0
  25. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/.github/workflows/publish-release.yml +0 -0
  26. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/.gitignore +0 -0
  27. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/LICENSE +0 -0
  28. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/Makefile +0 -0
  29. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/NOTICE +0 -0
  30. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/benchmark/README.md +0 -0
  31. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/benchmark/__init__.py +0 -0
  32. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/benchmark/benchmarks_visualizer.py +0 -0
  33. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/benchmark/data/all_benchmark_data.csv +0 -0
  34. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/benchmark/scripts/__init__.py +0 -0
  35. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  36. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  37. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  38. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  39. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/benchmark/scripts/benchmark_embedding.py +0 -0
  40. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  41. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  42. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/benchmark/scripts/benchmark_geglu.py +0 -0
  43. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/benchmark/scripts/benchmark_group_norm.py +0 -0
  44. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/benchmark/scripts/benchmark_jsd.py +0 -0
  45. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/benchmark/scripts/benchmark_kl_div.py +0 -0
  46. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  47. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  48. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  49. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  50. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  51. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/benchmark/scripts/benchmark_rope.py +0 -0
  52. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  53. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/benchmark/scripts/benchmark_swiglu.py +0 -0
  54. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/benchmark/scripts/benchmark_tvd.py +0 -0
  55. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/benchmark/scripts/utils.py +0 -0
  56. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/dev/fmt-requirements.txt +0 -0
  57. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/dev/modal/tests.py +0 -0
  58. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/dev/modal/tests_bwd.py +0 -0
  59. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/docs/Examples.md +0 -0
  60. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/docs/Getting-Started.md +0 -0
  61. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/docs/High-Level-APIs.md +0 -0
  62. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/docs/Low-Level-APIs.md +0 -0
  63. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/docs/acknowledgement.md +0 -0
  64. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/docs/contributing.md +0 -0
  65. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/docs/images/banner.GIF +0 -0
  66. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/docs/images/compose.gif +0 -0
  67. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/docs/images/e2e-memory.png +0 -0
  68. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/docs/images/e2e-tps.png +0 -0
  69. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/docs/images/logo-banner.png +0 -0
  70. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/docs/images/patch.gif +0 -0
  71. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/docs/images/post-training.png +0 -0
  72. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/docs/index.md +0 -0
  73. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/docs/license.md +0 -0
  74. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/alignment/accelerate_config.yaml +0 -0
  75. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/alignment/run_orpo.py +0 -0
  76. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/huggingface/README.md +0 -0
  77. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/huggingface/callback.py +0 -0
  78. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/huggingface/config/fsdp_config.json +0 -0
  79. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  80. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  81. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  82. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/huggingface/img/llama_tps.png +0 -0
  83. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  84. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/huggingface/img/qwen_tps.png +0 -0
  85. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/huggingface/launch_on_modal.py +0 -0
  86. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/huggingface/requirements.txt +0 -0
  87. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/huggingface/run_benchmarks.sh +0 -0
  88. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/huggingface/run_gemma.sh +0 -0
  89. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/huggingface/run_llama.sh +0 -0
  90. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/huggingface/run_qwen.sh +0 -0
  91. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/huggingface/run_qwen2_vl.sh +0 -0
  92. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/huggingface/training.py +0 -0
  93. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/huggingface/training_multimodal.py +0 -0
  94. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/lightning/README.md +0 -0
  95. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/lightning/requirements.txt +0 -0
  96. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/lightning/training.py +0 -0
  97. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/medusa/README.md +0 -0
  98. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/medusa/callback.py +0 -0
  99. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  100. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  101. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  102. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  103. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  104. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  105. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  106. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  107. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  108. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/medusa/medusa_util.py +0 -0
  109. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/medusa/requirements.txt +0 -0
  110. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  111. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/examples/medusa/train.py +0 -0
  112. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/licenses/LICENSE-Apache-2.0 +0 -0
  113. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  114. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  115. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/licenses/LICENSE-MIT-llmc +0 -0
  116. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/licenses/LICENSE-MIT-triton +0 -0
  117. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/mkdocs.yml +0 -0
  118. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/setup.cfg +0 -0
  119. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/setup.py +0 -0
  120. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/__init__.py +0 -0
  121. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/chunked_loss/README.md +0 -0
  122. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  123. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  124. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  125. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/chunked_loss/functional.py +0 -0
  126. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  127. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  128. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -0
  129. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  130. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
  131. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  132. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  133. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  134. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  135. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/env_report.py +0 -0
  136. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/ops/__init__.py +0 -0
  137. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/ops/cross_entropy.py +0 -0
  138. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  139. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  140. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  141. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  142. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/ops/geglu.py +0 -0
  143. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/ops/group_norm.py +0 -0
  144. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/ops/jsd.py +0 -0
  145. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/ops/kl_div.py +0 -0
  146. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  147. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/ops/rms_norm.py +0 -0
  148. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/ops/rope.py +0 -0
  149. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/ops/swiglu.py +0 -0
  150. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/ops/tvd.py +0 -0
  151. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/ops/utils.py +0 -0
  152. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/auto_model.py +0 -0
  153. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  154. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  155. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/functional.py +0 -0
  156. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  157. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  158. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/geglu.py +0 -0
  159. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/group_norm.py +0 -0
  160. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/jsd.py +0 -0
  161. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/kl_div.py +0 -0
  162. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/layer_norm.py +0 -0
  163. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/model/__init__.py +0 -0
  164. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/model/gemma.py +0 -0
  165. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  166. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/model/llama.py +0 -0
  167. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/model/mistral.py +0 -0
  168. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  169. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/model/mllama.py +0 -0
  170. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/model/phi3.py +0 -0
  171. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  172. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  173. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  174. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/rms_norm.py +0 -0
  175. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/rope.py +0 -0
  176. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/swiglu.py +0 -0
  177. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  178. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  179. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  180. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/transformers/tvd.py +0 -0
  181. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/triton/__init__.py +0 -0
  182. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/triton/monkey_patch.py +0 -0
  183. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel/utils.py +0 -0
  184. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  185. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  186. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  187. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/__init__.py +0 -0
  188. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/chunked_loss/__init__.py +0 -0
  189. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/chunked_loss/test_cpo_loss.py +0 -0
  190. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/chunked_loss/test_dpo_loss.py +0 -0
  191. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/chunked_loss/test_grpo_loss.py +0 -0
  192. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/chunked_loss/test_jsd_loss.py +0 -0
  193. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/chunked_loss/test_kto_loss.py +0 -0
  194. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/chunked_loss/test_orpo_loss.py +0 -0
  195. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/chunked_loss/test_simpo_loss.py +0 -0
  196. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/conftest.py +0 -0
  197. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/convergence/__init__.py +0 -0
  198. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/convergence/bf16/__init__.py +0 -0
  199. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
  200. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/convergence/fp32/__init__.py +0 -0
  201. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
  202. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  203. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  204. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  205. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/resources/tiny_shakespeare.txt +0 -0
  206. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  207. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  208. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  209. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/transformers/test_auto_model.py +0 -0
  210. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/transformers/test_cross_entropy.py +0 -0
  211. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/transformers/test_embedding.py +0 -0
  212. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/transformers/test_flex_attention.py +0 -0
  213. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  214. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/transformers/test_fused_linear_jsd.py +0 -0
  215. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/transformers/test_geglu.py +0 -0
  216. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/transformers/test_group_norm.py +0 -0
  217. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/transformers/test_jsd.py +0 -0
  218. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/transformers/test_kl_div.py +0 -0
  219. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/transformers/test_mm_int8int2.py +0 -0
  220. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/transformers/test_qwen2vl_mrope.py +0 -0
  221. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/transformers/test_rms_norm.py +0 -0
  222. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/transformers/test_rope.py +0 -0
  223. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/transformers/test_swiglu.py +0 -0
  224. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/transformers/test_trainer_integration.py +0 -0
  225. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/transformers/test_transformers.py +0 -0
  226. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/transformers/test_tvd.py +0 -0
  227. {liger_kernel_nightly-0.5.3.dev20250221233257 → liger_kernel_nightly-0.5.3.dev20250224212643}/test/triton/test_triton_monkey_patch.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.3.dev20250221233257
3
+ Version: 0.5.3.dev20250224212643
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -314,6 +314,7 @@ loss.backward()
314
314
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
315
315
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
316
316
  | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
317
+ | OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
317
318
 
318
319
 
319
320
  ## Low-level APIs
@@ -266,6 +266,7 @@ loss.backward()
266
266
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
267
267
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
268
268
  | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
269
+ | OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
269
270
 
270
271
 
271
272
  ## Low-level APIs
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.3.dev20250221233257"
7
+ version = "0.5.3.dev20250224212643"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -57,13 +57,14 @@ def _layer_norm_forward_kernel(
57
57
  B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0)
58
58
 
59
59
  mean = tl.sum(X_row, axis=0) / n_cols
60
- var = tl.sum((X_row - mean) * (X_row - mean), axis=0) / n_cols
60
+ Xmm = tl.where(mask, X_row - mean, 0)
61
+ var = tl.sum(Xmm * Xmm, axis=0) / n_cols
61
62
  rstd = rsqrt(var + eps)
62
63
 
63
64
  tl.store(Mean_ptr, mean)
64
65
  tl.store(RSTD_ptr, rstd)
65
66
 
66
- Y_row = (X_row - mean) * rstd * W_row + B_row
67
+ Y_row = Xmm * rstd * W_row + B_row
67
68
 
68
69
  tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
69
70
 
@@ -147,9 +148,11 @@ def layer_norm_forward(X, W, B, eps):
147
148
  Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
148
149
  Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)
149
150
  RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)
150
- assert X.shape[1] == W.shape[0], (
151
- f"Incompatible hidden size dimension between input tensor with shape[1] = {X.shape[1]} and weight tensor with shape[0] = {W.shape[0]}"
152
- )
151
+ if X.shape[1] != W.shape[0]:
152
+ raise ValueError(
153
+ f"Incompatible dimensions: input feature size (X.shape[1]={X.shape[1]}) "
154
+ f"must match weight size (W.shape[0]={W.shape[0]})"
155
+ )
153
156
 
154
157
  _layer_norm_forward_kernel[(n_rows,)](
155
158
  Y,
@@ -190,11 +193,21 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
190
193
 
191
194
  BLOCK_SIZE, num_warps = calculate_settings(n_cols)
192
195
  if n_cols > BLOCK_SIZE:
193
- raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
196
+ raise RuntimeError(
197
+ f"Feature dimension {n_cols} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
198
+ )
194
199
 
195
200
  rows_per_program = math.ceil(n_rows / sm_count)
196
201
  grid = (sm_count,)
197
- triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16
202
+ triton_dtype = (
203
+ tl.float32
204
+ if X.dtype == torch.float32
205
+ else tl.bfloat16
206
+ if X.dtype == torch.bfloat16
207
+ else tl.float16
208
+ if X.dtype == torch.float16
209
+ else tl.float32 # fallback to float32 for other types
210
+ )
198
211
  _layer_norm_backward_kernel[grid](
199
212
  X,
200
213
  W,
@@ -14,6 +14,7 @@ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama
14
14
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
15
15
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
16
16
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
17
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo2 # noqa: F401
17
18
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
18
19
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
19
20
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
@@ -0,0 +1,124 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from transformers.modeling_outputs import CausalLMOutputWithPast
9
+ from transformers.models.olmo2.modeling_olmo2 import _CONFIG_FOR_DOC
10
+ from transformers.models.olmo2.modeling_olmo2 import OLMO2_INPUTS_DOCSTRING
11
+ from transformers.utils import add_start_docstrings_to_model_forward
12
+ from transformers.utils import replace_return_docstrings
13
+
14
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
15
+
16
+
17
+ @add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING)
18
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
19
+ def lce_forward(
20
+ self,
21
+ input_ids: torch.LongTensor = None,
22
+ attention_mask: Optional[torch.Tensor] = None,
23
+ position_ids: Optional[torch.LongTensor] = None,
24
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
25
+ inputs_embeds: Optional[torch.FloatTensor] = None,
26
+ labels: Optional[torch.LongTensor] = None,
27
+ use_cache: Optional[bool] = None,
28
+ output_attentions: Optional[bool] = None,
29
+ output_hidden_states: Optional[bool] = None,
30
+ return_dict: Optional[bool] = None,
31
+ cache_position: Optional[torch.LongTensor] = None,
32
+ num_logits_to_keep: int = 0,
33
+ **loss_kwargs,
34
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
35
+ r"""
36
+ Args:
37
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
38
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
39
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
40
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
41
+
42
+ num_logits_to_keep (`int`, *optional*):
43
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
44
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
45
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
46
+
47
+ Returns:
48
+
49
+ Example:
50
+
51
+ ```python
52
+ >>> from transformers import AutoTokenizer, Olmo2ForCausalLM
53
+
54
+ >>> model = Olmo2ForCausalLM.from_pretrained("allenai/Olmo2-1B-hf")
55
+ >>> tokenizer = AutoTokenizer.from_pretrained("allenai/Olmo2-1B-hf")
56
+
57
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
58
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
59
+
60
+ >>> # Generate
61
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
62
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
63
+ 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
64
+ ```
65
+ """
66
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
67
+ output_hidden_states = (
68
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
69
+ )
70
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
71
+
72
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
73
+ outputs = self.model(
74
+ input_ids=input_ids,
75
+ attention_mask=attention_mask,
76
+ position_ids=position_ids,
77
+ past_key_values=past_key_values,
78
+ inputs_embeds=inputs_embeds,
79
+ use_cache=use_cache,
80
+ output_attentions=output_attentions,
81
+ output_hidden_states=output_hidden_states,
82
+ return_dict=return_dict,
83
+ cache_position=cache_position,
84
+ )
85
+
86
+ hidden_states = outputs[0]
87
+
88
+ logits = None
89
+ loss = None
90
+ # if in training mode, don't materialize logits
91
+ if self.training and (labels is not None):
92
+ # We do the same thing as ForCausalLMLoss but using Liger FLCE
93
+
94
+ shift_hidden_states = hidden_states[..., :-1, :].contiguous()
95
+ shift_labels = labels[..., 1:].contiguous()
96
+
97
+ # flatten tokens
98
+ shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
99
+ shift_labels = shift_labels.view(-1)
100
+
101
+ reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
102
+ lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
103
+
104
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
105
+ if reduction == "sum":
106
+ loss /= loss_kwargs["num_items_in_batch"]
107
+
108
+ else: # if in inference mode materialize logits
109
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
110
+ if labels is not None:
111
+ loss = self.loss_function(
112
+ logits=logits,
113
+ labels=labels,
114
+ vocab_size=self.config.vocab_size,
115
+ **loss_kwargs,
116
+ )
117
+
118
+ return CausalLMOutputWithPast(
119
+ loss=loss,
120
+ logits=logits,
121
+ past_key_values=outputs.past_key_values,
122
+ hidden_states=outputs.hidden_states,
123
+ attentions=outputs.attentions,
124
+ )
@@ -814,6 +814,69 @@ def apply_liger_kernel_to_phi3(
814
814
  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
815
815
 
816
816
 
817
+ def apply_liger_kernel_to_olmo2(
818
+ rope: bool = True,
819
+ cross_entropy: bool = False,
820
+ fused_linear_cross_entropy: bool = True,
821
+ rms_norm: bool = True,
822
+ swiglu: bool = True,
823
+ model: PreTrainedModel = None,
824
+ ) -> None:
825
+ """
826
+ Apply Liger kernels to replace original implementation in HuggingFace OLMO2 models.
827
+
828
+ Args:
829
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
830
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
831
+ fused_linear_cross_entropy (bool):
832
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
833
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
834
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
835
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
836
+ swiglu (bool): Whether to apply Liger's SwiGLU Olmo2MLP. Default is True.
837
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
838
+ loaded. Default is None.
839
+ """
840
+ assert not (cross_entropy and fused_linear_cross_entropy), (
841
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
842
+ )
843
+
844
+ from transformers.models.olmo2 import modeling_olmo2
845
+ from transformers.models.olmo2.modeling_olmo2 import Olmo2Model
846
+
847
+ from liger_kernel.transformers.model.olmo2 import lce_forward as olmo2_lce_forward
848
+
849
+ if rope:
850
+ modeling_olmo2.apply_rotary_pos_emb = liger_rotary_pos_emb
851
+ if rms_norm:
852
+ modeling_olmo2.Olmo2RMSNorm = partial(LigerRMSNorm, in_place=False)
853
+ if swiglu:
854
+ modeling_olmo2.Olmo2MLP = LigerSwiGLUMLP
855
+ if cross_entropy:
856
+ from transformers.loss.loss_utils import nn
857
+
858
+ nn.functional.cross_entropy = liger_cross_entropy
859
+ if fused_linear_cross_entropy:
860
+ modeling_olmo2.Olmo2ForCausalLM.forward = olmo2_lce_forward
861
+
862
+ if model is not None:
863
+ # The model instance already exists, so we need to additionally patch the
864
+ # instance variables that reference already-instantiated modules
865
+
866
+ # get the base model from the model instance
867
+ base_model: Olmo2Model = getattr(model, model.base_model_prefix, model)
868
+
869
+ if rms_norm:
870
+ _patch_rms_norm_module(base_model.norm)
871
+
872
+ for decoder_layer in base_model.layers:
873
+ if swiglu:
874
+ _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward)
875
+ if rms_norm:
876
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
877
+ _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
878
+
879
+
817
880
  # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
818
881
  MODEL_TYPE_TO_APPLY_LIGER_FN = {
819
882
  "gemma": apply_liger_kernel_to_gemma,
@@ -824,6 +887,7 @@ MODEL_TYPE_TO_APPLY_LIGER_FN = {
824
887
  "mllama_text_model": apply_liger_kernel_to_mllama,
825
888
  "mistral": apply_liger_kernel_to_mistral,
826
889
  "mixtral": apply_liger_kernel_to_mixtral,
890
+ "olmo2": apply_liger_kernel_to_olmo2,
827
891
  "qwen2": apply_liger_kernel_to_qwen2,
828
892
  "qwen2_vl": apply_liger_kernel_to_qwen2_vl,
829
893
  "phi3": apply_liger_kernel_to_phi3,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.3.dev20250221233257
3
+ Version: 0.5.3.dev20250224212643
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -314,6 +314,7 @@ loss.backward()
314
314
  | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
315
315
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
316
316
  | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
317
+ | OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
317
318
 
318
319
 
319
320
  ## Low-level APIs
@@ -162,6 +162,7 @@ src/liger_kernel/transformers/model/llama.py
162
162
  src/liger_kernel/transformers/model/mistral.py
163
163
  src/liger_kernel/transformers/model/mixtral.py
164
164
  src/liger_kernel/transformers/model/mllama.py
165
+ src/liger_kernel/transformers/model/olmo2.py
165
166
  src/liger_kernel/transformers/model/phi3.py
166
167
  src/liger_kernel/transformers/model/qwen2.py
167
168
  src/liger_kernel/transformers/model/qwen2_vl.py
@@ -25,6 +25,7 @@ from liger_kernel.transformers import apply_liger_kernel_to_llama
25
25
  from liger_kernel.transformers import apply_liger_kernel_to_mistral
26
26
  from liger_kernel.transformers import apply_liger_kernel_to_mixtral
27
27
  from liger_kernel.transformers import apply_liger_kernel_to_mllama
28
+ from liger_kernel.transformers import apply_liger_kernel_to_olmo2
28
29
  from liger_kernel.transformers import apply_liger_kernel_to_phi3
29
30
  from liger_kernel.transformers import apply_liger_kernel_to_qwen2
30
31
  from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl
@@ -38,6 +39,7 @@ from test.utils import revert_liger_kernel_to_llama
38
39
  from test.utils import revert_liger_kernel_to_mistral
39
40
  from test.utils import revert_liger_kernel_to_mixtral
40
41
  from test.utils import revert_liger_kernel_to_mllama
42
+ from test.utils import revert_liger_kernel_to_olmo2
41
43
  from test.utils import revert_liger_kernel_to_phi3
42
44
  from test.utils import revert_liger_kernel_to_qwen2
43
45
  from test.utils import revert_liger_kernel_to_qwen2_vl
@@ -71,6 +73,15 @@ try:
71
73
  except ImportError:
72
74
  GRANITE_AVAILABLE = False
73
75
 
76
+ try:
77
+ # OLMO2 is only available in transformers>=4.47.0
78
+ from transformers.models.olmo2.configuration_olmo2 import Olmo2Config
79
+ from transformers.models.olmo2.modeling_olmo2 import Olmo2ForCausalLM
80
+
81
+ OLMO2_AVAILABLE = True
82
+ except ImportError:
83
+ OLMO2_AVAILABLE = False
84
+
74
85
  from liger_kernel.utils import infer_device
75
86
 
76
87
  device = infer_device()
@@ -426,6 +437,35 @@ if GRANITE_AVAILABLE:
426
437
  ),
427
438
  )
428
439
 
440
+ if OLMO2_AVAILABLE:
441
+ MINI_MODEL_SETUPS["mini_olmo2"] = MiniModelConfig(
442
+ liger_kernel_patch_func=apply_liger_kernel_to_olmo2,
443
+ liger_kernel_patch_revert_func=revert_liger_kernel_to_olmo2,
444
+ model_class=Olmo2ForCausalLM,
445
+ mini_model_config=Olmo2Config(
446
+ bos_token_id=1, # 128000
447
+ eos_token_id=2, # 128001
448
+ pad_token_id=2,
449
+ cross_attention_layers=None,
450
+ dropout=0,
451
+ hidden_act="silu",
452
+ hidden_size=1024, # 4096
453
+ initializer_range=0.02,
454
+ intermediate_size=2048, # 14336
455
+ max_position_embeddings=4096,
456
+ num_attention_heads=8, # 32
457
+ num_hidden_layers=4, # 40
458
+ num_key_value_heads=2, # 8
459
+ rms_norm_eps=1e-5,
460
+ rope_scaling=None,
461
+ rope_theta=500_000,
462
+ tie_word_embeddings=False,
463
+ use_cache=True,
464
+ vocab_size=32000, # 128256,
465
+ attn_implementation="sdpa", # default value, pytorch native attention
466
+ ),
467
+ )
468
+
429
469
 
430
470
  def create_model(model_name="mini_llama3"):
431
471
  """
@@ -612,6 +652,25 @@ def run_mini_model(
612
652
  1e-2,
613
653
  marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
614
654
  ),
655
+ pytest.param(
656
+ "mini_olmo2",
657
+ 32,
658
+ 1e-4,
659
+ torch.bfloat16,
660
+ 1e-3,
661
+ 1e-2,
662
+ 1e-1,
663
+ 1e-2,
664
+ 1e-2,
665
+ 1e-2,
666
+ marks=[
667
+ pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
668
+ pytest.mark.skipif(
669
+ not OLMO2_AVAILABLE,
670
+ reason="OLMO2 not available in this version of transformers",
671
+ ),
672
+ ],
673
+ ),
615
674
  # TODO: mixtral is flaky so disable the test for now
616
675
  # pytest.param(
617
676
  # "mini_mixtral",
@@ -25,6 +25,7 @@ from liger_kernel.transformers import apply_liger_kernel_to_llama
25
25
  from liger_kernel.transformers import apply_liger_kernel_to_mistral
26
26
  from liger_kernel.transformers import apply_liger_kernel_to_mixtral
27
27
  from liger_kernel.transformers import apply_liger_kernel_to_mllama
28
+ from liger_kernel.transformers import apply_liger_kernel_to_olmo2
28
29
  from liger_kernel.transformers import apply_liger_kernel_to_phi3
29
30
  from liger_kernel.transformers import apply_liger_kernel_to_qwen2
30
31
  from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl
@@ -38,6 +39,7 @@ from test.utils import revert_liger_kernel_to_llama
38
39
  from test.utils import revert_liger_kernel_to_mistral
39
40
  from test.utils import revert_liger_kernel_to_mixtral
40
41
  from test.utils import revert_liger_kernel_to_mllama
42
+ from test.utils import revert_liger_kernel_to_olmo2
41
43
  from test.utils import revert_liger_kernel_to_phi3
42
44
  from test.utils import revert_liger_kernel_to_qwen2
43
45
  from test.utils import revert_liger_kernel_to_qwen2_vl
@@ -71,6 +73,15 @@ try:
71
73
  except ImportError:
72
74
  GRANITE_AVAILABLE = False
73
75
 
76
+ try:
77
+ # OLMO2 is only available in transformers>=4.47.0
78
+ from transformers.models.olmo2.configuration_olmo2 import Olmo2Config
79
+ from transformers.models.olmo2.modeling_olmo2 import Olmo2ForCausalLM
80
+
81
+ OLMO2_AVAILABLE = True
82
+ except ImportError:
83
+ OLMO2_AVAILABLE = False
84
+
74
85
  from liger_kernel.utils import infer_device
75
86
 
76
87
  device = infer_device()
@@ -390,6 +401,34 @@ if QWEN2_VL_AVAILABLE:
390
401
  attn_implementation="sdpa",
391
402
  ),
392
403
  )
404
+ if OLMO2_AVAILABLE:
405
+ MINI_MODEL_SETUPS["mini_olmo2"] = MiniModelConfig(
406
+ liger_kernel_patch_func=apply_liger_kernel_to_olmo2,
407
+ liger_kernel_patch_revert_func=revert_liger_kernel_to_olmo2,
408
+ model_class=Olmo2ForCausalLM,
409
+ mini_model_config=Olmo2Config(
410
+ bos_token_id=1, # 128000
411
+ eos_token_id=2, # 128001
412
+ pad_token_id=2,
413
+ cross_attention_layers=None,
414
+ dropout=0,
415
+ hidden_act="silu",
416
+ hidden_size=1024, # 4096
417
+ initializer_range=0.02,
418
+ intermediate_size=2048, # 14336
419
+ max_position_embeddings=4096,
420
+ num_attention_heads=8, # 32
421
+ num_hidden_layers=4, # 40
422
+ num_key_value_heads=2, # 8
423
+ rms_norm_eps=1e-5,
424
+ rope_scaling=None,
425
+ rope_theta=500_000,
426
+ tie_word_embeddings=False,
427
+ use_cache=True,
428
+ vocab_size=32000, # 128256,
429
+ attn_implementation="sdpa", # default value, pytorch native attention
430
+ ),
431
+ )
393
432
 
394
433
  if GRANITE_AVAILABLE:
395
434
  MINI_MODEL_SETUPS["mini_granite3"] = MiniModelConfig(
@@ -655,6 +694,25 @@ def run_mini_model(
655
694
  1e-2,
656
695
  marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
657
696
  ),
697
+ pytest.param(
698
+ "mini_olmo2",
699
+ 32,
700
+ 1e-4,
701
+ torch.bfloat16,
702
+ 1e-3,
703
+ 1e-2,
704
+ 1e-1,
705
+ 1e-2,
706
+ 1e-2,
707
+ 1e-2,
708
+ marks=[
709
+ pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
710
+ pytest.mark.skipif(
711
+ not OLMO2_AVAILABLE,
712
+ reason="OLMO2 not available in this version of transformers",
713
+ ),
714
+ ],
715
+ ),
658
716
  # TODO: Gemma2 test for bf16 is not passing within the tolerance range, might be casting issue, need to investigate
659
717
  # pytest.param(
660
718
  # "mini_gemma2",
@@ -699,6 +757,8 @@ def test_mini_model(
699
757
  rtol=loss_rtol,
700
758
  )
701
759
 
760
+ # No logits are materialized
761
+ # import pdb; pdb.set_trace()
702
762
  # Compare the logits from the last step
703
763
  assert_verbose_allclose(
704
764
  expected_output["logits"],
@@ -25,6 +25,7 @@ from liger_kernel.transformers import apply_liger_kernel_to_llama
25
25
  from liger_kernel.transformers import apply_liger_kernel_to_mistral
26
26
  from liger_kernel.transformers import apply_liger_kernel_to_mixtral
27
27
  from liger_kernel.transformers import apply_liger_kernel_to_mllama
28
+ from liger_kernel.transformers import apply_liger_kernel_to_olmo2
28
29
  from liger_kernel.transformers import apply_liger_kernel_to_phi3
29
30
  from liger_kernel.transformers import apply_liger_kernel_to_qwen2
30
31
  from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl
@@ -38,6 +39,7 @@ from test.utils import revert_liger_kernel_to_llama
38
39
  from test.utils import revert_liger_kernel_to_mistral
39
40
  from test.utils import revert_liger_kernel_to_mixtral
40
41
  from test.utils import revert_liger_kernel_to_mllama
42
+ from test.utils import revert_liger_kernel_to_olmo2
41
43
  from test.utils import revert_liger_kernel_to_phi3
42
44
  from test.utils import revert_liger_kernel_to_qwen2
43
45
  from test.utils import revert_liger_kernel_to_qwen2_vl
@@ -70,6 +72,15 @@ try:
70
72
  except ImportError:
71
73
  GRANITE_AVAILABLE = False
72
74
 
75
+ try:
76
+ # OLMO2 is only available in transformers>=4.47.0
77
+ from transformers.models.olmo2.configuration_olmo2 import Olmo2Config
78
+ from transformers.models.olmo2.modeling_olmo2 import Olmo2ForCausalLM
79
+
80
+ OLMO2_AVAILABLE = True
81
+ except ImportError:
82
+ OLMO2_AVAILABLE = False
83
+
73
84
  from liger_kernel.utils import infer_device
74
85
 
75
86
  device = infer_device()
@@ -425,6 +436,35 @@ if GRANITE_AVAILABLE:
425
436
  ),
426
437
  )
427
438
 
439
+ if OLMO2_AVAILABLE:
440
+ MINI_MODEL_SETUPS["mini_olmo2"] = MiniModelConfig(
441
+ liger_kernel_patch_func=apply_liger_kernel_to_olmo2,
442
+ liger_kernel_patch_revert_func=revert_liger_kernel_to_olmo2,
443
+ model_class=Olmo2ForCausalLM,
444
+ mini_model_config=Olmo2Config(
445
+ bos_token_id=1, # 128000
446
+ eos_token_id=2, # 128001
447
+ pad_token_id=2,
448
+ cross_attention_layers=None,
449
+ dropout=0,
450
+ hidden_act="silu",
451
+ hidden_size=1024, # 4096
452
+ initializer_range=0.02,
453
+ intermediate_size=2048, # 14336
454
+ max_position_embeddings=4096,
455
+ num_attention_heads=8, # 32
456
+ num_hidden_layers=4, # 40
457
+ num_key_value_heads=2, # 8
458
+ rms_norm_eps=1e-5,
459
+ rope_scaling=None,
460
+ rope_theta=500_000,
461
+ tie_word_embeddings=False,
462
+ use_cache=True,
463
+ vocab_size=32000, # 128256,
464
+ attn_implementation="sdpa", # default value, pytorch native attention
465
+ ),
466
+ )
467
+
428
468
 
429
469
  def create_model(model_name="mini_llama3"):
430
470
  """
@@ -536,6 +576,22 @@ def run_mini_model(
536
576
  reason="Qwen2-VL not available in this version of transformers",
537
577
  ),
538
578
  ),
579
+ pytest.param(
580
+ "mini_olmo2",
581
+ 32,
582
+ 1e-4,
583
+ torch.float32,
584
+ 1e-8,
585
+ 1e-5,
586
+ 5e-3,
587
+ 1e-5,
588
+ 5e-3,
589
+ 1e-5,
590
+ marks=pytest.mark.skipif(
591
+ not OLMO2_AVAILABLE,
592
+ reason="OLMO2 not available in this version of transformers",
593
+ ),
594
+ ),
539
595
  ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
540
596
  ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
541
597
  # TODO: mixtral is flaky so disable the test for now