liger-kernel-nightly 0.5.5.dev20250320214749__tar.gz → 0.5.5.dev20250324181221__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (236) hide show
  1. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/PKG-INFO +4 -4
  2. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/README.md +3 -3
  3. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/benchmarks_visualizer.py +2 -2
  4. liger_kernel_nightly-0.5.5.dev20250324181221/benchmark/scripts/benchmark_dyt.py +139 -0
  5. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/pyproject.toml +1 -1
  6. liger_kernel_nightly-0.5.5.dev20250324181221/src/liger_kernel/ops/dyt.py +225 -0
  7. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/__init__.py +1 -0
  8. liger_kernel_nightly-0.5.5.dev20250324181221/src/liger_kernel/transformers/dyt.py +20 -0
  9. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/functional.py +5 -0
  10. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel_nightly.egg-info/PKG-INFO +4 -4
  11. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel_nightly.egg-info/SOURCES.txt +4 -0
  12. liger_kernel_nightly-0.5.5.dev20250324181221/test/transformers/test_dyt.py +136 -0
  13. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  14. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  15. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/.github/pull_request_template.md +0 -0
  16. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/.github/workflows/amd-ci.yml +0 -0
  17. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/.github/workflows/docs.yml +0 -0
  18. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/.github/workflows/intel-ci.yml +0 -0
  19. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/.github/workflows/nvi-ci.yml +0 -0
  20. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/.github/workflows/publish-nightly.yml +0 -0
  21. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/.github/workflows/publish-release.yml +0 -0
  22. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/.gitignore +0 -0
  23. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/LICENSE +0 -0
  24. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/Makefile +0 -0
  25. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/NOTICE +0 -0
  26. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/README.md +0 -0
  27. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/__init__.py +0 -0
  28. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/data/all_benchmark_data.csv +0 -0
  29. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/__init__.py +0 -0
  30. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  31. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  32. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  33. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  34. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_embedding.py +0 -0
  35. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  36. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  37. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_geglu.py +0 -0
  38. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_group_norm.py +0 -0
  39. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_jsd.py +0 -0
  40. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_kl_div.py +0 -0
  41. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  42. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  43. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  44. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  45. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  46. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_rope.py +0 -0
  47. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  48. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_swiglu.py +0 -0
  49. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_tvd.py +0 -0
  50. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/utils.py +0 -0
  51. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/dev/fmt-requirements.txt +0 -0
  52. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/dev/modal/tests.py +0 -0
  53. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/dev/modal/tests_bwd.py +0 -0
  54. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/Examples.md +0 -0
  55. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/Getting-Started.md +0 -0
  56. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/High-Level-APIs.md +0 -0
  57. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/Low-Level-APIs.md +0 -0
  58. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/acknowledgement.md +0 -0
  59. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/contributing.md +0 -0
  60. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/images/banner.GIF +0 -0
  61. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/images/compose.gif +0 -0
  62. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/images/e2e-memory.png +0 -0
  63. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/images/e2e-tps.png +0 -0
  64. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/images/logo-banner.png +0 -0
  65. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/images/patch.gif +0 -0
  66. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/images/post-training.png +0 -0
  67. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/index.md +0 -0
  68. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/license.md +0 -0
  69. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/alignment/accelerate_config.yaml +0 -0
  70. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/alignment/run_orpo.py +0 -0
  71. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/README.md +0 -0
  72. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/callback.py +0 -0
  73. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/config/fsdp_config.json +0 -0
  74. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  75. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  76. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  77. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/img/llama_tps.png +0 -0
  78. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  79. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/img/qwen_tps.png +0 -0
  80. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/launch_on_modal.py +0 -0
  81. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/requirements.txt +0 -0
  82. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/run_benchmarks.sh +0 -0
  83. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/run_gemma.sh +0 -0
  84. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/run_llama.sh +0 -0
  85. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/run_qwen.sh +0 -0
  86. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/run_qwen2_vl.sh +0 -0
  87. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/training.py +0 -0
  88. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/training_multimodal.py +0 -0
  89. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/lightning/README.md +0 -0
  90. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/lightning/requirements.txt +0 -0
  91. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/lightning/training.py +0 -0
  92. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/README.md +0 -0
  93. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/callback.py +0 -0
  94. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  95. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  96. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  97. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  98. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  99. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  100. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  101. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  102. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  103. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/medusa_util.py +0 -0
  104. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/requirements.txt +0 -0
  105. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  106. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/train.py +0 -0
  107. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/licenses/LICENSE-Apache-2.0 +0 -0
  108. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  109. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  110. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/licenses/LICENSE-MIT-llmc +0 -0
  111. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/licenses/LICENSE-MIT-triton +0 -0
  112. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/mkdocs.yml +0 -0
  113. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/setup.cfg +0 -0
  114. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/setup.py +0 -0
  115. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/__init__.py +0 -0
  116. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/README.md +0 -0
  117. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  118. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  119. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  120. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/functional.py +0 -0
  121. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  122. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  123. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -0
  124. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  125. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
  126. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  127. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  128. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  129. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  130. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/env_report.py +0 -0
  131. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/__init__.py +0 -0
  132. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/cross_entropy.py +0 -0
  133. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  134. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  135. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  136. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  137. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/geglu.py +0 -0
  138. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/group_norm.py +0 -0
  139. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/jsd.py +0 -0
  140. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/kl_div.py +0 -0
  141. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/layer_norm.py +0 -0
  142. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  143. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/rms_norm.py +0 -0
  144. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/rope.py +0 -0
  145. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/swiglu.py +0 -0
  146. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/tvd.py +0 -0
  147. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/utils.py +0 -0
  148. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/auto_model.py +0 -0
  149. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  150. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  151. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  152. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  153. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/geglu.py +0 -0
  154. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/group_norm.py +0 -0
  155. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/jsd.py +0 -0
  156. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/kl_div.py +0 -0
  157. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/layer_norm.py +0 -0
  158. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/__init__.py +0 -0
  159. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/gemma.py +0 -0
  160. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  161. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/llama.py +0 -0
  162. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
  163. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/mistral.py +0 -0
  164. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  165. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/mllama.py +0 -0
  166. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/olmo2.py +0 -0
  167. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/paligemma.py +0 -0
  168. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/phi3.py +0 -0
  169. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  170. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
  171. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  172. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  173. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  174. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/rms_norm.py +0 -0
  175. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/rope.py +0 -0
  176. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/swiglu.py +0 -0
  177. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  178. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  179. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  180. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/tvd.py +0 -0
  181. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/triton/__init__.py +0 -0
  182. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/triton/monkey_patch.py +0 -0
  183. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/utils.py +0 -0
  184. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  185. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  186. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  187. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/__init__.py +0 -0
  188. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/chunked_loss/__init__.py +0 -0
  189. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/chunked_loss/test_cpo_loss.py +0 -0
  190. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/chunked_loss/test_dpo_loss.py +0 -0
  191. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/chunked_loss/test_grpo_loss.py +0 -0
  192. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/chunked_loss/test_jsd_loss.py +0 -0
  193. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/chunked_loss/test_kto_loss.py +0 -0
  194. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/chunked_loss/test_orpo_loss.py +0 -0
  195. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/chunked_loss/test_simpo_loss.py +0 -0
  196. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/conftest.py +0 -0
  197. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/convergence/__init__.py +0 -0
  198. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/convergence/bf16/__init__.py +0 -0
  199. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/convergence/bf16/test_mini_models.py +0 -0
  200. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
  201. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
  202. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/convergence/fp32/__init__.py +0 -0
  203. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/convergence/fp32/test_mini_models.py +0 -0
  204. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
  205. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
  206. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
  207. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  208. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
  209. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  210. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  211. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/resources/tiny_shakespeare.txt +0 -0
  212. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  213. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  214. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  215. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_auto_model.py +0 -0
  216. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_cross_entropy.py +0 -0
  217. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_embedding.py +0 -0
  218. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_flex_attention.py +0 -0
  219. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  220. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_fused_linear_jsd.py +0 -0
  221. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_geglu.py +0 -0
  222. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_group_norm.py +0 -0
  223. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_jsd.py +0 -0
  224. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_kl_div.py +0 -0
  225. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_layer_norm.py +0 -0
  226. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_mm_int8int2.py +0 -0
  227. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_monkey_patch.py +0 -0
  228. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_qwen2vl_mrope.py +0 -0
  229. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_rms_norm.py +0 -0
  230. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_rope.py +0 -0
  231. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_swiglu.py +0 -0
  232. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_trainer_integration.py +0 -0
  233. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_transformers.py +0 -0
  234. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_tvd.py +0 -0
  235. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/triton/test_triton_monkey_patch.py +0 -0
  236. {liger_kernel_nightly-0.5.5.dev20250320214749 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.5.dev20250320214749
3
+ Version: 0.5.5.dev20250324181221
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -176,7 +176,7 @@ y = orpo_loss(lm_head.weight, x, target)
176
176
  - **Exact:** Computation is exact—no approximations! Both forward and backward passes are implemented with rigorous unit tests and undergo convergence testing against training runs without Liger Kernel to ensure accuracy.
177
177
  - **Lightweight:** Liger Kernel has minimal dependencies, requiring only Torch and Triton—no extra libraries needed! Say goodbye to dependency headaches!
178
178
  - **Multi-GPU supported:** Compatible with multi-GPU setups (PyTorch FSDP, DeepSpeed, DDP, etc.).
179
- - **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860), [SWIFT](https://github.com/modelscope/ms-swift)
179
+ - **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860), [SWIFT](https://github.com/modelscope/ms-swift), [oumi](https://github.com/oumi-ai/oumi/tree/main)
180
180
 
181
181
  ## Installation
182
182
 
@@ -386,8 +386,8 @@ loss.backward()
386
386
  ## Contact
387
387
 
388
388
  - For issues, create a Github ticket in this repository
389
- - For open discussion, join [our discord channel](https://discord.gg/gpumode)
390
- - For formal collaboration, send an email to yannchen@linkedin.com
389
+ - For open discussion, join [our discord channel on GPUMode](https://discord.com/channels/1189498204333543425/1275130785933951039)
390
+ - For formal collaboration, send an email to yannchen@linkedin.com and hning@linkedin.com
391
391
 
392
392
  ## Cite this work
393
393
 
@@ -128,7 +128,7 @@ y = orpo_loss(lm_head.weight, x, target)
128
128
  - **Exact:** Computation is exact—no approximations! Both forward and backward passes are implemented with rigorous unit tests and undergo convergence testing against training runs without Liger Kernel to ensure accuracy.
129
129
  - **Lightweight:** Liger Kernel has minimal dependencies, requiring only Torch and Triton—no extra libraries needed! Say goodbye to dependency headaches!
130
130
  - **Multi-GPU supported:** Compatible with multi-GPU setups (PyTorch FSDP, DeepSpeed, DDP, etc.).
131
- - **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860), [SWIFT](https://github.com/modelscope/ms-swift)
131
+ - **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860), [SWIFT](https://github.com/modelscope/ms-swift), [oumi](https://github.com/oumi-ai/oumi/tree/main)
132
132
 
133
133
  ## Installation
134
134
 
@@ -338,8 +338,8 @@ loss.backward()
338
338
  ## Contact
339
339
 
340
340
  - For issues, create a Github ticket in this repository
341
- - For open discussion, join [our discord channel](https://discord.gg/gpumode)
342
- - For formal collaboration, send an email to yannchen@linkedin.com
341
+ - For open discussion, join [our discord channel on GPUMode](https://discord.com/channels/1189498204333543425/1275130785933951039)
342
+ - For formal collaboration, send an email to yannchen@linkedin.com and hning@linkedin.com
343
343
 
344
344
  ## Cite this work
345
345
 
@@ -8,8 +8,8 @@ import matplotlib.pyplot as plt
8
8
  import pandas as pd
9
9
  import seaborn as sns
10
10
 
11
- DATA_PATH = "data/all_benchmark_data.csv"
12
- VISUALIZATIONS_PATH = "visualizations/"
11
+ DATA_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "data/all_benchmark_data.csv"))
12
+ VISUALIZATIONS_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "visualizations/"))
13
13
 
14
14
 
15
15
  @dataclass
@@ -0,0 +1,139 @@
1
+ import os
2
+ import sys
3
+
4
+ import torch
5
+ import triton
6
+
7
+ from utils import QUANTILES
8
+ from utils import SingleBenchmarkRunInput
9
+ from utils import SingleBenchmarkRunOutput
10
+ from utils import _test_memory
11
+ from utils import parse_benchmark_script_args
12
+ from utils import run_benchmarks
13
+
14
+ from liger_kernel.utils import infer_device
15
+
16
+ device = infer_device()
17
+
18
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
19
+
20
+
21
+ def bench_speed_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
22
+ from test.transformers.test_dyt import LigerDyT
23
+ from test.transformers.test_dyt import TorchDyT
24
+
25
+ BT = input.x
26
+ provider = input.kernel_provider
27
+ mode = input.kernel_operation_mode
28
+ extra_benchmark_config = input.extra_benchmark_config
29
+ hidden_size = extra_benchmark_config["hidden_size"]
30
+ dtype = extra_benchmark_config["dtype"]
31
+
32
+ x_shape = (BT, hidden_size)
33
+ torch_dyt = TorchDyT(hidden_size=hidden_size).to(device)
34
+ torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size).to(device))
35
+ triton_dyt = LigerDyT(hidden_size=hidden_size).to(device)
36
+
37
+ x = torch.randn(x_shape, dtype=dtype, device=device)
38
+ dy = torch.randn_like(x)
39
+ x.requires_grad_(True)
40
+
41
+ def fwd():
42
+ if provider == "liger":
43
+ return triton_dyt(x)
44
+ elif provider == "torch":
45
+ return torch_dyt(x)
46
+ elif provider == "torch_compile":
47
+ return torch_compile_dyt(x)
48
+
49
+ if mode == "forward":
50
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500)
51
+ elif mode == "backward":
52
+ y = fwd()
53
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(
54
+ lambda: y.backward(dy, retain_graph=True),
55
+ quantiles=QUANTILES,
56
+ grad_to_none=[x],
57
+ rep=500,
58
+ )
59
+ elif mode == "full":
60
+
61
+ def full():
62
+ y = fwd()
63
+ y.backward(dy)
64
+
65
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=[x], rep=500)
66
+
67
+ return SingleBenchmarkRunOutput(
68
+ y_20=ms_20,
69
+ y_50=ms_50,
70
+ y_80=ms_80,
71
+ )
72
+
73
+
74
+ def bench_memory_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
75
+ from test.transformers.test_dyt import LigerDyT
76
+ from test.transformers.test_dyt import TorchDyT
77
+
78
+ BT = input.x
79
+ provider = input.kernel_provider
80
+ extra_benchmark_config = input.extra_benchmark_config
81
+ hidden_size = extra_benchmark_config["hidden_size"]
82
+ dtype = extra_benchmark_config["dtype"]
83
+
84
+ x_shape = (BT, hidden_size)
85
+ torch_dyt = TorchDyT(hidden_size=hidden_size).to(device)
86
+ torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size).to(device))
87
+ triton_dyt = LigerDyT(hidden_size=hidden_size).to(device)
88
+
89
+ x = torch.randn(x_shape, dtype=dtype, device=device)
90
+ dy = torch.randn_like(x)
91
+ x.requires_grad_(True)
92
+
93
+ def fwd():
94
+ if provider == "liger":
95
+ return triton_dyt(x)
96
+ elif provider == "torch":
97
+ return torch_dyt(x)
98
+ elif provider == "torch_compile":
99
+ return torch_compile_dyt(x)
100
+
101
+ def full():
102
+ y = fwd()
103
+ y.backward(dy, retain_graph=True)
104
+
105
+ mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
106
+ return SingleBenchmarkRunOutput(
107
+ y_20=mem_20,
108
+ y_50=mem_50,
109
+ y_80=mem_80,
110
+ )
111
+
112
+
113
+ if __name__ == "__main__":
114
+ args = parse_benchmark_script_args()
115
+
116
+ common_configs = {
117
+ "kernel_name": "dyt",
118
+ "x_name": "BT",
119
+ "x_label": "batch_size * seq_len",
120
+ "x_values": [2**i for i in range(10, 15)],
121
+ "kernel_providers": ["liger", "torch", "torch_compile"],
122
+ "extra_benchmark_configs": [{"hidden_size": 4096, "dtype": torch.float32}],
123
+ "overwrite": args.overwrite,
124
+ }
125
+
126
+ run_benchmarks(
127
+ bench_test_fn=bench_speed_dyt,
128
+ kernel_operation_modes=["forward", "backward", "full"],
129
+ metric_name="speed",
130
+ metric_unit="ms",
131
+ **common_configs,
132
+ )
133
+ run_benchmarks(
134
+ bench_test_fn=bench_memory_dyt,
135
+ kernel_operation_modes=["full"],
136
+ metric_name="memory",
137
+ metric_unit="MB",
138
+ **common_configs,
139
+ )
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.5.dev20250320214749"
7
+ version = "0.5.5.dev20250324181221"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -0,0 +1,225 @@
1
+ import operator
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import compare_version
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+ from liger_kernel.ops.utils import infer_device
11
+
12
+ if compare_version("triton", operator.ge, "3.0.0"):
13
+ try:
14
+ # typical import path with dispatch available
15
+ from triton.language.extra.libdevice import tanh
16
+ except ModuleNotFoundError:
17
+ # for working with NGC containers
18
+ from triton.language.extra.cuda.libdevice import tanh
19
+ else:
20
+ from triton.language.math import tanh
21
+
22
+
23
+ @triton.jit
24
+ def _dyt_fwd_kernel(
25
+ x_ptr,
26
+ x_row_stride,
27
+ alpha_ptr,
28
+ gamma_ptr,
29
+ beta_ptr,
30
+ y_ptr,
31
+ y_row_stride,
32
+ n_cols,
33
+ BLOCK_SIZE: tl.constexpr,
34
+ ):
35
+ """
36
+ Reference:
37
+ https://arxiv.org/abs/2503.10622
38
+
39
+ Shapes:
40
+ - x: (BT, C)
41
+ - alpha: (1)
42
+ - gamma: (C)
43
+ - beta: (C)
44
+ """
45
+ row_idx = tl.program_id(0)
46
+ offsets = tl.arange(0, BLOCK_SIZE)
47
+ mask = offsets < n_cols
48
+
49
+ x_ptr += row_idx * x_row_stride
50
+ y_ptr += row_idx * y_row_stride
51
+
52
+ alpha = tl.load(alpha_ptr)
53
+ gamma = tl.load(gamma_ptr + offsets, mask=mask)
54
+ beta = tl.load(beta_ptr + offsets, mask=mask)
55
+ x = tl.load(x_ptr + offsets, mask=mask)
56
+ y = gamma * tanh((alpha * x).cast(tl.float32)) + beta
57
+ tl.store(y_ptr + offsets, y, mask=mask)
58
+
59
+
60
+ @triton.jit
61
+ def _dyt_bwd_kernel(
62
+ x_ptr,
63
+ x_row_stride,
64
+ dy_ptr,
65
+ dy_row_stride,
66
+ dx_ptr,
67
+ dx_row_stride,
68
+ alpha_ptr,
69
+ dalpha_ptr,
70
+ gamma_ptr,
71
+ dgamma_ptr,
72
+ dgamma_row_stride,
73
+ n_cols,
74
+ n_rows,
75
+ ROWS_PER_PROGRAM: tl.constexpr,
76
+ BLOCK_SIZE: tl.constexpr,
77
+ ):
78
+ """
79
+ Reference:
80
+ https://arxiv.org/abs/2503.10622
81
+
82
+ Shapes:
83
+ - x: (BT, C)
84
+ - alpha: (1)
85
+ - gamma: (C)
86
+ - dx: (BT, C)
87
+ - dy: (BT, C)
88
+ - dgamma: (sm_count, C)
89
+ - dalpha: (sm_count,)
90
+ """
91
+ # d(gamma * tanh(alpha * x) + beta) / dx
92
+ # = gamma * (1 - tanh^2(alpha * x)) * alpha
93
+ # d(gamma * tanh(alpha * x) + beta) / dalpha
94
+ # = gamma * (1 - tanh^2(alpha * x)) * x
95
+ # d(gamma * tanh(alpha * x) + beta) / dgamma
96
+ # = tanh(alpha * x)
97
+ # d(gamma * tanh(alpha * x)) / dbeta = 1
98
+ pid = tl.program_id(0)
99
+
100
+ row_start = pid * ROWS_PER_PROGRAM
101
+ row_end = min((pid + 1) * ROWS_PER_PROGRAM, n_rows)
102
+ offsets = tl.arange(0, BLOCK_SIZE)
103
+ mask = offsets < n_cols
104
+
105
+ dalpha = 0.0
106
+ dgamma = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
107
+
108
+ x_ptr += row_start * x_row_stride
109
+ dx_ptr += row_start * dx_row_stride
110
+ dy_ptr += row_start * dy_row_stride
111
+ alpha = tl.load(alpha_ptr)
112
+ gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0)
113
+
114
+ for _ in tl.range(row_start, row_end):
115
+ dy = tl.load(dy_ptr + offsets, mask=mask, other=0.0)
116
+ x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
117
+ tanh_ax = tanh((alpha * x).cast(tl.float32))
118
+ sech2_ax = 1 - tanh_ax * tanh_ax
119
+
120
+ dx = dy * gamma * sech2_ax * alpha
121
+ dalpha += tl.sum(dy * gamma * sech2_ax * x)
122
+ dgamma += dy * tanh_ax
123
+ tl.store(dx_ptr + offsets, dx, mask=mask)
124
+
125
+ dy_ptr += dy_row_stride
126
+ x_ptr += x_row_stride
127
+ dx_ptr += dx_row_stride
128
+
129
+ tl.store(dgamma_ptr + pid * dgamma_row_stride + offsets, dgamma, mask=mask)
130
+ tl.store(dalpha_ptr + pid, dalpha)
131
+
132
+ pass
133
+
134
+
135
+ def liger_dyt_fwd(x, alpha, gamma, beta):
136
+ shape = x.shape
137
+ dim = shape[-1]
138
+ x = x.view(-1, dim)
139
+ n_rows, n_cols = x.shape
140
+ y = torch.empty_like(x)
141
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
142
+ _dyt_fwd_kernel[(n_rows,)](
143
+ x_ptr=x,
144
+ alpha_ptr=alpha,
145
+ gamma_ptr=gamma,
146
+ beta_ptr=beta,
147
+ y_ptr=y,
148
+ x_row_stride=x.stride(0),
149
+ y_row_stride=y.stride(0),
150
+ n_cols=n_cols,
151
+ BLOCK_SIZE=BLOCK_SIZE,
152
+ num_warps=num_warps,
153
+ )
154
+ return y.view(*shape)
155
+
156
+
157
+ def liger_dyt_bwd(dy, x, alpha, gamma):
158
+ shape = dy.shape
159
+ dtype = x.dtype
160
+ dim = shape[-1]
161
+ dy = dy.view(-1, dim)
162
+ x = x.view(-1, dim)
163
+ n_rows, n_cols = dy.shape
164
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
165
+ sm_count = 1
166
+ device = infer_device()
167
+ if device == "cuda":
168
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
169
+ elif device == "xpu":
170
+ sm_count = torch.xpu.get_device_properties(x.device).gpu_subslice_count
171
+ if n_cols > BLOCK_SIZE:
172
+ raise RuntimeError(
173
+ f"Feature dimension {dim} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
174
+ )
175
+
176
+ dx = torch.empty_like(x, dtype=torch.float32)
177
+ _dalpha = torch.empty((sm_count,), dtype=torch.float32, device=x.device)
178
+ _dgamma = torch.empty((sm_count, n_cols), dtype=torch.float32, device=x.device)
179
+
180
+ grid = (sm_count,)
181
+ rows_per_program = triton.cdiv(n_rows, sm_count)
182
+ _dyt_bwd_kernel[grid](
183
+ x_ptr=x,
184
+ x_row_stride=x.stride(0),
185
+ dy_ptr=dy,
186
+ dy_row_stride=dy.stride(0),
187
+ dx_ptr=dx,
188
+ dx_row_stride=dx.stride(0),
189
+ alpha_ptr=alpha,
190
+ dalpha_ptr=_dalpha,
191
+ gamma_ptr=gamma,
192
+ dgamma_ptr=_dgamma,
193
+ dgamma_row_stride=_dgamma.stride(0),
194
+ n_cols=n_cols,
195
+ n_rows=n_rows,
196
+ ROWS_PER_PROGRAM=rows_per_program,
197
+ BLOCK_SIZE=BLOCK_SIZE,
198
+ num_warps=num_warps,
199
+ )
200
+ dalpha = _dalpha.sum(dim=0, keepdim=True).to(dtype)
201
+ dgamma = _dgamma.sum(dim=0).to(dtype)
202
+ dbeta = dy.sum(dim=0).to(dtype)
203
+ return dx.view(*shape), dalpha, dgamma, dbeta
204
+
205
+
206
+ class LigerDyTFunction(torch.autograd.Function):
207
+ @staticmethod
208
+ @ensure_contiguous
209
+ def forward(ctx, x, alpha, gamma, beta):
210
+ y = liger_dyt_fwd(x, alpha, gamma, beta)
211
+ ctx.save_for_backward(x, alpha, gamma)
212
+ return y
213
+
214
+ @staticmethod
215
+ @ensure_contiguous
216
+ def backward(ctx, grad_output):
217
+ x, alpha, gamma = ctx.saved_tensors
218
+ dx, dalpha, dgamma, dbeta = liger_dyt_bwd(
219
+ grad_output,
220
+ x,
221
+ alpha,
222
+ gamma,
223
+ )
224
+
225
+ return (dx, dalpha, dgamma, dbeta)
@@ -1,5 +1,6 @@
1
1
  from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
2
2
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
3
+ from liger_kernel.transformers.dyt import LigerDyT # noqa: F401
3
4
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401
4
5
  from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
5
6
  from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
@@ -0,0 +1,20 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from liger_kernel.ops.dyt import LigerDyTFunction
5
+
6
+
7
+ class LigerDyT(nn.Module):
8
+ def __init__(self, hidden_size, init_alpha=0.5):
9
+ super().__init__()
10
+ self.hidden_size = hidden_size
11
+ self.init_alpha = init_alpha
12
+ self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
13
+ self.gamma = nn.Parameter(torch.ones(hidden_size))
14
+ self.beta = nn.Parameter(torch.zeros(hidden_size))
15
+
16
+ def forward(self, x):
17
+ return LigerDyTFunction.apply(x, self.alpha, self.gamma, self.beta)
18
+
19
+ def extra_repr(self):
20
+ return f"{self.hidden_size}, init_alpha={self.init_alpha}"
@@ -1,6 +1,7 @@
1
1
  from typing import Optional
2
2
 
3
3
  from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
4
+ from liger_kernel.ops.dyt import LigerDyTFunction
4
5
  from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
5
6
  from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
6
7
  from liger_kernel.ops.geglu import LigerGELUMulFunction
@@ -192,3 +193,7 @@ def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
192
193
 
193
194
  def liger_swiglu(a, b):
194
195
  return LigerSiLUMulFunction.apply(a, b)
196
+
197
+
198
+ def liger_dyt(x, alpha, gamma, beta):
199
+ return LigerDyTFunction.apply(x, alpha, gamma, beta)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.5.dev20250320214749
3
+ Version: 0.5.5.dev20250324181221
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -176,7 +176,7 @@ y = orpo_loss(lm_head.weight, x, target)
176
176
  - **Exact:** Computation is exact—no approximations! Both forward and backward passes are implemented with rigorous unit tests and undergo convergence testing against training runs without Liger Kernel to ensure accuracy.
177
177
  - **Lightweight:** Liger Kernel has minimal dependencies, requiring only Torch and Triton—no extra libraries needed! Say goodbye to dependency headaches!
178
178
  - **Multi-GPU supported:** Compatible with multi-GPU setups (PyTorch FSDP, DeepSpeed, DDP, etc.).
179
- - **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860), [SWIFT](https://github.com/modelscope/ms-swift)
179
+ - **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860), [SWIFT](https://github.com/modelscope/ms-swift), [oumi](https://github.com/oumi-ai/oumi/tree/main)
180
180
 
181
181
  ## Installation
182
182
 
@@ -386,8 +386,8 @@ loss.backward()
386
386
  ## Contact
387
387
 
388
388
  - For issues, create a Github ticket in this repository
389
- - For open discussion, join [our discord channel](https://discord.gg/gpumode)
390
- - For formal collaboration, send an email to yannchen@linkedin.com
389
+ - For open discussion, join [our discord channel on GPUMode](https://discord.com/channels/1189498204333543425/1275130785933951039)
390
+ - For formal collaboration, send an email to yannchen@linkedin.com and hning@linkedin.com
391
391
 
392
392
  ## Cite this work
393
393
 
@@ -24,6 +24,7 @@ benchmark/scripts/benchmark_cpo_loss.py
24
24
  benchmark/scripts/benchmark_cross_entropy.py
25
25
  benchmark/scripts/benchmark_distill_jsd_loss.py
26
26
  benchmark/scripts/benchmark_dpo_loss.py
27
+ benchmark/scripts/benchmark_dyt.py
27
28
  benchmark/scripts/benchmark_embedding.py
28
29
  benchmark/scripts/benchmark_fused_linear_cross_entropy.py
29
30
  benchmark/scripts/benchmark_fused_linear_jsd.py
@@ -121,6 +122,7 @@ src/liger_kernel/chunked_loss/orpo_loss.py
121
122
  src/liger_kernel/chunked_loss/simpo_loss.py
122
123
  src/liger_kernel/ops/__init__.py
123
124
  src/liger_kernel/ops/cross_entropy.py
125
+ src/liger_kernel/ops/dyt.py
124
126
  src/liger_kernel/ops/fused_linear_cross_entropy.py
125
127
  src/liger_kernel/ops/fused_linear_jsd.py
126
128
  src/liger_kernel/ops/geglu.py
@@ -139,6 +141,7 @@ src/liger_kernel/ops/experimental/mm_int8int2.py
139
141
  src/liger_kernel/transformers/__init__.py
140
142
  src/liger_kernel/transformers/auto_model.py
141
143
  src/liger_kernel/transformers/cross_entropy.py
144
+ src/liger_kernel/transformers/dyt.py
142
145
  src/liger_kernel/transformers/functional.py
143
146
  src/liger_kernel/transformers/fused_linear_cross_entropy.py
144
147
  src/liger_kernel/transformers/fused_linear_jsd.py
@@ -209,6 +212,7 @@ test/resources/tiny_shakespeare_tokenized/dataset_info.json
209
212
  test/resources/tiny_shakespeare_tokenized/state.json
210
213
  test/transformers/test_auto_model.py
211
214
  test/transformers/test_cross_entropy.py
215
+ test/transformers/test_dyt.py
212
216
  test/transformers/test_embedding.py
213
217
  test/transformers/test_flex_attention.py
214
218
  test/transformers/test_fused_linear_cross_entropy.py
@@ -0,0 +1,136 @@
1
+ import pytest
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from test.utils import assert_verbose_allclose
6
+ from test.utils import infer_device
7
+ from test.utils import set_seed
8
+ from test.utils import supports_bfloat16
9
+
10
+ from liger_kernel.ops.dyt import LigerDyTFunction
11
+ from liger_kernel.transformers.dyt import LigerDyT
12
+ from liger_kernel.transformers.functional import liger_dyt
13
+
14
+
15
+ class TorchDyT(nn.Module):
16
+ def __init__(self, hidden_size, init_alpha=0.5):
17
+ super().__init__()
18
+ self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
19
+ self.gamma = nn.Parameter(torch.ones(hidden_size))
20
+ self.beta = nn.Parameter(torch.zeros(hidden_size))
21
+
22
+ def forward(self, x):
23
+ return self.gamma * torch.tanh(self.alpha * x) + self.beta
24
+
25
+
26
+ set_seed(42)
27
+ device = infer_device()
28
+
29
+
30
+ @pytest.mark.parametrize("init_alpha", [0.5, 0.2, 1.0])
31
+ @pytest.mark.parametrize(
32
+ "B, T, hidden_size",
33
+ [
34
+ (2, 8, 4096),
35
+ (4, 16, 2048),
36
+ (1, 1, 1023), # Minimal batch/seq with near power-of-2 hidden
37
+ (3, 7, 256), # Prime numbers for batch/seq
38
+ ],
39
+ )
40
+ @pytest.mark.parametrize(
41
+ "dtype, atol, rtol",
42
+ [
43
+ (torch.float32, 1e-5, 1e-5),
44
+ ],
45
+ )
46
+ def test_liger_dyt_correctness(B, T, hidden_size, init_alpha, dtype, atol, rtol):
47
+ _input = torch.randn(B, T, hidden_size, device=device, dtype=dtype)
48
+
49
+ x1 = _input.clone().requires_grad_(True)
50
+ x2 = _input.clone().requires_grad_(True)
51
+
52
+ # initialize weights
53
+ alpha = torch.randn(1, device=device, dtype=dtype)
54
+ gamma = torch.randn(hidden_size, device=device, dtype=dtype)
55
+ beta = torch.randn(hidden_size, device=device, dtype=dtype)
56
+
57
+ torch_dyt = TorchDyT(hidden_size=hidden_size, init_alpha=init_alpha).to(device).to(dtype)
58
+ torch_dyt.alpha.data = alpha.clone()
59
+ torch_dyt.gamma.data = gamma.clone()
60
+ torch_dyt.beta.data = beta.clone()
61
+
62
+ liger_dyt = LigerDyT(hidden_size=hidden_size, init_alpha=init_alpha).to(device).to(dtype)
63
+ liger_dyt.alpha.data = alpha.clone()
64
+ liger_dyt.gamma.data = gamma.clone()
65
+ liger_dyt.beta.data = beta.clone()
66
+
67
+ torch_output = torch_dyt(x1)
68
+ liger_output = liger_dyt(x2)
69
+
70
+ assert_verbose_allclose(torch_output, liger_output, rtol=rtol, atol=atol)
71
+
72
+ grad_output = torch.randn_like(_input)
73
+ torch_output.backward(grad_output)
74
+ liger_output.backward(grad_output)
75
+
76
+ assert_verbose_allclose(x1.grad, x2.grad, rtol=rtol, atol=atol)
77
+ assert_verbose_allclose(torch_dyt.alpha.grad, liger_dyt.alpha.grad, rtol=rtol, atol=atol)
78
+ assert_verbose_allclose(torch_dyt.gamma.grad, liger_dyt.gamma.grad, rtol=rtol, atol=atol)
79
+ assert_verbose_allclose(torch_dyt.beta.grad, liger_dyt.beta.grad, rtol=rtol, atol=atol)
80
+
81
+
82
+ @pytest.mark.parametrize(
83
+ "B, T, hidden_size",
84
+ [
85
+ (2, 8, 4096),
86
+ (4, 16, 2048),
87
+ (1, 1, 1023), # Minimal batch/seq with near power-of-2 hidden
88
+ (3, 7, 256), # Prime numbers for batch/seq
89
+ ],
90
+ )
91
+ @pytest.mark.parametrize(
92
+ "dtype, atol, rtol",
93
+ [
94
+ # atol is for small values: they have more difference, so set atol higher
95
+ # rtol is for larger values: they are very close, so set rtol lower
96
+ (torch.float32, 1e-5, 1e-5),
97
+ pytest.param(
98
+ torch.bfloat16,
99
+ 1e-8,
100
+ 5e-2,
101
+ marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
102
+ ),
103
+ ],
104
+ )
105
+ def test_liger_dyt_functional(B, T, hidden_size, dtype, atol, rtol):
106
+ _input = torch.randn(B, T, hidden_size, device=device, dtype=dtype)
107
+
108
+ x1 = _input.clone().requires_grad_(True)
109
+ x2 = _input.clone().requires_grad_(True)
110
+
111
+ # initialize weights
112
+ alpha = torch.randn(1, device=device, dtype=dtype)
113
+ gamma = torch.randn(hidden_size, device=device, dtype=dtype)
114
+ beta = torch.randn(hidden_size, device=device, dtype=dtype)
115
+
116
+ alpha1 = alpha.clone().requires_grad_(True)
117
+ gamma1 = gamma.clone().requires_grad_(True)
118
+ beta1 = beta.clone().requires_grad_(True)
119
+
120
+ alpha2 = alpha.clone().requires_grad_(True)
121
+ gamma2 = gamma.clone().requires_grad_(True)
122
+ beta2 = beta.clone().requires_grad_(True)
123
+
124
+ output1 = liger_dyt(x1, alpha=alpha1, gamma=gamma1, beta=beta1)
125
+ output2 = LigerDyTFunction.apply(x2, alpha2, gamma2, beta2)
126
+
127
+ assert_verbose_allclose(output1, output2, rtol=rtol, atol=atol)
128
+
129
+ grad_output = torch.randn_like(_input)
130
+ output1.backward(grad_output)
131
+ output2.backward(grad_output)
132
+
133
+ assert_verbose_allclose(x1.grad, x2.grad, rtol=rtol, atol=atol)
134
+ assert_verbose_allclose(alpha1.grad, alpha2.grad, rtol=rtol, atol=atol)
135
+ assert_verbose_allclose(gamma1.grad, gamma2.grad, rtol=rtol, atol=atol)
136
+ assert_verbose_allclose(beta1.grad, beta2.grad, rtol=rtol, atol=atol)