liger-kernel-nightly 0.5.5.dev20250322021112__tar.gz → 0.5.5.dev20250324181221__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of liger-kernel-nightly might be problematic. Click here for more details.

Files changed (236) hide show
  1. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/PKG-INFO +1 -1
  2. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/benchmarks_visualizer.py +2 -2
  3. liger_kernel_nightly-0.5.5.dev20250324181221/benchmark/scripts/benchmark_dyt.py +139 -0
  4. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/pyproject.toml +1 -1
  5. liger_kernel_nightly-0.5.5.dev20250324181221/src/liger_kernel/ops/dyt.py +225 -0
  6. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/__init__.py +1 -0
  7. liger_kernel_nightly-0.5.5.dev20250324181221/src/liger_kernel/transformers/dyt.py +20 -0
  8. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/functional.py +5 -0
  9. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel_nightly.egg-info/PKG-INFO +1 -1
  10. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel_nightly.egg-info/SOURCES.txt +4 -0
  11. liger_kernel_nightly-0.5.5.dev20250324181221/test/transformers/test_dyt.py +136 -0
  12. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/.github/ISSUE_TEMPLATE/bug_report.yaml +0 -0
  13. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/.github/ISSUE_TEMPLATE/feature_request.yaml +0 -0
  14. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/.github/pull_request_template.md +0 -0
  15. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/.github/workflows/amd-ci.yml +0 -0
  16. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/.github/workflows/docs.yml +0 -0
  17. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/.github/workflows/intel-ci.yml +0 -0
  18. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/.github/workflows/nvi-ci.yml +0 -0
  19. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/.github/workflows/publish-nightly.yml +0 -0
  20. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/.github/workflows/publish-release.yml +0 -0
  21. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/.gitignore +0 -0
  22. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/LICENSE +0 -0
  23. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/Makefile +0 -0
  24. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/NOTICE +0 -0
  25. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/README.md +0 -0
  26. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/README.md +0 -0
  27. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/__init__.py +0 -0
  28. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/data/all_benchmark_data.csv +0 -0
  29. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/__init__.py +0 -0
  30. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_cpo_loss.py +0 -0
  31. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_cross_entropy.py +0 -0
  32. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_distill_jsd_loss.py +0 -0
  33. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_dpo_loss.py +0 -0
  34. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_embedding.py +0 -0
  35. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_fused_linear_cross_entropy.py +0 -0
  36. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_fused_linear_jsd.py +0 -0
  37. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_geglu.py +0 -0
  38. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_group_norm.py +0 -0
  39. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_jsd.py +0 -0
  40. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_kl_div.py +0 -0
  41. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_kto_loss.py +0 -0
  42. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_layer_norm.py +0 -0
  43. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_orpo_loss.py +0 -0
  44. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_qwen2vl_mrope.py +0 -0
  45. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_rms_norm.py +0 -0
  46. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_rope.py +0 -0
  47. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_simpo_loss.py +0 -0
  48. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_swiglu.py +0 -0
  49. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/benchmark_tvd.py +0 -0
  50. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/benchmark/scripts/utils.py +0 -0
  51. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/dev/fmt-requirements.txt +0 -0
  52. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/dev/modal/tests.py +0 -0
  53. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/dev/modal/tests_bwd.py +0 -0
  54. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/Examples.md +0 -0
  55. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/Getting-Started.md +0 -0
  56. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/High-Level-APIs.md +0 -0
  57. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/Low-Level-APIs.md +0 -0
  58. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/acknowledgement.md +0 -0
  59. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/contributing.md +0 -0
  60. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/images/banner.GIF +0 -0
  61. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/images/compose.gif +0 -0
  62. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/images/e2e-memory.png +0 -0
  63. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/images/e2e-tps.png +0 -0
  64. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/images/logo-banner.png +0 -0
  65. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/images/patch.gif +0 -0
  66. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/images/post-training.png +0 -0
  67. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/index.md +0 -0
  68. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/docs/license.md +0 -0
  69. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/alignment/accelerate_config.yaml +0 -0
  70. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/alignment/run_orpo.py +0 -0
  71. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/README.md +0 -0
  72. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/callback.py +0 -0
  73. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/config/fsdp_config.json +0 -0
  74. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/img/gemma_7b_mem.png +0 -0
  75. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/img/gemma_7b_tp.png +0 -0
  76. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/img/llama_mem_alloc.png +0 -0
  77. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/img/llama_tps.png +0 -0
  78. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/img/qwen_mem_alloc.png +0 -0
  79. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/img/qwen_tps.png +0 -0
  80. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/launch_on_modal.py +0 -0
  81. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/requirements.txt +0 -0
  82. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/run_benchmarks.sh +0 -0
  83. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/run_gemma.sh +0 -0
  84. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/run_llama.sh +0 -0
  85. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/run_qwen.sh +0 -0
  86. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/run_qwen2_vl.sh +0 -0
  87. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/training.py +0 -0
  88. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/huggingface/training_multimodal.py +0 -0
  89. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/lightning/README.md +0 -0
  90. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/lightning/requirements.txt +0 -0
  91. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/lightning/training.py +0 -0
  92. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/README.md +0 -0
  93. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/callback.py +0 -0
  94. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/docs/images/Memory_Stage1_num_head_3.png +0 -0
  95. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/docs/images/Memory_Stage1_num_head_5.png +0 -0
  96. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/docs/images/Memory_Stage2_num_head_3.png +0 -0
  97. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/docs/images/Memory_Stage2_num_head_5.png +0 -0
  98. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/docs/images/Throughput_Stage1_num_head_3.png +0 -0
  99. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/docs/images/Throughput_Stage1_num_head_5.png +0 -0
  100. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/docs/images/Throughput_Stage2_num_head_3.png +0 -0
  101. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/docs/images/Throughput_Stage2_num_head_5.png +0 -0
  102. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/fsdp/acc-fsdp.conf +0 -0
  103. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/medusa_util.py +0 -0
  104. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/requirements.txt +0 -0
  105. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/scripts/llama3_8b_medusa.sh +0 -0
  106. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/examples/medusa/train.py +0 -0
  107. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/licenses/LICENSE-Apache-2.0 +0 -0
  108. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/licenses/LICENSE-MIT-AutoAWQ +0 -0
  109. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/licenses/LICENSE-MIT-Efficient-Cross-Entropy +0 -0
  110. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/licenses/LICENSE-MIT-llmc +0 -0
  111. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/licenses/LICENSE-MIT-triton +0 -0
  112. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/mkdocs.yml +0 -0
  113. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/setup.cfg +0 -0
  114. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/setup.py +0 -0
  115. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/__init__.py +0 -0
  116. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/README.md +0 -0
  117. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/__init__.py +0 -0
  118. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/cpo_loss.py +0 -0
  119. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/dpo_loss.py +0 -0
  120. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/functional.py +0 -0
  121. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/fused_linear_distillation.py +0 -0
  122. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/fused_linear_preference.py +0 -0
  123. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/fused_linear_rlhf.py +0 -0
  124. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py +0 -0
  125. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/grpo_loss.py +0 -0
  126. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/jsd_loss.py +0 -0
  127. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/kto_loss.py +0 -0
  128. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/orpo_loss.py +0 -0
  129. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/chunked_loss/simpo_loss.py +0 -0
  130. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/env_report.py +0 -0
  131. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/__init__.py +0 -0
  132. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/cross_entropy.py +0 -0
  133. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/experimental/embedding.py +0 -0
  134. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/experimental/mm_int8int2.py +0 -0
  135. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/fused_linear_cross_entropy.py +0 -0
  136. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/fused_linear_jsd.py +0 -0
  137. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/geglu.py +0 -0
  138. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/group_norm.py +0 -0
  139. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/jsd.py +0 -0
  140. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/kl_div.py +0 -0
  141. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/layer_norm.py +0 -0
  142. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/qwen2vl_mrope.py +0 -0
  143. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/rms_norm.py +0 -0
  144. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/rope.py +0 -0
  145. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/swiglu.py +0 -0
  146. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/tvd.py +0 -0
  147. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/ops/utils.py +0 -0
  148. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/auto_model.py +0 -0
  149. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/cross_entropy.py +0 -0
  150. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/experimental/embedding.py +0 -0
  151. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/fused_linear_cross_entropy.py +0 -0
  152. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/fused_linear_jsd.py +0 -0
  153. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/geglu.py +0 -0
  154. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/group_norm.py +0 -0
  155. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/jsd.py +0 -0
  156. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/kl_div.py +0 -0
  157. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/layer_norm.py +0 -0
  158. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/__init__.py +0 -0
  159. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/gemma.py +0 -0
  160. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/gemma2.py +0 -0
  161. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/llama.py +0 -0
  162. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/loss_utils.py +0 -0
  163. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/mistral.py +0 -0
  164. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/mixtral.py +0 -0
  165. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/mllama.py +0 -0
  166. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/olmo2.py +0 -0
  167. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/paligemma.py +0 -0
  168. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/phi3.py +0 -0
  169. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/qwen2.py +0 -0
  170. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/qwen2_5_vl.py +0 -0
  171. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/model/qwen2_vl.py +0 -0
  172. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/monkey_patch.py +0 -0
  173. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/qwen2vl_mrope.py +0 -0
  174. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/rms_norm.py +0 -0
  175. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/rope.py +0 -0
  176. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/swiglu.py +0 -0
  177. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/trainer/__init__.py +0 -0
  178. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/trainer/orpo_trainer.py +0 -0
  179. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/trainer_integration.py +0 -0
  180. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/transformers/tvd.py +0 -0
  181. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/triton/__init__.py +0 -0
  182. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/triton/monkey_patch.py +0 -0
  183. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel/utils.py +0 -0
  184. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel_nightly.egg-info/dependency_links.txt +0 -0
  185. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel_nightly.egg-info/requires.txt +0 -0
  186. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/src/liger_kernel_nightly.egg-info/top_level.txt +0 -0
  187. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/__init__.py +0 -0
  188. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/chunked_loss/__init__.py +0 -0
  189. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/chunked_loss/test_cpo_loss.py +0 -0
  190. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/chunked_loss/test_dpo_loss.py +0 -0
  191. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/chunked_loss/test_grpo_loss.py +0 -0
  192. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/chunked_loss/test_jsd_loss.py +0 -0
  193. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/chunked_loss/test_kto_loss.py +0 -0
  194. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/chunked_loss/test_orpo_loss.py +0 -0
  195. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/chunked_loss/test_simpo_loss.py +0 -0
  196. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/conftest.py +0 -0
  197. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/convergence/__init__.py +0 -0
  198. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/convergence/bf16/__init__.py +0 -0
  199. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/convergence/bf16/test_mini_models.py +0 -0
  200. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/convergence/bf16/test_mini_models_multimodal.py +0 -0
  201. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/convergence/bf16/test_mini_models_with_logits.py +0 -0
  202. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/convergence/fp32/__init__.py +0 -0
  203. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/convergence/fp32/test_mini_models.py +0 -0
  204. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/convergence/fp32/test_mini_models_multimodal.py +0 -0
  205. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/convergence/fp32/test_mini_models_with_logits.py +0 -0
  206. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/resources/fake_configs/Google/Paligemma/paligemma-3b-pt-224/tokenizer_config.json +0 -0
  207. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/resources/fake_configs/Qwen/Qwen2-VL-7B-Instruct/tokenizer_config.json +0 -0
  208. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/resources/fake_configs/Qwen/Qwen2.5-VL-7B-Instruct/tokenizer_config.json +0 -0
  209. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/resources/fake_configs/meta-llama/Llama-3.2-11B-Vision-Instruct/tokenizer_config.json +0 -0
  210. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/resources/scripts/generate_tokenized_dataset.py +0 -0
  211. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/resources/tiny_shakespeare.txt +0 -0
  212. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/resources/tiny_shakespeare_tokenized/data-00000-of-00001.arrow +0 -0
  213. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/resources/tiny_shakespeare_tokenized/dataset_info.json +0 -0
  214. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/resources/tiny_shakespeare_tokenized/state.json +0 -0
  215. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_auto_model.py +0 -0
  216. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_cross_entropy.py +0 -0
  217. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_embedding.py +0 -0
  218. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_flex_attention.py +0 -0
  219. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_fused_linear_cross_entropy.py +0 -0
  220. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_fused_linear_jsd.py +0 -0
  221. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_geglu.py +0 -0
  222. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_group_norm.py +0 -0
  223. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_jsd.py +0 -0
  224. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_kl_div.py +0 -0
  225. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_layer_norm.py +0 -0
  226. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_mm_int8int2.py +0 -0
  227. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_monkey_patch.py +0 -0
  228. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_qwen2vl_mrope.py +0 -0
  229. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_rms_norm.py +0 -0
  230. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_rope.py +0 -0
  231. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_swiglu.py +0 -0
  232. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_trainer_integration.py +0 -0
  233. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_transformers.py +0 -0
  234. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/transformers/test_tvd.py +0 -0
  235. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/triton/test_triton_monkey_patch.py +0 -0
  236. {liger_kernel_nightly-0.5.5.dev20250322021112 → liger_kernel_nightly-0.5.5.dev20250324181221}/test/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.5.dev20250322021112
3
+ Version: 0.5.5.dev20250324181221
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -8,8 +8,8 @@ import matplotlib.pyplot as plt
8
8
  import pandas as pd
9
9
  import seaborn as sns
10
10
 
11
- DATA_PATH = "data/all_benchmark_data.csv"
12
- VISUALIZATIONS_PATH = "visualizations/"
11
+ DATA_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "data/all_benchmark_data.csv"))
12
+ VISUALIZATIONS_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "visualizations/"))
13
13
 
14
14
 
15
15
  @dataclass
@@ -0,0 +1,139 @@
1
+ import os
2
+ import sys
3
+
4
+ import torch
5
+ import triton
6
+
7
+ from utils import QUANTILES
8
+ from utils import SingleBenchmarkRunInput
9
+ from utils import SingleBenchmarkRunOutput
10
+ from utils import _test_memory
11
+ from utils import parse_benchmark_script_args
12
+ from utils import run_benchmarks
13
+
14
+ from liger_kernel.utils import infer_device
15
+
16
+ device = infer_device()
17
+
18
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
19
+
20
+
21
+ def bench_speed_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
22
+ from test.transformers.test_dyt import LigerDyT
23
+ from test.transformers.test_dyt import TorchDyT
24
+
25
+ BT = input.x
26
+ provider = input.kernel_provider
27
+ mode = input.kernel_operation_mode
28
+ extra_benchmark_config = input.extra_benchmark_config
29
+ hidden_size = extra_benchmark_config["hidden_size"]
30
+ dtype = extra_benchmark_config["dtype"]
31
+
32
+ x_shape = (BT, hidden_size)
33
+ torch_dyt = TorchDyT(hidden_size=hidden_size).to(device)
34
+ torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size).to(device))
35
+ triton_dyt = LigerDyT(hidden_size=hidden_size).to(device)
36
+
37
+ x = torch.randn(x_shape, dtype=dtype, device=device)
38
+ dy = torch.randn_like(x)
39
+ x.requires_grad_(True)
40
+
41
+ def fwd():
42
+ if provider == "liger":
43
+ return triton_dyt(x)
44
+ elif provider == "torch":
45
+ return torch_dyt(x)
46
+ elif provider == "torch_compile":
47
+ return torch_compile_dyt(x)
48
+
49
+ if mode == "forward":
50
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500)
51
+ elif mode == "backward":
52
+ y = fwd()
53
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(
54
+ lambda: y.backward(dy, retain_graph=True),
55
+ quantiles=QUANTILES,
56
+ grad_to_none=[x],
57
+ rep=500,
58
+ )
59
+ elif mode == "full":
60
+
61
+ def full():
62
+ y = fwd()
63
+ y.backward(dy)
64
+
65
+ ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=[x], rep=500)
66
+
67
+ return SingleBenchmarkRunOutput(
68
+ y_20=ms_20,
69
+ y_50=ms_50,
70
+ y_80=ms_80,
71
+ )
72
+
73
+
74
+ def bench_memory_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
75
+ from test.transformers.test_dyt import LigerDyT
76
+ from test.transformers.test_dyt import TorchDyT
77
+
78
+ BT = input.x
79
+ provider = input.kernel_provider
80
+ extra_benchmark_config = input.extra_benchmark_config
81
+ hidden_size = extra_benchmark_config["hidden_size"]
82
+ dtype = extra_benchmark_config["dtype"]
83
+
84
+ x_shape = (BT, hidden_size)
85
+ torch_dyt = TorchDyT(hidden_size=hidden_size).to(device)
86
+ torch_compile_dyt = torch.compile(TorchDyT(hidden_size=hidden_size).to(device))
87
+ triton_dyt = LigerDyT(hidden_size=hidden_size).to(device)
88
+
89
+ x = torch.randn(x_shape, dtype=dtype, device=device)
90
+ dy = torch.randn_like(x)
91
+ x.requires_grad_(True)
92
+
93
+ def fwd():
94
+ if provider == "liger":
95
+ return triton_dyt(x)
96
+ elif provider == "torch":
97
+ return torch_dyt(x)
98
+ elif provider == "torch_compile":
99
+ return torch_compile_dyt(x)
100
+
101
+ def full():
102
+ y = fwd()
103
+ y.backward(dy, retain_graph=True)
104
+
105
+ mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
106
+ return SingleBenchmarkRunOutput(
107
+ y_20=mem_20,
108
+ y_50=mem_50,
109
+ y_80=mem_80,
110
+ )
111
+
112
+
113
+ if __name__ == "__main__":
114
+ args = parse_benchmark_script_args()
115
+
116
+ common_configs = {
117
+ "kernel_name": "dyt",
118
+ "x_name": "BT",
119
+ "x_label": "batch_size * seq_len",
120
+ "x_values": [2**i for i in range(10, 15)],
121
+ "kernel_providers": ["liger", "torch", "torch_compile"],
122
+ "extra_benchmark_configs": [{"hidden_size": 4096, "dtype": torch.float32}],
123
+ "overwrite": args.overwrite,
124
+ }
125
+
126
+ run_benchmarks(
127
+ bench_test_fn=bench_speed_dyt,
128
+ kernel_operation_modes=["forward", "backward", "full"],
129
+ metric_name="speed",
130
+ metric_unit="ms",
131
+ **common_configs,
132
+ )
133
+ run_benchmarks(
134
+ bench_test_fn=bench_memory_dyt,
135
+ kernel_operation_modes=["full"],
136
+ metric_name="memory",
137
+ metric_unit="MB",
138
+ **common_configs,
139
+ )
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "liger_kernel_nightly"
7
- version = "0.5.5.dev20250322021112"
7
+ version = "0.5.5.dev20250324181221"
8
8
  description = "Efficient Triton kernels for LLM Training"
9
9
  urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
10
10
  readme = { file = "README.md", content-type = "text/markdown" }
@@ -0,0 +1,225 @@
1
+ import operator
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from liger_kernel.ops.utils import calculate_settings
8
+ from liger_kernel.ops.utils import compare_version
9
+ from liger_kernel.ops.utils import ensure_contiguous
10
+ from liger_kernel.ops.utils import infer_device
11
+
12
+ if compare_version("triton", operator.ge, "3.0.0"):
13
+ try:
14
+ # typical import path with dispatch available
15
+ from triton.language.extra.libdevice import tanh
16
+ except ModuleNotFoundError:
17
+ # for working with NGC containers
18
+ from triton.language.extra.cuda.libdevice import tanh
19
+ else:
20
+ from triton.language.math import tanh
21
+
22
+
23
+ @triton.jit
24
+ def _dyt_fwd_kernel(
25
+ x_ptr,
26
+ x_row_stride,
27
+ alpha_ptr,
28
+ gamma_ptr,
29
+ beta_ptr,
30
+ y_ptr,
31
+ y_row_stride,
32
+ n_cols,
33
+ BLOCK_SIZE: tl.constexpr,
34
+ ):
35
+ """
36
+ Reference:
37
+ https://arxiv.org/abs/2503.10622
38
+
39
+ Shapes:
40
+ - x: (BT, C)
41
+ - alpha: (1)
42
+ - gamma: (C)
43
+ - beta: (C)
44
+ """
45
+ row_idx = tl.program_id(0)
46
+ offsets = tl.arange(0, BLOCK_SIZE)
47
+ mask = offsets < n_cols
48
+
49
+ x_ptr += row_idx * x_row_stride
50
+ y_ptr += row_idx * y_row_stride
51
+
52
+ alpha = tl.load(alpha_ptr)
53
+ gamma = tl.load(gamma_ptr + offsets, mask=mask)
54
+ beta = tl.load(beta_ptr + offsets, mask=mask)
55
+ x = tl.load(x_ptr + offsets, mask=mask)
56
+ y = gamma * tanh((alpha * x).cast(tl.float32)) + beta
57
+ tl.store(y_ptr + offsets, y, mask=mask)
58
+
59
+
60
+ @triton.jit
61
+ def _dyt_bwd_kernel(
62
+ x_ptr,
63
+ x_row_stride,
64
+ dy_ptr,
65
+ dy_row_stride,
66
+ dx_ptr,
67
+ dx_row_stride,
68
+ alpha_ptr,
69
+ dalpha_ptr,
70
+ gamma_ptr,
71
+ dgamma_ptr,
72
+ dgamma_row_stride,
73
+ n_cols,
74
+ n_rows,
75
+ ROWS_PER_PROGRAM: tl.constexpr,
76
+ BLOCK_SIZE: tl.constexpr,
77
+ ):
78
+ """
79
+ Reference:
80
+ https://arxiv.org/abs/2503.10622
81
+
82
+ Shapes:
83
+ - x: (BT, C)
84
+ - alpha: (1)
85
+ - gamma: (C)
86
+ - dx: (BT, C)
87
+ - dy: (BT, C)
88
+ - dgamma: (sm_count, C)
89
+ - dalpha: (sm_count,)
90
+ """
91
+ # d(gamma * tanh(alpha * x) + beta) / dx
92
+ # = gamma * (1 - tanh^2(alpha * x)) * alpha
93
+ # d(gamma * tanh(alpha * x) + beta) / dalpha
94
+ # = gamma * (1 - tanh^2(alpha * x)) * x
95
+ # d(gamma * tanh(alpha * x) + beta) / dgamma
96
+ # = tanh(alpha * x)
97
+ # d(gamma * tanh(alpha * x)) / dbeta = 1
98
+ pid = tl.program_id(0)
99
+
100
+ row_start = pid * ROWS_PER_PROGRAM
101
+ row_end = min((pid + 1) * ROWS_PER_PROGRAM, n_rows)
102
+ offsets = tl.arange(0, BLOCK_SIZE)
103
+ mask = offsets < n_cols
104
+
105
+ dalpha = 0.0
106
+ dgamma = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
107
+
108
+ x_ptr += row_start * x_row_stride
109
+ dx_ptr += row_start * dx_row_stride
110
+ dy_ptr += row_start * dy_row_stride
111
+ alpha = tl.load(alpha_ptr)
112
+ gamma = tl.load(gamma_ptr + offsets, mask=mask, other=0.0)
113
+
114
+ for _ in tl.range(row_start, row_end):
115
+ dy = tl.load(dy_ptr + offsets, mask=mask, other=0.0)
116
+ x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
117
+ tanh_ax = tanh((alpha * x).cast(tl.float32))
118
+ sech2_ax = 1 - tanh_ax * tanh_ax
119
+
120
+ dx = dy * gamma * sech2_ax * alpha
121
+ dalpha += tl.sum(dy * gamma * sech2_ax * x)
122
+ dgamma += dy * tanh_ax
123
+ tl.store(dx_ptr + offsets, dx, mask=mask)
124
+
125
+ dy_ptr += dy_row_stride
126
+ x_ptr += x_row_stride
127
+ dx_ptr += dx_row_stride
128
+
129
+ tl.store(dgamma_ptr + pid * dgamma_row_stride + offsets, dgamma, mask=mask)
130
+ tl.store(dalpha_ptr + pid, dalpha)
131
+
132
+ pass
133
+
134
+
135
+ def liger_dyt_fwd(x, alpha, gamma, beta):
136
+ shape = x.shape
137
+ dim = shape[-1]
138
+ x = x.view(-1, dim)
139
+ n_rows, n_cols = x.shape
140
+ y = torch.empty_like(x)
141
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
142
+ _dyt_fwd_kernel[(n_rows,)](
143
+ x_ptr=x,
144
+ alpha_ptr=alpha,
145
+ gamma_ptr=gamma,
146
+ beta_ptr=beta,
147
+ y_ptr=y,
148
+ x_row_stride=x.stride(0),
149
+ y_row_stride=y.stride(0),
150
+ n_cols=n_cols,
151
+ BLOCK_SIZE=BLOCK_SIZE,
152
+ num_warps=num_warps,
153
+ )
154
+ return y.view(*shape)
155
+
156
+
157
+ def liger_dyt_bwd(dy, x, alpha, gamma):
158
+ shape = dy.shape
159
+ dtype = x.dtype
160
+ dim = shape[-1]
161
+ dy = dy.view(-1, dim)
162
+ x = x.view(-1, dim)
163
+ n_rows, n_cols = dy.shape
164
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
165
+ sm_count = 1
166
+ device = infer_device()
167
+ if device == "cuda":
168
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
169
+ elif device == "xpu":
170
+ sm_count = torch.xpu.get_device_properties(x.device).gpu_subslice_count
171
+ if n_cols > BLOCK_SIZE:
172
+ raise RuntimeError(
173
+ f"Feature dimension {dim} exceeds maximum supported size of {BLOCK_SIZE}. Consider using a smaller feature dimension."
174
+ )
175
+
176
+ dx = torch.empty_like(x, dtype=torch.float32)
177
+ _dalpha = torch.empty((sm_count,), dtype=torch.float32, device=x.device)
178
+ _dgamma = torch.empty((sm_count, n_cols), dtype=torch.float32, device=x.device)
179
+
180
+ grid = (sm_count,)
181
+ rows_per_program = triton.cdiv(n_rows, sm_count)
182
+ _dyt_bwd_kernel[grid](
183
+ x_ptr=x,
184
+ x_row_stride=x.stride(0),
185
+ dy_ptr=dy,
186
+ dy_row_stride=dy.stride(0),
187
+ dx_ptr=dx,
188
+ dx_row_stride=dx.stride(0),
189
+ alpha_ptr=alpha,
190
+ dalpha_ptr=_dalpha,
191
+ gamma_ptr=gamma,
192
+ dgamma_ptr=_dgamma,
193
+ dgamma_row_stride=_dgamma.stride(0),
194
+ n_cols=n_cols,
195
+ n_rows=n_rows,
196
+ ROWS_PER_PROGRAM=rows_per_program,
197
+ BLOCK_SIZE=BLOCK_SIZE,
198
+ num_warps=num_warps,
199
+ )
200
+ dalpha = _dalpha.sum(dim=0, keepdim=True).to(dtype)
201
+ dgamma = _dgamma.sum(dim=0).to(dtype)
202
+ dbeta = dy.sum(dim=0).to(dtype)
203
+ return dx.view(*shape), dalpha, dgamma, dbeta
204
+
205
+
206
+ class LigerDyTFunction(torch.autograd.Function):
207
+ @staticmethod
208
+ @ensure_contiguous
209
+ def forward(ctx, x, alpha, gamma, beta):
210
+ y = liger_dyt_fwd(x, alpha, gamma, beta)
211
+ ctx.save_for_backward(x, alpha, gamma)
212
+ return y
213
+
214
+ @staticmethod
215
+ @ensure_contiguous
216
+ def backward(ctx, grad_output):
217
+ x, alpha, gamma = ctx.saved_tensors
218
+ dx, dalpha, dgamma, dbeta = liger_dyt_bwd(
219
+ grad_output,
220
+ x,
221
+ alpha,
222
+ gamma,
223
+ )
224
+
225
+ return (dx, dalpha, dgamma, dbeta)
@@ -1,5 +1,6 @@
1
1
  from liger_kernel.transformers.auto_model import AutoLigerKernelForCausalLM # noqa: F401
2
2
  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss # noqa: F401
3
+ from liger_kernel.transformers.dyt import LigerDyT # noqa: F401
3
4
  from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss # noqa: F401
4
5
  from liger_kernel.transformers.fused_linear_jsd import LigerFusedLinearJSD # noqa: F401
5
6
  from liger_kernel.transformers.geglu import LigerGEGLUMLP # noqa: F401
@@ -0,0 +1,20 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from liger_kernel.ops.dyt import LigerDyTFunction
5
+
6
+
7
+ class LigerDyT(nn.Module):
8
+ def __init__(self, hidden_size, init_alpha=0.5):
9
+ super().__init__()
10
+ self.hidden_size = hidden_size
11
+ self.init_alpha = init_alpha
12
+ self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
13
+ self.gamma = nn.Parameter(torch.ones(hidden_size))
14
+ self.beta = nn.Parameter(torch.zeros(hidden_size))
15
+
16
+ def forward(self, x):
17
+ return LigerDyTFunction.apply(x, self.alpha, self.gamma, self.beta)
18
+
19
+ def extra_repr(self):
20
+ return f"{self.hidden_size}, init_alpha={self.init_alpha}"
@@ -1,6 +1,7 @@
1
1
  from typing import Optional
2
2
 
3
3
  from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
4
+ from liger_kernel.ops.dyt import LigerDyTFunction
4
5
  from liger_kernel.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
5
6
  from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
6
7
  from liger_kernel.ops.geglu import LigerGELUMulFunction
@@ -192,3 +193,7 @@ def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
192
193
 
193
194
  def liger_swiglu(a, b):
194
195
  return LigerSiLUMulFunction.apply(a, b)
196
+
197
+
198
+ def liger_dyt(x, alpha, gamma, beta):
199
+ return LigerDyTFunction.apply(x, alpha, gamma, beta)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: liger_kernel_nightly
3
- Version: 0.5.5.dev20250322021112
3
+ Version: 0.5.5.dev20250324181221
4
4
  Summary: Efficient Triton kernels for LLM Training
5
5
  License: BSD 2-CLAUSE LICENSE
6
6
  Copyright 2024 LinkedIn Corporation
@@ -24,6 +24,7 @@ benchmark/scripts/benchmark_cpo_loss.py
24
24
  benchmark/scripts/benchmark_cross_entropy.py
25
25
  benchmark/scripts/benchmark_distill_jsd_loss.py
26
26
  benchmark/scripts/benchmark_dpo_loss.py
27
+ benchmark/scripts/benchmark_dyt.py
27
28
  benchmark/scripts/benchmark_embedding.py
28
29
  benchmark/scripts/benchmark_fused_linear_cross_entropy.py
29
30
  benchmark/scripts/benchmark_fused_linear_jsd.py
@@ -121,6 +122,7 @@ src/liger_kernel/chunked_loss/orpo_loss.py
121
122
  src/liger_kernel/chunked_loss/simpo_loss.py
122
123
  src/liger_kernel/ops/__init__.py
123
124
  src/liger_kernel/ops/cross_entropy.py
125
+ src/liger_kernel/ops/dyt.py
124
126
  src/liger_kernel/ops/fused_linear_cross_entropy.py
125
127
  src/liger_kernel/ops/fused_linear_jsd.py
126
128
  src/liger_kernel/ops/geglu.py
@@ -139,6 +141,7 @@ src/liger_kernel/ops/experimental/mm_int8int2.py
139
141
  src/liger_kernel/transformers/__init__.py
140
142
  src/liger_kernel/transformers/auto_model.py
141
143
  src/liger_kernel/transformers/cross_entropy.py
144
+ src/liger_kernel/transformers/dyt.py
142
145
  src/liger_kernel/transformers/functional.py
143
146
  src/liger_kernel/transformers/fused_linear_cross_entropy.py
144
147
  src/liger_kernel/transformers/fused_linear_jsd.py
@@ -209,6 +212,7 @@ test/resources/tiny_shakespeare_tokenized/dataset_info.json
209
212
  test/resources/tiny_shakespeare_tokenized/state.json
210
213
  test/transformers/test_auto_model.py
211
214
  test/transformers/test_cross_entropy.py
215
+ test/transformers/test_dyt.py
212
216
  test/transformers/test_embedding.py
213
217
  test/transformers/test_flex_attention.py
214
218
  test/transformers/test_fused_linear_cross_entropy.py
@@ -0,0 +1,136 @@
1
+ import pytest
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from test.utils import assert_verbose_allclose
6
+ from test.utils import infer_device
7
+ from test.utils import set_seed
8
+ from test.utils import supports_bfloat16
9
+
10
+ from liger_kernel.ops.dyt import LigerDyTFunction
11
+ from liger_kernel.transformers.dyt import LigerDyT
12
+ from liger_kernel.transformers.functional import liger_dyt
13
+
14
+
15
+ class TorchDyT(nn.Module):
16
+ def __init__(self, hidden_size, init_alpha=0.5):
17
+ super().__init__()
18
+ self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
19
+ self.gamma = nn.Parameter(torch.ones(hidden_size))
20
+ self.beta = nn.Parameter(torch.zeros(hidden_size))
21
+
22
+ def forward(self, x):
23
+ return self.gamma * torch.tanh(self.alpha * x) + self.beta
24
+
25
+
26
+ set_seed(42)
27
+ device = infer_device()
28
+
29
+
30
+ @pytest.mark.parametrize("init_alpha", [0.5, 0.2, 1.0])
31
+ @pytest.mark.parametrize(
32
+ "B, T, hidden_size",
33
+ [
34
+ (2, 8, 4096),
35
+ (4, 16, 2048),
36
+ (1, 1, 1023), # Minimal batch/seq with near power-of-2 hidden
37
+ (3, 7, 256), # Prime numbers for batch/seq
38
+ ],
39
+ )
40
+ @pytest.mark.parametrize(
41
+ "dtype, atol, rtol",
42
+ [
43
+ (torch.float32, 1e-5, 1e-5),
44
+ ],
45
+ )
46
+ def test_liger_dyt_correctness(B, T, hidden_size, init_alpha, dtype, atol, rtol):
47
+ _input = torch.randn(B, T, hidden_size, device=device, dtype=dtype)
48
+
49
+ x1 = _input.clone().requires_grad_(True)
50
+ x2 = _input.clone().requires_grad_(True)
51
+
52
+ # initialize weights
53
+ alpha = torch.randn(1, device=device, dtype=dtype)
54
+ gamma = torch.randn(hidden_size, device=device, dtype=dtype)
55
+ beta = torch.randn(hidden_size, device=device, dtype=dtype)
56
+
57
+ torch_dyt = TorchDyT(hidden_size=hidden_size, init_alpha=init_alpha).to(device).to(dtype)
58
+ torch_dyt.alpha.data = alpha.clone()
59
+ torch_dyt.gamma.data = gamma.clone()
60
+ torch_dyt.beta.data = beta.clone()
61
+
62
+ liger_dyt = LigerDyT(hidden_size=hidden_size, init_alpha=init_alpha).to(device).to(dtype)
63
+ liger_dyt.alpha.data = alpha.clone()
64
+ liger_dyt.gamma.data = gamma.clone()
65
+ liger_dyt.beta.data = beta.clone()
66
+
67
+ torch_output = torch_dyt(x1)
68
+ liger_output = liger_dyt(x2)
69
+
70
+ assert_verbose_allclose(torch_output, liger_output, rtol=rtol, atol=atol)
71
+
72
+ grad_output = torch.randn_like(_input)
73
+ torch_output.backward(grad_output)
74
+ liger_output.backward(grad_output)
75
+
76
+ assert_verbose_allclose(x1.grad, x2.grad, rtol=rtol, atol=atol)
77
+ assert_verbose_allclose(torch_dyt.alpha.grad, liger_dyt.alpha.grad, rtol=rtol, atol=atol)
78
+ assert_verbose_allclose(torch_dyt.gamma.grad, liger_dyt.gamma.grad, rtol=rtol, atol=atol)
79
+ assert_verbose_allclose(torch_dyt.beta.grad, liger_dyt.beta.grad, rtol=rtol, atol=atol)
80
+
81
+
82
+ @pytest.mark.parametrize(
83
+ "B, T, hidden_size",
84
+ [
85
+ (2, 8, 4096),
86
+ (4, 16, 2048),
87
+ (1, 1, 1023), # Minimal batch/seq with near power-of-2 hidden
88
+ (3, 7, 256), # Prime numbers for batch/seq
89
+ ],
90
+ )
91
+ @pytest.mark.parametrize(
92
+ "dtype, atol, rtol",
93
+ [
94
+ # atol is for small values: they have more difference, so set atol higher
95
+ # rtol is for larger values: they are very close, so set rtol lower
96
+ (torch.float32, 1e-5, 1e-5),
97
+ pytest.param(
98
+ torch.bfloat16,
99
+ 1e-8,
100
+ 5e-2,
101
+ marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
102
+ ),
103
+ ],
104
+ )
105
+ def test_liger_dyt_functional(B, T, hidden_size, dtype, atol, rtol):
106
+ _input = torch.randn(B, T, hidden_size, device=device, dtype=dtype)
107
+
108
+ x1 = _input.clone().requires_grad_(True)
109
+ x2 = _input.clone().requires_grad_(True)
110
+
111
+ # initialize weights
112
+ alpha = torch.randn(1, device=device, dtype=dtype)
113
+ gamma = torch.randn(hidden_size, device=device, dtype=dtype)
114
+ beta = torch.randn(hidden_size, device=device, dtype=dtype)
115
+
116
+ alpha1 = alpha.clone().requires_grad_(True)
117
+ gamma1 = gamma.clone().requires_grad_(True)
118
+ beta1 = beta.clone().requires_grad_(True)
119
+
120
+ alpha2 = alpha.clone().requires_grad_(True)
121
+ gamma2 = gamma.clone().requires_grad_(True)
122
+ beta2 = beta.clone().requires_grad_(True)
123
+
124
+ output1 = liger_dyt(x1, alpha=alpha1, gamma=gamma1, beta=beta1)
125
+ output2 = LigerDyTFunction.apply(x2, alpha2, gamma2, beta2)
126
+
127
+ assert_verbose_allclose(output1, output2, rtol=rtol, atol=atol)
128
+
129
+ grad_output = torch.randn_like(_input)
130
+ output1.backward(grad_output)
131
+ output2.backward(grad_output)
132
+
133
+ assert_verbose_allclose(x1.grad, x2.grad, rtol=rtol, atol=atol)
134
+ assert_verbose_allclose(alpha1.grad, alpha2.grad, rtol=rtol, atol=atol)
135
+ assert_verbose_allclose(gamma1.grad, gamma2.grad, rtol=rtol, atol=atol)
136
+ assert_verbose_allclose(beta1.grad, beta2.grad, rtol=rtol, atol=atol)