liger-kernel-nightly 0.5.8.dev20250429220905__tar.gz → 0.5.8.dev20250502215739__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/PKG-INFO +2 -1
  2. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/README.md +1 -0
  3. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/pyproject.toml +1 -1
  4. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/ops/cross_entropy.py +4 -1
  5. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/ops/fused_linear_cross_entropy.py +4 -3
  6. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/__init__.py +3 -0
  7. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +1 -2
  8. liger_kernel_nightly-0.5.8.dev20250502215739/src/liger_kernel/transformers/model/glm4.py +123 -0
  9. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/monkey_patch.py +65 -0
  10. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel_nightly.egg-info/PKG-INFO +2 -1
  11. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel_nightly.egg-info/SOURCES.txt +1 -0
  12. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/convergence/bf16/test_mini_models.py +63 -0
  13. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/convergence/bf16/test_mini_models_with_logits.py +63 -1
  14. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/convergence/fp32/test_mini_models.py +60 -0
  15. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/convergence/fp32/test_mini_models_with_logits.py +60 -1
  16. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/transformers/test_cross_entropy.py +108 -2
  17. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/transformers/test_fused_linear_cross_entropy.py +3 -4
  18. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/transformers/test_monkey_patch.py +52 -0
  19. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/utils.py +12 -0
  20. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  21. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  22. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/.github/pull_request_template.md +0 -0
  23. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/.github/workflows/amd-ci.yml +0 -0
  24. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/.github/workflows/docs.yml +0 -0
  25. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/.github/workflows/intel-ci.yml +0 -0
  26. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/.github/workflows/nvi-ci.yml +0 -0
  27. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/.github/workflows/publish-nightly.yml +0 -0
  28. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/.github/workflows/publish-release.yml +0 -0
  29. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/.gitignore +0 -0
  30. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/LICENSE +0 -0
  31. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/Makefile +0 -0
  32. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/NOTICE +0 -0
  33. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/benchmark/README.md +0 -0
  34. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/benchmark/__init__.py +0 -0
  35. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/benchmark/benchmarks_visualizer.py +0 -0
  36. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/benchmark/data/all_benchmark_data.csv +0 -0
  37. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/benchmark/scripts/__init__.py +0 -0
  38. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  39. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  40. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  41. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  42. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/benchmark/scripts/benchmark_dyt.py +0 -0
  43. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/benchmark/scripts/benchmark_embedding.py +0 -0
  44. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  45. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  46. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/benchmark/scripts/benchmark_geglu.py +0 -0
  47. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/benchmark/scripts/benchmark_group_norm.py +0 -0
  48. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/benchmark/scripts/benchmark_jsd.py +0 -0
  49. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/benchmark/scripts/benchmark_kl_div.py +0 -0
  50. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  51. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  52. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  53. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  54. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  55. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/benchmark/scripts/benchmark_rope.py +0 -0
  56. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  57. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/benchmark/scripts/benchmark_swiglu.py +0 -0
  58. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/benchmark/scripts/benchmark_tvd.py +0 -0
  59. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/benchmark/scripts/utils.py +0 -0
  60. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/dev/fmt-requirements.txt +0 -0
  61. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/dev/modal/tests.py +0 -0
  62. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/dev/modal/tests_bwd.py +0 -0
  63. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/docs/Examples.md +0 -0
  64. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/docs/Getting-Started.md +0 -0
  65. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/docs/High-Level-APIs.md +0 -0
  66. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/docs/Low-Level-APIs.md +0 -0
  67. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/docs/acknowledgement.md +0 -0
  68. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/docs/contributing.md +0 -0
  69. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/docs/images/banner.GIF +0 -0
  70. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/docs/images/compose.gif +0 -0
  71. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/docs/images/e2e-memory.png +0 -0
  72. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/docs/images/e2e-tps.png +0 -0
  73. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/docs/images/logo-banner.png +0 -0
  74. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/docs/images/patch.gif +0 -0
  75. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/docs/images/post-training.png +0 -0
  76. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/docs/index.md +0 -0
  77. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/docs/license.md +0 -0
  78. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/alignment/accelerate_config.yaml +0 -0
  79. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/alignment/run_orpo.py +0 -0
  80. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/huggingface/README.md +0 -0
  81. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/huggingface/callback.py +0 -0
  82. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/huggingface/config/fsdp_config.json +0 -0
  83. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  84. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  85. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  86. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/huggingface/img/llama_tps.png +0 -0
  87. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  88. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/huggingface/img/qwen_tps.png +0 -0
  89. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/huggingface/launch_on_modal.py +0 -0
  90. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/huggingface/requirements.txt +0 -0
  91. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/huggingface/run_benchmarks.sh +0 -0
  92. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/huggingface/run_gemma.sh +0 -0
  93. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/huggingface/run_llama.sh +0 -0
  94. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/huggingface/run_qwen.sh +0 -0
  95. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/huggingface/run_qwen2_vl.sh +0 -0
  96. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/huggingface/training.py +0 -0
  97. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/huggingface/training_multimodal.py +0 -0
  98. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/lightning/README.md +0 -0
  99. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/lightning/requirements.txt +0 -0
  100. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/lightning/training.py +0 -0
  101. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/medusa/README.md +0 -0
  102. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/medusa/callback.py +0 -0
  103. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  104. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  105. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  106. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  107. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  108. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  109. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  110. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  111. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  112. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/medusa/medusa_util.py +0 -0
  113. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/medusa/requirements.txt +0 -0
  114. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  115. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/examples/medusa/train.py +0 -0
  116. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/licenses/LICENSE-Apache-2.0 +0 -0
  117. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  118. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  119. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/licenses/LICENSE-MIT-llmc +0 -0
  120. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/licenses/LICENSE-MIT-triton +0 -0
  121. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/mkdocs.yml +0 -0
  122. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/setup.cfg +0 -0
  123. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/setup.py +0 -0
  124. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/__init__.py +0 -0
  125. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/chunked_loss/README.md +0 -0
  126. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  127. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  128. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  129. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/chunked_loss/functional.py +0 -0
  130. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  131. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/chunked_loss/fused_linear_ppo.py +0 -0
  132. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  133. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  134. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
  135. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  136. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  137. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  138. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  139. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/env_report.py +0 -0
  140. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/ops/__init__.py +0 -0
  141. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/ops/dyt.py +0 -0
  142. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  143. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  144. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  145. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/ops/geglu.py +0 -0
  146. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/ops/group_norm.py +0 -0
  147. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/ops/jsd.py +0 -0
  148. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/ops/kl_div.py +0 -0
  149. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/ops/layer_norm.py +0 -0
  150. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  151. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/ops/rms_norm.py +0 -0
  152. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/ops/rope.py +0 -0
  153. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/ops/swiglu.py +0 -0
  154. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/ops/tvd.py +0 -0
  155. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/ops/utils.py +0 -0
  156. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/auto_model.py +0 -0
  157. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  158. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/dyt.py +0 -0
  159. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  160. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/functional.py +0 -0
  161. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  162. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/geglu.py +0 -0
  163. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/gema3_rms.py +0 -0
  164. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/group_norm.py +0 -0
  165. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/jsd.py +0 -0
  166. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/kl_div.py +0 -0
  167. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/layer_norm.py +0 -0
  168. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/model/__init__.py +0 -0
  169. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/model/gemma.py +0 -0
  170. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  171. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/model/gemma3.py +0 -0
  172. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/model/llama.py +0 -0
  173. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/model/llava.py +0 -0
  174. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
  175. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/model/mistral.py +0 -0
  176. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  177. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/model/mllama.py +0 -0
  178. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/model/olmo2.py +0 -0
  179. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/model/paligemma.py +0 -0
  180. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/model/phi3.py +0 -0
  181. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  182. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
  183. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  184. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  185. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/rms_norm.py +0 -0
  186. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/rope.py +0 -0
  187. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/swiglu.py +0 -0
  188. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  189. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  190. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  191. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/transformers/tvd.py +0 -0
  192. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/triton/__init__.py +0 -0
  193. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/triton/monkey_patch.py +0 -0
  194. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel/utils.py +0 -0
  195. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  196. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  197. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  198. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/__init__.py +0 -0
  199. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/chunked_loss/__init__.py +0 -0
  200. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/chunked_loss/test_cpo_loss.py +0 -0
  201. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/chunked_loss/test_dpo_loss.py +0 -0
  202. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/chunked_loss/test_grpo_loss.py +0 -0
  203. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/chunked_loss/test_jsd_loss.py +0 -0
  204. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/chunked_loss/test_kto_loss.py +0 -0
  205. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/chunked_loss/test_orpo_loss.py +0 -0
  206. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/chunked_loss/test_simpo_loss.py +0 -0
  207. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/conftest.py +0 -0
  208. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/convergence/__init__.py +0 -0
  209. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/convergence/bf16/__init__.py +0 -0
  210. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
  211. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/convergence/fp32/__init__.py +0 -0
  212. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
  213. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/resources/fake_configs/Google/Gemma3/gemma-3-4b-it/tokenizer_config.json +0 -0
  214. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
  215. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json +0 -0
  216. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json +0 -0
  217. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json +0 -0
  218. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  219. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
  220. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  221. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  222. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/resources/tiny_shakespeare.txt +0 -0
  223. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  224. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  225. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  226. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/transformers/test_auto_model.py +0 -0
  227. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/transformers/test_dyt.py +0 -0
  228. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/transformers/test_embedding.py +0 -0
  229. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/transformers/test_flex_attention.py +0 -0
  230. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/transformers/test_fused_linear_jsd.py +0 -0
  231. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/transformers/test_geglu.py +0 -0
  232. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/transformers/test_group_norm.py +0 -0
  233. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/transformers/test_jsd.py +0 -0
  234. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/transformers/test_kl_div.py +0 -0
  235. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/transformers/test_layer_norm.py +0 -0
  236. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/transformers/test_mm_int8int2.py +0 -0
  237. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/transformers/test_qwen2vl_mrope.py +0 -0
  238. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/transformers/test_rms_norm.py +0 -0
  239. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/transformers/test_rope.py +0 -0
  240. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/transformers/test_swiglu.py +0 -0
  241. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/transformers/test_trainer_integration.py +0 -0
  242. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/transformers/test_transformers.py +0 -0
  243. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/transformers/test_tvd.py +0 -0
  244. {liger_kernel_nightly-0.5.8.dev20250429220905 → liger_kernel_nightly-0.5.8.dev20250502215739}/test/triton/test_triton_monkey_patch.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.8.dev20250429220905
3
+ Version: 0.5.8.dev20250502215739
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -320,6 +320,7 @@ loss.backward()
320
320
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
321
321
  | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
322
322
  | OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
323
+ | GLM-4 | `liger_kernel.transformers.apply_liger_kernel_to_glm4` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
323
324
 
324
325
 
325
326
  ## Low-level APIs
@@ -272,6 +272,7 @@ loss.backward()
272
272
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
273
273
  | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
274
274
  | OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
275
+ | GLM-4 | `liger_kernel.transformers.apply_liger_kernel_to_glm4` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
275
276
 
276
277
 
277
278
  ## Low-level APIs
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.8.dev20250429220905"
7
+ version = "0.5.8.dev20250502215739"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -351,7 +351,10 @@ def cross_entropy_backward(_input, grad_output):
351
351
  # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
352
352
  if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
353
353
  pass
354
-
354
+ # If reduction is 'none'
355
+ elif grad_output.ndim > 0:
356
+ _input = _input * grad_output.unsqueeze(dim=1)
357
+ # If reduction is ['mean', 'sum'], grad_output is just a scalar
355
358
  # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
356
359
  # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
357
360
  else:
@@ -143,9 +143,10 @@ def fused_linear_cross_entropy_forward(
143
143
  alpha=1.0,
144
144
  )
145
145
 
146
- if reduction == "none":
147
- loss = loss_1d
148
- z_loss = z_loss_1d if return_z_loss else None
146
+ # Need extra calculations for backward if reduction=='none'. Not supporting reduction='none' now.
147
+ # if reduction == "none":
148
+ # loss = loss_1d
149
+ # z_loss = z_loss_1d if return_z_loss else None
149
150
 
150
151
  else:
151
152
  loss = torch.sum(loss_1d)
@@ -26,6 +26,7 @@ if TYPE_CHECKING:
26
26
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
27
27
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401
28
28
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401
29
+ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
29
30
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
30
31
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
31
32
  from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
@@ -79,6 +80,7 @@ def __getattr__(name: str):
79
80
  "apply_liger_kernel_to_gemma2",
80
81
  "apply_liger_kernel_to_gemma3",
81
82
  "apply_liger_kernel_to_gemma3_text",
83
+ "apply_liger_kernel_to_glm4",
82
84
  "apply_liger_kernel_to_granite",
83
85
  "apply_liger_kernel_to_llama",
84
86
  "apply_liger_kernel_to_llava",
@@ -129,6 +131,7 @@ if _TRANSFORMERS_AVAILABLE:
129
131
  "apply_liger_kernel_to_gemma2",
130
132
  "apply_liger_kernel_to_gemma3",
131
133
  "apply_liger_kernel_to_gemma3_text",
134
+ "apply_liger_kernel_to_glm4",
132
135
  "apply_liger_kernel_to_granite",
133
136
  "apply_liger_kernel_to_llama",
134
137
  "apply_liger_kernel_to_llava",
@@ -23,8 +23,7 @@ class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
23
23
  assert reduction in {
24
24
  "mean",
25
25
  "sum",
26
- "none",
27
- }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
26
+ }, f"reduction must be 'mean' or 'sum'. Got: {reduction}"
28
27
  assert softcap is None or softcap > 0, f"softcap must greater than 0.0 or None. Got: {softcap}"
29
28
  self.ce_weight = ce_weight
30
29
  self.ignore_index = ignore_index
@@ -0,0 +1,123 @@
1
+ from typing import List
2
+ from typing import Optional
3
+ from typing import Tuple
4
+ from typing import Union
5
+
6
+ import torch
7
+
8
+ from transformers.modeling_outputs import CausalLMOutputWithPast
9
+ from transformers.models.glm4.modeling_glm4 import _CONFIG_FOR_DOC
10
+ from transformers.models.glm4.modeling_glm4 import GLM4_INPUTS_DOCSTRING
11
+ from transformers.utils import add_start_docstrings_to_model_forward
12
+ from transformers.utils import replace_return_docstrings
13
+ from transformers.utils.deprecation import deprecate_kwarg
14
+
15
+ from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
16
+
17
+
18
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
19
+ @add_start_docstrings_to_model_forward(GLM4_INPUTS_DOCSTRING)
20
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
21
+ def lce_forward(
22
+ self,
23
+ input_ids: torch.LongTensor = None,
24
+ attention_mask: Optional[torch.Tensor] = None,
25
+ position_ids: Optional[torch.LongTensor] = None,
26
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
27
+ inputs_embeds: Optional[torch.FloatTensor] = None,
28
+ labels: Optional[torch.LongTensor] = None,
29
+ use_cache: Optional[bool] = None,
30
+ output_attentions: Optional[bool] = None,
31
+ output_hidden_states: Optional[bool] = None,
32
+ return_dict: Optional[bool] = None,
33
+ cache_position: Optional[torch.LongTensor] = None,
34
+ logits_to_keep: Union[int, torch.Tensor] = 0,
35
+ **loss_kwargs,
36
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
37
+ r"""
38
+ Args:
39
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
40
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
41
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
42
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
43
+
44
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
45
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
46
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
47
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
48
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
49
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
50
+
51
+ Returns:
52
+
53
+ Example:
54
+
55
+ ```python
56
+ >>> from transformers import AutoTokenizer, Glm4ForCausalLM
57
+
58
+ >>> model = Glm4ForCausalLM.from_pretrained("THUDM/GLM-4-9B-0414")
59
+ >>> tokenizer = AutoTokenizer.from_pretrained("THUDM/GLM-4-9B-0414")
60
+
61
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
62
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
63
+
64
+ >>> # Generate
65
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
66
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
67
+ 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
68
+ ```
69
+ """
70
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
71
+ output_hidden_states = (
72
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
73
+ )
74
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
75
+
76
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
77
+ outputs = self.model(
78
+ input_ids=input_ids,
79
+ attention_mask=attention_mask,
80
+ position_ids=position_ids,
81
+ past_key_values=past_key_values,
82
+ inputs_embeds=inputs_embeds,
83
+ use_cache=use_cache,
84
+ output_attentions=output_attentions,
85
+ output_hidden_states=output_hidden_states,
86
+ return_dict=return_dict,
87
+ cache_position=cache_position,
88
+ )
89
+
90
+ hidden_states = outputs[0]
91
+
92
+ shift_labels = loss_kwargs.pop("shift_labels", None)
93
+ logits = None
94
+ loss = None
95
+ # if in training mode, don't materialize logits
96
+ if self.training and (labels is not None or shift_labels is not None):
97
+ loss = LigerForCausalLMLoss(
98
+ hidden_states=hidden_states,
99
+ lm_head_weight=self.lm_head.weight,
100
+ labels=labels,
101
+ shift_labels=shift_labels,
102
+ hidden_size=self.config.hidden_size,
103
+ **loss_kwargs,
104
+ )
105
+
106
+ else: # if in inference mode materialize logits
107
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
108
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
109
+ if labels is not None:
110
+ loss = self.loss_function(
111
+ logits=logits,
112
+ labels=labels,
113
+ vocab_size=self.config.vocab_size,
114
+ **loss_kwargs,
115
+ )
116
+
117
+ return CausalLMOutputWithPast(
118
+ loss=loss,
119
+ logits=logits,
120
+ past_key_values=outputs.past_key_values,
121
+ hidden_states=outputs.hidden_states,
122
+ attentions=outputs.attentions,
123
+ )
@@ -17,6 +17,7 @@ from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forwa
17
17
  from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
18
18
  from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
19
19
  from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
20
+ from liger_kernel.transformers.model.glm4 import lce_forward as glm4_lce_forward
20
21
  from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
21
22
  from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
22
23
  from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
@@ -1319,12 +1320,76 @@ def apply_liger_kernel_to_olmo2(
1319
1320
  _patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
1320
1321
 
1321
1322
 
1323
+ def apply_liger_kernel_to_glm4(
1324
+ rope: bool = False,
1325
+ cross_entropy: bool = False,
1326
+ fused_linear_cross_entropy: bool = True,
1327
+ rms_norm: bool = True,
1328
+ swiglu: bool = True,
1329
+ model: PreTrainedModel = None,
1330
+ ) -> None:
1331
+ """
1332
+ Apply Liger kernels to replace original implementation in HuggingFace GLM-4 models.
1333
+
1334
+ Args:
1335
+ rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
1336
+ cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1337
+ fused_linear_cross_entropy (bool):
1338
+ Whether to apply Liger's fused linear cross entropy loss. Default is True.
1339
+ `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1340
+ If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1341
+ rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1342
+ swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
1343
+ model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1344
+ loaded. Default is None.
1345
+ """
1346
+ assert not (cross_entropy and fused_linear_cross_entropy), (
1347
+ "cross_entropy and fused_linear_cross_entropy cannot both be True."
1348
+ )
1349
+
1350
+ from transformers.models.glm4 import modeling_glm4
1351
+ from transformers.models.glm4.modeling_glm4 import Glm4Model
1352
+
1353
+ if rope:
1354
+ raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
1355
+ if rms_norm:
1356
+ modeling_glm4.Glm4RMSNorm = partial(LigerRMSNorm, in_place=False)
1357
+ if swiglu:
1358
+ modeling_glm4.Glm4MLP = LigerPhi3SwiGLUMLP
1359
+ if cross_entropy:
1360
+ from transformers.loss.loss_utils import nn
1361
+
1362
+ nn.functional.cross_entropy = liger_cross_entropy
1363
+ if fused_linear_cross_entropy:
1364
+ modeling_glm4.Glm4ForCausalLM.forward = glm4_lce_forward
1365
+
1366
+ if model is not None:
1367
+ # The model instance already exists, so we need to additionally patch the
1368
+ # instance variables that reference already-instantiated modules
1369
+
1370
+ # get the base model from the model instance
1371
+ base_model: Glm4Model = getattr(model, model.base_model_prefix, model)
1372
+
1373
+ if rms_norm:
1374
+ _patch_rms_norm_module(base_model.norm, in_place=False)
1375
+
1376
+ for decoder_layer in base_model.layers:
1377
+ if swiglu:
1378
+ _patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
1379
+ if rms_norm:
1380
+ _patch_rms_norm_module(decoder_layer.input_layernorm, in_place=False)
1381
+ _patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
1382
+ _patch_rms_norm_module(decoder_layer.post_self_attn_layernorm, in_place=False)
1383
+ _patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False)
1384
+
1385
+
1322
1386
  # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
1323
1387
  MODEL_TYPE_TO_APPLY_LIGER_FN = {
1324
1388
  "gemma": apply_liger_kernel_to_gemma,
1325
1389
  "gemma2": apply_liger_kernel_to_gemma2,
1326
1390
  "gemma3_text": apply_liger_kernel_to_gemma3_text,
1327
1391
  "gemma3": apply_liger_kernel_to_gemma3,
1392
+ "glm4": apply_liger_kernel_to_glm4,
1328
1393
  "llama": apply_liger_kernel_to_llama,
1329
1394
  "llava": apply_liger_kernel_to_llava,
1330
1395
  "granite": apply_liger_kernel_to_granite,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.8.dev20250429220905
3
+ Version: 0.5.8.dev20250502215739
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -320,6 +320,7 @@ loss.backward()
320
320
  | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
321
321
  | Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
322
322
  | OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
323
+ | GLM-4 | `liger_kernel.transformers.apply_liger_kernel_to_glm4` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
323
324
 
324
325
 
325
326
  ## Low-level APIs
@@ -163,6 +163,7 @@ src/liger_kernel/transformers/model/__init__.py
163
163
  src/liger_kernel/transformers/model/gemma.py
164
164
  src/liger_kernel/transformers/model/gemma2.py
165
165
  src/liger_kernel/transformers/model/gemma3.py
166
+ src/liger_kernel/transformers/model/glm4.py
166
167
  src/liger_kernel/transformers/model/llama.py
167
168
  src/liger_kernel/transformers/model/llava.py
168
169
  src/liger_kernel/transformers/model/loss_utils.py
@@ -21,6 +21,7 @@ from transformers.models.qwen2 import Qwen2ForCausalLM
21
21
  from liger_kernel.transformers import apply_liger_kernel_to_gemma
22
22
  from liger_kernel.transformers import apply_liger_kernel_to_gemma2
23
23
  from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text
24
+ from liger_kernel.transformers import apply_liger_kernel_to_glm4
24
25
  from liger_kernel.transformers import apply_liger_kernel_to_granite
25
26
  from liger_kernel.transformers import apply_liger_kernel_to_llama
26
27
  from liger_kernel.transformers import apply_liger_kernel_to_llava
@@ -38,6 +39,7 @@ from test.utils import assert_verbose_allclose
38
39
  from test.utils import revert_liger_kernel_to_gemma
39
40
  from test.utils import revert_liger_kernel_to_gemma2
40
41
  from test.utils import revert_liger_kernel_to_gemma3_text
42
+ from test.utils import revert_liger_kernel_to_glm4
41
43
  from test.utils import revert_liger_kernel_to_granite
42
44
  from test.utils import revert_liger_kernel_to_llama
43
45
  from test.utils import revert_liger_kernel_to_llava
@@ -106,6 +108,14 @@ try:
106
108
  except ImportError:
107
109
  OLMO2_AVAILABLE = False
108
110
 
111
+ try:
112
+ # Glm4 is only available in transformers>=4.51.3
113
+ from transformers.models.glm4.configuration_glm4 import Glm4Config
114
+ from transformers.models.glm4.modeling_glm4 import Glm4ForCausalLM
115
+
116
+ GLM4_AVAILABLE = True
117
+ except ImportError:
118
+ GLM4_AVAILABLE = False
109
119
 
110
120
  try:
111
121
  from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig
@@ -644,6 +654,37 @@ if OLMO2_AVAILABLE:
644
654
  ),
645
655
  )
646
656
 
657
+ if GLM4_AVAILABLE:
658
+ MINI_MODEL_SETUPS["mini_glm4"] = MiniModelConfig(
659
+ liger_kernel_patch_func=apply_liger_kernel_to_glm4,
660
+ liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4,
661
+ model_class=Glm4ForCausalLM,
662
+ mini_model_config=Glm4Config(
663
+ bos_token_id=1, # None
664
+ eos_token_id=2, # 151329, 151336, 151338
665
+ pad_token_id=2, # 151329
666
+ partial_rotary_factor=0.5,
667
+ cross_attention_layers=None,
668
+ dropout=0,
669
+ hidden_act="silu",
670
+ hidden_size=1024, # 6144
671
+ initializer_range=0.02,
672
+ intermediate_size=2048, # 14336
673
+ max_position_embeddings=4096, # 32768
674
+ num_attention_heads=8, # 48
675
+ num_hidden_layers=4, # 61
676
+ num_key_value_heads=2,
677
+ rms_norm_eps=1e-5,
678
+ rope_scaling=None,
679
+ rope_theta=500_000,
680
+ tie_word_embeddings=False,
681
+ use_cache=True,
682
+ vocab_size=32000, # 151552
683
+ attention_bias=True,
684
+ attn_implementation="sdpa", # default value, pytorch native attention
685
+ ),
686
+ )
687
+
647
688
 
648
689
  def create_model(model_name="mini_llama3"):
649
690
  """
@@ -679,6 +720,9 @@ def run_mini_model(
679
720
  "rms_norm": True,
680
721
  }
681
722
 
723
+ if "glm4" in model_name:
724
+ kwargs["rope"] = False
725
+
682
726
  model_supports_layer_norm = "qwen2_vl" in model_name
683
727
  if model_supports_layer_norm:
684
728
  kwargs["layer_norm"] = True
@@ -890,6 +934,25 @@ def run_mini_model(
890
934
  ),
891
935
  ],
892
936
  ),
937
+ pytest.param(
938
+ "mini_glm4",
939
+ 32,
940
+ 1e-4,
941
+ torch.bfloat16,
942
+ 1e-3,
943
+ 1e-2,
944
+ 1e-1,
945
+ 1e-2,
946
+ 1e-2,
947
+ 1e-2,
948
+ marks=[
949
+ pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
950
+ pytest.mark.skipif(
951
+ not GLM4_AVAILABLE,
952
+ reason="Glm4 not available in this version of transformers",
953
+ ),
954
+ ],
955
+ ),
893
956
  # TODO: mixtral is flaky so disable the test for now
894
957
  # pytest.param(
895
958
  # "mini_mixtral",
@@ -21,6 +21,7 @@ from transformers.models.qwen2 import Qwen2ForCausalLM
21
21
  from liger_kernel.transformers import apply_liger_kernel_to_gemma
22
22
  from liger_kernel.transformers import apply_liger_kernel_to_gemma2
23
23
  from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text
24
+ from liger_kernel.transformers import apply_liger_kernel_to_glm4
24
25
  from liger_kernel.transformers import apply_liger_kernel_to_granite
25
26
  from liger_kernel.transformers import apply_liger_kernel_to_llama
26
27
  from liger_kernel.transformers import apply_liger_kernel_to_llava
@@ -38,6 +39,7 @@ from test.utils import assert_verbose_allclose
38
39
  from test.utils import revert_liger_kernel_to_gemma
39
40
  from test.utils import revert_liger_kernel_to_gemma2
40
41
  from test.utils import revert_liger_kernel_to_gemma3_text
42
+ from test.utils import revert_liger_kernel_to_glm4
41
43
  from test.utils import revert_liger_kernel_to_granite
42
44
  from test.utils import revert_liger_kernel_to_llama
43
45
  from test.utils import revert_liger_kernel_to_llava
@@ -106,6 +108,14 @@ try:
106
108
  except ImportError:
107
109
  OLMO2_AVAILABLE = False
108
110
 
111
+ try:
112
+ # Glm4 is only available in transformers>=4.51.3
113
+ from transformers.models.glm4.configuration_glm4 import Glm4Config
114
+ from transformers.models.glm4.modeling_glm4 import Glm4ForCausalLM
115
+
116
+ GLM4_AVAILABLE = True
117
+ except ImportError:
118
+ GLM4_AVAILABLE = False
109
119
 
110
120
  try:
111
121
  from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig
@@ -377,7 +387,6 @@ if GEMMA3_AVAILABLE:
377
387
  ),
378
388
  )
379
389
 
380
-
381
390
  if MLLAMA_AVAILABLE:
382
391
  MINI_MODEL_SETUPS["mini_mllama"] = MiniModelConfig(
383
392
  liger_kernel_patch_func=apply_liger_kernel_to_mllama,
@@ -645,6 +654,37 @@ if OLMO2_AVAILABLE:
645
654
  ),
646
655
  )
647
656
 
657
+ if GLM4_AVAILABLE:
658
+ MINI_MODEL_SETUPS["mini_glm4"] = MiniModelConfig(
659
+ liger_kernel_patch_func=apply_liger_kernel_to_glm4,
660
+ liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4,
661
+ model_class=Glm4ForCausalLM,
662
+ mini_model_config=Glm4Config(
663
+ bos_token_id=1, # None
664
+ eos_token_id=2, # 151329, 151336, 151338
665
+ pad_token_id=2, # 151329
666
+ partial_rotary_factor=0.5,
667
+ cross_attention_layers=None,
668
+ dropout=0,
669
+ hidden_act="silu",
670
+ hidden_size=1024, # 6144
671
+ initializer_range=0.02,
672
+ intermediate_size=2048, # 14336
673
+ max_position_embeddings=4096, # 32768
674
+ num_attention_heads=8, # 48
675
+ num_hidden_layers=4, # 61
676
+ num_key_value_heads=2,
677
+ rms_norm_eps=1e-5,
678
+ rope_scaling=None,
679
+ rope_theta=500_000,
680
+ tie_word_embeddings=False,
681
+ use_cache=True,
682
+ vocab_size=32000, # 151552
683
+ attention_bias=True,
684
+ attn_implementation="sdpa", # default value, pytorch native attention
685
+ ),
686
+ )
687
+
648
688
 
649
689
  def create_model(model_name="mini_llama3"):
650
690
  """
@@ -680,6 +720,9 @@ def run_mini_model(
680
720
  "rms_norm": True,
681
721
  }
682
722
 
723
+ if "glm4" in model_name:
724
+ kwargs["rope"] = False
725
+
683
726
  model_supports_layer_norm = "qwen2_vl" in model_name
684
727
  if model_supports_layer_norm:
685
728
  kwargs["layer_norm"] = True
@@ -933,6 +976,25 @@ def run_mini_model(
933
976
  ),
934
977
  ],
935
978
  ),
979
+ pytest.param(
980
+ "mini_glm4",
981
+ 32,
982
+ 1e-4,
983
+ torch.bfloat16,
984
+ 1e-3,
985
+ 1e-2,
986
+ 1e-1,
987
+ 1e-2,
988
+ 1e-2,
989
+ 1e-2,
990
+ marks=[
991
+ pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
992
+ pytest.mark.skipif(
993
+ not GLM4_AVAILABLE,
994
+ reason="Glm4 not available in this version of transformers",
995
+ ),
996
+ ],
997
+ ),
936
998
  # TODO: Gemma2 test for bf16 is not passing within the tolerance range, might be casting issue, need to investigate
937
999
  # pytest.param(
938
1000
  # "mini_gemma2",
@@ -21,6 +21,7 @@ from transformers.models.qwen2 import Qwen2ForCausalLM
21
21
  from liger_kernel.transformers import apply_liger_kernel_to_gemma
22
22
  from liger_kernel.transformers import apply_liger_kernel_to_gemma2
23
23
  from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text
24
+ from liger_kernel.transformers import apply_liger_kernel_to_glm4
24
25
  from liger_kernel.transformers import apply_liger_kernel_to_granite
25
26
  from liger_kernel.transformers import apply_liger_kernel_to_llama
26
27
  from liger_kernel.transformers import apply_liger_kernel_to_llava
@@ -38,6 +39,7 @@ from test.utils import assert_verbose_allclose
38
39
  from test.utils import revert_liger_kernel_to_gemma
39
40
  from test.utils import revert_liger_kernel_to_gemma2
40
41
  from test.utils import revert_liger_kernel_to_gemma3_text
42
+ from test.utils import revert_liger_kernel_to_glm4
41
43
  from test.utils import revert_liger_kernel_to_granite
42
44
  from test.utils import revert_liger_kernel_to_llama
43
45
  from test.utils import revert_liger_kernel_to_llava
@@ -105,6 +107,14 @@ try:
105
107
  except ImportError:
106
108
  OLMO2_AVAILABLE = False
107
109
 
110
+ try:
111
+ # Glm4 is only available in transformers>=4.51.3
112
+ from transformers.models.glm4.configuration_glm4 import Glm4Config
113
+ from transformers.models.glm4.modeling_glm4 import Glm4ForCausalLM
114
+
115
+ GLM4_AVAILABLE = True
116
+ except ImportError:
117
+ GLM4_AVAILABLE = False
108
118
 
109
119
  try:
110
120
  from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig
@@ -583,6 +593,37 @@ if OLMO2_AVAILABLE:
583
593
  ),
584
594
  )
585
595
 
596
+ if GLM4_AVAILABLE:
597
+ MINI_MODEL_SETUPS["mini_glm4"] = MiniModelConfig(
598
+ liger_kernel_patch_func=apply_liger_kernel_to_glm4,
599
+ liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4,
600
+ model_class=Glm4ForCausalLM,
601
+ mini_model_config=Glm4Config(
602
+ bos_token_id=1, # None
603
+ eos_token_id=2, # 151329, 151336, 151338
604
+ pad_token_id=2, # 151329
605
+ partial_rotary_factor=0.5,
606
+ cross_attention_layers=None,
607
+ dropout=0,
608
+ hidden_act="silu",
609
+ hidden_size=1024, # 6144
610
+ initializer_range=0.02,
611
+ intermediate_size=2048, # 14336
612
+ max_position_embeddings=32768,
613
+ num_attention_heads=8, # 48
614
+ num_hidden_layers=4, # 61
615
+ num_key_value_heads=2,
616
+ rms_norm_eps=1e-5,
617
+ rope_scaling=None,
618
+ rope_theta=500_000,
619
+ tie_word_embeddings=False,
620
+ use_cache=True,
621
+ vocab_size=32000, # 151552
622
+ attention_bias=True,
623
+ attn_implementation="sdpa", # default value, pytorch native attention
624
+ ),
625
+ )
626
+
586
627
  if LLAVA_AVAILABLE:
587
628
  # https://huggingface.co/llava-hf/llava-1.5-7b-hf
588
629
  MINI_MODEL_SETUPS["mini_llava"] = MiniModelConfig(
@@ -677,6 +718,9 @@ def run_mini_model(
677
718
  "rms_norm": True,
678
719
  }
679
720
 
721
+ if "glm4" in model_name:
722
+ kwargs["rope"] = False
723
+
680
724
  model_supports_layer_norm = "qwen2_vl" in model_name
681
725
  if model_supports_layer_norm:
682
726
  kwargs["layer_norm"] = True
@@ -820,6 +864,22 @@ def run_mini_model(
820
864
  reason="OLMO2 not available in this version of transformers",
821
865
  ),
822
866
  ),
867
+ pytest.param(
868
+ "mini_glm4",
869
+ 32,
870
+ 1e-4,
871
+ torch.float32,
872
+ 1e-8,
873
+ 1e-5,
874
+ 5e-3,
875
+ 1e-5,
876
+ 5e-3,
877
+ 1e-5,
878
+ marks=pytest.mark.skipif(
879
+ not GLM4_AVAILABLE,
880
+ reason="Glm4 not available in this version of transformers",
881
+ ),
882
+ ),
823
883
  ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
824
884
  ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
825
885
  # TODO: mixtral is flaky so disable the test for now